predawn/response/sse/
event_stream.rs

1use std::{
2    collections::BTreeMap,
3    marker::PhantomData,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use futures_core::{Stream, TryStream};
10use futures_util::TryStreamExt;
11use http::{
12    header::{CACHE_CONTROL, CONTENT_TYPE},
13    StatusCode,
14};
15use pin_project_lite::pin_project;
16use predawn_core::{
17    api_response::ApiResponse,
18    body::ResponseBody,
19    error::BoxError,
20    into_response::IntoResponse,
21    media_type::{MediaType, MultiResponseMediaType, ResponseMediaType, SingleMediaType},
22    openapi::{self, AnySchema, ReferenceOr, Schema, SchemaKind},
23    response::{MultiResponse, Response, SingleResponse},
24};
25use predawn_schema::ToSchema;
26use serde::Serialize;
27
28use super::{event::Event, keep_alive::KeepAlive};
29use crate::{response::sse::keep_alive::KeepAliveStream, response_error::EventStreamError};
30
31pub struct EventStream<T> {
32    result: Result<Response, EventStreamError>,
33    _marker: PhantomData<T>,
34}
35
36impl<T> EventStream<T> {
37    pub fn new<S>(stream: S) -> Self
38    where
39        T: Serialize + Send + 'static,
40        S: TryStream<Ok = T> + Send + 'static,
41        S::Error: Into<BoxError>,
42    {
43        Self::builder().build(stream)
44    }
45
46    pub fn builder() -> EventStreamBuilder<DefaultOnCreateEvent<T>> {
47        EventStreamBuilder {
48            keep_alive: None,
49            _marker: PhantomData,
50        }
51    }
52}
53
54impl<T> IntoResponse for EventStream<T> {
55    type Error = EventStreamError;
56
57    fn into_response(self) -> Result<Response, Self::Error> {
58        self.result
59    }
60}
61
62impl<T> MediaType for EventStream<T> {
63    const MEDIA_TYPE: &'static str = "text/event-stream";
64}
65
66impl<T> ResponseMediaType for EventStream<T> {}
67
68impl<T: ToSchema> SingleMediaType for EventStream<T> {
69    fn media_type(
70        schemas: &mut BTreeMap<String, openapi::Schema>,
71        schemas_in_progress: &mut Vec<String>,
72    ) -> openapi::MediaType {
73        let schema = Schema {
74            schema_data: Default::default(),
75            schema_kind: SchemaKind::Any(AnySchema {
76                typ: Some("array".into()),
77                items: Some(T::schema_ref_box(schemas, schemas_in_progress)),
78                format: Some("event-stream".into()),
79                ..Default::default()
80            }),
81        };
82
83        openapi::MediaType {
84            schema: Some(ReferenceOr::Item(schema)),
85            ..Default::default()
86        }
87    }
88}
89
90impl<T: ToSchema> SingleResponse for EventStream<T> {
91    fn response(
92        schemas: &mut BTreeMap<String, Schema>,
93        schemas_in_progress: &mut Vec<String>,
94    ) -> openapi::Response {
95        openapi::Response {
96            content: <Self as MultiResponseMediaType>::content(schemas, schemas_in_progress),
97            ..Default::default()
98        }
99    }
100}
101
102impl<T: ToSchema> ApiResponse for EventStream<T> {
103    fn responses(
104        schemas: &mut BTreeMap<String, Schema>,
105        schemas_in_progress: &mut Vec<String>,
106    ) -> Option<BTreeMap<StatusCode, openapi::Response>> {
107        Some(<Self as MultiResponse>::responses(
108            schemas,
109            schemas_in_progress,
110        ))
111    }
112}
113
114pub struct EventStreamBuilder<F> {
115    keep_alive: Option<KeepAlive>,
116    _marker: PhantomData<F>,
117}
118
119impl<F> EventStreamBuilder<F> {
120    pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
121        self.keep_alive = Some(keep_alive);
122        self
123    }
124}
125
126impl<F: OnCreateEvent> EventStreamBuilder<F> {
127    pub fn on_create_event<C>(self) -> EventStreamBuilder<C>
128    where
129        C: OnCreateEvent<Item = F::Item>,
130    {
131        EventStreamBuilder {
132            keep_alive: self.keep_alive,
133            _marker: PhantomData,
134        }
135    }
136
137    pub fn build<S>(self, stream: S) -> EventStream<F::Data>
138    where
139        S: TryStream<Ok = F::Item> + Send + 'static,
140        S::Error: Into<BoxError>,
141    {
142        EventStream {
143            result: inner_build(self, stream),
144            _marker: PhantomData,
145        }
146    }
147}
148
149fn inner_build<C, S>(
150    builder: EventStreamBuilder<C>,
151    stream: S,
152) -> Result<Response, EventStreamError>
153where
154    C: OnCreateEvent,
155    S: TryStream<Ok = C::Item> + Send + 'static,
156    S::Error: Into<BoxError>,
157{
158    pin_project! {
159        struct SseStream<S> {
160            #[pin]
161            stream: S,
162            #[pin]
163            keep_alive: Option<KeepAliveStream>,
164        }
165    }
166
167    impl<S> Stream for SseStream<S>
168    where
169        S: Stream<Item = Result<Bytes, BoxError>> + Send + 'static,
170    {
171        type Item = S::Item;
172
173        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
174            let mut this = self.project();
175
176            match this.stream.try_poll_next_unpin(cx) {
177                Poll::Pending => {
178                    if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
179                        keep_alive.poll_event(cx).map(|e| Some(Ok(e)))
180                    } else {
181                        Poll::Pending
182                    }
183                }
184                ok @ Poll::Ready(Some(Ok(_))) => {
185                    if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
186                        keep_alive.reset();
187                    }
188
189                    ok
190                }
191                other => other,
192            }
193        }
194    }
195
196    let stream = SseStream {
197        stream: stream.map_err(Into::into).and_then(|item| async move {
198            let data = C::data(&item);
199            let data = serde_json::to_string(data).map_err(Box::new)?;
200
201            let mut evt = Event::data(data);
202            C::modify_event(&item, &mut evt);
203
204            let bytes = evt.as_bytes().map_err(Box::new)?;
205
206            Ok::<_, BoxError>(bytes)
207        }),
208        keep_alive: builder.keep_alive.map(KeepAliveStream::new).transpose()?,
209    };
210
211    let body = ResponseBody::from_stream(stream);
212
213    let response = http::Response::builder()
214        .header(CONTENT_TYPE, EventStream::<()>::MEDIA_TYPE)
215        .header(CACHE_CONTROL, "no-cache")
216        .header("X-Accel-Buffering", "no")
217        .body(body)
218        .unwrap();
219
220    Ok(response)
221}
222
223pub trait OnCreateEvent {
224    type Item: Send + 'static;
225    type Data: Serialize;
226
227    fn data(item: &Self::Item) -> &Self::Data;
228
229    fn modify_event(item: &Self::Item, event: &mut Event);
230}
231
232#[derive(Debug)]
233pub struct DefaultOnCreateEvent<T> {
234    _marker: PhantomData<T>,
235}
236
237impl<T> OnCreateEvent for DefaultOnCreateEvent<T>
238where
239    T: Serialize + Send + 'static,
240{
241    type Data = T;
242    type Item = T;
243
244    fn data(item: &Self::Item) -> &Self::Data {
245        item
246    }
247
248    fn modify_event(_: &Self::Item, _: &mut Event) {}
249}