use async_trait::async_trait;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, warn, error, info};
use crate::error::{WorkerError, WorkerResult};
use crate::message::ReceivedMessage;
use crate::middleware::{Middleware, MessageHandler};
use crate::backends::{DeadLetterQueueBackend, create_dlq_message};
use crate::dlq::PoisonPillTracker;
#[derive(Clone)]
pub struct RetryHandlerConfig {
pub max_retries: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub dead_letter_queue: Option<Arc<DeadLetterQueueBackend>>,
pub poison_pill_tracker: Option<Arc<PoisonPillTracker>>,
pub use_jitter: bool,
}
impl std::fmt::Debug for RetryHandlerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryHandlerConfig")
.field("max_retries", &self.max_retries)
.field("initial_backoff", &self.initial_backoff)
.field("max_backoff", &self.max_backoff)
.field("backoff_multiplier", &self.backoff_multiplier)
.field("dead_letter_queue", &self.dead_letter_queue.as_ref().map(|_| "<MessageBackend>"))
.field("use_jitter", &self.use_jitter)
.finish()
}
}
impl Default for RetryHandlerConfig {
fn default() -> Self {
Self {
max_retries: 5,
initial_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(60),
backoff_multiplier: 2.0,
dead_letter_queue: None,
poison_pill_tracker: None,
use_jitter: true,
}
}
}
impl RetryHandlerConfig {
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_dead_letter_queue(mut self, dlq: Arc<DeadLetterQueueBackend>) -> Self {
self.dead_letter_queue = Some(dlq);
self
}
pub fn with_poison_pill_tracker(mut self, tracker: Arc<PoisonPillTracker>) -> Self {
self.poison_pill_tracker = Some(tracker);
self
}
pub fn with_jitter(mut self, use_jitter: bool) -> Self {
self.use_jitter = use_jitter;
self
}
}
pub struct RetryHandler {
config: RetryHandlerConfig,
}
impl RetryHandler {
pub fn new(config: RetryHandlerConfig) -> Self {
Self { config }
}
fn calculate_backoff(&self, attempts: u32) -> Duration {
if attempts == 0 {
return self.config.initial_backoff;
}
let current_backoff = self.config.initial_backoff.as_secs_f64()
* self.config.backoff_multiplier.powf(attempts as f64 - 1.0);
let mut backoff = Duration::from_secs_f64(current_backoff);
if self.config.use_jitter {
let jitter_factor = rand::random::<f64>() * 0.5 - 0.25; let jitter = backoff.as_secs_f64() * jitter_factor;
let new_backoff = backoff.as_secs_f64() + jitter;
backoff = Duration::from_secs_f64(new_backoff.max(0.01));
}
std::cmp::min(backoff, self.config.max_backoff)
}
async fn send_to_dlq(&self, message: &ReceivedMessage<serde_json::Value>, error: &WorkerError) {
if let Some(ref dlq) = self.config.dead_letter_queue {
let is_poison_pill = if let Some(ref tracker) = self.config.poison_pill_tracker {
tracker.record_failure(&message.message.id)
} else {
false
};
if is_poison_pill {
error!(
"[{}] POISON PILL DETECTED: Message {} failed {} times - sending to DLQ",
self.name(),
message.message.id,
message.message.metadata.attempt
);
}
let mut dlq_message = create_dlq_message(
message.message.id.clone(),
message.message.payload.clone(),
message.message.metadata.source.clone(),
message.message.metadata.attempt,
error,
None, );
if is_poison_pill {
dlq_message = dlq_message.with_context("poison_pill", serde_json::json!(true));
}
match dlq.send_to_dlq(&dlq_message).await {
Ok(_) => {
info!(
"[{}] Successfully sent message {} to DLQ after {} attempts",
self.name(),
message.message.id,
message.message.metadata.attempt
);
}
Err(e) => {
error!(
"[{}] Failed to send message {} to DLQ: {:?}",
self.name(),
message.message.id,
e
);
}
}
}
}
}
#[async_trait]
impl Middleware for RetryHandler {
fn name(&self) -> &str {
"RetryHandler"
}
async fn handle(
&self,
mut message: ReceivedMessage<serde_json::Value>,
next: Box<dyn MessageHandler>,
) -> WorkerResult<()> {
message.message.metadata.increment_attempt();
let current_attempts = message.message.metadata.attempt;
debug!(
"[{}] Processing message {} (attempt {}/{})",
self.name(),
message.message.id,
current_attempts,
self.config.max_retries
);
let result = next.handle(message.clone()).await;
match result {
Ok(_) => {
debug!("[{}] Message {} processed successfully.", self.name(), message.message.id);
Ok(())
}
Err(e) => {
warn!(
"[{}] Message {} failed on attempt {}: {:?}",
self.name(),
message.message.id,
current_attempts,
e
);
if current_attempts < self.config.max_retries {
let delay = self.calculate_backoff(current_attempts);
debug!(
"[{}] Message {} will be retried in {:?}. Current attempts: {}",
self.name(),
message.message.id,
delay,
current_attempts
);
Err(WorkerError::RetryableFailure {
source: Box::new(e),
delay_ms: delay,
})
} else {
self.send_to_dlq(&message, &e).await;
warn!(
"[{}] Retries exhausted for message {} after {} attempts.",
self.name(),
message.message.id,
current_attempts
);
Err(WorkerError::RetriesExhausted {
source: Box::new(e),
})
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{Message, MessageMetadata, ReceivedMessage, AckHandle};
use std::sync::Arc;
#[derive(Debug)]
struct MockAckHandle;
#[async_trait::async_trait]
impl AckHandle for MockAckHandle {
async fn ack(&self) -> WorkerResult<()> { Ok(()) }
async fn nack(&self, _requeue: bool) -> WorkerResult<()> { Ok(()) }
}
struct FailingHandler {
fail_count: std::sync::atomic::AtomicUsize,
fail_until: usize,
}
#[async_trait::async_trait]
impl MessageHandler for FailingHandler {
async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
let count = self.fail_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if count < self.fail_until {
Err(WorkerError::ProcessingError("Simulated failure".into()))
} else {
Ok(())
}
}
}
#[tokio::test]
async fn test_retry_success_after_failures() {
let config = RetryHandlerConfig::default().with_max_retries(3);
let handler = RetryHandler::new(config);
let inner_handler = FailingHandler {
fail_count: std::sync::atomic::AtomicUsize::new(0),
fail_until: 2, };
let message = ReceivedMessage::new(
Message {
id: "test-id".to_string(),
payload: serde_json::json!({}),
metadata: MessageMetadata::new("test"),
},
Arc::new(MockAckHandle)
);
let result = handler.handle(message, Box::new(inner_handler)).await;
assert!(result.is_err());
if let Err(WorkerError::RetryableFailure { .. }) = result {
} else {
panic!("Expected RetryableFailure");
}
}
#[tokio::test]
async fn test_retries_exhausted() {
let config = RetryHandlerConfig::default().with_max_retries(1);
let handler = RetryHandler::new(config);
let inner_handler = FailingHandler {
fail_count: std::sync::atomic::AtomicUsize::new(0),
fail_until: 10, };
let mut message = ReceivedMessage::new(
Message {
id: "test-id".to_string(),
payload: serde_json::json!({}),
metadata: MessageMetadata::new("test"),
},
Arc::new(MockAckHandle)
);
message.message.metadata.attempt = 1;
let result = handler.handle(message, Box::new(inner_handler)).await;
if let Err(WorkerError::RetriesExhausted { .. }) = result {
} else {
panic!("Expected RetriesExhausted, got: {:?}", result);
}
}
}