1use std::{
2 pin::Pin,
3 task::{Context, Poll, ready},
4 time::Duration,
5};
6
7use bytes::Bytes;
8use futures_util::Stream;
9use http_body::{Body, Frame};
10
11use crate::Sse;
12
13pin_project_lite::pin_project! {
14 pub struct SseBody<S, T = NeverTimer> {
15 #[pin]
16 pub event_stream: S,
17 #[pin]
18 pub keep_alive: Option<KeepAliveStream<T>>,
19 }
20}
21
22impl<S, E> SseBody<S, NeverTimer>
23where
24 S: Stream<Item = Result<Sse, E>>,
25{
26 pub fn new(stream: S) -> Self {
27 Self {
28 event_stream: stream,
29 keep_alive: None,
30 }
31 }
32}
33
34impl<S, E, T> SseBody<S, T>
35where
36 S: Stream<Item = Result<Sse, E>>,
37 T: Timer,
38{
39 pub fn new_keep_alive(stream: S, keep_alive: KeepAlive) -> Self {
40 Self {
41 event_stream: stream,
42 keep_alive: Some(KeepAliveStream::new(keep_alive)),
43 }
44 }
45
46 pub fn with_keep_alive<T2: Timer>(self, keep_alive: KeepAlive) -> SseBody<S, T2> {
47 SseBody {
48 event_stream: self.event_stream,
49 keep_alive: Some(KeepAliveStream::new(keep_alive)),
50 }
51 }
52}
53
54impl<S, E, T> Body for SseBody<S, T>
55where
56 S: Stream<Item = Result<Sse, E>>,
57 T: Timer,
58{
59 type Data = Bytes;
60 type Error = E;
61
62 fn poll_frame(
63 self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
66 let this = self.project();
67
68 match this.event_stream.poll_next(cx) {
69 Poll::Pending => {
70 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
71 keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
72 } else {
73 Poll::Pending
74 }
75 }
76 Poll::Ready(Some(Ok(event))) => {
77 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
78 keep_alive.reset();
79 }
80 Poll::Ready(Some(Ok(Frame::data(event.into()))))
81 }
82 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
83 Poll::Ready(None) => Poll::Ready(None),
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
91#[must_use]
92pub struct KeepAlive {
93 event: Bytes,
94 max_interval: Duration,
95}
96
97impl KeepAlive {
98 pub fn new() -> Self {
100 Self {
101 event: Bytes::from_static(b":\n\n"),
102 max_interval: Duration::from_secs(15),
103 }
104 }
105
106 pub fn interval(mut self, time: Duration) -> Self {
110 self.max_interval = time;
111 self
112 }
113
114 pub fn event(mut self, event: Sse) -> Self {
123 self.event = event.into();
124 self
125 }
126
127 pub fn comment(mut self, comment: &str) -> Self {
129 self.event = format!(": {}\n\n", comment).into();
130 self
131 }
132}
133
134impl Default for KeepAlive {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140pub trait Timer: Future<Output = ()> {
141 fn reset(self: Pin<&mut Self>, instant: std::time::Instant);
142 fn from_duration(duration: Duration) -> Self;
143}
144
145pub struct NeverTimer;
146
147impl Future for NeverTimer {
148 type Output = ();
149
150 fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
151 Poll::Pending
152 }
153}
154
155impl Timer for NeverTimer {
156 fn from_duration(_: Duration) -> Self {
157 Self
158 }
159
160 fn reset(self: Pin<&mut Self>, _: std::time::Instant) {
161 }
163}
164
165pin_project_lite::pin_project! {
166 #[derive(Debug)]
167 struct KeepAliveStream<S> {
168 keep_alive: KeepAlive,
169 #[pin]
170 alive_timer: S,
171 }
172}
173
174impl<S> KeepAliveStream<S>
175where
176 S: Timer,
177{
178 fn new(keep_alive: KeepAlive) -> Self {
179 Self {
180 alive_timer: S::from_duration(keep_alive.max_interval),
181 keep_alive,
182 }
183 }
184
185 fn reset(self: Pin<&mut Self>) {
186 let this = self.project();
187 this.alive_timer
188 .reset(std::time::Instant::now() + this.keep_alive.max_interval);
189 }
190
191 fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
192 let this = self.as_mut().project();
193
194 ready!(this.alive_timer.poll(cx));
195
196 let event = this.keep_alive.event.clone();
197
198 self.reset();
199
200 Poll::Ready(event)
201 }
202}