use crate::utils::current_timestamp;
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::{interval, Instant};
use tracing::{debug, warn};
#[derive(Debug, thiserror::Error)]
pub enum ReplayError {
#[error("Duplicate message ID: {0}")]
DuplicateMessageId(String),
#[error("Timestamp too old: {0} seconds old (max: {1} seconds)")]
TimestampTooOld(i64, u64),
#[error("Timestamp too far in future: {0} seconds ahead (max: {1} seconds)")]
TimestampTooFuture(i64, u64),
#[error("Duplicate request ID: {0}")]
DuplicateRequestId(u64),
}
pub struct ReplayProtection {
message_ids: Arc<Mutex<HashMap<String, u64>>>,
request_ids: Arc<Mutex<HashMap<u64, u64>>>,
cleanup_interval: Duration,
message_id_expiration: Duration,
request_id_expiration: Duration,
future_tolerance: u64,
}
impl ReplayProtection {
pub fn new() -> Self {
Self::with_config(
Duration::from_secs(300), Duration::from_secs(3600), Duration::from_secs(300), 300, )
}
pub fn with_config(
cleanup_interval: Duration,
message_id_expiration: Duration,
request_id_expiration: Duration,
future_tolerance: u64,
) -> Self {
let protection = Self {
message_ids: Arc::new(Mutex::new(HashMap::new())),
request_ids: Arc::new(Mutex::new(HashMap::new())),
cleanup_interval,
message_id_expiration,
request_id_expiration,
future_tolerance,
};
protection.start_cleanup_task();
protection
}
pub async fn check_message_id(&self, message_id: &str, timestamp: i64) -> Result<()> {
let mut message_ids = self.message_ids.lock().await;
if let Some(first_seen) = message_ids.get(message_id) {
return Err(ReplayError::DuplicateMessageId(message_id.to_string()).into());
}
let now = current_timestamp();
message_ids.insert(message_id.to_string(), now);
Ok(())
}
pub async fn check_request_id(&self, request_id: u64) -> Result<()> {
let mut request_ids = self.request_ids.lock().await;
if let Some(_first_seen) = request_ids.get(&request_id) {
return Err(ReplayError::DuplicateRequestId(request_id).into());
}
let now = current_timestamp();
request_ids.insert(request_id, now);
Ok(())
}
pub fn validate_timestamp(timestamp: i64, max_age_seconds: u64) -> Result<()> {
let now = current_timestamp() as i64;
let future_tolerance = 300;
if timestamp > (now + future_tolerance as i64) {
return Err(ReplayError::TimestampTooFuture(timestamp - now, future_tolerance).into());
}
if timestamp < (now - max_age_seconds as i64) {
return Err(ReplayError::TimestampTooOld(now - timestamp, max_age_seconds).into());
}
Ok(())
}
pub fn validate_timestamp_with_tolerance(
timestamp: i64,
max_age_seconds: u64,
future_tolerance: u64,
) -> Result<()> {
let now = current_timestamp() as i64;
if timestamp > (now + future_tolerance as i64) {
return Err(ReplayError::TimestampTooFuture(timestamp - now, future_tolerance).into());
}
if timestamp < (now - max_age_seconds as i64) {
return Err(ReplayError::TimestampTooOld(now - timestamp, max_age_seconds).into());
}
Ok(())
}
fn start_cleanup_task(&self) {
let message_ids = Arc::clone(&self.message_ids);
let request_ids = Arc::clone(&self.request_ids);
let cleanup_interval = self.cleanup_interval;
let message_id_expiration = self.message_id_expiration;
let request_id_expiration = self.request_id_expiration;
tokio::spawn(async move {
let mut interval = interval(cleanup_interval);
loop {
interval.tick().await;
let now = current_timestamp();
let message_id_cutoff = now.saturating_sub(message_id_expiration.as_secs());
let request_id_cutoff = now.saturating_sub(request_id_expiration.as_secs());
let mut msg_ids = message_ids.lock().await;
let before = msg_ids.len();
msg_ids.retain(|_, &mut timestamp| timestamp > message_id_cutoff);
let after = msg_ids.len();
if before > after {
debug!(
"Replay protection cleanup: removed {} old message IDs ({} remaining)",
before - after,
after
);
}
drop(msg_ids);
let mut req_ids = request_ids.lock().await;
let before = req_ids.len();
req_ids.retain(|_, &mut timestamp| timestamp > request_id_cutoff);
let after = req_ids.len();
if before > after {
debug!(
"Replay protection cleanup: removed {} old request IDs ({} remaining)",
before - after,
after
);
}
}
});
}
pub async fn stats(&self) -> (usize, usize) {
let message_ids = self.message_ids.lock().await;
let request_ids = self.request_ids.lock().await;
(message_ids.len(), request_ids.len())
}
}
impl Default for ReplayProtection {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_message_id_deduplication() {
let protection = ReplayProtection::new();
let message_id = "test-message-id-123";
assert!(protection
.check_message_id(message_id, current_timestamp() as i64)
.await
.is_ok());
assert!(protection
.check_message_id(message_id, current_timestamp() as i64)
.await
.is_err());
}
#[tokio::test]
async fn test_request_id_deduplication() {
let protection = ReplayProtection::new();
let request_id = 12345u64;
assert!(protection.check_request_id(request_id).await.is_ok());
assert!(protection.check_request_id(request_id).await.is_err());
}
#[test]
fn test_timestamp_validation() {
let now = current_timestamp() as i64;
assert!(ReplayProtection::validate_timestamp(now, 3600).is_ok());
assert!(ReplayProtection::validate_timestamp(now - 60, 3600).is_ok());
assert!(ReplayProtection::validate_timestamp(now - 4000, 3600).is_err());
assert!(ReplayProtection::validate_timestamp(now + 400, 3600).is_err());
}
#[tokio::test]
async fn test_cleanup() {
let protection = ReplayProtection::with_config(
Duration::from_millis(100), Duration::from_millis(200), Duration::from_millis(200), 300,
);
protection
.check_message_id("msg1", current_timestamp() as i64)
.await
.unwrap();
protection.check_request_id(1).await.unwrap();
tokio::time::sleep(Duration::from_millis(300)).await;
let (msg_count, req_count) = protection.stats().await;
assert_eq!(msg_count, 0);
assert_eq!(req_count, 0);
}
}