1use std::{
2 pin::Pin,
3 task::{ready, Context, Poll},
4 time::Duration,
5};
6
7use crate::Sse;
8use bytes::Bytes;
9use futures_util::Stream;
10use http_body::{Body, Frame};
11use std::future::Future;
12pin_project_lite::pin_project! {
13 pub struct SseBody<S, T = NeverTimer> {
14 #[pin]
15 pub event_stream: S,
16 #[pin]
17 pub keep_alive: Option<KeepAliveStream<T>>,
18 }
19}
20
21impl<S, E> SseBody<S, NeverTimer>
22where
23 S: Stream<Item = Result<Sse, E>>,
24{
25 pub fn new(stream: S) -> Self {
26 Self {
27 event_stream: stream,
28 keep_alive: None,
29 }
30 }
31}
32
33impl<S, E, T> SseBody<S, T>
34where
35 S: Stream<Item = Result<Sse, E>>,
36 T: Timer,
37{
38 pub fn new_keep_alive(stream: S, keep_alive: KeepAlive) -> Self {
39 Self {
40 event_stream: stream,
41 keep_alive: Some(KeepAliveStream::new(keep_alive)),
42 }
43 }
44
45 pub fn with_keep_alive<T2: Timer>(self, keep_alive: KeepAlive) -> SseBody<S, T2> {
46 SseBody {
47 event_stream: self.event_stream,
48 keep_alive: Some(KeepAliveStream::new(keep_alive)),
49 }
50 }
51}
52
53impl<S, E, T> Body for SseBody<S, T>
54where
55 S: Stream<Item = Result<Sse, E>>,
56 T: Timer,
57{
58 type Data = Bytes;
59 type Error = E;
60
61 fn poll_frame(
62 self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
65 let this = self.project();
66
67 match this.event_stream.poll_next(cx) {
68 Poll::Pending => {
69 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
70 keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
71 } else {
72 Poll::Pending
73 }
74 }
75 Poll::Ready(Some(Ok(event))) => {
76 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
77 keep_alive.reset();
78 }
79 Poll::Ready(Some(Ok(Frame::data(event.into()))))
80 }
81 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
82 Poll::Ready(None) => Poll::Ready(None),
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
90#[must_use]
91pub struct KeepAlive {
92 event: Bytes,
93 max_interval: Duration,
94}
95
96impl KeepAlive {
97 pub fn new() -> Self {
99 Self {
100 event: Bytes::from_static(b":\n\n"),
101 max_interval: Duration::from_secs(15),
102 }
103 }
104
105 pub fn interval(mut self, time: Duration) -> Self {
109 self.max_interval = time;
110 self
111 }
112
113 pub fn event(mut self, event: Sse) -> Self {
122 self.event = event.into();
123 self
124 }
125
126 pub fn comment(mut self, comment: &str) -> Self {
128 self.event = format!(": {}\n\n", comment).into();
129 self
130 }
131}
132
133impl Default for KeepAlive {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139pub trait Timer: Future<Output = ()> {
140 fn reset(self: Pin<&mut Self>, instant: std::time::Instant);
141 fn from_duration(duration: Duration) -> Self;
142}
143
144pub struct NeverTimer;
145
146impl Future for NeverTimer {
147 type Output = ();
148
149 fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
150 Poll::Pending
151 }
152}
153
154impl Timer for NeverTimer {
155 fn from_duration(_: Duration) -> Self {
156 Self
157 }
158
159 fn reset(self: Pin<&mut Self>, _: std::time::Instant) {
160 }
162}
163
164pin_project_lite::pin_project! {
165 #[derive(Debug)]
166 struct KeepAliveStream<S> {
167 keep_alive: KeepAlive,
168 #[pin]
169 alive_timer: S,
170 }
171}
172
173impl<S> KeepAliveStream<S>
174where
175 S: Timer,
176{
177 fn new(keep_alive: KeepAlive) -> Self {
178 Self {
179 alive_timer: S::from_duration(keep_alive.max_interval),
180 keep_alive,
181 }
182 }
183
184 fn reset(self: Pin<&mut Self>) {
185 let this = self.project();
186 this.alive_timer
187 .reset(std::time::Instant::now() + this.keep_alive.max_interval);
188 }
189
190 fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
191 let this = self.as_mut().project();
192
193 ready!(this.alive_timer.poll(cx));
194
195 let event = this.keep_alive.event.clone();
196
197 self.reset();
198
199 Poll::Ready(event)
200 }
201}