Skip to main content

better_fetch/
sse.rs

1//! Server-Sent Events (`text/event-stream`) helpers for [`StreamingResponse`](crate::StreamingResponse).
2//!
3//! Enable with Cargo feature `sse` on the `better-fetch` crate.
4//!
5//! Use [`SseDecoder`] for incremental parsing, [`parse_sse_events`] for complete buffers, or
6//! [`StreamingResponse::read_sse_events`](crate::StreamingResponse::read_sse_events) to buffer first.
7
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use futures_util::Stream;
12use pin_project_lite::pin_project;
13
14use crate::Result;
15
16/// One SSE event (may aggregate multiple `data:` lines).
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct SseEvent {
19    /// Optional `event:` field.
20    pub event: Option<String>,
21    /// Concatenated `data:` lines joined with `\n`.
22    pub data: String,
23    /// Optional `id:` field.
24    pub id: Option<String>,
25}
26
27/// Incrementally parses SSE events from UTF-8 chunks (blocks delimited by `\n\n`).
28#[derive(Debug, Default)]
29pub struct SseDecoder {
30    buffer: String,
31}
32
33impl SseDecoder {
34    /// Creates an empty decoder.
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Appends a chunk and returns any complete events parsed from the buffer.
40    pub fn push_chunk(&mut self, chunk: &[u8]) -> Result<Vec<SseEvent>> {
41        let text = std::str::from_utf8(chunk)
42            .map_err(|e| crate::Error::Config(format!("SSE chunk is not valid UTF-8: {e}")))?;
43        self.buffer.push_str(text);
44        Ok(self.drain_complete_events())
45    }
46
47    /// Parses any trailing bytes as a final event block (if non-empty).
48    pub fn finish(mut self) -> Vec<SseEvent> {
49        let tail = std::mem::take(&mut self.buffer);
50        if tail.trim().is_empty() {
51            return Vec::new();
52        }
53        parse_sse_block(&tail).into_iter().collect()
54    }
55
56    fn drain_complete_events(&mut self) -> Vec<SseEvent> {
57        let mut events = Vec::new();
58        while let Some(pos) = self.buffer.find("\n\n") {
59            let block: String = self.buffer.drain(..pos + 2).collect();
60            let block = block.trim();
61            if block.is_empty() {
62                continue;
63            }
64            if let Some(event) = parse_sse_block(block) {
65                events.push(event);
66            }
67        }
68        events
69    }
70}
71
72pin_project! {
73    /// Stream of [`SseEvent`] parsed incrementally from a response body.
74    pub struct SseEventStream {
75        #[pin]
76        body: crate::BodyStream,
77        decoder: SseDecoder,
78        pending: std::collections::VecDeque<SseEvent>,
79        max_bytes: Option<u64>,
80        accumulated: u64,
81        finished: bool,
82    }
83}
84
85impl SseEventStream {
86    pub(crate) fn new(body: crate::BodyStream, max_bytes: Option<u64>) -> Self {
87        Self {
88            body,
89            decoder: SseDecoder::new(),
90            pending: std::collections::VecDeque::new(),
91            max_bytes,
92            accumulated: 0,
93            finished: false,
94        }
95    }
96}
97
98impl Stream for SseEventStream {
99    type Item = Result<SseEvent>;
100
101    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
102        let mut this = self.project();
103
104        if let Some(event) = this.pending.pop_front() {
105            return Poll::Ready(Some(Ok(event)));
106        }
107
108        if *this.finished {
109            return Poll::Ready(None);
110        }
111
112        loop {
113            match this.body.as_mut().poll_next(cx) {
114                Poll::Ready(Some(Ok(chunk))) => {
115                    if let Some(limit) = *this.max_bytes {
116                        *this.accumulated += chunk.len() as u64;
117                        if *this.accumulated > limit {
118                            return Poll::Ready(Some(Err(crate::Error::BodyTooLarge { limit })));
119                        }
120                    }
121                    match this.decoder.push_chunk(&chunk) {
122                        Ok(events) => {
123                            for event in events {
124                                this.pending.push_back(event);
125                            }
126                            if let Some(event) = this.pending.pop_front() {
127                                return Poll::Ready(Some(Ok(event)));
128                            }
129                        }
130                        Err(e) => return Poll::Ready(Some(Err(e))),
131                    }
132                }
133                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
134                Poll::Ready(None) => {
135                    *this.finished = true;
136                    let decoder = std::mem::take(this.decoder);
137                    for event in decoder.finish() {
138                        this.pending.push_back(event);
139                    }
140                    if let Some(event) = this.pending.pop_front() {
141                        return Poll::Ready(Some(Ok(event)));
142                    }
143                    return Poll::Ready(None);
144                }
145                Poll::Pending => return Poll::Pending,
146            }
147        }
148    }
149}
150
151/// Parses SSE events from a buffer (blocks separated by blank lines).
152pub fn parse_sse_events(buffer: &str) -> Vec<SseEvent> {
153    let mut events = Vec::new();
154    for block in buffer.split("\n\n") {
155        let block = block.trim();
156        if block.is_empty() {
157            continue;
158        }
159        if let Some(event) = parse_sse_block(block) {
160            events.push(event);
161        }
162    }
163    events
164}
165
166fn parse_sse_block(block: &str) -> Option<SseEvent> {
167    let mut event_name = None;
168    let mut id = None;
169    let mut data_lines = Vec::new();
170
171    for line in block.lines() {
172        if line.is_empty() || line.starts_with(':') {
173            continue;
174        }
175        if let Some(rest) = line.strip_prefix("event:") {
176            event_name = Some(rest.trim().to_string());
177        } else if let Some(rest) = line.strip_prefix("data:") {
178            data_lines.push(rest.trim_start().to_string());
179        } else if let Some(rest) = line.strip_prefix("id:") {
180            id = Some(rest.trim().to_string());
181        }
182    }
183
184    if data_lines.is_empty() && event_name.is_none() && id.is_none() {
185        return None;
186    }
187
188    Some(SseEvent {
189        event: event_name,
190        data: data_lines.join("\n"),
191        id,
192    })
193}
194
195/// Reads a streaming body to completion and parses SSE events.
196pub async fn read_sse_from_bytes(
197    body: crate::BodyStream,
198    max_bytes: Option<u64>,
199) -> Result<Vec<SseEvent>> {
200    let bytes = crate::streaming::accumulate_stream(body, max_bytes).await?;
201    let text = String::from_utf8_lossy(&bytes);
202    Ok(parse_sse_events(&text))
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn parses_simple_event() {
211        let events = parse_sse_events("data: hello\n\n");
212        assert_eq!(events.len(), 1);
213        assert_eq!(events[0].data, "hello");
214    }
215
216    #[test]
217    fn parses_event_name_and_multiline_data() {
218        let raw = "event: ping\ndata: line1\ndata: line2\n\n";
219        let events = parse_sse_events(raw);
220        assert_eq!(events[0].event.as_deref(), Some("ping"));
221        assert_eq!(events[0].data, "line1\nline2");
222    }
223
224    #[test]
225    fn decoder_splits_across_chunks() {
226        let mut decoder = SseDecoder::new();
227        let first = decoder.push_chunk(b"data: hel").unwrap();
228        assert!(first.is_empty());
229        let second = decoder.push_chunk(b"lo\n\ndata: world\n\n").unwrap();
230        assert_eq!(second.len(), 2);
231        assert_eq!(second[0].data, "hello");
232        assert_eq!(second[1].data, "world");
233    }
234}