use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{error, info, warn};
use crate::backends::{MessageBackend, ReceiveResult};
use crate::error::WorkerResult;
#[derive(Debug, Clone)]
pub enum ReconnectStrategy {
Fixed(Duration),
Exponential {
initial: Duration,
max: Duration,
multiplier: f64,
jitter_factor: f64,
},
}
impl ReconnectStrategy {
fn delay_for_attempt(&self, attempt: u32) -> Duration {
match self {
ReconnectStrategy::Fixed(d) => *d,
ReconnectStrategy::Exponential {
initial,
max,
multiplier,
jitter_factor,
} => {
let base_delay = initial.mul_f64(multiplier.powi(attempt as i32));
let clamped = base_delay.min(*max);
if *jitter_factor > 0.0 {
let jitter_range = clamped.mul_f64(*jitter_factor);
let jitter = jitter_range.mul_f64(rand::random::<f64>());
clamped + jitter
} else {
clamped
}
}
}
}
}
impl Default for ReconnectStrategy {
fn default() -> Self {
ReconnectStrategy::Exponential {
initial: Duration::from_secs(1),
max: Duration::from_secs(60),
multiplier: 2.0,
jitter_factor: 0.1, }
}
}
pub struct ResilientBackend {
inner: Arc<dyn MessageBackend>,
strategy: ReconnectStrategy,
reconnect_attempts: Arc<RwLock<u32>>,
last_success: Arc<RwLock<Instant>>,
is_connected: Arc<RwLock<bool>>,
consecutive_failures: Arc<RwLock<u32>>,
}
impl ResilientBackend {
pub fn new(inner: Arc<dyn MessageBackend>) -> Self {
Self {
inner,
strategy: ReconnectStrategy::default(),
reconnect_attempts: Arc::new(RwLock::new(0)),
last_success: Arc::new(RwLock::new(Instant::now())),
is_connected: Arc::new(RwLock::new(true)),
consecutive_failures: Arc::new(RwLock::new(0)),
}
}
pub fn with_strategy(inner: Arc<dyn MessageBackend>, strategy: ReconnectStrategy) -> Self {
Self {
inner,
strategy,
reconnect_attempts: Arc::new(RwLock::new(0)),
last_success: Arc::new(RwLock::new(Instant::now())),
is_connected: Arc::new(RwLock::new(true)),
consecutive_failures: Arc::new(RwLock::new(0)),
}
}
pub fn inner(&self) -> &Arc<dyn MessageBackend> {
&self.inner
}
pub async fn is_connected(&self) -> bool {
*self.is_connected.read().await
}
pub async fn reconnect_attempts(&self) -> u32 {
*self.reconnect_attempts.read().await
}
pub async fn consecutive_failures(&self) -> u32 {
*self.consecutive_failures.read().await
}
async fn execute_with_retry<T, F, Fut>(&self, operation_name: &str, op: F) -> WorkerResult<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = WorkerResult<T>>,
{
let mut attempt = 0;
loop {
match op().await {
Ok(result) => {
if attempt > 0 {
info!("{} succeeded after {} attempts", operation_name, attempt);
}
*self.reconnect_attempts.write().await = 0;
*self.consecutive_failures.write().await = 0;
*self.last_success.write().await = Instant::now();
*self.is_connected.write().await = true;
return Ok(result);
}
Err(e) => {
attempt += 1;
*self.reconnect_attempts.write().await = attempt;
let failures = {
let mut f = self.consecutive_failures.write().await;
*f += 1;
*f
};
*self.is_connected.write().await = false;
warn!(
"{} failed (attempt {}, consecutive failures: {}): {}. Retrying...",
operation_name, attempt, failures, e
);
if let Err(recover_err) = self.try_recover().await {
error!("Recovery attempt failed: {}", recover_err);
}
let delay = self.strategy.delay_for_attempt(attempt - 1);
if attempt % 10 == 0 || attempt <= 3 {
warn!(
"Still trying {} (attempt {}) - next retry in {:?}",
operation_name, attempt, delay
);
}
tokio::time::sleep(delay).await;
}
}
}
}
async fn try_recover(&self) -> WorkerResult<()> {
match self.inner.health_check().await {
Ok(_) => {
info!("Connection recovered");
*self.consecutive_failures.write().await = 0;
Ok(())
}
Err(e) => {
warn!("Health check failed during recovery: {}", e);
Err(e)
}
}
}
}
#[async_trait::async_trait]
impl MessageBackend for ResilientBackend {
async fn receive(&self) -> WorkerResult<ReceiveResult<serde_json::Value>> {
self.execute_with_retry("receive", || async { self.inner.receive().await })
.await
}
async fn ack(&self, message_id: &str) -> WorkerResult<()> {
self.inner.ack(message_id).await
}
async fn nack(&self, message_id: &str, requeue: bool) -> WorkerResult<()> {
self.execute_with_retry("nack", || async {
self.inner.nack(message_id, requeue).await
})
.await
}
async fn health_check(&self) -> WorkerResult<()> {
self.inner.health_check().await
}
async fn shutdown(&self) -> WorkerResult<()> {
self.inner.shutdown().await
}
}
pub struct ResilientBackendBuilder {
inner: Arc<dyn MessageBackend>,
strategy: ReconnectStrategy,
}
impl ResilientBackendBuilder {
pub fn new(inner: Arc<dyn MessageBackend>) -> Self {
Self {
inner,
strategy: ReconnectStrategy::default(),
}
}
pub fn with_strategy(mut self, strategy: ReconnectStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn build(self) -> ResilientBackend {
ResilientBackend::with_strategy(self.inner, self.strategy)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::{MemoryBackend, ReceiveResult};
use crate::error::WorkerError;
use std::sync::atomic::{AtomicUsize, Ordering};
struct FailingBackend {
fail_count: Arc<AtomicUsize>,
total_calls: Arc<AtomicUsize>,
succeed_after: usize,
}
impl FailingBackend {
fn new(succeed_after: usize) -> (Arc<Self>, Arc<AtomicUsize>, Arc<AtomicUsize>) {
let fail_count = Arc::new(AtomicUsize::new(0));
let total_calls = Arc::new(AtomicUsize::new(0));
(
Arc::new(Self {
fail_count: fail_count.clone(),
total_calls: total_calls.clone(),
succeed_after,
}),
fail_count,
total_calls,
)
}
}
#[async_trait::async_trait]
impl MessageBackend for FailingBackend {
async fn receive(&self) -> WorkerResult<ReceiveResult<serde_json::Value>> {
let calls = self.total_calls.fetch_add(1, Ordering::SeqCst);
if calls < self.succeed_after {
self.fail_count.fetch_add(1, Ordering::SeqCst);
Err(WorkerError::BackendError(
"Simulated network failure".to_string(),
))
} else {
Ok(ReceiveResult::Shutdown)
}
}
async fn ack(&self, _message_id: &str) -> WorkerResult<()> {
Ok(())
}
async fn nack(&self, _message_id: &str, _requeue: bool) -> WorkerResult<()> {
Ok(())
}
async fn health_check(&self) -> WorkerResult<()> {
let calls = self.total_calls.load(Ordering::SeqCst);
if calls < self.succeed_after {
Err(WorkerError::BackendError("Health check failed".to_string()))
} else {
Ok(())
}
}
async fn shutdown(&self) -> WorkerResult<()> {
Ok(())
}
}
#[tokio::test]
async fn test_resilient_backend_wraps_successfully() {
let inner = Arc::new(MemoryBackend::new());
let resilient = ResilientBackend::new(inner.clone());
assert!(resilient.is_connected().await);
assert_eq!(resilient.reconnect_attempts().await, 0);
assert_eq!(resilient.consecutive_failures().await, 0);
}
#[tokio::test]
async fn test_resilient_backend_receive() {
let inner = MemoryBackend::new();
let backend_arc = Arc::new(inner);
let resilient = ResilientBackend::new(backend_arc.clone());
backend_arc.enqueue(serde_json::json!({"test": "data"}));
let result = resilient.receive().await.unwrap();
assert!(result.is_message());
if let ReceiveResult::Message(msg) = result {
assert_eq!(msg.message.payload["test"], "data");
} else {
panic!("Expected Message variant");
}
}
#[tokio::test]
async fn test_resilient_backend_with_custom_strategy() {
let inner = Arc::new(MemoryBackend::new());
let strategy = ReconnectStrategy::Fixed(Duration::from_secs(1));
let resilient = ResilientBackend::with_strategy(inner, strategy);
assert!(resilient.is_connected().await);
}
#[tokio::test]
async fn test_exponential_backoff_calculation() {
let strategy = ReconnectStrategy::Exponential {
initial: Duration::from_millis(100),
max: Duration::from_secs(1),
multiplier: 2.0,
jitter_factor: 0.0, };
assert_eq!(strategy.delay_for_attempt(0).as_millis(), 100); assert_eq!(strategy.delay_for_attempt(1).as_millis(), 200); assert_eq!(strategy.delay_for_attempt(2).as_millis(), 400); assert_eq!(strategy.delay_for_attempt(3).as_millis(), 800); assert_eq!(strategy.delay_for_attempt(4).as_millis(), 1000); assert_eq!(strategy.delay_for_attempt(5).as_millis(), 1000); }
#[tokio::test]
async fn test_exponential_backoff_with_jitter() {
let strategy = ReconnectStrategy::Exponential {
initial: Duration::from_millis(100),
max: Duration::from_secs(1),
multiplier: 2.0,
jitter_factor: 0.5, };
let delay = strategy.delay_for_attempt(0);
let base = 100;
assert!(delay.as_millis() >= base as u128);
assert!(delay.as_millis() <= (base as f64 * 1.5) as u128);
}
#[tokio::test]
async fn test_fixed_delay_strategy() {
let strategy = ReconnectStrategy::Fixed(Duration::from_secs(2));
assert_eq!(strategy.delay_for_attempt(0).as_secs(), 2);
assert_eq!(strategy.delay_for_attempt(5).as_secs(), 2);
assert_eq!(strategy.delay_for_attempt(100).as_secs(), 2);
}
#[tokio::test]
async fn test_reconnection_on_failure() {
let (backend, fail_count, total_calls) = FailingBackend::new(2);
let resilient = ResilientBackend::new(backend);
let result = resilient.receive().await;
assert!(result.is_ok());
if let Ok(receive_result) = result {
assert!(receive_result.is_shutdown()); }
assert_eq!(fail_count.load(Ordering::SeqCst), 2); assert_eq!(total_calls.load(Ordering::SeqCst), 3); assert_eq!(resilient.reconnect_attempts().await, 0); assert_eq!(resilient.consecutive_failures().await, 0); assert!(resilient.is_connected().await);
}
#[tokio::test]
async fn test_connection_state_tracking() {
let (backend, _, _) = FailingBackend::new(1);
let resilient = ResilientBackend::new(backend);
assert!(resilient.is_connected().await);
assert_eq!(resilient.reconnect_attempts().await, 0);
let _ = resilient.receive().await;
assert!(resilient.is_connected().await);
assert_eq!(resilient.reconnect_attempts().await, 0); }
#[tokio::test]
async fn test_consecutive_failure_tracking() {
let (backend, _, _) = FailingBackend::new(3);
let resilient = ResilientBackend::new(backend);
let _ = resilient.receive().await;
assert_eq!(resilient.consecutive_failures().await, 0);
}
#[tokio::test]
async fn test_ack_operations_dont_retry_indefinitely() {
let inner = Arc::new(MemoryBackend::new());
let resilient = ResilientBackend::new(inner.clone());
let result = resilient.ack("non-existent-id").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_health_check_passthrough() {
let inner = Arc::new(MemoryBackend::new());
let resilient = ResilientBackend::new(inner.clone());
let result = resilient.health_check().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_shutdown_passthrough() {
let inner = Arc::new(MemoryBackend::new());
let resilient = ResilientBackend::new(inner.clone());
let result = resilient.shutdown().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_builder_pattern() {
let inner = Arc::new(MemoryBackend::new());
let strategy = ReconnectStrategy::Exponential {
initial: Duration::from_millis(500),
max: Duration::from_secs(30),
multiplier: 2.5,
jitter_factor: 0.2,
};
let resilient = ResilientBackendBuilder::new(inner)
.with_strategy(strategy)
.build();
assert!(resilient.is_connected().await);
}
#[tokio::test]
async fn test_multiple_receive_operations() {
let inner = MemoryBackend::new();
let backend_arc = Arc::new(inner);
let resilient = ResilientBackend::new(backend_arc.clone());
backend_arc.enqueue(serde_json::json!({"msg": 1}));
backend_arc.enqueue(serde_json::json!({"msg": 2}));
backend_arc.enqueue(serde_json::json!({"msg": 3}));
for expected in 1..=3 {
let result = resilient.receive().await.unwrap();
if let ReceiveResult::Message(msg) = result {
assert_eq!(msg.message.payload["msg"], expected);
} else {
panic!("Expected Message variant, got {:?}", result);
}
}
assert_eq!(resilient.reconnect_attempts().await, 0);
}
#[tokio::test]
async fn test_default_reconnect_strategy() {
let strategy = ReconnectStrategy::default();
match strategy {
ReconnectStrategy::Exponential {
initial,
max,
multiplier,
jitter_factor,
} => {
assert_eq!(initial, Duration::from_secs(1));
assert_eq!(max, Duration::from_secs(60));
assert_eq!(multiplier, 2.0);
assert_eq!(jitter_factor, 0.1);
}
_ => panic!("Default should be Exponential"),
}
}
}