1use std::{
2 collections::VecDeque,
3 num::ParseIntError,
4 str::Utf8Error,
5 task::{ready, Context, Poll},
6};
7
8use crate::Sse;
9use bytes::Buf;
10use futures_util::{stream::MapOk, Stream, TryStreamExt};
11use http_body::{Body, Frame};
12use http_body_util::{BodyDataStream, StreamBody};
13
14#[derive(Debug)]
15enum BomHeaderState {
16 NotFoundYet,
17 Parsing,
18 Consumed,
19}
20
21const BOM_HEADER: &[u8] = b"\xEF\xBB\xBF";
22
23pin_project_lite::pin_project! {
24 pub struct SseStream<B: Body> {
25 #[pin]
26 body: BodyDataStream<B>,
27 parsed: VecDeque<Sse>,
28 current: Option<Sse>,
29 unfinished_line: Vec<u8>,
30 mark_last_chunk_ending_with_cr: bool,
31 bom_header_state: BomHeaderState,
32 }
33}
34
35pub type ByteStreamBody<S, D> = StreamBody<MapOk<S, fn(D) -> Frame<D>>>;
36impl<E, S, D> SseStream<ByteStreamBody<S, D>>
37where
38 S: Stream<Item = Result<D, E>>,
39 E: std::error::Error,
40 D: Buf,
41 StreamBody<ByteStreamBody<S, D>>: Body,
42{
43 pub fn from_byte_stream(stream: S) -> Self {
47 let stream = stream.map_ok(http_body::Frame::data as fn(D) -> Frame<D>);
48 let body = StreamBody::new(stream);
49 Self {
50 body: BodyDataStream::new(body),
51 parsed: VecDeque::new(),
52 current: None,
53 unfinished_line: Vec::new(),
54 mark_last_chunk_ending_with_cr: false,
55 bom_header_state: BomHeaderState::NotFoundYet,
56 }
57 }
58}
59
60impl<B: Body> SseStream<B> {
61 pub fn new(body: B) -> Self {
63 Self {
64 body: BodyDataStream::new(body),
65 parsed: VecDeque::new(),
66 current: None,
67 unfinished_line: Vec::new(),
68 mark_last_chunk_ending_with_cr: false,
69 bom_header_state: BomHeaderState::NotFoundYet,
70 }
71 }
72}
73
74pub enum Error {
75 Body(Box<dyn std::error::Error + Send + Sync>),
76 InvalidLine,
77 DuplicatedEventLine,
78 DuplicatedIdLine,
79 DuplicatedRetry,
80 Utf8Parse(Utf8Error),
81 IntParse(ParseIntError),
82}
83
84impl std::fmt::Display for Error {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 match self {
87 Error::Body(e) => write!(f, "body error: {}", e),
88 Error::InvalidLine => write!(f, "invalid line"),
89 Error::DuplicatedEventLine => write!(f, "duplicated event line"),
90 Error::DuplicatedIdLine => write!(f, "duplicated id line"),
91 Error::DuplicatedRetry => write!(f, "duplicated retry line"),
92 Error::Utf8Parse(e) => write!(f, "utf8 parse error: {}", e),
93 Error::IntParse(e) => write!(f, "int parse error: {}", e),
94 }
95 }
96}
97
98impl std::fmt::Debug for Error {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 match self {
101 Error::Body(e) => write!(f, "Body({:?})", e),
102 Error::InvalidLine => write!(f, "InvalidLine"),
103 Error::DuplicatedEventLine => write!(f, "DuplicatedEventLine"),
104 Error::DuplicatedIdLine => write!(f, "DuplicatedIdLine"),
105 Error::DuplicatedRetry => write!(f, "DuplicatedRetry"),
106 Error::Utf8Parse(e) => write!(f, "Utf8Parse({:?})", e),
107 Error::IntParse(e) => write!(f, "IntParse({:?})", e),
108 }
109 }
110}
111
112impl std::error::Error for Error {
113 fn description(&self) -> &str {
114 match self {
115 Error::Body(_) => "body error",
116 Error::InvalidLine => "invalid line",
117 Error::DuplicatedEventLine => "duplicated event line",
118 Error::DuplicatedIdLine => "duplicated id line",
119 Error::DuplicatedRetry => "duplicated retry line",
120 Error::Utf8Parse(_) => "utf8 parse error",
121 Error::IntParse(_) => "int parse error",
122 }
123 }
124
125 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
126 match self {
127 Error::Body(e) => Some(e.as_ref()),
128 Error::Utf8Parse(e) => Some(e),
129 Error::IntParse(e) => Some(e),
130 _ => None,
131 }
132 }
133}
134
135impl<B: Body> Stream for SseStream<B>
136where
137 B::Error: std::error::Error + Send + Sync + 'static,
138{
139 type Item = Result<Sse, Error>;
140
141 fn poll_next(
142 mut self: std::pin::Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 ) -> Poll<Option<Self::Item>> {
145 let this = self.as_mut().project();
146 if let Some(sse) = this.parsed.pop_front() {
147 return Poll::Ready(Some(Ok(sse)));
148 }
149 let next_data = ready!(this.body.poll_next(cx));
150 match next_data {
151 Some(Ok(mut data)) => {
152 loop {
153 let mut bytes = data.chunk();
154 let chunk_size = bytes.len();
155
156 if *this.mark_last_chunk_ending_with_cr {
157 if !bytes.is_empty() && bytes[0] == b'\n' {
158 bytes = &bytes[1..];
159 }
160 *this.mark_last_chunk_ending_with_cr = false;
161 }
162
163 if bytes.is_empty() {
164 return self.poll_next(cx);
165 }
166 if let BomHeaderState::NotFoundYet = this.bom_header_state {
167 if bytes[0] == BOM_HEADER[0] {
168 *this.bom_header_state = BomHeaderState::Parsing;
169 }
170 }
171 if bytes.last().is_some_and(|b| *b == b'\r') {
173 *this.mark_last_chunk_ending_with_cr = true;
174 }
175 let mut lines = bytes.chunk_by(|line_end, line_start| {
176 !(
177 *line_end == b'\n' ||
179 (*line_end == b'\r' && *line_start != b'\n')
181 )
182 });
183 let first_line = lines.next().expect("frame is empty");
184
185 let mut new_unfinished_line = Vec::new();
186 let mut first_line = if !this.unfinished_line.is_empty() {
187 this.unfinished_line.extend(first_line);
188 std::mem::swap(&mut new_unfinished_line, this.unfinished_line);
189 new_unfinished_line.as_ref()
190 } else {
191 first_line
192 };
193
194 if let BomHeaderState::Parsing = this.bom_header_state {
195 if first_line.len() > BOM_HEADER.len() {
196 if let Some(stripped) = first_line.strip_prefix(BOM_HEADER) {
197 first_line = stripped
198 }
199 *this.bom_header_state = BomHeaderState::Consumed;
201 } else {
202 this.unfinished_line.extend(first_line);
203 return self.poll_next(cx);
204 }
205 }
206
207 let mut lines = std::iter::once(first_line).chain(lines);
208 *this.unfinished_line = loop {
209 let Some(line) = lines.next() else {
210 break Vec::new();
211 };
212 let line = if line.ends_with(b"\r\n") {
213 &line[..line.len() - 2]
214 } else if line.ends_with(b"\n") || line.ends_with(b"\r") {
215 &line[..line.len() - 1]
216 } else {
217 break line.to_vec();
218 };
219
220 if line.is_empty() {
221 if let Some(sse) = this.current.take() {
222 this.parsed.push_back(sse);
223 }
224 continue;
225 }
226 let Some(comma_index) = line.iter().position(|b| *b == b':') else {
228 #[cfg(feature = "tracing")]
229 tracing::warn!(?line, "invalid line, missing `:`");
230 return Poll::Ready(Some(Err(Error::InvalidLine)));
231 };
232 let field_name = &line[..comma_index];
233 let field_value = if line.len() > comma_index + 1 {
234 let field_value = &line[comma_index + 1..];
235 if field_value.starts_with(b" ") {
236 &field_value[1..]
237 } else {
238 field_value
239 }
240 } else {
241 b""
242 };
243 match field_name {
244 b"data" => {
245 let data_line =
246 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
247 if let Some(Sse { data, .. }) = this.current.as_mut() {
249 if data.is_none() {
250 data.replace(data_line.to_owned());
251 } else {
252 let data = data.as_mut().unwrap();
253 data.push('\n');
254 data.push_str(data_line);
255 }
256 } else {
257 this.current.replace(Sse {
258 event: None,
259 data: Some(data_line.to_owned()),
260 id: None,
261 retry: None,
262 });
263 }
264 }
265 b"event" => {
266 let event_value =
267 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
268 if let Some(Sse { event, .. }) = this.current.as_mut() {
269 if event.is_some() {
270 return Poll::Ready(Some(Err(Error::DuplicatedEventLine)));
271 } else {
272 event.replace(event_value.to_owned());
273 }
274 } else {
275 this.current.replace(Sse {
276 event: Some(event_value.to_owned()),
277 ..Default::default()
278 });
279 }
280 }
281 b"id" => {
282 if field_value.contains(&0u8) {
285 #[cfg(feature = "tracing")]
286 tracing::warn!(
287 ?line,
288 "id field contains NULL byte, ignoring per spec"
289 );
290 continue;
291 }
292 let id_value =
293 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
294 if let Some(Sse { id, .. }) = this.current.as_mut() {
295 if id.is_some() {
296 return Poll::Ready(Some(Err(Error::DuplicatedIdLine)));
297 } else {
298 id.replace(id_value.to_owned());
299 }
300 } else {
301 this.current.replace(Sse {
302 id: Some(id_value.to_owned()),
303 ..Default::default()
304 });
305 }
306 }
307 b"retry" => {
308 let retry_value = std::str::from_utf8(field_value)
309 .map_err(Error::Utf8Parse)?
310 .trim_ascii();
311 let retry_value =
312 retry_value.parse::<u64>().map_err(Error::IntParse)?;
313 if let Some(Sse { retry, .. }) = this.current.as_mut() {
314 if retry.is_some() {
315 return Poll::Ready(Some(Err(Error::DuplicatedRetry)));
316 } else {
317 retry.replace(retry_value);
318 }
319 } else {
320 this.current.replace(Sse {
321 retry: Some(retry_value),
322 ..Default::default()
323 });
324 }
325 }
326 b"" => {
327 #[cfg(feature = "tracing")]
328 if tracing::enabled!(tracing::Level::DEBUG) {
329 let comment = std::str::from_utf8(field_value)
331 .map_err(Error::Utf8Parse)?;
332 tracing::debug!(?comment, "sse comment line");
333 }
334 }
335 _line => {
336 #[cfg(feature = "tracing")]
337 if tracing::enabled!(tracing::Level::WARN) {
338 tracing::warn!(line = ?_line, "invalid line: unknown field");
339 }
340 return Poll::Ready(Some(Err(Error::InvalidLine)));
341 }
342 }
343 };
344 data.advance(chunk_size);
345 if !data.has_remaining() {
346 break;
347 }
348 }
349 self.poll_next(cx)
350 }
351 Some(Err(e)) => Poll::Ready(Some(Err(Error::Body(Box::new(e))))),
352 None => {
353 Poll::Ready(None)
355 }
356 }
357 }
358}