Skip to main content

openapi_contract/
sse.rs

1use bytes::Bytes;
2use futures_core::Stream;
3use pin_project_lite::pin_project;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7/// A single Server-Sent Event.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct SseEvent {
10    pub event: Option<String>,
11    pub data: String,
12    pub id: Option<String>,
13    pub retry: Option<u64>,
14}
15
16/// Parse a raw SSE text chunk into events.
17///
18/// SSE spec: fields separated by newlines, events separated by blank lines.
19pub fn parse_sse_chunk(text: &str) -> Vec<SseEvent> {
20    let normalized = text.replace("\r\n", "\n");
21    let mut events = Vec::new();
22    let mut event_type: Option<String> = None;
23    let mut data_lines: Vec<&str> = Vec::new();
24    let mut id: Option<String> = None;
25    let mut retry: Option<u64> = None;
26
27    for line in normalized.lines() {
28        if line.is_empty() {
29            if !data_lines.is_empty() {
30                events.push(SseEvent {
31                    event: event_type.take(),
32                    data: data_lines.join("\n"),
33                    id: id.take(),
34                    retry: retry.take(),
35                });
36                data_lines.clear();
37            } else {
38                // Empty event block should clear any accumulated metadata.
39                event_type = None;
40                id = None;
41                retry = None;
42            }
43            continue;
44        }
45
46        if let Some(value) = line.strip_prefix("data:") {
47            data_lines.push(value.strip_prefix(' ').unwrap_or(value));
48        } else if let Some(value) = line.strip_prefix("event:") {
49            event_type = Some(value.strip_prefix(' ').unwrap_or(value).to_string());
50        } else if let Some(value) = line.strip_prefix("id:") {
51            id = Some(value.strip_prefix(' ').unwrap_or(value).to_string());
52        } else if let Some(value) = line.strip_prefix("retry:") {
53            let v = value.strip_prefix(' ').unwrap_or(value);
54            retry = v.parse().ok();
55        }
56    }
57
58    if !data_lines.is_empty() {
59        events.push(SseEvent {
60            event: event_type,
61            data: data_lines.join("\n"),
62            id,
63            retry,
64        });
65    }
66
67    events
68}
69
70pin_project! {
71    /// A stream of SSE events from a `reqwest::Response`.
72    pub struct SseStream {
73        #[pin]
74        inner: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
75        buffer: String,
76        pending: std::collections::VecDeque<SseEvent>,
77    }
78}
79
80impl SseStream {
81    pub fn new(
82        byte_stream: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
83    ) -> Self {
84        Self {
85            inner: byte_stream,
86            buffer: String::new(),
87            pending: std::collections::VecDeque::new(),
88        }
89    }
90}
91
92impl Stream for SseStream {
93    type Item = Result<SseEvent, crate::ApiError>;
94
95    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
96        let mut this = self.project();
97
98        if let Some(ev) = this.pending.pop_front() {
99            return Poll::Ready(Some(Ok(ev)));
100        }
101
102        loop {
103            if let Some(first_blank) = this.buffer.find("\n\n") {
104                let chunk = this.buffer[..first_blank].to_string();
105                *this.buffer = this.buffer[first_blank + 2..].to_string();
106                let mut parsed = parse_sse_chunk(&chunk);
107                if !parsed.is_empty() {
108                    let first = parsed.remove(0);
109                    this.pending.extend(parsed);
110                    return Poll::Ready(Some(Ok(first)));
111                }
112                continue;
113            }
114
115            match this.inner.as_mut().poll_next(cx) {
116                Poll::Ready(Some(Ok(bytes))) => {
117                    let text = String::from_utf8_lossy(&bytes);
118                    let appended = text.replace("\r\n", "\n");
119                    this.buffer.push_str(&appended);
120                }
121                Poll::Ready(Some(Err(e))) => {
122                    return Poll::Ready(Some(Err(crate::ApiError::Http(e))));
123                }
124                Poll::Ready(None) => {
125                    if this.buffer.is_empty() {
126                        return Poll::Ready(None);
127                    }
128                    let remaining = std::mem::take(this.buffer);
129                    let mut parsed = parse_sse_chunk(&remaining);
130                    if parsed.is_empty() {
131                        return Poll::Ready(None);
132                    }
133                    let first = parsed.remove(0);
134                    this.pending.extend(parsed);
135                    return Poll::Ready(Some(Ok(first)));
136                }
137                Poll::Pending => return Poll::Pending,
138            }
139        }
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use futures_util::StreamExt;
147
148    // ── parse_sse_chunk ─────────────────────────────────────────
149
150    #[test]
151    fn parse_all_fields() {
152        // Simple event
153        let events = parse_sse_chunk("data: hello\n\n");
154        assert_eq!(events.len(), 1);
155        assert_eq!(events[0].data, "hello");
156        assert_eq!(events[0].event, None);
157
158        // Event type + JSON data
159        let events = parse_sse_chunk("event: update\ndata: {\"foo\":1}\n\n");
160        assert_eq!(events[0].event.as_deref(), Some("update"));
161
162        // Multiline data
163        let events = parse_sse_chunk("data: line1\ndata: line2\n\n");
164        assert_eq!(events[0].data, "line1\nline2");
165
166        // id + retry
167        let events = parse_sse_chunk("id: 42\nretry: 3000\ndata: ping\n\n");
168        assert_eq!(events[0].id.as_deref(), Some("42"));
169        assert_eq!(events[0].retry, Some(3000));
170    }
171
172    #[test]
173    fn parse_multiple_events_and_blanks() {
174        let events = parse_sse_chunk("data: first\n\ndata: second\n\n");
175        assert_eq!(events.len(), 2);
176        assert_eq!(events[0].data, "first");
177        assert_eq!(events[1].data, "second");
178
179        // Extra blank lines between events
180        let events = parse_sse_chunk("data: a\n\n\n\ndata: b\n\n");
181        assert_eq!(events.len(), 2);
182
183        // CRLF normalization
184        let events = parse_sse_chunk("data: hello\r\n\r\ndata: world\r\n\r\n");
185        assert_eq!(events.len(), 2);
186        assert_eq!(events[0].data, "hello");
187    }
188
189    #[test]
190    fn parse_edge_cases() {
191        // Empty input
192        assert!(parse_sse_chunk("").is_empty());
193
194        // No trailing blank line (unterminated)
195        let events = parse_sse_chunk("data: unterminated");
196        assert_eq!(events[0].data, "unterminated");
197
198        // No space after colon
199        let events = parse_sse_chunk("data:nospace\n\n");
200        assert_eq!(events[0].data, "nospace");
201
202        // Comment ignored
203        let events = parse_sse_chunk(": comment\ndata: real\n\n");
204        assert_eq!(events.len(), 1);
205        assert_eq!(events[0].data, "real");
206
207        // Invalid retry value
208        let events = parse_sse_chunk("retry: notanumber\ndata: x\n\n");
209        assert_eq!(events[0].retry, None);
210    }
211
212    #[test]
213    fn parse_state_reset_on_empty_block() {
214        // Event type, id, retry should reset on empty block
215        let events = parse_sse_chunk("event: stale\n\ndata: clean\n\n");
216        assert_eq!(events.len(), 1);
217        assert_eq!(events[0].event, None);
218
219        let events = parse_sse_chunk("id: old\nretry: 5000\n\ndata: fresh\n\n");
220        assert_eq!(events[0].id, None);
221        assert_eq!(events[0].retry, None);
222    }
223
224    // ── SseStream ───────────────────────────────────────────────
225
226    fn mock_byte_stream(
227        chunks: Vec<Result<Bytes, reqwest::Error>>,
228    ) -> Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>> {
229        Box::pin(futures_util::stream::iter(chunks))
230    }
231
232    #[tokio::test]
233    async fn stream_chunked_delivery() {
234        // Data split across chunks + CRLF normalization
235        let chunks = vec![
236            Ok(Bytes::from("data: first\r\n\r\ndata: sec")),
237            Ok(Bytes::from("ond\n\n")),
238        ];
239        let mut stream = SseStream::new(mock_byte_stream(chunks));
240        assert_eq!(stream.next().await.unwrap().unwrap().data, "first");
241        assert_eq!(stream.next().await.unwrap().unwrap().data, "second");
242        assert!(stream.next().await.is_none());
243    }
244
245    #[tokio::test]
246    async fn stream_multiple_events_single_chunk() {
247        // Multiple events in one chunk exercises the pending queue
248        let chunks = vec![Ok(Bytes::from("data: a\n\ndata: b\n\ndata: c\n\n"))];
249        let mut stream = SseStream::new(mock_byte_stream(chunks));
250        assert_eq!(stream.next().await.unwrap().unwrap().data, "a");
251        assert_eq!(stream.next().await.unwrap().unwrap().data, "b");
252        assert_eq!(stream.next().await.unwrap().unwrap().data, "c");
253        assert!(stream.next().await.is_none());
254
255        // Empty event blocks between real events
256        let chunks = vec![Ok(Bytes::from("data: x\n\n\n\ndata: y\n\n"))];
257        let mut stream = SseStream::new(mock_byte_stream(chunks));
258        assert_eq!(stream.next().await.unwrap().unwrap().data, "x");
259        assert_eq!(stream.next().await.unwrap().unwrap().data, "y");
260    }
261
262    #[tokio::test]
263    async fn stream_end_of_inner() {
264        // Empty stream
265        let mut stream = SseStream::new(mock_byte_stream(vec![]));
266        assert!(stream.next().await.is_none());
267
268        // Remaining data when inner ends
269        let mut stream = SseStream::new(mock_byte_stream(vec![Ok(Bytes::from("data: trailing"))]));
270        assert_eq!(stream.next().await.unwrap().unwrap().data, "trailing");
271        assert!(stream.next().await.is_none());
272
273        // Only non-data content (comment) when inner ends — no events
274        let mut stream = SseStream::new(mock_byte_stream(vec![Ok(Bytes::from(": comment only"))]));
275        assert!(stream.next().await.is_none());
276    }
277
278    #[tokio::test]
279    async fn stream_error_from_inner() {
280        let err = reqwest::Client::new()
281            .get("http://localhost:1/x")
282            .header("bad\0header", "v")
283            .build()
284            .unwrap_err();
285        let mut stream = SseStream::new(mock_byte_stream(vec![Err(err)]));
286        assert!(stream.next().await.unwrap().is_err());
287    }
288
289    #[tokio::test]
290    async fn stream_pending_then_data() {
291        let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, reqwest::Error>>(2);
292        let rx_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
293        let mut stream = SseStream::new(Box::pin(rx_stream));
294
295        tokio::spawn(async move {
296            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
297            tx.send(Ok(Bytes::from("data: delayed\n\n"))).await.unwrap();
298            drop(tx);
299        });
300
301        assert_eq!(stream.next().await.unwrap().unwrap().data, "delayed");
302        assert!(stream.next().await.is_none());
303    }
304}