Skip to main content

fin_stream/ws/
mod.rs

1//! WebSocket connection management — auto-reconnect and backpressure.
2//!
3//! ## Responsibility
4//! Manage the lifecycle of a WebSocket feed connection: connect, receive
5//! messages, detect disconnections, apply exponential backoff reconnect,
6//! and propagate backpressure when the downstream channel is full.
7//!
8//! ## Guarantees
9//! - Non-panicking: all operations return Result
10//! - Configurable: reconnect policy and buffer sizes are constructor params
11
12use crate::error::StreamError;
13use std::time::Duration;
14
15/// Reconnection policy for a WebSocket feed.
16///
17/// Controls exponential-backoff reconnect behaviour. Build with
18/// [`ReconnectPolicy::new`] or use [`Default`] for sensible defaults.
19#[derive(Debug, Clone)]
20pub struct ReconnectPolicy {
21    /// Maximum number of reconnect attempts before giving up.
22    pub max_attempts: u32,
23    /// Initial backoff delay for the first reconnect attempt.
24    pub initial_backoff: Duration,
25    /// Maximum backoff delay (cap for exponential growth).
26    pub max_backoff: Duration,
27    /// Multiplier applied to the backoff on each successive attempt (must be >= 1.0).
28    pub multiplier: f64,
29}
30
31impl ReconnectPolicy {
32    /// Build a reconnect policy with explicit parameters.
33    ///
34    /// Returns an error if `multiplier < 1.0` (which would cause backoff to
35    /// shrink over time) or if `max_attempts == 0`.
36    pub fn new(
37        max_attempts: u32,
38        initial_backoff: Duration,
39        max_backoff: Duration,
40        multiplier: f64,
41    ) -> Result<Self, StreamError> {
42        if multiplier < 1.0 {
43            return Err(StreamError::ConnectionFailed {
44                url: String::new(),
45                reason: "reconnect multiplier must be >= 1.0".into(),
46            });
47        }
48        if max_attempts == 0 {
49            return Err(StreamError::ConnectionFailed {
50                url: String::new(),
51                reason: "max_attempts must be > 0".into(),
52            });
53        }
54        Ok(Self { max_attempts, initial_backoff, max_backoff, multiplier })
55    }
56
57    /// Backoff duration for attempt N (0-indexed).
58    pub fn backoff_for_attempt(&self, attempt: u32) -> Duration {
59        let factor = self.multiplier.powi(attempt as i32);
60        // Cap the f64 value *before* casting to u64.  When `attempt` is large
61        // (e.g. 63 with multiplier=2.0), `factor` becomes f64::INFINITY.
62        // Casting f64::INFINITY as u64 is undefined behaviour in Rust — it
63        // saturates to 0 on some targets and panics in debug builds.  Clamping
64        // to max_backoff in floating-point space first avoids the UB entirely.
65        let max_ms = self.max_backoff.as_millis() as f64;
66        let ms = (self.initial_backoff.as_millis() as f64 * factor).min(max_ms);
67        Duration::from_millis(ms as u64)
68    }
69}
70
71impl Default for ReconnectPolicy {
72    fn default() -> Self {
73        Self {
74            max_attempts: 10,
75            initial_backoff: Duration::from_millis(500),
76            max_backoff: Duration::from_secs(30),
77            multiplier: 2.0,
78        }
79    }
80}
81
82/// Configuration for a WebSocket feed connection.
83#[derive(Debug, Clone)]
84pub struct ConnectionConfig {
85    /// WebSocket URL to connect to (e.g. `"wss://stream.binance.com:9443/ws"`).
86    pub url: String,
87    /// Capacity of the downstream channel that receives incoming messages.
88    pub channel_capacity: usize,
89    /// Reconnect policy applied on disconnection.
90    pub reconnect: ReconnectPolicy,
91    /// Ping interval to keep the connection alive (default: 20 s).
92    pub ping_interval: Duration,
93}
94
95impl ConnectionConfig {
96    /// Build a connection configuration for `url` with the given downstream
97    /// channel capacity.
98    ///
99    /// Returns an error if `url` is empty or `channel_capacity` is zero.
100    pub fn new(url: impl Into<String>, channel_capacity: usize) -> Result<Self, StreamError> {
101        let url = url.into();
102        if url.is_empty() {
103            return Err(StreamError::ConnectionFailed {
104                url: url.clone(),
105                reason: "URL must not be empty".into(),
106            });
107        }
108        if channel_capacity == 0 {
109            return Err(StreamError::Backpressure {
110                channel: url.clone(),
111                depth: 0,
112                capacity: 0,
113            });
114        }
115        Ok(Self {
116            url,
117            channel_capacity,
118            reconnect: ReconnectPolicy::default(),
119            ping_interval: Duration::from_secs(20),
120        })
121    }
122
123    /// Override the default reconnect policy.
124    pub fn with_reconnect(mut self, policy: ReconnectPolicy) -> Self {
125        self.reconnect = policy;
126        self
127    }
128
129    /// Override the keepalive ping interval (default: 20 s).
130    pub fn with_ping_interval(mut self, interval: Duration) -> Self {
131        self.ping_interval = interval;
132        self
133    }
134}
135
136/// Manages a single WebSocket feed: connect, receive, reconnect.
137///
138/// In production, WsManager wraps tokio-tungstenite. In tests, it operates
139/// in simulation mode with injected messages.
140pub struct WsManager {
141    config: ConnectionConfig,
142    connect_attempts: u32,
143    is_connected: bool,
144}
145
146impl WsManager {
147    /// Create a new manager from a validated [`ConnectionConfig`].
148    pub fn new(config: ConnectionConfig) -> Self {
149        Self { config, connect_attempts: 0, is_connected: false }
150    }
151
152    /// Simulate a connection (for testing without live WebSocket).
153    /// Increments `connect_attempts` to reflect the initial connection slot.
154    pub fn connect_simulated(&mut self) {
155        self.connect_attempts += 1;
156        self.is_connected = true;
157    }
158
159    /// Simulate a disconnection.
160    pub fn disconnect_simulated(&mut self) {
161        self.is_connected = false;
162    }
163
164    /// Whether the managed connection is currently in the connected state.
165    pub fn is_connected(&self) -> bool { self.is_connected }
166
167    /// Total connection attempts made so far (including the initial connect).
168    pub fn connect_attempts(&self) -> u32 { self.connect_attempts }
169
170    /// The configuration this manager was created with.
171    pub fn config(&self) -> &ConnectionConfig { &self.config }
172
173    /// Check whether the next reconnect attempt is allowed.
174    pub fn can_reconnect(&self) -> bool {
175        self.connect_attempts < self.config.reconnect.max_attempts
176    }
177
178    /// Consume a reconnect slot and return the backoff duration to wait.
179    pub fn next_reconnect_backoff(&mut self) -> Result<Duration, StreamError> {
180        if !self.can_reconnect() {
181            return Err(StreamError::ReconnectExhausted {
182                url: self.config.url.clone(),
183                attempts: self.connect_attempts,
184            });
185        }
186        let backoff = self.config.reconnect.backoff_for_attempt(self.connect_attempts);
187        self.connect_attempts += 1;
188        Ok(backoff)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    fn default_config() -> ConnectionConfig {
197        ConnectionConfig::new("wss://example.com/ws", 1024).unwrap()
198    }
199
200    #[test]
201    fn test_reconnect_policy_default_values() {
202        let p = ReconnectPolicy::default();
203        assert_eq!(p.max_attempts, 10);
204        assert_eq!(p.multiplier, 2.0);
205    }
206
207    #[test]
208    fn test_reconnect_policy_backoff_exponential() {
209        let p = ReconnectPolicy::new(10, Duration::from_millis(100), Duration::from_secs(30), 2.0).unwrap();
210        assert_eq!(p.backoff_for_attempt(0), Duration::from_millis(100));
211        assert_eq!(p.backoff_for_attempt(1), Duration::from_millis(200));
212        assert_eq!(p.backoff_for_attempt(2), Duration::from_millis(400));
213    }
214
215    #[test]
216    fn test_reconnect_policy_backoff_capped_at_max() {
217        let p = ReconnectPolicy::new(10, Duration::from_millis(1000), Duration::from_secs(5), 2.0).unwrap();
218        let backoff = p.backoff_for_attempt(10);
219        assert!(backoff <= Duration::from_secs(5));
220    }
221
222    #[test]
223    fn test_reconnect_policy_multiplier_below_1_rejected() {
224        let result = ReconnectPolicy::new(10, Duration::from_millis(100), Duration::from_secs(30), 0.5);
225        assert!(result.is_err());
226    }
227
228    #[test]
229    fn test_reconnect_policy_zero_attempts_rejected() {
230        let result = ReconnectPolicy::new(0, Duration::from_millis(100), Duration::from_secs(30), 2.0);
231        assert!(result.is_err());
232    }
233
234    #[test]
235    fn test_connection_config_empty_url_rejected() {
236        let result = ConnectionConfig::new("", 1024);
237        assert!(result.is_err());
238    }
239
240    #[test]
241    fn test_connection_config_zero_capacity_rejected() {
242        let result = ConnectionConfig::new("wss://example.com", 0);
243        assert!(result.is_err());
244    }
245
246    #[test]
247    fn test_connection_config_with_reconnect() {
248        let policy = ReconnectPolicy::new(3, Duration::from_millis(200), Duration::from_secs(10), 2.0).unwrap();
249        let config = default_config().with_reconnect(policy);
250        assert_eq!(config.reconnect.max_attempts, 3);
251    }
252
253    #[test]
254    fn test_connection_config_with_ping_interval() {
255        let config = default_config().with_ping_interval(Duration::from_secs(30));
256        assert_eq!(config.ping_interval, Duration::from_secs(30));
257    }
258
259    #[test]
260    fn test_ws_manager_initial_state() {
261        let mgr = WsManager::new(default_config());
262        assert!(!mgr.is_connected());
263        assert_eq!(mgr.connect_attempts(), 0);
264    }
265
266    #[test]
267    fn test_ws_manager_connect_simulated() {
268        let mut mgr = WsManager::new(default_config());
269        mgr.connect_simulated();
270        assert!(mgr.is_connected());
271        assert_eq!(mgr.connect_attempts(), 1);
272    }
273
274    #[test]
275    fn test_ws_manager_disconnect_simulated() {
276        let mut mgr = WsManager::new(default_config());
277        mgr.connect_simulated();
278        mgr.disconnect_simulated();
279        assert!(!mgr.is_connected());
280    }
281
282    #[test]
283    fn test_ws_manager_can_reconnect_within_limit() {
284        let mut mgr = WsManager::new(
285            default_config().with_reconnect(
286                ReconnectPolicy::new(3, Duration::from_millis(10), Duration::from_secs(1), 2.0).unwrap()
287            )
288        );
289        assert!(mgr.can_reconnect());
290        mgr.next_reconnect_backoff().unwrap();
291        mgr.next_reconnect_backoff().unwrap();
292        mgr.next_reconnect_backoff().unwrap();
293        assert!(!mgr.can_reconnect());
294    }
295
296    #[test]
297    fn test_ws_manager_reconnect_exhausted_error() {
298        let mut mgr = WsManager::new(
299            default_config().with_reconnect(
300                ReconnectPolicy::new(1, Duration::from_millis(10), Duration::from_secs(1), 2.0).unwrap()
301            )
302        );
303        mgr.next_reconnect_backoff().unwrap();
304        let result = mgr.next_reconnect_backoff();
305        assert!(matches!(result, Err(StreamError::ReconnectExhausted { .. })));
306    }
307
308    #[test]
309    fn test_ws_manager_backoff_increases() {
310        let mut mgr = WsManager::new(
311            default_config().with_reconnect(
312                ReconnectPolicy::new(5, Duration::from_millis(100), Duration::from_secs(30), 2.0).unwrap()
313            )
314        );
315        let b0 = mgr.next_reconnect_backoff().unwrap();
316        let b1 = mgr.next_reconnect_backoff().unwrap();
317        assert!(b1 >= b0);
318    }
319
320    #[test]
321    fn test_ws_manager_config_accessor() {
322        let mgr = WsManager::new(default_config());
323        assert_eq!(mgr.config().url, "wss://example.com/ws");
324        assert_eq!(mgr.config().channel_capacity, 1024);
325    }
326}