1use std::sync::Arc;
2use std::time::{Duration, Instant};
3use tokio::sync::RwLock;
4use tracing::{warn, info, error};
5
6use crate::error::WorkerResult;
7use crate::backends::{MessageBackend, ReceiveResult};
8
9#[derive(Debug, Clone)]
11pub enum ReconnectStrategy {
12 Fixed(Duration),
14
15 Exponential {
17 initial: Duration,
18 max: Duration,
19 multiplier: f64,
20 jitter_factor: f64,
22 },
23}
24
25impl ReconnectStrategy {
26 fn delay_for_attempt(&self, attempt: u32) -> Duration {
27 match self {
28 ReconnectStrategy::Fixed(d) => *d,
29 ReconnectStrategy::Exponential { initial, max, multiplier, jitter_factor } => {
30 let base_delay = initial.mul_f64(multiplier.powi(attempt as i32));
32 let clamped = base_delay.min(*max);
33
34 if *jitter_factor > 0.0 {
36 let jitter_range = clamped.mul_f64(*jitter_factor);
37 let jitter = jitter_range.mul_f64(rand::random::<f64>());
38 clamped + jitter
39 } else {
40 clamped
41 }
42 }
43 }
44 }
45}
46
47impl Default for ReconnectStrategy {
48 fn default() -> Self {
49 ReconnectStrategy::Exponential {
50 initial: Duration::from_secs(1),
51 max: Duration::from_secs(60),
52 multiplier: 2.0,
53 jitter_factor: 0.1, }
55 }
56}
57
58pub struct ResilientBackend {
80 inner: Arc<dyn MessageBackend>,
81 strategy: ReconnectStrategy,
82 reconnect_attempts: Arc<RwLock<u32>>,
83 last_success: Arc<RwLock<Instant>>,
84 is_connected: Arc<RwLock<bool>>,
85 consecutive_failures: Arc<RwLock<u32>>,
87}
88
89impl ResilientBackend {
90 pub fn new(inner: Arc<dyn MessageBackend>) -> Self {
92 Self {
93 inner,
94 strategy: ReconnectStrategy::default(),
95 reconnect_attempts: Arc::new(RwLock::new(0)),
96 last_success: Arc::new(RwLock::new(Instant::now())),
97 is_connected: Arc::new(RwLock::new(true)),
98 consecutive_failures: Arc::new(RwLock::new(0)),
99 }
100 }
101
102 pub fn with_strategy(inner: Arc<dyn MessageBackend>, strategy: ReconnectStrategy) -> Self {
104 Self {
105 inner,
106 strategy,
107 reconnect_attempts: Arc::new(RwLock::new(0)),
108 last_success: Arc::new(RwLock::new(Instant::now())),
109 is_connected: Arc::new(RwLock::new(true)),
110 consecutive_failures: Arc::new(RwLock::new(0)),
111 }
112 }
113
114 pub fn inner(&self) -> &Arc<dyn MessageBackend> {
116 &self.inner
117 }
118
119 pub async fn is_connected(&self) -> bool {
121 *self.is_connected.read().await
122 }
123
124 pub async fn reconnect_attempts(&self) -> u32 {
126 *self.reconnect_attempts.read().await
127 }
128
129 pub async fn consecutive_failures(&self) -> u32 {
131 *self.consecutive_failures.read().await
132 }
133
134 async fn execute_with_retry<T, F, Fut>(&self, operation_name: &str, op: F) -> WorkerResult<T>
139 where
140 F: Fn() -> Fut,
141 Fut: std::future::Future<Output = WorkerResult<T>>,
142 {
143 let mut attempt = 0;
144
145 loop {
146 match op().await {
147 Ok(result) => {
148 if attempt > 0 {
150 info!("{} succeeded after {} attempts", operation_name, attempt);
151 }
152 *self.reconnect_attempts.write().await = 0;
153 *self.consecutive_failures.write().await = 0;
154 *self.last_success.write().await = Instant::now();
155 *self.is_connected.write().await = true;
156 return Ok(result);
157 }
158 Err(e) => {
159 attempt += 1;
160 *self.reconnect_attempts.write().await = attempt;
161 let failures = {
162 let mut f = self.consecutive_failures.write().await;
163 *f += 1;
164 *f
165 };
166 *self.is_connected.write().await = false;
167
168 warn!(
169 "{} failed (attempt {}, consecutive failures: {}): {}. Retrying...",
170 operation_name, attempt, failures, e
171 );
172
173 if let Err(recover_err) = self.try_recover().await {
175 error!("Recovery attempt failed: {}", recover_err);
176 }
177
178 let delay = self.strategy.delay_for_attempt(attempt - 1);
180
181 if attempt % 10 == 0 || attempt <= 3 {
183 warn!(
184 "Still trying {} (attempt {}) - next retry in {:?}",
185 operation_name, attempt, delay
186 );
187 }
188
189 tokio::time::sleep(delay).await;
190
191 }
194 }
195 }
196 }
197
198 async fn try_recover(&self) -> WorkerResult<()> {
200 match self.inner.health_check().await {
202 Ok(_) => {
203 info!("Connection recovered");
204 *self.consecutive_failures.write().await = 0;
205 Ok(())
206 }
207 Err(e) => {
208 warn!("Health check failed during recovery: {}", e);
209 Err(e)
212 }
213 }
214 }
215}
216
217#[async_trait::async_trait]
218impl MessageBackend for ResilientBackend {
219 async fn receive(&self) -> WorkerResult<ReceiveResult<serde_json::Value>> {
220 self.execute_with_retry("receive", || async {
221 self.inner.receive().await
222 }).await
223 }
224
225 async fn ack(&self, message_id: &str) -> WorkerResult<()> {
226 self.inner.ack(message_id).await
229 }
230
231 async fn nack(&self, message_id: &str, requeue: bool) -> WorkerResult<()> {
232 self.execute_with_retry("nack", || async {
234 self.inner.nack(message_id, requeue).await
235 }).await
236 }
237
238 async fn health_check(&self) -> WorkerResult<()> {
239 self.inner.health_check().await
240 }
241
242 async fn shutdown(&self) -> WorkerResult<()> {
243 self.inner.shutdown().await
244 }
245}
246
247pub struct ResilientBackendBuilder {
249 inner: Arc<dyn MessageBackend>,
250 strategy: ReconnectStrategy,
251}
252
253impl ResilientBackendBuilder {
254 pub fn new(inner: Arc<dyn MessageBackend>) -> Self {
256 Self {
257 inner,
258 strategy: ReconnectStrategy::default(),
259 }
260 }
261
262 pub fn with_strategy(mut self, strategy: ReconnectStrategy) -> Self {
264 self.strategy = strategy;
265 self
266 }
267
268 pub fn build(self) -> ResilientBackend {
270 ResilientBackend::with_strategy(self.inner, self.strategy)
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::backends::{MemoryBackend, ReceiveResult};
278 use crate::error::WorkerError;
279 use std::sync::atomic::{AtomicUsize, Ordering};
280
281 struct FailingBackend {
283 fail_count: Arc<AtomicUsize>,
284 total_calls: Arc<AtomicUsize>,
285 succeed_after: usize,
286 }
287
288 impl FailingBackend {
289 fn new(succeed_after: usize) -> (Arc<Self>, Arc<AtomicUsize>, Arc<AtomicUsize>) {
290 let fail_count = Arc::new(AtomicUsize::new(0));
291 let total_calls = Arc::new(AtomicUsize::new(0));
292 (
293 Arc::new(Self {
294 fail_count: fail_count.clone(),
295 total_calls: total_calls.clone(),
296 succeed_after,
297 }),
298 fail_count,
299 total_calls,
300 )
301 }
302 }
303
304 #[async_trait::async_trait]
305 impl MessageBackend for FailingBackend {
306 async fn receive(&self) -> WorkerResult<ReceiveResult<serde_json::Value>> {
307 let calls = self.total_calls.fetch_add(1, Ordering::SeqCst);
308 if calls < self.succeed_after {
309 self.fail_count.fetch_add(1, Ordering::SeqCst);
310 Err(WorkerError::BackendError("Simulated network failure".to_string()))
311 } else {
312 Ok(ReceiveResult::Shutdown)
313 }
314 }
315
316 async fn ack(&self, _message_id: &str) -> WorkerResult<()> {
317 Ok(())
318 }
319
320 async fn nack(&self, _message_id: &str, _requeue: bool) -> WorkerResult<()> {
321 Ok(())
322 }
323
324 async fn health_check(&self) -> WorkerResult<()> {
325 let calls = self.total_calls.load(Ordering::SeqCst);
326 if calls < self.succeed_after {
327 Err(WorkerError::BackendError("Health check failed".to_string()))
328 } else {
329 Ok(())
330 }
331 }
332
333 async fn shutdown(&self) -> WorkerResult<()> {
334 Ok(())
335 }
336 }
337
338 #[tokio::test]
339 async fn test_resilient_backend_wraps_successfully() {
340 let inner = Arc::new(MemoryBackend::new());
341 let resilient = ResilientBackend::new(inner.clone());
342
343 assert!(resilient.is_connected().await);
344 assert_eq!(resilient.reconnect_attempts().await, 0);
345 assert_eq!(resilient.consecutive_failures().await, 0);
346 }
347
348 #[tokio::test]
349 async fn test_resilient_backend_receive() {
350 let inner = MemoryBackend::new();
351 let backend_arc = Arc::new(inner);
352 let resilient = ResilientBackend::new(backend_arc.clone());
353
354 backend_arc.enqueue(serde_json::json!({"test": "data"}));
356
357 let result = resilient.receive().await.unwrap();
359 assert!(result.is_message());
360 if let ReceiveResult::Message(msg) = result {
361 assert_eq!(msg.message.payload["test"], "data");
362 } else {
363 panic!("Expected Message variant");
364 }
365 }
366
367 #[tokio::test]
368 async fn test_resilient_backend_with_custom_strategy() {
369 let inner = Arc::new(MemoryBackend::new());
370 let strategy = ReconnectStrategy::Fixed(Duration::from_secs(1));
371 let resilient = ResilientBackend::with_strategy(inner, strategy);
372
373 assert!(resilient.is_connected().await);
374 }
375
376 #[tokio::test]
377 async fn test_exponential_backoff_calculation() {
378 let strategy = ReconnectStrategy::Exponential {
379 initial: Duration::from_millis(100),
380 max: Duration::from_secs(1),
381 multiplier: 2.0,
382 jitter_factor: 0.0, };
384
385 assert_eq!(strategy.delay_for_attempt(0).as_millis(), 100); assert_eq!(strategy.delay_for_attempt(1).as_millis(), 200); assert_eq!(strategy.delay_for_attempt(2).as_millis(), 400); assert_eq!(strategy.delay_for_attempt(3).as_millis(), 800); assert_eq!(strategy.delay_for_attempt(4).as_millis(), 1000); assert_eq!(strategy.delay_for_attempt(5).as_millis(), 1000); }
393
394 #[tokio::test]
395 async fn test_exponential_backoff_with_jitter() {
396 let strategy = ReconnectStrategy::Exponential {
397 initial: Duration::from_millis(100),
398 max: Duration::from_secs(1),
399 multiplier: 2.0,
400 jitter_factor: 0.5, };
402
403 let delay = strategy.delay_for_attempt(0);
405 let base = 100;
406 assert!(delay.as_millis() >= base as u128);
407 assert!(delay.as_millis() <= (base as f64 * 1.5) as u128);
408 }
409
410 #[tokio::test]
411 async fn test_fixed_delay_strategy() {
412 let strategy = ReconnectStrategy::Fixed(Duration::from_secs(2));
413
414 assert_eq!(strategy.delay_for_attempt(0).as_secs(), 2);
416 assert_eq!(strategy.delay_for_attempt(5).as_secs(), 2);
417 assert_eq!(strategy.delay_for_attempt(100).as_secs(), 2);
418 }
419
420 #[tokio::test]
421 async fn test_reconnection_on_failure() {
422 let (backend, fail_count, total_calls) = FailingBackend::new(2);
424 let resilient = ResilientBackend::new(backend);
425
426 let result = resilient.receive().await;
428
429 assert!(result.is_ok());
430 if let Ok(receive_result) = result {
431 assert!(receive_result.is_shutdown()); }
433 assert_eq!(fail_count.load(Ordering::SeqCst), 2); assert_eq!(total_calls.load(Ordering::SeqCst), 3); assert_eq!(resilient.reconnect_attempts().await, 0); assert_eq!(resilient.consecutive_failures().await, 0); assert!(resilient.is_connected().await);
438 }
439
440 #[tokio::test]
441 async fn test_connection_state_tracking() {
442 let (backend, _, _) = FailingBackend::new(1);
444 let resilient = ResilientBackend::new(backend);
445
446 assert!(resilient.is_connected().await);
448 assert_eq!(resilient.reconnect_attempts().await, 0);
449
450 let _ = resilient.receive().await;
452
453 assert!(resilient.is_connected().await);
455 assert_eq!(resilient.reconnect_attempts().await, 0); }
457
458 #[tokio::test]
459 async fn test_consecutive_failure_tracking() {
460 let (backend, _, _) = FailingBackend::new(3);
462 let resilient = ResilientBackend::new(backend);
463
464 let _ = resilient.receive().await;
466
467 assert_eq!(resilient.consecutive_failures().await, 0);
469 }
470
471 #[tokio::test]
472 async fn test_ack_operations_dont_retry_indefinitely() {
473 let inner = Arc::new(MemoryBackend::new());
474 let resilient = ResilientBackend::new(inner.clone());
475
476 let result = resilient.ack("non-existent-id").await;
478 assert!(result.is_ok());
479 }
480
481 #[tokio::test]
482 async fn test_health_check_passthrough() {
483 let inner = Arc::new(MemoryBackend::new());
484 let resilient = ResilientBackend::new(inner.clone());
485
486 let result = resilient.health_check().await;
488 assert!(result.is_ok());
489 }
490
491 #[tokio::test]
492 async fn test_shutdown_passthrough() {
493 let inner = Arc::new(MemoryBackend::new());
494 let resilient = ResilientBackend::new(inner.clone());
495
496 let result = resilient.shutdown().await;
498 assert!(result.is_ok());
499 }
500
501 #[tokio::test]
502 async fn test_builder_pattern() {
503 let inner = Arc::new(MemoryBackend::new());
504 let strategy = ReconnectStrategy::Exponential {
505 initial: Duration::from_millis(500),
506 max: Duration::from_secs(30),
507 multiplier: 2.5,
508 jitter_factor: 0.2,
509 };
510
511 let resilient = ResilientBackendBuilder::new(inner)
512 .with_strategy(strategy)
513 .build();
514
515 assert!(resilient.is_connected().await);
516 }
517
518 #[tokio::test]
519 async fn test_multiple_receive_operations() {
520 let inner = MemoryBackend::new();
521 let backend_arc = Arc::new(inner);
522 let resilient = ResilientBackend::new(backend_arc.clone());
523
524 backend_arc.enqueue(serde_json::json!({"msg": 1}));
526 backend_arc.enqueue(serde_json::json!({"msg": 2}));
527 backend_arc.enqueue(serde_json::json!({"msg": 3}));
528
529 for expected in 1..=3 {
531 let result = resilient.receive().await.unwrap();
532 if let ReceiveResult::Message(msg) = result {
533 assert_eq!(msg.message.payload["msg"], expected);
534 } else {
535 panic!("Expected Message variant, got {:?}", result);
536 }
537 }
538
539 assert_eq!(resilient.reconnect_attempts().await, 0);
541 }
542
543 #[tokio::test]
544 async fn test_default_reconnect_strategy() {
545 let strategy = ReconnectStrategy::default();
546
547 match strategy {
549 ReconnectStrategy::Exponential { initial, max, multiplier, jitter_factor } => {
550 assert_eq!(initial, Duration::from_secs(1));
551 assert_eq!(max, Duration::from_secs(60));
552 assert_eq!(multiplier, 2.0);
553 assert_eq!(jitter_factor, 0.1);
554 }
555 _ => panic!("Default should be Exponential"),
556 }
557 }
558}