use crate::error::StreamError;
use crate::state::distributed_state::StateBackend;
use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MessageId {
pub producer_id: String,
pub partition: u32,
pub sequence: u64,
}
impl MessageId {
pub fn new(producer: &str, partition: u32, seq: u64) -> Self {
Self {
producer_id: producer.to_string(),
partition,
sequence: seq,
}
}
pub fn serialize(&self) -> String {
format!("{}/{}/{}", self.producer_id, self.partition, self.sequence)
}
pub fn parse(s: &str) -> Result<Self, StreamError> {
let parts: Vec<&str> = s.splitn(3, '/').collect();
if parts.len() != 3 {
return Err(StreamError::InvalidInput(format!(
"MessageId must be '<producer>/<partition>/<seq>', got: {s}"
)));
}
let partition = parts[1].parse::<u32>().map_err(|e| {
StreamError::InvalidInput(format!("invalid partition in MessageId: {e}"))
})?;
let sequence = parts[2].parse::<u64>().map_err(|e| {
StreamError::InvalidInput(format!("invalid sequence in MessageId: {e}"))
})?;
Ok(Self {
producer_id: parts[0].to_string(),
partition,
sequence,
})
}
}
#[derive(Debug, Clone)]
pub struct DeduplicationConfig {
pub window_duration: Duration,
pub max_tracked: usize,
}
impl Default for DeduplicationConfig {
fn default() -> Self {
Self {
window_duration: Duration::from_secs(300), max_tracked: 100_000,
}
}
}
pub struct DeduplicationLog {
config: DeduplicationConfig,
processed: HashMap<MessageId, Instant>,
eviction_queue: VecDeque<(MessageId, Instant)>,
}
impl DeduplicationLog {
pub fn new(config: DeduplicationConfig) -> Self {
Self {
processed: HashMap::new(),
eviction_queue: VecDeque::new(),
config,
}
}
pub fn is_duplicate(&self, id: &MessageId) -> bool {
match self.processed.get(id) {
None => false,
Some(&processed_at) => {
processed_at.elapsed() < self.config.window_duration
}
}
}
pub fn mark_processed(&mut self, id: MessageId) {
let now = Instant::now();
while self.processed.len() >= self.config.max_tracked {
if let Some((oldest_id, _)) = self.eviction_queue.pop_front() {
self.processed.remove(&oldest_id);
} else {
break;
}
}
self.eviction_queue.push_back((id.clone(), now));
self.processed.insert(id, now);
}
pub fn evict_expired(&mut self) -> usize {
let deadline = self.config.window_duration;
let mut evicted = 0usize;
while let Some((id, ts)) = self.eviction_queue.front() {
if ts.elapsed() >= deadline {
let id = id.clone();
self.eviction_queue.pop_front();
self.processed.remove(&id);
evicted += 1;
} else {
break;
}
}
evicted
}
pub fn size(&self) -> usize {
self.processed.len()
}
}
pub struct ExactlyOnceTransaction {
pub transaction_id: String,
pub messages: Vec<MessageId>,
pub state_changes: Vec<(Vec<u8>, Vec<u8>)>,
pub started_at: Instant,
pub committed: bool,
}
impl Default for ExactlyOnceTransaction {
fn default() -> Self {
Self::new()
}
}
impl ExactlyOnceTransaction {
pub fn new() -> Self {
Self {
transaction_id: Uuid::new_v4().to_string(),
messages: Vec::new(),
state_changes: Vec::new(),
started_at: Instant::now(),
committed: false,
}
}
pub fn add_message(&mut self, id: MessageId) {
self.messages.push(id);
}
pub fn add_state_change(&mut self, key: Vec<u8>, value: Vec<u8>) {
self.state_changes.push((key, value));
}
pub fn commit(
mut self,
dedup_log: &mut DeduplicationLog,
backend: &dyn StateBackend,
) -> Result<(), StreamError> {
if self.committed {
return Err(StreamError::InvalidOperation(format!(
"transaction {} already committed",
self.transaction_id
)));
}
for (key, value) in &self.state_changes {
backend.put(key, value)?;
}
for id in self.messages.drain(..) {
dedup_log.mark_processed(id);
}
self.committed = true;
Ok(())
}
}
pub struct ExactlyOnceProcessor {
dedup_log: DeduplicationLog,
backend: std::sync::Arc<dyn StateBackend>,
duplicates_filtered: u64,
messages_processed: u64,
}
impl ExactlyOnceProcessor {
pub fn new(config: DeduplicationConfig, backend: std::sync::Arc<dyn StateBackend>) -> Self {
Self {
dedup_log: DeduplicationLog::new(config),
backend,
duplicates_filtered: 0,
messages_processed: 0,
}
}
pub fn process<R, F>(&mut self, id: MessageId, processor: F) -> Result<Option<R>, StreamError>
where
F: FnOnce(&mut ExactlyOnceTransaction) -> Result<R, StreamError>,
{
if self.dedup_log.is_duplicate(&id) {
self.duplicates_filtered += 1;
return Ok(None);
}
let mut txn = ExactlyOnceTransaction::new();
txn.add_message(id);
let result = processor(&mut txn)?;
txn.commit(&mut self.dedup_log, self.backend.as_ref())?;
self.messages_processed += 1;
Ok(Some(result))
}
pub fn maintenance(&mut self) -> usize {
self.dedup_log.evict_expired()
}
pub fn stats(&self) -> ExactlyOnceStats {
ExactlyOnceStats {
dedup_window_size: self.dedup_log.size(),
duplicates_filtered: self.duplicates_filtered,
messages_processed: self.messages_processed,
}
}
}
#[derive(Debug, Clone)]
pub struct ExactlyOnceStats {
pub dedup_window_size: usize,
pub duplicates_filtered: u64,
pub messages_processed: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::distributed_state::InMemoryStateBackend;
use std::sync::Arc;
#[test]
fn test_message_id_round_trip() {
let id = MessageId::new("producer-1", 3, 42);
let serialized = id.serialize();
assert_eq!(serialized, "producer-1/3/42");
let parsed = MessageId::parse(&serialized).unwrap();
assert_eq!(parsed, id);
}
#[test]
fn test_message_id_parse_error() {
assert!(MessageId::parse("bad").is_err());
assert!(MessageId::parse("a/b").is_err());
assert!(MessageId::parse("a/notnum/1").is_err());
assert!(MessageId::parse("a/1/notnum").is_err());
}
#[test]
fn test_dedup_log_basic() {
let config = DeduplicationConfig {
window_duration: Duration::from_secs(60),
max_tracked: 1000,
};
let mut log = DeduplicationLog::new(config);
let id = MessageId::new("p", 0, 1);
assert!(!log.is_duplicate(&id));
log.mark_processed(id.clone());
assert!(log.is_duplicate(&id));
assert_eq!(log.size(), 1);
}
#[test]
fn test_dedup_log_capacity_eviction() {
let config = DeduplicationConfig {
window_duration: Duration::from_secs(60),
max_tracked: 3,
};
let mut log = DeduplicationLog::new(config);
for i in 0..5u64 {
log.mark_processed(MessageId::new("p", 0, i));
}
assert!(log.size() <= 3);
}
#[test]
fn test_dedup_log_expiry() {
let config = DeduplicationConfig {
window_duration: Duration::from_millis(50),
max_tracked: 1000,
};
let mut log = DeduplicationLog::new(config);
let id = MessageId::new("p", 0, 99);
log.mark_processed(id.clone());
assert!(log.is_duplicate(&id));
std::thread::sleep(Duration::from_millis(60));
assert!(!log.is_duplicate(&id));
let evicted = log.evict_expired();
assert_eq!(evicted, 1);
assert_eq!(log.size(), 0);
}
#[test]
fn test_transaction_commit_applies_state() {
let backend = InMemoryStateBackend::new();
let mut dedup = DeduplicationLog::new(DeduplicationConfig::default());
let mut txn = ExactlyOnceTransaction::new();
txn.add_message(MessageId::new("p", 0, 1));
txn.add_state_change(
b"counter".to_vec(),
b"\x01\x00\x00\x00\x00\x00\x00\x00".to_vec(),
);
txn.commit(&mut dedup, &backend).unwrap();
assert_eq!(
backend.get(b"counter").unwrap().as_deref(),
Some(b"\x01\x00\x00\x00\x00\x00\x00\x00".as_ref())
);
assert!(dedup.is_duplicate(&MessageId::new("p", 0, 1)));
}
#[test]
fn test_transaction_double_commit_fails() {
let backend = InMemoryStateBackend::new();
let mut dedup = DeduplicationLog::new(DeduplicationConfig::default());
let txn = ExactlyOnceTransaction::new();
txn.commit(&mut dedup, &backend).unwrap();
}
#[test]
fn test_processor_exactly_once() {
let backend = Arc::new(InMemoryStateBackend::new());
let mut processor = ExactlyOnceProcessor::new(DeduplicationConfig::default(), backend);
let id = MessageId::new("prod", 0, 1);
let mut call_count = 0u32;
let result = processor
.process(id.clone(), |_txn| {
call_count += 1;
Ok(42u32)
})
.unwrap();
assert_eq!(result, Some(42u32));
assert_eq!(call_count, 1);
let result = processor
.process(id, |_txn| {
call_count += 1;
Ok(99u32)
})
.unwrap();
assert_eq!(result, None);
assert_eq!(call_count, 1);
let stats = processor.stats();
assert_eq!(stats.messages_processed, 1);
assert_eq!(stats.duplicates_filtered, 1);
}
}