#[cfg(test)]
mod destination_filter_tests {
use derusted::destination_filter::{DestinationError, DestinationFilter};
#[tokio::test]
async fn test_ssrf_localhost_blocked() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let blocked_hosts = vec![
"localhost",
"127.0.0.1",
"127.0.0.2",
"127.255.255.255",
"::1",
"0.0.0.0",
];
for host in blocked_hosts {
let result = filter.check_and_resolve(host).await;
assert!(
matches!(
result,
Err(DestinationError::BlockedHostname(_))
| Err(DestinationError::BlockedIpRange(_))
),
"Host {} should be blocked, got: {:?}",
host,
result
);
}
}
#[tokio::test]
async fn test_ssrf_rfc1918_blocked() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let blocked_ips = vec![
"10.0.0.1", "10.255.255.255",
"172.16.0.1", "172.31.255.255",
"192.168.0.1", "192.168.255.255",
];
for ip in blocked_ips {
let result = filter.check_and_resolve(ip).await;
assert!(
matches!(result, Err(DestinationError::BlockedIpRange(_))),
"IP {} should be blocked, got: {:?}",
ip,
result
);
}
}
#[tokio::test]
async fn test_ssrf_metadata_endpoint_blocked() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let result = filter.check_and_resolve("169.254.169.254").await;
assert!(
matches!(result, Err(DestinationError::BlockedIpRange(_))),
"AWS metadata IP should be blocked"
);
let result = filter.check_and_resolve("169.254.1.1").await;
assert!(
matches!(result, Err(DestinationError::BlockedIpRange(_))),
"Link-local IP should be blocked"
);
}
#[tokio::test]
async fn test_ssrf_internal_domains_blocked() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let blocked_domains = vec![
"server.local",
"api.internal",
"db.lan",
"localhost.localdomain",
];
for domain in blocked_domains {
let result = filter.check_and_resolve(domain).await;
assert!(
matches!(result, Err(DestinationError::BlockedHostname(_))),
"Domain {} should be blocked, got: {:?}",
domain,
result
);
}
}
#[tokio::test]
async fn test_public_domains_allowed() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let allowed_domains = vec!["google.com", "example.com", "httpbin.org"];
for domain in allowed_domains {
let result = filter.check_and_resolve(domain).await;
assert!(
result.is_ok(),
"Domain {} should be allowed, got: {:?}",
domain,
result
);
}
}
#[tokio::test]
async fn test_dns_cache_functionality() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let result1 = filter.check_and_resolve("google.com").await;
assert!(result1.is_ok(), "First resolution should succeed");
let result2 = filter.check_and_resolve("google.com").await;
assert!(result2.is_ok(), "Cached resolution should succeed");
assert_eq!(
result1.unwrap(),
result2.unwrap(),
"Cached result should match original"
);
}
#[tokio::test]
async fn test_invalid_hostname_error() {
let filter = DestinationFilter::new(100, 60, 5).unwrap();
let result = filter.check_and_resolve("invalid..hostname..com").await;
assert!(
matches!(result, Err(DestinationError::ResolutionFailed(_))),
"Invalid hostname should fail resolution"
);
}
}
#[cfg(test)]
mod ip_tracker_tests {
use derusted::ip_tracker::{IpTracker, IpTrackerError};
use std::net::IpAddr;
#[tokio::test]
async fn test_ip_limit_enforcement() {
let tracker = IpTracker::new(5, 1000, 3600);
let token_id = "test-token-1";
let ips: Vec<IpAddr> = vec![
"1.1.1.1".parse().unwrap(),
"2.2.2.2".parse().unwrap(),
"3.3.3.3".parse().unwrap(),
"4.4.4.4".parse().unwrap(),
"5.5.5.5".parse().unwrap(),
];
for (i, ip) in ips.iter().enumerate() {
let result = tracker.check_and_track(token_id, *ip).await;
println!("After adding IP #{}: {:?}", i + 1, result);
assert!(
result.is_ok(),
"Adding IP {} (#{}) should succeed, got: {:?}",
ip,
i + 1,
result
);
}
let count = tracker.get_ip_count(token_id).await;
println!("Current IP count before 6th: {}", count);
assert_eq!(count, 5, "Should have exactly 5 IPs tracked");
let sixth_ip: IpAddr = "6.6.6.6".parse().unwrap();
let result = tracker.check_and_track(token_id, sixth_ip).await;
println!("Result of adding 6th IP: {:?}", result);
assert!(
matches!(result, Err(IpTrackerError::LimitExceeded { .. })),
"6th IP should be rejected, got: {:?}",
result
);
}
#[tokio::test]
async fn test_ip_reuse_allowed() {
let tracker = IpTracker::new(5, 1000, 3600);
let token_id = "test-token-2";
let ip: IpAddr = "1.1.1.1".parse().unwrap();
for _ in 0..10 {
let result = tracker.check_and_track(token_id, ip).await;
assert!(result.is_ok(), "Reusing same IP should always succeed");
}
}
#[tokio::test]
async fn test_dual_stack_normalization() {
let tracker = IpTracker::new(5, 1000, 3600);
let token_id = "test-token-3";
let ipv4: IpAddr = "192.0.2.1".parse().unwrap();
let ipv6_mapped: IpAddr = "::ffff:192.0.2.1".parse().unwrap();
tracker.check_and_track(token_id, ipv4).await.unwrap();
let result = tracker.check_and_track(token_id, ipv6_mapped).await;
assert!(
result.is_ok(),
"IPv6-mapped IPv4 should be normalized to IPv4"
);
}
#[tokio::test]
async fn test_different_tokens_independent() {
let tracker = IpTracker::new(2, 1000, 3600);
let ip1: IpAddr = "1.1.1.1".parse().unwrap();
let ip2: IpAddr = "2.2.2.2".parse().unwrap();
let ip3: IpAddr = "3.3.3.3".parse().unwrap();
tracker.check_and_track("token-1", ip1).await.unwrap();
tracker.check_and_track("token-1", ip2).await.unwrap();
let result = tracker.check_and_track("token-1", ip3).await;
assert!(result.is_err(), "Token 1 should hit limit");
tracker.check_and_track("token-2", ip1).await.unwrap();
tracker.check_and_track("token-2", ip2).await.unwrap();
let result = tracker.check_and_track("token-2", ip3).await;
assert!(result.is_err(), "Token 2 should also hit its limit");
tracker.check_and_track("token-2", ip1).await.unwrap();
}
}
#[cfg(test)]
mod body_limiter_tests {
#[tokio::test]
async fn test_body_under_limit_allowed() {
}
#[tokio::test]
async fn test_body_over_limit_rejected() {
}
#[tokio::test]
async fn test_empty_body_allowed() {
}
#[tokio::test]
async fn test_exact_limit_allowed() {
}
}
#[cfg(test)]
mod rate_limiter_integration_tests {
}