1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
// mq-bridge
// © Copyright 2025, by Marco Mengelkoch
// Licensed under MIT License, see License file for more details
// git clone https://github.com/marcomq/mq-bridge
use crate::models::DeduplicationMiddleware;
use crate::traits::{
into_batch_commit_func, BoxFuture, ConsumerError, MessageConsumer, MessageDisposition,
Received, ReceivedBatch,
};
use anyhow::Context;
use async_trait::async_trait;
use sled::Db;
use std::any::Any;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tracing::{debug, error, info, instrument, trace, warn};
pub struct DeduplicationConsumer {
inner: Box<dyn MessageConsumer>,
db: Arc<Db>,
ttl_seconds: u64,
}
impl DeduplicationConsumer {
pub fn new(
inner: Box<dyn MessageConsumer>,
config: &DeduplicationMiddleware,
route_name: &str,
) -> anyhow::Result<Self> {
info!(
"Deduplication Middleware enabled for route '{}' with TTL {}s",
route_name, config.ttl_seconds
);
let db = sled::open(&config.sled_path)?;
Ok(Self {
inner,
db: Arc::new(db),
ttl_seconds: config.ttl_seconds,
})
}
}
#[async_trait]
impl MessageConsumer for DeduplicationConsumer {
fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
self.inner.on_connect_hook()
}
fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
let inner_hook = self.inner.on_disconnect_hook();
let db = self.db.clone();
Some(Box::pin(async move {
let mut first_error = None;
if let Some(hook) = inner_hook {
if let Err(err) = hook.await {
first_error = Some(err);
}
}
if let Err(err) = db.flush_async().await {
first_error.get_or_insert_with(|| anyhow::anyhow!(err));
}
match first_error {
Some(err) => Err(err),
None => Ok(()),
}
}))
}
#[instrument(skip_all)]
async fn receive(&mut self) -> Result<Received, ConsumerError> {
loop {
let received = self.inner.receive().await?;
let message = received.message;
let original_commit = received.commit;
let key = message.message_id.to_be_bytes().to_vec();
let message_id_hex = format!("{:032x}", message.message_id);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.context("System time is before UNIX EPOCH")?
.as_secs();
let now_bytes = now.to_be_bytes();
// Use a prefix to distinguish between pending (0) and processed (1) states.
// Pending state has a short TTL to allow recovery from crashes.
const STATE_PENDING: u8 = 0;
const STATE_PROCESSED: u8 = 1;
const PENDING_TTL: u64 = 5;
let mut pending_val = Vec::with_capacity(9);
pending_val.push(STATE_PENDING);
pending_val.extend_from_slice(&now_bytes);
let mut processed_val = Vec::with_capacity(9);
processed_val.push(STATE_PROCESSED);
processed_val.extend_from_slice(&now_bytes);
// Attempt atomic insert-if-absent to reserve the message ID
let mut is_duplicate = false;
let mut yield_counter = 0;
let mut total_attempts = 0;
const MAX_TOTAL_ATTEMPTS: usize = 1000;
loop {
if total_attempts >= MAX_TOTAL_ATTEMPTS {
return Err(ConsumerError::Connection(anyhow::anyhow!(
"Deduplication CAS exceeded max attempts for message ID {}",
message_id_hex
)));
}
if yield_counter > 10 {
tokio::task::yield_now().await;
yield_counter = 0;
}
yield_counter += 1;
total_attempts += 1;
match self
.db
.compare_and_swap(&key, None::<&[u8]>, Some(pending_val.as_slice()))
{
Ok(Ok(())) => break,
Ok(Err(cas_error)) => {
if let Some(current_bytes) = cas_error.current.as_deref() {
// Key exists. Check if it is within TTL.
let (ts, ttl) = if current_bytes.len() == 9 {
let state = current_bytes[0];
let ts_bytes: [u8; 8] = current_bytes[1..9].try_into().unwrap();
(
u64::from_be_bytes(ts_bytes),
if state == STATE_PENDING {
PENDING_TTL
} else {
self.ttl_seconds
},
)
} else if current_bytes.len() == 8 {
let ts_bytes: [u8; 8] = current_bytes.try_into().unwrap();
(u64::from_be_bytes(ts_bytes), self.ttl_seconds)
} else {
(0, 0) // Invalid length, treat as expired
};
if now.saturating_sub(ts) < ttl {
is_duplicate = true;
break;
}
// Expired or invalid, try to overwrite
match self.db.compare_and_swap(
&key,
Some(current_bytes),
Some(pending_val.as_slice()),
) {
Ok(Ok(())) => break,
Ok(Err(_)) => continue, // Retry
Err(e) => {
return Err(ConsumerError::Connection(anyhow::anyhow!(
"Deduplication DB error: {}",
e
)))
}
}
} else {
continue;
}
}
Err(e) => {
return Err(ConsumerError::Connection(anyhow::anyhow!(
"Deduplication DB error: {}",
e
)))
}
}
}
if is_duplicate {
info!(message_id = %message_id_hex, "Duplicate message detected and skipped");
if let Err(e) = original_commit(MessageDisposition::Ack).await {
warn!("Failed to commit skipped duplicate message: {}", e);
}
continue;
}
let db = self.db.clone();
let key_clone = key.clone();
// Wrap commit to update DB to "processed" state
let commit = Box::new(move |disposition: MessageDisposition| {
Box::pin(async move {
original_commit(disposition).await?;
// Update the pending marker to the final processed value
if let Err(e) = db.insert(&key_clone, processed_val) {
error!(
"Failed to update message as processed in deduplication DB: {}",
e
);
} else {
trace!("Updated message as processed in deduplication DB");
}
Ok(())
}) as crate::traits::BoxFuture<'static, anyhow::Result<()>>
});
// remove outdated
if rand::random::<u8>() < 5 {
// ~2% chance
let db = self.db.clone();
let ttl = self.ttl_seconds;
tokio::spawn(async move {
let now_duration = match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(duration) => duration,
Err(e) => {
error!("Failed to get system time duration since UNIX_EPOCH for deduplication cleanup: {}", e);
return; // Exit the spawned task if we can't get the current time
}
};
// Use saturating_sub to prevent underflow if ttl is very large, though unlikely for timestamps.
let cutoff = now_duration.as_secs().saturating_sub(ttl);
for item_result in db.iter() {
match item_result {
Ok((key, val)) => {
let len = val.as_ref().len();
let ts_offset = if len == 9 {
1
} else if len == 8 {
0
} else {
warn!("Deduplication DB entry for key {:?} has invalid timestamp length (expected 8 or 9 bytes, got {}). Skipping entry.", key, len);
continue; // Move to the next item
};
// After checking the length, `try_into()` from `&[u8]` to `&[u8; 8]` is infallible.
// However, using `match` explicitly handles the `Err` case for robustness and clarity.
let timestamp_bytes: [u8; 8] = match val.as_ref()
[ts_offset..ts_offset + 8]
.try_into()
{
Ok(bytes) => bytes,
Err(e) => {
error!("Internal error: Failed to convert DB value to [u8; 8] after length check for key {:?}: {}", key, e);
continue; // Move to the next item
}
};
let timestamp = u64::from_be_bytes(timestamp_bytes);
// If the timestamp is older than the cutoff, remove it.
if timestamp < cutoff {
match db.remove(&key) {
Ok(_) => debug!("Removed expired deduplication entry for key: {:?}", key),
Err(e) => error!("Failed to remove expired deduplication entry for key {:?}: {}", key, e),
}
}
}
Err(e) => {
error!("Error iterating deduplication DB during cleanup: {}", e);
continue; // Continue to the next item if iteration itself yields an error
}
}
}
});
}
return Ok(Received { message, commit });
}
}
/// Note: This implementation ignores `_max_messages` and always fetches a single message
/// to ensure correct deduplication logic per message.
async fn receive_batch(
&mut self,
_max_messages: usize,
) -> Result<ReceivedBatch, ConsumerError> {
let received = self.receive().await?;
let commit = into_batch_commit_func(received.commit);
Ok(ReceivedBatch {
messages: vec![received.message],
commit,
})
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::endpoints::memory::MemoryConsumer;
use crate::models::DeduplicationMiddleware;
use crate::CanonicalMessage;
use tempfile::tempdir;
#[tokio::test]
async fn test_deduplication_logic() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("dedup_test").to_str().unwrap().to_string();
let config = DeduplicationMiddleware {
sled_path: db_path,
ttl_seconds: 60,
};
let mem_consumer = MemoryConsumer::new_local("dedup_topic", 10);
let channel = mem_consumer.channel();
// 1. Send a message
let msg1 = CanonicalMessage::new(b"data1".to_vec(), Some(100));
channel.send_message(msg1).await.unwrap();
// 2. Send a duplicate message
let msg2 = CanonicalMessage::new(b"data1_dup".to_vec(), Some(100));
channel.send_message(msg2).await.unwrap();
// 3. Send a new message
let msg3 = CanonicalMessage::new(b"data2".to_vec(), Some(101));
channel.send_message(msg3).await.unwrap();
let mut dedup_consumer =
DeduplicationConsumer::new(Box::new(mem_consumer), &config, "test_route").unwrap();
// First receive: Should be msg1 (ID 100)
let rec1 = dedup_consumer.receive().await.unwrap();
assert_eq!(rec1.message.message_id, 100);
let _ = (rec1.commit)(crate::traits::MessageDisposition::Ack).await;
// Second receive: Should be msg3 (ID 101). msg2 (ID 100) is skipped internally.
let rec2 = dedup_consumer.receive().await.unwrap();
assert_eq!(rec2.message.message_id, 101);
let _ = (rec2.commit)(crate::traits::MessageDisposition::Ack).await;
}
}