use blvm_node::network::replay_protection::{ReplayError, ReplayProtection};
use blvm_node::utils::current_timestamp;
use std::sync::Arc;
use std::time::Duration;
#[tokio::test]
async fn test_message_id_deduplication_basic() {
let protection = ReplayProtection::new();
let message_id = "test-message-id-123";
assert!(protection
.check_message_id(message_id, current_timestamp() as i64)
.await
.is_ok());
let result = protection
.check_message_id(message_id, current_timestamp() as i64)
.await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err().downcast_ref::<ReplayError>(),
Some(ReplayError::DuplicateMessageId(_))
));
}
#[tokio::test]
async fn test_message_id_deduplication_multiple() {
let protection = ReplayProtection::new();
for i in 0..10 {
let msg_id = format!("msg-{i}");
assert!(protection
.check_message_id(&msg_id, current_timestamp() as i64)
.await
.is_ok());
}
assert!(protection
.check_message_id("msg-5", current_timestamp() as i64)
.await
.is_err());
assert!(protection
.check_message_id("msg-new", current_timestamp() as i64)
.await
.is_ok());
}
#[tokio::test]
async fn test_request_id_deduplication_basic() {
let protection = ReplayProtection::new();
let request_id = 12345u64;
assert!(protection.check_request_id(request_id).await.is_ok());
let result = protection.check_request_id(request_id).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err().downcast_ref::<ReplayError>(),
Some(ReplayError::DuplicateRequestId(_))
));
}
#[tokio::test]
async fn test_request_id_deduplication_multiple() {
let protection = ReplayProtection::new();
for i in 1000..1010 {
assert!(protection.check_request_id(i).await.is_ok());
}
assert!(protection.check_request_id(1005).await.is_err());
assert!(protection.check_request_id(2000).await.is_ok());
}
#[test]
fn test_timestamp_validation_current() {
let now = current_timestamp() as i64;
assert!(ReplayProtection::validate_timestamp(now, 3600).is_ok());
}
#[test]
fn test_timestamp_validation_recent_past() {
let now = current_timestamp() as i64;
assert!(ReplayProtection::validate_timestamp(now - 60, 3600).is_ok());
assert!(ReplayProtection::validate_timestamp(now - 1800, 3600).is_ok());
}
#[test]
fn test_timestamp_validation_too_old() {
let now = current_timestamp() as i64;
let result = ReplayProtection::validate_timestamp(now - 7200, 3600);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err().downcast_ref::<ReplayError>(),
Some(ReplayError::TimestampTooOld(_, _))
));
}
#[test]
fn test_timestamp_validation_too_future() {
let now = current_timestamp() as i64;
let result = ReplayProtection::validate_timestamp(now + 600, 3600);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err().downcast_ref::<ReplayError>(),
Some(ReplayError::TimestampTooFuture(_, _))
));
}
#[test]
fn test_timestamp_validation_with_custom_tolerance() {
let now = current_timestamp() as i64;
assert!(ReplayProtection::validate_timestamp_with_tolerance(now + 30, 3600, 60).is_ok());
assert!(ReplayProtection::validate_timestamp_with_tolerance(now + 90, 3600, 60).is_err());
}
#[tokio::test]
async fn test_cleanup_message_ids() {
let protection = ReplayProtection::with_config(
Duration::from_millis(100), Duration::from_millis(200), Duration::from_millis(500), 300,
);
protection
.check_message_id("msg1", current_timestamp() as i64)
.await
.unwrap();
protection
.check_message_id("msg2", current_timestamp() as i64)
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(300)).await;
let (msg_count, req_count) = protection.stats().await;
assert_eq!(msg_count, 0);
assert_eq!(req_count, 0);
}
#[tokio::test]
async fn test_cleanup_request_ids() {
let protection = ReplayProtection::with_config(
Duration::from_millis(100), Duration::from_millis(500), Duration::from_millis(200), 300,
);
protection.check_request_id(1).await.unwrap();
protection.check_request_id(2).await.unwrap();
tokio::time::sleep(Duration::from_millis(300)).await;
let (msg_count, req_count) = protection.stats().await;
assert_eq!(msg_count, 0);
assert_eq!(req_count, 0);
}
#[tokio::test]
async fn test_stats_tracking() {
let protection = ReplayProtection::new();
let (msg_count, req_count) = protection.stats().await;
assert_eq!(msg_count, 0);
assert_eq!(req_count, 0);
protection
.check_message_id("msg1", current_timestamp() as i64)
.await
.unwrap();
protection
.check_message_id("msg2", current_timestamp() as i64)
.await
.unwrap();
protection.check_request_id(100).await.unwrap();
protection.check_request_id(200).await.unwrap();
let (msg_count, req_count) = protection.stats().await;
assert_eq!(msg_count, 2);
assert_eq!(req_count, 2);
}
#[tokio::test]
async fn test_concurrent_message_ids() {
let protection = ReplayProtection::new();
let protection = std::sync::Arc::new(protection);
let mut handles = vec![];
for i in 0..10 {
let protection = Arc::clone(&protection);
let handle = tokio::spawn(async move {
protection
.check_message_id("concurrent-msg", current_timestamp() as i64)
.await
});
handles.push(handle);
}
let mut results = vec![];
for handle in handles {
results.push(handle.await);
}
let successes = results.iter().filter(|r| matches!(r, Ok(Ok(_)))).count();
let failures = results.iter().filter(|r| matches!(r, Ok(Err(_)))).count();
assert_eq!(successes, 1, "Only one concurrent check should succeed");
assert_eq!(failures, 9, "Nine concurrent checks should fail");
}
#[tokio::test]
async fn test_concurrent_request_ids() {
let protection = ReplayProtection::new();
let protection = std::sync::Arc::new(protection);
let mut handles = vec![];
for _ in 0..10 {
let protection = Arc::clone(&protection);
let handle = tokio::spawn(async move { protection.check_request_id(9999).await });
handles.push(handle);
}
let mut results = vec![];
for handle in handles {
results.push(handle.await);
}
let successes = results.iter().filter(|r| matches!(r, Ok(Ok(_)))).count();
let failures = results.iter().filter(|r| matches!(r, Ok(Err(_)))).count();
assert_eq!(successes, 1, "Only one concurrent check should succeed");
assert_eq!(failures, 9, "Nine concurrent checks should fail");
}
#[tokio::test]
async fn test_message_id_expires_after_cleanup() {
let protection = ReplayProtection::with_config(
Duration::from_millis(100), Duration::from_millis(200), Duration::from_millis(500), 300,
);
let message_id = "expiring-msg";
protection
.check_message_id(message_id, current_timestamp() as i64)
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(protection
.check_message_id(message_id, current_timestamp() as i64)
.await
.is_ok());
}
#[tokio::test]
async fn test_request_id_expires_after_cleanup() {
let protection = ReplayProtection::with_config(
Duration::from_millis(100), Duration::from_millis(500), Duration::from_millis(200), 300,
);
let request_id = 8888u64;
protection.check_request_id(request_id).await.unwrap();
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(protection.check_request_id(request_id).await.is_ok());
}