Skip to main content

foxtive_worker/middleware/
ack_nack.rs

1use async_trait::async_trait;
2
3use crate::error::WorkerError;
4use crate::message::ReceivedMessage;
5use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
6
7/// Middleware that automatically acknowledges or negative-acknowledges messages
8/// based on processing results.
9///
10/// This middleware wraps message processing and:
11/// - Calls `ack()` if processing succeeds (when `ack_on_success` is true)
12/// - Calls `nack()` if processing fails (when `nack_on_failure` is true)
13///
14/// # Example
15/// ```rust,no_run
16/// use foxtive_worker::AckNackMiddleware;
17///
18/// // Auto-ack on success, auto-nack with requeue on failure
19/// let middleware = AckNackMiddleware::default();
20/// ```
21#[derive(Debug, Clone)]
22pub struct AckNackMiddleware {
23    /// Whether to acknowledge messages on successful processing
24    pub ack_on_success: bool,
25
26    /// Whether to negative-acknowledge messages on failed processing
27    pub nack_on_failure: bool,
28
29    /// Whether to requeue messages when nacking
30    pub requeue_on_nack: bool,
31}
32
33impl Default for AckNackMiddleware {
34    fn default() -> Self {
35        Self {
36            ack_on_success: true,
37            nack_on_failure: true,
38            requeue_on_nack: true,
39        }
40    }
41}
42
43impl AckNackMiddleware {
44    /// Create a new AckNackMiddleware with default settings.
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Create a new AckNackMiddleware with custom settings.
50    pub fn with_config(ack_on_success: bool, nack_on_failure: bool, requeue_on_nack: bool) -> Self {
51        Self {
52            ack_on_success,
53            nack_on_failure,
54            requeue_on_nack,
55        }
56    }
57}
58
59#[async_trait]
60impl Middleware for AckNackMiddleware {
61    fn name(&self) -> &str {
62        "ack-nack"
63    }
64
65    async fn handle(
66        &self,
67        message: ReceivedMessage<serde_json::Value>,
68        next: Box<dyn MessageHandler>,
69    ) -> Result<MiddlewareResult, WorkerError> {
70        let result = next.handle(message.clone()).await;
71
72        match result {
73            Ok(MiddlewareResult::Continue) if self.ack_on_success => {
74                // Acknowledge successful processing
75                message.ack().await.map_err(|e| {
76                    tracing::error!("Failed to ack message {}: {}", message.message.id, e);
77                    WorkerError::AcknowledgmentFailed(format!(
78                        "Message {} processed successfully but ack failed: {}",
79                        message.message.id, e
80                    ))
81                })?;
82                // Signal that acknowledgment was handled
83                Ok(MiddlewareResult::Acknowledged)
84            }
85            Err(e) if self.nack_on_failure => {
86                // Negative-acknowledge failed processing
87                if let Err(nack_err) = message.nack(self.requeue_on_nack).await {
88                    tracing::error!(
89                        "Failed to nack message {}: {} (original error: {})",
90                        message.message.id,
91                        nack_err,
92                        e
93                    );
94                    // Return combined error to inform upstream about both failures
95                    return Err(WorkerError::AcknowledgmentFailed(format!(
96                        "Message {} processing failed and nack also failed: {} (original: {})",
97                        message.message.id, nack_err, e
98                    )));
99                }
100                // Successfully nacked - signal that acknowledgment was handled
101                Ok(MiddlewareResult::Acknowledged)
102            }
103            other => other,
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::message::{AckHandle, Message, MessageMetadata};
112    use std::sync::Arc;
113    use std::sync::atomic::{AtomicBool, Ordering};
114
115    #[derive(Debug)]
116    struct MockAckHandle {
117        acked: Arc<AtomicBool>,
118        nacked: Arc<AtomicBool>,
119        requeued: Arc<AtomicBool>,
120    }
121
122    impl MockAckHandle {
123        fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
124            let acked = Arc::new(AtomicBool::new(false));
125            let nacked = Arc::new(AtomicBool::new(false));
126            let requeued = Arc::new(AtomicBool::new(false));
127            (
128                Self {
129                    acked: acked.clone(),
130                    nacked: nacked.clone(),
131                    requeued: requeued.clone(),
132                },
133                acked,
134                nacked,
135                requeued,
136            )
137        }
138    }
139
140    #[async_trait]
141    impl AckHandle for MockAckHandle {
142        async fn ack(&self) -> crate::WorkerResult<()> {
143            self.acked.store(true, Ordering::SeqCst);
144            Ok(())
145        }
146
147        async fn nack(&self, requeue: bool) -> crate::WorkerResult<()> {
148            self.nacked.store(true, Ordering::SeqCst);
149            self.requeued.store(requeue, Ordering::SeqCst);
150            Ok(())
151        }
152    }
153
154    struct SuccessHandler;
155
156    #[async_trait]
157    impl MessageHandler for SuccessHandler {
158        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
159            Ok(MiddlewareResult::Continue)
160        }
161    }
162
163    struct FailureHandler;
164
165    #[async_trait]
166    impl MessageHandler for FailureHandler {
167        async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
168            Err(crate::error::WorkerError::ProcessingFailed(
169                "test error".to_string(),
170            ))
171        }
172    }
173
174    fn create_test_message() -> (
175        ReceivedMessage<serde_json::Value>,
176        Arc<AtomicBool>,
177        Arc<AtomicBool>,
178        Arc<AtomicBool>,
179    ) {
180        let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
181        let message = Message {
182            id: "test-1".to_string(),
183            payload: serde_json::json!({"test": "data"}),
184            metadata: MessageMetadata::new("test-queue"),
185        };
186        (
187            ReceivedMessage::new(message, Arc::new(ack_handle)),
188            acked,
189            nacked,
190            requeued,
191        )
192    }
193
194    #[tokio::test]
195    async fn test_ack_on_success() {
196        let middleware = AckNackMiddleware::new();
197        let (message, acked, nacked, _) = create_test_message();
198
199        let result = middleware.handle(message, Box::new(SuccessHandler)).await;
200        // Middleware returns Acknowledged to signal it handled the ack
201        assert!(matches!(result, Ok(MiddlewareResult::Acknowledged)));
202
203        assert!(acked.load(Ordering::SeqCst));
204        assert!(!nacked.load(Ordering::SeqCst));
205    }
206
207    #[tokio::test]
208    async fn test_nack_on_failure() {
209        let middleware = AckNackMiddleware::new();
210        let (message, acked, nacked, requeued) = create_test_message();
211
212        let result = middleware.handle(message, Box::new(FailureHandler)).await;
213        // Middleware returns Acknowledged after successfully nacking
214        assert!(matches!(result, Ok(MiddlewareResult::Acknowledged)));
215
216        assert!(!acked.load(Ordering::SeqCst));
217        assert!(nacked.load(Ordering::SeqCst));
218        assert!(requeued.load(Ordering::SeqCst));
219    }
220
221    #[tokio::test]
222    async fn test_no_ack_on_success_when_disabled() {
223        let middleware = AckNackMiddleware::with_config(false, true, true);
224        let (message, acked, _, _) = create_test_message();
225
226        middleware
227            .handle(message, Box::new(SuccessHandler))
228            .await
229            .unwrap();
230
231        assert!(!acked.load(Ordering::SeqCst));
232    }
233
234    #[tokio::test]
235    async fn test_no_nack_on_failure_when_disabled() {
236        let middleware = AckNackMiddleware::with_config(true, false, true);
237        let (message, _, nacked, _) = create_test_message();
238
239        let _ = middleware.handle(message, Box::new(FailureHandler)).await;
240
241        assert!(!nacked.load(Ordering::SeqCst));
242    }
243
244    #[tokio::test]
245    async fn test_nack_without_requeue() {
246        let middleware = AckNackMiddleware::with_config(true, true, false);
247        let (message, _, nacked, requeued) = create_test_message();
248
249        let _ = middleware.handle(message, Box::new(FailureHandler)).await;
250
251        assert!(nacked.load(Ordering::SeqCst));
252        assert!(!requeued.load(Ordering::SeqCst));
253    }
254}