foxtive_worker/middleware/
processing_timeout.rs1use async_trait::async_trait;
2use std::time::Duration;
3
4use crate::error::WorkerError;
5use crate::message::ReceivedMessage;
6use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
7
8#[derive(Debug, Clone)]
46pub struct ProcessingTimeoutMiddleware {
47 timeout: Duration,
49}
50
51impl ProcessingTimeoutMiddleware {
52 pub fn new(timeout: Duration) -> Self {
60 assert!(
61 !timeout.is_zero(),
62 "Processing timeout must be greater than zero"
63 );
64 Self { timeout }
65 }
66
67 pub fn timeout(&self) -> Duration {
69 self.timeout
70 }
71}
72
73#[async_trait]
74impl Middleware for ProcessingTimeoutMiddleware {
75 fn name(&self) -> &str {
76 "processing-timeout"
77 }
78
79 async fn handle(
80 &self,
81 message: ReceivedMessage<serde_json::Value>,
82 next: Box<dyn MessageHandler>,
83 ) -> Result<MiddlewareResult, WorkerError> {
84 let message_id = message.message.id.clone();
85
86 tracing::debug!(
87 message_id = %message_id,
88 timeout_ms = self.timeout.as_millis(),
89 "Starting message processing with timeout"
90 );
91
92 match tokio::time::timeout(self.timeout, next.handle(message.clone())).await {
96 Ok(result) => {
97 result
99 }
100 Err(_) => {
101 tracing::warn!(
103 message_id = %message_id,
104 timeout_ms = self.timeout.as_millis(),
105 "Message processing timed out - nacking with requeue"
106 );
107
108 if let Err(nack_err) = message.nack(true).await {
110 tracing::error!(
111 message_id = %message_id,
112 error = %nack_err,
113 "Failed to nack timed-out message"
114 );
115 }
116
117 Err(WorkerError::Timeout(format!(
119 "Message {} processing exceeded timeout of {:?}",
120 message_id, self.timeout
121 )))
122 }
123 }
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use crate::message::{AckHandle, Message, MessageMetadata};
131 use std::sync::Arc;
132 use std::sync::atomic::{AtomicBool, Ordering};
133
134 #[derive(Debug)]
135 struct MockAckHandle {
136 acked: Arc<AtomicBool>,
137 nacked: Arc<AtomicBool>,
138 requeued: Arc<AtomicBool>,
139 }
140
141 impl MockAckHandle {
142 fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
143 let acked = Arc::new(AtomicBool::new(false));
144 let nacked = Arc::new(AtomicBool::new(false));
145 let requeued = Arc::new(AtomicBool::new(false));
146 (
147 Self {
148 acked: acked.clone(),
149 nacked: nacked.clone(),
150 requeued: requeued.clone(),
151 },
152 acked,
153 nacked,
154 requeued,
155 )
156 }
157 }
158
159 #[async_trait]
160 impl AckHandle for MockAckHandle {
161 async fn ack(&self) -> crate::WorkerResult<()> {
162 self.acked.store(true, Ordering::SeqCst);
163 Ok(())
164 }
165
166 async fn nack(&self, requeue: bool) -> crate::WorkerResult<()> {
167 self.nacked.store(true, Ordering::SeqCst);
168 self.requeued.store(requeue, Ordering::SeqCst);
169 Ok(())
170 }
171 }
172
173 struct FastHandler;
174
175 #[async_trait]
176 impl MessageHandler for FastHandler {
177 async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
178 Ok(MiddlewareResult::Continue)
180 }
181 }
182
183 struct SlowHandler {
184 delay: Duration,
185 }
186
187 #[async_trait]
188 impl MessageHandler for SlowHandler {
189 async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
190 tokio::time::sleep(self.delay).await;
191 Ok(MiddlewareResult::Continue)
192 }
193 }
194
195 struct FailingHandler;
196
197 #[async_trait]
198 impl MessageHandler for FailingHandler {
199 async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
200 Err(WorkerError::ProcessingFailed(
201 "intentional failure".to_string(),
202 ))
203 }
204 }
205
206 fn create_test_message() -> (
207 ReceivedMessage<serde_json::Value>,
208 Arc<AtomicBool>,
209 Arc<AtomicBool>,
210 Arc<AtomicBool>,
211 ) {
212 let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
213 let message = Message {
214 id: "test-msg-1".to_string(),
215 payload: serde_json::json!({"test": "data"}),
216 metadata: MessageMetadata::new("test-queue"),
217 };
218 (
219 ReceivedMessage::new(message, Arc::new(ack_handle)),
220 acked,
221 nacked,
222 requeued,
223 )
224 }
225
226 #[tokio::test]
227 async fn test_fast_processing_completes() {
228 let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
229 let (message, acked, nacked, _) = create_test_message();
230
231 let result = middleware.handle(message, Box::new(FastHandler)).await;
232
233 assert!(result.is_ok());
234 assert!(!acked.load(Ordering::SeqCst)); assert!(!nacked.load(Ordering::SeqCst));
236 }
237
238 #[tokio::test]
239 async fn test_slow_processing_times_out() {
240 let timeout = Duration::from_millis(100);
241 let middleware = ProcessingTimeoutMiddleware::new(timeout);
242 let (message, _, nacked, requeued) = create_test_message();
243
244 let slow_handler = SlowHandler {
246 delay: Duration::from_secs(1),
247 };
248
249 let result = middleware.handle(message, Box::new(slow_handler)).await;
250
251 assert!(result.is_err());
252 assert!(matches!(result.unwrap_err(), WorkerError::Timeout(_)));
253 assert!(nacked.load(Ordering::SeqCst)); assert!(requeued.load(Ordering::SeqCst)); }
256
257 #[tokio::test]
258 async fn test_processing_error_propagates() {
259 let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
260 let (message, _, _, _) = create_test_message();
261
262 let result = middleware.handle(message, Box::new(FailingHandler)).await;
263
264 assert!(result.is_err());
265 assert!(matches!(
266 result.unwrap_err(),
267 WorkerError::ProcessingFailed(_)
268 ));
269 }
270
271 #[tokio::test]
272 async fn test_timeout_cancels_long_running_task() {
273 let timeout = Duration::from_millis(50);
274 let middleware = ProcessingTimeoutMiddleware::new(timeout);
275 let (message, _, nacked, _) = create_test_message();
276
277 let very_slow_handler = SlowHandler {
279 delay: Duration::from_secs(10),
280 };
281
282 let start = std::time::Instant::now();
283 let result = middleware
284 .handle(message, Box::new(very_slow_handler))
285 .await;
286 let elapsed = start.elapsed();
287
288 assert!(result.is_err());
290 assert!(elapsed < Duration::from_secs(1)); assert!(nacked.load(Ordering::SeqCst));
292 }
293
294 #[tokio::test]
295 async fn test_boundary_condition_exactly_at_timeout() {
296 let timeout = Duration::from_millis(100);
297 let middleware = ProcessingTimeoutMiddleware::new(timeout);
298 let (message, _, _, _) = create_test_message();
299
300 let almost_timeout_handler = SlowHandler {
302 delay: Duration::from_millis(80),
303 };
304
305 let result = middleware
306 .handle(message, Box::new(almost_timeout_handler))
307 .await;
308
309 assert!(result.is_ok());
311 }
312
313 #[test]
314 #[should_panic(expected = "Processing timeout must be greater than zero")]
315 fn test_zero_timeout_panics() {
316 let _ = ProcessingTimeoutMiddleware::new(Duration::ZERO);
317 }
318
319 #[test]
320 fn test_timeout_getter() {
321 let timeout = Duration::from_secs(30);
322 let middleware = ProcessingTimeoutMiddleware::new(timeout);
323 assert_eq!(middleware.timeout(), timeout);
324 }
325}