Skip to main content

foxtive_worker/middleware/
retry_handler.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3use std::time::Duration;
4use tracing::{debug, warn, error, info};
5
6use crate::error::{WorkerError, WorkerResult};
7use crate::message::ReceivedMessage;
8use crate::middleware::{Middleware, MessageHandler};
9use crate::backends::{DeadLetterQueueBackend, create_dlq_message};
10use crate::dlq::PoisonPillTracker;
11
12/// Configuration for the RetryHandler middleware.
13#[derive(Clone)]
14pub struct RetryHandlerConfig {
15    /// The maximum number of times a message should be retried.
16    pub max_retries: u32,
17    /// The initial backoff duration for the first retry.
18    pub initial_backoff: Duration,
19    /// The maximum backoff duration.
20    pub max_backoff: Duration,
21    /// Multiplier for exponential backoff.
22    pub backoff_multiplier: f64,
23    /// Optional dead letter queue backend for permanently failed messages.
24    pub dead_letter_queue: Option<Arc<DeadLetterQueueBackend>>,
25    /// Optional poison pill tracker for detecting problematic messages.
26    pub poison_pill_tracker: Option<Arc<PoisonPillTracker>>,
27    /// Whether to add jitter to backoff delays (recommended for distributed systems).
28    pub use_jitter: bool,
29}
30
31impl std::fmt::Debug for RetryHandlerConfig {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("RetryHandlerConfig")
34            .field("max_retries", &self.max_retries)
35            .field("initial_backoff", &self.initial_backoff)
36            .field("max_backoff", &self.max_backoff)
37            .field("backoff_multiplier", &self.backoff_multiplier)
38            .field("dead_letter_queue", &self.dead_letter_queue.as_ref().map(|_| "<MessageBackend>"))
39            .field("use_jitter", &self.use_jitter)
40            .finish()
41    }
42}
43
44impl Default for RetryHandlerConfig {
45    fn default() -> Self {
46        Self {
47            max_retries: 5,
48            initial_backoff: Duration::from_secs(1),
49            max_backoff: Duration::from_secs(60),
50            backoff_multiplier: 2.0,
51            dead_letter_queue: None,
52            poison_pill_tracker: None,
53            use_jitter: true,
54        }
55    }
56}
57
58impl RetryHandlerConfig {
59    /// Create a new config with custom max retries.
60    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
61        self.max_retries = max_retries;
62        self
63    }
64
65    /// Set the dead letter queue for permanently failed messages.
66    pub fn with_dead_letter_queue(mut self, dlq: Arc<DeadLetterQueueBackend>) -> Self {
67        self.dead_letter_queue = Some(dlq);
68        self
69    }
70
71    /// Set the poison pill tracker for detecting problematic messages.
72    pub fn with_poison_pill_tracker(mut self, tracker: Arc<PoisonPillTracker>) -> Self {
73        self.poison_pill_tracker = Some(tracker);
74        self
75    }
76
77    /// Enable or disable jitter (enabled by default).
78    pub fn with_jitter(mut self, use_jitter: bool) -> Self {
79        self.use_jitter = use_jitter;
80        self
81    }
82}
83
84/// Middleware that handles automatic retries for failed messages with exponential backoff.
85pub struct RetryHandler {
86    config: RetryHandlerConfig,
87}
88
89impl RetryHandler {
90    /// Creates a new `RetryHandler` with the given configuration.
91    pub fn new(config: RetryHandlerConfig) -> Self {
92        Self { config }
93    }
94
95    /// Calculates the next backoff duration based on the current attempt count.
96    /// Uses exponential backoff with optional jitter.
97    fn calculate_backoff(&self, attempts: u32) -> Duration {
98        if attempts == 0 {
99            return self.config.initial_backoff;
100        }
101
102        let current_backoff = self.config.initial_backoff.as_secs_f64()
103            * self.config.backoff_multiplier.powf(attempts as f64 - 1.0);
104
105        let mut backoff = Duration::from_secs_f64(current_backoff);
106
107        // Add jitter (±25% random variation) to prevent thundering herd
108        if self.config.use_jitter {
109            // Generate a random factor between -0.25 and +0.25
110            let jitter_factor = rand::random::<f64>() * 0.5 - 0.25; // Range: -0.25 to +0.25
111            let jitter = backoff.as_secs_f64() * jitter_factor;
112            let new_backoff = backoff.as_secs_f64() + jitter;
113            
114            // Ensure we don't go below a minimum of 10ms
115            backoff = Duration::from_secs_f64(new_backoff.max(0.01));
116        }
117
118        std::cmp::min(backoff, self.config.max_backoff)
119    }
120
121    /// Send a message to the dead letter queue if configured.
122    async fn send_to_dlq(&self, message: &ReceivedMessage<serde_json::Value>, error: &WorkerError) {
123        if let Some(ref dlq) = self.config.dead_letter_queue {
124            // Check for poison pill first
125            let is_poison_pill = if let Some(ref tracker) = self.config.poison_pill_tracker {
126                tracker.record_failure(&message.message.id)
127            } else {
128                false
129            };
130
131            if is_poison_pill {
132                error!(
133                    "[{}] POISON PILL DETECTED: Message {} failed {} times - sending to DLQ",
134                    self.name(),
135                    message.message.id,
136                    message.message.metadata.attempt
137                );
138            }
139
140            // Create DLQ message with failure context
141            let mut dlq_message = create_dlq_message(
142                message.message.id.clone(),
143                message.message.payload.clone(),
144                message.message.metadata.source.clone(),
145                message.message.metadata.attempt,
146                error,
147                None, // Worker ID not available in middleware context
148            );
149
150            // Add poison pill flag if detected
151            if is_poison_pill {
152                dlq_message = dlq_message.with_context("poison_pill", serde_json::json!(true));
153            }
154
155            // Send to DLQ backend
156            match dlq.send_to_dlq(&dlq_message).await {
157                Ok(_) => {
158                    info!(
159                        "[{}] Successfully sent message {} to DLQ after {} attempts",
160                        self.name(),
161                        message.message.id,
162                        message.message.metadata.attempt
163                    );
164                }
165                Err(e) => {
166                    error!(
167                        "[{}] Failed to send message {} to DLQ: {:?}",
168                        self.name(),
169                        message.message.id,
170                        e
171                    );
172                }
173            }
174        }
175    }
176}
177
178#[async_trait]
179impl Middleware for RetryHandler {
180    fn name(&self) -> &str {
181        "RetryHandler"
182    }
183
184    async fn handle(
185        &self,
186        mut message: ReceivedMessage<serde_json::Value>,
187        next: Box<dyn MessageHandler>,
188    ) -> WorkerResult<()> {
189        // Increment attempt count before processing
190        message.message.metadata.increment_attempt();
191        let current_attempts = message.message.metadata.attempt;
192
193        debug!(
194            "[{}] Processing message {} (attempt {}/{})",
195            self.name(),
196            message.message.id,
197            current_attempts,
198            self.config.max_retries
199        );
200
201        let result = next.handle(message.clone()).await;
202
203        match result {
204            Ok(_) => {
205                debug!("[{}] Message {} processed successfully.", self.name(), message.message.id);
206                Ok(())
207            }
208            Err(e) => {
209                warn!(
210                    "[{}] Message {} failed on attempt {}: {:?}",
211                    self.name(),
212                    message.message.id,
213                    current_attempts,
214                    e
215                );
216
217                if current_attempts < self.config.max_retries {
218                    let delay = self.calculate_backoff(current_attempts);
219                    debug!(
220                        "[{}] Message {} will be retried in {:?}. Current attempts: {}",
221                        self.name(),
222                        message.message.id,
223                        delay,
224                        current_attempts
225                    );
226                    Err(WorkerError::RetryableFailure {
227                        source: Box::new(e),
228                        delay_ms: delay,
229                    })
230                } else {
231                    // Retries exhausted - send to DLQ if configured
232                    self.send_to_dlq(&message, &e).await;
233                    
234                    warn!(
235                        "[{}] Retries exhausted for message {} after {} attempts.",
236                        self.name(),
237                        message.message.id,
238                        current_attempts
239                    );
240                    Err(WorkerError::RetriesExhausted {
241                        source: Box::new(e),
242                    })
243                }
244            }
245        }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::message::{Message, MessageMetadata, ReceivedMessage, AckHandle};
253    use std::sync::Arc;
254
255    #[derive(Debug)]
256    struct MockAckHandle;
257
258    #[async_trait::async_trait]
259    impl AckHandle for MockAckHandle {
260        async fn ack(&self) -> WorkerResult<()> { Ok(()) }
261        async fn nack(&self, _requeue: bool) -> WorkerResult<()> { Ok(()) }
262    }
263
264    struct FailingHandler {
265        fail_count: std::sync::atomic::AtomicUsize,
266        fail_until: usize,
267    }
268
269    #[async_trait::async_trait]
270    impl MessageHandler for FailingHandler {
271        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
272            let count = self.fail_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
273            if count < self.fail_until {
274                Err(WorkerError::ProcessingError("Simulated failure".into()))
275            } else {
276                Ok(())
277            }
278        }
279    }
280
281    #[tokio::test]
282    async fn test_retry_success_after_failures() {
283        let config = RetryHandlerConfig::default().with_max_retries(3);
284        let handler = RetryHandler::new(config);
285        
286        let inner_handler = FailingHandler {
287            fail_count: std::sync::atomic::AtomicUsize::new(0),
288            fail_until: 2, // Fail twice, succeed on third
289        };
290
291        let message = ReceivedMessage::new(
292            Message {
293                id: "test-id".to_string(),
294                payload: serde_json::json!({}),
295                metadata: MessageMetadata::new("test"),
296            },
297            Arc::new(MockAckHandle)
298        );
299
300        let result = handler.handle(message, Box::new(inner_handler)).await;
301        
302        // The handler should eventually succeed after retries are handled by the pool/dispatcher
303        // In this unit test context, we expect the first call to return a RetryableFailure
304        assert!(result.is_err());
305        if let Err(WorkerError::RetryableFailure { .. }) = result {
306            // Success: it correctly identified a retryable failure
307        } else {
308            panic!("Expected RetryableFailure");
309        }
310    }
311
312    #[tokio::test]
313    async fn test_retries_exhausted() {
314        let config = RetryHandlerConfig::default().with_max_retries(1);
315        let handler = RetryHandler::new(config);
316        
317        let inner_handler = FailingHandler {
318            fail_count: std::sync::atomic::AtomicUsize::new(0),
319            fail_until: 10, // Always fail
320        };
321
322        let mut message = ReceivedMessage::new(
323            Message {
324                id: "test-id".to_string(),
325                payload: serde_json::json!({}),
326                metadata: MessageMetadata::new("test"),
327            },
328            Arc::new(MockAckHandle)
329        );
330        
331        // Manually increment attempt to simulate max retries reached
332        message.message.metadata.attempt = 1;
333
334        let result = handler.handle(message, Box::new(inner_handler)).await;
335        
336        if let Err(WorkerError::RetriesExhausted { .. }) = result {
337            // Success: it correctly identified that retries are exhausted
338        } else {
339            panic!("Expected RetriesExhausted, got: {:?}", result);
340        }
341    }
342}