mincat/response/
sse.rs

1use std::{
2    convert::Infallible,
3    pin::Pin,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use bytes::{BufMut, Bytes, BytesMut};
9use futures_util::{
10    ready,
11    stream::{BoxStream, Stream},
12    Future, StreamExt,
13};
14use http::{header, HeaderValue, StatusCode};
15use hyper::body::Frame;
16use mincat_core::{
17    body::Body,
18    response::{IntoResponse, Response},
19};
20use pin_project_lite::pin_project;
21use tokio::time::Sleep;
22
23pub struct Sse {
24    stream: BoxStream<'static, Event>,
25    keep_alive: Option<KeepAlive>,
26}
27
28impl IntoResponse for Sse {
29    fn into_response(self) -> Response {
30        (
31            StatusCode::OK,
32            [
33                (
34                    header::CONTENT_TYPE,
35                    HeaderValue::from_static(mime::TEXT_EVENT_STREAM.as_ref()),
36                ),
37                (header::CACHE_CONTROL, HeaderValue::from_static("no-cache")),
38            ],
39            Body::new(SseBody {
40                event_stream: self.stream,
41                keep_alive: self.keep_alive.map(KeepAliveStream::new),
42            }),
43        )
44            .into_response()
45    }
46}
47
48impl Sse {
49    pub fn new<S>(stream: S) -> Self
50    where
51        S: Stream<Item = Event> + Send + 'static,
52    {
53        Sse {
54            stream: stream.boxed(),
55            keep_alive: None,
56        }
57    }
58
59    pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
60        self.keep_alive = Some(keep_alive);
61        self
62    }
63}
64
65#[derive(Debug, Default, Clone)]
66pub struct Event {
67    buffer: BytesMut,
68    flags: EventFlags,
69}
70
71impl Event {
72    pub fn data<T>(mut self, data: T) -> Event
73    where
74        T: AsRef<str>,
75    {
76        if self.flags.contains(EventFlags::HAS_DATA) {
77            panic!("Called `EventBuilder::data` multiple times");
78        }
79
80        for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
81            self.field("data", line);
82        }
83
84        self.flags.insert(EventFlags::HAS_DATA);
85
86        self
87    }
88
89    pub fn comment<T>(mut self, comment: T) -> Event
90    where
91        T: AsRef<str>,
92    {
93        self.field("", comment.as_ref());
94        self
95    }
96
97    pub fn event<T>(mut self, event: T) -> Event
98    where
99        T: AsRef<str>,
100    {
101        if self.flags.contains(EventFlags::HAS_EVENT) {
102            panic!("Called `EventBuilder::event` multiple times");
103        }
104        self.flags.insert(EventFlags::HAS_EVENT);
105
106        self.field("event", event.as_ref());
107
108        self
109    }
110
111    pub fn retry(mut self, duration: Duration) -> Event {
112        if self.flags.contains(EventFlags::HAS_RETRY) {
113            panic!("Called `EventBuilder::retry` multiple times");
114        }
115        self.flags.insert(EventFlags::HAS_RETRY);
116
117        self.buffer.extend_from_slice(b"retry:");
118
119        let secs = duration.as_secs();
120        let millis = duration.subsec_millis();
121
122        if secs > 0 {
123            self.buffer
124                .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
125
126            if millis < 10 {
127                self.buffer.extend_from_slice(b"00");
128            } else if millis < 100 {
129                self.buffer.extend_from_slice(b"0");
130            }
131        }
132
133        self.buffer
134            .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
135
136        self.buffer.put_u8(b'\n');
137
138        self
139    }
140
141    pub fn id<T>(mut self, id: T) -> Event
142    where
143        T: AsRef<str>,
144    {
145        if self.flags.contains(EventFlags::HAS_ID) {
146            panic!("Called `EventBuilder::id` multiple times");
147        }
148        self.flags.insert(EventFlags::HAS_ID);
149
150        let id = id.as_ref().as_bytes();
151        assert_eq!(
152            memchr::memchr(b'\0', id),
153            None,
154            "Event ID cannot contain null characters",
155        );
156
157        self.field("id", id);
158        self
159    }
160
161    fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
162        let value = value.as_ref();
163        assert_eq!(
164            memchr::memchr2(b'\r', b'\n', value),
165            None,
166            "SSE field value cannot contain newlines or carriage returns",
167        );
168        self.buffer.extend_from_slice(name.as_bytes());
169        self.buffer.put_u8(b':');
170        self.buffer.put_u8(b' ');
171        self.buffer.extend_from_slice(value);
172        self.buffer.put_u8(b'\n');
173    }
174
175    fn finalize(mut self) -> Bytes {
176        self.buffer.put_u8(b'\n');
177        self.buffer.freeze()
178    }
179}
180
181#[derive(Default, Debug, Copy, Clone, PartialEq)]
182struct EventFlags(u8);
183
184impl EventFlags {
185    const HAS_DATA: Self = Self::from_bits(0b0001);
186    const HAS_EVENT: Self = Self::from_bits(0b0010);
187    const HAS_RETRY: Self = Self::from_bits(0b0100);
188    const HAS_ID: Self = Self::from_bits(0b1000);
189
190    const fn bits(&self) -> u8 {
191        self.0
192    }
193
194    const fn from_bits(bits: u8) -> Self {
195        Self(bits)
196    }
197
198    const fn contains(&self, other: Self) -> bool {
199        self.bits() & other.bits() == other.bits()
200    }
201
202    fn insert(&mut self, other: Self) {
203        *self = Self::from_bits(self.bits() | other.bits());
204    }
205}
206
207fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
208    MemchrSplit {
209        needle,
210        haystack: Some(haystack),
211    }
212}
213
214struct MemchrSplit<'a> {
215    needle: u8,
216    haystack: Option<&'a [u8]>,
217}
218
219impl<'a> Iterator for MemchrSplit<'a> {
220    type Item = &'a [u8];
221    fn next(&mut self) -> Option<Self::Item> {
222        let haystack = self.haystack?;
223        if let Some(pos) = memchr::memchr(self.needle, haystack) {
224            let (front, back) = haystack.split_at(pos);
225            self.haystack = Some(&back[1..]);
226            Some(front)
227        } else {
228            self.haystack.take()
229        }
230    }
231}
232
233pin_project! {
234    struct SseBody {
235        event_stream: BoxStream<'static, Event>,
236        #[pin]
237        keep_alive: Option<KeepAliveStream>,
238    }
239}
240
241impl http_body::Body for SseBody {
242    type Data = Bytes;
243    type Error = Infallible;
244
245    fn poll_frame(
246        self: Pin<&mut Self>,
247        cx: &mut Context<'_>,
248    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
249        let this = self.project();
250
251        match this.event_stream.as_mut().poll_next(cx) {
252            Poll::Pending => {
253                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
254                    keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
255                } else {
256                    Poll::Pending
257                }
258            }
259            Poll::Ready(Some(event)) => {
260                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
261                    keep_alive.reset();
262                }
263                Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
264            }
265            Poll::Ready(None) => Poll::Ready(None),
266        }
267    }
268}
269
270#[derive(Debug, Clone)]
271#[must_use]
272pub struct KeepAlive {
273    event: Bytes,
274    max_interval: Duration,
275}
276
277impl Default for KeepAlive {
278    fn default() -> Self {
279        Self::new()
280    }
281}
282
283impl KeepAlive {
284    pub fn new() -> Self {
285        Self {
286            event: Bytes::from_static(b":\n\n"),
287            max_interval: Duration::from_secs(15),
288        }
289    }
290
291    pub fn interval(mut self, time: Duration) -> Self {
292        self.max_interval = time;
293        self
294    }
295
296    pub fn text<I>(self, text: I) -> Self
297    where
298        I: AsRef<str>,
299    {
300        self.event(Event::default().comment(text))
301    }
302
303    pub fn event(mut self, event: Event) -> Self {
304        self.event = event.finalize();
305        self
306    }
307}
308
309pin_project! {
310    #[derive(Debug)]
311    struct KeepAliveStream {
312        keep_alive: KeepAlive,
313        #[pin]
314        alive_timer: Sleep,
315    }
316}
317
318impl KeepAliveStream {
319    fn new(keep_alive: KeepAlive) -> Self {
320        Self {
321            alive_timer: tokio::time::sleep(keep_alive.max_interval),
322            keep_alive,
323        }
324    }
325
326    fn reset(self: Pin<&mut Self>) {
327        let this = self.project();
328        this.alive_timer
329            .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
330    }
331
332    fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
333        let this = self.as_mut().project();
334
335        ready!(this.alive_timer.poll(cx));
336
337        let event = this.keep_alive.event.clone();
338
339        self.reset();
340
341        Poll::Ready(event)
342    }
343}