use std::convert::Infallible;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
use crate::combinator::bulkhead::{Bulkhead, BulkheadPolicy};
use crate::combinator::circuit_breaker::{CircuitBreaker, CircuitBreakerPolicy};
use crate::combinator::rate_limit::{RateLimitPolicy, RateLimiter};
use crate::combinator::retry::RetryPolicy;
use crate::cx::Cx;
use crate::http::compress::{
ContentEncoding, DEFAULT_MAX_COMPRESSED_SIZE, make_compressor_with_output_limit,
negotiate_encoding,
};
use crate::tracing_compat::{debug, warn};
use crate::types::Time;
use futures_lite::FutureExt;
use super::extract::Request;
use super::handler::Handler;
use super::response::{IntoResponse, Redirect, Response, StatusCode};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CorsAllowOrigin {
Any,
Exact(Vec<String>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CorsPolicy {
pub allow_origin: CorsAllowOrigin,
pub allow_methods: Vec<String>,
pub allow_headers: Vec<String>,
pub expose_headers: Vec<String>,
pub max_age: Option<Duration>,
pub allow_credentials: bool,
}
const DEFAULT_ALLOW_HEADERS: &[&str] = &[
"Accept",
"Accept-Language",
"Content-Type",
"Authorization",
"X-Requested-With",
];
impl Default for CorsPolicy {
fn default() -> Self {
Self {
allow_origin: CorsAllowOrigin::Any,
allow_methods: vec![
"GET".to_string(),
"POST".to_string(),
"PUT".to_string(),
"PATCH".to_string(),
"DELETE".to_string(),
"HEAD".to_string(),
"OPTIONS".to_string(),
],
allow_headers: DEFAULT_ALLOW_HEADERS
.iter()
.map(|s| (*s).to_string())
.collect(),
expose_headers: Vec::new(),
max_age: Some(Duration::from_mins(10)),
allow_credentials: false,
}
}
}
impl CorsPolicy {
#[must_use]
pub fn with_exact_origins(origins: impl IntoIterator<Item = String>) -> Self {
Self {
allow_origin: CorsAllowOrigin::Exact(origins.into_iter().collect()),
..Self::default()
}
}
#[must_use]
pub fn with_any_headers() -> Self {
Self {
allow_headers: vec!["*".to_string()],
..Self::default()
}
}
}
pub struct CorsMiddleware<H> {
inner: H,
policy: CorsPolicy,
}
impl<H: Handler> CorsMiddleware<H> {
#[must_use]
pub fn new(inner: H, policy: CorsPolicy) -> Self {
if matches!(policy.allow_origin, CorsAllowOrigin::Any) && policy.allow_credentials {
debug_assert!(
false,
"CorsPolicy violates Fetch §3.2.5: allow_origin = Any with \
allow_credentials = true is a credential-reflection vulnerability. \
Use CorsPolicy::with_exact_origins(...) when allow_credentials is true."
);
crate::tracing_compat::warn!(
"CorsPolicy: allow_origin=Any with allow_credentials=true — \
forbidden by Fetch §3.2.5; per-request Origin will be echoed \
(credential reflection). Use exact-origin allow-list instead."
);
}
Self { inner, policy }
}
fn is_preflight(req: &Request) -> bool {
req.method.eq_ignore_ascii_case("OPTIONS")
&& header_value(req, "origin").is_some()
&& header_value(req, "access-control-request-method").is_some()
}
fn is_malformed_origin_value(origin: &str) -> bool {
origin.contains(',')
}
fn allowed_origin_value(&self, origin: &str) -> Option<String> {
match &self.policy.allow_origin {
CorsAllowOrigin::Any => {
if self.policy.allow_credentials {
crate::tracing_compat::warn!(
origin = %origin,
"CorsMiddleware: dropping Access-Control-Allow-Origin \
header — Allow-Origin = Any with Allow-Credentials = \
true is forbidden by Fetch §3.2.5 (credential \
reflection). Configure CorsPolicy::with_exact_origins \
when credentials are enabled."
);
None
} else {
Some("*".to_string())
}
}
CorsAllowOrigin::Exact(origins) => origins
.iter()
.find(|candidate| candidate.eq_ignore_ascii_case(origin))
.cloned(),
}
}
fn apply_common_headers(&self, mut resp: Response, allow_origin: &str) -> Response {
resp.set_header("access-control-allow-origin", allow_origin);
append_vary_header(&mut resp, "origin");
if self.policy.allow_credentials {
resp.set_header("access-control-allow-credentials", "true");
}
if !self.policy.expose_headers.is_empty() {
resp.set_header(
"access-control-expose-headers",
self.policy.expose_headers.join(", "),
);
}
resp
}
}
impl<H: Handler> Handler for CorsMiddleware<H> {
fn call(
&self,
cx: &crate::Cx,
req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let Some(origin) = header_value(&req, "origin") else {
return self.inner.call(&cx, req).await;
};
if Self::is_malformed_origin_value(&origin) {
crate::tracing_compat::warn!(
origin = %origin,
"CorsMiddleware: dropping malformed multi-origin request header"
);
return self.inner.call(&cx, req).await;
}
let Some(allow_origin) = self.allowed_origin_value(&origin) else {
return self.inner.call(&cx, req).await;
};
if Self::is_preflight(&req) {
let mut resp = Response::empty(StatusCode::NO_CONTENT);
resp = self.apply_common_headers(resp, &allow_origin);
resp.headers.insert(
"access-control-allow-methods".to_string(),
self.policy.allow_methods.join(", "),
);
resp.headers.insert(
"access-control-allow-headers".to_string(),
self.policy.allow_headers.join(", "),
);
if let Some(max_age) = self.policy.max_age {
resp.headers.insert(
"access-control-max-age".to_string(),
max_age.as_secs().to_string(),
);
}
append_vary_header(&mut resp, "origin");
append_vary_header(&mut resp, "access-control-request-method");
append_vary_header(&mut resp, "access-control-request-headers");
return resp;
}
let resp = self.inner.call(&cx, req).await;
self.apply_common_headers(resp, &allow_origin)
})
}
}
fn header_value(req: &Request, header_name: &str) -> Option<String> {
req.headers
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case(header_name))
.map(|(_, value)| value.clone())
}
fn append_vary_header(resp: &mut Response, token: &str) {
fn push_vary_token(tokens: &mut Vec<String>, token: &str) {
let normalized = token.trim().to_ascii_lowercase();
if normalized.is_empty() {
return;
}
if tokens
.iter()
.any(|existing| existing.eq_ignore_ascii_case(&normalized))
{
return;
}
tokens.push(normalized);
}
let mut tokens = Vec::new();
for (name, value) in &resp.headers {
if !name.eq_ignore_ascii_case("vary") {
continue;
}
for existing in value.split(',') {
push_vary_token(&mut tokens, existing);
}
}
push_vary_token(&mut tokens, token);
if tokens.is_empty() {
resp.remove_header("vary");
return;
}
resp.remove_header("vary");
resp.set_header("vary", tokens.join(", "));
}
fn normalize_header_name(name: impl Into<String>) -> String {
name.into().to_ascii_lowercase()
}
fn wall_clock_now() -> Time {
crate::time::wall_now()
}
pub struct TimeoutMiddleware<H> {
inner: H,
timeout: Duration,
time_getter: fn() -> Time,
}
impl<H: Handler> TimeoutMiddleware<H> {
#[must_use]
pub fn new(inner: H, timeout: Duration) -> Self {
Self::with_time_getter(inner, timeout, wall_clock_now)
}
#[must_use]
pub fn with_time_getter(inner: H, timeout: Duration, time_getter: fn() -> Time) -> Self {
Self {
inner,
timeout,
time_getter,
}
}
}
impl<H: Handler> Handler for TimeoutMiddleware<H> {
fn call(
&self,
cx: &crate::Cx,
req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let start = (self.time_getter)();
let resp = self.inner.call(&cx, req).await;
let elapsed = Duration::from_nanos((self.time_getter)().duration_since(start));
if elapsed > self.timeout {
Response::new(
StatusCode::GATEWAY_TIMEOUT,
format!("Request timed out after {elapsed:?}").into_bytes(),
)
} else {
resp
}
})
}
}
pub struct CircuitBreakerMiddleware<H> {
inner: H,
breaker: Arc<CircuitBreaker>,
time_getter: fn() -> Time,
}
impl<H: Handler> CircuitBreakerMiddleware<H> {
#[must_use]
pub fn new(inner: H, policy: CircuitBreakerPolicy) -> Self {
Self::with_time_getter(inner, policy, wall_clock_now)
}
#[must_use]
pub fn with_time_getter(
inner: H,
policy: CircuitBreakerPolicy,
time_getter: fn() -> Time,
) -> Self {
Self {
inner,
breaker: Arc::new(CircuitBreaker::new(policy)),
time_getter,
}
}
#[must_use]
pub fn shared(inner: H, breaker: Arc<CircuitBreaker>) -> Self {
Self::shared_with_time_getter(inner, breaker, wall_clock_now)
}
#[must_use]
pub fn shared_with_time_getter(
inner: H,
breaker: Arc<CircuitBreaker>,
time_getter: fn() -> Time,
) -> Self {
Self {
inner,
breaker,
time_getter,
}
}
#[must_use]
pub fn breaker(&self) -> &CircuitBreaker {
&self.breaker
}
}
impl<H: Handler> Handler for CircuitBreakerMiddleware<H> {
fn call(
&self,
cx: &crate::Cx,
req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let now = (self.time_getter)();
let permit = match self.breaker.should_allow(now) {
Ok(permit) => permit,
Err(crate::combinator::circuit_breaker::CircuitBreakerError::Open {
remaining,
}) => {
let body = format!(
"Service Unavailable: circuit breaker open, retry after {remaining:?}"
);
return Response::new(StatusCode::SERVICE_UNAVAILABLE, body.into_bytes())
.header("retry-after", format!("{}", remaining.as_secs().max(1)));
}
Err(crate::combinator::circuit_breaker::CircuitBreakerError::HalfOpenFull) => {
return Response::new(
StatusCode::SERVICE_UNAVAILABLE,
b"Service Unavailable: circuit breaker half-open, max probes active"
.to_vec(),
);
}
Err(crate::combinator::circuit_breaker::CircuitBreakerError::Inner(())) => {
unreachable!("should_allow cannot produce inner operation errors")
}
};
let resp = self.inner.call(&cx, req).await;
if resp.status.is_server_error() {
self.breaker.record_failure(permit, "server_error", now);
} else {
self.breaker.record_success(permit, now);
}
resp
})
}
}
pub struct RateLimitMiddleware<H> {
inner: H,
limiter: Arc<RateLimiter>,
time_getter: fn() -> Time,
}
impl<H: Handler> RateLimitMiddleware<H> {
#[must_use]
pub fn new(inner: H, policy: RateLimitPolicy) -> Self {
Self::with_time_getter(inner, policy, wall_clock_now)
}
#[must_use]
pub fn with_time_getter(inner: H, policy: RateLimitPolicy, time_getter: fn() -> Time) -> Self {
Self {
inner,
limiter: Arc::new(RateLimiter::new(policy)),
time_getter,
}
}
#[must_use]
pub fn shared(inner: H, limiter: Arc<RateLimiter>) -> Self {
Self::shared_with_time_getter(inner, limiter, wall_clock_now)
}
#[must_use]
pub fn shared_with_time_getter(
inner: H,
limiter: Arc<RateLimiter>,
time_getter: fn() -> Time,
) -> Self {
Self {
inner,
limiter,
time_getter,
}
}
#[must_use]
pub fn limiter(&self) -> &RateLimiter {
&self.limiter
}
}
impl<H: Handler> Handler for RateLimitMiddleware<H> {
fn call(
&self,
cx: &Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
let now = (self.time_getter)();
Box::pin(async move {
match self.limiter.call(now, || Ok::<_, Infallible>(())) {
Ok(()) => {
self.inner.call(&cx, req).await
}
Err(
crate::combinator::rate_limit::RateLimitError::RateLimitExceeded
| crate::combinator::rate_limit::RateLimitError::Timeout { .. }
| crate::combinator::rate_limit::RateLimitError::Cancelled,
) => {
let retry_after = self.limiter.retry_after(1, now);
let secs = retry_after.as_secs().max(1);
Response::new(
StatusCode::TOO_MANY_REQUESTS,
format!("Too Many Requests: rate limit exceeded, retry after {secs}s")
.into_bytes(),
)
.header("retry-after", format!("{secs}"))
}
Err(crate::combinator::rate_limit::RateLimitError::QueueIdExhausted) => {
Response::new(
StatusCode::SERVICE_UNAVAILABLE,
b"Service Unavailable: rate limiter queue exhausted".to_vec(),
)
}
Err(crate::combinator::rate_limit::RateLimitError::Inner(never)) => match never {},
}
})
}
}
pub struct BulkheadMiddleware<H> {
inner: H,
bulkhead: Arc<Bulkhead>,
}
impl<H: Handler> BulkheadMiddleware<H> {
#[must_use]
pub fn new(inner: H, policy: BulkheadPolicy) -> Self {
Self {
inner,
bulkhead: Arc::new(Bulkhead::new(policy)),
}
}
#[must_use]
pub fn shared(inner: H, bulkhead: Arc<Bulkhead>) -> Self {
Self { inner, bulkhead }
}
#[must_use]
pub fn bulkhead(&self) -> &Bulkhead {
&self.bulkhead
}
}
impl<H: Handler> Handler for BulkheadMiddleware<H> {
fn call(
&self,
cx: &Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
match self.bulkhead.try_acquire(1) {
Some(p) => {
let resp = self.inner.call(&cx, req).await;
p.release();
resp
}
None => Response::new(
StatusCode::SERVICE_UNAVAILABLE,
b"Service Unavailable: concurrency limit reached".to_vec(),
),
}
})
}
}
pub struct RetryMiddleware<H> {
inner: H,
policy: RetryPolicy,
idempotent_only: bool,
}
impl<H: Handler> RetryMiddleware<H> {
#[must_use]
pub fn new(inner: H, policy: RetryPolicy) -> Self {
Self {
inner,
policy,
idempotent_only: true,
}
}
#[must_use]
pub fn retry_all_methods(mut self) -> Self {
self.idempotent_only = false;
self
}
}
fn is_idempotent(method: &str) -> bool {
matches!(
method.to_uppercase().as_str(),
"GET" | "HEAD" | "OPTIONS" | "PUT" | "DELETE" | "TRACE"
)
}
impl<H: Handler> Handler for RetryMiddleware<H> {
fn call(
&self,
cx: &crate::Cx,
req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
if self.idempotent_only && !is_idempotent(&req.method) {
return self.inner.call(&cx, req).await;
}
let max = self.policy.max_attempts.max(1);
let mut delay = self.policy.initial_delay;
let mut last_resp = None;
for attempt in 0..max {
if attempt != 0 {
if !delay.is_zero() {
crate::time::sleep(wall_clock_now(), delay).await;
}
delay = Duration::from_secs_f64(
(delay.as_secs_f64() * self.policy.multiplier)
.min(self.policy.max_delay.as_secs_f64()),
);
}
let try_req = req.clone();
let resp = self.inner.call(&cx, try_req).await;
if !resp.status.is_server_error() {
return resp;
}
last_resp = Some(resp);
}
last_resp.unwrap_or_else(|| {
Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
b"Internal Server Error: all retry attempts exhausted".to_vec(),
)
})
})
}
}
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub supported: Vec<ContentEncoding>,
pub min_body_size: usize,
pub max_compressed_size: usize,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
supported: vec![
ContentEncoding::Brotli,
ContentEncoding::Gzip,
ContentEncoding::Deflate,
ContentEncoding::Identity,
],
min_body_size: 256,
max_compressed_size: DEFAULT_MAX_COMPRESSED_SIZE,
}
}
}
pub struct CompressionMiddleware<H> {
inner: H,
config: CompressionConfig,
}
impl<H: Handler> CompressionMiddleware<H> {
#[must_use]
pub fn new(inner: H, config: CompressionConfig) -> Self {
Self { inner, config }
}
}
impl<H: Handler> Handler for CompressionMiddleware<H> {
fn call(
&self,
cx: &Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let accept_encoding = header_value(&req, "accept-encoding");
let mut resp = self.inner.call(&cx, req).await;
if resp.status == StatusCode::NO_CONTENT || resp.status == StatusCode::NOT_MODIFIED {
return resp;
}
if let Some(existing_encoding) = resp.remove_header("content-encoding") {
resp.set_header("content-encoding", existing_encoding);
return resp;
}
let available_encodings: Vec<_> = self
.config
.supported
.iter()
.copied()
.filter(|encoding| compression_encoding_available(*encoding))
.collect();
let identity_acceptable =
negotiate_encoding(accept_encoding.as_deref(), &[ContentEncoding::Identity])
== Some(ContentEncoding::Identity);
let body_below_minimum = resp.body.len() < self.config.min_body_size;
if body_below_minimum && identity_acceptable {
return resp;
}
let candidate_encodings = if body_below_minimum {
available_encodings
.iter()
.copied()
.filter(|encoding| *encoding != ContentEncoding::Identity)
.collect::<Vec<_>>()
} else {
available_encodings
};
let Some(encoding) =
negotiate_encoding(accept_encoding.as_deref(), &candidate_encodings)
else {
if accept_encoding.is_some() {
return Response::new(
StatusCode::from_u16(406),
b"No acceptable response encoding".to_vec(),
);
}
return resp;
};
if encoding == ContentEncoding::Identity {
append_vary_header(&mut resp, "accept-encoding");
return resp;
}
let Some(mut compressor) =
make_compressor_with_output_limit(encoding, Some(self.config.max_compressed_size))
else {
if !identity_acceptable {
return Response::new(
StatusCode::from_u16(406),
b"No acceptable response encoding".to_vec(),
);
}
return resp;
};
let mut compressed = Vec::new();
if compressor.compress(&resp.body, &mut compressed).is_err() {
if !identity_acceptable {
return Response::new(
StatusCode::from_u16(406),
b"No acceptable response encoding".to_vec(),
);
}
append_vary_header(&mut resp, "accept-encoding");
return resp;
}
if compressor.finish(&mut compressed).is_err() {
if !identity_acceptable {
return Response::new(
StatusCode::from_u16(406),
b"No acceptable response encoding".to_vec(),
);
}
append_vary_header(&mut resp, "accept-encoding");
return resp;
}
if compressed.len() >= resp.body.len() && identity_acceptable {
append_vary_header(&mut resp, "accept-encoding");
return resp;
}
resp.body = compressed.into();
resp.remove_header("content-length");
resp.set_header("content-encoding", encoding.as_token().to_string());
append_vary_header(&mut resp, "accept-encoding");
resp
})
}
}
fn compression_encoding_available(encoding: ContentEncoding) -> bool {
match encoding {
ContentEncoding::Identity => true,
#[cfg(feature = "compression")]
ContentEncoding::Brotli | ContentEncoding::Gzip | ContentEncoding::Deflate => true,
#[cfg(not(feature = "compression"))]
ContentEncoding::Brotli | ContentEncoding::Gzip | ContentEncoding::Deflate => false,
}
}
pub struct RequestBodyLimitMiddleware<H> {
inner: H,
max_bytes: usize,
}
impl<H: Handler> RequestBodyLimitMiddleware<H> {
#[must_use]
pub fn new(inner: H, max_bytes: usize) -> Self {
Self { inner, max_bytes }
}
}
impl<H: Handler> Handler for RequestBodyLimitMiddleware<H> {
fn call(
&self,
cx: &Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
if let Some(cl_value) = super::extract::header_value_ci(&req, "content-length") {
if let Ok(declared_length) = super::extract::parse_content_length(cl_value) {
if declared_length > self.max_bytes {
return Response::new(
StatusCode::PAYLOAD_TOO_LARGE,
format!(
"Payload Too Large: Content-Length {} bytes exceeds limit {} bytes",
declared_length, self.max_bytes
)
.into_bytes(),
);
}
}
}
if req.body.len() > self.max_bytes {
return Response::new(
StatusCode::PAYLOAD_TOO_LARGE,
format!(
"Payload Too Large: body is {} bytes, limit is {} bytes",
req.body.len(),
self.max_bytes
)
.into_bytes(),
);
}
self.inner.call(&cx, req).await
})
}
}
pub const DEFAULT_REQUEST_ID_MAX_LENGTH: usize = 128;
pub const DEFAULT_TRACE_ID_MAX_LENGTH: usize = 128;
pub struct RequestIdMiddleware<H> {
inner: H,
header_name: String,
counter: Arc<AtomicU64>,
max_id_length: usize,
}
impl<H: Handler> RequestIdMiddleware<H> {
#[must_use]
pub fn new(inner: H, header_name: impl Into<String>) -> Self {
Self {
inner,
header_name: normalize_header_name(header_name),
counter: Arc::new(AtomicU64::new(1)),
max_id_length: DEFAULT_REQUEST_ID_MAX_LENGTH,
}
}
#[must_use]
pub fn shared(inner: H, header_name: impl Into<String>, counter: Arc<AtomicU64>) -> Self {
Self {
inner,
header_name: normalize_header_name(header_name),
counter,
max_id_length: DEFAULT_REQUEST_ID_MAX_LENGTH,
}
}
#[must_use]
pub fn with_max_length(mut self, max: usize) -> Self {
self.max_id_length = if max == 0 {
DEFAULT_REQUEST_ID_MAX_LENGTH
} else {
max
};
self
}
}
fn truncate_request_id(id: &str, max: usize) -> String {
if id.chars().count() <= max {
return id.to_string();
}
let mut end = 0usize;
for (i, _) in id.char_indices().take(max) {
end = i;
}
let mut iter = id.char_indices().skip(max);
let cutoff = iter.next().map_or(id.len(), |(idx, _)| idx);
let _ = end;
id[..cutoff].to_string()
}
fn sanitize_and_truncate_id(id: &str, max_length: usize) -> String {
let sanitized = id.replace(['\r', '\n'], "");
truncate_request_id(&sanitized, max_length)
}
impl<H: Handler> Handler for RequestIdMiddleware<H> {
fn call(
&self,
cx: &crate::Cx,
mut req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let request_id = header_value(&req, &self.header_name).unwrap_or_else(|| {
let id = self.counter.fetch_add(1, Ordering::Relaxed);
format!("req-{id}")
});
let request_id = sanitize_and_truncate_id(&request_id, self.max_id_length);
req.extensions.insert("request_id", request_id.clone());
req.extensions.insert("trace_id", request_id.clone());
let mut resp = self.inner.call(&cx, req).await;
resp.set_header(&self.header_name, request_id);
resp
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RequestTracePolicy {
pub duration_header: Option<String>,
pub trace_header: Option<String>,
}
impl Default for RequestTracePolicy {
fn default() -> Self {
Self {
duration_header: Some("x-response-time-ms".to_string()),
trace_header: Some("x-trace-id".to_string()),
}
}
}
pub struct RequestTraceMiddleware<H> {
inner: H,
policy: RequestTracePolicy,
time_getter: fn() -> Time,
}
impl<H: Handler> RequestTraceMiddleware<H> {
#[must_use]
pub fn new(inner: H, policy: RequestTracePolicy) -> Self {
Self::with_time_getter(inner, policy, wall_clock_now)
}
#[must_use]
pub fn with_time_getter(
inner: H,
policy: RequestTracePolicy,
time_getter: fn() -> Time,
) -> Self {
let policy = RequestTracePolicy {
duration_header: policy.duration_header.map(normalize_header_name),
trace_header: policy.trace_header.map(normalize_header_name),
};
Self {
inner,
policy,
time_getter,
}
}
fn resolve_trace_id(req: &Request) -> Option<String> {
resolve_trace_id(req)
}
}
fn resolve_trace_id(req: &Request) -> Option<String> {
if let Some(id) = req.extensions.get("trace_id") {
return Some(id.to_string());
}
if let Some(id) = req.extensions.get("request_id") {
return Some(id.to_string());
}
header_value(req, "x-request-id")
.map(|id| sanitize_and_truncate_id(&id, DEFAULT_TRACE_ID_MAX_LENGTH))
}
impl<H: Handler> Handler for RequestTraceMiddleware<H> {
fn call(
&self,
cx: &Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let method = req.method.clone();
let path = req.path.clone();
let trace_id = Self::resolve_trace_id(&req);
let start = (self.time_getter)();
debug!(
method = %method,
path = %path,
trace_id = ?trace_id,
"http request start"
);
let mut resp = self.inner.call(&cx, req).await;
let duration_ms =
Duration::from_nanos((self.time_getter)().duration_since(start)).as_millis();
let status_code = resp.status.as_u16();
if let Some(header_name) = &self.policy.duration_header {
resp.set_header(header_name, duration_ms.to_string());
}
if let (Some(header_name), Some(id)) = (&self.policy.trace_header, trace_id.as_ref()) {
if !resp.has_header(header_name) {
resp.set_header(header_name, id.clone());
}
}
if status_code >= 500 {
warn!(
method = %method,
path = %path,
status = status_code,
duration_ms = duration_ms,
trace_id = ?trace_id,
"http request completed with server error"
);
} else {
debug!(
method = %method,
path = %path,
status = status_code,
duration_ms = duration_ms,
trace_id = ?trace_id,
"http request completed"
);
}
#[cfg(not(feature = "tracing-integration"))]
let _ = (&method, &path);
resp
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthPolicy {
AnyBearer,
ExactBearer(Vec<BearerToken>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BearerToken {
token: String,
expires_at: Time,
}
impl BearerToken {
#[must_use]
pub fn new(token: impl Into<String>, expires_at: Time) -> Self {
Self {
token: token.into(),
expires_at,
}
}
#[must_use]
pub fn non_expiring(token: impl Into<String>) -> Self {
Self::new(token, Time::MAX)
}
#[must_use]
pub fn token(&self) -> &str {
&self.token
}
#[must_use]
pub const fn expires_at(&self) -> Time {
self.expires_at
}
fn is_valid_at(&self, now: Time) -> bool {
now < self.expires_at
}
}
impl Default for AuthPolicy {
fn default() -> Self {
Self::ExactBearer(Vec::new())
}
}
impl AuthPolicy {
#[must_use]
pub fn exact_bearer(token: impl Into<String>) -> Self {
Self::ExactBearer(vec![BearerToken::non_expiring(token)])
}
#[must_use]
pub fn exact_bearer_until(token: impl Into<String>, expires_at: Time) -> Self {
Self::ExactBearer(vec![BearerToken::new(token, expires_at)])
}
#[must_use]
pub fn exact_bearer_for(token: impl Into<String>, issued_at: Time, ttl: Duration) -> Self {
Self::exact_bearer_until(token, issued_at + ttl)
}
pub fn rotate_exact_bearer(&mut self, token: impl Into<String>, expires_at: Time, now: Time) {
if let Self::ExactBearer(tokens) = self {
tokens.retain(|token| token.is_valid_at(now));
tokens.push(BearerToken::new(token, expires_at));
}
}
pub fn prune_expired(&mut self, now: Time) {
if let Self::ExactBearer(tokens) = self {
tokens.retain(|token| token.is_valid_at(now));
}
}
fn allows_at(&self, req: &Request, now: Time) -> bool {
let Some(value) = header_value(req, "authorization") else {
return false;
};
let Some(token) = parse_bearer_token(&value) else {
return false;
};
match self {
Self::AnyBearer => !token.is_empty(),
Self::ExactBearer(tokens) => {
tokens.iter().fold(false, |matched, expected| {
let token_matches = constant_time_str_eq(expected.token(), token);
let token_active = expected.is_valid_at(now);
#[allow(clippy::needless_bitwise_bool)]
let result = matched | (token_active & token_matches);
result
})
}
}
}
}
fn constant_time_str_eq(expected: &str, token: &str) -> bool {
let mut diff = 0u8;
if expected.len() != token.len() {
diff |= 1;
}
let token_bytes = token.as_bytes();
for (i, b) in expected.bytes().enumerate() {
diff |= b ^ token_bytes.get(i).copied().unwrap_or(0);
}
diff == 0
}
fn parse_bearer_token(header: &str) -> Option<&str> {
let (scheme, token) = header.trim().split_once(' ')?;
if scheme.eq_ignore_ascii_case("bearer") {
Some(token.trim())
} else {
None
}
}
pub struct AuthMiddleware<H> {
inner: H,
policy: AuthPolicy,
time_getter: fn() -> Time,
}
impl<H: Handler> AuthMiddleware<H> {
#[must_use]
pub fn new(inner: H, policy: AuthPolicy) -> Self {
Self::with_time_getter(inner, policy, wall_clock_now)
}
#[must_use]
pub fn with_time_getter(inner: H, policy: AuthPolicy, time_getter: fn() -> Time) -> Self {
Self {
inner,
policy,
time_getter,
}
}
}
impl<H: Handler> Handler for AuthMiddleware<H> {
fn call(
&self,
cx: &Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
if !self.policy.allows_at(&req, (self.time_getter)()) {
return Response::new(StatusCode::UNAUTHORIZED, b"Unauthorized".to_vec())
.header("www-authenticate", "Bearer");
}
self.inner.call(&cx, req).await
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LoadShedPolicy {
pub max_in_flight: usize,
}
impl Default for LoadShedPolicy {
fn default() -> Self {
Self {
max_in_flight: 1024,
}
}
}
struct InFlightGuard<'a> {
counter: &'a AtomicUsize,
}
impl Drop for InFlightGuard<'_> {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::AcqRel);
}
}
pub struct LoadShedMiddleware<H> {
inner: H,
policy: LoadShedPolicy,
in_flight: Arc<AtomicUsize>,
}
impl<H: Handler> LoadShedMiddleware<H> {
#[must_use]
pub fn new(inner: H, policy: LoadShedPolicy) -> Self {
Self {
inner,
policy,
in_flight: Arc::new(AtomicUsize::new(0)),
}
}
}
impl<H: Handler> Handler for LoadShedMiddleware<H> {
fn call(
&self,
cx: &Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let previous = self.in_flight.fetch_add(1, Ordering::AcqRel);
if previous >= self.policy.max_in_flight {
self.in_flight.fetch_sub(1, Ordering::AcqRel);
return Response::new(
StatusCode::SERVICE_UNAVAILABLE,
b"Service Unavailable: overloaded".to_vec(),
);
}
let _guard = InFlightGuard {
counter: &self.in_flight,
};
self.inner.call(&cx, req).await
})
}
}
pub struct CatchPanicMiddleware<H> {
inner: H,
}
impl<H: Handler> CatchPanicMiddleware<H> {
#[must_use]
pub fn new(inner: H) -> Self {
Self { inner }
}
}
#[allow(dead_code)]
fn panic_payload_message(payload: &(dyn std::any::Any + Send)) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
return (*s).to_string();
}
if let Some(s) = payload.downcast_ref::<String>() {
return s.clone();
}
"<non-string panic payload>".to_string()
}
impl<H: Handler> Handler for CatchPanicMiddleware<H> {
fn call(
&self,
cx: &Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let method = req.method.clone();
let path = req.path.clone();
let trace_id = resolve_trace_id(&req);
match AssertUnwindSafe(self.inner.call(&cx, req))
.catch_unwind()
.await
{
Ok(response) => response,
Err(payload) => {
let panic_message = panic_payload_message(payload.as_ref());
let _panic_log_fields = (&method, &path, &trace_id, &panic_message);
warn!(
method = %method,
path = %path,
trace_id = trace_id.as_deref().unwrap_or(""),
panic = %panic_message,
"web handler panic recovered"
);
Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
b"Internal Server Error".to_vec(),
)
}
}
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrailingSlash {
Trim,
Always,
RedirectTrim,
RedirectAlways,
}
pub struct NormalizePathMiddleware<H> {
inner: H,
strategy: TrailingSlash,
}
impl<H: Handler> NormalizePathMiddleware<H> {
#[must_use]
pub fn new(inner: H, strategy: TrailingSlash) -> Self {
Self { inner, strategy }
}
}
fn normalization_redirect_response(path: &str) -> Response {
let candidate = path.replace(['\r', '\n'], "");
match Redirect::permanent(candidate.clone()) {
Ok(redirect) => redirect.into_response(),
Err(err) => {
let _ = &err;
warn!(
path = %candidate,
error = %err,
"NormalizePathMiddleware: refusing unsafe redirect candidate"
);
Response::new(
StatusCode::BAD_REQUEST,
b"Bad Request: invalid normalized redirect target".to_vec(),
)
}
}
}
impl<H: Handler> Handler for NormalizePathMiddleware<H> {
fn call(
&self,
cx: &Cx,
mut req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let path = &req.path;
match self.strategy {
TrailingSlash::Trim => {
if path.len() > 1 && path.ends_with('/') {
req.path = path.trim_end_matches('/').to_string();
if req.path.is_empty() {
req.path = "/".to_string();
}
}
self.inner.call(&cx, req).await
}
TrailingSlash::Always => {
if !path.ends_with('/') && !path.contains('.') {
req.path = format!("{path}/");
}
self.inner.call(&cx, req).await
}
TrailingSlash::RedirectTrim => {
if path.len() > 1 && path.ends_with('/') {
let mut trimmed = path.trim_end_matches('/').to_string();
if trimmed.is_empty() {
trimmed = "/".to_string();
}
return normalization_redirect_response(&trimmed);
}
self.inner.call(&cx, req).await
}
TrailingSlash::RedirectAlways => {
if !path.ends_with('/') && !path.contains('.') {
let with_slash = format!("{path}/");
return normalization_redirect_response(&with_slash);
}
self.inner.call(&cx, req).await
}
}
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HeaderOverwrite {
Always,
IfMissing,
}
pub struct SetResponseHeaderMiddleware<H> {
inner: H,
name: String,
value: String,
mode: HeaderOverwrite,
}
impl<H: Handler> SetResponseHeaderMiddleware<H> {
#[must_use]
pub fn new(
inner: H,
name: impl Into<String>,
value: impl Into<String>,
mode: HeaderOverwrite,
) -> Self {
Self {
inner,
name: normalize_header_name(name),
value: value.into(),
mode,
}
}
#[must_use]
pub fn always(inner: H, name: impl Into<String>, value: impl Into<String>) -> Self {
Self::new(inner, name, value, HeaderOverwrite::Always)
}
#[must_use]
pub fn if_missing(inner: H, name: impl Into<String>, value: impl Into<String>) -> Self {
Self::new(inner, name, value, HeaderOverwrite::IfMissing)
}
}
impl<H: Handler> Handler for SetResponseHeaderMiddleware<H> {
fn call(
&self,
cx: &Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let mut resp = self.inner.call(&cx, req).await;
match self.mode {
HeaderOverwrite::Always => {
resp.set_header(&self.name, self.value.clone());
}
HeaderOverwrite::IfMissing => {
resp.ensure_header(&self.name, self.value.clone());
}
}
resp
})
}
}
pub struct MiddlewareStack<H> {
inner: H,
}
impl<H: Handler> MiddlewareStack<H> {
#[must_use]
pub fn new(inner: H) -> Self {
Self { inner }
}
#[must_use]
pub fn with_timeout(self, timeout: Duration) -> MiddlewareStack<TimeoutMiddleware<H>> {
MiddlewareStack {
inner: TimeoutMiddleware::new(self.inner, timeout),
}
}
#[must_use]
pub fn with_cors(self, policy: CorsPolicy) -> MiddlewareStack<CorsMiddleware<H>> {
MiddlewareStack {
inner: CorsMiddleware::new(self.inner, policy),
}
}
#[must_use]
pub fn with_circuit_breaker(
self,
policy: CircuitBreakerPolicy,
) -> MiddlewareStack<CircuitBreakerMiddleware<H>> {
MiddlewareStack {
inner: CircuitBreakerMiddleware::new(self.inner, policy),
}
}
#[must_use]
pub fn with_shared_circuit_breaker(
self,
breaker: Arc<CircuitBreaker>,
) -> MiddlewareStack<CircuitBreakerMiddleware<H>> {
MiddlewareStack {
inner: CircuitBreakerMiddleware::shared(self.inner, breaker),
}
}
#[must_use]
pub fn with_rate_limit(
self,
policy: RateLimitPolicy,
) -> MiddlewareStack<RateLimitMiddleware<H>> {
MiddlewareStack {
inner: RateLimitMiddleware::new(self.inner, policy),
}
}
#[must_use]
pub fn with_shared_rate_limit(
self,
limiter: Arc<RateLimiter>,
) -> MiddlewareStack<RateLimitMiddleware<H>> {
MiddlewareStack {
inner: RateLimitMiddleware::shared(self.inner, limiter),
}
}
#[must_use]
pub fn with_bulkhead(self, policy: BulkheadPolicy) -> MiddlewareStack<BulkheadMiddleware<H>> {
MiddlewareStack {
inner: BulkheadMiddleware::new(self.inner, policy),
}
}
#[must_use]
pub fn with_shared_bulkhead(
self,
bulkhead: Arc<Bulkhead>,
) -> MiddlewareStack<BulkheadMiddleware<H>> {
MiddlewareStack {
inner: BulkheadMiddleware::shared(self.inner, bulkhead),
}
}
#[must_use]
pub fn with_retry(self, policy: RetryPolicy) -> MiddlewareStack<RetryMiddleware<H>> {
MiddlewareStack {
inner: RetryMiddleware::new(self.inner, policy),
}
}
#[must_use]
pub fn with_compression(
self,
config: CompressionConfig,
) -> MiddlewareStack<CompressionMiddleware<H>> {
MiddlewareStack {
inner: CompressionMiddleware::new(self.inner, config),
}
}
#[must_use]
pub fn with_body_limit(
self,
max_bytes: usize,
) -> MiddlewareStack<RequestBodyLimitMiddleware<H>> {
MiddlewareStack {
inner: RequestBodyLimitMiddleware::new(self.inner, max_bytes),
}
}
#[must_use]
pub fn with_auth(self, policy: AuthPolicy) -> MiddlewareStack<AuthMiddleware<H>> {
MiddlewareStack {
inner: AuthMiddleware::new(self.inner, policy),
}
}
#[must_use]
pub fn with_load_shed(self, policy: LoadShedPolicy) -> MiddlewareStack<LoadShedMiddleware<H>> {
MiddlewareStack {
inner: LoadShedMiddleware::new(self.inner, policy),
}
}
#[must_use]
pub fn with_request_id(
self,
header_name: impl Into<String>,
) -> MiddlewareStack<RequestIdMiddleware<H>> {
MiddlewareStack {
inner: RequestIdMiddleware::new(self.inner, header_name),
}
}
#[must_use]
pub fn with_request_trace(
self,
policy: RequestTracePolicy,
) -> MiddlewareStack<RequestTraceMiddleware<H>> {
MiddlewareStack {
inner: RequestTraceMiddleware::new(self.inner, policy),
}
}
#[must_use]
pub fn with_catch_panic(self) -> MiddlewareStack<CatchPanicMiddleware<H>> {
MiddlewareStack {
inner: CatchPanicMiddleware::new(self.inner),
}
}
#[must_use]
pub fn with_normalize_path(
self,
strategy: TrailingSlash,
) -> MiddlewareStack<NormalizePathMiddleware<H>> {
MiddlewareStack {
inner: NormalizePathMiddleware::new(self.inner, strategy),
}
}
#[must_use]
pub fn with_response_header(
self,
name: impl Into<String>,
value: impl Into<String>,
mode: HeaderOverwrite,
) -> MiddlewareStack<SetResponseHeaderMiddleware<H>> {
MiddlewareStack {
inner: SetResponseHeaderMiddleware::new(self.inner, name, value, mode),
}
}
#[must_use]
pub fn build(self) -> H {
self.inner
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use std::panic::{self, AssertUnwindSafe};
use super::*;
use crate::web::handler::FnHandler;
thread_local! {
static TIMEOUT_TEST_TIME_MS: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
static CIRCUIT_TEST_TIME_MS: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
static REQUEST_TRACE_TEST_TIME_MS: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
static RATE_LIMIT_TEST_TIME_MS: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
}
fn set_timeout_test_time(ms: u64) {
TIMEOUT_TEST_TIME_MS.with(|t| t.set(ms));
}
fn timeout_test_time() -> Time {
Time::from_millis(TIMEOUT_TEST_TIME_MS.with(std::cell::Cell::get))
}
fn set_circuit_test_time(ms: u64) {
CIRCUIT_TEST_TIME_MS.with(|t| t.set(ms));
}
fn circuit_test_time() -> Time {
Time::from_millis(CIRCUIT_TEST_TIME_MS.with(std::cell::Cell::get))
}
fn set_request_trace_test_time(ms: u64) {
REQUEST_TRACE_TEST_TIME_MS.with(|t| t.set(ms));
}
fn request_trace_test_time() -> Time {
Time::from_millis(REQUEST_TRACE_TEST_TIME_MS.with(std::cell::Cell::get))
}
fn set_rate_limit_test_time(ms: u64) {
RATE_LIMIT_TEST_TIME_MS.with(|t| t.set(ms));
}
fn rate_limit_test_time() -> Time {
Time::from_millis(RATE_LIMIT_TEST_TIME_MS.with(std::cell::Cell::get))
}
fn auth_test_time_10s() -> Time {
Time::from_secs(10)
}
fn auth_test_time_20s() -> Time {
Time::from_secs(20)
}
fn ok_handler() -> &'static str {
"ok"
}
fn error_handler() -> Response {
Response::new(StatusCode::INTERNAL_SERVER_ERROR, b"fail".to_vec())
}
fn slow_handler() -> &'static str {
std::thread::sleep(Duration::from_millis(50));
"slow"
}
fn make_request() -> Request {
Request::new("GET", "/test")
}
fn call_sync<H: Handler + ?Sized>(handler: &H, req: Request) -> Response {
futures_lite::future::block_on(Handler::call(handler, &crate::Cx::for_testing(), req))
}
macro_rules! impl_test_sync_call {
($ty:ident) => {
impl<H: Handler> $ty<H> {
fn call(&self, req: Request) -> Response {
call_sync(self, req)
}
}
};
}
impl_test_sync_call!(CorsMiddleware);
impl_test_sync_call!(TimeoutMiddleware);
impl_test_sync_call!(CircuitBreakerMiddleware);
impl_test_sync_call!(RateLimitMiddleware);
impl_test_sync_call!(BulkheadMiddleware);
impl_test_sync_call!(RetryMiddleware);
impl_test_sync_call!(CompressionMiddleware);
impl_test_sync_call!(RequestBodyLimitMiddleware);
impl_test_sync_call!(RequestIdMiddleware);
impl_test_sync_call!(RequestTraceMiddleware);
impl_test_sync_call!(AuthMiddleware);
impl_test_sync_call!(LoadShedMiddleware);
impl_test_sync_call!(CatchPanicMiddleware);
impl_test_sync_call!(NormalizePathMiddleware);
impl_test_sync_call!(SetResponseHeaderMiddleware);
struct CountingHandler {
calls: Arc<std::sync::atomic::AtomicU32>,
delay: Duration,
status: StatusCode,
}
impl Handler for CountingHandler {
fn call(
&self,
_cx: &crate::Cx,
_req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async move {
self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if !self.delay.is_zero() {
std::thread::sleep(self.delay);
}
Response::new(self.status, b"counted".to_vec())
})
}
}
struct InspectHandler;
impl Handler for InspectHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async move {
req.extensions.get("trace_id").map_or_else(
|| Response::new(StatusCode::BAD_REQUEST, b"missing trace_id".to_vec()),
|value| Response::new(StatusCode::OK, value.as_bytes().to_vec()),
)
})
}
}
struct FailingIfCalled;
impl Handler for FailingIfCalled {
fn call(
&self,
_cx: &crate::Cx,
_req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async {
Response::new(StatusCode::INTERNAL_SERVER_ERROR, b"inner-called".to_vec())
})
}
}
struct InspectPathHandler;
impl Handler for InspectPathHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async move { Response::new(StatusCode::OK, req.path.into_bytes()) })
}
}
struct PanicHandler;
impl Handler for PanicHandler {
fn call(
&self,
_cx: &crate::Cx,
_req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async { panic!("boom") })
}
}
struct AdvanceTimeHandler {
next_time_ms: u64,
status: StatusCode,
}
impl Handler for AdvanceTimeHandler {
fn call(
&self,
_cx: &crate::Cx,
_req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async move {
set_timeout_test_time(self.next_time_ms);
Response::new(self.status, b"advanced".to_vec())
})
}
}
struct AdvanceRequestTraceTimeHandler {
next_time_ms: u64,
body: &'static [u8],
}
impl Handler for AdvanceRequestTraceTimeHandler {
fn call(
&self,
_cx: &crate::Cx,
_req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async move {
set_request_trace_test_time(self.next_time_ms);
Response::new(StatusCode::OK, self.body.to_vec())
})
}
}
#[test]
fn timeout_passes_when_fast() {
let mw = TimeoutMiddleware::new(FnHandler::new(ok_handler), Duration::from_secs(5));
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn timeout_triggers_when_slow() {
let mw = TimeoutMiddleware::new(FnHandler::new(slow_handler), Duration::from_millis(1));
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::GATEWAY_TIMEOUT);
}
#[test]
fn timeout_time_getter_can_trigger_without_sleep() {
set_timeout_test_time(0);
let mw = TimeoutMiddleware::with_time_getter(
AdvanceTimeHandler {
next_time_ms: 25,
status: StatusCode::OK,
},
Duration::from_millis(10),
timeout_test_time,
);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::GATEWAY_TIMEOUT);
}
#[test]
fn timeout_time_getter_preserves_fast_response() {
set_timeout_test_time(0);
let mw = TimeoutMiddleware::with_time_getter(
AdvanceTimeHandler {
next_time_ms: 5,
status: StatusCode::CREATED,
},
Duration::from_millis(10),
timeout_test_time,
);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::CREATED);
assert_eq!(resp.body.as_ref(), b"advanced");
}
#[test]
fn circuit_breaker_passes_success() {
let policy = CircuitBreakerPolicy::default();
let mw = CircuitBreakerMiddleware::new(FnHandler::new(ok_handler), policy);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn circuit_breaker_opens_after_failures() {
let policy = CircuitBreakerPolicy {
failure_threshold: 2,
..Default::default()
};
let mw = CircuitBreakerMiddleware::new(FnHandler::new(error_handler), policy);
let _ = mw.call(make_request());
let _ = mw.call(make_request());
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::SERVICE_UNAVAILABLE);
}
#[test]
fn circuit_breaker_shared_state() {
let policy = CircuitBreakerPolicy::default();
let breaker = Arc::new(CircuitBreaker::new(policy));
let mw1 =
CircuitBreakerMiddleware::shared(FnHandler::new(ok_handler), Arc::clone(&breaker));
let mw2 =
CircuitBreakerMiddleware::shared(FnHandler::new(ok_handler), Arc::clone(&breaker));
let _ = mw1.call(make_request());
assert_eq!(
mw1.breaker().metrics().total_success,
mw2.breaker().metrics().total_success
);
}
#[test]
fn circuit_breaker_surfaces_handler_error() {
let policy = CircuitBreakerPolicy {
failure_threshold: 10,
..Default::default()
};
let mw = CircuitBreakerMiddleware::new(FnHandler::new(error_handler), policy);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(resp.body.as_ref(), b"fail");
}
#[test]
fn circuit_breaker_preserves_original_server_error_status_and_body() {
fn bad_gateway_handler() -> Response {
Response::new(StatusCode::BAD_GATEWAY, b"upstream gateway failed".to_vec())
}
let policy = CircuitBreakerPolicy {
failure_threshold: 10,
..Default::default()
};
let mw = CircuitBreakerMiddleware::new(FnHandler::new(bad_gateway_handler), policy);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::BAD_GATEWAY);
assert_eq!(resp.body.as_ref(), b"upstream gateway failed");
}
#[test]
fn circuit_breaker_time_getter_controls_open_window() {
let policy = CircuitBreakerPolicy {
failure_threshold: 1,
success_threshold: 1,
open_duration: Duration::from_secs(10),
..Default::default()
};
let breaker = Arc::new(CircuitBreaker::new(policy));
let fail_mw = CircuitBreakerMiddleware::shared_with_time_getter(
FnHandler::new(error_handler),
Arc::clone(&breaker),
circuit_test_time,
);
let ok_mw = CircuitBreakerMiddleware::shared_with_time_getter(
FnHandler::new(ok_handler),
Arc::clone(&breaker),
circuit_test_time,
);
set_circuit_test_time(1_000);
let first = fail_mw.call(make_request());
assert_eq!(first.status, StatusCode::INTERNAL_SERVER_ERROR);
let open = ok_mw.call(make_request());
assert_eq!(open.status, StatusCode::SERVICE_UNAVAILABLE);
set_circuit_test_time(11_000);
let recovered = ok_mw.call(make_request());
assert_eq!(recovered.status, StatusCode::OK);
}
#[test]
fn rate_limit_allows_within_limit() {
let policy = RateLimitPolicy {
rate: 100,
burst: 10,
..Default::default()
};
let mw = RateLimitMiddleware::new(FnHandler::new(ok_handler), policy);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn rate_limit_rejects_over_limit() {
let policy = RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(60),
..Default::default()
};
let mw = RateLimitMiddleware::new(FnHandler::new(ok_handler), policy);
let resp1 = mw.call(make_request());
assert_eq!(resp1.status, StatusCode::OK);
let resp2 = mw.call(make_request());
assert_eq!(resp2.status, StatusCode::TOO_MANY_REQUESTS);
assert!(resp2.headers.contains_key("retry-after"));
}
#[test]
fn rate_limit_short_circuits_inner_handler() {
let calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
let handler = CountingHandler {
calls: Arc::clone(&calls),
delay: Duration::from_millis(0),
status: StatusCode::OK,
};
let policy = RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(60),
..Default::default()
};
let mw = RateLimitMiddleware::new(handler, policy);
let _ = mw.call(make_request());
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::TOO_MANY_REQUESTS);
assert_eq!(calls.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[test]
fn rate_limit_panic_restores_consumed_token() {
let limiter = Arc::new(RateLimiter::new(RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(60),
..Default::default()
}));
let panic_mw = RateLimitMiddleware::shared(PanicHandler, Arc::clone(&limiter));
let ok_mw = RateLimitMiddleware::shared(FnHandler::new(ok_handler), Arc::clone(&limiter));
let panic = panic::catch_unwind(AssertUnwindSafe(|| {
let _ = panic_mw.call(make_request());
}));
assert!(panic.is_err(), "inner handler should panic");
assert_eq!(
limiter.available_tokens(),
1,
"panic path must refund the consumed token"
);
let resp = ok_mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(limiter.available_tokens(), 0);
}
#[test]
fn rate_limit_time_getter_controls_retry_after_and_refill() {
let policy = RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(60),
..Default::default()
};
let mw = RateLimitMiddleware::with_time_getter(
FnHandler::new(ok_handler),
policy,
rate_limit_test_time,
);
set_rate_limit_test_time(10_000);
let first = mw.call(make_request());
assert_eq!(first.status, StatusCode::OK);
let rejected = mw.call(make_request());
assert_eq!(rejected.status, StatusCode::TOO_MANY_REQUESTS);
assert_eq!(
rejected.headers.get("retry-after").map(String::as_str),
Some("60")
);
set_rate_limit_test_time(40_000);
let still_limited = mw.call(make_request());
assert_eq!(still_limited.status, StatusCode::TOO_MANY_REQUESTS);
assert_eq!(
still_limited.headers.get("retry-after").map(String::as_str),
Some("30")
);
set_rate_limit_test_time(70_000);
let recovered = mw.call(make_request());
assert_eq!(recovered.status, StatusCode::OK);
}
#[test]
fn rate_limit_retry_after_matches_rfc9110_delay_seconds_example() {
let policy = RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(120),
..Default::default()
};
let mw = RateLimitMiddleware::with_time_getter(
FnHandler::new(ok_handler),
policy,
rate_limit_test_time,
);
set_rate_limit_test_time(5_000);
let first = mw.call(make_request());
assert_eq!(first.status, StatusCode::OK);
let rejected = mw.call(make_request());
assert_eq!(rejected.status, StatusCode::TOO_MANY_REQUESTS);
assert_eq!(
rejected.headers.get("retry-after").map(String::as_str),
Some("120")
);
}
#[test]
fn bulkhead_allows_within_limit() {
let policy = BulkheadPolicy {
max_concurrent: 10,
..Default::default()
};
let mw = BulkheadMiddleware::new(FnHandler::new(ok_handler), policy);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn bulkhead_releases_permit_after_call() {
let policy = BulkheadPolicy {
max_concurrent: 1,
..Default::default()
};
let mw = BulkheadMiddleware::new(FnHandler::new(ok_handler), policy);
for _ in 0..5 {
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
}
}
#[test]
fn retry_succeeds_on_first_try() {
let policy = RetryPolicy::immediate(3);
let mw = RetryMiddleware::new(FnHandler::new(ok_handler), policy);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn retry_exhausts_attempts_on_server_error() {
let policy = RetryPolicy::immediate(3);
let mw = RetryMiddleware::new(FnHandler::new(error_handler), policy);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn retry_skips_non_idempotent_by_default() {
let policy = RetryPolicy::immediate(3);
let mw = RetryMiddleware::new(FnHandler::new(error_handler), policy);
let resp = mw.call(Request::new("POST", "/create"));
assert_eq!(resp.status, StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn retry_all_methods_retries_post() {
use std::sync::atomic::{AtomicU32, Ordering};
static CALL_COUNT: AtomicU32 = AtomicU32::new(0);
fn counting_handler() -> Response {
CALL_COUNT.fetch_add(1, Ordering::SeqCst);
Response::new(StatusCode::INTERNAL_SERVER_ERROR, b"fail".to_vec())
}
CALL_COUNT.store(0, Ordering::SeqCst);
let policy = RetryPolicy::immediate(3);
let mw = RetryMiddleware::new(FnHandler::new(counting_handler), policy).retry_all_methods();
let _resp = mw.call(Request::new("POST", "/create"));
assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 3);
}
#[test]
fn idempotent_methods() {
assert!(is_idempotent("GET"));
assert!(is_idempotent("HEAD"));
assert!(is_idempotent("OPTIONS"));
assert!(is_idempotent("PUT"));
assert!(is_idempotent("DELETE"));
assert!(is_idempotent("TRACE"));
assert!(!is_idempotent("POST"));
assert!(!is_idempotent("PATCH"));
}
#[test]
fn compression_identity_sets_vary_header() {
let mw = CompressionMiddleware::new(
FnHandler::new(ok_handler),
CompressionConfig {
supported: vec![ContentEncoding::Identity],
min_body_size: 0,
..CompressionConfig::default()
},
);
let req = Request::new("GET", "/compress").with_header("accept-encoding", "identity");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("vary"),
Some(&"accept-encoding".to_string())
);
assert!(!resp.headers.contains_key("content-encoding"));
}
#[test]
fn compression_merges_mixed_case_vary_header() {
fn handler() -> Response {
let mut resp = Response::new(StatusCode::OK, b"ok".to_vec());
resp.headers
.insert("Vary".to_string(), "Accept-Language".to_string());
resp
}
let mw = CompressionMiddleware::new(
FnHandler::new(handler),
CompressionConfig {
supported: vec![ContentEncoding::Identity],
min_body_size: 0,
..CompressionConfig::default()
},
);
let req = Request::new("GET", "/compress").with_header("accept-encoding", "identity");
let resp = mw.call(req);
assert_eq!(
resp.headers.get("vary"),
Some(&"accept-language, accept-encoding".to_string())
);
assert!(!resp.headers.contains_key("Vary"));
}
#[test]
fn compression_rejects_not_acceptable_encodings() {
let mw = CompressionMiddleware::new(
FnHandler::new(ok_handler),
CompressionConfig {
supported: vec![ContentEncoding::Identity],
min_body_size: 0,
..CompressionConfig::default()
},
);
let req = Request::new("GET", "/compress")
.with_header("accept-encoding", "gzip;q=1, identity;q=0");
let resp = mw.call(req);
assert_eq!(resp.status.as_u16(), 406);
}
#[test]
fn body_limit_short_circuits_large_payload() {
let mw = RequestBodyLimitMiddleware::new(FailingIfCalled, 3);
let req = Request::new("POST", "/upload").with_body(b"abcdef".to_vec());
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::PAYLOAD_TOO_LARGE);
}
#[test]
fn request_id_generates_when_missing() {
let mw = RequestIdMiddleware::new(FnHandler::new(ok_handler), "x-request-id");
let resp = mw.call(Request::new("GET", "/req-id"));
let request_id = resp
.headers
.get("x-request-id")
.expect("request id header should be present");
assert!(request_id.starts_with("req-"));
}
#[test]
fn request_id_preserves_incoming_header_value() {
let mw = RequestIdMiddleware::new(FnHandler::new(ok_handler), "x-request-id");
let req = Request::new("GET", "/req-id").with_header("x-request-id", "abc-123");
let resp = mw.call(req);
assert_eq!(
resp.headers.get("x-request-id"),
Some(&"abc-123".to_string())
);
}
#[test]
fn request_id_normalizes_mixed_case_response_header_name() {
let mw = RequestIdMiddleware::new(FnHandler::new(ok_handler), "X-Request-Id");
let req = Request::new("GET", "/req-id").with_header("x-request-id", "abc-123");
let resp = mw.call(req);
assert_eq!(
resp.headers.get("x-request-id"),
Some(&"abc-123".to_string())
);
assert!(!resp.headers.contains_key("X-Request-Id"));
}
#[test]
fn request_id_overwrites_mixed_case_inner_header_without_duplication() {
fn header_handler() -> Response {
let mut resp = Response::new(StatusCode::OK, b"ok".to_vec());
resp.headers
.insert("X-Request-Id".to_string(), "inner".to_string());
resp
}
let mw = RequestIdMiddleware::new(FnHandler::new(header_handler), "x-request-id");
let req = Request::new("GET", "/req-id").with_header("x-request-id", "outer");
let resp = mw.call(req);
assert_eq!(resp.header_value("x-request-id"), Some("outer"));
assert_eq!(
resp.headers.len(),
1,
"response should not carry duplicate request-id headers"
);
assert!(!resp.headers.contains_key("X-Request-Id"));
}
#[test]
fn auth_rejects_missing_authorization_header() {
let mw = AuthMiddleware::new(FnHandler::new(ok_handler), AuthPolicy::AnyBearer);
let resp = mw.call(Request::new("GET", "/auth"));
assert_eq!(resp.status, StatusCode::UNAUTHORIZED);
assert_eq!(
resp.headers.get("www-authenticate"),
Some(&"Bearer".to_string())
);
}
#[test]
fn auth_accepts_matching_bearer_token() {
let mw = AuthMiddleware::new(
FnHandler::new(ok_handler),
AuthPolicy::exact_bearer("token-123"),
);
let req = Request::new("GET", "/auth").with_header("authorization", "Bearer token-123");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn auth_default_policy_fails_closed_on_presence_only_bearer() {
let mw = AuthMiddleware::new(FnHandler::new(ok_handler), AuthPolicy::default());
let req = Request::new("GET", "/auth").with_header("authorization", "Bearer token-123");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::UNAUTHORIZED);
assert_eq!(
resp.headers.get("www-authenticate"),
Some(&"Bearer".to_string())
);
}
#[test]
fn auth_accepts_rfc7515_detached_compact_jws_bearer_token() {
let detached_jws =
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9..dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
let mw = AuthMiddleware::new(
FnHandler::new(ok_handler),
AuthPolicy::exact_bearer(detached_jws),
);
let req = Request::new("GET", "/auth")
.with_header("authorization", format!("Bearer {detached_jws}"));
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn auth_rejects_non_matching_bearer_token() {
let mw = AuthMiddleware::new(
FnHandler::new(ok_handler),
AuthPolicy::exact_bearer("token-123"),
);
let req = Request::new("GET", "/auth").with_header("authorization", "Bearer nope");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::UNAUTHORIZED);
}
#[test]
fn auth_rejects_expired_exact_bearer_token() {
let mw = AuthMiddleware::with_time_getter(
FnHandler::new(ok_handler),
AuthPolicy::exact_bearer_until("token-123", Time::from_secs(10)),
auth_test_time_10s,
);
let req = Request::new("GET", "/auth").with_header("authorization", "Bearer token-123");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::UNAUTHORIZED);
}
#[test]
fn auth_accepts_unexpired_exact_bearer_token() {
let mw = AuthMiddleware::with_time_getter(
FnHandler::new(ok_handler),
AuthPolicy::exact_bearer_until("token-123", Time::from_secs(11)),
auth_test_time_10s,
);
let req = Request::new("GET", "/auth").with_header("authorization", "Bearer token-123");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn auth_rotation_accepts_new_token_and_rejects_expired_old_token() {
let mut policy = AuthPolicy::exact_bearer_until("old-token", Time::from_secs(20));
policy.rotate_exact_bearer("new-token", Time::from_secs(40), Time::from_secs(10));
let before_expiry = AuthMiddleware::with_time_getter(
FnHandler::new(ok_handler),
policy.clone(),
auth_test_time_10s,
);
let old_req = Request::new("GET", "/auth").with_header("authorization", "Bearer old-token");
let new_req = Request::new("GET", "/auth").with_header("authorization", "Bearer new-token");
assert_eq!(before_expiry.call(old_req).status, StatusCode::OK);
assert_eq!(before_expiry.call(new_req).status, StatusCode::OK);
let after_expiry = AuthMiddleware::with_time_getter(
FnHandler::new(ok_handler),
policy,
auth_test_time_20s,
);
let old_req = Request::new("GET", "/auth").with_header("authorization", "Bearer old-token");
let new_req = Request::new("GET", "/auth").with_header("authorization", "Bearer new-token");
assert_eq!(after_expiry.call(old_req).status, StatusCode::UNAUTHORIZED);
assert_eq!(after_expiry.call(new_req).status, StatusCode::OK);
}
#[test]
fn auth_rotation_prune_removes_expired_tokens_without_dropping_replacement() {
let mut policy = AuthPolicy::exact_bearer_until("old-token", Time::from_secs(20));
policy.rotate_exact_bearer("new-token", Time::from_secs(40), Time::from_secs(10));
policy.prune_expired(Time::from_secs(20));
let AuthPolicy::ExactBearer(tokens) = &policy else {
panic!("rotation must preserve exact-bearer policy");
};
assert_eq!(
tokens.iter().map(BearerToken::token).collect::<Vec<_>>(),
vec!["new-token"],
"prune should remove expired old tokens and keep active replacements"
);
let mw = AuthMiddleware::with_time_getter(
FnHandler::new(ok_handler),
policy,
auth_test_time_20s,
);
let old_req = Request::new("GET", "/auth").with_header("authorization", "Bearer old-token");
let new_req = Request::new("GET", "/auth").with_header("authorization", "Bearer new-token");
assert_eq!(mw.call(old_req).status, StatusCode::UNAUTHORIZED);
assert_eq!(mw.call(new_req).status, StatusCode::OK);
}
#[test]
fn load_shed_rejects_when_capacity_zero() {
let mw = LoadShedMiddleware::new(
FnHandler::new(ok_handler),
LoadShedPolicy { max_in_flight: 0 },
);
let resp = mw.call(Request::new("GET", "/shed"));
assert_eq!(resp.status, StatusCode::SERVICE_UNAVAILABLE);
}
#[test]
fn catch_panic_returns_internal_server_error() {
let mw = CatchPanicMiddleware::new(PanicHandler);
let resp = mw.call(Request::new("GET", "/panic"));
assert_eq!(resp.status, StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn normalize_path_trim_rewrites_trailing_slash() {
let mw = NormalizePathMiddleware::new(InspectPathHandler, TrailingSlash::Trim);
let resp = mw.call(Request::new("GET", "/users/"));
assert_eq!(&resp.body[..], b"/users");
}
#[test]
fn normalize_path_redirect_always_redirects_without_slash() {
let mw = NormalizePathMiddleware::new(InspectPathHandler, TrailingSlash::RedirectAlways);
let resp = mw.call(Request::new("GET", "/users"));
assert_eq!(resp.status, StatusCode::MOVED_PERMANENTLY);
assert_eq!(resp.headers.get("location"), Some(&"/users/".to_string()));
}
#[test]
fn set_response_header_if_missing_preserves_existing() {
let inner = FnHandler::new(|| {
Response::new(StatusCode::OK, b"ok".to_vec()).header("x-env", "existing")
});
let mw = SetResponseHeaderMiddleware::if_missing(inner, "x-env", "new");
let resp = mw.call(Request::new("GET", "/"));
assert_eq!(resp.headers.get("x-env"), Some(&"existing".to_string()));
}
#[test]
fn cors_adds_headers_for_simple_request() {
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), CorsPolicy::default());
let req = Request::new("GET", "/cors").with_header("Origin", "https://example.com");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("access-control-allow-origin"),
Some(&"*".to_string())
);
assert_eq!(resp.headers.get("vary"), Some(&"origin".to_string()));
}
#[test]
fn cors_merges_mixed_case_vary_header_without_duplicates() {
fn handler() -> Response {
let mut resp = Response::new(StatusCode::OK, b"ok".to_vec());
resp.headers
.insert("Vary".to_string(), "Accept-Language, Origin".to_string());
resp
}
let mw = CorsMiddleware::new(FnHandler::new(handler), CorsPolicy::default());
let req = Request::new("GET", "/cors").with_header("Origin", "https://example.com");
let resp = mw.call(req);
assert_eq!(
resp.headers.get("vary"),
Some(&"accept-language, origin".to_string())
);
assert!(!resp.headers.contains_key("Vary"));
}
#[test]
fn cors_preflight_short_circuits_inner_handler() {
let mw = CorsMiddleware::new(FailingIfCalled, CorsPolicy::default());
let req = Request::new("OPTIONS", "/cors")
.with_header("Origin", "https://example.com")
.with_header("Access-Control-Request-Method", "POST")
.with_header("Access-Control-Request-Headers", "content-type");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::NO_CONTENT);
assert_eq!(
resp.headers.get("access-control-allow-origin"),
Some(&"*".to_string())
);
assert!(resp.headers.contains_key("access-control-allow-methods"));
assert!(resp.headers.contains_key("access-control-allow-headers"));
}
#[test]
fn _0qb0bf_default_allow_headers_is_narrow_safe_list_not_wildcard() {
let policy = CorsPolicy::default();
assert!(
!policy.allow_headers.iter().any(|h| h == "*"),
"default must NOT contain wildcard; got {:?}",
policy.allow_headers
);
for expected in [
"Accept",
"Accept-Language",
"Content-Type",
"Authorization",
"X-Requested-With",
] {
assert!(
policy.allow_headers.iter().any(|h| h == expected),
"default allowlist missing {expected:?}; got {:?}",
policy.allow_headers
);
}
}
#[test]
fn _0qb0bf_default_preflight_does_not_echo_arbitrary_requested_headers() {
let mw = CorsMiddleware::new(FailingIfCalled, CorsPolicy::default());
let req = Request::new("OPTIONS", "/cors")
.with_header("Origin", "https://example.com")
.with_header("Access-Control-Request-Method", "POST")
.with_header(
"Access-Control-Request-Headers",
"X-Evil-Internal, X-Internal-Auth, X-Backend-Token",
);
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::NO_CONTENT);
let allow = resp
.headers
.get("access-control-allow-headers")
.expect("preflight must set Access-Control-Allow-Headers");
for forbidden in ["X-Evil-Internal", "X-Internal-Auth", "X-Backend-Token"] {
assert!(
!allow.contains(forbidden),
"default preflight must NOT echo arbitrary requested header \
{forbidden:?}; Allow-Headers was {allow:?}"
);
}
for expected in ["Authorization", "Content-Type", "X-Requested-With"] {
assert!(
allow.contains(expected),
"static allowlist entry {expected:?} must be in Allow-Headers; \
got {allow:?}"
);
}
assert!(
!allow.contains('*'),
"default preflight must NOT advertise wildcard; got {allow:?}"
);
}
#[test]
fn _0qb0bf_with_any_headers_opt_in_restores_wildcard() {
let policy = CorsPolicy::with_any_headers();
assert_eq!(policy.allow_headers, vec!["*".to_string()]);
let mw = CorsMiddleware::new(FailingIfCalled, policy);
let req = Request::new("OPTIONS", "/cors")
.with_header("Origin", "https://example.com")
.with_header("Access-Control-Request-Method", "POST")
.with_header("Access-Control-Request-Headers", "X-Any-Header");
let resp = mw.call(req);
assert_eq!(
resp.headers.get("access-control-allow-headers"),
Some(&"*".to_string()),
"with_any_headers must produce wildcard on the wire"
);
}
#[test]
fn cors_exact_origins_blocks_unknown_origin() {
let policy = CorsPolicy::with_exact_origins(vec![
"https://allowed.example".to_string(),
"https://another.example".to_string(),
]);
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), policy);
let blocked =
mw.call(Request::new("GET", "/cors").with_header("Origin", "https://blocked.example"));
assert_eq!(blocked.status, StatusCode::OK);
assert!(!blocked.headers.contains_key("access-control-allow-origin"));
let allowed =
mw.call(Request::new("GET", "/cors").with_header("Origin", "https://allowed.example"));
assert_eq!(allowed.status, StatusCode::OK);
assert_eq!(
allowed.headers.get("access-control-allow-origin"),
Some(&"https://allowed.example".to_string())
);
}
#[test]
fn cors_credentials_with_allowlisted_origin_echoes_exact_origin() {
let policy = CorsPolicy {
allow_credentials: true,
..CorsPolicy::with_exact_origins(vec!["https://cred.example".to_string()])
};
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), policy);
let resp =
mw.call(Request::new("GET", "/cors").with_header("Origin", "https://cred.example"));
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("access-control-allow-origin"),
Some(&"https://cred.example".to_string())
);
assert_eq!(
resp.headers.get("access-control-allow-credentials"),
Some(&"true".to_string())
);
}
#[test]
fn cors_credentials_with_non_allowlisted_origin_emits_no_allow_origin() {
let policy = CorsPolicy {
allow_credentials: true,
..CorsPolicy::with_exact_origins(vec!["https://allowed.example".to_string()])
};
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), policy);
let resp =
mw.call(Request::new("GET", "/cors").with_header("Origin", "https://attacker.example"));
assert_eq!(resp.status, StatusCode::OK, "inner handler still runs");
assert!(
!resp.headers.contains_key("access-control-allow-origin"),
"non-allowlisted origin must not receive Allow-Origin"
);
assert!(
!resp
.headers
.contains_key("access-control-allow-credentials"),
"non-allowlisted origin must not receive Allow-Credentials"
);
}
#[test]
fn cors_multi_origin_header_fails_closed_for_exact_allowlist() {
let policy = CorsPolicy {
allow_credentials: true,
..CorsPolicy::with_exact_origins(vec!["https://allowed.example".to_string()])
};
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), policy);
let resp = mw.call(Request::new("GET", "/cors").with_header(
"Origin",
"https://allowed.example, https://attacker.example",
));
assert_eq!(resp.status, StatusCode::OK, "inner handler still runs");
assert!(
!resp.headers.contains_key("access-control-allow-origin"),
"malformed multi-origin header must not be reflected"
);
assert!(
!resp
.headers
.contains_key("access-control-allow-credentials"),
"malformed multi-origin header must not receive Allow-Credentials"
);
}
#[test]
fn cors_multi_origin_header_fails_closed_for_any_policy() {
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), CorsPolicy::default());
let resp = mw.call(Request::new("GET", "/cors").with_header(
"Origin",
"https://allowed.example, https://attacker.example",
));
assert_eq!(resp.status, StatusCode::OK, "inner handler still runs");
assert!(
!resp.headers.contains_key("access-control-allow-origin"),
"malformed multi-origin header must not receive wildcard Allow-Origin"
);
}
#[cfg(not(debug_assertions))]
#[test]
fn cors_credentials_with_any_policy_fails_closed_in_release() {
let policy = CorsPolicy {
allow_credentials: true,
..CorsPolicy::default() };
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), policy);
for origin in [
"https://attacker.example",
"https://anything-else.example",
"https://cred.example",
] {
let resp = mw.call(Request::new("GET", "/cors").with_header("Origin", origin));
assert_eq!(resp.status, StatusCode::OK);
assert!(
!resp.headers.contains_key("access-control-allow-origin"),
"Any+credentials must not echo any origin (saw {origin})"
);
assert!(
!resp
.headers
.contains_key("access-control-allow-credentials"),
"Any+credentials must not emit Allow-Credentials (saw {origin})"
);
}
}
#[test]
fn fetch_3_2_preflight_requires_origin_header() {
let inner_calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
let handler = CountingHandler {
calls: Arc::clone(&inner_calls),
delay: Duration::ZERO,
status: StatusCode::OK,
};
let mw = CorsMiddleware::new(handler, CorsPolicy::default());
let req =
Request::new("OPTIONS", "/cors").with_header("Access-Control-Request-Method", "POST");
let resp = mw.call(req);
assert_eq!(
inner_calls.load(std::sync::atomic::Ordering::SeqCst),
1,
"OPTIONS without Origin must reach the inner handler",
);
assert_eq!(resp.status, StatusCode::OK);
assert!(
!resp.headers.contains_key("access-control-allow-origin"),
"non-CORS OPTIONS must not emit ACAO",
);
}
#[test]
fn fetch_3_2_preflight_requires_acrm_header() {
let inner_calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
let handler = CountingHandler {
calls: Arc::clone(&inner_calls),
delay: Duration::ZERO,
status: StatusCode::OK,
};
let mw = CorsMiddleware::new(handler, CorsPolicy::default());
let req = Request::new("OPTIONS", "/cors").with_header("Origin", "https://example.com");
let resp = mw.call(req);
assert_eq!(
inner_calls.load(std::sync::atomic::Ordering::SeqCst),
1,
"OPTIONS without ACRM must reach the inner handler (not a preflight)",
);
assert_eq!(
resp.headers.get("access-control-allow-origin"),
Some(&"*".to_string()),
"non-preflight CORS request still gets ACAO",
);
assert!(
!resp.headers.contains_key("access-control-allow-methods"),
"Allow-Methods is preflight-only",
);
assert!(
!resp.headers.contains_key("access-control-max-age"),
"Max-Age is preflight-only",
);
}
#[test]
fn fetch_3_2_preflight_response_status_is_204() {
let mw = CorsMiddleware::new(FailingIfCalled, CorsPolicy::default());
let req = Request::new("OPTIONS", "/cors")
.with_header("Origin", "https://example.com")
.with_header("Access-Control-Request-Method", "POST");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::NO_CONTENT);
assert!(resp.body.is_empty(), "preflight body must be empty");
}
#[test]
fn fetch_3_2_preflight_allow_methods_comes_from_policy_not_request() {
let policy = CorsPolicy {
allow_methods: vec!["GET".to_string(), "POST".to_string()],
..CorsPolicy::default()
};
let mw = CorsMiddleware::new(FailingIfCalled, policy);
let req = Request::new("OPTIONS", "/cors")
.with_header("Origin", "https://example.com")
.with_header("Access-Control-Request-Method", "DELETE");
let resp = mw.call(req);
let allow = resp
.headers
.get("access-control-allow-methods")
.expect("preflight must set Allow-Methods");
assert!(allow.contains("GET"));
assert!(allow.contains("POST"));
assert!(
!allow.contains("DELETE"),
"DELETE was requested but is not in the policy — it must NOT be echoed; got {allow:?}",
);
}
#[test]
fn fetch_3_2_preflight_max_age_emitted_from_policy() {
let policy = CorsPolicy {
max_age: Some(Duration::from_secs(7200)),
..CorsPolicy::default()
};
let mw = CorsMiddleware::new(FailingIfCalled, policy);
let req = Request::new("OPTIONS", "/cors")
.with_header("Origin", "https://example.com")
.with_header("Access-Control-Request-Method", "POST");
let resp = mw.call(req);
assert_eq!(
resp.headers.get("access-control-max-age"),
Some(&"7200".to_string()),
"Max-Age must reflect the policy duration in seconds",
);
}
#[test]
fn fetch_3_2_preflight_max_age_omitted_when_none() {
let policy = CorsPolicy {
max_age: None,
..CorsPolicy::default()
};
let mw = CorsMiddleware::new(FailingIfCalled, policy);
let req = Request::new("OPTIONS", "/cors")
.with_header("Origin", "https://example.com")
.with_header("Access-Control-Request-Method", "POST");
let resp = mw.call(req);
assert!(
!resp.headers.contains_key("access-control-max-age"),
"Max-Age must be omitted when policy.max_age is None",
);
}
#[test]
fn fetch_3_2_preflight_vary_includes_origin_acrm_acrh() {
let mw = CorsMiddleware::new(FailingIfCalled, CorsPolicy::default());
let req = Request::new("OPTIONS", "/cors")
.with_header("Origin", "https://example.com")
.with_header("Access-Control-Request-Method", "POST")
.with_header("Access-Control-Request-Headers", "content-type");
let resp = mw.call(req);
let vary = resp
.headers
.get("vary")
.expect("preflight must emit a Vary header");
for token in [
"origin",
"access-control-request-method",
"access-control-request-headers",
] {
assert!(
vary.split(',')
.any(|t| t.trim().eq_ignore_ascii_case(token)),
"Vary must include {token:?}; got {vary:?}",
);
}
}
#[test]
fn fetch_3_2_simple_request_emits_expose_headers_from_policy() {
let policy = CorsPolicy {
expose_headers: vec!["X-Request-Id".to_string(), "ETag".to_string()],
..CorsPolicy::default()
};
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), policy);
let req = Request::new("GET", "/cors").with_header("Origin", "https://example.com");
let resp = mw.call(req);
let expose = resp
.headers
.get("access-control-expose-headers")
.expect("Expose-Headers must be set when policy lists them");
assert!(expose.contains("X-Request-Id"));
assert!(expose.contains("ETag"));
}
#[test]
fn fetch_3_2_simple_request_omits_expose_headers_when_policy_empty() {
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), CorsPolicy::default());
let req = Request::new("GET", "/cors").with_header("Origin", "https://example.com");
let resp = mw.call(req);
assert!(
!resp.headers.contains_key("access-control-expose-headers"),
"Expose-Headers must be omitted when policy.expose_headers is empty",
);
}
#[test]
fn fetch_3_2_origin_null_is_not_malformed() {
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), CorsPolicy::default());
let req = Request::new("GET", "/cors").with_header("Origin", "null");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("access-control-allow-origin"),
Some(&"*".to_string()),
"Any policy must echo `*` for an opaque (`null`) origin on a non-credentialed request",
);
}
#[test]
fn fetch_3_2_origin_null_not_in_exact_allowlist_emits_no_acao() {
let policy = CorsPolicy {
allow_credentials: true,
..CorsPolicy::with_exact_origins(vec!["https://app.example.com".to_string()])
};
let mw = CorsMiddleware::new(FnHandler::new(ok_handler), policy);
let req = Request::new("GET", "/cors").with_header("Origin", "null");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert!(
!resp.headers.contains_key("access-control-allow-origin"),
"Origin: null must not match an exact-origin allow-list of named origins",
);
assert!(
!resp
.headers
.contains_key("access-control-allow-credentials"),
);
}
#[test]
fn middleware_stack_builds() {
let handler = MiddlewareStack::new(FnHandler::new(ok_handler))
.with_timeout(Duration::from_secs(5))
.build();
let resp = handler.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn middleware_stack_composition() {
let handler = MiddlewareStack::new(FnHandler::new(ok_handler))
.with_cors(CorsPolicy::default())
.with_auth(AuthPolicy::AnyBearer)
.with_load_shed(LoadShedPolicy { max_in_flight: 16 })
.with_bulkhead(BulkheadPolicy {
max_concurrent: 10,
..Default::default()
})
.with_rate_limit(RateLimitPolicy {
rate: 100,
burst: 50,
..Default::default()
})
.with_timeout(Duration::from_secs(30))
.build();
let resp = handler.call(make_request().with_header("authorization", "Bearer token"));
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn middleware_stack_with_retry() {
let handler = MiddlewareStack::new(FnHandler::new(ok_handler))
.with_retry(RetryPolicy::immediate(3))
.with_timeout(Duration::from_secs(5))
.build();
let resp = handler.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn middleware_stack_preserves_request_extensions() {
let handler = MiddlewareStack::new(InspectHandler)
.with_timeout(Duration::from_secs(1))
.with_rate_limit(RateLimitPolicy {
rate: 100,
burst: 100,
period: Duration::from_secs(1),
..Default::default()
})
.build();
let mut req = Request::new("GET", "/ctx");
req.extensions.insert("trace_id", "trace-123");
let resp = handler.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(&resp.body[..], b"trace-123");
}
#[test]
fn middleware_stack_retry_wraps_timeout() {
let calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
let handler = CountingHandler {
calls: Arc::clone(&calls),
delay: Duration::from_millis(10),
status: StatusCode::OK,
};
let stacked = MiddlewareStack::new(handler)
.with_timeout(Duration::from_millis(1))
.with_retry(RetryPolicy::immediate(3))
.build();
let resp = stacked.call(make_request());
assert_eq!(resp.status, StatusCode::GATEWAY_TIMEOUT);
assert_eq!(calls.load(std::sync::atomic::Ordering::SeqCst), 3);
}
#[test]
fn middleware_stack_last_added_header_covers_rate_limit_short_circuit() {
let inner_header = MiddlewareStack::new(FnHandler::new(ok_handler))
.with_response_header(
"content-security-policy",
"default-src 'none'",
HeaderOverwrite::IfMissing,
)
.with_rate_limit(RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(60),
..Default::default()
})
.build();
let outer_header = MiddlewareStack::new(FnHandler::new(ok_handler))
.with_rate_limit(RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(60),
..Default::default()
})
.with_response_header(
"content-security-policy",
"default-src 'none'",
HeaderOverwrite::IfMissing,
)
.build();
assert_eq!(inner_header.call(make_request()).status, StatusCode::OK);
let inner_limited = inner_header.call(make_request());
assert_eq!(inner_limited.status, StatusCode::TOO_MANY_REQUESTS);
assert!(
!inner_limited
.headers
.contains_key("content-security-policy")
);
assert_eq!(outer_header.call(make_request()).status, StatusCode::OK);
let outer_limited = outer_header.call(make_request());
assert_eq!(outer_limited.status, StatusCode::TOO_MANY_REQUESTS);
assert_eq!(
outer_limited.headers.get("content-security-policy"),
Some(&"default-src 'none'".to_string())
);
}
#[test]
fn circuit_breaker_metrics_accessible() {
let policy = CircuitBreakerPolicy::default();
let mw = CircuitBreakerMiddleware::new(FnHandler::new(ok_handler), policy);
let _ = mw.call(make_request());
let metrics = mw.breaker().metrics();
assert_eq!(metrics.total_success, 1);
}
#[test]
fn rate_limit_metrics_accessible() {
let policy = RateLimitPolicy::default();
let burst = policy.burst;
let mw = RateLimitMiddleware::new(FnHandler::new(ok_handler), policy);
let _ = mw.call(make_request());
let metrics = mw.limiter().metrics();
assert!(metrics.total_allowed > 0);
assert!(metrics.available_tokens <= burst);
}
#[test]
fn bulkhead_metrics_accessible() {
let policy = BulkheadPolicy {
max_concurrent: 5,
..Default::default()
};
let mw = BulkheadMiddleware::new(FnHandler::new(ok_handler), policy);
let _ = mw.call(make_request());
let metrics = mw.bulkhead().metrics();
assert_eq!(metrics.active_permits, 0);
}
#[test]
fn compression_skips_small_bodies() {
let config = CompressionConfig {
min_body_size: 1000,
..Default::default()
};
let mw = CompressionMiddleware::new(FnHandler::new(ok_handler), config);
let req = make_request().with_header("Accept-Encoding", "gzip");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert!(!resp.headers.contains_key("content-encoding"));
}
#[test]
fn compression_rejects_small_body_when_identity_is_unacceptable() {
let config = CompressionConfig {
min_body_size: 1000,
..Default::default()
};
let mw = CompressionMiddleware::new(FnHandler::new(ok_handler), config);
let req = make_request().with_header("Accept-Encoding", "identity;q=0, *;q=0");
let resp = mw.call(req);
assert_eq!(resp.status.as_u16(), 406);
assert_eq!(resp.body.as_ref(), b"No acceptable response encoding");
}
#[test]
fn compression_negotiates_encoding() {
fn large_handler() -> Response {
Response::new(StatusCode::OK, vec![b'x'; 512])
}
let config = CompressionConfig {
min_body_size: 256,
supported: vec![ContentEncoding::Gzip, ContentEncoding::Identity],
..CompressionConfig::default()
};
let mw = CompressionMiddleware::new(FnHandler::new(large_handler), config);
let req = make_request().with_header("Accept-Encoding", "gzip");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("vary"),
Some(&"accept-encoding".to_string())
);
#[cfg(feature = "compression")]
assert_eq!(
resp.headers.get("content-encoding"),
Some(&"gzip".to_string())
);
#[cfg(not(feature = "compression"))]
assert!(!resp.headers.contains_key("content-encoding"));
}
#[cfg(feature = "compression")]
#[test]
fn compression_removes_stale_content_length_after_body_rewrite() {
fn large_handler() -> Response {
Response::new(StatusCode::OK, vec![b'a'; 4096]).header("content-length", "4096")
}
let config = CompressionConfig {
min_body_size: 0,
supported: vec![ContentEncoding::Gzip, ContentEncoding::Identity],
..CompressionConfig::default()
};
let mw = CompressionMiddleware::new(FnHandler::new(large_handler), config);
let req = make_request().with_header("Accept-Encoding", "gzip");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("content-encoding"),
Some(&"gzip".to_string())
);
assert!(
!resp.headers.contains_key("content-length"),
"compressed responses must not retain stale content-length after body rewrite"
);
}
#[cfg(feature = "compression")]
#[test]
fn compression_cap_falls_back_to_identity_when_allowed() {
fn large_handler() -> Response {
Response::new(StatusCode::OK, vec![b'a'; 4096])
}
let config = CompressionConfig {
min_body_size: 0,
max_compressed_size: 1,
supported: vec![ContentEncoding::Gzip, ContentEncoding::Identity],
};
let mw = CompressionMiddleware::new(FnHandler::new(large_handler), config);
let req = make_request().with_header("Accept-Encoding", "gzip");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert!(!resp.headers.contains_key("content-encoding"));
assert_eq!(resp.body.as_ref(), &[b'a'; 4096]);
assert_eq!(
resp.headers.get("vary"),
Some(&"accept-encoding".to_string())
);
}
#[cfg(feature = "compression")]
#[test]
fn compression_cap_rejects_when_identity_disallowed() {
fn large_handler() -> Response {
Response::new(StatusCode::OK, vec![b'a'; 4096])
}
let config = CompressionConfig {
min_body_size: 0,
max_compressed_size: 1,
supported: vec![ContentEncoding::Gzip, ContentEncoding::Identity],
};
let mw = CompressionMiddleware::new(FnHandler::new(large_handler), config);
let req = make_request().with_header("Accept-Encoding", "gzip, identity;q=0");
let resp = mw.call(req);
assert_eq!(resp.status.as_u16(), 406);
assert_eq!(resp.body.as_ref(), b"No acceptable response encoding");
}
#[test]
fn compression_absent_accept_encoding_remains_permissive() {
fn large_handler() -> Response {
Response::new(StatusCode::OK, vec![b'x'; 512])
}
let config = CompressionConfig {
min_body_size: 256,
supported: vec![ContentEncoding::Gzip, ContentEncoding::Identity],
..CompressionConfig::default()
};
let mw = CompressionMiddleware::new(FnHandler::new(large_handler), config);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("vary"),
Some(&"accept-encoding".to_string())
);
}
#[test]
fn compression_empty_accept_encoding_is_not_treated_as_absent() {
fn large_handler() -> Response {
Response::new(StatusCode::OK, vec![b'x'; 512])
}
let config = CompressionConfig {
min_body_size: 256,
supported: vec![ContentEncoding::Gzip],
..CompressionConfig::default()
};
let mw = CompressionMiddleware::new(FnHandler::new(large_handler), config);
let req = make_request().with_header("Accept-Encoding", "");
let resp = mw.call(req);
assert_eq!(resp.status.as_u16(), 406);
assert_eq!(resp.body.as_ref(), b"No acceptable response encoding");
}
#[test]
fn compression_identity_passthrough() {
fn large_handler() -> Response {
Response::new(StatusCode::OK, vec![b'x'; 512])
}
let config = CompressionConfig {
min_body_size: 256,
supported: vec![ContentEncoding::Identity],
..CompressionConfig::default()
};
let mw = CompressionMiddleware::new(FnHandler::new(large_handler), config);
let req = make_request().with_header("Accept-Encoding", "identity");
let resp = mw.call(req);
assert!(!resp.headers.contains_key("content-encoding"));
}
#[cfg(feature = "compression")]
#[test]
fn compression_brotli_roundtrip() {
use crate::http::compress::{BrotliDecompressor, Decompressor};
fn large_handler() -> Response {
Response::new(StatusCode::OK, "brotli me".repeat(128).into_bytes())
}
let config = CompressionConfig {
min_body_size: 0,
supported: vec![ContentEncoding::Brotli, ContentEncoding::Identity],
..CompressionConfig::default()
};
let mw = CompressionMiddleware::new(FnHandler::new(large_handler), config);
let req = make_request().with_header("Accept-Encoding", "br");
let resp = mw.call(req);
assert_eq!(
resp.headers.get("content-encoding"),
Some(&"br".to_string())
);
let mut dec = BrotliDecompressor::new(None);
let mut decompressed = Vec::new();
dec.decompress(&resp.body, &mut decompressed).unwrap();
dec.finish(&mut decompressed).unwrap();
assert_eq!(decompressed, "brotli me".repeat(128).into_bytes());
}
#[test]
fn body_limit_allows_within_limit() {
let mw = RequestBodyLimitMiddleware::new(FnHandler::new(ok_handler), 1024);
let mut req = make_request();
req.body = vec![0u8; 512].into();
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn body_limit_rejects_over_limit() {
let mw = RequestBodyLimitMiddleware::new(FnHandler::new(ok_handler), 100);
let mut req = make_request();
req.body = vec![0u8; 200].into();
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::PAYLOAD_TOO_LARGE);
let body_str = String::from_utf8_lossy(&resp.body);
assert!(body_str.contains("200 bytes"));
assert!(body_str.contains("100 bytes"));
}
#[test]
fn body_limit_allows_exact_limit() {
let mw = RequestBodyLimitMiddleware::new(FnHandler::new(ok_handler), 100);
let mut req = make_request();
req.body = vec![0u8; 100].into();
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn body_limit_short_circuits_handler() {
let calls = Arc::new(std::sync::atomic::AtomicU32::new(0));
let handler = CountingHandler {
calls: Arc::clone(&calls),
delay: Duration::ZERO,
status: StatusCode::OK,
};
let mw = RequestBodyLimitMiddleware::new(handler, 10);
let mut req = make_request();
req.body = vec![0u8; 20].into();
let _ = mw.call(req);
assert_eq!(calls.load(std::sync::atomic::Ordering::SeqCst), 0);
}
#[test]
fn body_limit_middleware_checks_content_length_early_dos_prevention() {
use crate::bytes::Bytes;
let mw = RequestBodyLimitMiddleware::new(FnHandler::new(ok_handler), 1024);
let req = Request::new("POST", "/upload")
.with_header("content-length", "2097152") .with_body(Bytes::from_static(b"small"));
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::PAYLOAD_TOO_LARGE);
let body_str = String::from_utf8_lossy(&resp.body);
assert!(
body_str.contains("Content-Length"),
"Error should mention Content-Length check, got: {}",
body_str
);
assert!(
body_str.contains("2097152"),
"Error should mention declared Content-Length value"
);
}
#[test]
fn request_id_generates_id() {
let mw = RequestIdMiddleware::new(FnHandler::new(ok_handler), "x-request-id");
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
let id = resp.headers.get("x-request-id").unwrap();
assert!(id.starts_with("req-"));
}
#[test]
fn request_id_propagates_existing() {
let mw = RequestIdMiddleware::new(FnHandler::new(ok_handler), "x-request-id");
let req = make_request().with_header("x-request-id", "custom-42");
let resp = mw.call(req);
assert_eq!(
resp.headers.get("x-request-id"),
Some(&"custom-42".to_string())
);
}
#[test]
fn request_id_monotonic_counter() {
let counter = Arc::new(AtomicU64::new(100));
let mw = RequestIdMiddleware::shared(
FnHandler::new(ok_handler),
"x-request-id",
Arc::clone(&counter),
);
let resp1 = mw.call(make_request());
let resp2 = mw.call(make_request());
assert_eq!(
resp1.headers.get("x-request-id"),
Some(&"req-100".to_string())
);
assert_eq!(
resp2.headers.get("x-request-id"),
Some(&"req-101".to_string())
);
}
#[test]
fn request_id_stores_in_extensions() {
struct RequestIdEchoHandler;
impl Handler for RequestIdEchoHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async move {
req.extensions.get("request_id").map_or_else(
|| Response::new(StatusCode::BAD_REQUEST, b"no id".to_vec()),
|val| Response::new(StatusCode::OK, val.as_bytes().to_vec()),
)
})
}
}
let mw = RequestIdMiddleware::new(RequestIdEchoHandler, "x-request-id");
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
let body = String::from_utf8_lossy(&resp.body);
assert!(body.starts_with("req-"));
}
#[test]
fn request_id_truncates_oversize_client_supplied_value_to_default_128() {
let mw = RequestIdMiddleware::new(FnHandler::new(ok_handler), "x-request-id");
let huge = "A".repeat(4 * 1024);
let req = make_request().with_header("x-request-id", &huge);
let resp = mw.call(req);
let echoed = resp.headers.get("x-request-id").unwrap();
assert_eq!(
echoed.chars().count(),
DEFAULT_REQUEST_ID_MAX_LENGTH,
"echo header must be truncated to DEFAULT_REQUEST_ID_MAX_LENGTH (128 chars), \
got {} chars",
echoed.chars().count()
);
assert!(echoed.chars().all(|c| c == 'A'));
}
#[test]
fn request_id_with_max_length_overrides_default() {
let mw = RequestIdMiddleware::new(FnHandler::new(ok_handler), "x-request-id")
.with_max_length(16);
let huge = "B".repeat(1024);
let req = make_request().with_header("x-request-id", &huge);
let resp = mw.call(req);
let echoed = resp.headers.get("x-request-id").unwrap();
assert_eq!(echoed.chars().count(), 16);
}
#[test]
fn request_id_with_max_length_zero_falls_back_to_default() {
let mw =
RequestIdMiddleware::new(FnHandler::new(ok_handler), "x-request-id").with_max_length(0);
let huge = "C".repeat(4 * 1024);
let req = make_request().with_header("x-request-id", &huge);
let resp = mw.call(req);
let echoed = resp.headers.get("x-request-id").unwrap();
assert_eq!(echoed.chars().count(), DEFAULT_REQUEST_ID_MAX_LENGTH);
}
#[test]
fn request_id_truncate_respects_utf8_char_boundary() {
let s: String = std::iter::repeat_n('🦀', 50).collect();
let mw = RequestIdMiddleware::new(FnHandler::new(ok_handler), "x-request-id")
.with_max_length(10);
let req = make_request().with_header("x-request-id", &s);
let resp = mw.call(req);
let echoed = resp.headers.get("x-request-id").unwrap();
assert_eq!(echoed.chars().count(), 10);
assert_eq!(echoed.chars().filter(|c| *c == '🦀').count(), 10);
let _ = echoed.as_bytes();
}
#[test]
fn request_id_passes_through_short_client_value_unchanged() {
let mw = RequestIdMiddleware::new(FnHandler::new(ok_handler), "x-request-id");
let req = make_request().with_header("x-request-id", "abc-123");
let resp = mw.call(req);
assert_eq!(
resp.headers.get("x-request-id"),
Some(&"abc-123".to_string()),
"values under the cap must pass through verbatim"
);
}
#[test]
fn request_trace_injects_duration_and_trace_headers() {
let mw =
RequestTraceMiddleware::new(FnHandler::new(ok_handler), RequestTracePolicy::default());
let req = make_request().with_header("x-request-id", "trace-42");
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("x-trace-id"),
Some(&"trace-42".to_string())
);
let duration = resp
.headers
.get("x-response-time-ms")
.expect("duration header should be present");
assert!(
duration.parse::<u128>().is_ok(),
"duration header should be numeric: {duration}"
);
}
#[test]
fn request_trace_time_getter_can_drive_duration_header_without_sleep() {
set_request_trace_test_time(0);
let mw = RequestTraceMiddleware::with_time_getter(
AdvanceRequestTraceTimeHandler {
next_time_ms: 25,
body: b"traced",
},
RequestTracePolicy::default(),
request_trace_test_time,
);
let resp = mw.call(make_request().with_header("x-request-id", "trace-99"));
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("x-response-time-ms"),
Some(&"25".to_string())
);
assert_eq!(
resp.headers.get("x-trace-id"),
Some(&"trace-99".to_string())
);
assert_eq!(resp.body.as_ref(), b"traced");
}
#[test]
fn request_trace_can_disable_duration_header() {
let policy = RequestTracePolicy {
duration_header: None,
trace_header: Some("x-trace-id".to_string()),
};
let mw = RequestTraceMiddleware::new(FnHandler::new(ok_handler), policy);
let resp = mw.call(make_request().with_header("x-request-id", "trace-7"));
assert_eq!(resp.status, StatusCode::OK);
assert!(!resp.headers.contains_key("x-response-time-ms"));
assert_eq!(resp.headers.get("x-trace-id"), Some(&"trace-7".to_string()));
}
#[test]
fn request_trace_preserves_existing_trace_header() {
fn header_handler() -> Response {
Response::new(StatusCode::OK, b"ok".to_vec()).header("x-trace-id", "inner-trace")
}
let mw = RequestTraceMiddleware::new(
FnHandler::new(header_handler),
RequestTracePolicy::default(),
);
let resp = mw.call(make_request().with_header("x-request-id", "outer-trace"));
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.headers.get("x-trace-id"),
Some(&"inner-trace".to_string())
);
}
#[test]
fn request_trace_preserves_mixed_case_existing_trace_header_without_duplication() {
fn header_handler() -> Response {
let mut resp = Response::new(StatusCode::OK, b"ok".to_vec());
resp.headers
.insert("X-Trace-Id".to_string(), "inner-trace".to_string());
resp
}
let mw = RequestTraceMiddleware::new(
FnHandler::new(header_handler),
RequestTracePolicy::default(),
);
let resp = mw.call(make_request().with_header("x-request-id", "outer-trace"));
assert_eq!(resp.header_value("x-trace-id"), Some("inner-trace"));
assert_eq!(
resp.headers.len(),
2,
"only duration and trace headers should be present"
);
assert!(!resp.headers.contains_key("x-trace-id"));
}
#[test]
fn request_trace_normalizes_mixed_case_policy_headers() {
fn header_handler() -> Response {
Response::new(StatusCode::OK, b"ok".to_vec()).header("x-trace-id", "inner-trace")
}
let mw = RequestTraceMiddleware::new(
FnHandler::new(header_handler),
RequestTracePolicy {
duration_header: Some("X-Response-Time-Ms".to_string()),
trace_header: Some("X-Trace-Id".to_string()),
},
);
let resp = mw.call(make_request().with_header("x-request-id", "outer-trace"));
assert!(resp.headers.contains_key("x-response-time-ms"));
assert!(!resp.headers.contains_key("X-Response-Time-Ms"));
assert_eq!(
resp.headers.get("x-trace-id"),
Some(&"inner-trace".to_string())
);
assert!(!resp.headers.contains_key("X-Trace-Id"));
}
#[test]
fn request_trace_truncates_giant_x_request_id_header_dos_attack() {
let giant_id = "A".repeat(4 * 1024 * 1024); let mw =
RequestTraceMiddleware::new(FnHandler::new(ok_handler), RequestTracePolicy::default());
let req = make_request().with_header("x-request-id", &giant_id);
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
let trace_id = resp.headers.get("x-trace-id").unwrap();
assert_eq!(
trace_id.chars().count(),
DEFAULT_TRACE_ID_MAX_LENGTH,
"giant x-request-id must be truncated to prevent DoS amplification"
);
assert_eq!(trace_id, &"A".repeat(DEFAULT_TRACE_ID_MAX_LENGTH));
assert!(trace_id.chars().all(|c| c == 'A'));
}
#[test]
fn request_trace_sanitizes_crlf_in_x_request_id_header() {
let malicious_id = "trace\r\nX-Injected: malicious\r\n";
let mw =
RequestTraceMiddleware::new(FnHandler::new(ok_handler), RequestTracePolicy::default());
let req = make_request().with_header("x-request-id", malicious_id);
let resp = mw.call(req);
assert_eq!(resp.status, StatusCode::OK);
let trace_id = resp.headers.get("x-trace-id").unwrap();
assert_eq!(trace_id, "traceX-Injected: malicious");
assert!(!trace_id.contains('\r'));
assert!(!trace_id.contains('\n'));
}
#[test]
fn catch_panic_recovers() {
let mw = CatchPanicMiddleware::new(PanicHandler);
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::INTERNAL_SERVER_ERROR);
let body = String::from_utf8_lossy(&resp.body);
assert_eq!(body, "Internal Server Error");
}
#[test]
fn catch_panic_passes_normal_responses() {
let mw = CatchPanicMiddleware::new(FnHandler::new(ok_handler));
let resp = mw.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn normalize_path_trim_trailing_slash() {
let mw = NormalizePathMiddleware::new(FnHandler::new(ok_handler), TrailingSlash::Trim);
let resp = mw.call(Request::new("GET", "/api/users/"));
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn normalize_path_trim_preserves_root() {
struct PathEchoHandler;
impl Handler for PathEchoHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async move { Response::new(StatusCode::OK, req.path.into_bytes()) })
}
}
let mw = NormalizePathMiddleware::new(PathEchoHandler, TrailingSlash::Trim);
let resp = mw.call(Request::new("GET", "/"));
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(&resp.body[..], b"/");
}
#[test]
fn normalize_path_always_adds_slash() {
struct PathEchoHandler;
impl Handler for PathEchoHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async move { Response::new(StatusCode::OK, req.path.into_bytes()) })
}
}
let mw = NormalizePathMiddleware::new(PathEchoHandler, TrailingSlash::Always);
let resp = mw.call(Request::new("GET", "/api/users"));
assert_eq!(String::from_utf8_lossy(&resp.body), "/api/users/");
}
#[test]
fn normalize_path_always_skips_dotfiles() {
struct PathEchoHandler;
impl Handler for PathEchoHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> Pin<Box<dyn Future<Output = Response> + Send + '_>> {
Box::pin(async move { Response::new(StatusCode::OK, req.path.into_bytes()) })
}
}
let mw = NormalizePathMiddleware::new(PathEchoHandler, TrailingSlash::Always);
let resp = mw.call(Request::new("GET", "/style.css"));
assert_eq!(String::from_utf8_lossy(&resp.body), "/style.css");
}
#[test]
fn normalize_path_redirect_trim() {
let mw =
NormalizePathMiddleware::new(FnHandler::new(ok_handler), TrailingSlash::RedirectTrim);
let resp = mw.call(Request::new("GET", "/api/users/"));
assert_eq!(resp.status, StatusCode::MOVED_PERMANENTLY);
assert_eq!(
resp.headers.get("location"),
Some(&"/api/users".to_string())
);
}
#[test]
fn normalize_path_redirect_always() {
let mw =
NormalizePathMiddleware::new(FnHandler::new(ok_handler), TrailingSlash::RedirectAlways);
let resp = mw.call(Request::new("GET", "/api/users"));
assert_eq!(resp.status, StatusCode::MOVED_PERMANENTLY);
assert_eq!(
resp.headers.get("location"),
Some(&"/api/users/".to_string())
);
}
#[test]
fn normalize_path_redirect_trim_prevents_protocol_relative_open_redirect() {
let mw =
NormalizePathMiddleware::new(FnHandler::new(ok_handler), TrailingSlash::RedirectTrim);
let resp = mw.call(Request::new("GET", "//evil.com/"));
assert_eq!(resp.status, StatusCode::BAD_REQUEST);
assert!(!resp.headers.contains_key("location"));
}
#[test]
fn normalize_path_redirect_always_prevents_protocol_relative_open_redirect() {
let mw =
NormalizePathMiddleware::new(FnHandler::new(ok_handler), TrailingSlash::RedirectAlways);
let resp = mw.call(Request::new("GET", "//evil"));
assert_eq!(resp.status, StatusCode::BAD_REQUEST);
assert!(!resp.headers.contains_key("location"));
}
#[test]
fn normalize_path_redirect_handles_complex_protocol_relative_attacks() {
let mw =
NormalizePathMiddleware::new(FnHandler::new(ok_handler), TrailingSlash::RedirectTrim);
let resp = mw.call(Request::new("GET", "//attacker-host/path/"));
assert_eq!(resp.status, StatusCode::BAD_REQUEST);
assert!(!resp.headers.contains_key("location"));
let resp = mw.call(Request::new("GET", "///triple-slash.example/"));
assert_eq!(resp.status, StatusCode::BAD_REQUEST);
assert!(!resp.headers.contains_key("location"));
let mw_always =
NormalizePathMiddleware::new(FnHandler::new(ok_handler), TrailingSlash::RedirectAlways);
let resp = mw_always.call(Request::new("GET", "//evilhost"));
assert_eq!(resp.status, StatusCode::BAD_REQUEST);
assert!(!resp.headers.contains_key("location"));
}
#[test]
fn normalize_path_redirect_rejects_backslash_host_pivot() {
let mw =
NormalizePathMiddleware::new(FnHandler::new(ok_handler), TrailingSlash::RedirectAlways);
let resp = mw.call(Request::new("GET", "/\\\\attacker.com"));
assert_eq!(resp.status, StatusCode::BAD_REQUEST);
assert!(!resp.headers.contains_key("location"));
}
#[test]
fn normalize_path_redirect_rejects_percent_encoded_slash_host_pivot() {
let mw =
NormalizePathMiddleware::new(FnHandler::new(ok_handler), TrailingSlash::RedirectTrim);
let resp = mw.call(Request::new("GET", "/%2fevil.example/"));
assert_eq!(resp.status, StatusCode::BAD_REQUEST);
assert!(!resp.headers.contains_key("location"));
}
#[test]
fn set_header_always_overwrites() {
fn header_handler() -> Response {
Response::new(StatusCode::OK, b"ok".to_vec()).header("x-custom", "original")
}
let mw = SetResponseHeaderMiddleware::always(
FnHandler::new(header_handler),
"x-custom",
"overwritten",
);
let resp = mw.call(make_request());
assert_eq!(
resp.headers.get("x-custom"),
Some(&"overwritten".to_string())
);
}
#[test]
fn set_header_if_missing_preserves_existing() {
fn header_handler() -> Response {
Response::new(StatusCode::OK, b"ok".to_vec()).header("x-custom", "original")
}
let mw = SetResponseHeaderMiddleware::if_missing(
FnHandler::new(header_handler),
"x-custom",
"default",
);
let resp = mw.call(make_request());
assert_eq!(resp.headers.get("x-custom"), Some(&"original".to_string()));
}
#[test]
fn set_header_if_missing_adds_when_absent() {
let mw = SetResponseHeaderMiddleware::if_missing(
FnHandler::new(ok_handler),
"x-content-type-options",
"nosniff",
);
let resp = mw.call(make_request());
assert_eq!(
resp.headers.get("x-content-type-options"),
Some(&"nosniff".to_string())
);
}
#[test]
fn set_header_if_missing_normalizes_mixed_case_name() {
fn header_handler() -> Response {
Response::new(StatusCode::OK, b"ok".to_vec()).header("x-custom", "original")
}
let mw = SetResponseHeaderMiddleware::if_missing(
FnHandler::new(header_handler),
"X-Custom",
"new",
);
let resp = mw.call(make_request());
assert_eq!(resp.headers.get("x-custom"), Some(&"original".to_string()));
assert!(!resp.headers.contains_key("X-Custom"));
}
#[test]
fn set_header_if_missing_respects_mixed_case_existing_header() {
fn header_handler() -> Response {
let mut resp = Response::new(StatusCode::OK, b"ok".to_vec());
resp.headers
.insert("X-Custom".to_string(), "original".to_string());
resp
}
let mw = SetResponseHeaderMiddleware::if_missing(
FnHandler::new(header_handler),
"x-custom",
"new",
);
let resp = mw.call(make_request());
assert_eq!(resp.header_value("x-custom"), Some("original"));
assert_eq!(
resp.headers.len(),
1,
"if-missing should not create a duplicate logical header"
);
assert_eq!(resp.headers.get("x-custom"), Some(&"original".to_string()));
assert!(!resp.headers.contains_key("X-Custom"));
}
#[test]
fn middleware_stack_with_body_limit() {
let handler = MiddlewareStack::new(FnHandler::new(ok_handler))
.with_body_limit(1024)
.build();
let resp = handler.call(make_request());
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn middleware_stack_with_request_id() {
let handler = MiddlewareStack::new(FnHandler::new(ok_handler))
.with_request_id("x-request-id")
.build();
let resp = handler.call(make_request());
assert!(resp.headers.contains_key("x-request-id"));
}
#[test]
fn middleware_stack_with_request_trace() {
let handler = MiddlewareStack::new(FnHandler::new(ok_handler))
.with_request_trace(RequestTracePolicy::default())
.build();
let resp = handler.call(make_request().with_header("x-request-id", "trace-55"));
assert_eq!(resp.status, StatusCode::OK);
assert!(resp.headers.contains_key("x-response-time-ms"));
assert_eq!(
resp.headers.get("x-trace-id"),
Some(&"trace-55".to_string())
);
}
#[test]
fn middleware_stack_with_catch_panic() {
let handler = MiddlewareStack::new(PanicHandler)
.with_catch_panic()
.build();
let resp = handler.call(make_request());
assert_eq!(resp.status, StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn middleware_stack_full_production_composition() {
let handler = MiddlewareStack::new(FnHandler::new(ok_handler))
.with_catch_panic()
.with_body_limit(10 * 1024 * 1024)
.with_request_id("x-request-id")
.with_request_trace(RequestTracePolicy::default())
.with_normalize_path(TrailingSlash::Trim)
.with_timeout(Duration::from_secs(30))
.with_cors(CorsPolicy::default())
.with_rate_limit(RateLimitPolicy {
rate: 100,
burst: 50,
..Default::default()
})
.with_response_header(
"x-content-type-options",
"nosniff",
HeaderOverwrite::IfMissing,
)
.build();
let req = Request::new("GET", "/api/test/").with_header("Origin", "https://example.com");
let resp = handler.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert!(resp.headers.contains_key("x-request-id"));
assert!(resp.headers.contains_key("x-response-time-ms"));
assert!(resp.headers.contains_key("access-control-allow-origin"));
assert_eq!(
resp.headers.get("x-content-type-options"),
Some(&"nosniff".to_string())
);
}
mod rate_limiting_compliance_audit {
use super::*;
#[test]
fn audit_rate_limit_returns_429_too_many_requests() {
let policy = RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(60),
..Default::default()
};
let mw = RateLimitMiddleware::new(FnHandler::new(ok_handler), policy);
let resp1 = mw.call(make_request());
assert_eq!(resp1.status, StatusCode::OK, "First request should succeed");
let resp2 = mw.call(make_request());
assert_eq!(
resp2.status,
StatusCode::TOO_MANY_REQUESTS,
"Rate limited request must return 429 Too Many Requests per RFC 9110 §15.5.16"
);
}
#[test]
fn audit_retry_after_header_compliance() {
let policy = RateLimitPolicy {
rate: 2, burst: 1, period: Duration::from_secs(30), ..Default::default()
};
let mw = RateLimitMiddleware::new(FnHandler::new(ok_handler), policy);
let _ = mw.call(make_request());
let rate_limited = mw.call(make_request());
assert_eq!(rate_limited.status, StatusCode::TOO_MANY_REQUESTS);
assert!(
rate_limited.headers.contains_key("retry-after"),
"429 response must include Retry-After header per RFC 9110 §15.5.16"
);
let retry_after = rate_limited.headers.get("retry-after").unwrap();
let seconds: u64 = retry_after
.parse()
.expect("Retry-After should be numeric seconds");
assert!(
seconds > 0 && seconds <= 60,
"Retry-After should specify reasonable delay: {} seconds",
seconds
);
}
#[test]
fn audit_rate_limit_vs_queue_exhaustion_status_codes() {
let policy = RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(60),
..Default::default()
};
let mw = RateLimitMiddleware::new(FnHandler::new(ok_handler), policy);
let _ = mw.call(make_request()); let rate_limited = mw.call(make_request());
assert_eq!(
rate_limited.status,
StatusCode::TOO_MANY_REQUESTS,
"Rate limit exceeded should return 429"
);
assert!(
rate_limited.headers.contains_key("retry-after"),
"Rate limit 429 should include Retry-After"
);
}
#[test]
fn audit_rate_limit_error_message_format() {
let policy = RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(45),
..Default::default()
};
let mw = RateLimitMiddleware::new(FnHandler::new(ok_handler), policy);
let _ = mw.call(make_request());
let rate_limited = mw.call(make_request());
let body = String::from_utf8(rate_limited.body.to_vec())
.expect("Response body should be UTF-8");
assert!(
body.contains("Too Many Requests"),
"Error message should identify the 429 error type: {}",
body
);
assert!(
body.contains("retry after") || body.contains("Retry-After"),
"Error message should provide retry guidance: {}",
body
);
assert!(
body.contains("rate limit"),
"Error message should identify rate limiting as the cause: {}",
body
);
}
#[test]
fn audit_retry_after_calculation_accuracy() {
thread_local! {
static TEST_TIME: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
}
fn set_test_time(ms: u64) {
TEST_TIME.with(|t| t.set(ms));
}
fn test_time() -> Time {
Time::from_millis(TEST_TIME.with(std::cell::Cell::get))
}
let policy = RateLimitPolicy {
rate: 1, burst: 1, period: Duration::from_secs(120), ..Default::default()
};
let mw = RateLimitMiddleware::with_time_getter(
FnHandler::new(ok_handler),
policy,
test_time,
);
set_test_time(10_000); let _ = mw.call(make_request());
let rate_limited = mw.call(make_request());
assert_eq!(rate_limited.status, StatusCode::TOO_MANY_REQUESTS);
let retry_after = rate_limited.headers.get("retry-after").unwrap();
assert_eq!(
retry_after, "120",
"Retry-After should match rate limiter refill period"
);
set_test_time(70_000); let still_limited = mw.call(make_request());
let updated_retry_after = still_limited.headers.get("retry-after").unwrap();
assert_eq!(
updated_retry_after, "70",
"Retry-After should decrease as time progresses"
);
}
#[test]
fn audit_retry_after_format_rfc9110_compliance() {
let policy = RateLimitPolicy {
rate: 1,
burst: 1,
period: Duration::from_secs(300), ..Default::default()
};
let mw = RateLimitMiddleware::new(FnHandler::new(ok_handler), policy);
let _ = mw.call(make_request());
let rate_limited = mw.call(make_request());
let retry_after = rate_limited.headers.get("retry-after").unwrap();
let seconds: Result<u64, _> = retry_after.parse();
assert!(
seconds.is_ok(),
"Retry-After must be valid delay-seconds format per RFC 9110, got: {}",
retry_after
);
let seconds = seconds.unwrap();
assert!(
seconds <= 300,
"Retry-After should not exceed rate limit period"
);
assert!(
seconds >= 1,
"Retry-After should be at least 1 second (minimum meaningful delay)"
);
}
#[test]
fn audit_edge_case_zero_burst_handling() {
let policy = RateLimitPolicy {
rate: 1,
burst: 0, period: Duration::from_secs(60),
..Default::default()
};
let mw = RateLimitMiddleware::new(FnHandler::new(ok_handler), policy);
let resp = mw.call(make_request());
assert_eq!(
resp.status,
StatusCode::TOO_MANY_REQUESTS,
"Zero burst should rate limit immediately"
);
assert!(
resp.headers.contains_key("retry-after"),
"Zero burst 429 should include Retry-After"
);
}
}
}