1use std::{
2 convert::Infallible,
3 pin::Pin,
4 task::{Context, Poll},
5 time::Duration,
6};
7
8use bytes::{BufMut, Bytes, BytesMut};
9use futures_util::{
10 ready,
11 stream::{BoxStream, Stream},
12 Future, StreamExt,
13};
14use http::{header, HeaderValue, StatusCode};
15use hyper::body::Frame;
16use mincat_core::{
17 body::Body,
18 response::{IntoResponse, Response},
19};
20use pin_project_lite::pin_project;
21use tokio::time::Sleep;
22
23pub struct Sse {
24 stream: BoxStream<'static, Event>,
25 keep_alive: Option<KeepAlive>,
26}
27
28impl IntoResponse for Sse {
29 fn into_response(self) -> Response {
30 (
31 StatusCode::OK,
32 [
33 (
34 header::CONTENT_TYPE,
35 HeaderValue::from_static(mime::TEXT_EVENT_STREAM.as_ref()),
36 ),
37 (header::CACHE_CONTROL, HeaderValue::from_static("no-cache")),
38 ],
39 Body::new(SseBody {
40 event_stream: self.stream,
41 keep_alive: self.keep_alive.map(KeepAliveStream::new),
42 }),
43 )
44 .into_response()
45 }
46}
47
48impl Sse {
49 pub fn new<S>(stream: S) -> Self
50 where
51 S: Stream<Item = Event> + Send + 'static,
52 {
53 Sse {
54 stream: stream.boxed(),
55 keep_alive: None,
56 }
57 }
58
59 pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
60 self.keep_alive = Some(keep_alive);
61 self
62 }
63}
64
65#[derive(Debug, Default, Clone)]
66pub struct Event {
67 buffer: BytesMut,
68 flags: EventFlags,
69}
70
71impl Event {
72 pub fn data<T>(mut self, data: T) -> Event
73 where
74 T: AsRef<str>,
75 {
76 if self.flags.contains(EventFlags::HAS_DATA) {
77 panic!("Called `EventBuilder::data` multiple times");
78 }
79
80 for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
81 self.field("data", line);
82 }
83
84 self.flags.insert(EventFlags::HAS_DATA);
85
86 self
87 }
88
89 pub fn comment<T>(mut self, comment: T) -> Event
90 where
91 T: AsRef<str>,
92 {
93 self.field("", comment.as_ref());
94 self
95 }
96
97 pub fn event<T>(mut self, event: T) -> Event
98 where
99 T: AsRef<str>,
100 {
101 if self.flags.contains(EventFlags::HAS_EVENT) {
102 panic!("Called `EventBuilder::event` multiple times");
103 }
104 self.flags.insert(EventFlags::HAS_EVENT);
105
106 self.field("event", event.as_ref());
107
108 self
109 }
110
111 pub fn retry(mut self, duration: Duration) -> Event {
112 if self.flags.contains(EventFlags::HAS_RETRY) {
113 panic!("Called `EventBuilder::retry` multiple times");
114 }
115 self.flags.insert(EventFlags::HAS_RETRY);
116
117 self.buffer.extend_from_slice(b"retry:");
118
119 let secs = duration.as_secs();
120 let millis = duration.subsec_millis();
121
122 if secs > 0 {
123 self.buffer
124 .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
125
126 if millis < 10 {
127 self.buffer.extend_from_slice(b"00");
128 } else if millis < 100 {
129 self.buffer.extend_from_slice(b"0");
130 }
131 }
132
133 self.buffer
134 .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
135
136 self.buffer.put_u8(b'\n');
137
138 self
139 }
140
141 pub fn id<T>(mut self, id: T) -> Event
142 where
143 T: AsRef<str>,
144 {
145 if self.flags.contains(EventFlags::HAS_ID) {
146 panic!("Called `EventBuilder::id` multiple times");
147 }
148 self.flags.insert(EventFlags::HAS_ID);
149
150 let id = id.as_ref().as_bytes();
151 assert_eq!(
152 memchr::memchr(b'\0', id),
153 None,
154 "Event ID cannot contain null characters",
155 );
156
157 self.field("id", id);
158 self
159 }
160
161 fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
162 let value = value.as_ref();
163 assert_eq!(
164 memchr::memchr2(b'\r', b'\n', value),
165 None,
166 "SSE field value cannot contain newlines or carriage returns",
167 );
168 self.buffer.extend_from_slice(name.as_bytes());
169 self.buffer.put_u8(b':');
170 self.buffer.put_u8(b' ');
171 self.buffer.extend_from_slice(value);
172 self.buffer.put_u8(b'\n');
173 }
174
175 fn finalize(mut self) -> Bytes {
176 self.buffer.put_u8(b'\n');
177 self.buffer.freeze()
178 }
179}
180
181#[derive(Default, Debug, Copy, Clone, PartialEq)]
182struct EventFlags(u8);
183
184impl EventFlags {
185 const HAS_DATA: Self = Self::from_bits(0b0001);
186 const HAS_EVENT: Self = Self::from_bits(0b0010);
187 const HAS_RETRY: Self = Self::from_bits(0b0100);
188 const HAS_ID: Self = Self::from_bits(0b1000);
189
190 const fn bits(&self) -> u8 {
191 self.0
192 }
193
194 const fn from_bits(bits: u8) -> Self {
195 Self(bits)
196 }
197
198 const fn contains(&self, other: Self) -> bool {
199 self.bits() & other.bits() == other.bits()
200 }
201
202 fn insert(&mut self, other: Self) {
203 *self = Self::from_bits(self.bits() | other.bits());
204 }
205}
206
207fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
208 MemchrSplit {
209 needle,
210 haystack: Some(haystack),
211 }
212}
213
214struct MemchrSplit<'a> {
215 needle: u8,
216 haystack: Option<&'a [u8]>,
217}
218
219impl<'a> Iterator for MemchrSplit<'a> {
220 type Item = &'a [u8];
221 fn next(&mut self) -> Option<Self::Item> {
222 let haystack = self.haystack?;
223 if let Some(pos) = memchr::memchr(self.needle, haystack) {
224 let (front, back) = haystack.split_at(pos);
225 self.haystack = Some(&back[1..]);
226 Some(front)
227 } else {
228 self.haystack.take()
229 }
230 }
231}
232
233pin_project! {
234 struct SseBody {
235 event_stream: BoxStream<'static, Event>,
236 #[pin]
237 keep_alive: Option<KeepAliveStream>,
238 }
239}
240
241impl http_body::Body for SseBody {
242 type Data = Bytes;
243 type Error = Infallible;
244
245 fn poll_frame(
246 self: Pin<&mut Self>,
247 cx: &mut Context<'_>,
248 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
249 let this = self.project();
250
251 match this.event_stream.as_mut().poll_next(cx) {
252 Poll::Pending => {
253 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
254 keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
255 } else {
256 Poll::Pending
257 }
258 }
259 Poll::Ready(Some(event)) => {
260 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
261 keep_alive.reset();
262 }
263 Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
264 }
265 Poll::Ready(None) => Poll::Ready(None),
266 }
267 }
268}
269
270#[derive(Debug, Clone)]
271#[must_use]
272pub struct KeepAlive {
273 event: Bytes,
274 max_interval: Duration,
275}
276
277impl Default for KeepAlive {
278 fn default() -> Self {
279 Self::new()
280 }
281}
282
283impl KeepAlive {
284 pub fn new() -> Self {
285 Self {
286 event: Bytes::from_static(b":\n\n"),
287 max_interval: Duration::from_secs(15),
288 }
289 }
290
291 pub fn interval(mut self, time: Duration) -> Self {
292 self.max_interval = time;
293 self
294 }
295
296 pub fn text<I>(self, text: I) -> Self
297 where
298 I: AsRef<str>,
299 {
300 self.event(Event::default().comment(text))
301 }
302
303 pub fn event(mut self, event: Event) -> Self {
304 self.event = event.finalize();
305 self
306 }
307}
308
309pin_project! {
310 #[derive(Debug)]
311 struct KeepAliveStream {
312 keep_alive: KeepAlive,
313 #[pin]
314 alive_timer: Sleep,
315 }
316}
317
318impl KeepAliveStream {
319 fn new(keep_alive: KeepAlive) -> Self {
320 Self {
321 alive_timer: tokio::time::sleep(keep_alive.max_interval),
322 keep_alive,
323 }
324 }
325
326 fn reset(self: Pin<&mut Self>) {
327 let this = self.project();
328 this.alive_timer
329 .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
330 }
331
332 fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
333 let this = self.as_mut().project();
334
335 ready!(this.alive_timer.poll(cx));
336
337 let event = this.keep_alive.event.clone();
338
339 self.reset();
340
341 Poll::Ready(event)
342 }
343}