Skip to main content

foxtive_worker/middleware/
ack_nack.rs

1use async_trait::async_trait;
2
3use crate::error::WorkerError;
4use crate::message::ReceivedMessage;
5use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
6
7/// Middleware that automatically acknowledges or negative-acknowledges messages
8/// based on processing results.
9///
10/// This middleware wraps message processing and:
11/// - Calls `ack()` if processing succeeds (when `ack_on_success` is true)
12/// - Calls `nack()` if processing fails (when `nack_on_failure` is true)
13///
14/// # Example
15/// ```rust,no_run
16/// use foxtive_worker::AckNackMiddleware;
17///
18/// // Auto-ack on success, auto-nack with requeue on failure
19/// let middleware = AckNackMiddleware::default();
20/// ```
21#[derive(Debug, Clone)]
22pub struct AckNackMiddleware {
23    /// Whether to acknowledge messages on successful processing
24    pub ack_on_success: bool,
25
26    /// Whether to negative-acknowledge messages on failed processing
27    pub nack_on_failure: bool,
28
29    /// Whether to requeue messages when nacking
30    pub requeue_on_nack: bool,
31}
32
33impl Default for AckNackMiddleware {
34    fn default() -> Self {
35        Self {
36            ack_on_success: true,
37            nack_on_failure: true,
38            requeue_on_nack: true,
39        }
40    }
41}
42
43impl AckNackMiddleware {
44    /// Create a new AckNackMiddleware with default settings.
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Create a new AckNackMiddleware with custom settings.
50    pub fn with_config(ack_on_success: bool, nack_on_failure: bool, requeue_on_nack: bool) -> Self {
51        Self {
52            ack_on_success,
53            nack_on_failure,
54            requeue_on_nack,
55        }
56    }
57}
58
59#[async_trait]
60impl Middleware for AckNackMiddleware {
61    fn name(&self) -> &str {
62        "ack-nack"
63    }
64
65    async fn handle(
66        &self,
67        message: ReceivedMessage<serde_json::Value>,
68        next: Box<dyn MessageHandler>,
69    ) -> Result<MiddlewareResult, WorkerError> {
70        let result = next.handle(message.clone()).await;
71
72        match result {
73            Ok(MiddlewareResult::Continue) if self.ack_on_success => {
74                // Acknowledge successful processing
75                message.ack().await.map_err(|e| {
76                    tracing::error!("Failed to ack message {}: {}", message.message.id, e);
77                    WorkerError::AcknowledgmentFailed(format!(
78                        "Message {} processed successfully but ack failed: {}",
79                        message.message.id, e
80                    ))
81                })?;
82                // Signal that acknowledgment was handled
83                Ok(MiddlewareResult::Acknowledged)
84            }
85            Err(e) if self.nack_on_failure => {
86                // Don't nack retry-related errors - let the pool handle delayed retry
87                match &e {
88                    WorkerError::RetryableFailure { .. } | WorkerError::RetriesExhausted { .. } => {
89                        tracing::debug!(
90                            "[AckNackMiddleware] Passing through retry error for message {}: {:?}",
91                            message.message.id,
92                            e
93                        );
94                        return Err(e);
95                    }
96                    _ => {}
97                }
98                
99                // Negative-acknowledge other failed processing
100                if let Err(nack_err) = message.nack(self.requeue_on_nack).await {
101                    tracing::error!(
102                        "Failed to nack message {}: {} (original error: {})",
103                        message.message.id,
104                        nack_err,
105                        e
106                    );
107                    // Return combined error to inform upstream about both failures
108                    return Err(WorkerError::AcknowledgmentFailed(format!(
109                        "Message {} processing failed and nack also failed: {} (original: {})",
110                        message.message.id, nack_err, e
111                    )));
112                }
113                // Successfully nacked - signal that acknowledgment was handled
114                Ok(MiddlewareResult::Acknowledged)
115            }
116            other => other,
117        }
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::message::{AckHandle, Message, MessageMetadata};
125    use std::sync::Arc;
126    use std::sync::atomic::{AtomicBool, Ordering};
127
128    #[derive(Debug)]
129    struct MockAckHandle {
130        acked: Arc<AtomicBool>,
131        nacked: Arc<AtomicBool>,
132        requeued: Arc<AtomicBool>,
133    }
134
135    impl MockAckHandle {
136        fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
137            let acked = Arc::new(AtomicBool::new(false));
138            let nacked = Arc::new(AtomicBool::new(false));
139            let requeued = Arc::new(AtomicBool::new(false));
140            (
141                Self {
142                    acked: acked.clone(),
143                    nacked: nacked.clone(),
144                    requeued: requeued.clone(),
145                },
146                acked,
147                nacked,
148                requeued,
149            )
150        }
151    }
152
153    #[async_trait]
154    impl AckHandle for MockAckHandle {
155        async fn ack(&self) -> crate::WorkerResult<()> {
156            self.acked.store(true, Ordering::SeqCst);
157            Ok(())
158        }
159
160        async fn nack(&self, requeue: bool) -> crate::WorkerResult<()> {
161            self.nacked.store(true, Ordering::SeqCst);
162            self.requeued.store(requeue, Ordering::SeqCst);
163            Ok(())
164        }
165    }
166
167    struct SuccessHandler;
168
169    #[async_trait]
170    impl MessageHandler for SuccessHandler {
171        async fn handle(
172            &self,
173            _message: ReceivedMessage<serde_json::Value>,
174        ) -> Result<MiddlewareResult, WorkerError> {
175            Ok(MiddlewareResult::Continue)
176        }
177    }
178
179    struct FailureHandler;
180
181    #[async_trait]
182    impl MessageHandler for FailureHandler {
183        async fn handle(
184            &self,
185            _message: ReceivedMessage<serde_json::Value>,
186        ) -> Result<MiddlewareResult, WorkerError> {
187            Err(crate::error::WorkerError::ProcessingFailed(
188                "test error".to_string(),
189            ))
190        }
191    }
192
193    fn create_test_message() -> (
194        ReceivedMessage<serde_json::Value>,
195        Arc<AtomicBool>,
196        Arc<AtomicBool>,
197        Arc<AtomicBool>,
198    ) {
199        let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
200        let message = Message {
201            id: "test-1".to_string(),
202            payload: serde_json::json!({"test": "data"}),
203            metadata: MessageMetadata::new("test-queue"),
204        };
205        (
206            ReceivedMessage::new(message, Arc::new(ack_handle)),
207            acked,
208            nacked,
209            requeued,
210        )
211    }
212
213    #[tokio::test]
214    async fn test_ack_on_success() {
215        let middleware = AckNackMiddleware::new();
216        let (message, acked, nacked, _) = create_test_message();
217
218        let result = middleware.handle(message, Box::new(SuccessHandler)).await;
219        // Middleware returns Acknowledged to signal it handled the ack
220        assert!(matches!(result, Ok(MiddlewareResult::Acknowledged)));
221
222        assert!(acked.load(Ordering::SeqCst));
223        assert!(!nacked.load(Ordering::SeqCst));
224    }
225
226    #[tokio::test]
227    async fn test_nack_on_failure() {
228        let middleware = AckNackMiddleware::new();
229        let (message, acked, nacked, requeued) = create_test_message();
230
231        let result = middleware.handle(message, Box::new(FailureHandler)).await;
232        // Middleware returns Acknowledged after successfully nacking
233        assert!(matches!(result, Ok(MiddlewareResult::Acknowledged)));
234
235        assert!(!acked.load(Ordering::SeqCst));
236        assert!(nacked.load(Ordering::SeqCst));
237        assert!(requeued.load(Ordering::SeqCst));
238    }
239
240    #[tokio::test]
241    async fn test_no_ack_on_success_when_disabled() {
242        let middleware = AckNackMiddleware::with_config(false, true, true);
243        let (message, acked, _, _) = create_test_message();
244
245        middleware
246            .handle(message, Box::new(SuccessHandler))
247            .await
248            .unwrap();
249
250        assert!(!acked.load(Ordering::SeqCst));
251    }
252
253    #[tokio::test]
254    async fn test_no_nack_on_failure_when_disabled() {
255        let middleware = AckNackMiddleware::with_config(true, false, true);
256        let (message, _, nacked, _) = create_test_message();
257
258        let _ = middleware.handle(message, Box::new(FailureHandler)).await;
259
260        assert!(!nacked.load(Ordering::SeqCst));
261    }
262
263    #[tokio::test]
264    async fn test_nack_without_requeue() {
265        let middleware = AckNackMiddleware::with_config(true, true, false);
266        let (message, _, nacked, requeued) = create_test_message();
267
268        let _ = middleware.handle(message, Box::new(FailureHandler)).await;
269
270        assert!(nacked.load(Ordering::SeqCst));
271        assert!(!requeued.load(Ordering::SeqCst));
272    }
273
274    struct RetryFailureHandler;
275
276    #[async_trait]
277    impl MessageHandler for RetryFailureHandler {
278        async fn handle(
279            &self,
280            _message: ReceivedMessage<serde_json::Value>,
281        ) -> Result<MiddlewareResult, WorkerError> {
282            Err(WorkerError::RetryableFailure {
283                source: Box::new(WorkerError::ProcessingFailed("retry error".to_string())),
284                delay_ms: std::time::Duration::from_secs(1),
285            })
286        }
287    }
288
289    #[tokio::test]
290    async fn test_passthrough_retry_failure() {
291        let middleware = AckNackMiddleware::new();
292        let (message, acked, nacked, _) = create_test_message();
293
294        let result = middleware.handle(message, Box::new(RetryFailureHandler)).await;
295        
296        // Should return the RetryableFailure error without nacking
297        assert!(result.is_err());
298        if let Err(WorkerError::RetryableFailure { .. }) = result {
299            // Success - retry error was passed through
300        } else {
301            panic!("Expected RetryableFailure to be passed through");
302        }
303        
304        // Verify no ack or nack was called
305        assert!(!acked.load(Ordering::SeqCst));
306        assert!(!nacked.load(Ordering::SeqCst));
307    }
308
309    struct RetriesExhaustedHandler;
310
311    #[async_trait]
312    impl MessageHandler for RetriesExhaustedHandler {
313        async fn handle(
314            &self,
315            _message: ReceivedMessage<serde_json::Value>,
316        ) -> Result<MiddlewareResult, WorkerError> {
317            Err(WorkerError::RetriesExhausted {
318                source: Box::new(WorkerError::ProcessingFailed("exhausted".to_string())),
319            })
320        }
321    }
322
323    #[tokio::test]
324    async fn test_passthrough_retries_exhausted() {
325        let middleware = AckNackMiddleware::new();
326        let (message, acked, nacked, _) = create_test_message();
327
328        let result = middleware.handle(message, Box::new(RetriesExhaustedHandler)).await;
329        
330        // Should return the RetriesExhausted error without nacking
331        assert!(result.is_err());
332        if let Err(WorkerError::RetriesExhausted { .. }) = result {
333            // Success - retries exhausted error was passed through
334        } else {
335            panic!("Expected RetriesExhausted to be passed through");
336        }
337        
338        // Verify no ack or nack was called
339        assert!(!acked.load(Ordering::SeqCst));
340        assert!(!nacked.load(Ordering::SeqCst));
341    }
342}