1use crate::{Error, ResponseContent};
2use async_trait::async_trait;
3use futures_util::Stream;
4use sse_agent::Sse as _;
5
6use std::{pin::Pin, task::Poll};
7
8impl From<sse_agent::Error<hyper::Error>> for Error {
9 fn from(err: sse_agent::Error<hyper::Error>) -> Self {
10 match err.kind() {
11 sse_agent::ErrorKind::<hyper::Error>::Inner(hyper_err) => Self::Net(hyper_err),
12 sse_agent::ErrorKind::<hyper::Error>::Sse(parse_err) => {
13 Self::InvalidBody(parse_err.to_string())
14 }
15 }
16 }
17}
18
19#[derive(Clone, Copy)]
20pub struct Sse<T>(std::marker::PhantomData<T>);
21
22#[async_trait]
23impl<T> ResponseContent for Sse<T>
24where
25 T: serde::de::DeserializeOwned + Unpin + Send + 'static,
26{
27 type Data = JsonStream<T>;
28
29 async fn convert_response(
30 response: hyper::Response<hyper::Body>,
31 ) -> Result<http::Response<Self::Data>, Error> {
32 let (parts, body) = response.into_parts();
33
34 if !parts.status.is_success() {
35 return Err(Error::non_2xx(parts.status, &[]));
36 }
37
38 let body = JsonStream {
39 inner: body.into_sse(),
40 _marker: std::marker::PhantomData,
41 last_event_id: None,
42 };
43
44 Ok(http::Response::from_parts(parts, body))
45 }
46}
47
48pub struct Event<T> {
49 pub event: String,
50 pub data: T,
51}
52
53pub struct JsonStream<T> {
54 inner: sse_agent::Body<hyper::Body>,
55 _marker: std::marker::PhantomData<T>,
56 last_event_id: Option<String>,
57}
58
59impl<T> Stream for JsonStream<T>
60where
61 T: serde::de::DeserializeOwned + Unpin,
62{
63 type Item = Result<Event<T>, crate::Error>;
64
65 fn poll_next(
66 mut self: std::pin::Pin<&mut Self>,
67 ctx: &mut std::task::Context<'_>,
68 ) -> Poll<Option<Self::Item>> {
69 match Pin::new(&mut self.inner).poll_next(ctx) {
70 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(Error::from(err)))),
71 Poll::Ready(None) => Poll::Ready(None),
72 Poll::Pending => Poll::Pending,
73 Poll::Ready(Some(Ok(sse_agent::Event {
74 event,
75 data,
76 last_event_id,
77 }))) => {
78 self.last_event_id = last_event_id;
79
80 let res = serde_json::from_str::<T>(&data)
81 .map(|data| Event { event, data })
82 .map_err(|err| Error::deserialization(err, data.as_bytes()));
83
84 Poll::Ready(Some(res))
85 }
86 }
87 }
88}