Skip to main content

sse_stream/
stream.rs

1use std::{
2    collections::VecDeque,
3    num::ParseIntError,
4    str::Utf8Error,
5    task::{ready, Context, Poll},
6};
7
8use crate::Sse;
9use bytes::Buf;
10use futures_util::{stream::MapOk, Stream, TryStreamExt};
11use http_body::{Body, Frame};
12use http_body_util::{BodyDataStream, StreamBody};
13
14#[derive(Debug)]
15enum BomHeaderState {
16    Parsing(Vec<u8>),
17    Consumed,
18}
19
20const BOM_HEADER: &[u8] = b"\xEF\xBB\xBF";
21
22// Try to consume the BOM header from the given bytes.
23// If the BOM header is found, return the remaining bytes, otherwise return the origin buffer.
24// Return `None` if we cannot determine whether the BOM header is present.
25fn try_consume_bom_header(buf: &[u8]) -> Option<&[u8]> {
26    if buf.len() < BOM_HEADER.len() {
27        if BOM_HEADER.starts_with(buf) {
28            None
29        } else {
30            Some(buf)
31        }
32    } else if buf.starts_with(BOM_HEADER) {
33        Some(&buf[BOM_HEADER.len()..])
34    } else {
35        Some(buf)
36    }
37}
38
39pin_project_lite::pin_project! {
40    pub struct SseStream<B: Body> {
41        #[pin]
42        body: BodyDataStream<B>,
43        parsed: VecDeque<Sse>,
44        current: Option<Sse>,
45        unfinished_line: Vec<u8>,
46        bom_header_state: BomHeaderState,
47    }
48}
49
50pub type ByteStreamBody<S, D> = StreamBody<MapOk<S, fn(D) -> Frame<D>>>;
51impl<E, S, D> SseStream<ByteStreamBody<S, D>>
52where
53    S: Stream<Item = Result<D, E>>,
54    E: std::error::Error,
55    D: Buf,
56    StreamBody<ByteStreamBody<S, D>>: Body,
57{
58    /// Create a new [`SseStream`] from a stream of [`Bytes`](bytes::Bytes).
59    ///
60    /// This is useful when you interact with clients don't provide response body directly list reqwest.
61    pub fn from_byte_stream(stream: S) -> Self {
62        let stream = stream.map_ok(http_body::Frame::data as fn(D) -> Frame<D>);
63        let body = StreamBody::new(stream);
64        Self {
65            body: BodyDataStream::new(body),
66            parsed: VecDeque::new(),
67            current: None,
68            unfinished_line: Vec::new(),
69            bom_header_state: BomHeaderState::Parsing(Vec::new()),
70        }
71    }
72}
73
74impl<B: Body> SseStream<B> {
75    /// Create a new [`SseStream`] from a [`Body`].
76    pub fn new(body: B) -> Self {
77        Self {
78            body: BodyDataStream::new(body),
79            parsed: VecDeque::new(),
80            current: None,
81            unfinished_line: Vec::new(),
82            bom_header_state: BomHeaderState::Parsing(Vec::new()),
83        }
84    }
85}
86
87pub enum Error {
88    Body(Box<dyn std::error::Error + Send + Sync>),
89    InvalidLine,
90    DuplicatedEventLine,
91    DuplicatedIdLine,
92    DuplicatedRetry,
93    Utf8Parse(Utf8Error),
94    IntParse(ParseIntError),
95}
96
97impl std::fmt::Display for Error {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        match self {
100            Error::Body(e) => write!(f, "body error: {}", e),
101            Error::InvalidLine => write!(f, "invalid line"),
102            Error::DuplicatedEventLine => write!(f, "duplicated event line"),
103            Error::DuplicatedIdLine => write!(f, "duplicated id line"),
104            Error::DuplicatedRetry => write!(f, "duplicated retry line"),
105            Error::Utf8Parse(e) => write!(f, "utf8 parse error: {}", e),
106            Error::IntParse(e) => write!(f, "int parse error: {}", e),
107        }
108    }
109}
110
111impl std::fmt::Debug for Error {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        match self {
114            Error::Body(e) => write!(f, "Body({:?})", e),
115            Error::InvalidLine => write!(f, "InvalidLine"),
116            Error::DuplicatedEventLine => write!(f, "DuplicatedEventLine"),
117            Error::DuplicatedIdLine => write!(f, "DuplicatedIdLine"),
118            Error::DuplicatedRetry => write!(f, "DuplicatedRetry"),
119            Error::Utf8Parse(e) => write!(f, "Utf8Parse({:?})", e),
120            Error::IntParse(e) => write!(f, "IntParse({:?})", e),
121        }
122    }
123}
124
125impl std::error::Error for Error {
126    fn description(&self) -> &str {
127        match self {
128            Error::Body(_) => "body error",
129            Error::InvalidLine => "invalid line",
130            Error::DuplicatedEventLine => "duplicated event line",
131            Error::DuplicatedIdLine => "duplicated id line",
132            Error::DuplicatedRetry => "duplicated retry line",
133            Error::Utf8Parse(_) => "utf8 parse error",
134            Error::IntParse(_) => "int parse error",
135        }
136    }
137
138    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
139        match self {
140            Error::Body(e) => Some(e.as_ref()),
141            Error::Utf8Parse(e) => Some(e),
142            Error::IntParse(e) => Some(e),
143            _ => None,
144        }
145    }
146}
147
148impl<B: Body> Stream for SseStream<B>
149where
150    B::Error: std::error::Error + Send + Sync + 'static,
151{
152    type Item = Result<Sse, Error>;
153
154    fn poll_next(
155        mut self: std::pin::Pin<&mut Self>,
156        cx: &mut Context<'_>,
157    ) -> Poll<Option<Self::Item>> {
158        let this = self.as_mut().project();
159        if let Some(sse) = this.parsed.pop_front() {
160            return Poll::Ready(Some(Ok(sse)));
161        }
162        let next_data = ready!(this.body.poll_next(cx));
163        match next_data {
164            Some(Ok(data)) => {
165                let stripped_vec = if let BomHeaderState::Parsing(buf) = this.bom_header_state {
166                    buf.extend_from_slice(data.chunk());
167                    if let Some(stripped) = try_consume_bom_header(buf) {
168                        let stripped_vec = stripped.to_vec();
169                        *this.bom_header_state = BomHeaderState::Consumed;
170                        Some(stripped_vec)
171                    } else {
172                        return self.poll_next(cx);
173                    }
174                } else {
175                    None
176                };
177
178                let chunk = stripped_vec.as_deref().unwrap_or(data.chunk());
179
180                if chunk.is_empty() {
181                    return self.poll_next(cx);
182                }
183                let mut lines = chunk.chunk_by(|maybe_nl, _| *maybe_nl != b'\n');
184                let first_line = lines.next().expect("frame is empty");
185                let mut new_unfinished_line = Vec::new();
186                let first_line = if !this.unfinished_line.is_empty() {
187                    this.unfinished_line.extend(first_line);
188                    std::mem::swap(&mut new_unfinished_line, this.unfinished_line);
189                    new_unfinished_line.as_ref()
190                } else {
191                    first_line
192                };
193                let mut lines = std::iter::once(first_line).chain(lines);
194                *this.unfinished_line = loop {
195                    let Some(line) = lines.next() else {
196                        break Vec::new();
197                    };
198                    let line = if line.ends_with(b"\r\n") {
199                        &line[..line.len() - 2]
200                    } else if line.ends_with(b"\n") || line.ends_with(b"\r") {
201                        &line[..line.len() - 1]
202                    } else {
203                        break line.to_vec();
204                    };
205
206                    if line.is_empty() {
207                        if let Some(sse) = this.current.take() {
208                            this.parsed.push_back(sse);
209                        }
210                        continue;
211                    }
212                    // find comma
213                    let Some(comma_index) = line.iter().position(|b| *b == b':') else {
214                        #[cfg(feature = "tracing")]
215                        tracing::warn!(?line, "invalid line, missing `:`");
216                        return Poll::Ready(Some(Err(Error::InvalidLine)));
217                    };
218                    let field_name = &line[..comma_index];
219                    let field_value = if line.len() > comma_index + 1 {
220                        let field_value = &line[comma_index + 1..];
221                        if field_value.starts_with(b" ") {
222                            &field_value[1..]
223                        } else {
224                            field_value
225                        }
226                    } else {
227                        b""
228                    };
229                    match field_name {
230                        b"data" => {
231                            let data_line =
232                                std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
233                            // merge data lines
234                            if let Some(Sse { data, .. }) = this.current.as_mut() {
235                                if data.is_none() {
236                                    data.replace(data_line.to_owned());
237                                } else {
238                                    let data = data.as_mut().unwrap();
239                                    data.push('\n');
240                                    data.push_str(data_line);
241                                }
242                            } else {
243                                this.current.replace(Sse {
244                                    event: None,
245                                    data: Some(data_line.to_owned()),
246                                    id: None,
247                                    retry: None,
248                                });
249                            }
250                        }
251                        b"event" => {
252                            let event_value =
253                                std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
254                            if let Some(Sse { event, .. }) = this.current.as_mut() {
255                                if event.is_some() {
256                                    return Poll::Ready(Some(Err(Error::DuplicatedEventLine)));
257                                } else {
258                                    event.replace(event_value.to_owned());
259                                }
260                            } else {
261                                this.current.replace(Sse {
262                                    event: Some(event_value.to_owned()),
263                                    ..Default::default()
264                                });
265                            }
266                        }
267                        b"id" => {
268                            let id_value =
269                                std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
270                            if let Some(Sse { id, .. }) = this.current.as_mut() {
271                                if id.is_some() {
272                                    return Poll::Ready(Some(Err(Error::DuplicatedIdLine)));
273                                } else {
274                                    id.replace(id_value.to_owned());
275                                }
276                            } else {
277                                this.current.replace(Sse {
278                                    id: Some(id_value.to_owned()),
279                                    ..Default::default()
280                                });
281                            }
282                        }
283                        b"retry" => {
284                            let retry_value = std::str::from_utf8(field_value)
285                                .map_err(Error::Utf8Parse)?
286                                .trim_ascii();
287                            let retry_value =
288                                retry_value.parse::<u64>().map_err(Error::IntParse)?;
289                            if let Some(Sse { retry, .. }) = this.current.as_mut() {
290                                if retry.is_some() {
291                                    return Poll::Ready(Some(Err(Error::DuplicatedRetry)));
292                                } else {
293                                    retry.replace(retry_value);
294                                }
295                            } else {
296                                this.current.replace(Sse {
297                                    retry: Some(retry_value),
298                                    ..Default::default()
299                                });
300                            }
301                        }
302                        b"" => {
303                            #[cfg(feature = "tracing")]
304                            if tracing::enabled!(tracing::Level::DEBUG) {
305                                // a comment
306                                let comment =
307                                    std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
308                                tracing::debug!(?comment, "sse comment line");
309                            }
310                        }
311                        _line => {
312                            #[cfg(feature = "tracing")]
313                            if tracing::enabled!(tracing::Level::WARN) {
314                                tracing::warn!(line = ?_line, "invalid line: unknown field");
315                            }
316                            return Poll::Ready(Some(Err(Error::InvalidLine)));
317                        }
318                    }
319                };
320                self.poll_next(cx)
321            }
322            Some(Err(e)) => Poll::Ready(Some(Err(Error::Body(Box::new(e))))),
323            None => {
324                if let Some(sse) = this.current.take() {
325                    Poll::Ready(Some(Ok(sse)))
326                } else {
327                    Poll::Ready(None)
328                }
329            }
330        }
331    }
332}