Skip to main content

foxtive_worker/middleware/
retry_handler.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3use std::time::Duration;
4use tracing::{debug, error, info, warn};
5
6use crate::backends::{DeadLetterQueueBackend, create_dlq_message};
7use crate::dlq::PoisonPillTracker;
8use crate::error::WorkerError;
9use crate::message::ReceivedMessage;
10use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
11
12/// Configuration for the RetryHandler middleware.
13#[derive(Clone)]
14pub struct RetryHandlerConfig {
15    /// The maximum number of times a message should be retried.
16    pub max_retries: u32,
17    /// The initial backoff duration for the first retry.
18    pub initial_backoff: Duration,
19    /// The maximum backoff duration.
20    pub max_backoff: Duration,
21    /// Multiplier for exponential backoff.
22    pub backoff_multiplier: f64,
23    /// Optional dead letter queue backend for permanently failed messages.
24    pub dead_letter_queue: Option<Arc<DeadLetterQueueBackend>>,
25    /// Optional poison pill tracker for detecting problematic messages.
26    pub poison_pill_tracker: Option<Arc<PoisonPillTracker>>,
27    /// Whether to add jitter to backoff delays (recommended for distributed systems).
28    pub use_jitter: bool,
29}
30
31impl std::fmt::Debug for RetryHandlerConfig {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("RetryHandlerConfig")
34            .field("max_retries", &self.max_retries)
35            .field("initial_backoff", &self.initial_backoff)
36            .field("max_backoff", &self.max_backoff)
37            .field("backoff_multiplier", &self.backoff_multiplier)
38            .field(
39                "dead_letter_queue",
40                &self.dead_letter_queue.as_ref().map(|_| "<MessageBackend>"),
41            )
42            .field("use_jitter", &self.use_jitter)
43            .finish()
44    }
45}
46
47impl Default for RetryHandlerConfig {
48    fn default() -> Self {
49        Self {
50            max_retries: 5,
51            initial_backoff: Duration::from_secs(1),
52            max_backoff: Duration::from_secs(60),
53            backoff_multiplier: 2.0,
54            dead_letter_queue: None,
55            poison_pill_tracker: None,
56            use_jitter: true,
57        }
58    }
59}
60
61impl RetryHandlerConfig {
62    /// Create a new config with custom max retries.
63    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
64        self.max_retries = max_retries;
65        self
66    }
67
68    /// Set the dead letter queue for permanently failed messages.
69    pub fn with_dead_letter_queue(mut self, dlq: Arc<DeadLetterQueueBackend>) -> Self {
70        self.dead_letter_queue = Some(dlq);
71        self
72    }
73
74    /// Set the poison pill tracker for detecting problematic messages.
75    pub fn with_poison_pill_tracker(mut self, tracker: Arc<PoisonPillTracker>) -> Self {
76        self.poison_pill_tracker = Some(tracker);
77        self
78    }
79
80    /// Enable or disable jitter (enabled by default).
81    pub fn with_jitter(mut self, use_jitter: bool) -> Self {
82        self.use_jitter = use_jitter;
83        self
84    }
85}
86
87/// Middleware that handles automatic retries for failed messages with exponential backoff.
88pub struct RetryHandler {
89    config: RetryHandlerConfig,
90}
91
92impl RetryHandler {
93    /// Creates a new `RetryHandler` with the given configuration.
94    pub fn new(config: RetryHandlerConfig) -> Self {
95        Self { config }
96    }
97
98    /// Calculates the next backoff duration based on the current attempt count.
99    /// Uses exponential backoff with optional jitter.
100    fn calculate_backoff(&self, attempts: u32) -> Duration {
101        if attempts == 0 {
102            return self.config.initial_backoff;
103        }
104
105        let current_backoff = self.config.initial_backoff.as_secs_f64()
106            * self.config.backoff_multiplier.powf(attempts as f64 - 1.0);
107
108        let mut backoff = Duration::from_secs_f64(current_backoff);
109
110        // Add jitter (±25% random variation) to prevent thundering herd
111        if self.config.use_jitter {
112            // Generate a random factor between -0.25 and +0.25
113            let jitter_factor = rand::random::<f64>() * 0.5 - 0.25; // Range: -0.25 to +0.25
114            let jitter = backoff.as_secs_f64() * jitter_factor;
115            let new_backoff = backoff.as_secs_f64() + jitter;
116
117            // Ensure we don't go below a minimum of 10ms
118            backoff = Duration::from_secs_f64(new_backoff.max(0.01));
119        }
120
121        std::cmp::min(backoff, self.config.max_backoff)
122    }
123
124    /// Send a message to the dead letter queue if configured.
125    async fn send_to_dlq(&self, message: &ReceivedMessage<serde_json::Value>, error: &WorkerError) {
126        if let Some(ref dlq) = self.config.dead_letter_queue {
127            // Check for poison pill first
128            let is_poison_pill = if let Some(ref tracker) = self.config.poison_pill_tracker {
129                tracker.record_failure(&message.message.id)
130            } else {
131                false
132            };
133
134            if is_poison_pill {
135                error!(
136                    "[{}] POISON PILL DETECTED: Message {} failed {} times - sending to DLQ",
137                    self.name(),
138                    message.message.id,
139                    message.message.metadata.attempt
140                );
141            }
142
143            // Create DLQ message with failure context
144            let mut dlq_message = create_dlq_message(
145                message.message.id.clone(),
146                message.message.payload.clone(),
147                message.message.metadata.source.clone(),
148                message.message.metadata.attempt,
149                error,
150                None, // Worker ID not available in middleware context
151            );
152
153            // Add poison pill flag if detected
154            if is_poison_pill {
155                dlq_message = dlq_message.with_context("poison_pill", serde_json::json!(true));
156            }
157
158            // Send to DLQ backend
159            match dlq.send_to_dlq(&dlq_message).await {
160                Ok(_) => {
161                    info!(
162                        "[{}] Successfully sent message {} to DLQ after {} attempts",
163                        self.name(),
164                        message.message.id,
165                        message.message.metadata.attempt
166                    );
167                }
168                Err(e) => {
169                    error!(
170                        "[{}] Failed to send message {} to DLQ: {:?}",
171                        self.name(),
172                        message.message.id,
173                        e
174                    );
175                }
176            }
177        }
178    }
179}
180
181#[async_trait]
182impl Middleware for RetryHandler {
183    fn name(&self) -> &str {
184        "RetryHandler"
185    }
186
187    async fn handle(
188        &self,
189        mut message: ReceivedMessage<serde_json::Value>,
190        next: Box<dyn MessageHandler>,
191    ) -> Result<MiddlewareResult, WorkerError> {
192        // Increment attempt count before processing
193        message.message.metadata.increment_attempt();
194        let current_attempts = message.message.metadata.attempt;
195
196        debug!(
197            "[{}] Processing message {} (attempt {}/{})",
198            self.name(),
199            message.message.id,
200            current_attempts,
201            self.config.max_retries
202        );
203
204        let result = next.handle(message.clone()).await;
205
206        match result {
207            Ok(middleware_result) => {
208                debug!(
209                    "[{}] Message {} processed successfully.",
210                    self.name(),
211                    message.message.id
212                );
213                Ok(middleware_result)
214            }
215            Err(e) => {
216                warn!(
217                    "[{}] Message {} failed on attempt {}: {:?}",
218                    self.name(),
219                    message.message.id,
220                    current_attempts,
221                    e
222                );
223
224                if current_attempts < self.config.max_retries {
225                    let delay = self.calculate_backoff(current_attempts);
226                    debug!(
227                        "[{}] Message {} will be retried in {:?}. Current attempts: {}",
228                        self.name(),
229                        message.message.id,
230                        delay,
231                        current_attempts
232                    );
233                    Err(WorkerError::RetryableFailure {
234                        source: Box::new(e),
235                        delay_ms: delay,
236                    })
237                } else {
238                    // Retries exhausted - send to DLQ if configured
239                    self.send_to_dlq(&message, &e).await;
240
241                    warn!(
242                        "[{}] Retries exhausted for message {} after {} attempts.",
243                        self.name(),
244                        message.message.id,
245                        current_attempts
246                    );
247                    Err(WorkerError::RetriesExhausted {
248                        source: Box::new(e),
249                    })
250                }
251            }
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::message::{AckHandle, Message, MessageMetadata, ReceivedMessage};
260    use std::sync::Arc;
261
262    #[derive(Debug)]
263    struct MockAckHandle;
264
265    #[async_trait::async_trait]
266    impl AckHandle for MockAckHandle {
267        async fn ack(&self) -> crate::WorkerResult<()> {
268            Ok(())
269        }
270        async fn nack(&self, _requeue: bool) -> crate::WorkerResult<()> {
271            Ok(())
272        }
273    }
274
275    struct FailingHandler {
276        fail_count: std::sync::atomic::AtomicUsize,
277        fail_until: usize,
278    }
279
280    #[async_trait::async_trait]
281    impl MessageHandler for FailingHandler {
282        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
283            let count = self
284                .fail_count
285                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
286            if count < self.fail_until {
287                Err(WorkerError::ProcessingError("Simulated failure".into()))
288            } else {
289                Ok(MiddlewareResult::Continue)
290            }
291        }
292    }
293
294    #[tokio::test]
295    async fn test_retry_success_after_failures() {
296        let config = RetryHandlerConfig::default().with_max_retries(3);
297        let handler = RetryHandler::new(config);
298
299        let inner_handler = FailingHandler {
300            fail_count: std::sync::atomic::AtomicUsize::new(0),
301            fail_until: 2, // Fail twice, succeed on third
302        };
303
304        let message = ReceivedMessage::new(
305            Message {
306                id: "test-id".to_string(),
307                payload: serde_json::json!({}),
308                metadata: MessageMetadata::new("test"),
309            },
310            Arc::new(MockAckHandle),
311        );
312
313        let result = handler.handle(message, Box::new(inner_handler)).await;
314
315        // The handler should eventually succeed after retries are handled by the pool/dispatcher
316        // In this unit test context, we expect the first call to return a RetryableFailure
317        assert!(result.is_err());
318        if let Err(WorkerError::RetryableFailure { .. }) = result {
319            // Success: it correctly identified a retryable failure
320        } else {
321            panic!("Expected RetryableFailure");
322        }
323    }
324
325    #[tokio::test]
326    async fn test_retries_exhausted() {
327        let config = RetryHandlerConfig::default().with_max_retries(1);
328        let handler = RetryHandler::new(config);
329
330        let inner_handler = FailingHandler {
331            fail_count: std::sync::atomic::AtomicUsize::new(0),
332            fail_until: 10, // Always fail
333        };
334
335        let mut message = ReceivedMessage::new(
336            Message {
337                id: "test-id".to_string(),
338                payload: serde_json::json!({}),
339                metadata: MessageMetadata::new("test"),
340            },
341            Arc::new(MockAckHandle),
342        );
343
344        // Manually increment attempt to simulate max retries reached
345        message.message.metadata.attempt = 1;
346
347        let result = handler.handle(message, Box::new(inner_handler)).await;
348
349        if let Err(WorkerError::RetriesExhausted { .. }) = result {
350            // Success: it correctly identified that retries are exhausted
351        } else {
352            panic!("Expected RetriesExhausted, got: {:?}", result);
353        }
354    }
355}