sse_stream/
lib.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2use std::{
3    collections::VecDeque,
4    fmt::Display,
5    num::ParseIntError,
6    str::Utf8Error,
7    task::{Context, Poll, ready},
8};
9
10use bytes::Buf;
11use futures_util::{Stream, TryStreamExt, stream::MapOk};
12use http_body::{Body, Frame};
13use http_body_util::{BodyDataStream, StreamBody};
14
15pin_project_lite::pin_project! {
16    pub struct SseStream<B: Body> {
17        #[pin]
18        body: BodyDataStream<B>,
19        parsed: VecDeque<Sse>,
20        current: Option<Sse>,
21        unfinished_line: Vec<u8>,
22    }
23}
24
25impl<E, S, D> SseStream<StreamBody<MapOk<S, fn(D) -> Frame<D>>>>
26where
27    S: Stream<Item = Result<D, E>>,
28    E: std::error::Error,
29    D: Buf,
30    StreamBody<MapOk<S, fn(D) -> Frame<D>>>: Body,
31{
32    /// Create a new [`SseStream`] from a stream of [`Bytes`](bytes::Bytes).
33    ///
34    /// This is useful when you interact with clients don't provide response body directly list reqwest.
35    pub fn from_byte_stream(stream: S) -> Self {
36        let stream = stream.map_ok(http_body::Frame::data as fn(D) -> Frame<D>);
37        let body = StreamBody::new(stream);
38        Self {
39            body: BodyDataStream::new(body),
40            parsed: VecDeque::new(),
41            current: None,
42            unfinished_line: Vec::new(),
43        }
44    }
45}
46
47impl<B: Body> SseStream<B> {
48    /// Create a new [`SseStream`] from a [`Body`].
49    pub fn new(body: B) -> Self {
50        Self {
51            body: BodyDataStream::new(body),
52            parsed: VecDeque::new(),
53            current: None,
54            unfinished_line: Vec::new(),
55        }
56    }
57}
58
59#[derive(Default, Debug)]
60pub struct Sse {
61    pub event: Option<String>,
62    pub data: Option<String>,
63    pub id: Option<String>,
64    pub retry: Option<u64>,
65}
66
67impl Sse {
68    pub fn is_event(&self) -> bool {
69        self.event.is_some()
70    }
71    pub fn is_message(&self) -> bool {
72        self.event.is_none()
73    }
74}
75
76pub enum Error<B: Body> {
77    Body(B::Error),
78    InvalidLine,
79    DuplicatedEventLine,
80    DuplicatedIdLine,
81    DuplicatedRetry,
82    Utf8Parse(Utf8Error),
83    IntParse(ParseIntError),
84}
85
86impl<B: Body> std::fmt::Display for Error<B>
87where
88    B::Error: Display,
89{
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        match self {
92            Error::Body(e) => write!(f, "body error: {}", e),
93            Error::InvalidLine => write!(f, "invalid line"),
94            Error::DuplicatedEventLine => write!(f, "duplicated event line"),
95            Error::DuplicatedIdLine => write!(f, "duplicated id line"),
96            Error::DuplicatedRetry => write!(f, "duplicated retry line"),
97            Error::Utf8Parse(e) => write!(f, "utf8 parse error: {}", e),
98            Error::IntParse(e) => write!(f, "int parse error: {}", e),
99        }
100    }
101}
102
103impl<B: Body> std::fmt::Debug for Error<B>
104where
105    B::Error: std::error::Error,
106{
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        match self {
109            Error::Body(e) => write!(f, "Body({:?})", e),
110            Error::InvalidLine => write!(f, "InvalidLine"),
111            Error::DuplicatedEventLine => write!(f, "DuplicatedEventLine"),
112            Error::DuplicatedIdLine => write!(f, "DuplicatedIdLine"),
113            Error::DuplicatedRetry => write!(f, "DuplicatedRetry"),
114            Error::Utf8Parse(e) => write!(f, "Utf8Parse({:?})", e),
115            Error::IntParse(e) => write!(f, "IntParse({:?})", e),
116        }
117    }
118}
119
120impl<B: Body> std::error::Error for Error<B>
121where
122    B::Error: std::error::Error + 'static,
123{
124    fn description(&self) -> &str {
125        match self {
126            Error::Body(_) => "body error",
127            Error::InvalidLine => "invalid line",
128            Error::DuplicatedEventLine => "duplicated event line",
129            Error::DuplicatedIdLine => "duplicated id line",
130            Error::DuplicatedRetry => "duplicated retry line",
131            Error::Utf8Parse(_) => "utf8 parse error",
132            Error::IntParse(_) => "int parse error",
133        }
134    }
135
136    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
137        match self {
138            Error::Body(e) => Some(e),
139            Error::Utf8Parse(e) => Some(e),
140            Error::IntParse(e) => Some(e),
141            _ => None,
142        }
143    }
144}
145
146impl<B: Body> Stream for SseStream<B> {
147    type Item = Result<Sse, Error<B>>;
148
149    fn poll_next(
150        mut self: std::pin::Pin<&mut Self>,
151        cx: &mut Context<'_>,
152    ) -> Poll<Option<Self::Item>> {
153        let this = self.as_mut().project();
154        if let Some(sse) = this.parsed.pop_front() {
155            return Poll::Ready(Some(Ok(sse)));
156        }
157        let next_data = ready!(this.body.poll_next(cx));
158        match next_data {
159            Some(Ok(data)) => {
160                let chunk = data.chunk();
161                if chunk.is_empty() {
162                    return self.poll_next(cx);
163                }
164                let mut lines = chunk.chunk_by(|maybe_nl, _| *maybe_nl != b'\n');
165                let first_line = lines.next().expect("frame is empty");
166                let mut new_unfinished_line = Vec::new();
167                let first_line = if !this.unfinished_line.is_empty() {
168                    this.unfinished_line.extend(first_line);
169                    std::mem::swap(&mut new_unfinished_line, this.unfinished_line);
170                    new_unfinished_line.as_ref()
171                } else {
172                    first_line
173                };
174                let mut lines = std::iter::once(first_line).chain(lines);
175                *this.unfinished_line = loop {
176                    let Some(line) = lines.next() else {
177                        break Vec::new();
178                    };
179                    if line.last().copied() != Some(b'\n') {
180                        break line.to_vec();
181                    }
182                    // remove the trailing \n
183                    let line = &line[..line.len() - 1];
184                    if line.is_empty() {
185                        if let Some(sse) = this.current.take() {
186                            this.parsed.push_back(sse);
187                        }
188                        continue;
189                    }
190                    // find comma
191                    let Some(comma_index) = line.iter().position(|b| *b == b':') else {
192                        return Poll::Ready(Some(Err(Error::InvalidLine)));
193                    };
194                    let field_name = &line[..comma_index];
195                    let field_value = &line[comma_index + 1..];
196                    match field_name {
197                        b"data" => {
198                            let data_line =
199                                std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
200                            // merge data lines
201                            if let Some(Sse { data, .. }) = this.current.as_mut() {
202                                if data.is_none() {
203                                    data.replace(data_line.to_owned());
204                                } else {
205                                    let data = data.as_mut().unwrap();
206                                    data.push('\n');
207                                    data.push_str(data_line);
208                                }
209                            } else {
210                                this.current.replace(Sse {
211                                    event: None,
212                                    data: Some(data_line.to_owned()),
213                                    id: None,
214                                    retry: None,
215                                });
216                            }
217                        }
218                        b"event" => {
219                            let event_value =
220                                std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
221                            if let Some(Sse { event, .. }) = this.current.as_mut() {
222                                if event.is_some() {
223                                    return Poll::Ready(Some(Err(Error::DuplicatedEventLine)));
224                                } else {
225                                    event.replace(event_value.to_owned());
226                                }
227                            } else {
228                                this.current.replace(Sse {
229                                    event: Some(event_value.to_owned()),
230                                    ..Default::default()
231                                });
232                            }
233                        }
234                        b"id" => {
235                            let id_value =
236                                std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
237                            if let Some(Sse { id, .. }) = this.current.as_mut() {
238                                if id.is_some() {
239                                    return Poll::Ready(Some(Err(Error::DuplicatedIdLine)));
240                                } else {
241                                    id.replace(id_value.to_owned());
242                                }
243                            } else {
244                                this.current.replace(Sse {
245                                    id: Some(id_value.to_owned()),
246                                    ..Default::default()
247                                });
248                            }
249                        }
250                        b"retry" => {
251                            let retry_value = std::str::from_utf8(field_value)
252                                .map_err(Error::Utf8Parse)?
253                                .trim_ascii();
254                            let retry_value =
255                                retry_value.parse::<u64>().map_err(Error::IntParse)?;
256                            if let Some(Sse { retry, .. }) = this.current.as_mut() {
257                                if retry.is_some() {
258                                    return Poll::Ready(Some(Err(Error::DuplicatedRetry)));
259                                } else {
260                                    retry.replace(retry_value);
261                                }
262                            } else {
263                                this.current.replace(Sse {
264                                    retry: Some(retry_value),
265                                    ..Default::default()
266                                });
267                            }
268                        }
269                        b"" => {
270                            #[cfg(feature = "tracing")]
271                            if tracing::enabled!(tracing::Level::DEBUG) {
272                                // a comment
273                                let comment =
274                                    std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
275                                tracing::debug!(?comment, "sse comment line");
276                            }
277                        }
278                        _ => {
279                            return Poll::Ready(Some(Err(Error::InvalidLine)));
280                        }
281                    }
282                };
283                self.poll_next(cx)
284            }
285            Some(Err(e)) => Poll::Ready(Some(Err(Error::Body(e)))),
286            None => {
287                if let Some(sse) = this.current.take() {
288                    Poll::Ready(Some(Ok(sse)))
289                } else {
290                    Poll::Ready(None)
291                }
292            }
293        }
294    }
295}