Skip to main content

fin_stream/ws/
mod.rs

1//! WebSocket connection management — auto-reconnect and backpressure.
2//!
3//! ## Responsibility
4//! Manage the lifecycle of a WebSocket feed connection: connect, receive
5//! messages, detect disconnections, apply exponential backoff reconnect,
6//! and propagate backpressure when the downstream channel is full.
7//!
8//! ## Guarantees
9//! - Non-panicking: all operations return Result
10//! - Configurable: reconnect policy and buffer sizes are constructor params
11
12use crate::error::StreamError;
13use std::time::Duration;
14
15/// Reconnection policy for a WebSocket feed.
16///
17/// Controls exponential-backoff reconnect behaviour. Build with
18/// [`ReconnectPolicy::new`] or use [`Default`] for sensible defaults.
19#[derive(Debug, Clone)]
20pub struct ReconnectPolicy {
21    /// Maximum number of reconnect attempts before giving up.
22    pub max_attempts: u32,
23    /// Initial backoff delay for the first reconnect attempt.
24    pub initial_backoff: Duration,
25    /// Maximum backoff delay (cap for exponential growth).
26    pub max_backoff: Duration,
27    /// Multiplier applied to the backoff on each successive attempt (must be >= 1.0).
28    pub multiplier: f64,
29}
30
31impl ReconnectPolicy {
32    /// Build a reconnect policy with explicit parameters.
33    ///
34    /// Returns an error if `multiplier < 1.0` (which would cause backoff to
35    /// shrink over time) or if `max_attempts == 0`.
36    pub fn new(
37        max_attempts: u32,
38        initial_backoff: Duration,
39        max_backoff: Duration,
40        multiplier: f64,
41    ) -> Result<Self, StreamError> {
42        if multiplier < 1.0 {
43            return Err(StreamError::ConnectionFailed {
44                url: String::new(),
45                reason: "reconnect multiplier must be >= 1.0".into(),
46            });
47        }
48        if max_attempts == 0 {
49            return Err(StreamError::ConnectionFailed {
50                url: String::new(),
51                reason: "max_attempts must be > 0".into(),
52            });
53        }
54        Ok(Self {
55            max_attempts,
56            initial_backoff,
57            max_backoff,
58            multiplier,
59        })
60    }
61
62    /// Backoff duration for attempt N (0-indexed).
63    pub fn backoff_for_attempt(&self, attempt: u32) -> Duration {
64        let factor = self.multiplier.powi(attempt as i32);
65        // Cap the f64 value *before* casting to u64.  When `attempt` is large
66        // (e.g. 63 with multiplier=2.0), `factor` becomes f64::INFINITY.
67        // Casting f64::INFINITY as u64 is undefined behaviour in Rust — it
68        // saturates to 0 on some targets and panics in debug builds.  Clamping
69        // to max_backoff in floating-point space first avoids the UB entirely.
70        let max_ms = self.max_backoff.as_millis() as f64;
71        let ms = (self.initial_backoff.as_millis() as f64 * factor).min(max_ms);
72        Duration::from_millis(ms as u64)
73    }
74}
75
76impl Default for ReconnectPolicy {
77    fn default() -> Self {
78        Self {
79            max_attempts: 10,
80            initial_backoff: Duration::from_millis(500),
81            max_backoff: Duration::from_secs(30),
82            multiplier: 2.0,
83        }
84    }
85}
86
87/// Configuration for a WebSocket feed connection.
88#[derive(Debug, Clone)]
89pub struct ConnectionConfig {
90    /// WebSocket URL to connect to (e.g. `"wss://stream.binance.com:9443/ws"`).
91    pub url: String,
92    /// Capacity of the downstream channel that receives incoming messages.
93    pub channel_capacity: usize,
94    /// Reconnect policy applied on disconnection.
95    pub reconnect: ReconnectPolicy,
96    /// Ping interval to keep the connection alive (default: 20 s).
97    pub ping_interval: Duration,
98}
99
100impl ConnectionConfig {
101    /// Build a connection configuration for `url` with the given downstream
102    /// channel capacity.
103    ///
104    /// Returns an error if `url` is empty or `channel_capacity` is zero.
105    pub fn new(url: impl Into<String>, channel_capacity: usize) -> Result<Self, StreamError> {
106        let url = url.into();
107        if url.is_empty() {
108            return Err(StreamError::ConnectionFailed {
109                url: url.clone(),
110                reason: "URL must not be empty".into(),
111            });
112        }
113        if channel_capacity == 0 {
114            return Err(StreamError::Backpressure {
115                channel: url.clone(),
116                depth: 0,
117                capacity: 0,
118            });
119        }
120        Ok(Self {
121            url,
122            channel_capacity,
123            reconnect: ReconnectPolicy::default(),
124            ping_interval: Duration::from_secs(20),
125        })
126    }
127
128    /// Override the default reconnect policy.
129    pub fn with_reconnect(mut self, policy: ReconnectPolicy) -> Self {
130        self.reconnect = policy;
131        self
132    }
133
134    /// Override the keepalive ping interval (default: 20 s).
135    pub fn with_ping_interval(mut self, interval: Duration) -> Self {
136        self.ping_interval = interval;
137        self
138    }
139}
140
141/// Manages a single WebSocket feed: connect, receive, reconnect.
142///
143/// In production, WsManager wraps tokio-tungstenite. In tests, it operates
144/// in simulation mode with injected messages.
145pub struct WsManager {
146    config: ConnectionConfig,
147    connect_attempts: u32,
148    is_connected: bool,
149}
150
151impl WsManager {
152    /// Create a new manager from a validated [`ConnectionConfig`].
153    pub fn new(config: ConnectionConfig) -> Self {
154        Self {
155            config,
156            connect_attempts: 0,
157            is_connected: false,
158        }
159    }
160
161    /// Simulate a connection (for testing without live WebSocket).
162    /// Increments `connect_attempts` to reflect the initial connection slot.
163    pub fn connect_simulated(&mut self) {
164        self.connect_attempts += 1;
165        self.is_connected = true;
166    }
167
168    /// Simulate a disconnection.
169    pub fn disconnect_simulated(&mut self) {
170        self.is_connected = false;
171    }
172
173    /// Whether the managed connection is currently in the connected state.
174    pub fn is_connected(&self) -> bool {
175        self.is_connected
176    }
177
178    /// Total connection attempts made so far (including the initial connect).
179    pub fn connect_attempts(&self) -> u32 {
180        self.connect_attempts
181    }
182
183    /// The configuration this manager was created with.
184    pub fn config(&self) -> &ConnectionConfig {
185        &self.config
186    }
187
188    /// Check whether the next reconnect attempt is allowed.
189    pub fn can_reconnect(&self) -> bool {
190        self.connect_attempts < self.config.reconnect.max_attempts
191    }
192
193    /// Consume a reconnect slot and return the backoff duration to wait.
194    pub fn next_reconnect_backoff(&mut self) -> Result<Duration, StreamError> {
195        if !self.can_reconnect() {
196            return Err(StreamError::ReconnectExhausted {
197                url: self.config.url.clone(),
198                attempts: self.connect_attempts,
199            });
200        }
201        let backoff = self
202            .config
203            .reconnect
204            .backoff_for_attempt(self.connect_attempts);
205        self.connect_attempts += 1;
206        Ok(backoff)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    fn default_config() -> ConnectionConfig {
215        ConnectionConfig::new("wss://example.com/ws", 1024).unwrap()
216    }
217
218    #[test]
219    fn test_reconnect_policy_default_values() {
220        let p = ReconnectPolicy::default();
221        assert_eq!(p.max_attempts, 10);
222        assert_eq!(p.multiplier, 2.0);
223    }
224
225    #[test]
226    fn test_reconnect_policy_backoff_exponential() {
227        let p = ReconnectPolicy::new(10, Duration::from_millis(100), Duration::from_secs(30), 2.0)
228            .unwrap();
229        assert_eq!(p.backoff_for_attempt(0), Duration::from_millis(100));
230        assert_eq!(p.backoff_for_attempt(1), Duration::from_millis(200));
231        assert_eq!(p.backoff_for_attempt(2), Duration::from_millis(400));
232    }
233
234    #[test]
235    fn test_reconnect_policy_backoff_capped_at_max() {
236        let p = ReconnectPolicy::new(10, Duration::from_millis(1000), Duration::from_secs(5), 2.0)
237            .unwrap();
238        let backoff = p.backoff_for_attempt(10);
239        assert!(backoff <= Duration::from_secs(5));
240    }
241
242    #[test]
243    fn test_reconnect_policy_multiplier_below_1_rejected() {
244        let result =
245            ReconnectPolicy::new(10, Duration::from_millis(100), Duration::from_secs(30), 0.5);
246        assert!(result.is_err());
247    }
248
249    #[test]
250    fn test_reconnect_policy_zero_attempts_rejected() {
251        let result =
252            ReconnectPolicy::new(0, Duration::from_millis(100), Duration::from_secs(30), 2.0);
253        assert!(result.is_err());
254    }
255
256    #[test]
257    fn test_connection_config_empty_url_rejected() {
258        let result = ConnectionConfig::new("", 1024);
259        assert!(result.is_err());
260    }
261
262    #[test]
263    fn test_connection_config_zero_capacity_rejected() {
264        let result = ConnectionConfig::new("wss://example.com", 0);
265        assert!(result.is_err());
266    }
267
268    #[test]
269    fn test_connection_config_with_reconnect() {
270        let policy =
271            ReconnectPolicy::new(3, Duration::from_millis(200), Duration::from_secs(10), 2.0)
272                .unwrap();
273        let config = default_config().with_reconnect(policy);
274        assert_eq!(config.reconnect.max_attempts, 3);
275    }
276
277    #[test]
278    fn test_connection_config_with_ping_interval() {
279        let config = default_config().with_ping_interval(Duration::from_secs(30));
280        assert_eq!(config.ping_interval, Duration::from_secs(30));
281    }
282
283    #[test]
284    fn test_ws_manager_initial_state() {
285        let mgr = WsManager::new(default_config());
286        assert!(!mgr.is_connected());
287        assert_eq!(mgr.connect_attempts(), 0);
288    }
289
290    #[test]
291    fn test_ws_manager_connect_simulated() {
292        let mut mgr = WsManager::new(default_config());
293        mgr.connect_simulated();
294        assert!(mgr.is_connected());
295        assert_eq!(mgr.connect_attempts(), 1);
296    }
297
298    #[test]
299    fn test_ws_manager_disconnect_simulated() {
300        let mut mgr = WsManager::new(default_config());
301        mgr.connect_simulated();
302        mgr.disconnect_simulated();
303        assert!(!mgr.is_connected());
304    }
305
306    #[test]
307    fn test_ws_manager_can_reconnect_within_limit() {
308        let mut mgr = WsManager::new(
309            default_config().with_reconnect(
310                ReconnectPolicy::new(3, Duration::from_millis(10), Duration::from_secs(1), 2.0)
311                    .unwrap(),
312            ),
313        );
314        assert!(mgr.can_reconnect());
315        mgr.next_reconnect_backoff().unwrap();
316        mgr.next_reconnect_backoff().unwrap();
317        mgr.next_reconnect_backoff().unwrap();
318        assert!(!mgr.can_reconnect());
319    }
320
321    #[test]
322    fn test_ws_manager_reconnect_exhausted_error() {
323        let mut mgr = WsManager::new(
324            default_config().with_reconnect(
325                ReconnectPolicy::new(1, Duration::from_millis(10), Duration::from_secs(1), 2.0)
326                    .unwrap(),
327            ),
328        );
329        mgr.next_reconnect_backoff().unwrap();
330        let result = mgr.next_reconnect_backoff();
331        assert!(matches!(
332            result,
333            Err(StreamError::ReconnectExhausted { .. })
334        ));
335    }
336
337    #[test]
338    fn test_ws_manager_backoff_increases() {
339        let mut mgr = WsManager::new(
340            default_config().with_reconnect(
341                ReconnectPolicy::new(5, Duration::from_millis(100), Duration::from_secs(30), 2.0)
342                    .unwrap(),
343            ),
344        );
345        let b0 = mgr.next_reconnect_backoff().unwrap();
346        let b1 = mgr.next_reconnect_backoff().unwrap();
347        assert!(b1 >= b0);
348    }
349
350    #[test]
351    fn test_ws_manager_config_accessor() {
352        let mgr = WsManager::new(default_config());
353        assert_eq!(mgr.config().url, "wss://example.com/ws");
354        assert_eq!(mgr.config().channel_capacity, 1024);
355    }
356}