Skip to main content

fin_stream/ws/
mod.rs

1//! WebSocket connection management — auto-reconnect and backpressure.
2//!
3//! ## Responsibility
4//! Manage the lifecycle of a WebSocket feed connection: connect, receive
5//! messages, detect disconnections, apply exponential backoff reconnect,
6//! and propagate backpressure when the downstream channel is full.
7//!
8//! ## Guarantees
9//! - Non-panicking: all operations return Result
10//! - Configurable: reconnect policy and buffer sizes are constructor params
11
12use crate::error::StreamError;
13use std::time::Duration;
14
15/// Reconnection policy for a WebSocket feed.
16#[derive(Debug, Clone)]
17pub struct ReconnectPolicy {
18    /// Maximum number of reconnect attempts before giving up.
19    pub max_attempts: u32,
20    /// Initial backoff delay.
21    pub initial_backoff: Duration,
22    /// Maximum backoff delay (cap for exponential growth).
23    pub max_backoff: Duration,
24    /// Backoff multiplier.
25    pub multiplier: f64,
26}
27
28impl ReconnectPolicy {
29    pub fn new(
30        max_attempts: u32,
31        initial_backoff: Duration,
32        max_backoff: Duration,
33        multiplier: f64,
34    ) -> Result<Self, StreamError> {
35        if multiplier < 1.0 {
36            return Err(StreamError::ConnectionFailed {
37                url: String::new(),
38                reason: "reconnect multiplier must be >= 1.0".into(),
39            });
40        }
41        if max_attempts == 0 {
42            return Err(StreamError::ConnectionFailed {
43                url: String::new(),
44                reason: "max_attempts must be > 0".into(),
45            });
46        }
47        Ok(Self { max_attempts, initial_backoff, max_backoff, multiplier })
48    }
49
50    /// Backoff duration for attempt N (0-indexed).
51    pub fn backoff_for_attempt(&self, attempt: u32) -> Duration {
52        let factor = self.multiplier.powi(attempt as i32);
53        let ms = (self.initial_backoff.as_millis() as f64 * factor) as u64;
54        let capped = ms.min(self.max_backoff.as_millis() as u64);
55        Duration::from_millis(capped)
56    }
57}
58
59impl Default for ReconnectPolicy {
60    fn default() -> Self {
61        Self {
62            max_attempts: 10,
63            initial_backoff: Duration::from_millis(500),
64            max_backoff: Duration::from_secs(30),
65            multiplier: 2.0,
66        }
67    }
68}
69
70/// Configuration for a WebSocket feed connection.
71#[derive(Debug, Clone)]
72pub struct ConnectionConfig {
73    pub url: String,
74    pub channel_capacity: usize,
75    pub reconnect: ReconnectPolicy,
76    /// Ping interval to keep the connection alive.
77    pub ping_interval: Duration,
78}
79
80impl ConnectionConfig {
81    pub fn new(url: impl Into<String>, channel_capacity: usize) -> Result<Self, StreamError> {
82        let url = url.into();
83        if url.is_empty() {
84            return Err(StreamError::ConnectionFailed {
85                url: url.clone(),
86                reason: "URL must not be empty".into(),
87            });
88        }
89        if channel_capacity == 0 {
90            return Err(StreamError::Backpressure {
91                channel: url.clone(),
92                depth: 0,
93                capacity: 0,
94            });
95        }
96        Ok(Self {
97            url,
98            channel_capacity,
99            reconnect: ReconnectPolicy::default(),
100            ping_interval: Duration::from_secs(20),
101        })
102    }
103
104    pub fn with_reconnect(mut self, policy: ReconnectPolicy) -> Self {
105        self.reconnect = policy;
106        self
107    }
108
109    pub fn with_ping_interval(mut self, interval: Duration) -> Self {
110        self.ping_interval = interval;
111        self
112    }
113}
114
115/// Manages a single WebSocket feed: connect, receive, reconnect.
116///
117/// In production, WsManager wraps tokio-tungstenite. In tests, it operates
118/// in simulation mode with injected messages.
119pub struct WsManager {
120    config: ConnectionConfig,
121    connect_attempts: u32,
122    is_connected: bool,
123}
124
125impl WsManager {
126    pub fn new(config: ConnectionConfig) -> Self {
127        Self { config, connect_attempts: 0, is_connected: false }
128    }
129
130    /// Simulate a connection (for testing without live WebSocket).
131    /// Increments `connect_attempts` to reflect the initial connection slot.
132    pub fn connect_simulated(&mut self) {
133        self.connect_attempts += 1;
134        self.is_connected = true;
135    }
136
137    /// Simulate a disconnection.
138    pub fn disconnect_simulated(&mut self) {
139        self.is_connected = false;
140    }
141
142    pub fn is_connected(&self) -> bool { self.is_connected }
143    pub fn connect_attempts(&self) -> u32 { self.connect_attempts }
144    pub fn config(&self) -> &ConnectionConfig { &self.config }
145
146    /// Check whether the next reconnect attempt is allowed.
147    pub fn can_reconnect(&self) -> bool {
148        self.connect_attempts < self.config.reconnect.max_attempts
149    }
150
151    /// Consume a reconnect slot and return the backoff duration to wait.
152    pub fn next_reconnect_backoff(&mut self) -> Result<Duration, StreamError> {
153        if !self.can_reconnect() {
154            return Err(StreamError::ReconnectExhausted {
155                url: self.config.url.clone(),
156                attempts: self.connect_attempts,
157            });
158        }
159        let backoff = self.config.reconnect.backoff_for_attempt(self.connect_attempts);
160        self.connect_attempts += 1;
161        Ok(backoff)
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    fn default_config() -> ConnectionConfig {
170        ConnectionConfig::new("wss://example.com/ws", 1024).unwrap()
171    }
172
173    #[test]
174    fn test_reconnect_policy_default_values() {
175        let p = ReconnectPolicy::default();
176        assert_eq!(p.max_attempts, 10);
177        assert_eq!(p.multiplier, 2.0);
178    }
179
180    #[test]
181    fn test_reconnect_policy_backoff_exponential() {
182        let p = ReconnectPolicy::new(10, Duration::from_millis(100), Duration::from_secs(30), 2.0).unwrap();
183        assert_eq!(p.backoff_for_attempt(0), Duration::from_millis(100));
184        assert_eq!(p.backoff_for_attempt(1), Duration::from_millis(200));
185        assert_eq!(p.backoff_for_attempt(2), Duration::from_millis(400));
186    }
187
188    #[test]
189    fn test_reconnect_policy_backoff_capped_at_max() {
190        let p = ReconnectPolicy::new(10, Duration::from_millis(1000), Duration::from_secs(5), 2.0).unwrap();
191        let backoff = p.backoff_for_attempt(10);
192        assert!(backoff <= Duration::from_secs(5));
193    }
194
195    #[test]
196    fn test_reconnect_policy_multiplier_below_1_rejected() {
197        let result = ReconnectPolicy::new(10, Duration::from_millis(100), Duration::from_secs(30), 0.5);
198        assert!(result.is_err());
199    }
200
201    #[test]
202    fn test_reconnect_policy_zero_attempts_rejected() {
203        let result = ReconnectPolicy::new(0, Duration::from_millis(100), Duration::from_secs(30), 2.0);
204        assert!(result.is_err());
205    }
206
207    #[test]
208    fn test_connection_config_empty_url_rejected() {
209        let result = ConnectionConfig::new("", 1024);
210        assert!(result.is_err());
211    }
212
213    #[test]
214    fn test_connection_config_zero_capacity_rejected() {
215        let result = ConnectionConfig::new("wss://example.com", 0);
216        assert!(result.is_err());
217    }
218
219    #[test]
220    fn test_connection_config_with_reconnect() {
221        let policy = ReconnectPolicy::new(3, Duration::from_millis(200), Duration::from_secs(10), 2.0).unwrap();
222        let config = default_config().with_reconnect(policy);
223        assert_eq!(config.reconnect.max_attempts, 3);
224    }
225
226    #[test]
227    fn test_connection_config_with_ping_interval() {
228        let config = default_config().with_ping_interval(Duration::from_secs(30));
229        assert_eq!(config.ping_interval, Duration::from_secs(30));
230    }
231
232    #[test]
233    fn test_ws_manager_initial_state() {
234        let mgr = WsManager::new(default_config());
235        assert!(!mgr.is_connected());
236        assert_eq!(mgr.connect_attempts(), 0);
237    }
238
239    #[test]
240    fn test_ws_manager_connect_simulated() {
241        let mut mgr = WsManager::new(default_config());
242        mgr.connect_simulated();
243        assert!(mgr.is_connected());
244        assert_eq!(mgr.connect_attempts(), 1);
245    }
246
247    #[test]
248    fn test_ws_manager_disconnect_simulated() {
249        let mut mgr = WsManager::new(default_config());
250        mgr.connect_simulated();
251        mgr.disconnect_simulated();
252        assert!(!mgr.is_connected());
253    }
254
255    #[test]
256    fn test_ws_manager_can_reconnect_within_limit() {
257        let mut mgr = WsManager::new(
258            default_config().with_reconnect(
259                ReconnectPolicy::new(3, Duration::from_millis(10), Duration::from_secs(1), 2.0).unwrap()
260            )
261        );
262        assert!(mgr.can_reconnect());
263        mgr.next_reconnect_backoff().unwrap();
264        mgr.next_reconnect_backoff().unwrap();
265        mgr.next_reconnect_backoff().unwrap();
266        assert!(!mgr.can_reconnect());
267    }
268
269    #[test]
270    fn test_ws_manager_reconnect_exhausted_error() {
271        let mut mgr = WsManager::new(
272            default_config().with_reconnect(
273                ReconnectPolicy::new(1, Duration::from_millis(10), Duration::from_secs(1), 2.0).unwrap()
274            )
275        );
276        mgr.next_reconnect_backoff().unwrap();
277        let result = mgr.next_reconnect_backoff();
278        assert!(matches!(result, Err(StreamError::ReconnectExhausted { .. })));
279    }
280
281    #[test]
282    fn test_ws_manager_backoff_increases() {
283        let mut mgr = WsManager::new(
284            default_config().with_reconnect(
285                ReconnectPolicy::new(5, Duration::from_millis(100), Duration::from_secs(30), 2.0).unwrap()
286            )
287        );
288        let b0 = mgr.next_reconnect_backoff().unwrap();
289        let b1 = mgr.next_reconnect_backoff().unwrap();
290        assert!(b1 >= b0);
291    }
292
293    #[test]
294    fn test_ws_manager_config_accessor() {
295        let mgr = WsManager::new(default_config());
296        assert_eq!(mgr.config().url, "wss://example.com/ws");
297        assert_eq!(mgr.config().channel_capacity, 1024);
298    }
299}