Skip to main content

better_fetch/
sse.rs

1//! Server-Sent Events (`text/event-stream`) helpers for [`StreamingResponse`](crate::StreamingResponse).
2//!
3//! Enable with Cargo feature `sse` on the `better-fetch` crate.
4//!
5//! Use [`SseDecoder`] for incremental parsing, [`parse_sse_events`] for complete buffers, or
6//! [`StreamingResponse::read_sse_events`](crate::StreamingResponse::read_sse_events) to buffer first.
7
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use futures_util::Stream;
12use pin_project_lite::pin_project;
13
14use crate::Result;
15
16/// One SSE event (may aggregate multiple `data:` lines).
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct SseEvent {
19    /// Optional `event:` field.
20    pub event: Option<String>,
21    /// Concatenated `data:` lines joined with `\n`.
22    pub data: String,
23    /// Optional `id:` field.
24    pub id: Option<String>,
25}
26
27/// Incrementally parses SSE events from UTF-8 chunks (blocks delimited by a blank line).
28///
29/// Line terminators may be LF, CR, or CRLF (normalized to LF), including a CRLF split across chunks.
30#[derive(Debug, Default)]
31pub struct SseDecoder {
32    buffer: String,
33    /// Previous chunk ended with a lone `\r` that may pair with a leading `\n` of the next chunk.
34    pending_cr: bool,
35}
36
37impl SseDecoder {
38    /// Creates an empty decoder.
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    /// Appends a chunk and returns any complete events parsed from the buffer.
44    pub fn push_chunk(&mut self, chunk: &[u8]) -> Result<Vec<SseEvent>> {
45        let text = std::str::from_utf8(chunk)
46            .map_err(|e| crate::Error::Config(format!("SSE chunk is not valid UTF-8: {e}")))?;
47        self.push_normalized(text);
48        Ok(self.drain_complete_events())
49    }
50
51    /// Appends `text` to the buffer, normalizing CR / CRLF line endings to LF.
52    fn push_normalized(&mut self, text: &str) {
53        let mut chars = text.chars().peekable();
54        if self.pending_cr {
55            self.pending_cr = false;
56            if chars.peek() == Some(&'\n') {
57                chars.next();
58            }
59            self.buffer.push('\n');
60        }
61        while let Some(c) = chars.next() {
62            if c == '\r' {
63                match chars.peek() {
64                    Some('\n') => {
65                        chars.next();
66                        self.buffer.push('\n');
67                    }
68                    Some(_) => self.buffer.push('\n'),
69                    None => self.pending_cr = true,
70                }
71            } else {
72                self.buffer.push(c);
73            }
74        }
75    }
76
77    /// Parses any trailing bytes as a final event block (if non-empty).
78    pub fn finish(mut self) -> Vec<SseEvent> {
79        if self.pending_cr {
80            self.buffer.push('\n');
81        }
82        let tail = std::mem::take(&mut self.buffer);
83        if tail.trim().is_empty() {
84            return Vec::new();
85        }
86        parse_sse_block(&tail).into_iter().collect()
87    }
88
89    fn drain_complete_events(&mut self) -> Vec<SseEvent> {
90        let mut events = Vec::new();
91        while let Some(pos) = self.buffer.find("\n\n") {
92            let block: String = self.buffer.drain(..pos + 2).collect();
93            let block = block.trim();
94            if block.is_empty() {
95                continue;
96            }
97            if let Some(event) = parse_sse_block(block) {
98                events.push(event);
99            }
100        }
101        events
102    }
103}
104
105pin_project! {
106    /// Stream of [`SseEvent`] parsed incrementally from a response body.
107    pub struct SseEventStream {
108        #[pin]
109        body: crate::BodyStream,
110        decoder: SseDecoder,
111        pending: std::collections::VecDeque<SseEvent>,
112        max_bytes: Option<u64>,
113        accumulated: u64,
114        finished: bool,
115    }
116}
117
118impl SseEventStream {
119    pub(crate) fn new(body: crate::BodyStream, max_bytes: Option<u64>) -> Self {
120        Self {
121            body,
122            decoder: SseDecoder::new(),
123            pending: std::collections::VecDeque::new(),
124            max_bytes,
125            accumulated: 0,
126            finished: false,
127        }
128    }
129}
130
131impl Stream for SseEventStream {
132    type Item = Result<SseEvent>;
133
134    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
135        let mut this = self.project();
136
137        if let Some(event) = this.pending.pop_front() {
138            return Poll::Ready(Some(Ok(event)));
139        }
140
141        if *this.finished {
142            return Poll::Ready(None);
143        }
144
145        loop {
146            match this.body.as_mut().poll_next(cx) {
147                Poll::Ready(Some(Ok(chunk))) => {
148                    if let Some(limit) = *this.max_bytes {
149                        *this.accumulated += chunk.len() as u64;
150                        if *this.accumulated > limit {
151                            return Poll::Ready(Some(Err(crate::Error::BodyTooLarge { limit })));
152                        }
153                    }
154                    match this.decoder.push_chunk(&chunk) {
155                        Ok(events) => {
156                            for event in events {
157                                this.pending.push_back(event);
158                            }
159                            if let Some(event) = this.pending.pop_front() {
160                                return Poll::Ready(Some(Ok(event)));
161                            }
162                        }
163                        Err(e) => return Poll::Ready(Some(Err(e))),
164                    }
165                }
166                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
167                Poll::Ready(None) => {
168                    *this.finished = true;
169                    let decoder = std::mem::take(this.decoder);
170                    for event in decoder.finish() {
171                        this.pending.push_back(event);
172                    }
173                    if let Some(event) = this.pending.pop_front() {
174                        return Poll::Ready(Some(Ok(event)));
175                    }
176                    return Poll::Ready(None);
177                }
178                Poll::Pending => return Poll::Pending,
179            }
180        }
181    }
182}
183
184/// Parses SSE events from a buffer (blocks separated by blank lines; LF, CR, or CRLF).
185pub fn parse_sse_events(buffer: &str) -> Vec<SseEvent> {
186    let normalized = buffer.replace("\r\n", "\n").replace('\r', "\n");
187    let mut events = Vec::new();
188    for block in normalized.split("\n\n") {
189        let block = block.trim();
190        if block.is_empty() {
191            continue;
192        }
193        if let Some(event) = parse_sse_block(block) {
194            events.push(event);
195        }
196    }
197    events
198}
199
200/// Removes a single optional leading space after the field colon (per the SSE spec).
201fn strip_one_space(value: &str) -> &str {
202    value.strip_prefix(' ').unwrap_or(value)
203}
204
205fn parse_sse_block(block: &str) -> Option<SseEvent> {
206    let mut event_name = None;
207    let mut id = None;
208    let mut data_lines = Vec::new();
209
210    for line in block.lines() {
211        if line.is_empty() || line.starts_with(':') {
212            continue;
213        }
214        if let Some(rest) = line.strip_prefix("event:") {
215            event_name = Some(strip_one_space(rest).to_string());
216        } else if let Some(rest) = line.strip_prefix("data:") {
217            data_lines.push(strip_one_space(rest).to_string());
218        } else if let Some(rest) = line.strip_prefix("id:") {
219            id = Some(strip_one_space(rest).to_string());
220        }
221    }
222
223    if data_lines.is_empty() && event_name.is_none() && id.is_none() {
224        return None;
225    }
226
227    Some(SseEvent {
228        event: event_name,
229        data: data_lines.join("\n"),
230        id,
231    })
232}
233
234/// Reads a streaming body to completion and parses SSE events.
235pub async fn read_sse_from_bytes(
236    body: crate::BodyStream,
237    max_bytes: Option<u64>,
238) -> Result<Vec<SseEvent>> {
239    let bytes = crate::streaming::accumulate_stream(body, max_bytes).await?;
240    let text = String::from_utf8_lossy(&bytes);
241    Ok(parse_sse_events(&text))
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn parses_simple_event() {
250        let events = parse_sse_events("data: hello\n\n");
251        assert_eq!(events.len(), 1);
252        assert_eq!(events[0].data, "hello");
253    }
254
255    #[test]
256    fn parses_event_name_and_multiline_data() {
257        let raw = "event: ping\ndata: line1\ndata: line2\n\n";
258        let events = parse_sse_events(raw);
259        assert_eq!(events[0].event.as_deref(), Some("ping"));
260        assert_eq!(events[0].data, "line1\nline2");
261    }
262
263    #[test]
264    fn decoder_splits_across_chunks() {
265        let mut decoder = SseDecoder::new();
266        let first = decoder.push_chunk(b"data: hel").unwrap();
267        assert!(first.is_empty());
268        let second = decoder.push_chunk(b"lo\n\ndata: world\n\n").unwrap();
269        assert_eq!(second.len(), 2);
270        assert_eq!(second[0].data, "hello");
271        assert_eq!(second[1].data, "world");
272    }
273
274    #[test]
275    fn parses_crlf_delimited_events() {
276        let events = parse_sse_events("data: a\r\n\r\ndata: b\r\n\r\n");
277        assert_eq!(events.len(), 2);
278        assert_eq!(events[0].data, "a");
279        assert_eq!(events[1].data, "b");
280    }
281
282    #[test]
283    fn decoder_handles_crlf_split_across_chunks() {
284        let mut decoder = SseDecoder::new();
285        // Chunk ends mid-CRLF (lone `\r`); the next chunk supplies the `\n`.
286        assert!(decoder.push_chunk(b"data: hello\r").unwrap().is_empty());
287        let events = decoder.push_chunk(b"\n\r\n").unwrap();
288        assert_eq!(events.len(), 1);
289        assert_eq!(events[0].data, "hello");
290    }
291
292    #[test]
293    fn keeps_significant_leading_space_after_single_strip() {
294        let events = parse_sse_events("data:  two\n\n");
295        assert_eq!(events[0].data, " two");
296    }
297}