foxtive_worker/middleware/
ack_nack.rs1use async_trait::async_trait;
2
3use crate::error::WorkerError;
4use crate::message::ReceivedMessage;
5use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
6
7#[derive(Debug, Clone)]
22pub struct AckNackMiddleware {
23 pub ack_on_success: bool,
25
26 pub nack_on_failure: bool,
28
29 pub requeue_on_nack: bool,
31}
32
33impl Default for AckNackMiddleware {
34 fn default() -> Self {
35 Self {
36 ack_on_success: true,
37 nack_on_failure: true,
38 requeue_on_nack: true,
39 }
40 }
41}
42
43impl AckNackMiddleware {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn with_config(ack_on_success: bool, nack_on_failure: bool, requeue_on_nack: bool) -> Self {
51 Self {
52 ack_on_success,
53 nack_on_failure,
54 requeue_on_nack,
55 }
56 }
57}
58
59#[async_trait]
60impl Middleware for AckNackMiddleware {
61 fn name(&self) -> &str {
62 "ack-nack"
63 }
64
65 async fn handle(
66 &self,
67 message: ReceivedMessage<serde_json::Value>,
68 next: Box<dyn MessageHandler>,
69 ) -> Result<MiddlewareResult, WorkerError> {
70 let result = next.handle(message.clone()).await;
71
72 match result {
73 Ok(MiddlewareResult::Continue) if self.ack_on_success => {
74 message.ack().await.map_err(|e| {
76 tracing::error!("Failed to ack message {}: {}", message.message.id, e);
77 WorkerError::AcknowledgmentFailed(format!(
78 "Message {} processed successfully but ack failed: {}",
79 message.message.id, e
80 ))
81 })?;
82 Ok(MiddlewareResult::Acknowledged)
84 }
85 Err(e) if self.nack_on_failure => {
86 match &e {
88 WorkerError::RetryableFailure { .. } | WorkerError::RetriesExhausted { .. } => {
89 tracing::debug!(
90 "[AckNackMiddleware] Passing through retry error for message {}: {:?}",
91 message.message.id,
92 e
93 );
94 return Err(e);
95 }
96 _ => {}
97 }
98
99 if let Err(nack_err) = message.nack(self.requeue_on_nack).await {
101 tracing::error!(
102 "Failed to nack message {}: {} (original error: {})",
103 message.message.id,
104 nack_err,
105 e
106 );
107 return Err(WorkerError::AcknowledgmentFailed(format!(
109 "Message {} processing failed and nack also failed: {} (original: {})",
110 message.message.id, nack_err, e
111 )));
112 }
113 Ok(MiddlewareResult::Acknowledged)
115 }
116 other => other,
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::message::{AckHandle, Message, MessageMetadata};
125 use std::sync::Arc;
126 use std::sync::atomic::{AtomicBool, Ordering};
127
128 #[derive(Debug)]
129 struct MockAckHandle {
130 acked: Arc<AtomicBool>,
131 nacked: Arc<AtomicBool>,
132 requeued: Arc<AtomicBool>,
133 }
134
135 impl MockAckHandle {
136 fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
137 let acked = Arc::new(AtomicBool::new(false));
138 let nacked = Arc::new(AtomicBool::new(false));
139 let requeued = Arc::new(AtomicBool::new(false));
140 (
141 Self {
142 acked: acked.clone(),
143 nacked: nacked.clone(),
144 requeued: requeued.clone(),
145 },
146 acked,
147 nacked,
148 requeued,
149 )
150 }
151 }
152
153 #[async_trait]
154 impl AckHandle for MockAckHandle {
155 async fn ack(&self) -> crate::WorkerResult<()> {
156 self.acked.store(true, Ordering::SeqCst);
157 Ok(())
158 }
159
160 async fn nack(&self, requeue: bool) -> crate::WorkerResult<()> {
161 self.nacked.store(true, Ordering::SeqCst);
162 self.requeued.store(requeue, Ordering::SeqCst);
163 Ok(())
164 }
165 }
166
167 struct SuccessHandler;
168
169 #[async_trait]
170 impl MessageHandler for SuccessHandler {
171 async fn handle(
172 &self,
173 _message: ReceivedMessage<serde_json::Value>,
174 ) -> Result<MiddlewareResult, WorkerError> {
175 Ok(MiddlewareResult::Continue)
176 }
177 }
178
179 struct FailureHandler;
180
181 #[async_trait]
182 impl MessageHandler for FailureHandler {
183 async fn handle(
184 &self,
185 _message: ReceivedMessage<serde_json::Value>,
186 ) -> Result<MiddlewareResult, WorkerError> {
187 Err(crate::error::WorkerError::ProcessingFailed(
188 "test error".to_string(),
189 ))
190 }
191 }
192
193 fn create_test_message() -> (
194 ReceivedMessage<serde_json::Value>,
195 Arc<AtomicBool>,
196 Arc<AtomicBool>,
197 Arc<AtomicBool>,
198 ) {
199 let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
200 let message = Message {
201 id: "test-1".to_string(),
202 payload: serde_json::json!({"test": "data"}),
203 metadata: MessageMetadata::new("test-queue"),
204 };
205 (
206 ReceivedMessage::new(message, Arc::new(ack_handle)),
207 acked,
208 nacked,
209 requeued,
210 )
211 }
212
213 #[tokio::test]
214 async fn test_ack_on_success() {
215 let middleware = AckNackMiddleware::new();
216 let (message, acked, nacked, _) = create_test_message();
217
218 let result = middleware.handle(message, Box::new(SuccessHandler)).await;
219 assert!(matches!(result, Ok(MiddlewareResult::Acknowledged)));
221
222 assert!(acked.load(Ordering::SeqCst));
223 assert!(!nacked.load(Ordering::SeqCst));
224 }
225
226 #[tokio::test]
227 async fn test_nack_on_failure() {
228 let middleware = AckNackMiddleware::new();
229 let (message, acked, nacked, requeued) = create_test_message();
230
231 let result = middleware.handle(message, Box::new(FailureHandler)).await;
232 assert!(matches!(result, Ok(MiddlewareResult::Acknowledged)));
234
235 assert!(!acked.load(Ordering::SeqCst));
236 assert!(nacked.load(Ordering::SeqCst));
237 assert!(requeued.load(Ordering::SeqCst));
238 }
239
240 #[tokio::test]
241 async fn test_no_ack_on_success_when_disabled() {
242 let middleware = AckNackMiddleware::with_config(false, true, true);
243 let (message, acked, _, _) = create_test_message();
244
245 middleware
246 .handle(message, Box::new(SuccessHandler))
247 .await
248 .unwrap();
249
250 assert!(!acked.load(Ordering::SeqCst));
251 }
252
253 #[tokio::test]
254 async fn test_no_nack_on_failure_when_disabled() {
255 let middleware = AckNackMiddleware::with_config(true, false, true);
256 let (message, _, nacked, _) = create_test_message();
257
258 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
259
260 assert!(!nacked.load(Ordering::SeqCst));
261 }
262
263 #[tokio::test]
264 async fn test_nack_without_requeue() {
265 let middleware = AckNackMiddleware::with_config(true, true, false);
266 let (message, _, nacked, requeued) = create_test_message();
267
268 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
269
270 assert!(nacked.load(Ordering::SeqCst));
271 assert!(!requeued.load(Ordering::SeqCst));
272 }
273
274 struct RetryFailureHandler;
275
276 #[async_trait]
277 impl MessageHandler for RetryFailureHandler {
278 async fn handle(
279 &self,
280 _message: ReceivedMessage<serde_json::Value>,
281 ) -> Result<MiddlewareResult, WorkerError> {
282 Err(WorkerError::RetryableFailure {
283 source: Box::new(WorkerError::ProcessingFailed("retry error".to_string())),
284 delay_ms: std::time::Duration::from_secs(1),
285 })
286 }
287 }
288
289 #[tokio::test]
290 async fn test_passthrough_retry_failure() {
291 let middleware = AckNackMiddleware::new();
292 let (message, acked, nacked, _) = create_test_message();
293
294 let result = middleware.handle(message, Box::new(RetryFailureHandler)).await;
295
296 assert!(result.is_err());
298 if let Err(WorkerError::RetryableFailure { .. }) = result {
299 } else {
301 panic!("Expected RetryableFailure to be passed through");
302 }
303
304 assert!(!acked.load(Ordering::SeqCst));
306 assert!(!nacked.load(Ordering::SeqCst));
307 }
308
309 struct RetriesExhaustedHandler;
310
311 #[async_trait]
312 impl MessageHandler for RetriesExhaustedHandler {
313 async fn handle(
314 &self,
315 _message: ReceivedMessage<serde_json::Value>,
316 ) -> Result<MiddlewareResult, WorkerError> {
317 Err(WorkerError::RetriesExhausted {
318 source: Box::new(WorkerError::ProcessingFailed("exhausted".to_string())),
319 })
320 }
321 }
322
323 #[tokio::test]
324 async fn test_passthrough_retries_exhausted() {
325 let middleware = AckNackMiddleware::new();
326 let (message, acked, nacked, _) = create_test_message();
327
328 let result = middleware.handle(message, Box::new(RetriesExhaustedHandler)).await;
329
330 assert!(result.is_err());
332 if let Err(WorkerError::RetriesExhausted { .. }) = result {
333 } else {
335 panic!("Expected RetriesExhausted to be passed through");
336 }
337
338 assert!(!acked.load(Ordering::SeqCst));
340 assert!(!nacked.load(Ordering::SeqCst));
341 }
342}