Skip to main content

foxtive_worker/middleware/
processing_timeout.rs

1use async_trait::async_trait;
2use std::time::Duration;
3
4use crate::error::WorkerError;
5use crate::message::ReceivedMessage;
6use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
7
8/// Middleware that enforces a processing timeout on message handling.
9///
10/// This middleware wraps message processing with `tokio::time::timeout` to ensure
11/// that messages don't exceed a maximum processing time. If processing exceeds the
12/// timeout, the message is negative-acknowledged (nacked) with requeue before the
13/// broker's consumer timeout can kill the connection.
14///
15/// # Why This Matters
16///
17/// Message brokers like RabbitMQ have consumer timeouts (default 30 seconds).
18/// If a worker receives a message but doesn't ack/nack within that window,
19/// the broker assumes the consumer is dead and closes the channel with a
20/// PRECONDITION_FAILED error.
21///
22/// This middleware prevents that by enforcing a timeout **shorter than** the
23/// broker's timeout, ensuring graceful nack with proper error handling.
24///
25/// # Example
26/// ```rust,no_run
27/// use foxtive_worker::ProcessingTimeoutMiddleware;
28/// use std::time::Duration;
29///
30/// // Enforce 25-second timeout (less than RabbitMQ's 30s default)
31/// let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(25));
32/// ```
33///
34/// # Architecture
35///
36/// The middleware uses `tokio::time::timeout` which does **NOT** spawn detached tasks.
37/// It runs the future inline and cancels it if the timeout expires. This maintains
38/// controlled concurrency without unbounded task spawning.
39///
40/// ```text
41/// Message → [Timeout Check] → [Next Handler] → Result
42///              ↓ (if timeout)
43///           Nack + Error
44/// ```
45#[derive(Debug, Clone)]
46pub struct ProcessingTimeoutMiddleware {
47    /// Maximum allowed processing time per message
48    timeout: Duration,
49}
50
51impl ProcessingTimeoutMiddleware {
52    /// Create a new processing timeout middleware.
53    ///
54    /// # Arguments
55    /// * `timeout` - Maximum time allowed for message processing
56    ///
57    /// # Panics
58    /// Panics if timeout is zero
59    pub fn new(timeout: Duration) -> Self {
60        assert!(
61            !timeout.is_zero(),
62            "Processing timeout must be greater than zero"
63        );
64        Self { timeout }
65    }
66
67    /// Get the configured timeout duration.
68    pub fn timeout(&self) -> Duration {
69        self.timeout
70    }
71}
72
73#[async_trait]
74impl Middleware for ProcessingTimeoutMiddleware {
75    fn name(&self) -> &str {
76        "processing-timeout"
77    }
78
79    async fn handle(
80        &self,
81        message: ReceivedMessage<serde_json::Value>,
82        next: Box<dyn MessageHandler>,
83    ) -> Result<MiddlewareResult, WorkerError> {
84        let message_id = message.message.id.clone();
85
86        tracing::debug!(
87            message_id = %message_id,
88            timeout_ms = self.timeout.as_millis(),
89            "Starting message processing with timeout"
90        );
91
92        // Use tokio::time::timeout to enforce the limit
93        // This does NOT spawn a detached task - it polls the future inline
94        // and cancels it if the timeout expires
95        match tokio::time::timeout(self.timeout, next.handle(message.clone())).await {
96            Ok(result) => {
97                // Processing completed within timeout
98                result
99            }
100            Err(_) => {
101                // Timeout expired! Nack the message before broker kills us
102                tracing::warn!(
103                    message_id = %message_id,
104                    timeout_ms = self.timeout.as_millis(),
105                    "Message processing timed out - nacking with requeue"
106                );
107
108                // Attempt to nack with requeue so the message can be retried
109                if let Err(nack_err) = message.nack(true).await {
110                    tracing::error!(
111                        message_id = %message_id,
112                        error = %nack_err,
113                        "Failed to nack timed-out message"
114                    );
115                }
116
117                // Return a timeout error
118                Err(WorkerError::Timeout(format!(
119                    "Message {} processing exceeded timeout of {:?}",
120                    message_id, self.timeout
121                )))
122            }
123        }
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::message::{AckHandle, Message, MessageMetadata};
131    use std::sync::Arc;
132    use std::sync::atomic::{AtomicBool, Ordering};
133
134    #[derive(Debug)]
135    struct MockAckHandle {
136        acked: Arc<AtomicBool>,
137        nacked: Arc<AtomicBool>,
138        requeued: Arc<AtomicBool>,
139    }
140
141    impl MockAckHandle {
142        fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
143            let acked = Arc::new(AtomicBool::new(false));
144            let nacked = Arc::new(AtomicBool::new(false));
145            let requeued = Arc::new(AtomicBool::new(false));
146            (
147                Self {
148                    acked: acked.clone(),
149                    nacked: nacked.clone(),
150                    requeued: requeued.clone(),
151                },
152                acked,
153                nacked,
154                requeued,
155            )
156        }
157    }
158
159    #[async_trait]
160    impl AckHandle for MockAckHandle {
161        async fn ack(&self) -> crate::WorkerResult<()> {
162            self.acked.store(true, Ordering::SeqCst);
163            Ok(())
164        }
165
166        async fn nack(&self, requeue: bool) -> crate::WorkerResult<()> {
167            self.nacked.store(true, Ordering::SeqCst);
168            self.requeued.store(requeue, Ordering::SeqCst);
169            Ok(())
170        }
171    }
172
173    struct FastHandler;
174
175    #[async_trait]
176    impl MessageHandler for FastHandler {
177        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
178            // Completes immediately
179            Ok(MiddlewareResult::Continue)
180        }
181    }
182
183    struct SlowHandler {
184        delay: Duration,
185    }
186
187    #[async_trait]
188    impl MessageHandler for SlowHandler {
189        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
190            tokio::time::sleep(self.delay).await;
191            Ok(MiddlewareResult::Continue)
192        }
193    }
194
195    struct FailingHandler;
196
197    #[async_trait]
198    impl MessageHandler for FailingHandler {
199        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
200            Err(WorkerError::ProcessingFailed(
201                "intentional failure".to_string(),
202            ))
203        }
204    }
205
206    fn create_test_message() -> (
207        ReceivedMessage<serde_json::Value>,
208        Arc<AtomicBool>,
209        Arc<AtomicBool>,
210        Arc<AtomicBool>,
211    ) {
212        let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
213        let message = Message {
214            id: "test-msg-1".to_string(),
215            payload: serde_json::json!({"test": "data"}),
216            metadata: MessageMetadata::new("test-queue"),
217        };
218        (
219            ReceivedMessage::new(message, Arc::new(ack_handle)),
220            acked,
221            nacked,
222            requeued,
223        )
224    }
225
226    #[tokio::test]
227    async fn test_fast_processing_completes() {
228        let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
229        let (message, acked, nacked, _) = create_test_message();
230
231        let result = middleware.handle(message, Box::new(FastHandler)).await;
232
233        assert!(result.is_ok());
234        assert!(!acked.load(Ordering::SeqCst)); // Middleware doesn't auto-ack
235        assert!(!nacked.load(Ordering::SeqCst));
236    }
237
238    #[tokio::test]
239    async fn test_slow_processing_times_out() {
240        let timeout = Duration::from_millis(100);
241        let middleware = ProcessingTimeoutMiddleware::new(timeout);
242        let (message, _, nacked, requeued) = create_test_message();
243
244        // Handler takes longer than timeout
245        let slow_handler = SlowHandler {
246            delay: Duration::from_secs(1),
247        };
248
249        let result = middleware.handle(message, Box::new(slow_handler)).await;
250
251        assert!(result.is_err());
252        assert!(matches!(result.unwrap_err(), WorkerError::Timeout(_)));
253        assert!(nacked.load(Ordering::SeqCst)); // Should nack on timeout
254        assert!(requeued.load(Ordering::SeqCst)); // Should requeue
255    }
256
257    #[tokio::test]
258    async fn test_processing_error_propagates() {
259        let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
260        let (message, _, _, _) = create_test_message();
261
262        let result = middleware.handle(message, Box::new(FailingHandler)).await;
263
264        assert!(result.is_err());
265        assert!(matches!(
266            result.unwrap_err(),
267            WorkerError::ProcessingFailed(_)
268        ));
269    }
270
271    #[tokio::test]
272    async fn test_timeout_cancels_long_running_task() {
273        let timeout = Duration::from_millis(50);
274        let middleware = ProcessingTimeoutMiddleware::new(timeout);
275        let (message, _, nacked, _) = create_test_message();
276
277        // Very slow handler
278        let very_slow_handler = SlowHandler {
279            delay: Duration::from_secs(10),
280        };
281
282        let start = std::time::Instant::now();
283        let result = middleware
284            .handle(message, Box::new(very_slow_handler))
285            .await;
286        let elapsed = start.elapsed();
287
288        // Should timeout quickly, not wait for full 10 seconds
289        assert!(result.is_err());
290        assert!(elapsed < Duration::from_secs(1)); // Should complete in ~50ms, give some buffer
291        assert!(nacked.load(Ordering::SeqCst));
292    }
293
294    #[tokio::test]
295    async fn test_boundary_condition_exactly_at_timeout() {
296        let timeout = Duration::from_millis(100);
297        let middleware = ProcessingTimeoutMiddleware::new(timeout);
298        let (message, _, _, _) = create_test_message();
299
300        // Handler completes just before timeout
301        let almost_timeout_handler = SlowHandler {
302            delay: Duration::from_millis(80),
303        };
304
305        let result = middleware
306            .handle(message, Box::new(almost_timeout_handler))
307            .await;
308
309        // Should succeed (completed before timeout)
310        assert!(result.is_ok());
311    }
312
313    #[test]
314    #[should_panic(expected = "Processing timeout must be greater than zero")]
315    fn test_zero_timeout_panics() {
316        let _ = ProcessingTimeoutMiddleware::new(Duration::ZERO);
317    }
318
319    #[test]
320    fn test_timeout_getter() {
321        let timeout = Duration::from_secs(30);
322        let middleware = ProcessingTimeoutMiddleware::new(timeout);
323        assert_eq!(middleware.timeout(), timeout);
324    }
325}