foxtive_worker/middleware/
processing_timeout.rs1use async_trait::async_trait;
2use std::time::Duration;
3
4use crate::error::WorkerError;
5use crate::message::ReceivedMessage;
6use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
7
8#[derive(Debug, Clone)]
46pub struct ProcessingTimeoutMiddleware {
47 timeout: Duration,
49}
50
51impl ProcessingTimeoutMiddleware {
52 pub fn new(timeout: Duration) -> Self {
60 assert!(
61 !timeout.is_zero(),
62 "Processing timeout must be greater than zero"
63 );
64 Self { timeout }
65 }
66
67 pub fn timeout(&self) -> Duration {
69 self.timeout
70 }
71}
72
73#[async_trait]
74impl Middleware for ProcessingTimeoutMiddleware {
75 fn name(&self) -> &str {
76 "processing-timeout"
77 }
78
79 async fn handle(
80 &self,
81 message: ReceivedMessage<serde_json::Value>,
82 next: Box<dyn MessageHandler>,
83 ) -> Result<MiddlewareResult, WorkerError> {
84 let message_id = message.message.id.clone();
85
86 tracing::debug!(
87 message_id = %message_id,
88 timeout_ms = self.timeout.as_millis(),
89 "Starting message processing with timeout"
90 );
91
92 match tokio::time::timeout(self.timeout, next.handle(message.clone())).await {
96 Ok(result) => {
97 result
99 }
100 Err(_) => {
101 tracing::warn!(
103 message_id = %message_id,
104 timeout_ms = self.timeout.as_millis(),
105 "Message processing timed out - nacking with requeue"
106 );
107
108 if let Err(nack_err) = message.nack(true).await {
110 tracing::error!(
111 message_id = %message_id,
112 error = %nack_err,
113 "Failed to nack timed-out message"
114 );
115 }
116
117 Err(WorkerError::Timeout(format!(
119 "Message {} processing exceeded timeout of {:?}",
120 message_id, self.timeout
121 )))
122 }
123 }
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use crate::message::{AckHandle, Message, MessageMetadata};
131 use std::sync::Arc;
132 use std::sync::atomic::{AtomicBool, Ordering};
133
134 #[derive(Debug)]
135 struct MockAckHandle {
136 acked: Arc<AtomicBool>,
137 nacked: Arc<AtomicBool>,
138 requeued: Arc<AtomicBool>,
139 }
140
141 impl MockAckHandle {
142 fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
143 let acked = Arc::new(AtomicBool::new(false));
144 let nacked = Arc::new(AtomicBool::new(false));
145 let requeued = Arc::new(AtomicBool::new(false));
146 (
147 Self {
148 acked: acked.clone(),
149 nacked: nacked.clone(),
150 requeued: requeued.clone(),
151 },
152 acked,
153 nacked,
154 requeued,
155 )
156 }
157 }
158
159 #[async_trait]
160 impl AckHandle for MockAckHandle {
161 async fn ack(&self) -> crate::WorkerResult<()> {
162 self.acked.store(true, Ordering::SeqCst);
163 Ok(())
164 }
165
166 async fn nack(&self, requeue: bool) -> crate::WorkerResult<()> {
167 self.nacked.store(true, Ordering::SeqCst);
168 self.requeued.store(requeue, Ordering::SeqCst);
169 Ok(())
170 }
171 }
172
173 struct FastHandler;
174
175 #[async_trait]
176 impl MessageHandler for FastHandler {
177 async fn handle(
178 &self,
179 _message: ReceivedMessage<serde_json::Value>,
180 ) -> Result<MiddlewareResult, WorkerError> {
181 Ok(MiddlewareResult::Continue)
183 }
184 }
185
186 struct SlowHandler {
187 delay: Duration,
188 }
189
190 #[async_trait]
191 impl MessageHandler for SlowHandler {
192 async fn handle(
193 &self,
194 _message: ReceivedMessage<serde_json::Value>,
195 ) -> Result<MiddlewareResult, WorkerError> {
196 tokio::time::sleep(self.delay).await;
197 Ok(MiddlewareResult::Continue)
198 }
199 }
200
201 struct FailingHandler;
202
203 #[async_trait]
204 impl MessageHandler for FailingHandler {
205 async fn handle(
206 &self,
207 _message: ReceivedMessage<serde_json::Value>,
208 ) -> Result<MiddlewareResult, WorkerError> {
209 Err(WorkerError::ProcessingFailed(
210 "intentional failure".to_string(),
211 ))
212 }
213 }
214
215 fn create_test_message() -> (
216 ReceivedMessage<serde_json::Value>,
217 Arc<AtomicBool>,
218 Arc<AtomicBool>,
219 Arc<AtomicBool>,
220 ) {
221 let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
222 let message = Message {
223 id: "test-msg-1".to_string(),
224 payload: serde_json::json!({"test": "data"}),
225 metadata: MessageMetadata::new("test-queue"),
226 };
227 (
228 ReceivedMessage::new(message, Arc::new(ack_handle)),
229 acked,
230 nacked,
231 requeued,
232 )
233 }
234
235 #[tokio::test]
236 async fn test_fast_processing_completes() {
237 let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
238 let (message, acked, nacked, _) = create_test_message();
239
240 let result = middleware.handle(message, Box::new(FastHandler)).await;
241
242 assert!(result.is_ok());
243 assert!(!acked.load(Ordering::SeqCst)); assert!(!nacked.load(Ordering::SeqCst));
245 }
246
247 #[tokio::test]
248 async fn test_slow_processing_times_out() {
249 let timeout = Duration::from_millis(100);
250 let middleware = ProcessingTimeoutMiddleware::new(timeout);
251 let (message, _, nacked, requeued) = create_test_message();
252
253 let slow_handler = SlowHandler {
255 delay: Duration::from_secs(1),
256 };
257
258 let result = middleware.handle(message, Box::new(slow_handler)).await;
259
260 assert!(result.is_err());
261 assert!(matches!(result.unwrap_err(), WorkerError::Timeout(_)));
262 assert!(nacked.load(Ordering::SeqCst)); assert!(requeued.load(Ordering::SeqCst)); }
265
266 #[tokio::test]
267 async fn test_processing_error_propagates() {
268 let middleware = ProcessingTimeoutMiddleware::new(Duration::from_secs(5));
269 let (message, _, _, _) = create_test_message();
270
271 let result = middleware.handle(message, Box::new(FailingHandler)).await;
272
273 assert!(result.is_err());
274 assert!(matches!(
275 result.unwrap_err(),
276 WorkerError::ProcessingFailed(_)
277 ));
278 }
279
280 #[tokio::test]
281 async fn test_timeout_cancels_long_running_task() {
282 let timeout = Duration::from_millis(50);
283 let middleware = ProcessingTimeoutMiddleware::new(timeout);
284 let (message, _, nacked, _) = create_test_message();
285
286 let very_slow_handler = SlowHandler {
288 delay: Duration::from_secs(10),
289 };
290
291 let start = std::time::Instant::now();
292 let result = middleware
293 .handle(message, Box::new(very_slow_handler))
294 .await;
295 let elapsed = start.elapsed();
296
297 assert!(result.is_err());
299 assert!(elapsed < Duration::from_secs(1)); assert!(nacked.load(Ordering::SeqCst));
301 }
302
303 #[tokio::test]
304 async fn test_boundary_condition_exactly_at_timeout() {
305 let timeout = Duration::from_millis(100);
306 let middleware = ProcessingTimeoutMiddleware::new(timeout);
307 let (message, _, _, _) = create_test_message();
308
309 let almost_timeout_handler = SlowHandler {
311 delay: Duration::from_millis(80),
312 };
313
314 let result = middleware
315 .handle(message, Box::new(almost_timeout_handler))
316 .await;
317
318 assert!(result.is_ok());
320 }
321
322 #[test]
323 #[should_panic(expected = "Processing timeout must be greater than zero")]
324 fn test_zero_timeout_panics() {
325 let _ = ProcessingTimeoutMiddleware::new(Duration::ZERO);
326 }
327
328 #[test]
329 fn test_timeout_getter() {
330 let timeout = Duration::from_secs(30);
331 let middleware = ProcessingTimeoutMiddleware::new(timeout);
332 assert_eq!(middleware.timeout(), timeout);
333 }
334}