1use alloc::boxed::Box;
2use alloc::format;
3use alloc::string::String;
4use alloc::vec::Vec;
5use core::fmt;
6use core::pin::Pin;
7use core::task::{Context, Poll};
8use core::time::Duration;
9
10use futures_core::Stream;
11
12use crate::SdkError;
13
14#[derive(Clone, Debug, PartialEq, Eq)]
16pub enum ConnectionState {
17 Connecting,
19 Connected,
21 Reconnecting {
23 attempt: u32,
25 },
26 Disconnected {
28 reason: DisconnectReason,
30 },
31}
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum DisconnectReason {
36 Normal,
38 Error,
40 Timeout,
42}
43
44#[derive(Clone, Debug, PartialEq, Eq)]
46pub struct ConnectionEvent {
47 pub previous: ConnectionState,
49 pub current: ConnectionState,
51}
52
53impl ConnectionEvent {
54 #[must_use]
56 pub const fn new(previous: ConnectionState, current: ConnectionState) -> Self {
57 Self { previous, current }
58 }
59}
60
61pub struct ConnectionEvents<S> {
66 inner: S,
67}
68
69impl<S> ConnectionEvents<S> {
70 #[must_use]
72 pub const fn new(inner: S) -> Self {
73 Self { inner }
74 }
75
76 #[must_use]
78 pub const fn inner(&self) -> &S {
79 &self.inner
80 }
81
82 #[must_use]
84 pub fn into_inner(self) -> S {
85 self.inner
86 }
87}
88
89impl<S: Clone> Clone for ConnectionEvents<S> {
90 fn clone(&self) -> Self {
91 Self {
92 inner: self.inner.clone(),
93 }
94 }
95}
96
97impl<S> fmt::Debug for ConnectionEvents<S> {
98 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
99 formatter.debug_struct("ConnectionEvents").finish()
100 }
101}
102
103impl<S> Stream for ConnectionEvents<S>
104where
105 S: Stream<Item = ConnectionEvent> + Unpin,
106{
107 type Item = ConnectionEvent;
108
109 fn poll_next(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
110 let stream = &mut self.as_mut().get_mut().inner;
111 Pin::new(stream).poll_next(context)
112 }
113}
114
115#[derive(Clone, Copy, Debug, PartialEq, Eq)]
117pub struct ReconnectConfig {
118 pub base_delay: Duration,
120 pub max_delay: Duration,
122}
123
124impl ReconnectConfig {
125 #[must_use]
127 pub const fn new(base_delay: Duration, max_delay: Duration) -> Self {
128 Self {
129 base_delay,
130 max_delay,
131 }
132 }
133
134 #[must_use]
136 pub fn capped_delay(self, attempt: u32) -> Duration {
137 let base_nanos = self.base_delay.as_nanos();
138 let max_nanos = self.max_delay.as_nanos();
139
140 if base_nanos == 0 || max_nanos == 0 {
141 return Duration::ZERO;
142 }
143
144 let multiplier = 1_u128.checked_shl(attempt).unwrap_or(u128::MAX);
145 let scaled_nanos = base_nanos.saturating_mul(multiplier);
146 duration_from_nanos(core::cmp::min(scaled_nanos, max_nanos))
147 }
148
149 pub fn retry_delay<J>(self, attempt: u32, jitter: &mut J) -> Result<Duration, SdkError>
156 where
157 J: ReconnectJitter + ?Sized,
158 {
159 let capped_delay = self.capped_delay(attempt);
160 let jitter_delay = jitter.jitter(attempt, capped_delay);
161 self.retry_delay_with_jitter(attempt, jitter_delay)
162 }
163
164 pub fn retry_delay_with_jitter(
174 self,
175 attempt: u32,
176 jitter: Duration,
177 ) -> Result<Duration, SdkError> {
178 let capped_delay = self.capped_delay(attempt);
179 let jitter_limit = capped_delay / 2;
180
181 if jitter > jitter_limit {
182 return Err(connection_error(format!(
183 "reconnect jitter {jitter:?} exceeds 50% limit {jitter_limit:?}"
184 )));
185 }
186
187 Ok(capped_delay.checked_add(jitter).unwrap_or(Duration::MAX))
188 }
189}
190
191impl Default for ReconnectConfig {
192 fn default() -> Self {
193 Self::new(Duration::from_millis(100), Duration::from_secs(30))
194 }
195}
196
197pub trait ReconnectJitter: fmt::Debug {
203 fn jitter(&mut self, attempt: u32, capped_delay: Duration) -> Duration;
205}
206
207type ConnectionObserver = Box<dyn FnMut(&ConnectionEvent) + Send>;
208
209pub struct ConnectionLifecycle {
211 state: ConnectionState,
212 reconnect_config: ReconnectConfig,
213 next_reconnect_attempt: u32,
214 observers: Vec<ConnectionObserver>,
215}
216
217impl ConnectionLifecycle {
218 #[must_use]
220 pub fn new(reconnect_config: ReconnectConfig) -> Self {
221 Self {
222 state: ConnectionState::Connecting,
223 reconnect_config,
224 next_reconnect_attempt: 0,
225 observers: Vec::new(),
226 }
227 }
228
229 #[must_use]
231 pub const fn state(&self) -> &ConnectionState {
232 &self.state
233 }
234
235 #[must_use]
237 pub const fn reconnect_config(&self) -> ReconnectConfig {
238 self.reconnect_config
239 }
240
241 pub fn observe(&mut self, observer: impl FnMut(&ConnectionEvent) + Send + 'static) {
243 self.observers.push(Box::new(observer));
244 }
245
246 pub fn connect(&mut self) -> Result<(), SdkError> {
252 match self.state {
253 ConnectionState::Disconnected { .. } => {
254 self.transition(ConnectionState::Connecting);
255 Ok(())
256 }
257 _ => Err(invalid_transition(&self.state, "Connecting")),
258 }
259 }
260
261 pub fn connected(&mut self) -> Result<(), SdkError> {
267 match self.state {
268 ConnectionState::Connecting | ConnectionState::Reconnecting { .. } => {
269 self.next_reconnect_attempt = 0;
270 self.transition(ConnectionState::Connected);
271 Ok(())
272 }
273 _ => Err(invalid_transition(&self.state, "Connected")),
274 }
275 }
276
277 pub fn reconnect<J>(&mut self, jitter: &mut J) -> Result<Duration, SdkError>
288 where
289 J: ReconnectJitter + ?Sized,
290 {
291 match self.state {
292 ConnectionState::Connecting
293 | ConnectionState::Connected
294 | ConnectionState::Reconnecting { .. } => {
295 let attempt = self.next_reconnect_attempt;
296 let delay = self.reconnect_config.retry_delay(attempt, jitter)?;
297 self.next_reconnect_attempt = attempt.saturating_add(1);
298 self.transition(ConnectionState::Reconnecting { attempt });
299 Ok(delay)
300 }
301 ConnectionState::Disconnected { .. } => {
302 Err(invalid_transition(&self.state, "Reconnecting"))
303 }
304 }
305 }
306
307 pub fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), SdkError> {
313 match self.state {
314 ConnectionState::Connecting
315 | ConnectionState::Connected
316 | ConnectionState::Reconnecting { .. } => {
317 self.transition(ConnectionState::Disconnected { reason });
318 Ok(())
319 }
320 ConnectionState::Disconnected { .. } => {
321 Err(invalid_transition(&self.state, "Disconnected"))
322 }
323 }
324 }
325
326 fn transition(&mut self, next: ConnectionState) {
327 let previous = core::mem::replace(&mut self.state, next);
328 let event = ConnectionEvent::new(previous, self.state.clone());
329
330 for observer in &mut self.observers {
331 observer(&event);
332 }
333 }
334}
335
336impl Default for ConnectionLifecycle {
337 fn default() -> Self {
338 Self::new(ReconnectConfig::default())
339 }
340}
341
342impl fmt::Debug for ConnectionLifecycle {
343 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
344 formatter
345 .debug_struct("ConnectionLifecycle")
346 .field("state", &self.state)
347 .field("reconnect_config", &self.reconnect_config)
348 .field("next_reconnect_attempt", &self.next_reconnect_attempt)
349 .field("observers", &self.observers.len())
350 .finish()
351 }
352}
353
354fn invalid_transition(previous: &ConnectionState, requested: &str) -> SdkError {
355 connection_error(format!(
356 "invalid connection transition from {previous:?} to {requested}"
357 ))
358}
359
360const fn connection_error(description: String) -> SdkError {
361 SdkError::Connection { description }
362}
363
364fn duration_from_nanos(nanos: u128) -> Duration {
365 const NANOS_PER_SECOND: u128 = 1_000_000_000;
366
367 let seconds = nanos / NANOS_PER_SECOND;
368 let subsecond_nanos = nanos % NANOS_PER_SECOND;
369
370 let Ok(seconds) = u64::try_from(seconds) else {
371 return Duration::MAX;
372 };
373 let Ok(subsecond_nanos) = u32::try_from(subsecond_nanos) else {
374 return Duration::MAX;
375 };
376
377 Duration::new(seconds, subsecond_nanos)
378}
379
380#[cfg(test)]
381mod tests;