use std::collections::HashSet;
use std::hash::{BuildHasher, Hasher};
use std::marker::PhantomData;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
pub const HISTORY_LEN: u64 = 1024;
struct NoHashHasher<T>(u64, PhantomData<T>);
impl<T> Default for NoHashHasher<T> {
fn default() -> Self {
NoHashHasher(0, PhantomData)
}
}
impl<T> Hasher for NoHashHasher<T> {
fn finish(&self) -> u64 {
self.0
}
fn write(&mut self, _: &[u8]) {
debug_assert!(false, "NoHashHasher::write() called - this indicates incorrect usage with non-u64 types");
tracing::warn!("NoHashHasher::write() called with byte slice - potential collision risk from incorrect usage");
}
fn write_u64(&mut self, n: u64) {
#[cfg(debug_assertions)]
{
if std::mem::size_of::<T>() != std::mem::size_of::<u64>() {
panic!("NoHashHasher<T>::write_u64() called with T != u64 - potential collision risk");
}
}
self.0 = n
}
}
impl<T> BuildHasher for NoHashHasher<T> {
type Hasher = Self;
fn build_hasher(&self) -> Self::Hasher {
Self::default()
}
}
pub struct AntiReplayContainer {
history: Mutex<(u64, HashSet<u64, NoHashHasher<u64>>)>,
counter_out: AtomicU64,
window_size: u64,
}
impl AntiReplayContainer {
pub fn new() -> Self {
Self::with_window_size(HISTORY_LEN)
}
pub fn with_window_size(window_size: u64) -> Self {
Self {
history: Mutex::new((
0,
HashSet::with_capacity_and_hasher(window_size as usize, NoHashHasher::default()),
)),
counter_out: AtomicU64::new(0),
window_size,
}
}
#[inline]
pub fn get_next_pid(&self) -> u64 {
self.counter_out.fetch_add(1, Ordering::Relaxed)
}
pub fn validate_pid(&self, pid: u64) -> bool {
let mut queue = self.history.lock().unwrap_or_else(|e| e.into_inner());
let (ref mut base_counter, ref mut seen_pids) = *queue;
let min_acceptable = *base_counter;
let max_acceptable = base_counter.saturating_add(self.window_size);
let already_seen = seen_pids.contains(&pid);
let too_old = pid < min_acceptable;
let too_far_ahead = pid > max_acceptable;
let event_type = if already_seen {
"REPLAY_ATTACK"
} else if too_old {
"DELAYED_REPLAY"
} else if too_far_ahead {
"TOO_FAR_AHEAD"
} else {
"VALID"
};
if event_type != "VALID" {
tracing::warn!(
"🚨 SECURITY EVENT: PID {} status: {} (min: {}, max: {})",
pid,
event_type,
min_acceptable,
max_acceptable
);
}
let is_valid = !already_seen && !too_old && !too_far_ahead;
if is_valid {
seen_pids.insert(pid);
while pid >= base_counter.saturating_add(self.window_size) {
seen_pids.remove(base_counter);
*base_counter += 1;
}
while seen_pids.len() > self.window_size as usize {
seen_pids.remove(base_counter);
*base_counter += 1;
}
}
is_valid
}
pub fn has_tracked_packets(&self) -> bool {
let counter = self.counter_out.load(Ordering::Relaxed);
let queue = self.history.lock().unwrap();
counter > 0 || queue.0 > 0 || !queue.1.is_empty()
}
pub fn reset(&self) {
self.counter_out.store(0, Ordering::Relaxed);
let mut lock = self.history.lock().unwrap();
lock.0 = 0;
lock.1 = HashSet::with_capacity_and_hasher(
self.window_size as usize,
NoHashHasher::default(),
);
tracing::debug!("🔄 Anti-replay container reset");
}
pub fn current_counter(&self) -> u64 {
self.counter_out.load(Ordering::Relaxed)
}
pub fn history_size(&self) -> usize {
self.history.lock().unwrap().1.len()
}
}
impl Default for AntiReplayContainer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_next_pid() {
let container = AntiReplayContainer::new();
assert_eq!(container.get_next_pid(), 0);
assert_eq!(container.get_next_pid(), 1);
assert_eq!(container.get_next_pid(), 2);
}
#[test]
fn test_validate_fresh_pid() {
let container = AntiReplayContainer::new();
assert!(container.validate_pid(0));
assert!(container.validate_pid(1));
assert!(container.validate_pid(2));
}
#[test]
fn test_reject_duplicate_pid() {
let container = AntiReplayContainer::new();
assert!(container.validate_pid(5));
assert!(!container.validate_pid(5)); }
#[test]
fn test_out_of_order_packets() {
let container = AntiReplayContainer::new();
assert!(container.validate_pid(10));
assert!(container.validate_pid(8)); assert!(container.validate_pid(12));
assert!(container.validate_pid(9));
assert_eq!(container.history_size(), 4);
}
#[test]
fn test_delayed_replay_protection() {
let container = AntiReplayContainer::with_window_size(10);
for i in 0..20 {
container.validate_pid(i);
}
assert!(!container.validate_pid(0));
assert!(!container.validate_pid(5));
}
#[test]
fn test_reset() {
let container = AntiReplayContainer::new();
container.get_next_pid();
container.get_next_pid();
container.validate_pid(100);
assert!(container.has_tracked_packets());
container.reset();
assert_eq!(container.current_counter(), 0);
assert_eq!(container.history_size(), 0);
}
#[test]
fn test_window_sliding() {
let container = AntiReplayContainer::with_window_size(5);
for i in 0..5 {
assert!(container.validate_pid(i));
}
assert_eq!(container.history_size(), 5);
assert!(container.validate_pid(5));
assert_eq!(container.history_size(), 5);
assert!(!container.validate_pid(0));
}
}