Skip to main content

better_fetch/
sse.rs

1//! Server-Sent Events (`text/event-stream`) helpers for [`StreamingResponse`](crate::StreamingResponse).
2//!
3//! Use [`SseDecoder`] for incremental parsing, [`parse_sse_events`] for complete buffers, or
4//! [`StreamingResponse::read_sse_events`](crate::StreamingResponse::read_sse_events) to buffer first.
5
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use futures_util::Stream;
10use pin_project_lite::pin_project;
11
12use crate::Result;
13
14/// One SSE event (may aggregate multiple `data:` lines).
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct SseEvent {
17    /// Optional `event:` field.
18    pub event: Option<String>,
19    /// Concatenated `data:` lines joined with `\n`.
20    pub data: String,
21    /// Optional `id:` field.
22    pub id: Option<String>,
23}
24
25/// Incrementally parses SSE events from UTF-8 chunks (blocks delimited by `\n\n`).
26#[derive(Debug, Default)]
27pub struct SseDecoder {
28    buffer: String,
29}
30
31impl SseDecoder {
32    /// Creates an empty decoder.
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Appends a chunk and returns any complete events parsed from the buffer.
38    pub fn push_chunk(&mut self, chunk: &[u8]) -> Result<Vec<SseEvent>> {
39        let text = std::str::from_utf8(chunk)
40            .map_err(|e| crate::Error::Config(format!("SSE chunk is not valid UTF-8: {e}")))?;
41        self.buffer.push_str(text);
42        Ok(self.drain_complete_events())
43    }
44
45    /// Parses any trailing bytes as a final event block (if non-empty).
46    pub fn finish(mut self) -> Vec<SseEvent> {
47        let tail = std::mem::take(&mut self.buffer);
48        if tail.trim().is_empty() {
49            return Vec::new();
50        }
51        parse_sse_block(&tail).into_iter().collect()
52    }
53
54    fn drain_complete_events(&mut self) -> Vec<SseEvent> {
55        let mut events = Vec::new();
56        while let Some(pos) = self.buffer.find("\n\n") {
57            let block: String = self.buffer.drain(..pos + 2).collect();
58            let block = block.trim();
59            if block.is_empty() {
60                continue;
61            }
62            if let Some(event) = parse_sse_block(block) {
63                events.push(event);
64            }
65        }
66        events
67    }
68}
69
70pin_project! {
71    /// Stream of [`SseEvent`] parsed incrementally from a response body.
72    pub struct SseEventStream {
73        #[pin]
74        body: crate::BodyStream,
75        decoder: SseDecoder,
76        pending: std::collections::VecDeque<SseEvent>,
77        max_bytes: Option<u64>,
78        accumulated: u64,
79        finished: bool,
80    }
81}
82
83impl SseEventStream {
84    pub(crate) fn new(body: crate::BodyStream, max_bytes: Option<u64>) -> Self {
85        Self {
86            body,
87            decoder: SseDecoder::new(),
88            pending: std::collections::VecDeque::new(),
89            max_bytes,
90            accumulated: 0,
91            finished: false,
92        }
93    }
94}
95
96impl Stream for SseEventStream {
97    type Item = Result<SseEvent>;
98
99    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100        let mut this = self.project();
101
102        if let Some(event) = this.pending.pop_front() {
103            return Poll::Ready(Some(Ok(event)));
104        }
105
106        if *this.finished {
107            return Poll::Ready(None);
108        }
109
110        loop {
111            match this.body.as_mut().poll_next(cx) {
112                Poll::Ready(Some(Ok(chunk))) => {
113                    if let Some(limit) = *this.max_bytes {
114                        *this.accumulated += chunk.len() as u64;
115                        if *this.accumulated > limit {
116                            return Poll::Ready(Some(Err(crate::Error::BodyTooLarge { limit })));
117                        }
118                    }
119                    match this.decoder.push_chunk(&chunk) {
120                        Ok(events) => {
121                            for event in events {
122                                this.pending.push_back(event);
123                            }
124                            if let Some(event) = this.pending.pop_front() {
125                                return Poll::Ready(Some(Ok(event)));
126                            }
127                        }
128                        Err(e) => return Poll::Ready(Some(Err(e))),
129                    }
130                }
131                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
132                Poll::Ready(None) => {
133                    *this.finished = true;
134                    let decoder = std::mem::take(this.decoder);
135                    for event in decoder.finish() {
136                        this.pending.push_back(event);
137                    }
138                    if let Some(event) = this.pending.pop_front() {
139                        return Poll::Ready(Some(Ok(event)));
140                    }
141                    return Poll::Ready(None);
142                }
143                Poll::Pending => return Poll::Pending,
144            }
145        }
146    }
147}
148
149/// Parses SSE events from a buffer (blocks separated by blank lines).
150pub fn parse_sse_events(buffer: &str) -> Vec<SseEvent> {
151    let mut events = Vec::new();
152    for block in buffer.split("\n\n") {
153        let block = block.trim();
154        if block.is_empty() {
155            continue;
156        }
157        if let Some(event) = parse_sse_block(block) {
158            events.push(event);
159        }
160    }
161    events
162}
163
164fn parse_sse_block(block: &str) -> Option<SseEvent> {
165    let mut event_name = None;
166    let mut id = None;
167    let mut data_lines = Vec::new();
168
169    for line in block.lines() {
170        if line.is_empty() || line.starts_with(':') {
171            continue;
172        }
173        if let Some(rest) = line.strip_prefix("event:") {
174            event_name = Some(rest.trim().to_string());
175        } else if let Some(rest) = line.strip_prefix("data:") {
176            data_lines.push(rest.trim_start().to_string());
177        } else if let Some(rest) = line.strip_prefix("id:") {
178            id = Some(rest.trim().to_string());
179        }
180    }
181
182    if data_lines.is_empty() && event_name.is_none() && id.is_none() {
183        return None;
184    }
185
186    Some(SseEvent {
187        event: event_name,
188        data: data_lines.join("\n"),
189        id,
190    })
191}
192
193/// Reads a streaming body to completion and parses SSE events.
194pub async fn read_sse_from_bytes(
195    body: crate::BodyStream,
196    max_bytes: Option<u64>,
197) -> Result<Vec<SseEvent>> {
198    let bytes = crate::streaming::accumulate_stream(body, max_bytes).await?;
199    let text = String::from_utf8_lossy(&bytes);
200    Ok(parse_sse_events(&text))
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn parses_simple_event() {
209        let events = parse_sse_events("data: hello\n\n");
210        assert_eq!(events.len(), 1);
211        assert_eq!(events[0].data, "hello");
212    }
213
214    #[test]
215    fn parses_event_name_and_multiline_data() {
216        let raw = "event: ping\ndata: line1\ndata: line2\n\n";
217        let events = parse_sse_events(raw);
218        assert_eq!(events[0].event.as_deref(), Some("ping"));
219        assert_eq!(events[0].data, "line1\nline2");
220    }
221
222    #[test]
223    fn decoder_splits_across_chunks() {
224        let mut decoder = SseDecoder::new();
225        let first = decoder.push_chunk(b"data: hel").unwrap();
226        assert!(first.is_empty());
227        let second = decoder.push_chunk(b"lo\n\ndata: world\n\n").unwrap();
228        assert_eq!(second.len(), 2);
229        assert_eq!(second[0].data, "hello");
230        assert_eq!(second[1].data, "world");
231    }
232}