use crate::packet::publish::PublishPacket;
use crate::time::{Duration, Instant};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::debug;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct MessageFingerprint {
hash: [u8; 32],
}
#[derive(Clone)]
pub struct LoopPrevention {
seen_messages: Arc<RwLock<HashMap<MessageFingerprint, (Instant, bool)>>>,
ttl: Duration,
max_cache_size: usize,
}
impl Default for LoopPrevention {
fn default() -> Self {
Self::new(Duration::from_secs(60), 10000)
}
}
impl LoopPrevention {
#[allow(clippy::must_use_candidate)]
pub fn new(ttl: Duration, max_cache_size: usize) -> Self {
Self {
seen_messages: Arc::new(RwLock::new(HashMap::new())),
ttl,
max_cache_size,
}
}
#[must_use]
pub fn ttl(&self) -> Duration {
self.ttl
}
#[must_use]
pub fn max_cache_size(&self) -> usize {
self.max_cache_size
}
pub async fn check_message(&self, packet: &PublishPacket) -> bool {
let fingerprint = Self::calculate_fingerprint(packet);
let mut cache = self.seen_messages.write().await;
if cache.len() > self.max_cache_size {
self.cleanup_cache(&mut cache);
}
if let Some((first_seen, already_warned)) = cache.get_mut(&fingerprint) {
if first_seen.elapsed() < self.ttl {
if !*already_warned {
debug!("Message loop detected for topic: {}", packet.topic_name);
*already_warned = true;
}
return false;
}
}
cache.insert(fingerprint, (Instant::now(), false));
debug!(
"Message fingerprint recorded for topic: {}, cache size: {}",
packet.topic_name,
cache.len()
);
true
}
fn calculate_fingerprint(packet: &PublishPacket) -> MessageFingerprint {
let mut hasher = Sha256::new();
hasher.update(packet.topic_name.as_bytes());
hasher.update(&packet.payload);
hasher.update([u8::from(packet.qos)]);
hasher.update([u8::from(packet.retain)]);
let result = hasher.finalize();
let mut hash = [0u8; 32];
hash.copy_from_slice(&result);
MessageFingerprint { hash }
}
fn cleanup_cache(&self, cache: &mut HashMap<MessageFingerprint, (Instant, bool)>) {
let now = Instant::now();
let ttl = self.ttl;
cache.retain(|_, (first_seen, _)| now.duration_since(*first_seen) < ttl);
debug!("Loop prevention cache cleaned, size: {}", cache.len());
}
pub async fn clear_cache(&self) {
self.seen_messages.write().await.clear();
debug!("Loop prevention cache cleared");
}
pub async fn cache_size(&self) -> usize {
self.seen_messages.read().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::QoS;
#[tokio::test]
async fn test_loop_detection() {
let loop_prevention = LoopPrevention::new(Duration::from_secs(5), 100);
let packet = PublishPacket::new(
"test/topic".to_string(),
&b"test payload"[..],
QoS::AtLeastOnce,
);
assert!(loop_prevention.check_message(&packet).await);
assert!(!loop_prevention.check_message(&packet).await);
let packet2 = PublishPacket::new(
"test/topic2".to_string(),
&b"test payload"[..],
QoS::AtLeastOnce,
);
assert!(loop_prevention.check_message(&packet2).await);
}
#[tokio::test]
async fn test_ttl_expiration() {
let loop_prevention = LoopPrevention::new(Duration::from_millis(100), 100);
let packet = PublishPacket::new(
"test/topic".to_string(),
&b"test payload"[..],
QoS::AtMostOnce,
);
assert!(loop_prevention.check_message(&packet).await);
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(loop_prevention.check_message(&packet).await);
}
#[tokio::test]
async fn test_different_qos_different_fingerprint() {
let loop_prevention = LoopPrevention::new(Duration::from_secs(5), 100);
let packet1 = PublishPacket::new(
"test/topic".to_string(),
&b"test payload"[..],
QoS::AtMostOnce,
);
let packet2 = PublishPacket::new(
"test/topic".to_string(),
&b"test payload"[..],
QoS::AtLeastOnce,
);
assert!(loop_prevention.check_message(&packet1).await);
assert!(loop_prevention.check_message(&packet2).await);
}
}