1use bytes::Bytes;
2use http::{header::CONTENT_TYPE, Response};
3use http_body_util::{StreamBody, BodyExt};
4use ranvier_core::event::EventSource;
5use std::convert::Infallible;
6use std::time::Duration;
7use futures_util::stream::{Stream, StreamExt};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use crate::response::{HttpResponse, IntoResponse};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct SseEvent {
14 pub(crate) data: Option<String>,
15 pub(crate) id: Option<String>,
16 pub(crate) event: Option<String>,
17 pub(crate) retry: Option<Duration>,
18 pub(crate) comment: Option<String>,
19}
20
21impl SseEvent {
22 pub fn default() -> Self {
23 Self {
24 data: None,
25 id: None,
26 event: None,
27 retry: None,
28 comment: None,
29 }
30 }
31
32 pub fn data(mut self, data: impl Into<String>) -> Self {
33 self.data = Some(data.into());
34 self
35 }
36
37 pub fn id(mut self, id: impl Into<String>) -> Self {
38 self.id = Some(id.into());
39 self
40 }
41
42 pub fn event(mut self, event: impl Into<String>) -> Self {
43 self.event = Some(event.into());
44 self
45 }
46
47 pub fn retry(mut self, duration: Duration) -> Self {
48 self.retry = Some(duration);
49 self
50 }
51
52 pub fn comment(mut self, comment: impl Into<String>) -> Self {
53 self.comment = Some(comment.into());
54 self
55 }
56
57 fn serialize(&self) -> String {
58 let mut out = String::new();
59 if let Some(comment) = &self.comment {
60 for line in comment.lines() {
61 out.push_str(&format!(": {}\n", line));
62 }
63 }
64 if let Some(event) = &self.event {
65 out.push_str(&format!("event: {}\n", event));
66 }
67 if let Some(id) = &self.id {
68 out.push_str(&format!("id: {}\n", id));
69 }
70 if let Some(retry) = &self.retry {
71 out.push_str(&format!("retry: {}\n", retry.as_millis()));
72 }
73 if let Some(data) = &self.data {
74 for line in data.lines() {
75 out.push_str(&format!("data: {}\n", line));
76 }
77 }
78 out.push('\n');
79 out
80 }
81}
82
83pub struct Sse<S> {
84 stream: S,
85}
86
87impl<S, E> Sse<S>
88where
89 S: Stream<Item = Result<SseEvent, E>> + Send + 'static,
90 E: Into<Box<dyn std::error::Error + Send + Sync>>,
91{
92 pub fn new(stream: S) -> Self {
93 Self { stream }
94 }
95}
96
97pub struct FrameStream<S, E> {
98 inner: S,
99 _marker: std::marker::PhantomData<fn() -> E>,
100}
101
102impl<S, E> Stream for FrameStream<S, E>
103where
104 S: Stream<Item = Result<SseEvent, E>> + Unpin,
105 E: Into<Box<dyn std::error::Error + Send + Sync>>,
106{
107 type Item = Result<http_body::Frame<Bytes>, E>;
108
109 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
110 match Pin::new(&mut self.inner).poll_next(cx) {
111 Poll::Ready(Some(Ok(event))) => {
112 let serialized = event.serialize();
113 let frame = http_body::Frame::data(Bytes::from(serialized));
114 Poll::Ready(Some(Ok(frame)))
115 }
116 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
117 Poll::Ready(None) => Poll::Ready(None),
118 Poll::Pending => Poll::Pending,
119 }
120 }
121}
122
123impl<S, E> IntoResponse for Sse<S>
124where
125 S: Stream<Item = Result<SseEvent, E>> + Send + Sync + Unpin + 'static,
126 E: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
127{
128 fn into_response(self) -> HttpResponse {
129 let frame_stream = FrameStream {
130 inner: self.stream,
131 _marker: std::marker::PhantomData,
132 };
133
134 let mut frame_stream = Box::pin(frame_stream);
135 let infallible_stream = async_stream::stream! {
136 while let Some(res) = futures_util::StreamExt::next(&mut frame_stream).await {
137 match res {
138 Ok(frame) => yield Ok::<_, std::convert::Infallible>(frame),
139 Err(e) => {
140 let err: Box<dyn std::error::Error + Send + Sync> = e.into();
141 tracing::error!("SSE stream terminated with error: {:?}", err);
142 break;
143 }
144 }
145 }
146 };
147
148 let body = http_body_util::StreamBody::new(infallible_stream);
149
150 http::Response::builder()
151 .status(http::StatusCode::OK)
152 .header(http::header::CONTENT_TYPE, "text/event-stream")
153 .header(http::header::CACHE_CONTROL, "no-cache")
154 .header(http::header::CONNECTION, "keep-alive")
155 .body(http_body_util::BodyExt::boxed(body))
156 .expect("Valid builder")
157 }
158}
159
160pub fn from_event_source<E, S, F>(mut source: S, mut mapper: F) -> impl Stream<Item = Result<SseEvent, Infallible>> + Send + Sync
161where
162 S: EventSource<E> + Send + 'static,
163 E: Send + 'static,
164 F: FnMut(E) -> SseEvent + Send + 'static,
165{
166 let (tx, mut rx) = tokio::sync::mpsc::channel(16);
167 tokio::spawn(async move {
168 while let Some(event) = source.next_event().await {
169 if tx.send(mapper(event)).await.is_err() {
170 break;
171 }
172 }
173 });
174
175 let stream = async_stream::stream! {
176 while let Some(event) = rx.recv().await {
177 yield Ok(event);
178 }
179 };
180 Box::pin(stream)
181}