Skip to main content

liminal_sdk/connection/
lifecycle.rs

1use alloc::boxed::Box;
2use alloc::format;
3use alloc::string::String;
4use alloc::vec::Vec;
5use core::fmt;
6use core::pin::Pin;
7use core::task::{Context, Poll};
8use core::time::Duration;
9
10use futures_core::Stream;
11
12use crate::SdkError;
13
14/// Application-visible lifecycle state for a remote SDK connection.
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub enum ConnectionState {
17    /// The SDK is establishing a connection.
18    Connecting,
19    /// The SDK has an active connection.
20    Connected,
21    /// The SDK is attempting to reconnect after a disruption.
22    Reconnecting {
23        /// Zero-based reconnect attempt counter for this disruption.
24        attempt: u32,
25    },
26    /// The SDK is disconnected and will not become connected without reconnecting.
27    Disconnected {
28        /// Reason the SDK entered the disconnected state.
29        reason: DisconnectReason,
30    },
31}
32
33/// Reason a connection entered the disconnected state.
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum DisconnectReason {
36    /// The connection was closed intentionally.
37    Normal,
38    /// The connection closed because of an error.
39    Error,
40    /// The connection closed because a timeout elapsed.
41    Timeout,
42}
43
44/// Event emitted after a connection lifecycle transition succeeds.
45#[derive(Clone, Debug, PartialEq, Eq)]
46pub struct ConnectionEvent {
47    /// State before the transition.
48    pub previous: ConnectionState,
49    /// State after the transition.
50    pub current: ConnectionState,
51}
52
53impl ConnectionEvent {
54    /// Creates a connection transition event.
55    #[must_use]
56    pub const fn new(previous: ConnectionState, current: ConnectionState) -> Self {
57        Self { previous, current }
58    }
59}
60
61/// Stream wrapper for observing connection lifecycle events.
62///
63/// Concrete SDK clients can wrap their runtime-specific event stream in this
64/// type while exposing a stable SDK item type of [`ConnectionEvent`].
65pub struct ConnectionEvents<S> {
66    inner: S,
67}
68
69impl<S> ConnectionEvents<S> {
70    /// Wraps a stream that yields connection transition events.
71    #[must_use]
72    pub const fn new(inner: S) -> Self {
73        Self { inner }
74    }
75
76    /// Returns a shared reference to the wrapped stream.
77    #[must_use]
78    pub const fn inner(&self) -> &S {
79        &self.inner
80    }
81
82    /// Unwraps the runtime-specific event stream.
83    #[must_use]
84    pub fn into_inner(self) -> S {
85        self.inner
86    }
87}
88
89impl<S: Clone> Clone for ConnectionEvents<S> {
90    fn clone(&self) -> Self {
91        Self {
92            inner: self.inner.clone(),
93        }
94    }
95}
96
97impl<S> fmt::Debug for ConnectionEvents<S> {
98    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
99        formatter.debug_struct("ConnectionEvents").finish()
100    }
101}
102
103impl<S> Stream for ConnectionEvents<S>
104where
105    S: Stream<Item = ConnectionEvent> + Unpin,
106{
107    type Item = ConnectionEvent;
108
109    fn poll_next(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
110        let stream = &mut self.as_mut().get_mut().inner;
111        Pin::new(stream).poll_next(context)
112    }
113}
114
115/// Configures exponential reconnect backoff.
116#[derive(Clone, Copy, Debug, PartialEq, Eq)]
117pub struct ReconnectConfig {
118    /// Initial retry delay before exponential growth is applied.
119    pub base_delay: Duration,
120    /// Maximum delay used for the exponential component before jitter is added.
121    pub max_delay: Duration,
122}
123
124impl ReconnectConfig {
125    /// Creates a reconnect configuration from base and maximum delays.
126    #[must_use]
127    pub const fn new(base_delay: Duration, max_delay: Duration) -> Self {
128        Self {
129            base_delay,
130            max_delay,
131        }
132    }
133
134    /// Computes `min(base_delay * 2^attempt, max_delay)` before jitter.
135    #[must_use]
136    pub fn capped_delay(self, attempt: u32) -> Duration {
137        let base_nanos = self.base_delay.as_nanos();
138        let max_nanos = self.max_delay.as_nanos();
139
140        if base_nanos == 0 || max_nanos == 0 {
141            return Duration::ZERO;
142        }
143
144        let multiplier = 1_u128.checked_shl(attempt).unwrap_or(u128::MAX);
145        let scaled_nanos = base_nanos.saturating_mul(multiplier);
146        duration_from_nanos(core::cmp::min(scaled_nanos, max_nanos))
147    }
148
149    /// Computes the retry delay for an attempt using an injected random jitter source.
150    ///
151    /// # Errors
152    ///
153    /// Returns [`SdkError`] if the jitter source returns a value above 50% of the
154    /// capped exponential delay.
155    pub fn retry_delay<J>(self, attempt: u32, jitter: &mut J) -> Result<Duration, SdkError>
156    where
157        J: ReconnectJitter + ?Sized,
158    {
159        let capped_delay = self.capped_delay(attempt);
160        let jitter_delay = jitter.jitter(attempt, capped_delay);
161        self.retry_delay_with_jitter(attempt, jitter_delay)
162    }
163
164    /// Computes the retry delay for an attempt using a precomputed jitter value.
165    ///
166    /// This helper is useful for deterministic tests and for transport layers
167    /// that produce randomness externally. The jitter value must be between zero
168    /// and 50% of the capped exponential delay.
169    ///
170    /// # Errors
171    ///
172    /// Returns [`SdkError`] if `jitter` is greater than 50% of the capped delay.
173    pub fn retry_delay_with_jitter(
174        self,
175        attempt: u32,
176        jitter: Duration,
177    ) -> Result<Duration, SdkError> {
178        let capped_delay = self.capped_delay(attempt);
179        let jitter_limit = capped_delay / 2;
180
181        if jitter > jitter_limit {
182            return Err(connection_error(format!(
183                "reconnect jitter {jitter:?} exceeds 50% limit {jitter_limit:?}"
184            )));
185        }
186
187        Ok(capped_delay.checked_add(jitter).unwrap_or(Duration::MAX))
188    }
189}
190
191impl Default for ReconnectConfig {
192    fn default() -> Self {
193        Self::new(Duration::from_millis(100), Duration::from_secs(30))
194    }
195}
196
197/// Source of random reconnect jitter.
198///
199/// Implementations must return a random value from zero through 50% of the
200/// supplied capped delay. The SDK validates that upper bound before using the
201/// value so reconnection never falls back to a fixed retry interval.
202pub trait ReconnectJitter: fmt::Debug {
203    /// Returns random jitter for a reconnect attempt and capped base delay.
204    fn jitter(&mut self, attempt: u32, capped_delay: Duration) -> Duration;
205}
206
207type ConnectionObserver = Box<dyn FnMut(&ConnectionEvent) + Send>;
208
209/// Owns the SDK connection lifecycle state and emits validated transitions.
210pub struct ConnectionLifecycle {
211    state: ConnectionState,
212    reconnect_config: ReconnectConfig,
213    next_reconnect_attempt: u32,
214    observers: Vec<ConnectionObserver>,
215}
216
217impl ConnectionLifecycle {
218    /// Creates a lifecycle in the [`ConnectionState::Connecting`] state.
219    #[must_use]
220    pub fn new(reconnect_config: ReconnectConfig) -> Self {
221        Self {
222            state: ConnectionState::Connecting,
223            reconnect_config,
224            next_reconnect_attempt: 0,
225            observers: Vec::new(),
226        }
227    }
228
229    /// Returns the current connection state.
230    #[must_use]
231    pub const fn state(&self) -> &ConnectionState {
232        &self.state
233    }
234
235    /// Returns the reconnect backoff configuration.
236    #[must_use]
237    pub const fn reconnect_config(&self) -> ReconnectConfig {
238        self.reconnect_config
239    }
240
241    /// Registers an observer that is called after each successful transition.
242    pub fn observe(&mut self, observer: impl FnMut(&ConnectionEvent) + Send + 'static) {
243        self.observers.push(Box::new(observer));
244    }
245
246    /// Transitions from [`ConnectionState::Disconnected`] to connecting.
247    ///
248    /// # Errors
249    ///
250    /// Returns [`SdkError`] when the lifecycle is not disconnected.
251    pub fn connect(&mut self) -> Result<(), SdkError> {
252        match self.state {
253            ConnectionState::Disconnected { .. } => {
254                self.transition(ConnectionState::Connecting);
255                Ok(())
256            }
257            _ => Err(invalid_transition(&self.state, "Connecting")),
258        }
259    }
260
261    /// Transitions from connecting or reconnecting to connected.
262    ///
263    /// # Errors
264    ///
265    /// Returns [`SdkError`] when the lifecycle is disconnected or already connected.
266    pub fn connected(&mut self) -> Result<(), SdkError> {
267        match self.state {
268            ConnectionState::Connecting | ConnectionState::Reconnecting { .. } => {
269                self.next_reconnect_attempt = 0;
270                self.transition(ConnectionState::Connected);
271                Ok(())
272            }
273            _ => Err(invalid_transition(&self.state, "Connected")),
274        }
275    }
276
277    /// Transitions to reconnecting and returns the next retry delay.
278    ///
279    /// The first reconnect attempt after a successful connection uses attempt
280    /// zero. Each subsequent reconnect attempt increments the counter until a
281    /// successful [`Self::connected`] transition resets it.
282    ///
283    /// # Errors
284    ///
285    /// Returns [`SdkError`] when reconnecting from the current state is invalid or
286    /// when the jitter source exceeds the allowed jitter range.
287    pub fn reconnect<J>(&mut self, jitter: &mut J) -> Result<Duration, SdkError>
288    where
289        J: ReconnectJitter + ?Sized,
290    {
291        match self.state {
292            ConnectionState::Connecting
293            | ConnectionState::Connected
294            | ConnectionState::Reconnecting { .. } => {
295                let attempt = self.next_reconnect_attempt;
296                let delay = self.reconnect_config.retry_delay(attempt, jitter)?;
297                self.next_reconnect_attempt = attempt.saturating_add(1);
298                self.transition(ConnectionState::Reconnecting { attempt });
299                Ok(delay)
300            }
301            ConnectionState::Disconnected { .. } => {
302                Err(invalid_transition(&self.state, "Reconnecting"))
303            }
304        }
305    }
306
307    /// Transitions to disconnected with the supplied reason.
308    ///
309    /// # Errors
310    ///
311    /// Returns [`SdkError`] when the lifecycle is already disconnected.
312    pub fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), SdkError> {
313        match self.state {
314            ConnectionState::Connecting
315            | ConnectionState::Connected
316            | ConnectionState::Reconnecting { .. } => {
317                self.transition(ConnectionState::Disconnected { reason });
318                Ok(())
319            }
320            ConnectionState::Disconnected { .. } => {
321                Err(invalid_transition(&self.state, "Disconnected"))
322            }
323        }
324    }
325
326    fn transition(&mut self, next: ConnectionState) {
327        let previous = core::mem::replace(&mut self.state, next);
328        let event = ConnectionEvent::new(previous, self.state.clone());
329
330        for observer in &mut self.observers {
331            observer(&event);
332        }
333    }
334}
335
336impl Default for ConnectionLifecycle {
337    fn default() -> Self {
338        Self::new(ReconnectConfig::default())
339    }
340}
341
342impl fmt::Debug for ConnectionLifecycle {
343    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
344        formatter
345            .debug_struct("ConnectionLifecycle")
346            .field("state", &self.state)
347            .field("reconnect_config", &self.reconnect_config)
348            .field("next_reconnect_attempt", &self.next_reconnect_attempt)
349            .field("observers", &self.observers.len())
350            .finish()
351    }
352}
353
354fn invalid_transition(previous: &ConnectionState, requested: &str) -> SdkError {
355    connection_error(format!(
356        "invalid connection transition from {previous:?} to {requested}"
357    ))
358}
359
360const fn connection_error(description: String) -> SdkError {
361    SdkError::Connection { description }
362}
363
364fn duration_from_nanos(nanos: u128) -> Duration {
365    const NANOS_PER_SECOND: u128 = 1_000_000_000;
366
367    let seconds = nanos / NANOS_PER_SECOND;
368    let subsecond_nanos = nanos % NANOS_PER_SECOND;
369
370    let Ok(seconds) = u64::try_from(seconds) else {
371        return Duration::MAX;
372    };
373    let Ok(subsecond_nanos) = u32::try_from(subsecond_nanos) else {
374        return Duration::MAX;
375    };
376
377    Duration::new(seconds, subsecond_nanos)
378}
379
380#[cfg(test)]
381mod tests;