1use bytes::Bytes;
2use futures_core::Stream;
3use pin_project_lite::pin_project;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct SseEvent {
10 pub event: Option<String>,
11 pub data: String,
12 pub id: Option<String>,
13 pub retry: Option<u64>,
14}
15
16pub fn parse_sse_chunk(text: &str) -> Vec<SseEvent> {
20 let normalized = text.replace("\r\n", "\n");
21 let mut events = Vec::new();
22 let mut event_type: Option<String> = None;
23 let mut data_lines: Vec<&str> = Vec::new();
24 let mut id: Option<String> = None;
25 let mut retry: Option<u64> = None;
26
27 for line in normalized.lines() {
28 if line.is_empty() {
29 if !data_lines.is_empty() {
30 events.push(SseEvent {
31 event: event_type.take(),
32 data: data_lines.join("\n"),
33 id: id.take(),
34 retry: retry.take(),
35 });
36 data_lines.clear();
37 } else {
38 event_type = None;
40 id = None;
41 retry = None;
42 }
43 continue;
44 }
45
46 if let Some(value) = line.strip_prefix("data:") {
47 data_lines.push(value.strip_prefix(' ').unwrap_or(value));
48 } else if let Some(value) = line.strip_prefix("event:") {
49 event_type = Some(value.strip_prefix(' ').unwrap_or(value).to_string());
50 } else if let Some(value) = line.strip_prefix("id:") {
51 id = Some(value.strip_prefix(' ').unwrap_or(value).to_string());
52 } else if let Some(value) = line.strip_prefix("retry:") {
53 let v = value.strip_prefix(' ').unwrap_or(value);
54 retry = v.parse().ok();
55 }
56 }
57
58 if !data_lines.is_empty() {
59 events.push(SseEvent {
60 event: event_type,
61 data: data_lines.join("\n"),
62 id,
63 retry,
64 });
65 }
66
67 events
68}
69
70pin_project! {
71 pub struct SseStream {
73 #[pin]
74 inner: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
75 buffer: String,
76 pending: std::collections::VecDeque<SseEvent>,
77 }
78}
79
80impl SseStream {
81 pub fn new(
82 byte_stream: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
83 ) -> Self {
84 Self {
85 inner: byte_stream,
86 buffer: String::new(),
87 pending: std::collections::VecDeque::new(),
88 }
89 }
90}
91
92impl Stream for SseStream {
93 type Item = Result<SseEvent, crate::ApiError>;
94
95 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
96 let mut this = self.project();
97
98 if let Some(ev) = this.pending.pop_front() {
99 return Poll::Ready(Some(Ok(ev)));
100 }
101
102 loop {
103 if let Some(first_blank) = this.buffer.find("\n\n") {
104 let chunk = this.buffer[..first_blank].to_string();
105 *this.buffer = this.buffer[first_blank + 2..].to_string();
106 let mut parsed = parse_sse_chunk(&chunk);
107 if !parsed.is_empty() {
108 let first = parsed.remove(0);
109 this.pending.extend(parsed);
110 return Poll::Ready(Some(Ok(first)));
111 }
112 continue;
113 }
114
115 match this.inner.as_mut().poll_next(cx) {
116 Poll::Ready(Some(Ok(bytes))) => {
117 let text = String::from_utf8_lossy(&bytes);
118 let appended = text.replace("\r\n", "\n");
119 this.buffer.push_str(&appended);
120 }
121 Poll::Ready(Some(Err(e))) => {
122 return Poll::Ready(Some(Err(crate::ApiError::Http(e))));
123 }
124 Poll::Ready(None) => {
125 if this.buffer.is_empty() {
126 return Poll::Ready(None);
127 }
128 let remaining = std::mem::take(this.buffer);
129 let mut parsed = parse_sse_chunk(&remaining);
130 if parsed.is_empty() {
131 return Poll::Ready(None);
132 }
133 let first = parsed.remove(0);
134 this.pending.extend(parsed);
135 return Poll::Ready(Some(Ok(first)));
136 }
137 Poll::Pending => return Poll::Pending,
138 }
139 }
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use futures_util::StreamExt;
147
148 #[test]
151 fn parse_all_fields() {
152 let events = parse_sse_chunk("data: hello\n\n");
154 assert_eq!(events.len(), 1);
155 assert_eq!(events[0].data, "hello");
156 assert_eq!(events[0].event, None);
157
158 let events = parse_sse_chunk("event: update\ndata: {\"foo\":1}\n\n");
160 assert_eq!(events[0].event.as_deref(), Some("update"));
161
162 let events = parse_sse_chunk("data: line1\ndata: line2\n\n");
164 assert_eq!(events[0].data, "line1\nline2");
165
166 let events = parse_sse_chunk("id: 42\nretry: 3000\ndata: ping\n\n");
168 assert_eq!(events[0].id.as_deref(), Some("42"));
169 assert_eq!(events[0].retry, Some(3000));
170 }
171
172 #[test]
173 fn parse_multiple_events_and_blanks() {
174 let events = parse_sse_chunk("data: first\n\ndata: second\n\n");
175 assert_eq!(events.len(), 2);
176 assert_eq!(events[0].data, "first");
177 assert_eq!(events[1].data, "second");
178
179 let events = parse_sse_chunk("data: a\n\n\n\ndata: b\n\n");
181 assert_eq!(events.len(), 2);
182
183 let events = parse_sse_chunk("data: hello\r\n\r\ndata: world\r\n\r\n");
185 assert_eq!(events.len(), 2);
186 assert_eq!(events[0].data, "hello");
187 }
188
189 #[test]
190 fn parse_edge_cases() {
191 assert!(parse_sse_chunk("").is_empty());
193
194 let events = parse_sse_chunk("data: unterminated");
196 assert_eq!(events[0].data, "unterminated");
197
198 let events = parse_sse_chunk("data:nospace\n\n");
200 assert_eq!(events[0].data, "nospace");
201
202 let events = parse_sse_chunk(": comment\ndata: real\n\n");
204 assert_eq!(events.len(), 1);
205 assert_eq!(events[0].data, "real");
206
207 let events = parse_sse_chunk("retry: notanumber\ndata: x\n\n");
209 assert_eq!(events[0].retry, None);
210 }
211
212 #[test]
213 fn parse_state_reset_on_empty_block() {
214 let events = parse_sse_chunk("event: stale\n\ndata: clean\n\n");
216 assert_eq!(events.len(), 1);
217 assert_eq!(events[0].event, None);
218
219 let events = parse_sse_chunk("id: old\nretry: 5000\n\ndata: fresh\n\n");
220 assert_eq!(events[0].id, None);
221 assert_eq!(events[0].retry, None);
222 }
223
224 fn mock_byte_stream(
227 chunks: Vec<Result<Bytes, reqwest::Error>>,
228 ) -> Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>> {
229 Box::pin(futures_util::stream::iter(chunks))
230 }
231
232 #[tokio::test]
233 async fn stream_chunked_delivery() {
234 let chunks = vec![
236 Ok(Bytes::from("data: first\r\n\r\ndata: sec")),
237 Ok(Bytes::from("ond\n\n")),
238 ];
239 let mut stream = SseStream::new(mock_byte_stream(chunks));
240 assert_eq!(stream.next().await.unwrap().unwrap().data, "first");
241 assert_eq!(stream.next().await.unwrap().unwrap().data, "second");
242 assert!(stream.next().await.is_none());
243 }
244
245 #[tokio::test]
246 async fn stream_multiple_events_single_chunk() {
247 let chunks = vec![Ok(Bytes::from("data: a\n\ndata: b\n\ndata: c\n\n"))];
249 let mut stream = SseStream::new(mock_byte_stream(chunks));
250 assert_eq!(stream.next().await.unwrap().unwrap().data, "a");
251 assert_eq!(stream.next().await.unwrap().unwrap().data, "b");
252 assert_eq!(stream.next().await.unwrap().unwrap().data, "c");
253 assert!(stream.next().await.is_none());
254
255 let chunks = vec![Ok(Bytes::from("data: x\n\n\n\ndata: y\n\n"))];
257 let mut stream = SseStream::new(mock_byte_stream(chunks));
258 assert_eq!(stream.next().await.unwrap().unwrap().data, "x");
259 assert_eq!(stream.next().await.unwrap().unwrap().data, "y");
260 }
261
262 #[tokio::test]
263 async fn stream_end_of_inner() {
264 let mut stream = SseStream::new(mock_byte_stream(vec![]));
266 assert!(stream.next().await.is_none());
267
268 let mut stream = SseStream::new(mock_byte_stream(vec![Ok(Bytes::from("data: trailing"))]));
270 assert_eq!(stream.next().await.unwrap().unwrap().data, "trailing");
271 assert!(stream.next().await.is_none());
272
273 let mut stream = SseStream::new(mock_byte_stream(vec![Ok(Bytes::from(": comment only"))]));
275 assert!(stream.next().await.is_none());
276 }
277
278 #[tokio::test]
279 async fn stream_error_from_inner() {
280 let err = reqwest::Client::new()
281 .get("http://localhost:1/x")
282 .header("bad\0header", "v")
283 .build()
284 .unwrap_err();
285 let mut stream = SseStream::new(mock_byte_stream(vec![Err(err)]));
286 assert!(stream.next().await.unwrap().is_err());
287 }
288
289 #[tokio::test]
290 async fn stream_pending_then_data() {
291 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, reqwest::Error>>(2);
292 let rx_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
293 let mut stream = SseStream::new(Box::pin(rx_stream));
294
295 tokio::spawn(async move {
296 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
297 tx.send(Ok(Bytes::from("data: delayed\n\n"))).await.unwrap();
298 drop(tx);
299 });
300
301 assert_eq!(stream.next().await.unwrap().unwrap().data, "delayed");
302 assert!(stream.next().await.is_none());
303 }
304}