1use crate::response::{HttpResponse, IntoResponse};
2use bytes::Bytes;
3use futures_util::stream::Stream;
4
5use ranvier_core::event::EventSource;
6use serde::de::{self, Deserializer};
7use serde::ser::Serializer;
8use serde::{Deserialize, Serialize};
9use std::convert::Infallible;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct SseEvent {
16 pub(crate) data: Option<String>,
17 pub(crate) id: Option<String>,
18 pub(crate) event: Option<String>,
19 pub(crate) retry: Option<Duration>,
20 pub(crate) comment: Option<String>,
21}
22
23impl SseEvent {
24 pub fn default() -> Self {
25 Self {
26 data: None,
27 id: None,
28 event: None,
29 retry: None,
30 comment: None,
31 }
32 }
33
34 pub fn data(mut self, data: impl Into<String>) -> Self {
35 self.data = Some(data.into());
36 self
37 }
38
39 pub fn id(mut self, id: impl Into<String>) -> Self {
40 self.id = Some(id.into());
41 self
42 }
43
44 pub fn event(mut self, event: impl Into<String>) -> Self {
45 self.event = Some(event.into());
46 self
47 }
48
49 pub fn retry(mut self, duration: Duration) -> Self {
50 self.retry = Some(duration);
51 self
52 }
53
54 pub fn comment(mut self, comment: impl Into<String>) -> Self {
55 self.comment = Some(comment.into());
56 self
57 }
58
59 fn serialize(&self) -> String {
60 let mut out = String::new();
61 if let Some(comment) = &self.comment {
62 for line in comment.lines() {
63 out.push_str(&format!(": {}\n", line));
64 }
65 }
66 if let Some(event) = &self.event {
67 out.push_str(&format!("event: {}\n", event));
68 }
69 if let Some(id) = &self.id {
70 out.push_str(&format!("id: {}\n", id));
71 }
72 if let Some(retry) = &self.retry {
73 out.push_str(&format!("retry: {}\n", retry.as_millis()));
74 }
75 if let Some(data) = &self.data {
76 for line in data.lines() {
77 out.push_str(&format!("data: {}\n", line));
78 }
79 }
80 out.push('\n');
81 out
82 }
83}
84
85pub struct Sse<S> {
86 stream: S,
87}
88
89impl<S> Serialize for Sse<S> {
94 fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
95 serializer.serialize_str("<<Sse stream>>")
96 }
97}
98
99impl<'de, S> Deserialize<'de> for Sse<S> {
100 fn deserialize<D: Deserializer<'de>>(_deserializer: D) -> Result<Self, D::Error> {
101 Err(de::Error::custom(
102 "Sse<S> cannot be deserialized; it wraps a live stream",
103 ))
104 }
105}
106
107impl<S, E> Sse<S>
108where
109 S: Stream<Item = Result<SseEvent, E>> + Send + 'static,
110 E: Into<Box<dyn std::error::Error + Send + Sync>>,
111{
112 pub fn new(stream: S) -> Self {
113 Self { stream }
114 }
115}
116
117pub struct FrameStream<S, E> {
118 inner: S,
119 _marker: std::marker::PhantomData<fn() -> E>,
120}
121
122impl<S, E> Stream for FrameStream<S, E>
123where
124 S: Stream<Item = Result<SseEvent, E>> + Unpin,
125 E: Into<Box<dyn std::error::Error + Send + Sync>>,
126{
127 type Item = Result<http_body::Frame<Bytes>, E>;
128
129 fn poll_next(
130 mut self: Pin<&mut Self>,
131 cx: &mut Context<'_>,
132 ) -> Poll<Option<<Self as Stream>::Item>> {
133 match Pin::new(&mut self.inner).poll_next(cx) {
134 Poll::Ready(Some(Ok(event))) => {
135 let serialized = event.serialize();
136 let frame = http_body::Frame::data(Bytes::from(serialized));
137 Poll::Ready(Some(Ok(frame)))
138 }
139 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
140 Poll::Ready(None) => Poll::Ready(None),
141 Poll::Pending => Poll::Pending,
142 }
143 }
144}
145
146impl<S, E> IntoResponse for Sse<S>
147where
148 S: Stream<Item = Result<SseEvent, E>> + Send + Sync + Unpin + 'static,
149 E: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
150{
151 fn into_response(self) -> HttpResponse {
152 let frame_stream = FrameStream {
153 inner: self.stream,
154 _marker: std::marker::PhantomData,
155 };
156
157 let mut frame_stream = Box::pin(frame_stream);
158 let infallible_stream = async_stream::stream! {
159 while let Some(res) = futures_util::StreamExt::next(&mut frame_stream).await {
160 match res {
161 Ok(frame) => yield Ok::<_, std::convert::Infallible>(frame),
162 Err(e) => {
163 let err: Box<dyn std::error::Error + Send + Sync> = e.into();
164 tracing::error!("SSE stream terminated with error: {:?}", err);
165 break;
166 }
167 }
168 }
169 };
170
171 let body = http_body_util::StreamBody::new(infallible_stream);
172
173 http::Response::builder()
174 .status(http::StatusCode::OK)
175 .header(http::header::CONTENT_TYPE, "text/event-stream")
176 .header(http::header::CACHE_CONTROL, "no-cache")
177 .header(http::header::CONNECTION, "keep-alive")
178 .body(http_body_util::BodyExt::boxed(body))
179 .expect("Valid builder")
180 }
181}
182
183pub fn from_event_source<E, S, F>(
184 mut source: S,
185 mut mapper: F,
186) -> impl Stream<Item = Result<SseEvent, Infallible>> + Send + Sync
187where
188 S: EventSource<E> + Send + 'static,
189 E: Send + 'static,
190 F: FnMut(E) -> SseEvent + Send + 'static,
191{
192 let (tx, mut rx) = tokio::sync::mpsc::channel(16);
193 tokio::spawn(async move {
194 while let Some(event) = source.next_event().await {
195 if tx.send(mapper(event)).await.is_err() {
196 break;
197 }
198 }
199 });
200
201 let stream = async_stream::stream! {
202 while let Some(event) = rx.recv().await {
203 yield Ok(event);
204 }
205 };
206 Box::pin(stream)
207}