1use std::{
2 collections::BTreeMap,
3 marker::PhantomData,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use futures_core::{Stream, TryStream};
10use futures_util::TryStreamExt;
11use http::{
12 header::{CACHE_CONTROL, CONTENT_TYPE},
13 StatusCode,
14};
15use pin_project_lite::pin_project;
16use predawn_core::{
17 api_response::ApiResponse,
18 body::ResponseBody,
19 error::BoxError,
20 into_response::IntoResponse,
21 media_type::{MediaType, MultiResponseMediaType, ResponseMediaType, SingleMediaType},
22 openapi::{self, AnySchema, ReferenceOr, Schema, SchemaKind},
23 response::{MultiResponse, Response, SingleResponse},
24};
25use predawn_schema::ToSchema;
26use serde::Serialize;
27
28use super::{event::Event, keep_alive::KeepAlive};
29use crate::{response::sse::keep_alive::KeepAliveStream, response_error::EventStreamError};
30
31pub struct EventStream<T> {
32 result: Result<Response, EventStreamError>,
33 _marker: PhantomData<T>,
34}
35
36impl<T> EventStream<T> {
37 pub fn new<S>(stream: S) -> Self
38 where
39 T: Serialize + Send + 'static,
40 S: TryStream<Ok = T> + Send + 'static,
41 S::Error: Into<BoxError>,
42 {
43 Self::builder().build(stream)
44 }
45
46 pub fn builder() -> EventStreamBuilder<DefaultOnCreateEvent<T>> {
47 EventStreamBuilder {
48 keep_alive: None,
49 _marker: PhantomData,
50 }
51 }
52}
53
54impl<T> IntoResponse for EventStream<T> {
55 type Error = EventStreamError;
56
57 fn into_response(self) -> Result<Response, Self::Error> {
58 self.result
59 }
60}
61
62impl<T> MediaType for EventStream<T> {
63 const MEDIA_TYPE: &'static str = "text/event-stream";
64}
65
66impl<T> ResponseMediaType for EventStream<T> {}
67
68impl<T: ToSchema> SingleMediaType for EventStream<T> {
69 fn media_type(
70 schemas: &mut BTreeMap<String, openapi::Schema>,
71 schemas_in_progress: &mut Vec<String>,
72 ) -> openapi::MediaType {
73 let schema = Schema {
74 schema_data: Default::default(),
75 schema_kind: SchemaKind::Any(AnySchema {
76 typ: Some("array".into()),
77 items: Some(T::schema_ref_box(schemas, schemas_in_progress)),
78 format: Some("event-stream".into()),
79 ..Default::default()
80 }),
81 };
82
83 openapi::MediaType {
84 schema: Some(ReferenceOr::Item(schema)),
85 ..Default::default()
86 }
87 }
88}
89
90impl<T: ToSchema> SingleResponse for EventStream<T> {
91 fn response(
92 schemas: &mut BTreeMap<String, Schema>,
93 schemas_in_progress: &mut Vec<String>,
94 ) -> openapi::Response {
95 openapi::Response {
96 content: <Self as MultiResponseMediaType>::content(schemas, schemas_in_progress),
97 ..Default::default()
98 }
99 }
100}
101
102impl<T: ToSchema> ApiResponse for EventStream<T> {
103 fn responses(
104 schemas: &mut BTreeMap<String, Schema>,
105 schemas_in_progress: &mut Vec<String>,
106 ) -> Option<BTreeMap<StatusCode, openapi::Response>> {
107 Some(<Self as MultiResponse>::responses(
108 schemas,
109 schemas_in_progress,
110 ))
111 }
112}
113
114pub struct EventStreamBuilder<F> {
115 keep_alive: Option<KeepAlive>,
116 _marker: PhantomData<F>,
117}
118
119impl<F> EventStreamBuilder<F> {
120 pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
121 self.keep_alive = Some(keep_alive);
122 self
123 }
124}
125
126impl<F: OnCreateEvent> EventStreamBuilder<F> {
127 pub fn on_create_event<C>(self) -> EventStreamBuilder<C>
128 where
129 C: OnCreateEvent<Item = F::Item>,
130 {
131 EventStreamBuilder {
132 keep_alive: self.keep_alive,
133 _marker: PhantomData,
134 }
135 }
136
137 pub fn build<S>(self, stream: S) -> EventStream<F::Data>
138 where
139 S: TryStream<Ok = F::Item> + Send + 'static,
140 S::Error: Into<BoxError>,
141 {
142 EventStream {
143 result: inner_build(self, stream),
144 _marker: PhantomData,
145 }
146 }
147}
148
149fn inner_build<C, S>(
150 builder: EventStreamBuilder<C>,
151 stream: S,
152) -> Result<Response, EventStreamError>
153where
154 C: OnCreateEvent,
155 S: TryStream<Ok = C::Item> + Send + 'static,
156 S::Error: Into<BoxError>,
157{
158 pin_project! {
159 struct SseStream<S> {
160 #[pin]
161 stream: S,
162 #[pin]
163 keep_alive: Option<KeepAliveStream>,
164 }
165 }
166
167 impl<S> Stream for SseStream<S>
168 where
169 S: Stream<Item = Result<Bytes, BoxError>> + Send + 'static,
170 {
171 type Item = S::Item;
172
173 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
174 let mut this = self.project();
175
176 match this.stream.try_poll_next_unpin(cx) {
177 Poll::Pending => {
178 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
179 keep_alive.poll_event(cx).map(|e| Some(Ok(e)))
180 } else {
181 Poll::Pending
182 }
183 }
184 ok @ Poll::Ready(Some(Ok(_))) => {
185 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
186 keep_alive.reset();
187 }
188
189 ok
190 }
191 other => other,
192 }
193 }
194 }
195
196 let stream = SseStream {
197 stream: stream.map_err(Into::into).and_then(|item| async move {
198 let data = C::data(&item);
199 let data = serde_json::to_string(data).map_err(Box::new)?;
200
201 let mut evt = Event::data(data);
202 C::modify_event(&item, &mut evt);
203
204 let bytes = evt.as_bytes().map_err(Box::new)?;
205
206 Ok::<_, BoxError>(bytes)
207 }),
208 keep_alive: builder.keep_alive.map(KeepAliveStream::new).transpose()?,
209 };
210
211 let body = ResponseBody::from_stream(stream);
212
213 let response = http::Response::builder()
214 .header(CONTENT_TYPE, EventStream::<()>::MEDIA_TYPE)
215 .header(CACHE_CONTROL, "no-cache")
216 .header("X-Accel-Buffering", "no")
217 .body(body)
218 .unwrap();
219
220 Ok(response)
221}
222
223pub trait OnCreateEvent {
224 type Item: Send + 'static;
225 type Data: Serialize;
226
227 fn data(item: &Self::Item) -> &Self::Data;
228
229 fn modify_event(item: &Self::Item, event: &mut Event);
230}
231
232#[derive(Debug)]
233pub struct DefaultOnCreateEvent<T> {
234 _marker: PhantomData<T>,
235}
236
237impl<T> OnCreateEvent for DefaultOnCreateEvent<T>
238where
239 T: Serialize + Send + 'static,
240{
241 type Data = T;
242 type Item = T;
243
244 fn data(item: &Self::Item) -> &Self::Data {
245 item
246 }
247
248 fn modify_event(_: &Self::Item, _: &mut Event) {}
249}