foxtive_worker/middleware/
ack_nack.rs1use async_trait::async_trait;
2
3use crate::error::WorkerError;
4use crate::message::ReceivedMessage;
5use crate::middleware::{MessageHandler, Middleware, MiddlewareResult};
6
7#[derive(Debug, Clone)]
22pub struct AckNackMiddleware {
23 pub ack_on_success: bool,
25
26 pub nack_on_failure: bool,
28
29 pub requeue_on_nack: bool,
31}
32
33impl Default for AckNackMiddleware {
34 fn default() -> Self {
35 Self {
36 ack_on_success: true,
37 nack_on_failure: true,
38 requeue_on_nack: true,
39 }
40 }
41}
42
43impl AckNackMiddleware {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn with_config(ack_on_success: bool, nack_on_failure: bool, requeue_on_nack: bool) -> Self {
51 Self {
52 ack_on_success,
53 nack_on_failure,
54 requeue_on_nack,
55 }
56 }
57}
58
59#[async_trait]
60impl Middleware for AckNackMiddleware {
61 fn name(&self) -> &str {
62 "ack-nack"
63 }
64
65 async fn handle(
66 &self,
67 message: ReceivedMessage<serde_json::Value>,
68 next: Box<dyn MessageHandler>,
69 ) -> Result<MiddlewareResult, WorkerError> {
70 let result = next.handle(message.clone()).await;
71
72 match result {
73 Ok(MiddlewareResult::Continue) if self.ack_on_success => {
74 message.ack().await.map_err(|e| {
76 tracing::error!("Failed to ack message {}: {}", message.message.id, e);
77 WorkerError::AcknowledgmentFailed(format!(
78 "Message {} processed successfully but ack failed: {}",
79 message.message.id, e
80 ))
81 })?;
82 Ok(MiddlewareResult::Acknowledged)
84 }
85 Err(e) if self.nack_on_failure => {
86 if let Err(nack_err) = message.nack(self.requeue_on_nack).await {
88 tracing::error!(
89 "Failed to nack message {}: {} (original error: {})",
90 message.message.id,
91 nack_err,
92 e
93 );
94 return Err(WorkerError::AcknowledgmentFailed(format!(
96 "Message {} processing failed and nack also failed: {} (original: {})",
97 message.message.id, nack_err, e
98 )));
99 }
100 Ok(MiddlewareResult::Acknowledged)
102 }
103 other => other,
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::message::{AckHandle, Message, MessageMetadata};
112 use std::sync::Arc;
113 use std::sync::atomic::{AtomicBool, Ordering};
114
115 #[derive(Debug)]
116 struct MockAckHandle {
117 acked: Arc<AtomicBool>,
118 nacked: Arc<AtomicBool>,
119 requeued: Arc<AtomicBool>,
120 }
121
122 impl MockAckHandle {
123 fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>, Arc<AtomicBool>) {
124 let acked = Arc::new(AtomicBool::new(false));
125 let nacked = Arc::new(AtomicBool::new(false));
126 let requeued = Arc::new(AtomicBool::new(false));
127 (
128 Self {
129 acked: acked.clone(),
130 nacked: nacked.clone(),
131 requeued: requeued.clone(),
132 },
133 acked,
134 nacked,
135 requeued,
136 )
137 }
138 }
139
140 #[async_trait]
141 impl AckHandle for MockAckHandle {
142 async fn ack(&self) -> crate::WorkerResult<()> {
143 self.acked.store(true, Ordering::SeqCst);
144 Ok(())
145 }
146
147 async fn nack(&self, requeue: bool) -> crate::WorkerResult<()> {
148 self.nacked.store(true, Ordering::SeqCst);
149 self.requeued.store(requeue, Ordering::SeqCst);
150 Ok(())
151 }
152 }
153
154 struct SuccessHandler;
155
156 #[async_trait]
157 impl MessageHandler for SuccessHandler {
158 async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
159 Ok(MiddlewareResult::Continue)
160 }
161 }
162
163 struct FailureHandler;
164
165 #[async_trait]
166 impl MessageHandler for FailureHandler {
167 async fn handle(&self, _message: ReceivedMessage<serde_json::Value>) -> Result<MiddlewareResult, WorkerError> {
168 Err(crate::error::WorkerError::ProcessingFailed(
169 "test error".to_string(),
170 ))
171 }
172 }
173
174 fn create_test_message() -> (
175 ReceivedMessage<serde_json::Value>,
176 Arc<AtomicBool>,
177 Arc<AtomicBool>,
178 Arc<AtomicBool>,
179 ) {
180 let (ack_handle, acked, nacked, requeued) = MockAckHandle::new();
181 let message = Message {
182 id: "test-1".to_string(),
183 payload: serde_json::json!({"test": "data"}),
184 metadata: MessageMetadata::new("test-queue"),
185 };
186 (
187 ReceivedMessage::new(message, Arc::new(ack_handle)),
188 acked,
189 nacked,
190 requeued,
191 )
192 }
193
194 #[tokio::test]
195 async fn test_ack_on_success() {
196 let middleware = AckNackMiddleware::new();
197 let (message, acked, nacked, _) = create_test_message();
198
199 let result = middleware.handle(message, Box::new(SuccessHandler)).await;
200 assert!(matches!(result, Ok(MiddlewareResult::Acknowledged)));
202
203 assert!(acked.load(Ordering::SeqCst));
204 assert!(!nacked.load(Ordering::SeqCst));
205 }
206
207 #[tokio::test]
208 async fn test_nack_on_failure() {
209 let middleware = AckNackMiddleware::new();
210 let (message, acked, nacked, requeued) = create_test_message();
211
212 let result = middleware.handle(message, Box::new(FailureHandler)).await;
213 assert!(matches!(result, Ok(MiddlewareResult::Acknowledged)));
215
216 assert!(!acked.load(Ordering::SeqCst));
217 assert!(nacked.load(Ordering::SeqCst));
218 assert!(requeued.load(Ordering::SeqCst));
219 }
220
221 #[tokio::test]
222 async fn test_no_ack_on_success_when_disabled() {
223 let middleware = AckNackMiddleware::with_config(false, true, true);
224 let (message, acked, _, _) = create_test_message();
225
226 middleware
227 .handle(message, Box::new(SuccessHandler))
228 .await
229 .unwrap();
230
231 assert!(!acked.load(Ordering::SeqCst));
232 }
233
234 #[tokio::test]
235 async fn test_no_nack_on_failure_when_disabled() {
236 let middleware = AckNackMiddleware::with_config(true, false, true);
237 let (message, _, nacked, _) = create_test_message();
238
239 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
240
241 assert!(!nacked.load(Ordering::SeqCst));
242 }
243
244 #[tokio::test]
245 async fn test_nack_without_requeue() {
246 let middleware = AckNackMiddleware::with_config(true, true, false);
247 let (message, _, nacked, requeued) = create_test_message();
248
249 let _ = middleware.handle(message, Box::new(FailureHandler)).await;
250
251 assert!(nacked.load(Ordering::SeqCst));
252 assert!(!requeued.load(Ordering::SeqCst));
253 }
254}