Skip to main content

foxtive_worker/middleware/
processing_timeout.rs

1use async_trait::async_trait;
2use std::time::Duration;
3
4use crate::error::{WorkerError, WorkerResult};
5use crate::message::ReceivedMessage;
6use crate::middleware::{MessageHandler, Middleware};
7
8/// Middleware that enforces a processing timeout on message handling.
9///
10/// This middleware wraps message processing with `tokio::time::timeout` to ensure
11/// that messages don't exceed a maximum processing time. If processing exceeds the
12/// timeout, the message is negative-acknowledged (nacked) with requeue before the
13/// broker's consumer timeout can kill the connection.
14///
15/// # Why This Matters
16///
17/// Message brokers like RabbitMQ have consumer timeouts (default 30 seconds).
18/// If a worker receives a message but doesn't ack/nack within that window,
19/// the broker assumes the consumer is dead and closes the channel with a
20/// PRECONDITION_FAILED error.
21///
22/// This middleware prevents that by enforcing a timeout **shorter than** the
23/// broker's timeout, ensuring graceful nack with proper error handling.
24///
25/// # Example
26/// ```rust,no_run
27/// use foxtive_worker::ProcessingTimeoutMiddleware;
28/// use std::time::Duration;
29///
30/// // Enforce 25-second timeout (less than RabbitMQ's 30s default)
31/// let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(25));
32/// ```
33///
34/// # Architecture
35///
36/// The middleware uses `tokio::time::timeout` which does **NOT** spawn detached tasks.
37/// It runs the future inline and cancels it if the timeout expires. This maintains
38/// controlled concurrency without unbounded task spawning.
39///
40/// ```text
41/// Message → [Timeout Check] → [Next Handler] → Result
42///              ↓ (if timeout)
43///           Nack + Error
44/// ```
45#[derive(Debug, Clone)]
46pub struct ProcessingTimeoutMiddleware {
47    /// Maximum allowed processing time per message
48    timeout: Duration,
49}
50
51impl ProcessingTimeoutMiddleware {
52    /// Create a new processing timeout middleware.
53    ///
54    /// # Arguments
55    /// * `timeout` - Maximum time allowed for message processing
56    ///
57    /// # Panics
58    /// Panics if timeout is zero
59    pub fn new(timeout: Duration) -> Self {
60        assert!(
61            !timeout.is_zero(),
62            "Processing timeout must be greater than zero"
63        );
64        Self { timeout }
65    }
66
67    /// Get the configured timeout duration.
68    pub fn timeout(&self) -> Duration {
69        self.timeout
70    }
71}
72
73#[async_trait]
74impl Middleware for ProcessingTimeoutMiddleware {
75    fn name(&self) -> &str {
76        "processing-timeout"
77    }
78
79    async fn handle(
80        &self,
81        message: ReceivedMessage<serde_json::Value>,
82        next: Box<dyn MessageHandler>,
83    ) -> WorkerResult<()> {
84        let message_id = message.message.id.clone();
85        
86        tracing::debug!(
87            message_id = %message_id,
88            timeout_ms = self.timeout.as_millis(),
89            "Starting message processing with timeout"
90        );
91
92        // Use tokio::time::timeout to enforce the limit
93        // This does NOT spawn a detached task - it polls the future inline
94        // and cancels it if the timeout expires
95        match tokio::time::timeout(self.timeout, next.handle(message.clone())).await {
96            Ok(result) => {
97                // Processing completed within timeout
98                match result {
99                    Ok(()) => {
100                        tracing::debug!(
101                            message_id = %message_id,
102                            "Message processing completed successfully within timeout"
103                        );
104                        Ok(())
105                    }
106                    Err(e) => {
107                        tracing::warn!(
108                            message_id = %message_id,
109                            error = %e,
110                            "Message processing failed (within timeout)"
111                        );
112                        Err(e)
113                    }
114                }
115            }
116            Err(_) => {
117                // Timeout expired! Nack the message before broker kills us
118                tracing::warn!(
119                    message_id = %message_id,
120                    timeout_ms = self.timeout.as_millis(),
121                    "Message processing timed out - nacking with requeue"
122                );
123
124                // Attempt to nack with requeue so the message can be retried
125                if let Err(nack_err) = message.nack(true).await {
126                    tracing::error!(
127                        message_id = %message_id,
128                        error = %nack_err,
129                        "Failed to nack timed-out message"
130                    );
131                }
132
133                // Return a timeout error
134                Err(WorkerError::Timeout(format!(
135                    "Message {} processing exceeded timeout of {:?}",
136                    message_id, self.timeout
137                )))
138            }
139        }
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::message::{AckHandle, Message, MessageMetadata};
147    use std::sync::atomic::{AtomicBool, Ordering};
148    use std::sync::Arc;
149
150    #[derive(Debug)]
151    struct MockAckHandle {
152        acked: Arc<AtomicBool>,
153        nacked: Arc<AtomicBool>,
154        requeued: Arc<AtomicBool>,
155    }
156
157    impl MockAckHandle {
158        fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
159            let acked = Arc::new(AtomicBool::new(false));
160            let nacked = Arc::new(AtomicBool::new(false));
161            let requeued = Arc::new(AtomicBool::new(false));
162            (
163                Self {
164                    acked: acked.clone(),
165                    nacked: nacked.clone(),
166                    requeued: requeued.clone(),
167                },
168                acked,
169                nacked,
170                requeued,
171            )
172        }
173    }
174
175    #[async_trait]
176    impl AckHandle for MockAckHandle {
177        async fn ack(&self) -> WorkerResult<()> {
178            self.acked.store(true, Ordering::SeqCst);
179            Ok(())
180        }
181
182        async fn nack(&self, requeue: bool) -> WorkerResult<()> {
183            self.nacked.store(true, Ordering::SeqCst);
184            self.requeued.store(requeue, Ordering::SeqCst);
185            Ok(())
186        }
187    }
188
189    struct FastHandler;
190
191    #[async_trait]
192    impl MessageHandler for FastHandler {
193        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
194            // Completes immediately
195            Ok(())
196        }
197    }
198
199    struct SlowHandler {
200        delay: Duration,
201    }
202
203    #[async_trait]
204    impl MessageHandler for SlowHandler {
205        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
206            tokio::time::sleep(self.delay).await;
207            Ok(())
208        }
209    }
210
211    struct FailingHandler;
212
213    #[async_trait]
214    impl MessageHandler for FailingHandler {
215        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
216            Err(WorkerError::ProcessingFailed("intentional failure".to_string()))
217        }
218    }
219
220    fn create_test_message() -> (ReceivedMessage<serde_json::Value>, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
221        let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
222        let message = Message {
223            id: "test-msg-1".to_string(),
224            payload: serde_json::json!({"test": "data"}),
225            metadata: MessageMetadata::new("test-queue"),
226        };
227        (ReceivedMessage::new(message, Arc::new(ack_handle)), acked, nacked, requeued)
228    }
229
230    #[tokio::test]
231    async fn test_fast_processing_completes() {
232        let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
233        let (message, acked, nacked, _) = create_test_message();
234
235        let result = middleware.handle(message, Box::new(FastHandler)).await;
236        
237        assert!(result.is_ok());
238        assert!(!acked.load(Ordering::SeqCst)); // Middleware doesn't auto-ack
239        assert!(!nacked.load(Ordering::SeqCst));
240    }
241
242    #[tokio::test]
243    async fn test_slow_processing_times_out() {
244        let timeout = Duration::from_millis(100);
245        let middleware = ProcessingTimeoutMiddleware::new(timeout);
246        let (message, _, nacked, requeued) = create_test_message();
247
248        // Handler takes longer than timeout
249        let slow_handler = SlowHandler {
250            delay: Duration::from_secs(1),
251        };
252
253        let result = middleware.handle(message, Box::new(slow_handler)).await;
254        
255        assert!(result.is_err());
256        assert!(matches!(result.unwrap_err(), WorkerError::Timeout(_)));
257        assert!(nacked.load(Ordering::SeqCst)); // Should nack on timeout
258        assert!(requeued.load(Ordering::SeqCst)); // Should requeue
259    }
260
261    #[tokio::test]
262    async fn test_processing_error_propagates() {
263        let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
264        let (message, _, _, _) = create_test_message();
265
266        let result = middleware.handle(message, Box::new(FailingHandler)).await;
267        
268        assert!(result.is_err());
269        assert!(matches!(result.unwrap_err(), WorkerError::ProcessingFailed(_)));
270    }
271
272    #[tokio::test]
273    async fn test_timeout_cancels_long_running_task() {
274        let timeout = Duration::from_millis(50);
275        let middleware = ProcessingTimeoutMiddleware::new(timeout);
276        let (message, _, nacked, _) = create_test_message();
277
278        // Very slow handler
279        let very_slow_handler = SlowHandler {
280            delay: Duration::from_secs(10),
281        };
282
283        let start = std::time::Instant::now();
284        let result = middleware.handle(message, Box::new(very_slow_handler)).await;
285        let elapsed = start.elapsed();
286
287        // Should timeout quickly, not wait for full 10 seconds
288        assert!(result.is_err());
289        assert!(elapsed < Duration::from_secs(1)); // Should complete in ~50ms, give some buffer
290        assert!(nacked.load(Ordering::SeqCst));
291    }
292
293    #[tokio::test]
294    async fn test_boundary_condition_exactly_at_timeout() {
295        let timeout = Duration::from_millis(100);
296        let middleware = ProcessingTimeoutMiddleware::new(timeout);
297        let (message, _, _, _) = create_test_message();
298
299        // Handler completes just before timeout
300        let almost_timeout_handler = SlowHandler {
301            delay: Duration::from_millis(80),
302        };
303
304        let result = middleware.handle(message, Box::new(almost_timeout_handler)).await;
305        
306        // Should succeed (completed before timeout)
307        assert!(result.is_ok());
308    }
309
310    #[test]
311    #[should_panic(expected = "Processing timeout must be greater than zero")]
312    fn test_zero_timeout_panics() {
313        let _ = ProcessingTimeoutMiddleware::new(Duration::ZERO);
314    }
315
316    #[test]
317    fn test_timeout_getter() {
318        let timeout = Duration::from_secs(30);
319        let middleware = ProcessingTimeoutMiddleware::new(timeout);
320        assert_eq!(middleware.timeout(), timeout);
321    }
322}