1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2use std::{
3 collections::VecDeque,
4 num::ParseIntError,
5 str::Utf8Error,
6 task::{Context, Poll, ready},
7};
8
9use bytes::Buf;
10use futures_util::{Stream, TryStreamExt, stream::MapOk};
11use http_body::{Body, Frame};
12use http_body_util::{BodyDataStream, StreamBody};
13
14pin_project_lite::pin_project! {
15 pub struct SseStream<B: Body> {
16 #[pin]
17 body: BodyDataStream<B>,
18 parsed: VecDeque<Sse>,
19 current: Option<Sse>,
20 unfinished_line: Vec<u8>,
21 }
22}
23
24pub type ByteStreamBody<S, D> = StreamBody<MapOk<S, fn(D) -> Frame<D>>>;
25impl<E, S, D> SseStream<ByteStreamBody<S, D>>
26where
27 S: Stream<Item = Result<D, E>>,
28 E: std::error::Error,
29 D: Buf,
30 StreamBody<ByteStreamBody<S, D>>: Body,
31{
32 pub fn from_byte_stream(stream: S) -> Self {
36 let stream = stream.map_ok(http_body::Frame::data as fn(D) -> Frame<D>);
37 let body = StreamBody::new(stream);
38 Self {
39 body: BodyDataStream::new(body),
40 parsed: VecDeque::new(),
41 current: None,
42 unfinished_line: Vec::new(),
43 }
44 }
45}
46
47impl<B: Body> SseStream<B> {
48 pub fn new(body: B) -> Self {
50 Self {
51 body: BodyDataStream::new(body),
52 parsed: VecDeque::new(),
53 current: None,
54 unfinished_line: Vec::new(),
55 }
56 }
57}
58
59#[derive(Default, Debug)]
60pub struct Sse {
61 pub event: Option<String>,
62 pub data: Option<String>,
63 pub id: Option<String>,
64 pub retry: Option<u64>,
65}
66
67impl Sse {
68 pub fn is_event(&self) -> bool {
69 self.event.is_some()
70 }
71 pub fn is_message(&self) -> bool {
72 self.event.is_none()
73 }
74}
75
76pub enum Error {
77 Body(Box<dyn std::error::Error>),
78 InvalidLine,
79 DuplicatedEventLine,
80 DuplicatedIdLine,
81 DuplicatedRetry,
82 Utf8Parse(Utf8Error),
83 IntParse(ParseIntError),
84}
85
86impl std::fmt::Display for Error {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 match self {
89 Error::Body(e) => write!(f, "body error: {}", e),
90 Error::InvalidLine => write!(f, "invalid line"),
91 Error::DuplicatedEventLine => write!(f, "duplicated event line"),
92 Error::DuplicatedIdLine => write!(f, "duplicated id line"),
93 Error::DuplicatedRetry => write!(f, "duplicated retry line"),
94 Error::Utf8Parse(e) => write!(f, "utf8 parse error: {}", e),
95 Error::IntParse(e) => write!(f, "int parse error: {}", e),
96 }
97 }
98}
99
100impl std::fmt::Debug for Error {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 match self {
103 Error::Body(e) => write!(f, "Body({:?})", e),
104 Error::InvalidLine => write!(f, "InvalidLine"),
105 Error::DuplicatedEventLine => write!(f, "DuplicatedEventLine"),
106 Error::DuplicatedIdLine => write!(f, "DuplicatedIdLine"),
107 Error::DuplicatedRetry => write!(f, "DuplicatedRetry"),
108 Error::Utf8Parse(e) => write!(f, "Utf8Parse({:?})", e),
109 Error::IntParse(e) => write!(f, "IntParse({:?})", e),
110 }
111 }
112}
113
114impl std::error::Error for Error {
115 fn description(&self) -> &str {
116 match self {
117 Error::Body(_) => "body error",
118 Error::InvalidLine => "invalid line",
119 Error::DuplicatedEventLine => "duplicated event line",
120 Error::DuplicatedIdLine => "duplicated id line",
121 Error::DuplicatedRetry => "duplicated retry line",
122 Error::Utf8Parse(_) => "utf8 parse error",
123 Error::IntParse(_) => "int parse error",
124 }
125 }
126
127 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
128 match self {
129 Error::Body(e) => Some(e.as_ref()),
130 Error::Utf8Parse(e) => Some(e),
131 Error::IntParse(e) => Some(e),
132 _ => None,
133 }
134 }
135}
136
137impl<B: Body> Stream for SseStream<B>
138where
139 B::Error: std::error::Error + 'static,
140{
141 type Item = Result<Sse, Error>;
142
143 fn poll_next(
144 mut self: std::pin::Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 ) -> Poll<Option<Self::Item>> {
147 let this = self.as_mut().project();
148 if let Some(sse) = this.parsed.pop_front() {
149 return Poll::Ready(Some(Ok(sse)));
150 }
151 let next_data = ready!(this.body.poll_next(cx));
152 match next_data {
153 Some(Ok(data)) => {
154 let chunk = data.chunk();
155 if chunk.is_empty() {
156 return self.poll_next(cx);
157 }
158 let mut lines = chunk.chunk_by(|maybe_nl, _| *maybe_nl != b'\n');
159 let first_line = lines.next().expect("frame is empty");
160 let mut new_unfinished_line = Vec::new();
161 let first_line = if !this.unfinished_line.is_empty() {
162 this.unfinished_line.extend(first_line);
163 std::mem::swap(&mut new_unfinished_line, this.unfinished_line);
164 new_unfinished_line.as_ref()
165 } else {
166 first_line
167 };
168 let mut lines = std::iter::once(first_line).chain(lines);
169 *this.unfinished_line = loop {
170 let Some(line) = lines.next() else {
171 break Vec::new();
172 };
173 if line.last().copied() != Some(b'\n') {
174 break line.to_vec();
175 }
176 let line = &line[..line.len() - 1];
178 if line.is_empty() {
179 if let Some(sse) = this.current.take() {
180 this.parsed.push_back(sse);
181 }
182 continue;
183 }
184 let Some(comma_index) = line.iter().position(|b| *b == b':') else {
186 return Poll::Ready(Some(Err(Error::InvalidLine)));
187 };
188 let field_name = &line[..comma_index];
189 let field_value = &line[comma_index + 1..];
190 match field_name {
191 b"data" => {
192 let data_line =
193 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
194 if let Some(Sse { data, .. }) = this.current.as_mut() {
196 if data.is_none() {
197 data.replace(data_line.to_owned());
198 } else {
199 let data = data.as_mut().unwrap();
200 data.push('\n');
201 data.push_str(data_line);
202 }
203 } else {
204 this.current.replace(Sse {
205 event: None,
206 data: Some(data_line.to_owned()),
207 id: None,
208 retry: None,
209 });
210 }
211 }
212 b"event" => {
213 let event_value =
214 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
215 if let Some(Sse { event, .. }) = this.current.as_mut() {
216 if event.is_some() {
217 return Poll::Ready(Some(Err(Error::DuplicatedEventLine)));
218 } else {
219 event.replace(event_value.to_owned());
220 }
221 } else {
222 this.current.replace(Sse {
223 event: Some(event_value.to_owned()),
224 ..Default::default()
225 });
226 }
227 }
228 b"id" => {
229 let id_value =
230 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
231 if let Some(Sse { id, .. }) = this.current.as_mut() {
232 if id.is_some() {
233 return Poll::Ready(Some(Err(Error::DuplicatedIdLine)));
234 } else {
235 id.replace(id_value.to_owned());
236 }
237 } else {
238 this.current.replace(Sse {
239 id: Some(id_value.to_owned()),
240 ..Default::default()
241 });
242 }
243 }
244 b"retry" => {
245 let retry_value = std::str::from_utf8(field_value)
246 .map_err(Error::Utf8Parse)?
247 .trim_ascii();
248 let retry_value =
249 retry_value.parse::<u64>().map_err(Error::IntParse)?;
250 if let Some(Sse { retry, .. }) = this.current.as_mut() {
251 if retry.is_some() {
252 return Poll::Ready(Some(Err(Error::DuplicatedRetry)));
253 } else {
254 retry.replace(retry_value);
255 }
256 } else {
257 this.current.replace(Sse {
258 retry: Some(retry_value),
259 ..Default::default()
260 });
261 }
262 }
263 b"" => {
264 #[cfg(feature = "tracing")]
265 if tracing::enabled!(tracing::Level::DEBUG) {
266 let comment =
268 std::str::from_utf8(field_value).map_err(Error::Utf8Parse)?;
269 tracing::debug!(?comment, "sse comment line");
270 }
271 }
272 _ => {
273 return Poll::Ready(Some(Err(Error::InvalidLine)));
274 }
275 }
276 };
277 self.poll_next(cx)
278 }
279 Some(Err(e)) => Poll::Ready(Some(Err(Error::Body(Box::new(e))))),
280 None => {
281 if let Some(sse) = this.current.take() {
282 Poll::Ready(Some(Ok(sse)))
283 } else {
284 Poll::Ready(None)
285 }
286 }
287 }
288 }
289}