Skip to main content

foxtive_worker/middleware/
processing_timeout.rs

1use async_trait::async_trait;
2use std::time::Duration;
3
4use crate::error::WorkerError;
5use crate::message::ReceivedMessage;
6use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
7
8/// Middleware that enforces a processing timeout on message handling.
9///
10/// This middleware wraps message processing with `tokio::time::timeout` to ensure
11/// that messages don't exceed a maximum processing time. If processing exceeds the
12/// timeout, the message is negative-acknowledged (nacked) with requeue before the
13/// broker's consumer timeout can kill the connection.
14///
15/// # Why This Matters
16///
17/// Message brokers like RabbitMQ have consumer timeouts (default 30 seconds).
18/// If a worker receives a message but doesn't ack/nack within that window,
19/// the broker assumes the consumer is dead and closes the channel with a
20/// PRECONDITION_FAILED error.
21///
22/// This middleware prevents that by enforcing a timeout **shorter than** the
23/// broker's timeout, ensuring graceful nack with proper error handling.
24///
25/// # Example
26/// ```rust,no_run
27/// use foxtive_worker::ProcessingTimeoutMiddleware;
28/// use std::time::Duration;
29///
30/// // Enforce 25-second timeout (less than RabbitMQ's 30s default)
31/// let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(25));
32/// ```
33///
34/// # Architecture
35///
36/// The middleware uses `tokio::time::timeout` which does **NOT** spawn detached tasks.
37/// It runs the future inline and cancels it if the timeout expires. This maintains
38/// controlled concurrency without unbounded task spawning.
39///
40/// ```text
41/// Message → [Timeout Check] → [Next Handler] → Result
42///              ↓ (if timeout)
43///           Nack + Error
44/// ```
45#[derive(Debug, Clone)]
46pub struct ProcessingTimeoutMiddleware {
47    /// Maximum allowed processing time per message
48    timeout: Duration,
49}
50
51impl ProcessingTimeoutMiddleware {
52    /// Create a new processing timeout middleware.
53    ///
54    /// # Arguments
55    /// * `timeout` - Maximum time allowed for message processing
56    ///
57    /// # Panics
58    /// Panics if timeout is zero
59    pub fn new(timeout: Duration) -> Self {
60        assert!(
61            !timeout.is_zero(),
62            "Processing timeout must be greater than zero"
63        );
64        Self { timeout }
65    }
66
67    /// Get the configured timeout duration.
68    pub fn timeout(&self) -> Duration {
69        self.timeout
70    }
71}
72
73#[async_trait]
74impl Middleware for ProcessingTimeoutMiddleware {
75    fn name(&self) -> &str {
76        "processing-timeout"
77    }
78
79    async fn handle(
80        &self,
81        message: ReceivedMessage<serde_json::Value>,
82        next: Box<dyn MessageHandler>,
83    ) -> Result<MiddlewareResult, WorkerError> {
84        let message_id = message.message.id.clone();
85
86        tracing::debug!(
87            message_id = %message_id,
88            timeout_ms = self.timeout.as_millis(),
89            "Starting message processing with timeout"
90        );
91
92        // Use tokio::time::timeout to enforce the limit
93        // This does NOT spawn a detached task - it polls the future inline
94        // and cancels it if the timeout expires
95        match tokio::time::timeout(self.timeout, next.handle(message.clone())).await {
96            Ok(result) => {
97                // Processing completed within timeout
98                result
99            }
100            Err(_) => {
101                // Timeout expired! Nack the message before broker kills us
102                tracing::warn!(
103                    message_id = %message_id,
104                    timeout_ms = self.timeout.as_millis(),
105                    "Message processing timed out - nacking with requeue"
106                );
107
108                // Attempt to nack with requeue so the message can be retried
109                if let Err(nack_err) = message.nack(true).await {
110                    tracing::error!(
111                        message_id = %message_id,
112                        error = %nack_err,
113                        "Failed to nack timed-out message"
114                    );
115                }
116
117                // Return a timeout error
118                Err(WorkerError::Timeout(format!(
119                    "Message {} processing exceeded timeout of {:?}",
120                    message_id, self.timeout
121                )))
122            }
123        }
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::message::{AckHandle, Message, MessageMetadata};
131    use std::sync::Arc;
132    use std::sync::atomic::{AtomicBool, Ordering};
133
134    #[derive(Debug)]
135    struct MockAckHandle {
136        acked: Arc<AtomicBool>,
137        nacked: Arc<AtomicBool>,
138        requeued: Arc<AtomicBool>,
139    }
140
141    impl MockAckHandle {
142        fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
143            let acked = Arc::new(AtomicBool::new(false));
144            let nacked = Arc::new(AtomicBool::new(false));
145            let requeued = Arc::new(AtomicBool::new(false));
146            (
147                Self {
148                    acked: acked.clone(),
149                    nacked: nacked.clone(),
150                    requeued: requeued.clone(),
151                },
152                acked,
153                nacked,
154                requeued,
155            )
156        }
157    }
158
159    #[async_trait]
160    impl AckHandle for MockAckHandle {
161        async fn ack(&self) -> crate::WorkerResult<()> {
162            self.acked.store(true, Ordering::SeqCst);
163            Ok(())
164        }
165
166        async fn nack(&self, requeue: bool) -> crate::WorkerResult<()> {
167            self.nacked.store(true, Ordering::SeqCst);
168            self.requeued.store(requeue, Ordering::SeqCst);
169            Ok(())
170        }
171    }
172
173    struct FastHandler;
174
175    #[async_trait]
176    impl MessageHandler for FastHandler {
177        async fn handle(
178            &self,
179            _message: ReceivedMessage<serde_json::Value>,
180        ) -> Result<MiddlewareResult, WorkerError> {
181            // Completes immediately
182            Ok(MiddlewareResult::Continue)
183        }
184    }
185
186    struct SlowHandler {
187        delay: Duration,
188    }
189
190    #[async_trait]
191    impl MessageHandler for SlowHandler {
192        async fn handle(
193            &self,
194            _message: ReceivedMessage<serde_json::Value>,
195        ) -> Result<MiddlewareResult, WorkerError> {
196            tokio::time::sleep(self.delay).await;
197            Ok(MiddlewareResult::Continue)
198        }
199    }
200
201    struct FailingHandler;
202
203    #[async_trait]
204    impl MessageHandler for FailingHandler {
205        async fn handle(
206            &self,
207            _message: ReceivedMessage<serde_json::Value>,
208        ) -> Result<MiddlewareResult, WorkerError> {
209            Err(WorkerError::ProcessingFailed(
210                "intentional failure".to_string(),
211            ))
212        }
213    }
214
215    fn create_test_message() -> (
216        ReceivedMessage<serde_json::Value>,
217        Arc<AtomicBool>,
218        Arc<AtomicBool>,
219        Arc<AtomicBool>,
220    ) {
221        let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
222        let message = Message {
223            id: "test-msg-1".to_string(),
224            payload: serde_json::json!({"test": "data"}),
225            metadata: MessageMetadata::new("test-queue"),
226        };
227        (
228            ReceivedMessage::new(message, Arc::new(ack_handle)),
229            acked,
230            nacked,
231            requeued,
232        )
233    }
234
235    #[tokio::test]
236    async fn test_fast_processing_completes() {
237        let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
238        let (message, acked, nacked, _) = create_test_message();
239
240        let result = middleware.handle(message, Box::new(FastHandler)).await;
241
242        assert!(result.is_ok());
243        assert!(!acked.load(Ordering::SeqCst)); // Middleware doesn't auto-ack
244        assert!(!nacked.load(Ordering::SeqCst));
245    }
246
247    #[tokio::test]
248    async fn test_slow_processing_times_out() {
249        let timeout = Duration::from_millis(100);
250        let middleware = ProcessingTimeoutMiddleware::new(timeout);
251        let (message, _, nacked, requeued) = create_test_message();
252
253        // Handler takes longer than timeout
254        let slow_handler = SlowHandler {
255            delay: Duration::from_secs(1),
256        };
257
258        let result = middleware.handle(message, Box::new(slow_handler)).await;
259
260        assert!(result.is_err());
261        assert!(matches!(result.unwrap_err(), WorkerError::Timeout(_)));
262        assert!(nacked.load(Ordering::SeqCst)); // Should nack on timeout
263        assert!(requeued.load(Ordering::SeqCst)); // Should requeue
264    }
265
266    #[tokio::test]
267    async fn test_processing_error_propagates() {
268        let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
269        let (message, _, _, _) = create_test_message();
270
271        let result = middleware.handle(message, Box::new(FailingHandler)).await;
272
273        assert!(result.is_err());
274        assert!(matches!(
275            result.unwrap_err(),
276            WorkerError::ProcessingFailed(_)
277        ));
278    }
279
280    #[tokio::test]
281    async fn test_timeout_cancels_long_running_task() {
282        let timeout = Duration::from_millis(50);
283        let middleware = ProcessingTimeoutMiddleware::new(timeout);
284        let (message, _, nacked, _) = create_test_message();
285
286        // Very slow handler
287        let very_slow_handler = SlowHandler {
288            delay: Duration::from_secs(10),
289        };
290
291        let start = std::time::Instant::now();
292        let result = middleware
293            .handle(message, Box::new(very_slow_handler))
294            .await;
295        let elapsed = start.elapsed();
296
297        // Should timeout quickly, not wait for full 10 seconds
298        assert!(result.is_err());
299        assert!(elapsed < Duration::from_secs(1)); // Should complete in ~50ms, give some buffer
300        assert!(nacked.load(Ordering::SeqCst));
301    }
302
303    #[tokio::test]
304    async fn test_boundary_condition_exactly_at_timeout() {
305        let timeout = Duration::from_millis(100);
306        let middleware = ProcessingTimeoutMiddleware::new(timeout);
307        let (message, _, _, _) = create_test_message();
308
309        // Handler completes just before timeout
310        let almost_timeout_handler = SlowHandler {
311            delay: Duration::from_millis(80),
312        };
313
314        let result = middleware
315            .handle(message, Box::new(almost_timeout_handler))
316            .await;
317
318        // Should succeed (completed before timeout)
319        assert!(result.is_ok());
320    }
321
322    #[test]
323    #[should_panic(expected = "Processing timeout must be greater than zero")]
324    fn test_zero_timeout_panics() {
325        let _ = ProcessingTimeoutMiddleware::new(Duration::ZERO);
326    }
327
328    #[test]
329    fn test_timeout_getter() {
330        let timeout = Duration::from_secs(30);
331        let middleware = ProcessingTimeoutMiddleware::new(timeout);
332        assert_eq!(middleware.timeout(), timeout);
333    }
334}