1use crate::error::StreamError;
13use std::time::Duration;
14
15#[derive(Debug, Clone)]
20pub struct ReconnectPolicy {
21 pub max_attempts: u32,
23 pub initial_backoff: Duration,
25 pub max_backoff: Duration,
27 pub multiplier: f64,
29}
30
31impl ReconnectPolicy {
32 pub fn new(
37 max_attempts: u32,
38 initial_backoff: Duration,
39 max_backoff: Duration,
40 multiplier: f64,
41 ) -> Result<Self, StreamError> {
42 if multiplier < 1.0 {
43 return Err(StreamError::ConnectionFailed {
44 url: String::new(),
45 reason: "reconnect multiplier must be >= 1.0".into(),
46 });
47 }
48 if max_attempts == 0 {
49 return Err(StreamError::ConnectionFailed {
50 url: String::new(),
51 reason: "max_attempts must be > 0".into(),
52 });
53 }
54 Ok(Self {
55 max_attempts,
56 initial_backoff,
57 max_backoff,
58 multiplier,
59 })
60 }
61
62 pub fn backoff_for_attempt(&self, attempt: u32) -> Duration {
64 let factor = self.multiplier.powi(attempt as i32);
65 let max_ms = self.max_backoff.as_millis() as f64;
71 let ms = (self.initial_backoff.as_millis() as f64 * factor).min(max_ms);
72 Duration::from_millis(ms as u64)
73 }
74}
75
76impl Default for ReconnectPolicy {
77 fn default() -> Self {
78 Self {
79 max_attempts: 10,
80 initial_backoff: Duration::from_millis(500),
81 max_backoff: Duration::from_secs(30),
82 multiplier: 2.0,
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct ConnectionConfig {
90 pub url: String,
92 pub channel_capacity: usize,
94 pub reconnect: ReconnectPolicy,
96 pub ping_interval: Duration,
98}
99
100impl ConnectionConfig {
101 pub fn new(url: impl Into<String>, channel_capacity: usize) -> Result<Self, StreamError> {
106 let url = url.into();
107 if url.is_empty() {
108 return Err(StreamError::ConnectionFailed {
109 url: url.clone(),
110 reason: "URL must not be empty".into(),
111 });
112 }
113 if channel_capacity == 0 {
114 return Err(StreamError::Backpressure {
115 channel: url.clone(),
116 depth: 0,
117 capacity: 0,
118 });
119 }
120 Ok(Self {
121 url,
122 channel_capacity,
123 reconnect: ReconnectPolicy::default(),
124 ping_interval: Duration::from_secs(20),
125 })
126 }
127
128 pub fn with_reconnect(mut self, policy: ReconnectPolicy) -> Self {
130 self.reconnect = policy;
131 self
132 }
133
134 pub fn with_ping_interval(mut self, interval: Duration) -> Self {
136 self.ping_interval = interval;
137 self
138 }
139}
140
141pub struct WsManager {
146 config: ConnectionConfig,
147 connect_attempts: u32,
148 is_connected: bool,
149}
150
151impl WsManager {
152 pub fn new(config: ConnectionConfig) -> Self {
154 Self {
155 config,
156 connect_attempts: 0,
157 is_connected: false,
158 }
159 }
160
161 pub fn connect_simulated(&mut self) {
164 self.connect_attempts += 1;
165 self.is_connected = true;
166 }
167
168 pub fn disconnect_simulated(&mut self) {
170 self.is_connected = false;
171 }
172
173 pub fn is_connected(&self) -> bool {
175 self.is_connected
176 }
177
178 pub fn connect_attempts(&self) -> u32 {
180 self.connect_attempts
181 }
182
183 pub fn config(&self) -> &ConnectionConfig {
185 &self.config
186 }
187
188 pub fn can_reconnect(&self) -> bool {
190 self.connect_attempts < self.config.reconnect.max_attempts
191 }
192
193 pub fn next_reconnect_backoff(&mut self) -> Result<Duration, StreamError> {
195 if !self.can_reconnect() {
196 return Err(StreamError::ReconnectExhausted {
197 url: self.config.url.clone(),
198 attempts: self.connect_attempts,
199 });
200 }
201 let backoff = self
202 .config
203 .reconnect
204 .backoff_for_attempt(self.connect_attempts);
205 self.connect_attempts += 1;
206 Ok(backoff)
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 fn default_config() -> ConnectionConfig {
215 ConnectionConfig::new("wss://example.com/ws", 1024).unwrap()
216 }
217
218 #[test]
219 fn test_reconnect_policy_default_values() {
220 let p = ReconnectPolicy::default();
221 assert_eq!(p.max_attempts, 10);
222 assert_eq!(p.multiplier, 2.0);
223 }
224
225 #[test]
226 fn test_reconnect_policy_backoff_exponential() {
227 let p = ReconnectPolicy::new(10, Duration::from_millis(100), Duration::from_secs(30), 2.0)
228 .unwrap();
229 assert_eq!(p.backoff_for_attempt(0), Duration::from_millis(100));
230 assert_eq!(p.backoff_for_attempt(1), Duration::from_millis(200));
231 assert_eq!(p.backoff_for_attempt(2), Duration::from_millis(400));
232 }
233
234 #[test]
235 fn test_reconnect_policy_backoff_capped_at_max() {
236 let p = ReconnectPolicy::new(10, Duration::from_millis(1000), Duration::from_secs(5), 2.0)
237 .unwrap();
238 let backoff = p.backoff_for_attempt(10);
239 assert!(backoff <= Duration::from_secs(5));
240 }
241
242 #[test]
243 fn test_reconnect_policy_multiplier_below_1_rejected() {
244 let result =
245 ReconnectPolicy::new(10, Duration::from_millis(100), Duration::from_secs(30), 0.5);
246 assert!(result.is_err());
247 }
248
249 #[test]
250 fn test_reconnect_policy_zero_attempts_rejected() {
251 let result =
252 ReconnectPolicy::new(0, Duration::from_millis(100), Duration::from_secs(30), 2.0);
253 assert!(result.is_err());
254 }
255
256 #[test]
257 fn test_connection_config_empty_url_rejected() {
258 let result = ConnectionConfig::new("", 1024);
259 assert!(result.is_err());
260 }
261
262 #[test]
263 fn test_connection_config_zero_capacity_rejected() {
264 let result = ConnectionConfig::new("wss://example.com", 0);
265 assert!(result.is_err());
266 }
267
268 #[test]
269 fn test_connection_config_with_reconnect() {
270 let policy =
271 ReconnectPolicy::new(3, Duration::from_millis(200), Duration::from_secs(10), 2.0)
272 .unwrap();
273 let config = default_config().with_reconnect(policy);
274 assert_eq!(config.reconnect.max_attempts, 3);
275 }
276
277 #[test]
278 fn test_connection_config_with_ping_interval() {
279 let config = default_config().with_ping_interval(Duration::from_secs(30));
280 assert_eq!(config.ping_interval, Duration::from_secs(30));
281 }
282
283 #[test]
284 fn test_ws_manager_initial_state() {
285 let mgr = WsManager::new(default_config());
286 assert!(!mgr.is_connected());
287 assert_eq!(mgr.connect_attempts(), 0);
288 }
289
290 #[test]
291 fn test_ws_manager_connect_simulated() {
292 let mut mgr = WsManager::new(default_config());
293 mgr.connect_simulated();
294 assert!(mgr.is_connected());
295 assert_eq!(mgr.connect_attempts(), 1);
296 }
297
298 #[test]
299 fn test_ws_manager_disconnect_simulated() {
300 let mut mgr = WsManager::new(default_config());
301 mgr.connect_simulated();
302 mgr.disconnect_simulated();
303 assert!(!mgr.is_connected());
304 }
305
306 #[test]
307 fn test_ws_manager_can_reconnect_within_limit() {
308 let mut mgr = WsManager::new(
309 default_config().with_reconnect(
310 ReconnectPolicy::new(3, Duration::from_millis(10), Duration::from_secs(1), 2.0)
311 .unwrap(),
312 ),
313 );
314 assert!(mgr.can_reconnect());
315 mgr.next_reconnect_backoff().unwrap();
316 mgr.next_reconnect_backoff().unwrap();
317 mgr.next_reconnect_backoff().unwrap();
318 assert!(!mgr.can_reconnect());
319 }
320
321 #[test]
322 fn test_ws_manager_reconnect_exhausted_error() {
323 let mut mgr = WsManager::new(
324 default_config().with_reconnect(
325 ReconnectPolicy::new(1, Duration::from_millis(10), Duration::from_secs(1), 2.0)
326 .unwrap(),
327 ),
328 );
329 mgr.next_reconnect_backoff().unwrap();
330 let result = mgr.next_reconnect_backoff();
331 assert!(matches!(
332 result,
333 Err(StreamError::ReconnectExhausted { .. })
334 ));
335 }
336
337 #[test]
338 fn test_ws_manager_backoff_increases() {
339 let mut mgr = WsManager::new(
340 default_config().with_reconnect(
341 ReconnectPolicy::new(5, Duration::from_millis(100), Duration::from_secs(30), 2.0)
342 .unwrap(),
343 ),
344 );
345 let b0 = mgr.next_reconnect_backoff().unwrap();
346 let b1 = mgr.next_reconnect_backoff().unwrap();
347 assert!(b1 >= b0);
348 }
349
350 #[test]
351 fn test_ws_manager_config_accessor() {
352 let mgr = WsManager::new(default_config());
353 assert_eq!(mgr.config().url, "wss://example.com/ws");
354 assert_eq!(mgr.config().channel_capacity, 1024);
355 }
356}