Skip to main content

eventsrc_client/replayable/
mod.rs

1use std::{
2    fmt,
3    fmt::Debug,
4    mem,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8    time::Duration,
9};
10
11use eventsrc::{Event, Frame, FrameStream, StreamError};
12use futures_core::Stream;
13use tokio::time;
14
15mod connector;
16mod retry;
17
18pub use self::{
19    connector::{BoxBodyStream, ConnectFuture, Connector},
20    retry::*,
21};
22use crate::error::{Error, ErrorKind};
23
24type BoxFrameStream = Pin<Box<FrameStream<BoxBodyStream>>>;
25type SleepFuture = Pin<Box<time::Sleep>>;
26
27enum ConnectionState {
28    Idle,
29    Connecting(ConnectFuture),
30    Streaming(BoxFrameStream),
31    Waiting(SleepFuture),
32    Closed,
33}
34
35/// Reconnecting SSE event source backed by a backend-neutral connector.
36pub struct EventSource {
37    connector: Arc<dyn Connector>,
38    retry_policy: Arc<dyn RetryPolicy>,
39    last_event_id: Option<String>,
40    server_retry_delay: Option<Duration>,
41    consecutive_failures: usize,
42    state: ConnectionState,
43}
44
45impl Clone for EventSource {
46    fn clone(&self) -> Self {
47        Self {
48            connector: self.connector.clone(),
49            retry_policy: self.retry_policy.clone(),
50            last_event_id: self.last_event_id.clone(),
51            server_retry_delay: self.server_retry_delay,
52            consecutive_failures: self.consecutive_failures,
53            state: ConnectionState::Idle,
54        }
55    }
56}
57
58impl Debug for EventSource {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.debug_struct("replayable::EventSource")
61            .field("connector", &self.connector)
62            .field("retry_policy", &self.retry_policy)
63            .field("last_event_id", &self.last_event_id)
64            .field("server_retry_delay", &self.server_retry_delay)
65            .field("consecutive_failures", &self.consecutive_failures)
66            .finish()
67    }
68}
69
70impl EventSource {
71    /// Creates a reconnecting event source from a backend-neutral connector.
72    pub fn new<C>(connector: C) -> Self
73    where
74        C: Connector,
75    {
76        let retry_policy = Arc::new(ConstantBackoff::default());
77
78        Self {
79            connector: Arc::new(connector),
80            retry_policy,
81            last_event_id: None,
82            server_retry_delay: None,
83            consecutive_failures: 0,
84            state: ConnectionState::Idle,
85        }
86    }
87
88    /// Replaces the reconnect timing policy.
89    ///
90    /// Built-in policies include [`ConstantBackoff`], [`crate::replayable::ExponentialBackoff`],
91    /// and [`crate::replayable::NeverRetry`].
92    /// Custom policies may implement [`crate::replayable::RetryPolicy`] directly.
93    pub fn with_retry_policy<P>(mut self, retry_policy: P) -> Self
94    where
95        P: RetryPolicy,
96    {
97        self.retry_policy = Arc::new(retry_policy);
98        self
99    }
100
101    /// Returns the current effective `Last-Event-ID`, if one is stored.
102    ///
103    /// This value is updated from the SSE stream and reused on future reconnect
104    /// attempts through the underlying connector.
105    pub fn last_event_id(&self) -> Option<&str> {
106        self.last_event_id.as_deref()
107    }
108
109    /// Replaces the stored `Last-Event-ID`.
110    ///
111    /// This affects the next reconnect attempt and overrides any previously
112    /// remembered value until the stream updates it again.
113    pub fn set_last_event_id(&mut self, last_event_id: impl Into<String>) {
114        self.last_event_id = Some(last_event_id.into());
115    }
116
117    /// Clears the stored `Last-Event-ID`.
118    ///
119    /// Future reconnect attempts will omit the header until the stream observes
120    /// a new event id.
121    pub fn clear_last_event_id(&mut self) {
122        self.last_event_id = None;
123    }
124
125    fn connect(&self) -> Result<ConnectFuture, Error> {
126        self.connector.connect(self.last_event_id.as_deref())
127    }
128
129    fn update_last_event_id_from_stream(&mut self, stream: &BoxFrameStream) {
130        let last_event_id = stream.as_ref().get_ref().last_event_id();
131
132        if last_event_id.is_empty() {
133            self.last_event_id = None;
134        } else {
135            self.last_event_id = Some(last_event_id.to_owned());
136        }
137    }
138
139    fn schedule_reconnect(&mut self, cause: RetryCause) -> bool {
140        let failure_streak = match cause {
141            RetryCause::Disconnect => 0,
142            RetryCause::ConnectError | RetryCause::StreamError => self.consecutive_failures + 1,
143        };
144
145        let Some(delay) = self.retry_policy.next_delay(RetryContext {
146            cause,
147            failure_streak,
148            server_retry: self.server_retry_delay,
149        }) else {
150            self.state = ConnectionState::Closed;
151            return false;
152        };
153
154        if matches!(cause, RetryCause::ConnectError | RetryCause::StreamError) {
155            self.consecutive_failures += 1;
156        }
157
158        self.state = ConnectionState::Waiting(Box::pin(tokio::time::sleep(delay)));
159        true
160    }
161}
162
163impl Stream for EventSource {
164    type Item = Result<Event, Error>;
165
166    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
167        let this = self.get_mut();
168        let mut scheduled_reconnect = false;
169
170        loop {
171            match mem::replace(&mut this.state, ConnectionState::Closed) {
172                ConnectionState::Idle => match this.connect() {
173                    Ok(connect) => {
174                        this.state = ConnectionState::Connecting(connect);
175                    },
176                    Err(error) => return Poll::Ready(Some(Err(error))),
177                },
178                ConnectionState::Connecting(mut connect) => match connect.as_mut().poll(cx) {
179                    Poll::Pending => {
180                        this.state = ConnectionState::Connecting(connect);
181                        return Poll::Pending;
182                    },
183                    Poll::Ready(Ok(body)) => {
184                        let stream = match this.last_event_id.as_deref() {
185                            Some(last_event_id) => {
186                                eventsrc::FrameStream::new(body).with_last_event_id(last_event_id)
187                            },
188                            None => eventsrc::FrameStream::new(body),
189                        };
190
191                        this.consecutive_failures = 0;
192                        this.state = ConnectionState::Streaming(Box::pin(stream));
193                    },
194                    Poll::Ready(Err(err)) => {
195                        if err.kind() == ErrorKind::Transport
196                            && this.schedule_reconnect(RetryCause::ConnectError)
197                        {
198                            scheduled_reconnect = true;
199                            continue;
200                        }
201                        return Poll::Ready(Some(Err(err)));
202                    },
203                },
204                ConnectionState::Streaming(mut stream) => match stream.as_mut().poll_next(cx) {
205                    Poll::Pending => {
206                        this.state = ConnectionState::Streaming(stream);
207                        return Poll::Pending;
208                    },
209                    Poll::Ready(Some(Ok(Frame::Retry(delay)))) => {
210                        this.server_retry_delay = Some(delay);
211                        this.state = ConnectionState::Streaming(stream);
212                    },
213                    Poll::Ready(Some(Ok(Frame::Event(event)))) => {
214                        this.update_last_event_id_from_stream(&stream);
215                        this.state = ConnectionState::Streaming(stream);
216                        return Poll::Ready(Some(Ok(event)));
217                    },
218                    Poll::Ready(Some(Err(StreamError::Protocol(error)))) => {
219                        this.update_last_event_id_from_stream(&stream);
220                        return Poll::Ready(Some(Err(error.into())));
221                    },
222                    Poll::Ready(Some(Err(StreamError::Source(error)))) => {
223                        this.update_last_event_id_from_stream(&stream);
224
225                        if this.schedule_reconnect(RetryCause::StreamError) {
226                            scheduled_reconnect = true;
227                            continue;
228                        }
229
230                        return Poll::Ready(Some(Err(error)));
231                    },
232                    Poll::Ready(None) => {
233                        this.update_last_event_id_from_stream(&stream);
234                        let _ = this.schedule_reconnect(RetryCause::Disconnect);
235                        scheduled_reconnect = true;
236                        continue;
237                    },
238                },
239                ConnectionState::Waiting(mut sleep) => match sleep.as_mut().poll(cx) {
240                    Poll::Pending => {
241                        this.state = ConnectionState::Waiting(sleep);
242                        return Poll::Pending;
243                    },
244                    Poll::Ready(()) => {
245                        this.state = ConnectionState::Idle;
246
247                        if scheduled_reconnect {
248                            cx.waker().wake_by_ref();
249                            return Poll::Pending;
250                        }
251                    },
252                },
253                ConnectionState::Closed => return Poll::Ready(None),
254            }
255        }
256    }
257}
258
259/// Extension methods for building reconnecting SSE event sources from backend-specific clients.
260pub trait EventSourceExt {
261    /// Converts this backend-specific request source into a reconnecting [`EventSource`].
262    fn event_source(self) -> Result<EventSource, Error>;
263}