eiktyrner/
sse.rs

1use crate::{Error, ResponseContent};
2use async_trait::async_trait;
3use futures_util::Stream;
4use sse_agent::Sse as _;
5
6use std::{pin::Pin, task::Poll};
7
8impl From<sse_agent::Error<hyper::Error>> for Error {
9    fn from(err: sse_agent::Error<hyper::Error>) -> Self {
10        match err.kind() {
11            sse_agent::ErrorKind::<hyper::Error>::Inner(hyper_err) => Self::Net(hyper_err),
12            sse_agent::ErrorKind::<hyper::Error>::Sse(parse_err) => {
13                Self::InvalidBody(parse_err.to_string())
14            }
15        }
16    }
17}
18
19#[derive(Clone, Copy)]
20pub struct Sse<T>(std::marker::PhantomData<T>);
21
22#[async_trait]
23impl<T> ResponseContent for Sse<T>
24where
25    T: serde::de::DeserializeOwned + Unpin + Send + 'static,
26{
27    type Data = JsonStream<T>;
28
29    async fn convert_response(
30        response: hyper::Response<hyper::Body>,
31    ) -> Result<http::Response<Self::Data>, Error> {
32        let (parts, body) = response.into_parts();
33
34        if !parts.status.is_success() {
35            return Err(Error::non_2xx(parts.status, &[]));
36        }
37
38        let body = JsonStream {
39            inner: body.into_sse(),
40            _marker: std::marker::PhantomData,
41            last_event_id: None,
42        };
43
44        Ok(http::Response::from_parts(parts, body))
45    }
46}
47
48pub struct Event<T> {
49    pub event: String,
50    pub data: T,
51}
52
53pub struct JsonStream<T> {
54    inner: sse_agent::Body<hyper::Body>,
55    _marker: std::marker::PhantomData<T>,
56    last_event_id: Option<String>,
57}
58
59impl<T> Stream for JsonStream<T>
60where
61    T: serde::de::DeserializeOwned + Unpin,
62{
63    type Item = Result<Event<T>, crate::Error>;
64
65    fn poll_next(
66        mut self: std::pin::Pin<&mut Self>,
67        ctx: &mut std::task::Context<'_>,
68    ) -> Poll<Option<Self::Item>> {
69        match Pin::new(&mut self.inner).poll_next(ctx) {
70            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(Error::from(err)))),
71            Poll::Ready(None) => Poll::Ready(None),
72            Poll::Pending => Poll::Pending,
73            Poll::Ready(Some(Ok(sse_agent::Event {
74                event,
75                data,
76                last_event_id,
77            }))) => {
78                self.last_event_id = last_event_id;
79
80                let res = serde_json::from_str::<T>(&data)
81                    .map(|data| Event { event, data })
82                    .map_err(|err| Error::deserialization(err, data.as_bytes()));
83
84                Poll::Ready(Some(res))
85            }
86        }
87    }
88}