1use std::sync::Arc;
2use std::time::{Duration, Instant};
3use tokio::sync::RwLock;
4use tracing::{error, info, warn};
5
6use crate::backends::{MessageBackend, ReceiveResult};
7use crate::error::WorkerResult;
8
9#[derive(Debug, Clone)]
11pub enum ReconnectStrategy {
12 Fixed(Duration),
14
15 Exponential {
17 initial: Duration,
18 max: Duration,
19 multiplier: f64,
20 jitter_factor: f64,
22 },
23}
24
25impl ReconnectStrategy {
26 fn delay_for_attempt(&self, attempt: u32) -> Duration {
27 match self {
28 ReconnectStrategy::Fixed(d) => *d,
29 ReconnectStrategy::Exponential {
30 initial,
31 max,
32 multiplier,
33 jitter_factor,
34 } => {
35 let base_delay = initial.mul_f64(multiplier.powi(attempt as i32));
37 let clamped = base_delay.min(*max);
38
39 if *jitter_factor > 0.0 {
41 let jitter_range = clamped.mul_f64(*jitter_factor);
42 let jitter = jitter_range.mul_f64(rand::random::<f64>());
43 clamped + jitter
44 } else {
45 clamped
46 }
47 }
48 }
49 }
50}
51
52impl Default for ReconnectStrategy {
53 fn default() -> Self {
54 ReconnectStrategy::Exponential {
55 initial: Duration::from_secs(1),
56 max: Duration::from_secs(60),
57 multiplier: 2.0,
58 jitter_factor: 0.1, }
60 }
61}
62
63pub struct ResilientBackend {
85 inner: Arc<dyn MessageBackend>,
86 strategy: ReconnectStrategy,
87 reconnect_attempts: Arc<RwLock<u32>>,
88 last_success: Arc<RwLock<Instant>>,
89 is_connected: Arc<RwLock<bool>>,
90 consecutive_failures: Arc<RwLock<u32>>,
92}
93
94impl ResilientBackend {
95 pub fn new(inner: Arc<dyn MessageBackend>) -> Self {
97 Self {
98 inner,
99 strategy: ReconnectStrategy::default(),
100 reconnect_attempts: Arc::new(RwLock::new(0)),
101 last_success: Arc::new(RwLock::new(Instant::now())),
102 is_connected: Arc::new(RwLock::new(true)),
103 consecutive_failures: Arc::new(RwLock::new(0)),
104 }
105 }
106
107 pub fn with_strategy(inner: Arc<dyn MessageBackend>, strategy: ReconnectStrategy) -> Self {
109 Self {
110 inner,
111 strategy,
112 reconnect_attempts: Arc::new(RwLock::new(0)),
113 last_success: Arc::new(RwLock::new(Instant::now())),
114 is_connected: Arc::new(RwLock::new(true)),
115 consecutive_failures: Arc::new(RwLock::new(0)),
116 }
117 }
118
119 pub fn inner(&self) -> &Arc<dyn MessageBackend> {
121 &self.inner
122 }
123
124 pub async fn is_connected(&self) -> bool {
126 *self.is_connected.read().await
127 }
128
129 pub async fn reconnect_attempts(&self) -> u32 {
131 *self.reconnect_attempts.read().await
132 }
133
134 pub async fn consecutive_failures(&self) -> u32 {
136 *self.consecutive_failures.read().await
137 }
138
139 async fn execute_with_retry<T, F, Fut>(&self, operation_name: &str, op: F) -> WorkerResult<T>
144 where
145 F: Fn() -> Fut,
146 Fut: std::future::Future<Output = WorkerResult<T>>,
147 {
148 let mut attempt = 0;
149
150 loop {
151 match op().await {
152 Ok(result) => {
153 if attempt > 0 {
155 info!("{} succeeded after {} attempts", operation_name, attempt);
156 }
157 *self.reconnect_attempts.write().await = 0;
158 *self.consecutive_failures.write().await = 0;
159 *self.last_success.write().await = Instant::now();
160 *self.is_connected.write().await = true;
161 return Ok(result);
162 }
163 Err(e) => {
164 attempt += 1;
165 *self.reconnect_attempts.write().await = attempt;
166 let failures = {
167 let mut f = self.consecutive_failures.write().await;
168 *f += 1;
169 *f
170 };
171 *self.is_connected.write().await = false;
172
173 warn!(
174 "{} failed (attempt {}, consecutive failures: {}): {}. Retrying...",
175 operation_name, attempt, failures, e
176 );
177
178 if let Err(recover_err) = self.try_recover().await {
180 error!("Recovery attempt failed: {}", recover_err);
181 }
182
183 let delay = self.strategy.delay_for_attempt(attempt - 1);
185
186 if attempt % 10 == 0 || attempt <= 3 {
188 warn!(
189 "Still trying {} (attempt {}) - next retry in {:?}",
190 operation_name, attempt, delay
191 );
192 }
193
194 tokio::time::sleep(delay).await;
195
196 }
199 }
200 }
201 }
202
203 async fn try_recover(&self) -> WorkerResult<()> {
205 match self.inner.health_check().await {
207 Ok(_) => {
208 info!("Connection recovered");
209 *self.consecutive_failures.write().await = 0;
210 Ok(())
211 }
212 Err(e) => {
213 warn!("Health check failed during recovery: {}", e);
214 Err(e)
217 }
218 }
219 }
220}
221
222#[async_trait::async_trait]
223impl MessageBackend for ResilientBackend {
224 async fn receive(&self) -> WorkerResult<ReceiveResult<serde_json::Value>> {
225 self.execute_with_retry("receive", || async { self.inner.receive().await })
226 .await
227 }
228
229 async fn ack(&self, message_id: &str) -> WorkerResult<()> {
230 self.inner.ack(message_id).await
233 }
234
235 async fn nack(&self, message_id: &str, requeue: bool) -> WorkerResult<()> {
236 self.execute_with_retry("nack", || async {
238 self.inner.nack(message_id, requeue).await
239 })
240 .await
241 }
242
243 async fn health_check(&self) -> WorkerResult<()> {
244 self.inner.health_check().await
245 }
246
247 async fn shutdown(&self) -> WorkerResult<()> {
248 self.inner.shutdown().await
249 }
250}
251
252pub struct ResilientBackendBuilder {
254 inner: Arc<dyn MessageBackend>,
255 strategy: ReconnectStrategy,
256}
257
258impl ResilientBackendBuilder {
259 pub fn new(inner: Arc<dyn MessageBackend>) -> Self {
261 Self {
262 inner,
263 strategy: ReconnectStrategy::default(),
264 }
265 }
266
267 pub fn with_strategy(mut self, strategy: ReconnectStrategy) -> Self {
269 self.strategy = strategy;
270 self
271 }
272
273 pub fn build(self) -> ResilientBackend {
275 ResilientBackend::with_strategy(self.inner, self.strategy)
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use crate::backends::{MemoryBackend, ReceiveResult};
283 use crate::error::WorkerError;
284 use std::sync::atomic::{AtomicUsize, Ordering};
285
286 struct FailingBackend {
288 fail_count: Arc<AtomicUsize>,
289 total_calls: Arc<AtomicUsize>,
290 succeed_after: usize,
291 }
292
293 impl FailingBackend {
294 fn new(succeed_after: usize) -> (Arc<Self>, Arc<AtomicUsize>, Arc<AtomicUsize>) {
295 let fail_count = Arc::new(AtomicUsize::new(0));
296 let total_calls = Arc::new(AtomicUsize::new(0));
297 (
298 Arc::new(Self {
299 fail_count: fail_count.clone(),
300 total_calls: total_calls.clone(),
301 succeed_after,
302 }),
303 fail_count,
304 total_calls,
305 )
306 }
307 }
308
309 #[async_trait::async_trait]
310 impl MessageBackend for FailingBackend {
311 async fn receive(&self) -> WorkerResult<ReceiveResult<serde_json::Value>> {
312 let calls = self.total_calls.fetch_add(1, Ordering::SeqCst);
313 if calls < self.succeed_after {
314 self.fail_count.fetch_add(1, Ordering::SeqCst);
315 Err(WorkerError::BackendError(
316 "Simulated network failure".to_string(),
317 ))
318 } else {
319 Ok(ReceiveResult::Shutdown)
320 }
321 }
322
323 async fn ack(&self, _message_id: &str) -> WorkerResult<()> {
324 Ok(())
325 }
326
327 async fn nack(&self, _message_id: &str, _requeue: bool) -> WorkerResult<()> {
328 Ok(())
329 }
330
331 async fn health_check(&self) -> WorkerResult<()> {
332 let calls = self.total_calls.load(Ordering::SeqCst);
333 if calls < self.succeed_after {
334 Err(WorkerError::BackendError("Health check failed".to_string()))
335 } else {
336 Ok(())
337 }
338 }
339
340 async fn shutdown(&self) -> WorkerResult<()> {
341 Ok(())
342 }
343 }
344
345 #[tokio::test]
346 async fn test_resilient_backend_wraps_successfully() {
347 let inner = Arc::new(MemoryBackend::new());
348 let resilient = ResilientBackend::new(inner.clone());
349
350 assert!(resilient.is_connected().await);
351 assert_eq!(resilient.reconnect_attempts().await, 0);
352 assert_eq!(resilient.consecutive_failures().await, 0);
353 }
354
355 #[tokio::test]
356 async fn test_resilient_backend_receive() {
357 let inner = MemoryBackend::new();
358 let backend_arc = Arc::new(inner);
359 let resilient = ResilientBackend::new(backend_arc.clone());
360
361 backend_arc.enqueue(serde_json::json!({"test": "data"}));
363
364 let result = resilient.receive().await.unwrap();
366 assert!(result.is_message());
367 if let ReceiveResult::Message(msg) = result {
368 assert_eq!(msg.message.payload["test"], "data");
369 } else {
370 panic!("Expected Message variant");
371 }
372 }
373
374 #[tokio::test]
375 async fn test_resilient_backend_with_custom_strategy() {
376 let inner = Arc::new(MemoryBackend::new());
377 let strategy = ReconnectStrategy::Fixed(Duration::from_secs(1));
378 let resilient = ResilientBackend::with_strategy(inner, strategy);
379
380 assert!(resilient.is_connected().await);
381 }
382
383 #[tokio::test]
384 async fn test_exponential_backoff_calculation() {
385 let strategy = ReconnectStrategy::Exponential {
386 initial: Duration::from_millis(100),
387 max: Duration::from_secs(1),
388 multiplier: 2.0,
389 jitter_factor: 0.0, };
391
392 assert_eq!(strategy.delay_for_attempt(0).as_millis(), 100); assert_eq!(strategy.delay_for_attempt(1).as_millis(), 200); assert_eq!(strategy.delay_for_attempt(2).as_millis(), 400); assert_eq!(strategy.delay_for_attempt(3).as_millis(), 800); assert_eq!(strategy.delay_for_attempt(4).as_millis(), 1000); assert_eq!(strategy.delay_for_attempt(5).as_millis(), 1000); }
400
401 #[tokio::test]
402 async fn test_exponential_backoff_with_jitter() {
403 let strategy = ReconnectStrategy::Exponential {
404 initial: Duration::from_millis(100),
405 max: Duration::from_secs(1),
406 multiplier: 2.0,
407 jitter_factor: 0.5, };
409
410 let delay = strategy.delay_for_attempt(0);
412 let base = 100;
413 assert!(delay.as_millis() >= base as u128);
414 assert!(delay.as_millis() <= (base as f64 * 1.5) as u128);
415 }
416
417 #[tokio::test]
418 async fn test_fixed_delay_strategy() {
419 let strategy = ReconnectStrategy::Fixed(Duration::from_secs(2));
420
421 assert_eq!(strategy.delay_for_attempt(0).as_secs(), 2);
423 assert_eq!(strategy.delay_for_attempt(5).as_secs(), 2);
424 assert_eq!(strategy.delay_for_attempt(100).as_secs(), 2);
425 }
426
427 #[tokio::test]
428 async fn test_reconnection_on_failure() {
429 let (backend, fail_count, total_calls) = FailingBackend::new(2);
431 let resilient = ResilientBackend::new(backend);
432
433 let result = resilient.receive().await;
435
436 assert!(result.is_ok());
437 if let Ok(receive_result) = result {
438 assert!(receive_result.is_shutdown()); }
440 assert_eq!(fail_count.load(Ordering::SeqCst), 2); assert_eq!(total_calls.load(Ordering::SeqCst), 3); assert_eq!(resilient.reconnect_attempts().await, 0); assert_eq!(resilient.consecutive_failures().await, 0); assert!(resilient.is_connected().await);
445 }
446
447 #[tokio::test]
448 async fn test_connection_state_tracking() {
449 let (backend, _, _) = FailingBackend::new(1);
451 let resilient = ResilientBackend::new(backend);
452
453 assert!(resilient.is_connected().await);
455 assert_eq!(resilient.reconnect_attempts().await, 0);
456
457 let _ = resilient.receive().await;
459
460 assert!(resilient.is_connected().await);
462 assert_eq!(resilient.reconnect_attempts().await, 0); }
464
465 #[tokio::test]
466 async fn test_consecutive_failure_tracking() {
467 let (backend, _, _) = FailingBackend::new(3);
469 let resilient = ResilientBackend::new(backend);
470
471 let _ = resilient.receive().await;
473
474 assert_eq!(resilient.consecutive_failures().await, 0);
476 }
477
478 #[tokio::test]
479 async fn test_ack_operations_dont_retry_indefinitely() {
480 let inner = Arc::new(MemoryBackend::new());
481 let resilient = ResilientBackend::new(inner.clone());
482
483 let result = resilient.ack("non-existent-id").await;
485 assert!(result.is_ok());
486 }
487
488 #[tokio::test]
489 async fn test_health_check_passthrough() {
490 let inner = Arc::new(MemoryBackend::new());
491 let resilient = ResilientBackend::new(inner.clone());
492
493 let result = resilient.health_check().await;
495 assert!(result.is_ok());
496 }
497
498 #[tokio::test]
499 async fn test_shutdown_passthrough() {
500 let inner = Arc::new(MemoryBackend::new());
501 let resilient = ResilientBackend::new(inner.clone());
502
503 let result = resilient.shutdown().await;
505 assert!(result.is_ok());
506 }
507
508 #[tokio::test]
509 async fn test_builder_pattern() {
510 let inner = Arc::new(MemoryBackend::new());
511 let strategy = ReconnectStrategy::Exponential {
512 initial: Duration::from_millis(500),
513 max: Duration::from_secs(30),
514 multiplier: 2.5,
515 jitter_factor: 0.2,
516 };
517
518 let resilient = ResilientBackendBuilder::new(inner)
519 .with_strategy(strategy)
520 .build();
521
522 assert!(resilient.is_connected().await);
523 }
524
525 #[tokio::test]
526 async fn test_multiple_receive_operations() {
527 let inner = MemoryBackend::new();
528 let backend_arc = Arc::new(inner);
529 let resilient = ResilientBackend::new(backend_arc.clone());
530
531 backend_arc.enqueue(serde_json::json!({"msg": 1}));
533 backend_arc.enqueue(serde_json::json!({"msg": 2}));
534 backend_arc.enqueue(serde_json::json!({"msg": 3}));
535
536 for expected in 1..=3 {
538 let result = resilient.receive().await.unwrap();
539 if let ReceiveResult::Message(msg) = result {
540 assert_eq!(msg.message.payload["msg"], expected);
541 } else {
542 panic!("Expected Message variant, got {:?}", result);
543 }
544 }
545
546 assert_eq!(resilient.reconnect_attempts().await, 0);
548 }
549
550 #[tokio::test]
551 async fn test_default_reconnect_strategy() {
552 let strategy = ReconnectStrategy::default();
553
554 match strategy {
556 ReconnectStrategy::Exponential {
557 initial,
558 max,
559 multiplier,
560 jitter_factor,
561 } => {
562 assert_eq!(initial, Duration::from_secs(1));
563 assert_eq!(max, Duration::from_secs(60));
564 assert_eq!(multiplier, 2.0);
565 assert_eq!(jitter_factor, 0.1);
566 }
567 _ => panic!("Default should be Exponential"),
568 }
569 }
570}