use super::*;
use axum::http::HeaderValue;
use tower::ServiceExt;
#[derive(Clone)]
struct OkService;
impl Service<Request<Body>> for OkService {
type Response = Response<Body>;
type Error = std::convert::Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let waker = cx.waker();
assert!(
waker.will_wake(waker),
"Waker::will_wake must hold reflexively"
);
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
tracing::trace!(method = %req.method(), "OkService call");
Box::pin(async { Ok(Response::new(Body::empty())) })
}
}
fn make_request_with_ip(ip: &str) -> Request<Body> {
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert("x-forwarded-for", HeaderValue::from_str(ip).unwrap());
req
}
fn make_service(max: u32, window: Duration) -> RateLimitService<OkService> {
let config = RateLimitConfig::builder()
.max_requests(max)
.window(window)
.key(KeyExtractor::ForwardedIp)
.build();
let layer = RateLimitLayer::new(config);
layer.layer(OkService)
}
#[tokio::test]
async fn request_within_limit_passes() {
let svc = make_service(5, Duration::from_secs(60));
let req = make_request_with_ip("1.2.3.4");
let res = ServiceExt::oneshot(svc, req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn request_exceeding_limit_returns_429() {
let svc = make_service(2, Duration::from_secs(60));
let res = ServiceExt::oneshot(svc.clone(), make_request_with_ip("10.0.0.1"))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = ServiceExt::oneshot(svc.clone(), make_request_with_ip("10.0.0.1"))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = ServiceExt::oneshot(svc.clone(), make_request_with_ip("10.0.0.1"))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn window_reset_allows_new_requests() {
let store = BucketStore::new(1, Duration::from_secs(1));
let t0 = Instant::now();
assert!(matches!(
store.try_acquire_at("k", t0),
Acquire::Allowed { .. }
));
assert!(matches!(
store.try_acquire_at("k", t0),
Acquire::Limited { .. }
));
let t1 = t0 + Duration::from_secs(2);
assert!(matches!(
store.try_acquire_at("k", t1),
Acquire::Allowed { .. }
));
}
#[tokio::test]
async fn per_key_isolation() {
let svc = make_service(1, Duration::from_secs(60));
let res = ServiceExt::oneshot(svc.clone(), make_request_with_ip("10.0.0.1"))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = ServiceExt::oneshot(svc.clone(), make_request_with_ip("10.0.0.1"))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
let res = ServiceExt::oneshot(svc.clone(), make_request_with_ip("10.0.0.2"))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn retry_after_header_present_when_configured() {
let config = RateLimitConfig::builder()
.max_requests(1)
.window(Duration::from_secs(120))
.key(KeyExtractor::ForwardedIp)
.retry_after(true)
.build();
let svc = RateLimitLayer::new(config).layer(OkService);
let _ = ServiceExt::oneshot(svc.clone(), make_request_with_ip("5.5.5.5"))
.await
.unwrap();
let res = ServiceExt::oneshot(svc.clone(), make_request_with_ip("5.5.5.5"))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
let header = res
.headers()
.get("retry-after")
.expect("missing Retry-After header");
let secs: u64 = header.to_str().unwrap().parse().unwrap();
assert!(secs > 0);
}
#[tokio::test]
async fn retry_after_header_absent_when_disabled() {
let config = RateLimitConfig::builder()
.max_requests(1)
.window(Duration::from_secs(60))
.key(KeyExtractor::ForwardedIp)
.retry_after(false)
.build();
let svc = RateLimitLayer::new(config).layer(OkService);
let _ = ServiceExt::oneshot(svc.clone(), make_request_with_ip("6.6.6.6"))
.await
.unwrap();
let res = ServiceExt::oneshot(svc.clone(), make_request_with_ip("6.6.6.6"))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
assert!(res.headers().get("retry-after").is_none());
}
#[tokio::test]
async fn concurrent_access_does_not_panic() {
let svc = make_service(1000, Duration::from_secs(60));
let handles: Vec<_> = (0..50)
.map(|i| {
let svc = svc.clone();
tokio::spawn(async move {
let ip = format!("10.0.{}.{}", i / 256, i % 256);
let req = make_request_with_ip(&ip);
let res = ServiceExt::oneshot(svc, req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
})
})
.collect();
for h in handles {
h.await.unwrap();
}
}
#[tokio::test]
async fn evict_expired_removes_stale_buckets() {
let store = BucketStore::new(1, Duration::from_secs(10));
let t0 = Instant::now();
store.try_acquire_at("stale", t0);
let t1 = t0 + Duration::from_secs(20);
store.try_acquire_at("fresh", t1);
store.evict_expired_at(t1);
assert!(!store.buckets.contains_key("stale"));
assert!(store.buckets.contains_key("fresh"));
}
#[tokio::test]
async fn header_key_extractor() {
let config = RateLimitConfig::builder()
.max_requests(1)
.window(Duration::from_secs(60))
.key(KeyExtractor::Header("x-api-key".into()))
.build();
let svc = RateLimitLayer::new(config).layer(OkService);
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert("x-api-key", HeaderValue::from_static("key-a"));
let res = ServiceExt::oneshot(svc.clone(), req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert("x-api-key", HeaderValue::from_static("key-a"));
let res = ServiceExt::oneshot(svc.clone(), req).await.unwrap();
assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert("x-api-key", HeaderValue::from_static("key-b"));
let res = ServiceExt::oneshot(svc.clone(), req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn eviction_triggers_on_request_count_not_just_bucket_size() {
let store = BucketStore::new(100, Duration::from_secs(1));
let t0 = Instant::now();
store.try_acquire_at("stale", t0);
let t1 = t0 + Duration::from_secs(5);
for _ in 0..1024 {
store.try_acquire_at("active", t1);
store.request_count.fetch_add(1, Ordering::Relaxed);
}
store.evict_expired_at(t1);
assert!(
!store.buckets.contains_key("stale"),
"stale bucket should have been evicted"
);
assert!(store.buckets.contains_key("active"));
}
fn make_login_request(identifier: Option<&str>) -> Request<Body> {
let mut req = Request::new(Body::empty());
if let Some(id) = identifier
&& let Some(ext) = RateLimitLoginIdentifier::new(id)
{
req.extensions_mut().insert(ext);
}
req
}
fn make_login_service(max: u32) -> RateLimitService<OkService> {
let config = RateLimitConfig::builder()
.max_requests(max)
.window(Duration::from_secs(60))
.key(KeyExtractor::LoginIdentifier)
.build();
RateLimitLayer::new(config).layer(OkService)
}
#[tokio::test]
async fn login_identifier_buckets_attempts_per_username() {
let svc = make_login_service(2);
let res = ServiceExt::oneshot(svc.clone(), make_login_request(Some("alice")))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = ServiceExt::oneshot(svc.clone(), make_login_request(Some("alice")))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = ServiceExt::oneshot(svc.clone(), make_login_request(Some("alice")))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
let res = ServiceExt::oneshot(svc.clone(), make_login_request(Some("bob")))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn login_identifier_normalises_case_and_whitespace() {
let svc = make_login_service(1);
let res = ServiceExt::oneshot(svc.clone(), make_login_request(Some("Alice")))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = ServiceExt::oneshot(svc.clone(), make_login_request(Some(" ALICE ")))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn login_identifier_missing_extension_falls_to_anonymous_bucket() {
let svc = make_login_service(1);
let res = ServiceExt::oneshot(svc.clone(), make_login_request(None))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = ServiceExt::oneshot(svc.clone(), make_login_request(None))
.await
.unwrap();
assert_eq!(
res.status(),
StatusCode::TOO_MANY_REQUESTS,
"missing extension must funnel into the shared anonymous bucket"
);
}
#[test]
fn login_identifier_rejects_empty_input() {
assert!(RateLimitLoginIdentifier::new("").is_none());
assert!(RateLimitLoginIdentifier::new(" ").is_none());
assert!(RateLimitLoginIdentifier::new("alice").is_some());
}
#[test]
fn truncate_key_at_max_key_len_unchanged() {
let key = "x".repeat(MAX_KEY_LEN);
let out = truncate_key(key.clone());
assert_eq!(
out.len(),
MAX_KEY_LEN,
"key length at exactly MAX_KEY_LEN must pass through"
);
assert_eq!(out, key, "at-cap key must be byte-identical");
}
#[test]
fn truncate_key_one_byte_over_cap_is_truncated() {
let key = "x".repeat(MAX_KEY_LEN + 1);
let out = truncate_key(key);
assert_eq!(
out.len(),
MAX_KEY_LEN,
"one byte over MAX_KEY_LEN must truncate to MAX_KEY_LEN"
);
}
#[test]
fn truncate_key_walks_back_to_char_boundary_on_multibyte() {
let key = "a".repeat(MAX_KEY_LEN - 2) + "🚀";
assert_eq!(key.len(), MAX_KEY_LEN + 2, "test fixture: 258-byte key");
let out = truncate_key(key);
assert_eq!(
out.len(),
MAX_KEY_LEN - 2,
"walk-back must land at UTF-8 boundary (254), not 256/258"
);
assert!(
out.is_char_boundary(out.len()),
"truncated key must end on a char boundary"
);
assert!(
!out.contains('🚀'),
"truncated key must drop the straddling multi-byte char"
);
}
#[tokio::test]
async fn evict_expired_delegates_to_at_helper() {
let store = BucketStore::new(1, Duration::from_millis(20));
store.try_acquire("aged-out");
assert!(store.buckets.contains_key("aged-out"));
tokio::time::sleep(Duration::from_millis(50)).await;
store.evict_expired();
assert!(
!store.buckets.contains_key("aged-out"),
"evict_expired() must delegate to evict_expired_at(Instant::now()), \
not be a no-op"
);
}
#[tokio::test]
async fn evict_expired_at_evicts_when_age_equals_window() {
let store = BucketStore::new(1, Duration::from_secs(10));
let t0 = Instant::now();
store.try_acquire_at("at-cutoff", t0);
store.evict_expired_at(t0 + Duration::from_secs(10));
assert!(
!store.buckets.contains_key("at-cutoff"),
"at age == window the bucket must evict (strict `<`, not `<=`)"
);
}
#[test]
fn extract_peer_ip_reads_socket_addr_from_extensions() {
use std::net::SocketAddr;
let mut req = Request::new(Body::empty());
let addr: SocketAddr = "192.0.2.7:54321".parse().expect("test addr");
req.extensions_mut().insert(addr);
let ip = extract_peer_ip(&req);
assert_eq!(
ip, "192.0.2.7",
"extract_peer_ip must return the SocketAddr's IP"
);
let req = Request::new(Body::empty());
let ip = extract_peer_ip(&req);
assert_eq!(
ip, "unknown",
"missing SocketAddr must fall through to the shared 'unknown' bucket"
);
}
#[test]
fn warn_threshold_helpers_have_strict_bounds() {
assert!(should_warn_very_low_max_requests(0));
assert!(should_warn_very_low_max_requests(4));
assert!(!should_warn_very_low_max_requests(5));
assert!(!should_warn_very_low_max_requests(100));
use std::time::Duration;
assert!(!should_warn_very_long_window(Duration::from_secs(3600)));
assert!(should_warn_very_long_window(Duration::from_secs(3601)));
assert!(!should_warn_very_long_window(Duration::from_secs(60)));
}