use anyhow::{bail, Result};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::time::{Duration, Instant};
pub struct ReplayProtection {
seen_nonces: RwLock<HashMap<u32, Instant>>,
window: Duration,
max_entries: usize,
}
impl ReplayProtection {
pub fn new() -> Self {
Self {
seen_nonces: RwLock::new(HashMap::new()),
window: Duration::from_secs(120),
max_entries: 100_000,
}
}
pub fn with_config(window: Duration, max_entries: usize) -> Self {
Self {
seen_nonces: RwLock::new(HashMap::new()),
window,
max_entries,
}
}
pub fn check_and_record(&self, nonce: u32) -> Result<()> {
let mut seen = self.seen_nonces.write();
if let Some(first_seen) = seen.get(&nonce) {
if first_seen.elapsed() < self.window {
bail!(
"Replay attack detected: duplicate nonce {} (first seen {:?} ago)",
nonce,
first_seen.elapsed()
);
}
seen.remove(&nonce);
}
if seen.len() >= self.max_entries {
self.cleanup_old_nonces(&mut seen);
if seen.len() >= self.max_entries {
bail!("Replay protection cache full (possible DoS attack)");
}
}
seen.insert(nonce, Instant::now());
Ok(())
}
fn cleanup_old_nonces(&self, seen: &mut HashMap<u32, Instant>) {
let now = Instant::now();
seen.retain(|_, first_seen| now.duration_since(*first_seen) < self.window);
}
pub fn stats(&self) -> ReplayProtectionStats {
let seen = self.seen_nonces.read();
ReplayProtectionStats {
total_nonces: seen.len(),
window_seconds: self.window.as_secs(),
max_entries: self.max_entries,
}
}
#[cfg(test)]
pub fn clear(&self) {
self.seen_nonces.write().clear();
}
}
impl Default for ReplayProtection {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ReplayProtectionStats {
pub total_nonces: usize,
pub window_seconds: u64,
pub max_entries: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_duplicate_nonce_rejected() {
let rp = ReplayProtection::new();
assert!(rp.check_and_record(12345).is_ok());
assert!(rp.check_and_record(12345).is_err());
}
#[test]
fn test_different_nonces_accepted() {
let rp = ReplayProtection::new();
assert!(rp.check_and_record(1).is_ok());
assert!(rp.check_and_record(2).is_ok());
assert!(rp.check_and_record(3).is_ok());
}
#[test]
fn test_nonce_expires() {
let rp = ReplayProtection::with_config(Duration::from_millis(100), 1000);
assert!(rp.check_and_record(999).is_ok());
std::thread::sleep(Duration::from_millis(150));
assert!(rp.check_and_record(999).is_ok());
}
#[test]
fn test_cleanup() {
let rp = ReplayProtection::with_config(Duration::from_millis(50), 10);
for i in 0..10 {
assert!(rp.check_and_record(i).is_ok());
}
std::thread::sleep(Duration::from_millis(100));
assert!(rp.check_and_record(999).is_ok());
let stats = rp.stats();
assert_eq!(stats.total_nonces, 1);
}
}