Skip to main content

ranvier_http/
sse.rs

1use bytes::Bytes;
2use http::{header::CONTENT_TYPE, Response};
3use http_body_util::{StreamBody, BodyExt};
4use ranvier_core::event::EventSource;
5use std::convert::Infallible;
6use std::time::Duration;
7use futures_util::stream::{Stream, StreamExt};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use crate::response::{HttpResponse, IntoResponse};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct SseEvent {
14    pub(crate) data: Option<String>,
15    pub(crate) id: Option<String>,
16    pub(crate) event: Option<String>,
17    pub(crate) retry: Option<Duration>,
18    pub(crate) comment: Option<String>,
19}
20
21impl SseEvent {
22    pub fn default() -> Self {
23        Self {
24            data: None,
25            id: None,
26            event: None,
27            retry: None,
28            comment: None,
29        }
30    }
31
32    pub fn data(mut self, data: impl Into<String>) -> Self {
33        self.data = Some(data.into());
34        self
35    }
36
37    pub fn id(mut self, id: impl Into<String>) -> Self {
38        self.id = Some(id.into());
39        self
40    }
41
42    pub fn event(mut self, event: impl Into<String>) -> Self {
43        self.event = Some(event.into());
44        self
45    }
46
47    pub fn retry(mut self, duration: Duration) -> Self {
48        self.retry = Some(duration);
49        self
50    }
51
52    pub fn comment(mut self, comment: impl Into<String>) -> Self {
53        self.comment = Some(comment.into());
54        self
55    }
56
57    fn serialize(&self) -> String {
58        let mut out = String::new();
59        if let Some(comment) = &self.comment {
60            for line in comment.lines() {
61                out.push_str(&format!(": {}\n", line));
62            }
63        }
64        if let Some(event) = &self.event {
65            out.push_str(&format!("event: {}\n", event));
66        }
67        if let Some(id) = &self.id {
68            out.push_str(&format!("id: {}\n", id));
69        }
70        if let Some(retry) = &self.retry {
71            out.push_str(&format!("retry: {}\n", retry.as_millis()));
72        }
73        if let Some(data) = &self.data {
74            for line in data.lines() {
75                out.push_str(&format!("data: {}\n", line));
76            }
77        }
78        out.push('\n');
79        out
80    }
81}
82
83pub struct Sse<S> {
84    stream: S,
85}
86
87impl<S, E> Sse<S>
88where
89    S: Stream<Item = Result<SseEvent, E>> + Send + 'static,
90    E: Into<Box<dyn std::error::Error + Send + Sync>>,
91{
92    pub fn new(stream: S) -> Self {
93        Self { stream }
94    }
95}
96
97pub struct FrameStream<S, E> {
98    inner: S,
99    _marker: std::marker::PhantomData<fn() -> E>,
100}
101
102impl<S, E> Stream for FrameStream<S, E>
103where
104    S: Stream<Item = Result<SseEvent, E>> + Unpin,
105    E: Into<Box<dyn std::error::Error + Send + Sync>>,
106{
107    type Item = Result<http_body::Frame<Bytes>, E>;
108
109    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
110        match Pin::new(&mut self.inner).poll_next(cx) {
111            Poll::Ready(Some(Ok(event))) => {
112                let serialized = event.serialize();
113                let frame = http_body::Frame::data(Bytes::from(serialized));
114                Poll::Ready(Some(Ok(frame)))
115            }
116            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
117            Poll::Ready(None) => Poll::Ready(None),
118            Poll::Pending => Poll::Pending,
119        }
120    }
121}
122
123impl<S, E> IntoResponse for Sse<S>
124where
125    S: Stream<Item = Result<SseEvent, E>> + Send + Sync + Unpin + 'static,
126    E: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
127{
128    fn into_response(self) -> HttpResponse {
129        let frame_stream = FrameStream {
130            inner: self.stream,
131            _marker: std::marker::PhantomData,
132        };
133
134        let mut frame_stream = Box::pin(frame_stream);
135        let infallible_stream = async_stream::stream! {
136            while let Some(res) = futures_util::StreamExt::next(&mut frame_stream).await {
137                match res {
138                    Ok(frame) => yield Ok::<_, std::convert::Infallible>(frame),
139                    Err(e) => {
140                        let err: Box<dyn std::error::Error + Send + Sync> = e.into();
141                        tracing::error!("SSE stream terminated with error: {:?}", err);
142                        break;
143                    }
144                }
145            }
146        };
147
148        let body = http_body_util::StreamBody::new(infallible_stream);
149
150        http::Response::builder()
151            .status(http::StatusCode::OK)
152            .header(http::header::CONTENT_TYPE, "text/event-stream")
153            .header(http::header::CACHE_CONTROL, "no-cache")
154            .header(http::header::CONNECTION, "keep-alive")
155            .body(http_body_util::BodyExt::boxed(body))
156            .expect("Valid builder")
157    }
158}
159
160pub fn from_event_source<E, S, F>(mut source: S, mut mapper: F) -> impl Stream<Item = Result<SseEvent, Infallible>> + Send + Sync
161where
162    S: EventSource<E> + Send + 'static,
163    E: Send + 'static,
164    F: FnMut(E) -> SseEvent + Send + 'static,
165{
166    let (tx, mut rx) = tokio::sync::mpsc::channel(16);
167    tokio::spawn(async move {
168        while let Some(event) = source.next_event().await {
169            if tx.send(mapper(event)).await.is_err() {
170                break;
171            }
172        }
173    });
174
175    let stream = async_stream::stream! {
176        while let Some(event) = rx.recv().await {
177            yield Ok(event);
178        }
179    };
180    Box::pin(stream)
181}