foxtive_worker/middleware/
processing_timeout.rs1use async_trait::async_trait;
2use std::time::Duration;
3
4use crate::error::{WorkerError, WorkerResult};
5use crate::message::ReceivedMessage;
6use crate::middleware::{MessageHandler, Middleware};
7
8#[derive(Debug, Clone)]
46pub struct ProcessingTimeoutMiddleware {
47 timeout: Duration,
49}
50
51impl ProcessingTimeoutMiddleware {
52 pub fn new(timeout: Duration) -> Self {
60 assert!(
61 !timeout.is_zero(),
62 "Processing timeout must be greater than zero"
63 );
64 Self { timeout }
65 }
66
67 pub fn timeout(&self) -> Duration {
69 self.timeout
70 }
71}
72
73#[async_trait]
74impl Middleware for ProcessingTimeoutMiddleware {
75 fn name(&self) -> &str {
76 "processing-timeout"
77 }
78
79 async fn handle(
80 &self,
81 message: ReceivedMessage<serde_json::Value>,
82 next: Box<dyn MessageHandler>,
83 ) -> WorkerResult<()> {
84 let message_id = message.message.id.clone();
85
86 tracing::debug!(
87 message_id = %message_id,
88 timeout_ms = self.timeout.as_millis(),
89 "Starting message processing with timeout"
90 );
91
92 match tokio::time::timeout(self.timeout, next.handle(message.clone())).await {
96 Ok(result) => {
97 match result {
99 Ok(()) => {
100 tracing::debug!(
101 message_id = %message_id,
102 "Message processing completed successfully within timeout"
103 );
104 Ok(())
105 }
106 Err(e) => {
107 tracing::warn!(
108 message_id = %message_id,
109 error = %e,
110 "Message processing failed (within timeout)"
111 );
112 Err(e)
113 }
114 }
115 }
116 Err(_) => {
117 tracing::warn!(
119 message_id = %message_id,
120 timeout_ms = self.timeout.as_millis(),
121 "Message processing timed out - nacking with requeue"
122 );
123
124 if let Err(nack_err) = message.nack(true).await {
126 tracing::error!(
127 message_id = %message_id,
128 error = %nack_err,
129 "Failed to nack timed-out message"
130 );
131 }
132
133 Err(WorkerError::Timeout(format!(
135 "Message {} processing exceeded timeout of {:?}",
136 message_id, self.timeout
137 )))
138 }
139 }
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::message::{AckHandle, Message, MessageMetadata};
147 use std::sync::atomic::{AtomicBool, Ordering};
148 use std::sync::Arc;
149
150 #[derive(Debug)]
151 struct MockAckHandle {
152 acked: Arc<AtomicBool>,
153 nacked: Arc<AtomicBool>,
154 requeued: Arc<AtomicBool>,
155 }
156
157 impl MockAckHandle {
158 fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
159 let acked = Arc::new(AtomicBool::new(false));
160 let nacked = Arc::new(AtomicBool::new(false));
161 let requeued = Arc::new(AtomicBool::new(false));
162 (
163 Self {
164 acked: acked.clone(),
165 nacked: nacked.clone(),
166 requeued: requeued.clone(),
167 },
168 acked,
169 nacked,
170 requeued,
171 )
172 }
173 }
174
175 #[async_trait]
176 impl AckHandle for MockAckHandle {
177 async fn ack(&self) -> WorkerResult<()> {
178 self.acked.store(true, Ordering::SeqCst);
179 Ok(())
180 }
181
182 async fn nack(&self, requeue: bool) -> WorkerResult<()> {
183 self.nacked.store(true, Ordering::SeqCst);
184 self.requeued.store(requeue, Ordering::SeqCst);
185 Ok(())
186 }
187 }
188
189 struct FastHandler;
190
191 #[async_trait]
192 impl MessageHandler for FastHandler {
193 async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
194 Ok(())
196 }
197 }
198
199 struct SlowHandler {
200 delay: Duration,
201 }
202
203 #[async_trait]
204 impl MessageHandler for SlowHandler {
205 async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
206 tokio::time::sleep(self.delay).await;
207 Ok(())
208 }
209 }
210
211 struct FailingHandler;
212
213 #[async_trait]
214 impl MessageHandler for FailingHandler {
215 async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
216 Err(WorkerError::ProcessingFailed("intentional failure".to_string()))
217 }
218 }
219
220 fn create_test_message() -> (ReceivedMessage<serde_json::Value>, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
221 let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
222 let message = Message {
223 id: "test-msg-1".to_string(),
224 payload: serde_json::json!({"test": "data"}),
225 metadata: MessageMetadata::new("test-queue"),
226 };
227 (ReceivedMessage::new(message, Arc::new(ack_handle)), acked, nacked, requeued)
228 }
229
230 #[tokio::test]
231 async fn test_fast_processing_completes() {
232 let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
233 let (message, acked, nacked, _) = create_test_message();
234
235 let result = middleware.handle(message, Box::new(FastHandler)).await;
236
237 assert!(result.is_ok());
238 assert!(!acked.load(Ordering::SeqCst)); assert!(!nacked.load(Ordering::SeqCst));
240 }
241
242 #[tokio::test]
243 async fn test_slow_processing_times_out() {
244 let timeout = Duration::from_millis(100);
245 let middleware = ProcessingTimeoutMiddleware::new(timeout);
246 let (message, _, nacked, requeued) = create_test_message();
247
248 let slow_handler = SlowHandler {
250 delay: Duration::from_secs(1),
251 };
252
253 let result = middleware.handle(message, Box::new(slow_handler)).await;
254
255 assert!(result.is_err());
256 assert!(matches!(result.unwrap_err(), WorkerError::Timeout(_)));
257 assert!(nacked.load(Ordering::SeqCst)); assert!(requeued.load(Ordering::SeqCst)); }
260
261 #[tokio::test]
262 async fn test_processing_error_propagates() {
263 let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
264 let (message, _, _, _) = create_test_message();
265
266 let result = middleware.handle(message, Box::new(FailingHandler)).await;
267
268 assert!(result.is_err());
269 assert!(matches!(result.unwrap_err(), WorkerError::ProcessingFailed(_)));
270 }
271
272 #[tokio::test]
273 async fn test_timeout_cancels_long_running_task() {
274 let timeout = Duration::from_millis(50);
275 let middleware = ProcessingTimeoutMiddleware::new(timeout);
276 let (message, _, nacked, _) = create_test_message();
277
278 let very_slow_handler = SlowHandler {
280 delay: Duration::from_secs(10),
281 };
282
283 let start = std::time::Instant::now();
284 let result = middleware.handle(message, Box::new(very_slow_handler)).await;
285 let elapsed = start.elapsed();
286
287 assert!(result.is_err());
289 assert!(elapsed < Duration::from_secs(1)); assert!(nacked.load(Ordering::SeqCst));
291 }
292
293 #[tokio::test]
294 async fn test_boundary_condition_exactly_at_timeout() {
295 let timeout = Duration::from_millis(100);
296 let middleware = ProcessingTimeoutMiddleware::new(timeout);
297 let (message, _, _, _) = create_test_message();
298
299 let almost_timeout_handler = SlowHandler {
301 delay: Duration::from_millis(80),
302 };
303
304 let result = middleware.handle(message, Box::new(almost_timeout_handler)).await;
305
306 assert!(result.is_ok());
308 }
309
310 #[test]
311 #[should_panic(expected = "Processing timeout must be greater than zero")]
312 fn test_zero_timeout_panics() {
313 let _ = ProcessingTimeoutMiddleware::new(Duration::ZERO);
314 }
315
316 #[test]
317 fn test_timeout_getter() {
318 let timeout = Duration::from_secs(30);
319 let middleware = ProcessingTimeoutMiddleware::new(timeout);
320 assert_eq!(middleware.timeout(), timeout);
321 }
322}