use std::collections::HashSet;
use std::future::Future;
use std::ops::ControlFlow as StdControlFlow;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;
use crate::context::RequestContext;
use crate::dependency::DependencyOverrides;
use crate::logging::{LogConfig, RequestLogger};
use crate::request::{Body, Request};
use crate::response::Response;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
#[derive(Debug)]
pub enum ControlFlow {
Continue,
Break(Response),
}
impl ControlFlow {
#[must_use]
pub fn is_continue(&self) -> bool {
matches!(self, Self::Continue)
}
#[must_use]
pub fn is_break(&self) -> bool {
matches!(self, Self::Break(_))
}
}
impl From<ControlFlow> for StdControlFlow<Response, ()> {
fn from(cf: ControlFlow) -> Self {
match cf {
ControlFlow::Continue => StdControlFlow::Continue(()),
ControlFlow::Break(r) => StdControlFlow::Break(r),
}
}
}
pub trait Middleware: Send + Sync {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
Box::pin(async { ControlFlow::Continue })
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move { response })
}
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
}
pub trait Handler: Send + Sync {
fn call<'a>(&'a self, ctx: &'a RequestContext, req: &'a mut Request)
-> BoxFuture<'a, Response>;
fn dependency_overrides(&self) -> Option<Arc<DependencyOverrides>> {
None
}
}
impl<F, Fut> Handler for F
where
F: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync,
Fut: Future<Output = Response> + Send + 'static,
{
fn call<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, Response> {
let fut = self(ctx, req);
Box::pin(fut)
}
}
impl<H: Handler + ?Sized> Handler for Arc<H> {
fn call<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, Response> {
(**self).call(ctx, req)
}
fn dependency_overrides(&self) -> Option<Arc<DependencyOverrides>> {
(**self).dependency_overrides()
}
}
#[derive(Default)]
pub struct MiddlewareStack {
middleware: Vec<Arc<dyn Middleware>>,
}
impl MiddlewareStack {
#[must_use]
pub fn new() -> Self {
Self {
middleware: Vec::new(),
}
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
middleware: Vec::with_capacity(capacity),
}
}
pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
self.middleware.push(Arc::new(middleware));
}
pub fn push_arc(&mut self, middleware: Arc<dyn Middleware>) {
self.middleware.push(middleware);
}
#[must_use]
pub fn len(&self) -> usize {
self.middleware.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.middleware.is_empty()
}
pub async fn execute<H: Handler>(
&self,
handler: &H,
ctx: &RequestContext,
req: &mut Request,
) -> Response {
let mut ran_before_count = 0;
for mw in &self.middleware {
let _ = ctx.checkpoint();
match mw.before(ctx, req).await {
ControlFlow::Continue => {
ran_before_count += 1;
}
ControlFlow::Break(response) => {
return self
.run_after_hooks(ctx, req, response, ran_before_count)
.await;
}
}
}
let _ = ctx.checkpoint();
let response = handler.call(ctx, req).await;
self.run_after_hooks(ctx, req, response, ran_before_count)
.await
}
async fn run_after_hooks(
&self,
ctx: &RequestContext,
req: &Request,
mut response: Response,
count: usize,
) -> Response {
for mw in self.middleware[..count].iter().rev() {
let _ = ctx.checkpoint();
response = mw.after(ctx, req, response).await;
}
response
}
}
pub struct Layer<M> {
middleware: M,
}
impl<M: Middleware + Clone> Layer<M> {
pub fn new(middleware: M) -> Self {
Self { middleware }
}
pub fn wrap<H: Handler>(&self, handler: H) -> Layered<M, H> {
Layered {
middleware: self.middleware.clone(),
inner: handler,
}
}
}
pub struct Layered<M, H> {
middleware: M,
inner: H,
}
impl<M: Middleware, H: Handler> Handler for Layered<M, H> {
fn call<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
let _ = ctx.checkpoint();
match self.middleware.before(ctx, req).await {
ControlFlow::Continue => {
let _ = ctx.checkpoint();
let response = self.inner.call(ctx, req).await;
let _ = ctx.checkpoint();
self.middleware.after(ctx, req, response).await
}
ControlFlow::Break(response) => {
let _ = ctx.checkpoint();
self.middleware.after(ctx, req, response).await
}
}
})
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoopMiddleware;
impl Middleware for NoopMiddleware {
fn name(&self) -> &'static str {
"Noop"
}
}
#[derive(Debug, Clone)]
pub struct AddResponseHeader {
name: String,
value: Vec<u8>,
}
impl AddResponseHeader {
pub fn new(name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
Self {
name: name.into(),
value: value.into(),
}
}
}
impl Middleware for AddResponseHeader {
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let name = self.name.clone();
let value = self.value.clone();
Box::pin(async move { response.header(name, value) })
}
fn name(&self) -> &'static str {
"AddResponseHeader"
}
}
#[derive(Debug, Clone)]
pub struct RequireHeader {
name: String,
}
impl RequireHeader {
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
}
impl Middleware for RequireHeader {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
let has_header = req.headers().get(&self.name).is_some();
let name = self.name.clone();
Box::pin(async move {
if has_header {
ControlFlow::Continue
} else {
let body = format!("Missing required header: {name}");
ControlFlow::Break(
Response::with_status(crate::response::StatusCode::BAD_REQUEST)
.header("content-type", b"text/plain".to_vec())
.body(crate::response::ResponseBody::Bytes(body.into_bytes())),
)
}
})
}
fn name(&self) -> &'static str {
"RequireHeader"
}
}
#[derive(Debug, Clone)]
pub struct PathPrefixFilter {
prefix: String,
}
impl PathPrefixFilter {
pub fn new(prefix: impl Into<String>) -> Self {
Self {
prefix: prefix.into(),
}
}
}
impl Middleware for PathPrefixFilter {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
let path_matches = req.path().starts_with(&self.prefix);
Box::pin(async move {
if path_matches {
ControlFlow::Continue
} else {
ControlFlow::Break(Response::with_status(
crate::response::StatusCode::NOT_FOUND,
))
}
})
}
fn name(&self) -> &'static str {
"PathPrefixFilter"
}
}
#[derive(Debug, Clone)]
pub struct ConditionalStatus<F>
where
F: Fn(&Request) -> bool + Send + Sync,
{
condition: F,
status_if_true: crate::response::StatusCode,
status_if_false: crate::response::StatusCode,
}
impl<F> ConditionalStatus<F>
where
F: Fn(&Request) -> bool + Send + Sync,
{
pub fn new(
condition: F,
status_if_true: crate::response::StatusCode,
status_if_false: crate::response::StatusCode,
) -> Self {
Self {
condition,
status_if_true,
status_if_false,
}
}
}
impl<F> Middleware for ConditionalStatus<F>
where
F: Fn(&Request) -> bool + Send + Sync,
{
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let matches = (self.condition)(req);
let status = if matches {
self.status_if_true
} else {
self.status_if_false
};
Box::pin(async move { Response::with_status(status).body(response.body_ref().into()) })
}
fn name(&self) -> &'static str {
"ConditionalStatus"
}
}
#[derive(Debug, Clone)]
pub enum OriginPattern {
Any,
Exact(String),
Wildcard(String),
Regex(String),
}
impl OriginPattern {
fn matches(&self, origin: &str) -> bool {
match self {
Self::Any => true,
Self::Exact(value) => value == origin,
Self::Wildcard(pattern) => wildcard_match(pattern, origin),
Self::Regex(pattern) => regex_match(pattern, origin),
}
}
}
#[derive(Debug, Clone)]
pub struct CorsConfig {
allow_any_origin: bool,
allow_credentials: bool,
allowed_methods: Vec<crate::request::Method>,
allowed_headers: Vec<String>,
expose_headers: Vec<String>,
max_age: Option<u32>,
origins: Vec<OriginPattern>,
}
impl Default for CorsConfig {
fn default() -> Self {
Self {
allow_any_origin: false,
allow_credentials: false,
allowed_methods: vec![
crate::request::Method::Get,
crate::request::Method::Post,
crate::request::Method::Put,
crate::request::Method::Patch,
crate::request::Method::Delete,
crate::request::Method::Options,
crate::request::Method::Head,
],
allowed_headers: Vec::new(),
expose_headers: Vec::new(),
max_age: None,
origins: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct Cors {
config: CorsConfig,
}
impl Cors {
#[must_use]
pub fn new() -> Self {
Self {
config: CorsConfig::default(),
}
}
#[must_use]
pub fn config(mut self, config: CorsConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn allow_any_origin(mut self) -> Self {
self.config.allow_any_origin = true;
self
}
#[must_use]
pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
self.config
.origins
.push(OriginPattern::Exact(origin.into()));
self
}
#[must_use]
pub fn allow_origin_wildcard(mut self, pattern: impl Into<String>) -> Self {
self.config
.origins
.push(OriginPattern::Wildcard(pattern.into()));
self
}
#[must_use]
pub fn allow_origin_regex(mut self, pattern: impl Into<String>) -> Self {
self.config
.origins
.push(OriginPattern::Regex(pattern.into()));
self
}
#[must_use]
pub fn allow_credentials(mut self, allow: bool) -> Self {
self.config.allow_credentials = allow;
self
}
#[must_use]
pub fn allow_methods<I>(mut self, methods: I) -> Self
where
I: IntoIterator<Item = crate::request::Method>,
{
self.config.allowed_methods = methods.into_iter().collect();
self
}
#[must_use]
pub fn allow_headers<I, S>(mut self, headers: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.config.allowed_headers = headers.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn expose_headers<I, S>(mut self, headers: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.config.expose_headers = headers.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn max_age(mut self, seconds: u32) -> Self {
self.config.max_age = Some(seconds);
self
}
fn is_origin_allowed(&self, origin: &str) -> bool {
if self.config.allow_any_origin {
return true;
}
self.config
.origins
.iter()
.any(|pattern| pattern.matches(origin))
}
fn allow_origin_value(&self, origin: &str) -> Option<String> {
if !self.is_origin_allowed(origin) {
return None;
}
if self.config.allow_any_origin && !self.config.allow_credentials {
Some("*".to_string())
} else {
Some(origin.to_string())
}
}
fn allow_methods_value(&self) -> String {
self.config
.allowed_methods
.iter()
.map(|method| method.as_str())
.collect::<Vec<_>>()
.join(", ")
}
fn allow_headers_value(&self, request: &Request) -> Option<String> {
if self.config.allowed_headers.is_empty() {
return None;
}
if self.config.allowed_headers.iter().any(|h| h == "*") {
if self.config.allow_credentials {
return request
.headers()
.get("access-control-request-headers")
.and_then(|value| std::str::from_utf8(value).ok())
.map(ToString::to_string);
}
return Some("*".to_string());
}
Some(self.config.allowed_headers.join(", "))
}
fn apply_common_headers(&self, mut response: Response, origin: &str) -> Response {
if let Some(allow_origin) = self.allow_origin_value(origin) {
let is_wildcard = allow_origin == "*";
response = response.header("access-control-allow-origin", allow_origin.into_bytes());
if !is_wildcard {
response = response.header("vary", b"Origin".to_vec());
}
if self.config.allow_credentials {
response = response.header("access-control-allow-credentials", b"true".to_vec());
}
if !self.config.expose_headers.is_empty() {
response = response.header(
"access-control-expose-headers",
self.config.expose_headers.join(", ").into_bytes(),
);
}
}
response
}
}
impl Default for Cors {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct CorsOrigin(String);
impl Middleware for Cors {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
let origin = req
.headers()
.get("origin")
.and_then(|value| std::str::from_utf8(value).ok())
.map(ToString::to_string);
let Some(origin) = origin else {
return Box::pin(async { ControlFlow::Continue });
};
if !self.is_origin_allowed(&origin) {
let is_preflight = req.method() == crate::request::Method::Options
&& req.headers().get("access-control-request-method").is_some();
if is_preflight {
return Box::pin(async {
ControlFlow::Break(Response::with_status(
crate::response::StatusCode::FORBIDDEN,
))
});
}
return Box::pin(async { ControlFlow::Continue });
}
let is_preflight = req.method() == crate::request::Method::Options
&& req.headers().get("access-control-request-method").is_some();
if is_preflight {
let mut response = Response::no_content();
response = self.apply_common_headers(response, &origin);
response = response.header(
"access-control-allow-methods",
self.allow_methods_value().into_bytes(),
);
if let Some(value) = self.allow_headers_value(req) {
response = response.header("access-control-allow-headers", value.into_bytes());
}
if let Some(max_age) = self.config.max_age {
response =
response.header("access-control-max-age", max_age.to_string().into_bytes());
}
return Box::pin(async move { ControlFlow::Break(response) });
}
req.insert_extension(CorsOrigin(origin));
Box::pin(async { ControlFlow::Continue })
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let origin = req.get_extension::<CorsOrigin>().map(|v| v.0.clone());
Box::pin(async move {
if let Some(origin) = origin {
return self.apply_common_headers(response, &origin);
}
response
})
}
fn name(&self) -> &'static str {
"Cors"
}
}
fn wildcard_match(pattern: &str, value: &str) -> bool {
let mut pat_chars = pattern.chars().peekable();
let mut val_chars = value.chars().peekable();
let mut star = None;
let mut match_after_star = None;
while let Some(p) = pat_chars.next() {
match p {
'*' => {
star = Some(pat_chars.clone());
match_after_star = Some(val_chars.clone());
}
_ => {
if let Some(v) = val_chars.next() {
if p != v {
if let (Some(pat_backup), Some(val_backup)) =
(star.clone(), match_after_star.clone())
{
pat_chars = pat_backup;
val_chars = val_backup;
val_chars.next();
match_after_star = Some(val_chars.clone());
continue;
}
return false;
}
} else {
return false;
}
}
}
}
if pat_chars.peek().is_none() && val_chars.peek().is_none() {
return true;
}
if let Some(pat_backup) = star {
if val_chars.peek().is_none() {
let trailing = pat_backup;
for ch in trailing {
if ch != '*' {
return false;
}
}
return true;
}
}
val_chars.peek().is_none()
}
fn regex_match(pattern: &str, value: &str) -> bool {
let pat = pattern.as_bytes();
let text = value.as_bytes();
if pat.first() == Some(&b'^') {
return regex_match_here(&pat[1..], text);
}
let mut i = 0;
loop {
if regex_match_here(pat, &text[i..]) {
return true;
}
if i == text.len() {
break;
}
i += 1;
}
false
}
fn regex_match_here(pattern: &[u8], text: &[u8]) -> bool {
if pattern.is_empty() {
return true;
}
if pattern == b"$" {
return text.is_empty();
}
if pattern.len() >= 2 && pattern[1] == b'*' {
return regex_match_star(pattern[0], &pattern[2..], text);
}
if !text.is_empty() && (pattern[0] == b'.' || pattern[0] == text[0]) {
return regex_match_here(&pattern[1..], &text[1..]);
}
false
}
fn regex_match_star(ch: u8, pattern: &[u8], text: &[u8]) -> bool {
let mut i = 0;
loop {
if regex_match_here(pattern, &text[i..]) {
return true;
}
if i == text.len() {
return false;
}
if ch != b'.' && text[i] != ch {
return false;
}
i += 1;
}
}
#[derive(Debug, Clone)]
pub struct RequestResponseLogger {
log_config: LogConfig,
redact_headers: HashSet<String>,
log_request_headers: bool,
log_response_headers: bool,
log_body: bool,
max_body_bytes: usize,
}
impl Default for RequestResponseLogger {
fn default() -> Self {
Self {
log_config: LogConfig::production(),
redact_headers: default_redacted_headers(),
log_request_headers: true,
log_response_headers: true,
log_body: false,
max_body_bytes: 1024,
}
}
}
impl RequestResponseLogger {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn log_config(mut self, config: LogConfig) -> Self {
self.log_config = config;
self
}
#[must_use]
pub fn log_request_headers(mut self, enabled: bool) -> Self {
self.log_request_headers = enabled;
self
}
#[must_use]
pub fn log_response_headers(mut self, enabled: bool) -> Self {
self.log_response_headers = enabled;
self
}
#[must_use]
pub fn log_body(mut self, enabled: bool) -> Self {
self.log_body = enabled;
self
}
#[must_use]
pub fn max_body_bytes(mut self, max: usize) -> Self {
self.max_body_bytes = max;
self
}
#[must_use]
pub fn redact_header(mut self, name: impl Into<String>) -> Self {
self.redact_headers.insert(name.into().to_ascii_lowercase());
self
}
}
#[derive(Debug, Clone)]
struct RequestStart(Instant);
impl Middleware for RequestResponseLogger {
fn before<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
let logger = RequestLogger::new(ctx, self.log_config.clone());
req.insert_extension(RequestStart(Instant::now()));
let method = req.method();
let path = req.path();
let query = req.query();
let body_bytes = body_len(req.body());
logger.info_with_fields("request", |entry| {
let mut entry = entry
.field("method", method)
.field("path", path)
.field("body_bytes", body_bytes);
if let Some(q) = query {
entry = entry.field("query", q);
}
if self.log_request_headers {
let headers = format_headers(req.headers().iter(), &self.redact_headers);
entry = entry.field("headers", headers);
}
if self.log_body {
if let Some(body) = preview_body(req.body(), self.max_body_bytes) {
entry = entry.field("body", body);
}
}
entry
});
Box::pin(async { ControlFlow::Continue })
}
fn after<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let logger = RequestLogger::new(ctx, self.log_config.clone());
let duration = req
.get_extension::<RequestStart>()
.map(|start| start.0.elapsed())
.unwrap_or_default();
let status = response.status();
let body_bytes = response.body_ref().len();
logger.info_with_fields("response", |entry| {
let mut entry = entry
.field("status", status.as_u16())
.field("duration_us", duration.as_micros())
.field("body_bytes", body_bytes);
if self.log_response_headers {
let headers = format_response_headers(response.headers(), &self.redact_headers);
entry = entry.field("headers", headers);
}
if self.log_body {
if let Some(body) = preview_response_body(response.body_ref(), self.max_body_bytes)
{
entry = entry.field("body", body);
}
}
entry
});
Box::pin(async move { response })
}
fn name(&self) -> &'static str {
"RequestResponseLogger"
}
}
fn default_redacted_headers() -> HashSet<String> {
[
"authorization",
"proxy-authorization",
"cookie",
"set-cookie",
]
.iter()
.map(ToString::to_string)
.collect()
}
fn body_len(body: &Body) -> usize {
match body {
Body::Empty => 0,
Body::Bytes(bytes) => bytes.len(),
Body::Stream { content_length, .. } => content_length.unwrap_or(0),
}
}
fn preview_body(body: &Body, max_bytes: usize) -> Option<String> {
if max_bytes == 0 {
return None;
}
match body {
Body::Empty => None,
Body::Bytes(bytes) => {
if bytes.is_empty() {
None
} else {
Some(format_bytes(bytes, max_bytes))
}
}
Body::Stream { .. } => None,
}
}
fn preview_response_body(body: &crate::response::ResponseBody, max_bytes: usize) -> Option<String> {
if max_bytes == 0 {
return None;
}
match body {
crate::response::ResponseBody::Empty => None,
crate::response::ResponseBody::Bytes(bytes) => {
if bytes.is_empty() {
None
} else {
Some(format_bytes(bytes, max_bytes))
}
}
crate::response::ResponseBody::Stream(_) => None,
}
}
fn format_headers<'a>(
headers: impl Iterator<Item = (&'a str, &'a [u8])>,
redacted: &HashSet<String>,
) -> String {
let mut out = String::new();
for (idx, (name, value)) in headers.enumerate() {
if idx > 0 {
out.push_str(", ");
}
out.push_str(name);
out.push('=');
let lowered = name.to_ascii_lowercase();
if redacted.contains(&lowered) {
out.push_str("<redacted>");
continue;
}
match std::str::from_utf8(value) {
Ok(text) => out.push_str(text),
Err(_) => out.push_str("<binary>"),
}
}
out
}
fn format_response_headers(headers: &[(String, Vec<u8>)], redacted: &HashSet<String>) -> String {
format_headers(
headers
.iter()
.map(|(name, value)| (name.as_str(), value.as_slice())),
redacted,
)
}
fn format_bytes(bytes: &[u8], max_bytes: usize) -> String {
let limit = max_bytes.min(bytes.len());
match std::str::from_utf8(&bytes[..limit]) {
Ok(text) => {
let mut output = text.to_string();
if bytes.len() > max_bytes {
output.push_str("...");
}
output
}
Err(_) => format!("<{} bytes binary>", bytes.len()),
}
}
impl From<&crate::response::ResponseBody> for crate::response::ResponseBody {
fn from(body: &crate::response::ResponseBody) -> Self {
match body {
crate::response::ResponseBody::Empty => crate::response::ResponseBody::Empty,
crate::response::ResponseBody::Bytes(b) => {
crate::response::ResponseBody::Bytes(b.clone())
}
crate::response::ResponseBody::Stream(_) => crate::response::ResponseBody::Empty,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RequestId(pub String);
impl RequestId {
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub fn generate() -> Self {
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_micros() as u64)
.unwrap_or(0);
let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
Self(format!("{:x}-{:x}", timestamp, counter))
}
}
impl std::fmt::Display for RequestId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<String> for RequestId {
fn from(s: String) -> Self {
Self(s)
}
}
impl From<&str> for RequestId {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
#[derive(Debug, Clone)]
pub struct RequestIdConfig {
pub header_name: String,
pub accept_from_client: bool,
pub add_to_response: bool,
pub max_client_id_length: usize,
}
impl Default for RequestIdConfig {
fn default() -> Self {
Self {
header_name: "x-request-id".to_string(),
accept_from_client: true,
add_to_response: true,
max_client_id_length: 128,
}
}
}
impl RequestIdConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn header_name(mut self, name: impl Into<String>) -> Self {
self.header_name = name.into();
self
}
#[must_use]
pub fn accept_from_client(mut self, accept: bool) -> Self {
self.accept_from_client = accept;
self
}
#[must_use]
pub fn add_to_response(mut self, add: bool) -> Self {
self.add_to_response = add;
self
}
#[must_use]
pub fn max_client_id_length(mut self, max: usize) -> Self {
self.max_client_id_length = max;
self
}
}
#[derive(Debug, Clone)]
pub struct RequestIdMiddleware {
config: RequestIdConfig,
}
impl Default for RequestIdMiddleware {
fn default() -> Self {
Self::new()
}
}
impl RequestIdMiddleware {
#[must_use]
pub fn new() -> Self {
Self {
config: RequestIdConfig::default(),
}
}
#[must_use]
pub fn with_config(config: RequestIdConfig) -> Self {
Self { config }
}
fn get_or_generate_id(&self, req: &Request) -> RequestId {
if self.config.accept_from_client {
if let Some(header_value) = req.headers().get(&self.config.header_name) {
if let Ok(client_id) = std::str::from_utf8(header_value) {
if !client_id.is_empty()
&& client_id.len() <= self.config.max_client_id_length
&& is_valid_request_id(client_id)
{
return RequestId::new(client_id);
}
}
}
}
RequestId::generate()
}
}
fn is_valid_request_id(id: &str) -> bool {
!id.is_empty()
&& id
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
}
impl Middleware for RequestIdMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
let request_id = self.get_or_generate_id(req);
req.insert_extension(request_id);
Box::pin(async { ControlFlow::Continue })
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
if !self.config.add_to_response {
return Box::pin(async move { response });
}
let request_id = req.get_extension::<RequestId>().cloned();
let header_name = self.config.header_name.clone();
Box::pin(async move {
if let Some(id) = request_id {
response.header(header_name, id.0.into_bytes())
} else {
response
}
})
}
fn name(&self) -> &'static str {
"RequestId"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum XFrameOptions {
Deny,
SameOrigin,
}
impl XFrameOptions {
fn as_bytes(self) -> &'static [u8] {
match self {
Self::Deny => b"DENY",
Self::SameOrigin => b"SAMEORIGIN",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReferrerPolicy {
NoReferrer,
NoReferrerWhenDowngrade,
Origin,
OriginWhenCrossOrigin,
SameOrigin,
StrictOrigin,
StrictOriginWhenCrossOrigin,
UnsafeUrl,
}
impl ReferrerPolicy {
fn as_bytes(self) -> &'static [u8] {
match self {
Self::NoReferrer => b"no-referrer",
Self::NoReferrerWhenDowngrade => b"no-referrer-when-downgrade",
Self::Origin => b"origin",
Self::OriginWhenCrossOrigin => b"origin-when-cross-origin",
Self::SameOrigin => b"same-origin",
Self::StrictOrigin => b"strict-origin",
Self::StrictOriginWhenCrossOrigin => b"strict-origin-when-cross-origin",
Self::UnsafeUrl => b"unsafe-url",
}
}
}
#[derive(Debug, Clone)]
pub struct SecurityHeadersConfig {
pub x_content_type_options: Option<&'static str>,
pub x_frame_options: Option<XFrameOptions>,
pub x_xss_protection: Option<&'static str>,
pub content_security_policy: Option<String>,
pub hsts: Option<(u64, bool, bool)>,
pub referrer_policy: Option<ReferrerPolicy>,
pub permissions_policy: Option<String>,
}
impl Default for SecurityHeadersConfig {
fn default() -> Self {
Self {
x_content_type_options: Some("nosniff"),
x_frame_options: Some(XFrameOptions::Deny),
x_xss_protection: Some("0"),
content_security_policy: None,
hsts: None,
referrer_policy: Some(ReferrerPolicy::StrictOriginWhenCrossOrigin),
permissions_policy: None,
}
}
}
impl SecurityHeadersConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn none() -> Self {
Self {
x_content_type_options: None,
x_frame_options: None,
x_xss_protection: None,
content_security_policy: None,
hsts: None,
referrer_policy: None,
permissions_policy: None,
}
}
#[must_use]
pub fn strict() -> Self {
Self {
x_content_type_options: Some("nosniff"),
x_frame_options: Some(XFrameOptions::Deny),
x_xss_protection: Some("0"),
content_security_policy: Some("default-src 'self'".to_string()),
hsts: Some((31536000, true, false)), referrer_policy: Some(ReferrerPolicy::NoReferrer),
permissions_policy: Some("geolocation=(), camera=(), microphone=()".to_string()),
}
}
#[must_use]
pub fn x_content_type_options(mut self, value: Option<&'static str>) -> Self {
self.x_content_type_options = value;
self
}
#[must_use]
pub fn x_frame_options(mut self, value: Option<XFrameOptions>) -> Self {
self.x_frame_options = value;
self
}
#[must_use]
pub fn x_xss_protection(mut self, value: Option<&'static str>) -> Self {
self.x_xss_protection = value;
self
}
#[must_use]
pub fn content_security_policy(mut self, value: impl Into<String>) -> Self {
self.content_security_policy = Some(value.into());
self
}
#[must_use]
pub fn no_content_security_policy(mut self) -> Self {
self.content_security_policy = None;
self
}
#[must_use]
pub fn hsts(mut self, max_age: u64, include_sub_domains: bool, preload: bool) -> Self {
self.hsts = Some((max_age, include_sub_domains, preload));
self
}
#[must_use]
pub fn no_hsts(mut self) -> Self {
self.hsts = None;
self
}
#[must_use]
pub fn referrer_policy(mut self, value: Option<ReferrerPolicy>) -> Self {
self.referrer_policy = value;
self
}
#[must_use]
pub fn permissions_policy(mut self, value: impl Into<String>) -> Self {
self.permissions_policy = Some(value.into());
self
}
#[must_use]
pub fn no_permissions_policy(mut self) -> Self {
self.permissions_policy = None;
self
}
fn build_hsts_value(&self) -> Option<String> {
self.hsts.map(|(max_age, include_sub, preload)| {
let mut value = format!("max-age={}", max_age);
if include_sub {
value.push_str("; includeSubDomains");
}
if preload {
value.push_str("; preload");
}
value
})
}
}
#[derive(Debug, Clone)]
pub struct SecurityHeaders {
config: SecurityHeadersConfig,
}
impl Default for SecurityHeaders {
fn default() -> Self {
Self::new()
}
}
impl SecurityHeaders {
#[must_use]
pub fn new() -> Self {
Self {
config: SecurityHeadersConfig::default(),
}
}
#[must_use]
pub fn with_config(config: SecurityHeadersConfig) -> Self {
Self { config }
}
#[must_use]
pub fn strict() -> Self {
Self {
config: SecurityHeadersConfig::strict(),
}
}
}
impl Middleware for SecurityHeaders {
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let config = self.config.clone();
Box::pin(async move {
let mut resp = response;
if let Some(value) = config.x_content_type_options {
resp = resp.header("X-Content-Type-Options", value.as_bytes().to_vec());
}
if let Some(value) = config.x_frame_options {
resp = resp.header("X-Frame-Options", value.as_bytes().to_vec());
}
if let Some(value) = config.x_xss_protection {
resp = resp.header("X-XSS-Protection", value.as_bytes().to_vec());
}
if let Some(ref value) = config.content_security_policy {
resp = resp.header("Content-Security-Policy", value.as_bytes().to_vec());
}
if let Some(ref hsts_value) = config.build_hsts_value() {
resp = resp.header("Strict-Transport-Security", hsts_value.as_bytes().to_vec());
}
if let Some(value) = config.referrer_policy {
resp = resp.header("Referrer-Policy", value.as_bytes().to_vec());
}
if let Some(ref value) = config.permissions_policy {
resp = resp.header("Permissions-Policy", value.as_bytes().to_vec());
}
resp
})
}
fn name(&self) -> &'static str {
"SecurityHeaders"
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CsrfToken(pub String);
impl CsrfToken {
#[must_use]
pub fn new(token: impl Into<String>) -> Self {
Self(token.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub fn generate() -> Self {
let bytes = Self::read_urandom(32).unwrap_or_else(|_| {
panic!(
"FATAL: Cryptographically secure random source (/dev/urandom) is unavailable. \
CSRF token generation requires a CSPRNG. Cannot safely generate CSRF tokens \
without cryptographic entropy."
);
});
Self(Self::bytes_to_hex(&bytes))
}
fn read_urandom(len: usize) -> std::io::Result<Vec<u8>> {
use std::io::Read;
let mut f = std::fs::File::open("/dev/urandom")?;
let mut buf = vec![0u8; len];
f.read_exact(&mut buf)?;
Ok(buf)
}
fn bytes_to_hex(bytes: &[u8]) -> String {
use std::fmt::Write;
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
let _ = write!(s, "{b:02x}");
}
s
}
}
impl std::fmt::Display for CsrfToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl From<&str> for CsrfToken {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CsrfMode {
#[default]
DoubleSubmit,
HeaderOnly,
}
#[derive(Debug, Clone)]
pub struct CsrfConfig {
pub cookie_name: String,
pub header_name: String,
pub mode: CsrfMode,
pub rotate_token: bool,
pub production: bool,
pub error_message: Option<String>,
}
impl Default for CsrfConfig {
fn default() -> Self {
Self {
cookie_name: "csrf_token".to_string(),
header_name: "x-csrf-token".to_string(),
mode: CsrfMode::DoubleSubmit,
rotate_token: false,
production: true,
error_message: None,
}
}
}
impl CsrfConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
self.cookie_name = name.into();
self
}
#[must_use]
pub fn header_name(mut self, name: impl Into<String>) -> Self {
self.header_name = name.into();
self
}
#[must_use]
pub fn mode(mut self, mode: CsrfMode) -> Self {
self.mode = mode;
self
}
#[must_use]
pub fn rotate_token(mut self, rotate: bool) -> Self {
self.rotate_token = rotate;
self
}
#[must_use]
pub fn production(mut self, production: bool) -> Self {
self.production = production;
self
}
#[must_use]
pub fn error_message(mut self, message: impl Into<String>) -> Self {
self.error_message = Some(message.into());
self
}
}
#[derive(Debug, Clone)]
pub struct CsrfMiddleware {
config: CsrfConfig,
}
impl Default for CsrfMiddleware {
fn default() -> Self {
Self::new()
}
}
impl CsrfMiddleware {
#[must_use]
pub fn new() -> Self {
Self {
config: CsrfConfig::default(),
}
}
#[must_use]
pub fn with_config(config: CsrfConfig) -> Self {
Self { config }
}
fn is_safe_method(method: crate::request::Method) -> bool {
matches!(
method,
crate::request::Method::Get
| crate::request::Method::Head
| crate::request::Method::Options
| crate::request::Method::Trace
)
}
fn get_cookie_token(&self, req: &Request) -> Option<String> {
let cookie_header = req.headers().get("cookie")?;
let cookie_str = std::str::from_utf8(cookie_header).ok()?;
for part in cookie_str.split(';') {
let part = part.trim();
if let Some((name, value)) = part.split_once('=') {
if name.trim() == self.config.cookie_name {
return Some(value.trim().to_string());
}
}
}
None
}
fn get_header_token(&self, req: &Request) -> Option<String> {
let header_value = req.headers().get(&self.config.header_name)?;
std::str::from_utf8(header_value)
.ok()
.map(|s| s.trim().to_string())
}
fn validate_token(&self, req: &Request) -> Result<Option<CsrfToken>, Response> {
let header_token = self.get_header_token(req);
match self.config.mode {
CsrfMode::DoubleSubmit => {
let cookie_token = self.get_cookie_token(req);
match (header_token, cookie_token) {
(Some(header), Some(cookie))
if !header.is_empty()
&& crate::password::constant_time_eq(
header.as_bytes(),
cookie.as_bytes(),
) =>
{
Ok(Some(CsrfToken::new(header)))
}
(None, _) | (_, None) => Err(self.csrf_error_response("CSRF token missing")),
_ => Err(self.csrf_error_response("CSRF token mismatch")),
}
}
CsrfMode::HeaderOnly => match header_token {
Some(token) if !token.is_empty() => Ok(Some(CsrfToken::new(token))),
_ => Err(self.csrf_error_response("CSRF token missing in header")),
},
}
}
fn csrf_error_response(&self, default_message: &str) -> Response {
let message = self
.config
.error_message
.as_deref()
.unwrap_or(default_message);
let detail = serde_json::json!({
"detail": [{
"type": "csrf_error",
"loc": ["header", self.config.header_name],
"msg": message,
}]
});
let body = detail.to_string();
Response::with_status(crate::response::StatusCode::FORBIDDEN)
.header("content-type", b"application/json".to_vec())
.body(crate::response::ResponseBody::Bytes(body.into_bytes()))
}
fn make_set_cookie_header_value(cookie_name: &str, token: &str, production: bool) -> Vec<u8> {
let mut cookie = format!("{}={}; Path=/; SameSite=Strict", cookie_name, token);
if production {
cookie.push_str("; Secure");
}
cookie.into_bytes()
}
}
impl Middleware for CsrfMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
Box::pin(async move {
if Self::is_safe_method(req.method()) {
let existing_token = self.get_cookie_token(req);
let token = existing_token
.map(CsrfToken::new)
.unwrap_or_else(CsrfToken::generate);
req.insert_extension(token);
ControlFlow::Continue
} else {
match self.validate_token(req) {
Ok(Some(token)) => {
req.insert_extension(token);
ControlFlow::Continue
}
Ok(None) => ControlFlow::Continue,
Err(response) => ControlFlow::Break(response),
}
}
})
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let config = self.config.clone();
let is_safe = Self::is_safe_method(req.method());
let existing_cookie_token = self.get_cookie_token(req);
let token = req.get_extension::<CsrfToken>().cloned();
Box::pin(async move {
if is_safe {
let should_set_cookie = existing_cookie_token.is_none() || config.rotate_token;
if should_set_cookie {
if let Some(token) = token {
let cookie_value = Self::make_set_cookie_header_value(
&config.cookie_name,
token.as_str(),
config.production,
);
return response.header("set-cookie", cookie_value);
}
}
}
response
})
}
fn name(&self) -> &'static str {
"CSRF"
}
}
#[cfg(feature = "compression")]
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub min_size: usize,
pub level: u32,
pub skip_content_types: Vec<&'static str>,
}
#[cfg(feature = "compression")]
impl Default for CompressionConfig {
fn default() -> Self {
Self {
min_size: 1024,
level: 6,
skip_content_types: vec![
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"image/avif",
"video/",
"audio/",
"application/zip",
"application/gzip",
"application/x-gzip",
"application/x-bzip2",
"application/x-xz",
"application/x-7z-compressed",
"application/x-rar-compressed",
"application/pdf",
"application/woff",
"application/woff2",
"font/woff",
"font/woff2",
],
}
}
}
#[cfg(feature = "compression")]
impl CompressionConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn min_size(mut self, size: usize) -> Self {
self.min_size = size;
self
}
#[must_use]
pub fn level(mut self, level: u32) -> Self {
self.level = level.clamp(1, 9);
self
}
#[must_use]
pub fn skip_content_type(mut self, content_type: &'static str) -> Self {
self.skip_content_types.push(content_type);
self
}
fn should_skip_content_type(&self, content_type: &str) -> bool {
let ct_lower = content_type.to_ascii_lowercase();
for skip in &self.skip_content_types {
if skip.ends_with('/') {
if ct_lower.starts_with(*skip) {
return true;
}
} else {
if ct_lower == *skip || ct_lower.starts_with(&format!("{skip};")) {
return true;
}
}
}
false
}
}
#[cfg(feature = "compression")]
#[derive(Debug, Clone)]
pub struct CompressionMiddleware {
config: CompressionConfig,
}
#[cfg(feature = "compression")]
impl Default for CompressionMiddleware {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "compression")]
impl CompressionMiddleware {
#[must_use]
pub fn new() -> Self {
Self {
config: CompressionConfig::default(),
}
}
#[must_use]
pub fn with_config(config: CompressionConfig) -> Self {
Self { config }
}
fn accepts_gzip(req: &Request) -> bool {
if let Some(accept_encoding) = req.headers().get("accept-encoding") {
if let Ok(value) = std::str::from_utf8(accept_encoding) {
for part in value.split(',') {
let encoding = part.trim().split(';').next().unwrap_or("").trim();
if encoding.eq_ignore_ascii_case("gzip") {
return true;
}
if encoding == "*" {
return true;
}
}
}
}
false
}
fn get_content_type(headers: &[(String, Vec<u8>)]) -> Option<String> {
for (name, value) in headers {
if name.eq_ignore_ascii_case("content-type") {
return std::str::from_utf8(value).ok().map(String::from);
}
}
None
}
fn has_content_encoding(headers: &[(String, Vec<u8>)]) -> bool {
headers
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"))
}
fn compress_gzip(data: &[u8], level: u32) -> Result<Vec<u8>, std::io::Error> {
use flate2::Compression;
use flate2::write::GzEncoder;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
encoder.write_all(data)?;
encoder.finish()
}
}
#[cfg(feature = "compression")]
impl Middleware for CompressionMiddleware {
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let config = self.config.clone();
Box::pin(async move {
if !Self::accepts_gzip(req) {
return response;
}
let (status, headers, body) = response.into_parts();
if Self::has_content_encoding(&headers) {
return Response::with_status(status)
.body(body)
.rebuild_with_headers(headers);
}
let body_bytes = match body {
crate::response::ResponseBody::Bytes(bytes) => bytes,
other => {
return Response::with_status(status)
.body(other)
.rebuild_with_headers(headers);
}
};
if body_bytes.len() < config.min_size {
return Response::with_status(status)
.body(crate::response::ResponseBody::Bytes(body_bytes))
.rebuild_with_headers(headers);
}
if let Some(content_type) = Self::get_content_type(&headers) {
if config.should_skip_content_type(&content_type) {
return Response::with_status(status)
.body(crate::response::ResponseBody::Bytes(body_bytes))
.rebuild_with_headers(headers);
}
}
match Self::compress_gzip(&body_bytes, config.level) {
Ok(compressed) => {
if compressed.len() >= body_bytes.len() {
return Response::with_status(status)
.body(crate::response::ResponseBody::Bytes(body_bytes))
.rebuild_with_headers(headers);
}
let mut resp = Response::with_status(status)
.body(crate::response::ResponseBody::Bytes(compressed));
for (name, value) in headers {
if !name.eq_ignore_ascii_case("content-length") {
resp = resp.header(name, value);
}
}
resp = resp.header("Content-Encoding", b"gzip".to_vec());
resp = resp.header("Vary", b"Accept-Encoding".to_vec());
resp
}
Err(_) => {
Response::with_status(status)
.body(crate::response::ResponseBody::Bytes(body_bytes))
.rebuild_with_headers(headers)
}
}
})
}
fn name(&self) -> &'static str {
"Compression"
}
}
use parking_lot::Mutex;
use std::collections::HashMap as StdHashMap;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitAlgorithm {
TokenBucket,
FixedWindow,
SlidingWindow,
}
#[derive(Debug, Clone)]
pub struct RateLimitResult {
pub allowed: bool,
pub limit: u64,
pub remaining: u64,
pub reset_after_secs: u64,
}
pub trait KeyExtractor: Send + Sync {
fn extract_key(&self, req: &Request) -> Option<String>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RemoteAddr(pub std::net::IpAddr);
impl std::fmt::Display for RemoteAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone)]
pub struct ConnectedIpKeyExtractor;
impl KeyExtractor for ConnectedIpKeyExtractor {
fn extract_key(&self, req: &Request) -> Option<String> {
req.get_extension::<RemoteAddr>().map(ToString::to_string)
}
}
#[derive(Debug, Clone)]
pub struct IpKeyExtractor;
impl KeyExtractor for IpKeyExtractor {
fn extract_key(&self, req: &Request) -> Option<String> {
if let Some(forwarded) = req.headers().get("x-forwarded-for") {
if let Ok(s) = std::str::from_utf8(forwarded) {
if let Some(ip) = s.split(',').next() {
return Some(ip.trim().to_string());
}
}
}
if let Some(real_ip) = req.headers().get("x-real-ip") {
if let Ok(s) = std::str::from_utf8(real_ip) {
return Some(s.trim().to_string());
}
}
Some("unknown".to_string())
}
}
#[derive(Debug, Clone)]
pub struct TrustedProxyIpKeyExtractor {
trusted_cidrs: Vec<(std::net::IpAddr, u8)>,
}
impl TrustedProxyIpKeyExtractor {
#[must_use]
pub fn new() -> Self {
Self {
trusted_cidrs: Vec::new(),
}
}
#[must_use]
pub fn trust_cidr(mut self, cidr: &str) -> Self {
let (ip, prefix) = parse_cidr(cidr).expect("invalid CIDR notation");
self.trusted_cidrs.push((ip, prefix));
self
}
#[must_use]
pub fn trust_loopback(mut self) -> Self {
self.trusted_cidrs.push((
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 0)),
8,
));
self.trusted_cidrs
.push((std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST), 128));
self
}
fn is_trusted(&self, ip: std::net::IpAddr) -> bool {
self.trusted_cidrs
.iter()
.any(|(cidr_ip, prefix)| ip_in_cidr(ip, *cidr_ip, *prefix))
}
fn extract_from_header(&self, req: &Request) -> Option<String> {
if let Some(forwarded) = req.headers().get("x-forwarded-for") {
if let Ok(s) = std::str::from_utf8(forwarded) {
if let Some(ip) = s.split(',').next() {
return Some(ip.trim().to_string());
}
}
}
if let Some(real_ip) = req.headers().get("x-real-ip") {
if let Ok(s) = std::str::from_utf8(real_ip) {
return Some(s.trim().to_string());
}
}
None
}
}
impl Default for TrustedProxyIpKeyExtractor {
fn default() -> Self {
Self::new()
}
}
impl KeyExtractor for TrustedProxyIpKeyExtractor {
fn extract_key(&self, req: &Request) -> Option<String> {
let remote = req.get_extension::<RemoteAddr>()?;
if self.is_trusted(remote.0) {
self.extract_from_header(req)
.or_else(|| Some(remote.to_string()))
} else {
Some(remote.to_string())
}
}
}
fn parse_cidr(cidr: &str) -> Option<(std::net::IpAddr, u8)> {
let (ip_str, prefix_str) = cidr.split_once('/')?;
let ip: std::net::IpAddr = ip_str.parse().ok()?;
let prefix: u8 = prefix_str.parse().ok()?;
let max_prefix = match ip {
std::net::IpAddr::V4(_) => 32,
std::net::IpAddr::V6(_) => 128,
};
if prefix > max_prefix {
return None;
}
Some((ip, prefix))
}
fn ip_in_cidr(ip: std::net::IpAddr, cidr_ip: std::net::IpAddr, prefix: u8) -> bool {
match (ip, cidr_ip) {
(std::net::IpAddr::V4(ip), std::net::IpAddr::V4(cidr)) => {
if prefix == 0 {
return true;
}
let ip_bits = u32::from(ip);
let cidr_bits = u32::from(cidr);
let mask = !0u32 << (32 - prefix);
(ip_bits & mask) == (cidr_bits & mask)
}
(std::net::IpAddr::V6(ip), std::net::IpAddr::V6(cidr)) => {
if prefix == 0 {
return true;
}
let ip_bits = u128::from(ip);
let cidr_bits = u128::from(cidr);
let mask = !0u128 << (128 - prefix);
(ip_bits & mask) == (cidr_bits & mask)
}
_ => false, }
}
#[derive(Debug, Clone)]
pub struct HeaderKeyExtractor {
header_name: String,
}
impl HeaderKeyExtractor {
#[must_use]
pub fn new(header_name: impl Into<String>) -> Self {
Self {
header_name: header_name.into(),
}
}
}
impl KeyExtractor for HeaderKeyExtractor {
fn extract_key(&self, req: &Request) -> Option<String> {
req.headers()
.get(&self.header_name)
.and_then(|v| std::str::from_utf8(v).ok())
.map(str::to_string)
}
}
#[derive(Debug, Clone)]
pub struct PathKeyExtractor;
impl KeyExtractor for PathKeyExtractor {
fn extract_key(&self, req: &Request) -> Option<String> {
Some(req.path().to_string())
}
}
pub struct CompositeKeyExtractor {
extractors: Vec<Box<dyn KeyExtractor>>,
}
impl CompositeKeyExtractor {
#[must_use]
pub fn new(extractors: Vec<Box<dyn KeyExtractor>>) -> Self {
Self { extractors }
}
}
impl KeyExtractor for CompositeKeyExtractor {
fn extract_key(&self, req: &Request) -> Option<String> {
let parts: Vec<String> = self
.extractors
.iter()
.filter_map(|e| e.extract_key(req))
.collect();
if parts.is_empty() {
None
} else {
Some(parts.join(":"))
}
}
}
#[derive(Debug, Clone)]
struct TokenBucketState {
tokens: f64,
last_refill: Instant,
}
#[derive(Debug, Clone)]
struct FixedWindowState {
count: u64,
window_start: Instant,
}
#[derive(Debug, Clone)]
struct SlidingWindowState {
current_count: u64,
previous_count: u64,
current_window_start: Instant,
}
pub struct InMemoryRateLimitStore {
token_buckets: Mutex<StdHashMap<String, TokenBucketState>>,
fixed_windows: Mutex<StdHashMap<String, FixedWindowState>>,
sliding_windows: Mutex<StdHashMap<String, SlidingWindowState>>,
}
impl InMemoryRateLimitStore {
#[must_use]
pub fn new() -> Self {
Self {
token_buckets: Mutex::new(StdHashMap::new()),
fixed_windows: Mutex::new(StdHashMap::new()),
sliding_windows: Mutex::new(StdHashMap::new()),
}
}
#[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)]
fn check_token_bucket(
&self,
key: &str,
max_tokens: u64,
refill_rate: f64,
window: Duration,
) -> RateLimitResult {
let mut buckets = self.token_buckets.lock();
let now = Instant::now();
let state = buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucketState {
tokens: max_tokens as f64,
last_refill: now,
});
let elapsed = now.duration_since(state.last_refill);
let refill = elapsed.as_secs_f64() * refill_rate;
state.tokens = (state.tokens + refill).min(max_tokens as f64);
state.last_refill = now;
if state.tokens >= 1.0 {
state.tokens -= 1.0;
RateLimitResult {
allowed: true,
limit: max_tokens,
remaining: state.tokens as u64,
reset_after_secs: if state.tokens < max_tokens as f64 {
((max_tokens as f64 - state.tokens) / refill_rate).ceil() as u64
} else {
window.as_secs()
},
}
} else {
let wait_secs = ((1.0 - state.tokens) / refill_rate).ceil() as u64;
RateLimitResult {
allowed: false,
limit: max_tokens,
remaining: 0,
reset_after_secs: wait_secs,
}
}
}
fn check_fixed_window(
&self,
key: &str,
max_requests: u64,
window: Duration,
) -> RateLimitResult {
let mut windows = self.fixed_windows.lock();
let now = Instant::now();
let state = windows
.entry(key.to_string())
.or_insert_with(|| FixedWindowState {
count: 0,
window_start: now,
});
let elapsed = now.duration_since(state.window_start);
if elapsed >= window {
state.count = 0;
state.window_start = now;
}
let remaining_time = window
.checked_sub(now.duration_since(state.window_start))
.unwrap_or(Duration::ZERO);
if state.count < max_requests {
state.count += 1;
RateLimitResult {
allowed: true,
limit: max_requests,
remaining: max_requests - state.count,
reset_after_secs: remaining_time.as_secs(),
}
} else {
RateLimitResult {
allowed: false,
limit: max_requests,
remaining: 0,
reset_after_secs: remaining_time.as_secs(),
}
}
}
#[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)]
fn check_sliding_window(
&self,
key: &str,
max_requests: u64,
window: Duration,
) -> RateLimitResult {
let mut windows = self.sliding_windows.lock();
let now = Instant::now();
let state = windows
.entry(key.to_string())
.or_insert_with(|| SlidingWindowState {
current_count: 0,
previous_count: 0,
current_window_start: now,
});
let elapsed = now.duration_since(state.current_window_start);
if elapsed >= window {
state.previous_count = state.current_count;
state.current_count = 0;
state.current_window_start = now;
}
let window_elapsed = now.duration_since(state.current_window_start);
let window_fraction = window_elapsed.as_secs_f64() / window.as_secs_f64();
let previous_weight = 1.0 - window_fraction;
let weighted_count =
(state.previous_count as f64 * previous_weight) + state.current_count as f64;
let remaining_time = window.checked_sub(window_elapsed).unwrap_or(Duration::ZERO);
if weighted_count < max_requests as f64 {
state.current_count += 1;
let new_weighted =
(state.previous_count as f64 * previous_weight) + state.current_count as f64;
let remaining = (max_requests as f64 - new_weighted).max(0.0) as u64;
RateLimitResult {
allowed: true,
limit: max_requests,
remaining,
reset_after_secs: remaining_time.as_secs(),
}
} else {
RateLimitResult {
allowed: false,
limit: max_requests,
remaining: 0,
reset_after_secs: remaining_time.as_secs(),
}
}
}
#[allow(clippy::cast_precision_loss)]
pub fn check(
&self,
key: &str,
algorithm: RateLimitAlgorithm,
max_requests: u64,
window: Duration,
) -> RateLimitResult {
match algorithm {
RateLimitAlgorithm::TokenBucket => {
let refill_rate = max_requests as f64 / window.as_secs_f64();
self.check_token_bucket(key, max_requests, refill_rate, window)
}
RateLimitAlgorithm::FixedWindow => self.check_fixed_window(key, max_requests, window),
RateLimitAlgorithm::SlidingWindow => {
self.check_sliding_window(key, max_requests, window)
}
}
}
}
impl Default for InMemoryRateLimitStore {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct RateLimitConfig {
pub max_requests: u64,
pub window: Duration,
pub algorithm: RateLimitAlgorithm,
pub include_headers: bool,
pub retry_message: String,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window: Duration::from_secs(60),
algorithm: RateLimitAlgorithm::TokenBucket,
include_headers: true,
retry_message: "Rate limit exceeded. Please retry later.".to_string(),
}
}
}
pub struct RateLimitBuilder {
config: RateLimitConfig,
key_extractor: Option<Box<dyn KeyExtractor>>,
}
impl RateLimitBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config: RateLimitConfig::default(),
key_extractor: None,
}
}
#[must_use]
pub fn requests(mut self, max: u64) -> Self {
self.config.max_requests = max;
self
}
#[must_use]
pub fn per(mut self, window: Duration) -> Self {
self.config.window = window;
self
}
#[must_use]
pub fn per_second(self, secs: u64) -> Self {
self.per(Duration::from_secs(secs))
}
#[must_use]
pub fn per_minute(self, minutes: u64) -> Self {
self.per(Duration::from_secs(minutes * 60))
}
#[must_use]
pub fn per_hour(self, hours: u64) -> Self {
self.per(Duration::from_secs(hours * 3600))
}
#[must_use]
pub fn algorithm(mut self, algo: RateLimitAlgorithm) -> Self {
self.config.algorithm = algo;
self
}
#[must_use]
pub fn key_extractor(mut self, extractor: impl KeyExtractor + 'static) -> Self {
self.key_extractor = Some(Box::new(extractor));
self
}
#[must_use]
pub fn include_headers(mut self, include: bool) -> Self {
self.config.include_headers = include;
self
}
#[must_use]
pub fn retry_message(mut self, msg: impl Into<String>) -> Self {
self.config.retry_message = msg.into();
self
}
#[must_use]
pub fn build(self) -> RateLimitMiddleware {
let key_extractor = self
.key_extractor
.unwrap_or_else(|| Box::new(IpKeyExtractor));
RateLimitMiddleware {
config: self.config,
store: Arc::new(InMemoryRateLimitStore::new()),
key_extractor: Arc::from(key_extractor),
}
}
}
impl Default for RateLimitBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct RateLimitInfo {
result: RateLimitResult,
}
pub struct RateLimitMiddleware {
config: RateLimitConfig,
store: Arc<InMemoryRateLimitStore>,
key_extractor: Arc<dyn KeyExtractor>,
}
impl RateLimitMiddleware {
#[must_use]
pub fn new() -> Self {
Self::builder().build()
}
#[must_use]
pub fn builder() -> RateLimitBuilder {
RateLimitBuilder::new()
}
fn too_many_requests_body(&self, result: &RateLimitResult) -> Vec<u8> {
format!(
r#"{{"detail":"{}","retry_after_secs":{}}}"#,
self.config.retry_message, result.reset_after_secs
)
.into_bytes()
}
fn add_headers(&self, response: Response, result: &RateLimitResult) -> Response {
response
.header("X-RateLimit-Limit", result.limit.to_string().into_bytes())
.header(
"X-RateLimit-Remaining",
result.remaining.to_string().into_bytes(),
)
.header(
"X-RateLimit-Reset",
result.reset_after_secs.to_string().into_bytes(),
)
}
}
impl Default for RateLimitMiddleware {
fn default() -> Self {
Self::new()
}
}
impl Middleware for RateLimitMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
Box::pin(async move {
let Some(key) = self.key_extractor.extract_key(req) else {
return ControlFlow::Continue;
};
let result = self.store.check(
&key,
self.config.algorithm,
self.config.max_requests,
self.config.window,
);
if result.allowed {
req.insert_extension(RateLimitInfo { result });
ControlFlow::Continue
} else {
let body = self.too_many_requests_body(&result);
let mut response =
Response::with_status(crate::response::StatusCode::TOO_MANY_REQUESTS)
.header("Content-Type", b"application/json".to_vec())
.header(
"Retry-After",
result.reset_after_secs.to_string().into_bytes(),
)
.body(crate::response::ResponseBody::Bytes(body));
if self.config.include_headers {
response = self.add_headers(response, &result);
}
ControlFlow::Break(response)
}
})
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
if !self.config.include_headers {
return response;
}
if let Some(info) = req.get_extension::<RateLimitInfo>() {
self.add_headers(response, &info.result)
} else {
response
}
})
}
fn name(&self) -> &'static str {
"RateLimit"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InspectionVerbosity {
Minimal,
Normal,
Verbose,
}
pub struct RequestInspectionMiddleware {
log_config: LogConfig,
verbosity: InspectionVerbosity,
redact_headers: HashSet<String>,
slow_threshold_ms: u64,
max_body_preview: usize,
}
impl Default for RequestInspectionMiddleware {
fn default() -> Self {
Self {
log_config: LogConfig::development(),
verbosity: InspectionVerbosity::Normal,
redact_headers: default_redacted_headers(),
slow_threshold_ms: 1000,
max_body_preview: 2048,
}
}
}
impl RequestInspectionMiddleware {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn log_config(mut self, config: LogConfig) -> Self {
self.log_config = config;
self
}
#[must_use]
pub fn verbosity(mut self, level: InspectionVerbosity) -> Self {
self.verbosity = level;
self
}
#[must_use]
pub fn slow_threshold_ms(mut self, ms: u64) -> Self {
self.slow_threshold_ms = ms;
self
}
#[must_use]
pub fn max_body_preview(mut self, max: usize) -> Self {
self.max_body_preview = max;
self
}
#[must_use]
pub fn redact_header(mut self, name: impl Into<String>) -> Self {
self.redact_headers.insert(name.into().to_ascii_lowercase());
self
}
fn format_body_preview(&self, bytes: &[u8], content_type: Option<&[u8]>) -> Option<String> {
if bytes.is_empty() || self.max_body_preview == 0 {
return None;
}
let is_json = content_type
.and_then(|ct| std::str::from_utf8(ct).ok())
.is_some_and(|ct| ct.contains("application/json"));
let limit = self.max_body_preview.min(bytes.len());
let truncated = bytes.len() > self.max_body_preview;
match std::str::from_utf8(&bytes[..limit]) {
Ok(text) => {
if is_json {
if let Some(pretty) = try_pretty_json(text) {
let mut output = pretty;
if truncated {
output.push_str("\n ... (truncated)");
}
return Some(output);
}
}
let mut output = text.to_string();
if truncated {
output.push_str("...");
}
Some(output)
}
Err(_) => Some(format!("<{} bytes binary>", bytes.len())),
}
}
fn format_response_preview(
&self,
body: &crate::response::ResponseBody,
content_type: Option<&[u8]>,
) -> Option<String> {
match body {
crate::response::ResponseBody::Empty => None,
crate::response::ResponseBody::Bytes(bytes) => {
self.format_body_preview(bytes, content_type)
}
crate::response::ResponseBody::Stream(_) => Some("<streaming body>".to_string()),
}
}
fn format_inspection_headers<'a>(
&self,
headers: impl Iterator<Item = (&'a str, &'a [u8])>,
) -> String {
let mut out = String::new();
for (name, value) in headers {
out.push_str("\n ");
out.push_str(name);
out.push_str(": ");
let lowered = name.to_ascii_lowercase();
if self.redact_headers.contains(&lowered) {
out.push_str("[REDACTED]");
} else {
match std::str::from_utf8(value) {
Ok(text) => out.push_str(text),
Err(_) => out.push_str("<binary>"),
}
}
}
out
}
fn format_response_inspection_headers(&self, headers: &[(String, Vec<u8>)]) -> String {
self.format_inspection_headers(
headers
.iter()
.map(|(name, value)| (name.as_str(), value.as_slice())),
)
}
}
#[derive(Debug, Clone)]
struct InspectionStart(Instant);
impl Middleware for RequestInspectionMiddleware {
fn before<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
let logger = RequestLogger::new(ctx, self.log_config.clone());
req.insert_extension(InspectionStart(Instant::now()));
let method = req.method();
let path = req.path();
let query = req.query();
let mut request_line = format!("--> {method} {path}");
if let Some(q) = query {
request_line.push('?');
request_line.push_str(q);
}
let body_size = body_len(req.body());
if body_size > 0 {
request_line.push_str(&format!(" ({body_size} bytes)"));
}
match self.verbosity {
InspectionVerbosity::Minimal => {
logger.info(request_line);
}
InspectionVerbosity::Normal => {
let headers = self.format_inspection_headers(req.headers().iter());
logger.info(format!("{request_line}{headers}"));
}
InspectionVerbosity::Verbose => {
let headers = self.format_inspection_headers(req.headers().iter());
let content_type = req.headers().get("content-type");
let body_preview = match req.body() {
Body::Empty => None,
Body::Bytes(bytes) => self.format_body_preview(bytes, content_type),
Body::Stream { .. } => None,
};
let mut output = format!("{request_line}{headers}");
if let Some(body) = body_preview {
output.push_str("\n ");
output.push_str(&body.replace('\n', "\n "));
}
logger.info(output);
}
}
Box::pin(async { ControlFlow::Continue })
}
fn after<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let logger = RequestLogger::new(ctx, self.log_config.clone());
let duration = req
.get_extension::<InspectionStart>()
.map(|start| start.0.elapsed())
.unwrap_or_default();
let status = response.status();
let duration_ms = duration.as_millis();
let mut response_line = format!(
"<-- {} {} ({duration_ms}ms)",
status.as_u16(),
status.canonical_reason(),
);
if duration_ms >= u128::from(self.slow_threshold_ms) {
response_line.push_str(" [SLOW]");
}
match self.verbosity {
InspectionVerbosity::Minimal => {
if duration_ms >= u128::from(self.slow_threshold_ms) {
logger.warn(response_line);
} else {
logger.info(response_line);
}
}
InspectionVerbosity::Normal => {
let headers = self.format_response_inspection_headers(response.headers());
let output = format!("{response_line}{headers}");
if duration_ms >= u128::from(self.slow_threshold_ms) {
logger.warn(output);
} else {
logger.info(output);
}
}
InspectionVerbosity::Verbose => {
let headers = self.format_response_inspection_headers(response.headers());
let resp_content_type: Option<&[u8]> = response
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
.map(|(_, value)| value.as_slice());
let body_preview =
self.format_response_preview(response.body_ref(), resp_content_type);
let mut output = format!("{response_line}{headers}");
if let Some(body) = body_preview {
output.push_str("\n ");
output.push_str(&body.replace('\n', "\n "));
}
if duration_ms >= u128::from(self.slow_threshold_ms) {
logger.warn(output);
} else {
logger.info(output);
}
}
}
Box::pin(async move { response })
}
fn name(&self) -> &'static str {
"RequestInspection"
}
}
fn try_pretty_json(input: &str) -> Option<String> {
let trimmed = input.trim();
if !trimmed.starts_with('{') && !trimmed.starts_with('[') {
return None;
}
let mut output = String::with_capacity(trimmed.len() * 2);
if json_pretty_format(trimmed, &mut output).is_ok() {
Some(output)
} else {
None
}
}
fn json_pretty_format(input: &str, output: &mut String) -> Result<(), ()> {
let bytes = input.as_bytes();
let mut pos = 0;
let mut indent: usize = 0;
let mut in_string = false;
let mut escape_next = false;
while pos < bytes.len() {
let ch = bytes[pos] as char;
if escape_next {
output.push(ch);
escape_next = false;
pos += 1;
continue;
}
if in_string {
output.push(ch);
if ch == '\\' {
escape_next = true;
} else if ch == '"' {
in_string = false;
}
pos += 1;
continue;
}
match ch {
'"' => {
in_string = true;
output.push('"');
}
'{' | '[' => {
output.push(ch);
let peek = skip_whitespace(bytes, pos + 1);
let closing = if ch == '{' { '}' } else { ']' };
if peek < bytes.len() && bytes[peek] as char == closing {
output.push(closing);
pos = peek + 1;
continue;
}
indent += 1;
output.push('\n');
push_indent(output, indent);
}
'}' | ']' => {
indent = indent.saturating_sub(1);
output.push('\n');
push_indent(output, indent);
output.push(ch);
}
':' => {
output.push_str(": ");
}
',' => {
output.push(',');
output.push('\n');
push_indent(output, indent);
}
c if c.is_ascii_whitespace() => {
}
_ => {
output.push(ch);
}
}
pos += 1;
}
if in_string || indent != 0 {
return Err(());
}
Ok(())
}
fn skip_whitespace(bytes: &[u8], start: usize) -> usize {
let mut i = start;
while i < bytes.len() && (bytes[i] as char).is_ascii_whitespace() {
i += 1;
}
i
}
fn push_indent(output: &mut String, level: usize) {
for _ in 0..level {
output.push_str(" ");
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ETagMode {
Auto,
Manual,
Disabled,
}
impl Default for ETagMode {
fn default() -> Self {
Self::Auto
}
}
#[derive(Debug, Clone)]
pub struct ETagConfig {
pub mode: ETagMode,
pub weak: bool,
pub min_size: usize,
}
impl Default for ETagConfig {
fn default() -> Self {
Self {
mode: ETagMode::Auto,
weak: false,
min_size: 0,
}
}
}
impl ETagConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn mode(mut self, mode: ETagMode) -> Self {
self.mode = mode;
self
}
#[must_use]
pub fn weak(mut self, weak: bool) -> Self {
self.weak = weak;
self
}
#[must_use]
pub fn min_size(mut self, size: usize) -> Self {
self.min_size = size;
self
}
}
pub struct ETagMiddleware {
config: ETagConfig,
}
impl Default for ETagMiddleware {
fn default() -> Self {
Self::new()
}
}
impl ETagMiddleware {
#[must_use]
pub fn new() -> Self {
Self {
config: ETagConfig::default(),
}
}
#[must_use]
pub fn with_config(config: ETagConfig) -> Self {
Self { config }
}
fn generate_etag(data: &[u8], weak: bool) -> String {
const FNV_OFFSET_BASIS: u64 = 0xcbf29ce484222325;
const FNV_PRIME: u64 = 0x100000001b3;
let mut hash = FNV_OFFSET_BASIS;
for &byte in data {
hash ^= u64::from(byte);
hash = hash.wrapping_mul(FNV_PRIME);
}
if weak {
format!("W/\"{:016x}\"", hash)
} else {
format!("\"{:016x}\"", hash)
}
}
fn parse_if_none_match(value: &str) -> Vec<String> {
let trimmed = value.trim();
if trimmed == "*" {
return vec!["*".to_string()];
}
let mut etags = Vec::new();
let mut current = String::new();
let mut in_quote = false;
let mut prev_char = '\0';
for ch in trimmed.chars() {
match ch {
'"' if prev_char != '\\' => {
current.push(ch);
if in_quote {
let etag = current.trim().to_string();
if !etag.is_empty() {
etags.push(etag);
}
current.clear();
}
in_quote = !in_quote;
}
',' if !in_quote => {
current.clear();
}
_ => {
current.push(ch);
}
}
prev_char = ch;
}
etags
}
fn etags_match_weak(etag1: &str, etag2: &str) -> bool {
let e1 = Self::strip_weak_prefix(etag1);
let e2 = Self::strip_weak_prefix(etag2);
e1 == e2
}
fn strip_weak_prefix(s: &str) -> &str {
if s.starts_with("W/") || s.starts_with("w/") {
&s[2..]
} else {
s
}
}
fn is_cacheable_method(method: crate::request::Method) -> bool {
matches!(
method,
crate::request::Method::Get | crate::request::Method::Head
)
}
fn get_existing_etag(headers: &[(String, Vec<u8>)]) -> Option<String> {
for (name, value) in headers {
if name.eq_ignore_ascii_case("etag") {
return std::str::from_utf8(value).ok().map(String::from);
}
}
None
}
}
impl Middleware for ETagMiddleware {
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let config = self.config.clone();
Box::pin(async move {
if config.mode == ETagMode::Disabled {
return response;
}
if !Self::is_cacheable_method(req.method()) {
return response;
}
let (status, headers, body) = response.into_parts();
let existing_etag = Self::get_existing_etag(&headers);
let body_bytes = match &body {
crate::response::ResponseBody::Bytes(bytes) => Some(bytes.clone()),
crate::response::ResponseBody::Empty => Some(Vec::new()),
crate::response::ResponseBody::Stream(_) => None,
};
let etag = if let Some(existing) = existing_etag {
Some(existing)
} else if config.mode == ETagMode::Auto {
if let Some(ref bytes) = body_bytes {
if bytes.len() >= config.min_size {
Some(Self::generate_etag(bytes, config.weak))
} else {
None
}
} else {
None
}
} else {
None
};
if let Some(ref etag_value) = etag {
if let Some(if_none_match) = req.headers().get("if-none-match") {
if let Ok(value) = std::str::from_utf8(if_none_match) {
let client_etags = Self::parse_if_none_match(value);
let matches = client_etags.iter().any(|client_etag| {
client_etag == "*" || Self::etags_match_weak(client_etag, etag_value)
});
if matches {
return Response::with_status(
crate::response::StatusCode::NOT_MODIFIED,
)
.header("etag", etag_value.as_bytes().to_vec());
}
}
}
}
let mut new_response = Response::with_status(status)
.body(body)
.rebuild_with_headers(headers);
if let Some(etag_value) = etag {
new_response = new_response.header("etag", etag_value.into_bytes());
}
new_response
})
}
fn name(&self) -> &'static str {
"ETagMiddleware"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheDirective {
Public,
Private,
NoStore,
NoCache,
NoTransform,
MustRevalidate,
ProxyRevalidate,
StaleIfError,
StaleWhileRevalidate,
SMaxAge,
OnlyIfCached,
Immutable,
}
impl CacheDirective {
fn as_str(self) -> &'static str {
match self {
Self::Public => "public",
Self::Private => "private",
Self::NoStore => "no-store",
Self::NoCache => "no-cache",
Self::NoTransform => "no-transform",
Self::MustRevalidate => "must-revalidate",
Self::ProxyRevalidate => "proxy-revalidate",
Self::StaleIfError => "stale-if-error",
Self::StaleWhileRevalidate => "stale-while-revalidate",
Self::SMaxAge => "s-maxage",
Self::OnlyIfCached => "only-if-cached",
Self::Immutable => "immutable",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheControlBuilder {
directives: Vec<CacheDirective>,
max_age: Option<u32>,
s_maxage: Option<u32>,
stale_while_revalidate: Option<u32>,
stale_if_error: Option<u32>,
}
impl CacheControlBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn public(mut self) -> Self {
self.directives.push(CacheDirective::Public);
self
}
#[must_use]
pub fn private(mut self) -> Self {
self.directives.push(CacheDirective::Private);
self
}
#[must_use]
pub fn no_store(mut self) -> Self {
self.directives.push(CacheDirective::NoStore);
self
}
#[must_use]
pub fn no_cache(mut self) -> Self {
self.directives.push(CacheDirective::NoCache);
self
}
#[must_use]
pub fn no_transform(mut self) -> Self {
self.directives.push(CacheDirective::NoTransform);
self
}
#[must_use]
pub fn must_revalidate(mut self) -> Self {
self.directives.push(CacheDirective::MustRevalidate);
self
}
#[must_use]
pub fn proxy_revalidate(mut self) -> Self {
self.directives.push(CacheDirective::ProxyRevalidate);
self
}
#[must_use]
pub fn immutable(mut self) -> Self {
self.directives.push(CacheDirective::Immutable);
self
}
#[must_use]
pub fn max_age_secs(mut self, seconds: u32) -> Self {
self.max_age = Some(seconds);
self
}
#[must_use]
pub fn max_age(self, duration: std::time::Duration) -> Self {
self.max_age_secs(duration.as_secs() as u32)
}
#[must_use]
pub fn s_maxage_secs(mut self, seconds: u32) -> Self {
self.s_maxage = Some(seconds);
self
}
#[must_use]
pub fn s_maxage(self, duration: std::time::Duration) -> Self {
self.s_maxage_secs(duration.as_secs() as u32)
}
#[must_use]
pub fn stale_while_revalidate_secs(mut self, seconds: u32) -> Self {
self.stale_while_revalidate = Some(seconds);
self
}
#[must_use]
pub fn stale_if_error_secs(mut self, seconds: u32) -> Self {
self.stale_if_error = Some(seconds);
self
}
#[must_use]
pub fn build(&self) -> String {
let mut parts = Vec::new();
for directive in &self.directives {
parts.push(directive.as_str().to_string());
}
if let Some(age) = self.max_age {
parts.push(format!("max-age={age}"));
}
if let Some(age) = self.s_maxage {
parts.push(format!("s-maxage={age}"));
}
if let Some(seconds) = self.stale_while_revalidate {
parts.push(format!("stale-while-revalidate={seconds}"));
}
if let Some(seconds) = self.stale_if_error {
parts.push(format!("stale-if-error={seconds}"));
}
parts.join(", ")
}
#[must_use]
pub fn is_no_cache(&self) -> bool {
self.directives.contains(&CacheDirective::NoStore)
|| self.directives.contains(&CacheDirective::NoCache)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CachePreset {
NoCache,
PrivateNoCache,
PublicOneHour,
Immutable,
CdnFriendly,
StaticAssets,
}
impl CachePreset {
#[must_use]
pub fn to_header_value(&self) -> String {
match self {
Self::NoCache => "no-store, no-cache, must-revalidate".to_string(),
Self::PrivateNoCache => "private, max-age=0, must-revalidate".to_string(),
Self::PublicOneHour => "public, max-age=3600".to_string(),
Self::Immutable => "public, max-age=31536000, immutable".to_string(),
Self::CdnFriendly => "public, max-age=60, s-maxage=3600".to_string(),
Self::StaticAssets => "public, max-age=86400".to_string(),
}
}
#[must_use]
pub fn to_builder(&self) -> CacheControlBuilder {
match self {
Self::NoCache => CacheControlBuilder::new()
.no_store()
.no_cache()
.must_revalidate(),
Self::PrivateNoCache => CacheControlBuilder::new()
.private()
.max_age_secs(0)
.must_revalidate(),
Self::PublicOneHour => CacheControlBuilder::new().public().max_age_secs(3600),
Self::Immutable => CacheControlBuilder::new()
.public()
.max_age_secs(31536000)
.immutable(),
Self::CdnFriendly => CacheControlBuilder::new()
.public()
.max_age_secs(60)
.s_maxage_secs(3600),
Self::StaticAssets => CacheControlBuilder::new().public().max_age_secs(86400),
}
}
}
#[derive(Debug, Clone)]
pub struct CacheControlConfig {
pub cache_control: String,
pub vary: Vec<String>,
pub set_expires: bool,
pub preserve_existing: bool,
pub methods: Vec<crate::request::Method>,
pub path_patterns: Vec<String>,
pub cacheable_statuses: Vec<u16>,
}
impl Default for CacheControlConfig {
fn default() -> Self {
Self {
cache_control: CachePreset::NoCache.to_header_value(),
vary: Vec::new(),
set_expires: false,
preserve_existing: true,
methods: vec![crate::request::Method::Get, crate::request::Method::Head],
path_patterns: Vec::new(),
cacheable_statuses: (200..300).collect(),
}
}
}
impl CacheControlConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn from_preset(preset: CachePreset) -> Self {
Self {
cache_control: preset.to_header_value(),
..Self::default()
}
}
#[must_use]
pub fn from_builder(builder: CacheControlBuilder) -> Self {
Self {
cache_control: builder.build(),
..Self::default()
}
}
#[must_use]
pub fn cache_control(mut self, value: impl Into<String>) -> Self {
self.cache_control = value.into();
self
}
#[must_use]
pub fn vary(mut self, header: impl Into<String>) -> Self {
self.vary.push(header.into());
self
}
#[must_use]
pub fn vary_headers(mut self, headers: Vec<String>) -> Self {
self.vary.extend(headers);
self
}
#[must_use]
pub fn with_expires(mut self, enable: bool) -> Self {
self.set_expires = enable;
self
}
#[must_use]
pub fn preserve_existing(mut self, preserve: bool) -> Self {
self.preserve_existing = preserve;
self
}
#[must_use]
pub fn methods(mut self, methods: Vec<crate::request::Method>) -> Self {
self.methods = methods;
self
}
#[must_use]
pub fn path_patterns(mut self, patterns: Vec<String>) -> Self {
self.path_patterns = patterns;
self
}
#[must_use]
pub fn cacheable_statuses(mut self, statuses: Vec<u16>) -> Self {
self.cacheable_statuses = statuses;
self
}
}
pub struct CacheControlMiddleware {
config: CacheControlConfig,
}
impl Default for CacheControlMiddleware {
fn default() -> Self {
Self::new()
}
}
impl CacheControlMiddleware {
#[must_use]
pub fn new() -> Self {
Self {
config: CacheControlConfig::default(),
}
}
#[must_use]
pub fn with_preset(preset: CachePreset) -> Self {
Self {
config: CacheControlConfig::from_preset(preset),
}
}
#[must_use]
pub fn with_config(config: CacheControlConfig) -> Self {
Self { config }
}
fn is_cacheable_method(&self, method: crate::request::Method) -> bool {
self.config.methods.contains(&method)
}
fn is_cacheable_status(&self, status: u16) -> bool {
self.config.cacheable_statuses.contains(&status)
}
fn matches_path(&self, path: &str) -> bool {
if self.config.path_patterns.is_empty() {
return true; }
for pattern in &self.config.path_patterns {
if path_matches_pattern(path, pattern) {
return true;
}
}
false
}
fn has_cache_control(headers: &[(String, Vec<u8>)]) -> bool {
headers
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("cache-control"))
}
fn calculate_expires(cache_control: &str) -> Option<String> {
for directive in cache_control.split(',') {
let directive = directive.trim();
if directive.starts_with("max-age=") {
if let Ok(seconds) = directive[8..].parse::<u64>() {
let now = std::time::SystemTime::now();
if let Some(expires) = now.checked_add(std::time::Duration::from_secs(seconds))
{
return Some(format_http_date(expires));
}
}
}
}
None
}
}
fn path_matches_pattern(path: &str, pattern: &str) -> bool {
if pattern == "*" {
return true;
}
if pattern.contains('*') {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 2 {
let (prefix, suffix) = (parts[0], parts[1]);
return path.starts_with(prefix) && path.ends_with(suffix);
}
let fixed_parts: Vec<&str> = pattern.split('*').filter(|s| !s.is_empty()).collect();
let mut remaining = path;
for part in fixed_parts {
if let Some(pos) = remaining.find(part) {
remaining = &remaining[pos + part.len()..];
} else {
return false;
}
}
true
} else {
path == pattern
}
}
fn format_http_date(time: std::time::SystemTime) -> String {
match time.duration_since(std::time::UNIX_EPOCH) {
Ok(duration) => {
let secs = duration.as_secs();
let days = secs / 86400;
let remaining_secs = secs % 86400;
let hours = remaining_secs / 3600;
let minutes = (remaining_secs % 3600) / 60;
let seconds = remaining_secs % 60;
let day_of_week = ((days + 4) % 7) as usize;
let day_names = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
let (year, month, day) = days_to_date(days);
let month_names = [
"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
];
format!(
"{}, {:02} {} {} {:02}:{:02}:{:02} GMT",
day_names[day_of_week],
day,
month_names[(month - 1) as usize],
year,
hours,
minutes,
seconds
)
}
Err(_) => "Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
}
}
fn days_to_date(days: u64) -> (u64, u64, u64) {
let mut remaining_days = days;
let mut year = 1970u64;
loop {
let days_in_year = if is_leap_year(year) { 366 } else { 365 };
if remaining_days < days_in_year {
break;
}
remaining_days -= days_in_year;
year += 1;
}
let leap = is_leap_year(year);
let month_days: [u64; 12] = if leap {
[31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
} else {
[31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
};
let mut month = 1u64;
for &days_in_month in &month_days {
if remaining_days < days_in_month {
break;
}
remaining_days -= days_in_month;
month += 1;
}
(year, month, remaining_days + 1)
}
fn is_leap_year(year: u64) -> bool {
(year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
}
impl Middleware for CacheControlMiddleware {
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let config = self.config.clone();
Box::pin(async move {
if !self.is_cacheable_method(req.method()) {
return response;
}
if !self.is_cacheable_status(response.status().as_u16()) {
return response;
}
if !self.matches_path(req.path()) {
return response;
}
let (status, mut headers, body) = response.into_parts();
if config.preserve_existing && Self::has_cache_control(&headers) {
let mut resp = Response::with_status(status);
for (name, value) in headers {
resp = resp.header(name, value);
}
return resp.body(body);
}
headers.push((
"Cache-Control".to_string(),
config.cache_control.as_bytes().to_vec(),
));
if !config.vary.is_empty() {
let vary_value = config.vary.join(", ");
headers.push(("Vary".to_string(), vary_value.into_bytes()));
}
if config.set_expires {
if let Some(expires) = Self::calculate_expires(&config.cache_control) {
headers.push(("Expires".to_string(), expires.into_bytes()));
}
}
let mut resp = Response::with_status(status);
for (name, value) in headers {
resp = resp.header(name, value);
}
resp.body(body)
})
}
fn name(&self) -> &'static str {
"CacheControlMiddleware"
}
}
#[derive(Debug, Clone)]
pub struct TraceRejectionMiddleware {
log_attempts: bool,
}
impl Default for TraceRejectionMiddleware {
fn default() -> Self {
Self::new()
}
}
impl TraceRejectionMiddleware {
#[must_use]
pub fn new() -> Self {
Self { log_attempts: true }
}
#[must_use]
pub fn log_attempts(mut self, log: bool) -> Self {
self.log_attempts = log;
self
}
fn rejection_response(path: &str) -> Response {
let body = format!(
r#"{{"detail":"HTTP TRACE method is not allowed","path":"{}"}}"#,
path.replace('"', "\\\"")
);
Response::with_status(crate::response::StatusCode::METHOD_NOT_ALLOWED)
.header("Content-Type", b"application/json".to_vec())
.header(
"Allow",
b"GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD".to_vec(),
)
.body(crate::response::ResponseBody::Bytes(body.into_bytes()))
}
}
impl Middleware for TraceRejectionMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
Box::pin(async move {
if req.method() == crate::request::Method::Trace {
if self.log_attempts {
let path = req.path();
let remote_ip = req
.headers()
.get("X-Forwarded-For")
.or_else(|| req.headers().get("X-Real-IP"))
.map(|v| String::from_utf8_lossy(v).to_string())
.unwrap_or_else(|| "unknown".to_string());
eprintln!(
"[SECURITY] TRACE request blocked: path={}, remote_ip={}",
path, remote_ip
);
}
return ControlFlow::Break(Self::rejection_response(req.path()));
}
ControlFlow::Continue
})
}
fn name(&self) -> &'static str {
"TraceRejection"
}
}
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct HttpsRedirectConfig {
pub redirect_enabled: bool,
pub permanent_redirect: bool,
pub hsts_max_age_secs: u64,
pub hsts_include_subdomains: bool,
pub hsts_preload: bool,
pub exclude_paths: Vec<String>,
pub https_port: u16,
}
impl Default for HttpsRedirectConfig {
fn default() -> Self {
Self {
redirect_enabled: true,
permanent_redirect: true, hsts_max_age_secs: 31_536_000, hsts_include_subdomains: false,
hsts_preload: false,
exclude_paths: Vec::new(),
https_port: 443,
}
}
}
#[derive(Debug, Clone)]
pub struct HttpsRedirectMiddleware {
config: HttpsRedirectConfig,
}
impl Default for HttpsRedirectMiddleware {
fn default() -> Self {
Self::new()
}
}
impl HttpsRedirectMiddleware {
#[must_use]
pub fn new() -> Self {
Self {
config: HttpsRedirectConfig::default(),
}
}
#[must_use]
pub fn redirect_enabled(mut self, enabled: bool) -> Self {
self.config.redirect_enabled = enabled;
self
}
#[must_use]
pub fn permanent_redirect(mut self, permanent: bool) -> Self {
self.config.permanent_redirect = permanent;
self
}
#[must_use]
pub fn hsts_max_age_secs(mut self, secs: u64) -> Self {
self.config.hsts_max_age_secs = secs;
self
}
#[must_use]
pub fn include_subdomains(mut self, include: bool) -> Self {
self.config.hsts_include_subdomains = include;
self
}
#[must_use]
pub fn preload(mut self, preload: bool) -> Self {
self.config.hsts_preload = preload;
self
}
#[must_use]
pub fn exclude_path(mut self, path: impl Into<String>) -> Self {
self.config.exclude_paths.push(path.into());
self
}
#[must_use]
pub fn exclude_paths(mut self, paths: Vec<String>) -> Self {
self.config.exclude_paths = paths;
self
}
#[must_use]
pub fn https_port(mut self, port: u16) -> Self {
self.config.https_port = port;
self
}
fn is_secure(&self, req: &Request) -> bool {
fn trim_ascii(mut bytes: &[u8]) -> &[u8] {
while matches!(bytes.first(), Some(b' ' | b'\t')) {
bytes = &bytes[1..];
}
while matches!(bytes.last(), Some(b' ' | b'\t')) {
bytes = &bytes[..bytes.len() - 1];
}
bytes
}
if let Some(info) = req.get_extension::<crate::request::ConnectionInfo>() {
if info.is_tls {
return true;
}
}
if let Some(forwarded) = req.headers().get("Forwarded") {
if let Ok(s) = std::str::from_utf8(forwarded) {
for entry in s.split(',') {
for param in entry.split(';') {
let param = param.trim();
if let Some((k, v)) = param.split_once('=') {
if k.trim().eq_ignore_ascii_case("proto") {
let proto = v.trim().trim_matches('"');
if proto.eq_ignore_ascii_case("https") {
return true;
}
}
}
}
}
}
}
if let Some(proto) = req.headers().get("X-Forwarded-Proto") {
let first = proto.split(|&b| b == b',').next().unwrap_or(proto);
return trim_ascii(first).eq_ignore_ascii_case(b"https");
}
if let Some(ssl) = req.headers().get("X-Forwarded-Ssl") {
return ssl.eq_ignore_ascii_case(b"on");
}
if let Some(https) = req.headers().get("Front-End-Https") {
return https.eq_ignore_ascii_case(b"on");
}
false
}
fn is_excluded(&self, path: &str) -> bool {
self.config
.exclude_paths
.iter()
.any(|p| path.starts_with(p))
}
fn build_hsts_header(&self) -> Option<Vec<u8>> {
if self.config.hsts_max_age_secs == 0 {
return None;
}
let mut value = format!("max-age={}", self.config.hsts_max_age_secs);
if self.config.hsts_include_subdomains {
value.push_str("; includeSubDomains");
}
if self.config.hsts_preload {
value.push_str("; preload");
}
Some(value.into_bytes())
}
fn build_redirect_url(&self, req: &Request) -> String {
let host = req
.headers()
.get("Host")
.map(|h| String::from_utf8_lossy(h).to_string())
.unwrap_or_else(|| "localhost".to_string());
let host_without_port = host.split(':').next().unwrap_or(&host);
let path = req.path();
let query = req.query();
if self.config.https_port == 443 {
match query {
Some(q) => format!("https://{}{}?{}", host_without_port, path, q),
None => format!("https://{}{}", host_without_port, path),
}
} else {
match query {
Some(q) => format!(
"https://{}:{}{}?{}",
host_without_port, self.config.https_port, path, q
),
None => format!(
"https://{}:{}{}",
host_without_port, self.config.https_port, path
),
}
}
}
}
impl Middleware for HttpsRedirectMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
Box::pin(async move {
if !self.config.redirect_enabled {
return ControlFlow::Continue;
}
if self.is_secure(req) {
return ControlFlow::Continue;
}
if self.is_excluded(req.path()) {
return ControlFlow::Continue;
}
let redirect_url = self.build_redirect_url(req);
let status = if self.config.permanent_redirect {
crate::response::StatusCode::MOVED_PERMANENTLY
} else {
crate::response::StatusCode::TEMPORARY_REDIRECT
};
let response = Response::with_status(status)
.header("Location", redirect_url.into_bytes())
.header("Content-Type", b"text/plain".to_vec())
.body(crate::response::ResponseBody::Bytes(
b"Redirecting to HTTPS...".to_vec(),
));
ControlFlow::Break(response)
})
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
if !self.is_secure(req) {
return response;
}
if let Some(hsts_value) = self.build_hsts_header() {
response.header("Strict-Transport-Security", hsts_value)
} else {
response
}
})
}
fn name(&self) -> &'static str {
"HttpsRedirect"
}
}
pub trait ResponseInterceptor: Send + Sync {
fn intercept<'a>(
&'a self,
ctx: &'a ResponseInterceptorContext<'a>,
response: Response,
) -> BoxFuture<'a, Response>;
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
}
#[derive(Debug)]
pub struct ResponseInterceptorContext<'a> {
pub request: &'a Request,
pub start_time: Instant,
pub request_ctx: &'a RequestContext,
}
impl<'a> ResponseInterceptorContext<'a> {
pub fn new(request: &'a Request, request_ctx: &'a RequestContext, start_time: Instant) -> Self {
Self {
request,
start_time,
request_ctx,
}
}
pub fn elapsed(&self) -> std::time::Duration {
self.start_time.elapsed()
}
pub fn elapsed_ms(&self) -> u128 {
self.start_time.elapsed().as_millis()
}
}
#[derive(Default)]
pub struct ResponseInterceptorStack {
interceptors: Vec<Arc<dyn ResponseInterceptor>>,
}
impl ResponseInterceptorStack {
#[must_use]
pub fn new() -> Self {
Self {
interceptors: Vec::new(),
}
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
interceptors: Vec::with_capacity(capacity),
}
}
pub fn push<I: ResponseInterceptor + 'static>(&mut self, interceptor: I) {
self.interceptors.push(Arc::new(interceptor));
}
pub fn push_arc(&mut self, interceptor: Arc<dyn ResponseInterceptor>) {
self.interceptors.push(interceptor);
}
#[must_use]
pub fn len(&self) -> usize {
self.interceptors.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.interceptors.is_empty()
}
pub async fn process(
&self,
ctx: &ResponseInterceptorContext<'_>,
mut response: Response,
) -> Response {
for interceptor in &self.interceptors {
let _ = ctx.request_ctx.checkpoint();
response = interceptor.intercept(ctx, response).await;
}
response
}
}
#[derive(Debug, Clone)]
pub struct TimingInterceptor {
header_name: String,
include_server_timing: bool,
server_timing_name: String,
}
impl Default for TimingInterceptor {
fn default() -> Self {
Self::new()
}
}
impl TimingInterceptor {
#[must_use]
pub fn new() -> Self {
Self {
header_name: "X-Response-Time".to_string(),
include_server_timing: false,
server_timing_name: "total".to_string(),
}
}
#[must_use]
pub fn with_server_timing(mut self, metric_name: impl Into<String>) -> Self {
self.include_server_timing = true;
self.server_timing_name = metric_name.into();
self
}
#[must_use]
pub fn header_name(mut self, name: impl Into<String>) -> Self {
self.header_name = name.into();
self
}
}
impl ResponseInterceptor for TimingInterceptor {
fn intercept<'a>(
&'a self,
ctx: &'a ResponseInterceptorContext<'a>,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
let elapsed_ms = ctx.elapsed_ms();
let timing_value = format!("{}ms", elapsed_ms);
let response = response.header(&self.header_name, timing_value.clone().into_bytes());
if self.include_server_timing {
let server_timing = format!("{};dur={}", self.server_timing_name, elapsed_ms);
response.header("Server-Timing", server_timing.into_bytes())
} else {
response
}
})
}
fn name(&self) -> &'static str {
"TimingInterceptor"
}
}
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct DebugInfoInterceptor {
include_path: bool,
include_method: bool,
include_request_id: bool,
include_timing: bool,
header_prefix: String,
}
impl Default for DebugInfoInterceptor {
fn default() -> Self {
Self::new()
}
}
impl DebugInfoInterceptor {
#[must_use]
pub fn new() -> Self {
Self {
include_path: true,
include_method: true,
include_request_id: true,
include_timing: true,
header_prefix: "X-Debug-".to_string(),
}
}
#[must_use]
pub fn include_path(mut self, include: bool) -> Self {
self.include_path = include;
self
}
#[must_use]
pub fn include_method(mut self, include: bool) -> Self {
self.include_method = include;
self
}
#[must_use]
pub fn include_request_id(mut self, include: bool) -> Self {
self.include_request_id = include;
self
}
#[must_use]
pub fn include_timing(mut self, include: bool) -> Self {
self.include_timing = include;
self
}
#[must_use]
pub fn header_prefix(mut self, prefix: impl Into<String>) -> Self {
self.header_prefix = prefix.into();
self
}
}
impl ResponseInterceptor for DebugInfoInterceptor {
fn intercept<'a>(
&'a self,
ctx: &'a ResponseInterceptorContext<'a>,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
let mut resp = response;
if self.include_path {
let header_name = format!("{}Path", self.header_prefix);
resp = resp.header(header_name, ctx.request.path().as_bytes().to_vec());
}
if self.include_method {
let header_name = format!("{}Method", self.header_prefix);
resp = resp.header(
header_name,
ctx.request.method().as_str().as_bytes().to_vec(),
);
}
if self.include_request_id {
if let Some(request_id) = ctx.request.get_extension::<RequestId>() {
let header_name = format!("{}Request-Id", self.header_prefix);
resp = resp.header(header_name, request_id.0.as_bytes().to_vec());
}
}
if self.include_timing {
let header_name = format!("{}Handler-Time", self.header_prefix);
let timing = format!("{}ms", ctx.elapsed_ms());
resp = resp.header(header_name, timing.into_bytes());
}
resp
})
}
fn name(&self) -> &'static str {
"DebugInfoInterceptor"
}
}
pub struct ResponseBodyTransform<F>
where
F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
{
transform_fn: F,
content_type_filter: Option<String>,
}
impl<F> ResponseBodyTransform<F>
where
F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
{
pub fn new(transform_fn: F) -> Self {
Self {
transform_fn,
content_type_filter: None,
}
}
#[must_use]
pub fn for_content_type(mut self, content_type: impl Into<String>) -> Self {
self.content_type_filter = Some(content_type.into());
self
}
fn should_transform(&self, response: &Response) -> bool {
match &self.content_type_filter {
Some(filter) => response
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
.and_then(|(_, ct)| std::str::from_utf8(ct).ok())
.map(|ct| ct.starts_with(filter))
.unwrap_or(false),
None => true,
}
}
}
impl<F> ResponseInterceptor for ResponseBodyTransform<F>
where
F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
{
fn intercept<'a>(
&'a self,
_ctx: &'a ResponseInterceptorContext<'a>,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
if !self.should_transform(&response) {
return response;
}
let body_bytes = match response.body_ref() {
crate::response::ResponseBody::Empty => Vec::new(),
crate::response::ResponseBody::Bytes(b) => b.clone(),
crate::response::ResponseBody::Stream(_) => {
return response;
}
};
let transformed = (self.transform_fn)(body_bytes);
response.body(crate::response::ResponseBody::Bytes(transformed))
})
}
fn name(&self) -> &'static str {
"ResponseBodyTransform"
}
}
#[derive(Debug, Clone, Default)]
pub struct HeaderTransformInterceptor {
add_headers: Vec<(String, Vec<u8>)>,
remove_headers: Vec<String>,
rename_headers: Vec<(String, String)>,
}
impl HeaderTransformInterceptor {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn add(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
self.add_headers.push((name.into(), value.into()));
self
}
#[must_use]
pub fn remove(mut self, name: impl Into<String>) -> Self {
self.remove_headers.push(name.into());
self
}
#[must_use]
pub fn rename(mut self, old_name: impl Into<String>, new_name: impl Into<String>) -> Self {
self.rename_headers.push((old_name.into(), new_name.into()));
self
}
}
impl ResponseInterceptor for HeaderTransformInterceptor {
fn intercept<'a>(
&'a self,
_ctx: &'a ResponseInterceptorContext<'a>,
response: Response,
) -> BoxFuture<'a, Response> {
let add_headers = self.add_headers.clone();
let remove_headers = self.remove_headers.clone();
let rename_headers = self.rename_headers.clone();
Box::pin(async move {
let mut resp = response;
for (old_name, new_name) in &rename_headers {
let values: Vec<Vec<u8>> = resp
.headers()
.iter()
.filter(|(name, _)| name.eq_ignore_ascii_case(old_name))
.map(|(_, v)| v.clone())
.collect();
if !values.is_empty() {
resp = resp.remove_header(old_name);
for v in values {
resp = resp.header(new_name, v);
}
}
}
for (name, value) in add_headers {
resp = resp.header(name, value);
}
for name in &remove_headers {
resp = resp.remove_header(name);
}
resp
})
}
fn name(&self) -> &'static str {
"HeaderTransformInterceptor"
}
}
pub struct ConditionalInterceptor<I, F>
where
I: ResponseInterceptor,
F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
{
inner: I,
condition: F,
}
impl<I, F> ConditionalInterceptor<I, F>
where
I: ResponseInterceptor,
F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
{
pub fn new(inner: I, condition: F) -> Self {
Self { inner, condition }
}
}
impl<I, F> ResponseInterceptor for ConditionalInterceptor<I, F>
where
I: ResponseInterceptor,
F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
{
fn intercept<'a>(
&'a self,
ctx: &'a ResponseInterceptorContext<'a>,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
if (self.condition)(ctx, &response) {
self.inner.intercept(ctx, response).await
} else {
response
}
})
}
fn name(&self) -> &'static str {
"ConditionalInterceptor"
}
}
#[derive(Debug, Clone)]
pub struct ErrorResponseTransformer {
status_codes: HashSet<u16>,
replacement_body: Option<Vec<u8>>,
add_error_id: bool,
}
impl Default for ErrorResponseTransformer {
fn default() -> Self {
Self::new()
}
}
impl ErrorResponseTransformer {
#[must_use]
pub fn new() -> Self {
Self {
status_codes: HashSet::new(),
replacement_body: None,
add_error_id: false,
}
}
#[must_use]
pub fn hide_details_for_status(mut self, status: crate::response::StatusCode) -> Self {
self.status_codes.insert(status.as_u16());
self
}
#[must_use]
pub fn with_replacement_body(mut self, body: impl Into<Vec<u8>>) -> Self {
self.replacement_body = Some(body.into());
self
}
#[must_use]
pub fn add_error_id(mut self, enable: bool) -> Self {
self.add_error_id = enable;
self
}
}
impl ResponseInterceptor for ErrorResponseTransformer {
fn intercept<'a>(
&'a self,
ctx: &'a ResponseInterceptorContext<'a>,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
let status_code = response.status().as_u16();
if !self.status_codes.contains(&status_code) {
return response;
}
let mut resp = response;
if let Some(ref replacement) = self.replacement_body {
resp = resp.body(crate::response::ResponseBody::Bytes(replacement.clone()));
}
if self.add_error_id {
let error_id = ctx
.request
.get_extension::<RequestId>()
.map(|r| r.0.clone())
.unwrap_or_else(|| format!("err-{}", ctx.elapsed_ms()));
resp = resp.header("X-Error-Id", error_id.into_bytes());
}
resp
})
}
fn name(&self) -> &'static str {
"ErrorResponseTransformer"
}
}
pub struct ResponseInterceptorMiddleware<I>
where
I: ResponseInterceptor,
{
interceptor: I,
}
impl<I> ResponseInterceptorMiddleware<I>
where
I: ResponseInterceptor,
{
pub fn new(interceptor: I) -> Self {
Self { interceptor }
}
}
impl<I> Middleware for ResponseInterceptorMiddleware<I>
where
I: ResponseInterceptor,
{
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
req.insert_extension(InterceptorStartTime(Instant::now()));
Box::pin(async { ControlFlow::Continue })
}
fn after<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move {
let start_time = req
.get_extension::<InterceptorStartTime>()
.map(|t| t.0)
.unwrap_or_else(Instant::now);
let interceptor_ctx = ResponseInterceptorContext::new(req, ctx, start_time);
self.interceptor.intercept(&interceptor_ctx, response).await
})
}
fn name(&self) -> &'static str {
self.interceptor.name()
}
}
#[derive(Debug, Clone, Copy)]
struct InterceptorStartTime(Instant);
#[derive(Debug, Clone)]
pub struct ServerTimingEntry {
name: String,
duration_ms: f64,
description: Option<String>,
}
impl ServerTimingEntry {
#[must_use]
pub fn new(name: impl Into<String>, duration_ms: f64) -> Self {
Self {
name: name.into(),
duration_ms,
description: None,
}
}
#[must_use]
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
#[must_use]
pub fn to_header_value(&self) -> String {
match &self.description {
Some(desc) => format!(
"{};dur={:.3};desc=\"{}\"",
self.name, self.duration_ms, desc
),
None => format!("{};dur={:.3}", self.name, self.duration_ms),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ServerTimingBuilder {
entries: Vec<ServerTimingEntry>,
}
impl ServerTimingBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn add(mut self, name: impl Into<String>, duration_ms: f64) -> Self {
self.entries.push(ServerTimingEntry::new(name, duration_ms));
self
}
#[must_use]
pub fn add_with_desc(
mut self,
name: impl Into<String>,
duration_ms: f64,
description: impl Into<String>,
) -> Self {
self.entries
.push(ServerTimingEntry::new(name, duration_ms).with_description(description));
self
}
#[must_use]
pub fn add_entry(mut self, entry: ServerTimingEntry) -> Self {
self.entries.push(entry);
self
}
#[must_use]
pub fn build(&self) -> String {
self.entries
.iter()
.map(ServerTimingEntry::to_header_value)
.collect::<Vec<_>>()
.join(", ")
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
}
#[derive(Debug, Clone)]
pub struct TimingMetrics {
pub start_time: Instant,
pub first_byte_time: Option<Instant>,
pub custom_metrics: Vec<(String, f64, Option<String>)>,
}
impl TimingMetrics {
#[must_use]
pub fn new() -> Self {
Self {
start_time: Instant::now(),
first_byte_time: None,
custom_metrics: Vec::new(),
}
}
#[must_use]
pub fn with_start_time(start_time: Instant) -> Self {
Self {
start_time,
first_byte_time: None,
custom_metrics: Vec::new(),
}
}
pub fn mark_first_byte(&mut self) {
self.first_byte_time = Some(Instant::now());
}
pub fn add_metric(&mut self, name: impl Into<String>, duration_ms: f64) {
self.custom_metrics.push((name.into(), duration_ms, None));
}
pub fn add_metric_with_desc(
&mut self,
name: impl Into<String>,
duration_ms: f64,
desc: impl Into<String>,
) {
self.custom_metrics
.push((name.into(), duration_ms, Some(desc.into())));
}
#[must_use]
pub fn total_ms(&self) -> f64 {
self.start_time.elapsed().as_secs_f64() * 1000.0
}
#[must_use]
pub fn ttfb_ms(&self) -> Option<f64> {
self.first_byte_time
.map(|t| t.duration_since(self.start_time).as_secs_f64() * 1000.0)
}
#[must_use]
pub fn to_server_timing(&self) -> ServerTimingBuilder {
let mut builder = ServerTimingBuilder::new().add_with_desc(
"total",
self.total_ms(),
"Total request time",
);
if let Some(ttfb) = self.ttfb_ms() {
builder = builder.add_with_desc("ttfb", ttfb, "Time to first byte");
}
for (name, duration, desc) in &self.custom_metrics {
match desc {
Some(d) => builder = builder.add_with_desc(name, *duration, d),
None => builder = builder.add(name, *duration),
}
}
builder
}
}
impl Default for TimingMetrics {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct TimingMetricsConfig {
pub add_server_timing_header: bool,
pub add_response_time_header: bool,
pub response_time_header_name: String,
pub include_custom_metrics: bool,
pub include_ttfb: bool,
}
impl Default for TimingMetricsConfig {
fn default() -> Self {
Self {
add_server_timing_header: true,
add_response_time_header: true,
response_time_header_name: "X-Response-Time".to_string(),
include_custom_metrics: true,
include_ttfb: true,
}
}
}
impl TimingMetricsConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn server_timing(mut self, enabled: bool) -> Self {
self.add_server_timing_header = enabled;
self
}
#[must_use]
pub fn response_time(mut self, enabled: bool) -> Self {
self.add_response_time_header = enabled;
self
}
#[must_use]
pub fn response_time_header(mut self, name: impl Into<String>) -> Self {
self.response_time_header_name = name.into();
self
}
#[must_use]
pub fn custom_metrics(mut self, enabled: bool) -> Self {
self.include_custom_metrics = enabled;
self
}
#[must_use]
pub fn ttfb(mut self, enabled: bool) -> Self {
self.include_ttfb = enabled;
self
}
#[must_use]
pub fn production() -> Self {
Self {
add_server_timing_header: false,
add_response_time_header: true,
response_time_header_name: "X-Response-Time".to_string(),
include_custom_metrics: false,
include_ttfb: false,
}
}
#[must_use]
pub fn development() -> Self {
Self::default()
}
}
#[derive(Debug, Clone)]
pub struct TimingMetricsMiddleware {
config: TimingMetricsConfig,
}
impl TimingMetricsMiddleware {
#[must_use]
pub fn new() -> Self {
Self {
config: TimingMetricsConfig::default(),
}
}
#[must_use]
pub fn with_config(config: TimingMetricsConfig) -> Self {
Self { config }
}
#[must_use]
pub fn production() -> Self {
Self {
config: TimingMetricsConfig::production(),
}
}
#[must_use]
pub fn development() -> Self {
Self {
config: TimingMetricsConfig::development(),
}
}
}
impl Default for TimingMetricsMiddleware {
fn default() -> Self {
Self::new()
}
}
impl Middleware for TimingMetricsMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
req.insert_extension(TimingMetrics::new());
Box::pin(async { ControlFlow::Continue })
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let config = self.config.clone();
Box::pin(async move {
let mut resp = response;
let metrics = req.get_extension::<TimingMetrics>();
match metrics {
Some(metrics) => {
if config.add_response_time_header {
let timing = format!("{:.3}ms", metrics.total_ms());
resp = resp.header(&config.response_time_header_name, timing.into_bytes());
}
if config.add_server_timing_header {
let mut builder = ServerTimingBuilder::new().add_with_desc(
"total",
metrics.total_ms(),
"Total request time",
);
if config.include_ttfb {
if let Some(ttfb) = metrics.ttfb_ms() {
builder = builder.add_with_desc("ttfb", ttfb, "Time to first byte");
}
}
if config.include_custom_metrics {
for (name, duration, desc) in &metrics.custom_metrics {
match desc {
Some(d) => builder = builder.add_with_desc(name, *duration, d),
None => builder = builder.add(name, *duration),
}
}
}
let header_value = builder.build();
resp = resp.header("Server-Timing", header_value.into_bytes());
}
}
None => {
if config.add_response_time_header {
resp = resp.header(&config.response_time_header_name, b"0.000ms".to_vec());
}
}
}
resp
})
}
fn name(&self) -> &'static str {
"TimingMetrics"
}
}
#[derive(Debug, Clone)]
pub struct TimingHistogramBucket {
pub le: f64,
pub count: u64,
}
#[derive(Debug, Clone)]
pub struct TimingHistogram {
bucket_bounds: Vec<f64>,
bucket_counts: Vec<u64>,
sum: f64,
count: u64,
}
impl TimingHistogram {
#[must_use]
pub fn with_buckets(bucket_bounds: Vec<f64>) -> Self {
let bucket_counts = vec![0; bucket_bounds.len()];
Self {
bucket_bounds,
bucket_counts,
sum: 0.0,
count: 0,
}
}
#[must_use]
pub fn http_latency() -> Self {
Self::with_buckets(vec![
1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1000.0, 2500.0, 5000.0, 10000.0,
])
}
pub fn observe(&mut self, value_ms: f64) {
self.sum += value_ms;
self.count += 1;
for (i, bound) in self.bucket_bounds.iter().enumerate() {
if value_ms <= *bound {
self.bucket_counts[i] += 1;
}
}
}
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[must_use]
pub fn sum(&self) -> f64 {
self.sum
}
#[must_use]
pub fn mean(&self) -> f64 {
if self.count == 0 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
{
self.sum / self.count as f64
}
}
}
#[must_use]
pub fn buckets(&self) -> Vec<TimingHistogramBucket> {
self.bucket_bounds
.iter()
.zip(&self.bucket_counts)
.map(|(&le, &count)| TimingHistogramBucket { le, count })
.collect()
}
pub fn reset(&mut self) {
self.sum = 0.0;
self.count = 0;
for count in &mut self.bucket_counts {
*count = 0;
}
}
}
impl Default for TimingHistogram {
fn default() -> Self {
Self::http_latency()
}
}
#[cfg(test)]
mod timing_metrics_tests {
use super::*;
use crate::request::Method;
use crate::response::StatusCode;
fn test_context() -> RequestContext {
RequestContext::new(asupersync::Cx::for_testing(), 1)
}
fn test_request() -> Request {
Request::new(Method::Get, "/test")
}
fn run_middleware_before(mw: &impl Middleware, req: &mut Request) -> ControlFlow {
let ctx = test_context();
futures_executor::block_on(mw.before(&ctx, req))
}
fn run_middleware_after(mw: &impl Middleware, req: &Request, resp: Response) -> Response {
let ctx = test_context();
futures_executor::block_on(mw.after(&ctx, req, resp))
}
#[test]
fn server_timing_entry_basic() {
let entry = ServerTimingEntry::new("db", 42.5);
assert_eq!(entry.to_header_value(), "db;dur=42.500");
}
#[test]
fn server_timing_entry_with_description() {
let entry = ServerTimingEntry::new("db", 42.5).with_description("Database query");
assert_eq!(
entry.to_header_value(),
"db;dur=42.500;desc=\"Database query\""
);
}
#[test]
fn server_timing_builder_single_entry() {
let timing = ServerTimingBuilder::new().add("total", 150.0).build();
assert_eq!(timing, "total;dur=150.000");
}
#[test]
fn server_timing_builder_multiple_entries() {
let timing = ServerTimingBuilder::new()
.add("total", 150.0)
.add_with_desc("db", 42.0, "Database")
.add("cache", 5.0)
.build();
assert!(timing.contains("total;dur=150.000"));
assert!(timing.contains("db;dur=42.000;desc=\"Database\""));
assert!(timing.contains("cache;dur=5.000"));
assert!(timing.contains(", ")); }
#[test]
fn server_timing_builder_empty() {
let builder = ServerTimingBuilder::new();
assert!(builder.is_empty());
assert_eq!(builder.len(), 0);
assert_eq!(builder.build(), "");
}
#[test]
fn timing_metrics_basic() {
let metrics = TimingMetrics::new();
std::thread::sleep(std::time::Duration::from_millis(5));
let total = metrics.total_ms();
assert!(total >= 5.0, "Total should be at least 5ms");
assert!(metrics.ttfb_ms().is_none(), "TTFB should not be set");
}
#[test]
fn timing_metrics_custom_metrics() {
let mut metrics = TimingMetrics::new();
metrics.add_metric("db", 42.5);
metrics.add_metric_with_desc("cache", 5.0, "Cache lookup");
let timing = metrics.to_server_timing();
assert_eq!(timing.len(), 3);
let header = timing.build();
assert!(header.contains("total"));
assert!(header.contains("db;dur=42.500"));
assert!(header.contains("cache;dur=5.000;desc=\"Cache lookup\""));
}
#[test]
fn timing_metrics_ttfb() {
let mut metrics = TimingMetrics::new();
std::thread::sleep(std::time::Duration::from_millis(5));
metrics.mark_first_byte();
let ttfb = metrics.ttfb_ms().unwrap();
assert!(ttfb >= 5.0, "TTFB should be at least 5ms");
}
#[test]
fn timing_metrics_config_default() {
let config = TimingMetricsConfig::default();
assert!(config.add_server_timing_header);
assert!(config.add_response_time_header);
assert!(config.include_custom_metrics);
assert!(config.include_ttfb);
}
#[test]
fn timing_metrics_config_production() {
let config = TimingMetricsConfig::production();
assert!(!config.add_server_timing_header);
assert!(config.add_response_time_header);
assert!(!config.include_custom_metrics);
}
#[test]
fn timing_middleware_adds_metrics_to_request() {
let mw = TimingMetricsMiddleware::new();
let mut req = test_request();
let result = run_middleware_before(&mw, &mut req);
assert!(result.is_continue());
let metrics = req.get_extension::<TimingMetrics>();
assert!(metrics.is_some(), "TimingMetrics should be in extensions");
}
#[test]
fn timing_middleware_adds_response_time_header() {
let mw = TimingMetricsMiddleware::new();
let mut req = test_request();
run_middleware_before(&mw, &mut req);
let resp = Response::with_status(StatusCode::OK);
let result = run_middleware_after(&mw, &req, resp);
let has_timing = result
.headers()
.iter()
.any(|(name, _)| name == "X-Response-Time");
assert!(has_timing, "Should have X-Response-Time header");
}
#[test]
fn timing_middleware_adds_server_timing_header() {
let mw = TimingMetricsMiddleware::new();
let mut req = test_request();
run_middleware_before(&mw, &mut req);
let resp = Response::with_status(StatusCode::OK);
let result = run_middleware_after(&mw, &req, resp);
let server_timing = result
.headers()
.iter()
.find(|(name, _)| name == "Server-Timing")
.map(|(_, v)| String::from_utf8_lossy(v).to_string());
assert!(server_timing.is_some(), "Should have Server-Timing header");
let header = server_timing.unwrap();
assert!(header.contains("total"), "Should have total timing");
}
#[test]
fn timing_middleware_production_mode() {
let mw = TimingMetricsMiddleware::production();
let mut req = test_request();
run_middleware_before(&mw, &mut req);
let resp = Response::with_status(StatusCode::OK);
let result = run_middleware_after(&mw, &req, resp);
let has_response_time = result
.headers()
.iter()
.any(|(name, _)| name == "X-Response-Time");
assert!(has_response_time);
let has_server_timing = result
.headers()
.iter()
.any(|(name, _)| name == "Server-Timing");
assert!(!has_server_timing);
}
#[test]
#[allow(clippy::float_cmp)]
fn timing_histogram_basic() {
let mut histogram = TimingHistogram::http_latency();
assert_eq!(histogram.count(), 0);
assert_eq!(histogram.sum(), 0.0);
histogram.observe(42.0);
histogram.observe(150.0);
histogram.observe(5.0);
assert_eq!(histogram.count(), 3);
assert_eq!(histogram.sum(), 197.0);
assert!((histogram.mean() - 65.666).abs() < 0.01);
}
#[test]
fn timing_histogram_buckets() {
let mut histogram = TimingHistogram::with_buckets(vec![10.0, 50.0, 100.0]);
histogram.observe(5.0); histogram.observe(25.0); histogram.observe(75.0); histogram.observe(150.0);
let buckets = histogram.buckets();
assert_eq!(buckets.len(), 3);
assert_eq!(buckets[0].count, 1); assert_eq!(buckets[1].count, 2); assert_eq!(buckets[2].count, 3); }
#[test]
#[allow(clippy::float_cmp)]
fn timing_histogram_reset() {
let mut histogram = TimingHistogram::http_latency();
histogram.observe(100.0);
histogram.observe(200.0);
assert_eq!(histogram.count(), 2);
histogram.reset();
assert_eq!(histogram.count(), 0);
assert_eq!(histogram.sum(), 0.0);
}
}
#[cfg(test)]
mod response_interceptor_tests {
use super::*;
use crate::request::Method;
use crate::response::StatusCode;
fn test_context() -> RequestContext {
RequestContext::new(asupersync::Cx::for_testing(), 1)
}
fn test_request() -> Request {
Request::new(Method::Get, "/test")
}
fn run_interceptor<I: ResponseInterceptor>(
interceptor: &I,
req: &Request,
resp: Response,
) -> Response {
let ctx = test_context();
let start_time = Instant::now();
let interceptor_ctx = ResponseInterceptorContext::new(req, &ctx, start_time);
futures_executor::block_on(interceptor.intercept(&interceptor_ctx, resp))
}
#[test]
fn timing_interceptor_adds_header() {
let interceptor = TimingInterceptor::new();
let req = test_request();
let resp = Response::with_status(StatusCode::OK);
let result = run_interceptor(&interceptor, &req, resp);
let has_timing = result
.headers()
.iter()
.any(|(name, _)| name == "X-Response-Time");
assert!(has_timing, "Should have X-Response-Time header");
}
#[test]
fn timing_interceptor_with_server_timing() {
let interceptor = TimingInterceptor::new().with_server_timing("app");
let req = test_request();
let resp = Response::with_status(StatusCode::OK);
let result = run_interceptor(&interceptor, &req, resp);
let has_server_timing = result
.headers()
.iter()
.any(|(name, _)| name == "Server-Timing");
assert!(has_server_timing, "Should have Server-Timing header");
}
#[test]
fn timing_interceptor_custom_header_name() {
let interceptor = TimingInterceptor::new().header_name("X-Custom-Time");
let req = test_request();
let resp = Response::with_status(StatusCode::OK);
let result = run_interceptor(&interceptor, &req, resp);
let has_custom = result
.headers()
.iter()
.any(|(name, _)| name == "X-Custom-Time");
assert!(has_custom, "Should have X-Custom-Time header");
}
#[test]
fn debug_info_interceptor_adds_headers() {
let interceptor = DebugInfoInterceptor::new();
let req = test_request();
let resp = Response::with_status(StatusCode::OK);
let result = run_interceptor(&interceptor, &req, resp);
let has_path = result
.headers()
.iter()
.any(|(name, _)| name == "X-Debug-Path");
let has_method = result
.headers()
.iter()
.any(|(name, _)| name == "X-Debug-Method");
let has_timing = result
.headers()
.iter()
.any(|(name, _)| name == "X-Debug-Handler-Time");
assert!(has_path, "Should have X-Debug-Path header");
assert!(has_method, "Should have X-Debug-Method header");
assert!(has_timing, "Should have X-Debug-Handler-Time header");
}
#[test]
fn debug_info_interceptor_custom_prefix() {
let interceptor = DebugInfoInterceptor::new().header_prefix("X-Trace-");
let req = test_request();
let resp = Response::with_status(StatusCode::OK);
let result = run_interceptor(&interceptor, &req, resp);
let has_trace_path = result
.headers()
.iter()
.any(|(name, _)| name == "X-Trace-Path");
assert!(has_trace_path, "Should have X-Trace-Path header");
}
#[test]
fn debug_info_interceptor_selective_options() {
let interceptor = DebugInfoInterceptor::new()
.include_path(true)
.include_method(false)
.include_timing(false)
.include_request_id(false);
let req = test_request();
let resp = Response::with_status(StatusCode::OK);
let result = run_interceptor(&interceptor, &req, resp);
let has_path = result
.headers()
.iter()
.any(|(name, _)| name == "X-Debug-Path");
let has_method = result
.headers()
.iter()
.any(|(name, _)| name == "X-Debug-Method");
assert!(has_path, "Should have X-Debug-Path header");
assert!(!has_method, "Should NOT have X-Debug-Method header");
}
#[test]
fn header_transform_adds_headers() {
let interceptor = HeaderTransformInterceptor::new()
.add("X-Powered-By", b"fastapi_rust".to_vec())
.add("X-Version", b"1.0".to_vec());
let req = test_request();
let resp = Response::with_status(StatusCode::OK);
let result = run_interceptor(&interceptor, &req, resp);
let has_powered_by = result
.headers()
.iter()
.any(|(name, _)| name == "X-Powered-By");
let has_version = result.headers().iter().any(|(name, _)| name == "X-Version");
assert!(has_powered_by, "Should have X-Powered-By header");
assert!(has_version, "Should have X-Version header");
}
#[test]
fn response_body_transform_modifies_body() {
let transformer = ResponseBodyTransform::new(|body| {
let mut result = b"[".to_vec();
result.extend_from_slice(&body);
result.extend_from_slice(b"]");
result
});
let req = test_request();
let resp = Response::with_status(StatusCode::OK)
.body(crate::response::ResponseBody::Bytes(b"hello".to_vec()));
let result = run_interceptor(&transformer, &req, resp);
match result.body_ref() {
crate::response::ResponseBody::Bytes(b) => {
assert_eq!(b, b"[hello]");
}
_ => panic!("Expected bytes body"),
}
}
#[test]
fn response_body_transform_with_content_type_filter() {
let transformer =
ResponseBodyTransform::new(|_| b"transformed".to_vec()).for_content_type("text/plain");
let req = test_request();
let json_resp = Response::with_status(StatusCode::OK)
.header("content-type", b"application/json".to_vec())
.body(crate::response::ResponseBody::Bytes(b"original".to_vec()));
let result = run_interceptor(&transformer, &req, json_resp);
match result.body_ref() {
crate::response::ResponseBody::Bytes(b) => {
assert_eq!(b, b"original", "JSON should not be transformed");
}
_ => panic!("Expected bytes body"),
}
let text_resp = Response::with_status(StatusCode::OK)
.header("content-type", b"text/plain".to_vec())
.body(crate::response::ResponseBody::Bytes(b"original".to_vec()));
let result = run_interceptor(&transformer, &req, text_resp);
match result.body_ref() {
crate::response::ResponseBody::Bytes(b) => {
assert_eq!(b, b"transformed", "Text should be transformed");
}
_ => panic!("Expected bytes body"),
}
}
#[test]
fn error_response_transformer_hides_details() {
let transformer = ErrorResponseTransformer::new()
.hide_details_for_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_replacement_body(b"An error occurred");
let req = test_request();
let error_resp = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR).body(
crate::response::ResponseBody::Bytes(b"Sensitive error details".to_vec()),
);
let result = run_interceptor(&transformer, &req, error_resp);
match result.body_ref() {
crate::response::ResponseBody::Bytes(b) => {
assert_eq!(b, b"An error occurred");
}
_ => panic!("Expected bytes body"),
}
let ok_resp = Response::with_status(StatusCode::OK)
.body(crate::response::ResponseBody::Bytes(b"Success".to_vec()));
let result = run_interceptor(&transformer, &req, ok_resp);
match result.body_ref() {
crate::response::ResponseBody::Bytes(b) => {
assert_eq!(b, b"Success");
}
_ => panic!("Expected bytes body"),
}
}
#[test]
fn response_interceptor_stack_chains_interceptors() {
let mut stack = ResponseInterceptorStack::new();
stack.push(TimingInterceptor::new());
stack.push(HeaderTransformInterceptor::new().add("X-Extra", b"value".to_vec()));
let req = test_request();
let resp = Response::with_status(StatusCode::OK);
let ctx = test_context();
let start_time = Instant::now();
let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
let result = futures_executor::block_on(stack.process(&interceptor_ctx, resp));
let has_timing = result
.headers()
.iter()
.any(|(name, _)| name == "X-Response-Time");
let has_extra = result.headers().iter().any(|(name, _)| name == "X-Extra");
assert!(
has_timing,
"Should have timing header from first interceptor"
);
assert!(
has_extra,
"Should have extra header from second interceptor"
);
}
#[test]
fn response_interceptor_stack_empty_is_noop() {
let stack = ResponseInterceptorStack::new();
assert!(stack.is_empty());
assert_eq!(stack.len(), 0);
let req = test_request();
let resp = Response::with_status(StatusCode::OK)
.body(crate::response::ResponseBody::Bytes(b"unchanged".to_vec()));
let ctx = test_context();
let start_time = Instant::now();
let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
let result = futures_executor::block_on(stack.process(&interceptor_ctx, resp));
match result.body_ref() {
crate::response::ResponseBody::Bytes(b) => {
assert_eq!(b, b"unchanged");
}
_ => panic!("Expected bytes body"),
}
}
#[test]
fn interceptor_context_provides_timing() {
let ctx = test_context();
let req = test_request();
let start_time = Instant::now();
std::thread::sleep(std::time::Duration::from_millis(5));
let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
assert!(
interceptor_ctx.elapsed_ms() >= 5,
"Elapsed time should be at least 5ms"
);
assert!(interceptor_ctx.elapsed().as_millis() >= 5);
}
#[test]
fn conditional_interceptor_applies_conditionally() {
let inner = HeaderTransformInterceptor::new().add("X-Success", b"true".to_vec());
let conditional =
ConditionalInterceptor::new(inner, |_ctx, resp| resp.status().as_u16() == 200);
let req = test_request();
let ok_resp = Response::with_status(StatusCode::OK);
let result = run_interceptor(&conditional, &req, ok_resp);
let has_success = result.headers().iter().any(|(name, _)| name == "X-Success");
assert!(has_success, "200 response should get X-Success header");
let not_found = Response::with_status(StatusCode::NOT_FOUND);
let result = run_interceptor(&conditional, &req, not_found);
let has_success = result.headers().iter().any(|(name, _)| name == "X-Success");
assert!(!has_success, "404 response should NOT get X-Success header");
}
}
#[cfg(test)]
mod cache_control_tests {
use super::*;
use crate::request::Method;
use crate::response::StatusCode;
fn test_context() -> RequestContext {
RequestContext::new(asupersync::Cx::for_testing(), 1)
}
fn run_after(mw: &CacheControlMiddleware, req: &Request, resp: Response) -> Response {
let ctx = test_context();
let fut = mw.after(&ctx, req, resp);
futures_executor::block_on(fut)
}
#[test]
fn cache_directive_as_str_works() {
assert_eq!(CacheDirective::Public.as_str(), "public");
assert_eq!(CacheDirective::Private.as_str(), "private");
assert_eq!(CacheDirective::NoStore.as_str(), "no-store");
assert_eq!(CacheDirective::NoCache.as_str(), "no-cache");
assert_eq!(CacheDirective::MustRevalidate.as_str(), "must-revalidate");
assert_eq!(CacheDirective::Immutable.as_str(), "immutable");
}
#[test]
fn cache_control_builder_basic() {
let cc = CacheControlBuilder::new()
.public()
.max_age_secs(3600)
.build();
assert!(cc.contains("public"));
assert!(cc.contains("max-age=3600"));
}
#[test]
fn cache_control_builder_complex() {
let cc = CacheControlBuilder::new()
.public()
.max_age_secs(60)
.s_maxage_secs(3600)
.stale_while_revalidate_secs(86400)
.build();
assert!(cc.contains("public"));
assert!(cc.contains("max-age=60"));
assert!(cc.contains("s-maxage=3600"));
assert!(cc.contains("stale-while-revalidate=86400"));
}
#[test]
fn cache_control_builder_no_cache() {
let cc = CacheControlBuilder::new()
.no_store()
.no_cache()
.must_revalidate()
.build();
assert!(cc.contains("no-store"));
assert!(cc.contains("no-cache"));
assert!(cc.contains("must-revalidate"));
}
#[test]
fn cache_preset_no_cache() {
let value = CachePreset::NoCache.to_header_value();
assert!(value.contains("no-store"));
assert!(value.contains("no-cache"));
assert!(value.contains("must-revalidate"));
}
#[test]
fn cache_preset_immutable() {
let value = CachePreset::Immutable.to_header_value();
assert!(value.contains("public"));
assert!(value.contains("max-age=31536000"));
assert!(value.contains("immutable"));
}
#[test]
fn cache_preset_static_assets() {
let value = CachePreset::StaticAssets.to_header_value();
assert!(value.contains("public"));
assert!(value.contains("max-age=86400"));
}
#[test]
fn middleware_adds_cache_control_header() {
let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
let req = Request::new(Method::Get, "/api/test");
let resp = Response::with_status(StatusCode::OK);
let result = run_after(&mw, &req, resp);
let headers = result.headers();
let cc_header = headers
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
assert!(
cc_header.is_some(),
"Cache-Control header should be present"
);
let (_, value) = cc_header.unwrap();
let value_str = String::from_utf8_lossy(value);
assert!(value_str.contains("public"));
assert!(value_str.contains("max-age=3600"));
}
#[test]
fn middleware_skips_post_requests() {
let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
let req = Request::new(Method::Post, "/api/test");
let resp = Response::with_status(StatusCode::OK);
let result = run_after(&mw, &req, resp);
let headers = result.headers();
let cc_header = headers
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
assert!(
cc_header.is_none(),
"Cache-Control should not be added for POST"
);
}
#[test]
fn middleware_skips_error_responses() {
let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
let req = Request::new(Method::Get, "/api/test");
let resp = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
let result = run_after(&mw, &req, resp);
let headers = result.headers();
let cc_header = headers
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
assert!(
cc_header.is_none(),
"Cache-Control should not be added for error responses"
);
}
#[test]
fn middleware_with_vary_header() {
let mw = CacheControlMiddleware::with_config(
CacheControlConfig::from_preset(CachePreset::PublicOneHour)
.vary("Accept-Encoding")
.vary("Accept-Language"),
);
let req = Request::new(Method::Get, "/api/test");
let resp = Response::with_status(StatusCode::OK);
let result = run_after(&mw, &req, resp);
let headers = result.headers();
let vary_header = headers
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("vary"));
assert!(vary_header.is_some(), "Vary header should be present");
let (_, value) = vary_header.unwrap();
let value_str = String::from_utf8_lossy(value);
assert!(value_str.contains("Accept-Encoding"));
assert!(value_str.contains("Accept-Language"));
}
#[test]
fn middleware_preserves_existing_cache_control() {
let mw = CacheControlMiddleware::with_config(
CacheControlConfig::from_preset(CachePreset::PublicOneHour).preserve_existing(true),
);
let req = Request::new(Method::Get, "/api/test");
let resp =
Response::with_status(StatusCode::OK).header("Cache-Control", b"max-age=60".to_vec());
let result = run_after(&mw, &req, resp);
let headers = result.headers();
let cc_headers: Vec<_> = headers
.iter()
.filter(|(name, _)| name.eq_ignore_ascii_case("cache-control"))
.collect();
assert_eq!(cc_headers.len(), 1);
let (_, value) = cc_headers[0];
let value_str = String::from_utf8_lossy(value);
assert_eq!(value_str, "max-age=60");
}
#[test]
fn path_pattern_matching_exact() {
assert!(path_matches_pattern("/api/users", "/api/users"));
assert!(!path_matches_pattern("/api/users", "/api/items"));
}
#[test]
fn path_pattern_matching_wildcard() {
assert!(path_matches_pattern("/api/users/123", "/api/users/*"));
assert!(path_matches_pattern("/static/css/style.css", "/static/*"));
assert!(path_matches_pattern("/anything", "*"));
}
#[test]
fn date_formatting_works() {
let now = std::time::SystemTime::now();
let formatted = format_http_date(now);
assert!(formatted.ends_with(" GMT"));
let days = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
assert!(days.iter().any(|d| formatted.starts_with(d)));
}
#[test]
fn leap_year_detection() {
assert!(!is_leap_year(1900)); assert!(is_leap_year(2000)); assert!(is_leap_year(2024)); assert!(!is_leap_year(2023)); }
}
#[cfg(test)]
mod trace_rejection_tests {
use super::*;
use crate::request::Method;
use crate::response::StatusCode;
fn test_context() -> RequestContext {
RequestContext::new(asupersync::Cx::for_testing(), 1)
}
fn run_before(mw: &TraceRejectionMiddleware, req: &mut Request) -> ControlFlow {
let ctx = test_context();
let fut = mw.before(&ctx, req);
futures_executor::block_on(fut)
}
fn find_header<'a>(headers: &'a [(String, Vec<u8>)], name: &str) -> Option<&'a [u8]> {
headers
.iter()
.find(|(n, _)| n.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_slice())
}
#[test]
fn trace_request_rejected() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Trace, "/");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Break(response) => {
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
}
ControlFlow::Continue => panic!("TRACE request should have been rejected"),
}
}
#[test]
fn trace_request_with_path() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Trace, "/api/users/123");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Break(response) => {
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
}
ControlFlow::Continue => panic!("TRACE request should have been rejected"),
}
}
#[test]
fn get_request_allowed() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Get, "/");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("GET request should be allowed"),
}
}
#[test]
fn post_request_allowed() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Post, "/api/users");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("POST request should be allowed"),
}
}
#[test]
fn put_request_allowed() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Put, "/api/users/1");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("PUT request should be allowed"),
}
}
#[test]
fn delete_request_allowed() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Delete, "/api/users/1");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("DELETE request should be allowed"),
}
}
#[test]
fn patch_request_allowed() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Patch, "/api/users/1");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("PATCH request should be allowed"),
}
}
#[test]
fn options_request_allowed() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Options, "/api/users");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("OPTIONS request should be allowed"),
}
}
#[test]
fn head_request_allowed() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Head, "/");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("HEAD request should be allowed"),
}
}
#[test]
fn response_includes_allow_header() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Trace, "/");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Break(response) => {
let allow_header = find_header(response.headers(), "Allow");
assert!(
allow_header.is_some(),
"Response should include Allow header"
);
}
ControlFlow::Continue => panic!("TRACE request should have been rejected"),
}
}
#[test]
fn response_has_json_content_type() {
let mw = TraceRejectionMiddleware::new();
let mut req = Request::new(Method::Trace, "/");
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Break(response) => {
let ct_header = find_header(response.headers(), "Content-Type");
assert_eq!(ct_header, Some(b"application/json".as_slice()));
}
ControlFlow::Continue => panic!("TRACE request should have been rejected"),
}
}
#[test]
fn default_enables_logging() {
let mw = TraceRejectionMiddleware::new();
assert!(mw.log_attempts);
}
#[test]
fn log_attempts_can_be_disabled() {
let mw = TraceRejectionMiddleware::new().log_attempts(false);
assert!(!mw.log_attempts);
}
#[test]
fn middleware_name() {
let mw = TraceRejectionMiddleware::new();
assert_eq!(mw.name(), "TraceRejection");
}
#[test]
fn default_impl() {
let mw = TraceRejectionMiddleware::default();
assert!(mw.log_attempts);
}
}
#[cfg(test)]
mod https_redirect_tests {
use super::*;
use crate::request::Method;
use crate::response::StatusCode;
fn test_context() -> RequestContext {
RequestContext::new(asupersync::Cx::for_testing(), 1)
}
fn run_before(mw: &HttpsRedirectMiddleware, req: &mut Request) -> ControlFlow {
let ctx = test_context();
let fut = mw.before(&ctx, req);
futures_executor::block_on(fut)
}
fn run_after(mw: &HttpsRedirectMiddleware, req: &Request, resp: Response) -> Response {
let ctx = test_context();
let fut = mw.after(&ctx, req, resp);
futures_executor::block_on(fut)
}
fn find_header<'a>(headers: &'a [(String, Vec<u8>)], name: &str) -> Option<&'a [u8]> {
headers
.iter()
.find(|(n, _)| n.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_slice())
}
#[test]
fn http_request_redirected() {
let mw = HttpsRedirectMiddleware::new();
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("Host", b"example.com".to_vec());
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Break(response) => {
assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY);
let location = find_header(response.headers(), "Location");
assert_eq!(location, Some(b"https://example.com/".as_slice()));
}
ControlFlow::Continue => panic!("HTTP request should be redirected"),
}
}
#[test]
fn http_request_with_path_and_query() {
let mw = HttpsRedirectMiddleware::new();
let mut req = Request::new(Method::Get, "/api/users?page=1");
req.headers_mut().insert("Host", b"example.com".to_vec());
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Break(response) => {
let location = find_header(response.headers(), "Location");
assert_eq!(
location,
Some(b"https://example.com/api/users?page=1".as_slice())
);
}
ControlFlow::Continue => panic!("HTTP request should be redirected"),
}
}
#[test]
fn https_request_not_redirected() {
let mw = HttpsRedirectMiddleware::new();
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("Host", b"example.com".to_vec());
req.headers_mut()
.insert("X-Forwarded-Proto", b"https".to_vec());
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("HTTPS request should not be redirected"),
}
}
#[test]
fn x_forwarded_ssl_recognized() {
let mw = HttpsRedirectMiddleware::new();
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("Host", b"example.com".to_vec());
req.headers_mut().insert("X-Forwarded-Ssl", b"on".to_vec());
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Request with X-Forwarded-Ssl=on should not redirect"),
}
}
#[test]
fn excluded_path_not_redirected() {
let mw = HttpsRedirectMiddleware::new().exclude_path("/health");
let mut req = Request::new(Method::Get, "/health");
req.headers_mut().insert("Host", b"example.com".to_vec());
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Excluded path should not be redirected"),
}
}
#[test]
fn excluded_path_prefix_matches() {
let mw = HttpsRedirectMiddleware::new().exclude_path("/health");
let mut req = Request::new(Method::Get, "/health/live");
req.headers_mut().insert("Host", b"example.com".to_vec());
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Path with excluded prefix should not be redirected"),
}
}
#[test]
fn temporary_redirect_option() {
let mw = HttpsRedirectMiddleware::new().permanent_redirect(false);
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("Host", b"example.com".to_vec());
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Break(response) => {
assert_eq!(response.status(), StatusCode::TEMPORARY_REDIRECT);
}
ControlFlow::Continue => panic!("HTTP request should be redirected"),
}
}
#[test]
fn redirect_disabled() {
let mw = HttpsRedirectMiddleware::new().redirect_enabled(false);
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("Host", b"example.com".to_vec());
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Continue => {} ControlFlow::Break(_) => panic!("Redirects are disabled, should continue"),
}
}
#[test]
fn hsts_header_on_https_response() {
let mw = HttpsRedirectMiddleware::new();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("X-Forwarded-Proto", b"https".to_vec());
let response = Response::with_status(StatusCode::OK);
let result = run_after(&mw, &req, response);
let hsts = find_header(result.headers(), "Strict-Transport-Security");
assert!(
hsts.is_some(),
"HSTS header should be present on HTTPS response"
);
let hsts_str = String::from_utf8_lossy(hsts.unwrap());
assert!(hsts_str.contains("max-age=31536000"));
}
#[test]
fn hsts_header_not_on_http_response() {
let mw = HttpsRedirectMiddleware::new().redirect_enabled(false);
let req = Request::new(Method::Get, "/");
let response = Response::with_status(StatusCode::OK);
let result = run_after(&mw, &req, response);
let hsts = find_header(result.headers(), "Strict-Transport-Security");
assert!(hsts.is_none(), "HSTS header should not be on HTTP response");
}
#[test]
fn hsts_with_include_subdomains() {
let mw = HttpsRedirectMiddleware::new().include_subdomains(true);
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("X-Forwarded-Proto", b"https".to_vec());
let response = Response::with_status(StatusCode::OK);
let result = run_after(&mw, &req, response);
let hsts = find_header(result.headers(), "Strict-Transport-Security");
let hsts_str = String::from_utf8_lossy(hsts.unwrap());
assert!(hsts_str.contains("includeSubDomains"));
}
#[test]
fn hsts_with_preload() {
let mw = HttpsRedirectMiddleware::new().preload(true);
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("X-Forwarded-Proto", b"https".to_vec());
let response = Response::with_status(StatusCode::OK);
let result = run_after(&mw, &req, response);
let hsts = find_header(result.headers(), "Strict-Transport-Security");
let hsts_str = String::from_utf8_lossy(hsts.unwrap());
assert!(hsts_str.contains("preload"));
}
#[test]
fn hsts_disabled_with_zero_max_age() {
let mw = HttpsRedirectMiddleware::new().hsts_max_age_secs(0);
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("X-Forwarded-Proto", b"https".to_vec());
let response = Response::with_status(StatusCode::OK);
let result = run_after(&mw, &req, response);
let hsts = find_header(result.headers(), "Strict-Transport-Security");
assert!(hsts.is_none(), "HSTS should be disabled with max-age=0");
}
#[test]
fn custom_https_port() {
let mw = HttpsRedirectMiddleware::new().https_port(8443);
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("Host", b"example.com".to_vec());
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Break(response) => {
let location = find_header(response.headers(), "Location");
assert_eq!(location, Some(b"https://example.com:8443/".as_slice()));
}
ControlFlow::Continue => panic!("HTTP request should be redirected"),
}
}
#[test]
fn host_with_port_stripped() {
let mw = HttpsRedirectMiddleware::new();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("Host", b"example.com:8080".to_vec());
let result = run_before(&mw, &mut req);
match result {
ControlFlow::Break(response) => {
let location = find_header(response.headers(), "Location");
assert_eq!(location, Some(b"https://example.com/".as_slice()));
}
ControlFlow::Continue => panic!("HTTP request should be redirected"),
}
}
#[test]
fn middleware_name() {
let mw = HttpsRedirectMiddleware::new();
assert_eq!(mw.name(), "HttpsRedirect");
}
#[test]
fn default_impl() {
let mw = HttpsRedirectMiddleware::default();
assert!(mw.config.redirect_enabled);
assert!(mw.config.permanent_redirect);
assert_eq!(mw.config.hsts_max_age_secs, 31_536_000);
}
#[test]
fn config_builder() {
let mw = HttpsRedirectMiddleware::new()
.redirect_enabled(false)
.permanent_redirect(false)
.hsts_max_age_secs(86400)
.include_subdomains(true)
.preload(true)
.https_port(8443);
assert!(!mw.config.redirect_enabled);
assert!(!mw.config.permanent_redirect);
assert_eq!(mw.config.hsts_max_age_secs, 86400);
assert!(mw.config.hsts_include_subdomains);
assert!(mw.config.hsts_preload);
assert_eq!(mw.config.https_port, 8443);
}
#[test]
fn exclude_paths_method() {
let mw = HttpsRedirectMiddleware::new()
.exclude_paths(vec!["/health".to_string(), "/ready".to_string()]);
assert_eq!(mw.config.exclude_paths.len(), 2);
assert!(mw.config.exclude_paths.contains(&"/health".to_string()));
assert!(mw.config.exclude_paths.contains(&"/ready".to_string()));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::response::{ResponseBody, StatusCode};
#[allow(dead_code)]
struct AddHeaderMiddleware {
name: &'static str,
value: &'static [u8],
}
impl Middleware for AddHeaderMiddleware {
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move { response.header(self.name, self.value.to_vec()) })
}
}
#[allow(dead_code)]
struct BlockingMiddleware;
impl Middleware for BlockingMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
Box::pin(async {
ControlFlow::Break(
Response::with_status(StatusCode::FORBIDDEN)
.body(ResponseBody::Bytes(b"blocked".to_vec())),
)
})
}
}
#[allow(dead_code)]
struct TrackingMiddleware {
before_count: std::sync::atomic::AtomicUsize,
after_count: std::sync::atomic::AtomicUsize,
}
#[allow(dead_code)]
impl TrackingMiddleware {
fn new() -> Self {
Self {
before_count: std::sync::atomic::AtomicUsize::new(0),
after_count: std::sync::atomic::AtomicUsize::new(0),
}
}
fn before_count(&self) -> usize {
self.before_count.load(std::sync::atomic::Ordering::SeqCst)
}
fn after_count(&self) -> usize {
self.after_count.load(std::sync::atomic::Ordering::SeqCst)
}
}
impl Middleware for TrackingMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
self.before_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Box::pin(async { ControlFlow::Continue })
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
self.after_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Box::pin(async move { response })
}
}
#[test]
fn control_flow_variants() {
let cont = ControlFlow::Continue;
assert!(cont.is_continue());
assert!(!cont.is_break());
let brk = ControlFlow::Break(Response::ok());
assert!(!brk.is_continue());
assert!(brk.is_break());
}
#[test]
fn middleware_stack_empty() {
let stack = MiddlewareStack::new();
assert!(stack.is_empty());
assert_eq!(stack.len(), 0);
}
#[test]
fn middleware_stack_push() {
let mut stack = MiddlewareStack::new();
stack.push(NoopMiddleware);
stack.push(NoopMiddleware);
assert_eq!(stack.len(), 2);
assert!(!stack.is_empty());
}
#[test]
fn noop_middleware_name() {
let mw = NoopMiddleware;
assert_eq!(mw.name(), "Noop");
}
#[test]
fn logging_redacts_sensitive_headers() {
let mut headers = crate::request::Headers::new();
headers.insert("Authorization", b"secret".to_vec());
headers.insert("X-Request-Id", b"abc123".to_vec());
let redacted = super::default_redacted_headers();
let formatted = super::format_headers(headers.iter(), &redacted);
assert!(formatted.contains("authorization=<redacted>"));
assert!(formatted.contains("x-request-id=abc123"));
}
#[test]
fn logging_body_truncation() {
let body = b"abcdef";
let preview = super::format_bytes(body, 4);
assert_eq!(preview, "abcd...");
let preview_full = super::format_bytes(body, 10);
assert_eq!(preview_full, "abcdef");
}
fn test_context() -> RequestContext {
let cx = asupersync::Cx::for_testing();
RequestContext::new(cx, 1)
}
fn header_value(response: &Response, name: &str) -> Option<String> {
response
.headers()
.iter()
.find(|(n, _)| n.eq_ignore_ascii_case(name))
.and_then(|(_, v)| std::str::from_utf8(v).ok())
.map(ToString::to_string)
}
#[test]
fn cors_exact_origin_allows() {
let cors = Cors::new().allow_origin("https://example.com");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("origin", b"https://example.com".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
let response = Response::ok().body(ResponseBody::Bytes(b"ok".to_vec()));
let response = futures_executor::block_on(cors.after(&ctx, &req, response));
assert_eq!(
header_value(&response, "access-control-allow-origin"),
Some("https://example.com".to_string())
);
assert_eq!(header_value(&response, "vary"), Some("Origin".to_string()));
}
#[test]
fn cors_wildcard_origin_allows() {
let cors = Cors::new().allow_origin_wildcard("https://*.example.com");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("origin", b"https://api.example.com".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
}
#[test]
fn cors_regex_origin_allows() {
let cors = Cors::new().allow_origin_regex(r"^https://.*\.example\.com$");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("origin", b"https://svc.example.com".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
}
#[test]
fn cors_preflight_handled() {
let cors = Cors::new()
.allow_any_origin()
.allow_headers(["x-test", "content-type"])
.max_age(600);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Options, "/");
req.headers_mut()
.insert("origin", b"https://example.com".to_vec());
req.headers_mut()
.insert("access-control-request-method", b"POST".to_vec());
req.headers_mut().insert(
"access-control-request-headers",
b"x-test, content-type".to_vec(),
);
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
let ControlFlow::Break(response) = result else {
panic!("expected preflight break");
};
assert_eq!(response.status().as_u16(), 204);
assert_eq!(
header_value(&response, "access-control-allow-origin"),
Some("*".to_string())
);
assert_eq!(
header_value(&response, "access-control-allow-methods"),
Some("GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD".to_string())
);
assert_eq!(
header_value(&response, "access-control-allow-headers"),
Some("x-test, content-type".to_string())
);
assert_eq!(
header_value(&response, "access-control-max-age"),
Some("600".to_string())
);
}
#[test]
fn cors_credentials_echo_origin() {
let cors = Cors::new().allow_any_origin().allow_credentials(true);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("origin", b"https://example.com".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
assert_eq!(
header_value(&response, "access-control-allow-origin"),
Some("https://example.com".to_string())
);
assert_eq!(
header_value(&response, "access-control-allow-credentials"),
Some("true".to_string())
);
}
#[test]
fn cors_spec_compliance_credentials_never_wildcard_origin() {
let cors = Cors::new().allow_any_origin().allow_credentials(true);
let ctx = test_context();
for origin in &[
"https://example.com",
"https://api.example.com",
"http://localhost:3000",
] {
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("origin", origin.as_bytes().to_vec());
futures_executor::block_on(cors.before(&ctx, &mut req));
let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
let allow_origin = header_value(&response, "access-control-allow-origin");
assert_eq!(
allow_origin,
Some((*origin).to_string()),
"With credentials enabled, Access-Control-Allow-Origin must echo '{}', not '*'",
origin
);
assert_ne!(
allow_origin,
Some("*".to_string()),
"CORS spec violation: credentials + wildcard origin is forbidden"
);
}
}
#[test]
fn cors_spec_compliance_preflight_with_credentials() {
let cors = Cors::new()
.allow_any_origin()
.allow_credentials(true)
.allow_headers(["content-type", "x-custom-header"]);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Options, "/");
req.headers_mut()
.insert("origin", b"https://example.com".to_vec());
req.headers_mut()
.insert("access-control-request-method", b"POST".to_vec());
req.headers_mut()
.insert("access-control-request-headers", b"content-type".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
let ControlFlow::Break(response) = result else {
panic!("expected preflight break");
};
let allow_origin = header_value(&response, "access-control-allow-origin");
assert_eq!(allow_origin, Some("https://example.com".to_string()));
assert_ne!(
allow_origin,
Some("*".to_string()),
"CORS spec violation: preflight with credentials must not use wildcard origin"
);
assert_eq!(
header_value(&response, "access-control-allow-credentials"),
Some("true".to_string())
);
}
#[test]
fn cors_spec_without_credentials_allows_wildcard() {
let cors = Cors::new().allow_any_origin();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("origin", b"https://example.com".to_vec());
futures_executor::block_on(cors.before(&ctx, &mut req));
let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
assert_eq!(
header_value(&response, "access-control-allow-origin"),
Some("*".to_string())
);
assert!(header_value(&response, "access-control-allow-credentials").is_none());
}
#[test]
fn cors_disallowed_preflight_forbidden() {
let cors = Cors::new().allow_origin("https://good.example");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Options, "/");
req.headers_mut()
.insert("origin", b"https://evil.example".to_vec());
req.headers_mut()
.insert("access-control-request-method", b"GET".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
let ControlFlow::Break(response) = result else {
panic!("expected forbidden preflight");
};
assert_eq!(response.status().as_u16(), 403);
}
#[test]
fn cors_simple_request_disallowed_origin_no_headers() {
let cors = Cors::new().allow_origin("https://good.example");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("origin", b"https://evil.example".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
assert!(header_value(&response, "access-control-allow-origin").is_none());
}
#[test]
fn cors_expose_headers_configuration() {
let cors = Cors::new()
.allow_any_origin()
.expose_headers(["x-custom-header", "x-another-header"]);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("origin", b"https://example.com".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
assert_eq!(
header_value(&response, "access-control-expose-headers"),
Some("x-custom-header, x-another-header".to_string())
);
}
#[test]
fn cors_any_origin_sets_wildcard() {
let cors = Cors::new().allow_any_origin();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("origin", b"https://any-site.com".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
assert_eq!(
header_value(&response, "access-control-allow-origin"),
Some("*".to_string())
);
}
#[test]
fn cors_config_allows_method_override() {
let cors = Cors::new()
.allow_any_origin()
.allow_methods([crate::request::Method::Get, crate::request::Method::Post]);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Options, "/");
req.headers_mut()
.insert("origin", b"https://example.com".to_vec());
req.headers_mut()
.insert("access-control-request-method", b"POST".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
let ControlFlow::Break(response) = result else {
panic!("expected preflight break");
};
assert_eq!(
header_value(&response, "access-control-allow-methods"),
Some("GET, POST".to_string())
);
}
#[test]
fn cors_no_origin_header_skips_cors() {
let cors = Cors::new().allow_any_origin();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
assert!(header_value(&response, "access-control-allow-origin").is_none());
}
#[test]
fn cors_middleware_name() {
let cors = Cors::new();
assert_eq!(cors.name(), "Cors");
}
#[test]
fn cors_empty_allowed_headers_does_not_reflect_request_headers() {
let cors = Cors::new().allow_any_origin(); let ctx = test_context();
let mut req = Request::new(crate::request::Method::Options, "/api");
req.headers_mut()
.insert("origin", b"https://example.com".to_vec());
req.headers_mut()
.insert("access-control-request-method", b"GET".to_vec());
req.headers_mut().insert(
"access-control-request-headers",
b"x-evil-custom, authorization".to_vec(),
);
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
if let ControlFlow::Break(response) = result {
assert_eq!(
header_value(&response, "access-control-allow-headers"),
None,
"Empty allowed_headers must not reflect request headers"
);
} else {
panic!("Preflight should have been handled (Break)");
}
}
#[test]
fn cors_explicit_allowed_headers_returned_in_preflight() {
let cors = Cors::new()
.allow_any_origin()
.allow_headers(["x-token", "content-type"]);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Options, "/api");
req.headers_mut()
.insert("origin", b"https://example.com".to_vec());
req.headers_mut()
.insert("access-control-request-method", b"POST".to_vec());
let result = futures_executor::block_on(cors.before(&ctx, &mut req));
if let ControlFlow::Break(response) = result {
let headers_val = header_value(&response, "access-control-allow-headers");
assert!(headers_val.is_some());
let val = headers_val.unwrap();
assert!(val.contains("x-token"));
assert!(val.contains("content-type"));
} else {
panic!("Preflight should have been handled (Break)");
}
}
#[test]
fn request_id_generates_unique_ids() {
let id1 = RequestId::generate();
let id2 = RequestId::generate();
let id3 = RequestId::generate();
assert_ne!(id1, id2);
assert_ne!(id2, id3);
assert_ne!(id1, id3);
assert!(!id1.as_str().is_empty());
assert!(!id2.as_str().is_empty());
assert!(!id3.as_str().is_empty());
}
#[test]
fn request_id_display() {
let id = RequestId::new("test-request-123");
assert_eq!(format!("{}", id), "test-request-123");
}
#[test]
fn request_id_from_string() {
let id: RequestId = "my-id".into();
assert_eq!(id.as_str(), "my-id");
let id2: RequestId = String::from("my-id-2").into();
assert_eq!(id2.as_str(), "my-id-2");
}
#[test]
fn request_id_config_defaults() {
let config = RequestIdConfig::default();
assert_eq!(config.header_name, "x-request-id");
assert!(config.accept_from_client);
assert!(config.add_to_response);
assert_eq!(config.max_client_id_length, 128);
}
#[test]
fn request_id_config_builder() {
let config = RequestIdConfig::new()
.header_name("X-Trace-ID")
.accept_from_client(false)
.add_to_response(false)
.max_client_id_length(64);
assert_eq!(config.header_name, "X-Trace-ID");
assert!(!config.accept_from_client);
assert!(!config.add_to_response);
assert_eq!(config.max_client_id_length, 64);
}
#[test]
fn request_id_middleware_generates_id() {
let middleware = RequestIdMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let result = futures_executor::block_on(middleware.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
let stored_id = req.get_extension::<RequestId>();
assert!(stored_id.is_some());
assert!(!stored_id.unwrap().as_str().is_empty());
}
#[test]
fn request_id_middleware_accepts_client_id() {
let middleware = RequestIdMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("x-request-id", b"client-provided-id-123".to_vec());
futures_executor::block_on(middleware.before(&ctx, &mut req));
let stored_id = req.get_extension::<RequestId>().unwrap();
assert_eq!(stored_id.as_str(), "client-provided-id-123");
}
#[test]
fn request_id_middleware_rejects_invalid_client_id() {
let middleware = RequestIdMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("x-request-id", b"invalid<script>id".to_vec());
futures_executor::block_on(middleware.before(&ctx, &mut req));
let stored_id = req.get_extension::<RequestId>().unwrap();
assert_ne!(stored_id.as_str(), "invalid<script>id");
}
#[test]
fn request_id_middleware_rejects_too_long_client_id() {
let config = RequestIdConfig::new().max_client_id_length(10);
let middleware = RequestIdMiddleware::with_config(config);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("x-request-id", b"this-id-is-way-too-long".to_vec());
futures_executor::block_on(middleware.before(&ctx, &mut req));
let stored_id = req.get_extension::<RequestId>().unwrap();
assert_ne!(stored_id.as_str(), "this-id-is-way-too-long");
}
#[test]
fn request_id_middleware_adds_to_response() {
let middleware = RequestIdMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
futures_executor::block_on(middleware.before(&ctx, &mut req));
let stored_id = req.get_extension::<RequestId>().unwrap().clone();
let response = Response::ok();
let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
let header = header_value(&response, "x-request-id");
assert_eq!(header, Some(stored_id.0));
}
#[test]
fn request_id_middleware_respects_add_to_response_false() {
let config = RequestIdConfig::new().add_to_response(false);
let middleware = RequestIdMiddleware::with_config(config);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
futures_executor::block_on(middleware.before(&ctx, &mut req));
let response = Response::ok();
let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
let header = header_value(&response, "x-request-id");
assert!(header.is_none());
}
#[test]
fn request_id_middleware_respects_accept_from_client_false() {
let config = RequestIdConfig::new().accept_from_client(false);
let middleware = RequestIdMiddleware::with_config(config);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("x-request-id", b"client-id".to_vec());
futures_executor::block_on(middleware.before(&ctx, &mut req));
let stored_id = req.get_extension::<RequestId>().unwrap();
assert_ne!(stored_id.as_str(), "client-id");
}
#[test]
fn request_id_middleware_custom_header_name() {
let config = RequestIdConfig::new().header_name("X-Trace-ID");
let middleware = RequestIdMiddleware::with_config(config);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("X-Trace-ID", b"trace-123".to_vec());
futures_executor::block_on(middleware.before(&ctx, &mut req));
let stored_id = req.get_extension::<RequestId>().unwrap();
assert_eq!(stored_id.as_str(), "trace-123");
let response = Response::ok();
let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
let header = header_value(&response, "X-Trace-ID");
assert_eq!(header, Some("trace-123".to_string()));
}
#[test]
fn is_valid_request_id_accepts_valid() {
assert!(super::is_valid_request_id("abc123"));
assert!(super::is_valid_request_id("request-id-123"));
assert!(super::is_valid_request_id("request_id_123"));
assert!(super::is_valid_request_id("request.id.123"));
assert!(super::is_valid_request_id("ABC123"));
assert!(super::is_valid_request_id("a-b_c.D"));
}
#[test]
fn is_valid_request_id_rejects_invalid() {
assert!(!super::is_valid_request_id(""));
assert!(!super::is_valid_request_id("id with spaces"));
assert!(!super::is_valid_request_id("id<script>"));
assert!(!super::is_valid_request_id("id\nwith\nnewlines"));
assert!(!super::is_valid_request_id("id;with;semicolons"));
assert!(!super::is_valid_request_id("id/with/slashes"));
}
#[test]
fn request_id_middleware_name() {
let middleware = RequestIdMiddleware::new();
assert_eq!(middleware.name(), "RequestId");
}
struct OrderTrackingMiddleware {
id: &'static str,
log: Arc<std::sync::Mutex<Vec<String>>>,
}
impl OrderTrackingMiddleware {
fn new(id: &'static str, log: Arc<std::sync::Mutex<Vec<String>>>) -> Self {
Self { id, log }
}
}
impl Middleware for OrderTrackingMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
self.log.lock().unwrap().push(format!("{}.before", self.id));
Box::pin(async { ControlFlow::Continue })
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
self.log.lock().unwrap().push(format!("{}.after", self.id));
Box::pin(async move { response })
}
}
struct ConditionalBreakMiddleware {
id: &'static str,
should_break: bool,
log: Arc<std::sync::Mutex<Vec<String>>>,
}
impl ConditionalBreakMiddleware {
fn new(
id: &'static str,
should_break: bool,
log: Arc<std::sync::Mutex<Vec<String>>>,
) -> Self {
Self {
id,
should_break,
log,
}
}
}
impl Middleware for ConditionalBreakMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
self.log.lock().unwrap().push(format!("{}.before", self.id));
let should_break = self.should_break;
Box::pin(async move {
if should_break {
ControlFlow::Break(
Response::with_status(StatusCode::FORBIDDEN)
.body(ResponseBody::Bytes(b"blocked".to_vec())),
)
} else {
ControlFlow::Continue
}
})
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
self.log.lock().unwrap().push(format!("{}.after", self.id));
Box::pin(async move { response })
}
}
struct OkHandler;
impl Handler for OkHandler {
fn call<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, Response> {
Box::pin(async move { Response::ok().body(ResponseBody::Bytes(b"handler".to_vec())) })
}
}
struct CheckHeaderHandler;
impl Handler for CheckHeaderHandler {
fn call<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, Response> {
let has_header = req.headers().get("X-Modified-By").is_some();
Box::pin(async move {
if has_header {
Response::ok().body(ResponseBody::Bytes(b"header-present".to_vec()))
} else {
Response::with_status(StatusCode::BAD_REQUEST)
}
})
}
}
struct ErrorHandler;
impl Handler for ErrorHandler {
fn call<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, Response> {
Box::pin(async move { Response::with_status(StatusCode::INTERNAL_SERVER_ERROR) })
}
}
#[test]
fn middleware_stack_executes_in_correct_order() {
let log = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
stack.push(OrderTrackingMiddleware::new("mw3", log.clone()));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
let calls = log.lock().unwrap().clone();
assert_eq!(
calls,
vec![
"mw1.before",
"mw2.before",
"mw3.before",
"mw3.after",
"mw2.after",
"mw1.after",
]
);
}
#[test]
fn middleware_stack_short_circuit_skips_later_middleware() {
let log = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
stack.push(ConditionalBreakMiddleware::new("mw2", true, log.clone()));
stack.push(OrderTrackingMiddleware::new("mw3", log.clone()));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 403);
let calls = log.lock().unwrap().clone();
assert_eq!(
calls,
vec![
"mw1.before",
"mw2.before",
"mw1.after",
]
);
}
#[test]
fn middleware_stack_first_middleware_breaks() {
let log = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(ConditionalBreakMiddleware::new("mw1", true, log.clone()));
stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 403);
let calls = log.lock().unwrap().clone();
assert_eq!(calls, vec!["mw1.before"]);
}
#[test]
fn middleware_stack_last_middleware_breaks() {
let log = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
stack.push(ConditionalBreakMiddleware::new("mw3", true, log.clone()));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 403);
let calls = log.lock().unwrap().clone();
assert_eq!(
calls,
vec![
"mw1.before",
"mw2.before",
"mw3.before",
"mw2.after",
"mw1.after",
]
);
}
#[test]
fn middleware_stack_empty_executes_handler_directly() {
let stack = MiddlewareStack::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 200);
}
#[test]
fn middleware_stack_with_capacity() {
let stack = MiddlewareStack::with_capacity(10);
assert!(stack.is_empty());
assert_eq!(stack.len(), 0);
}
#[test]
fn middleware_stack_push_arc() {
let mut stack = MiddlewareStack::new();
let mw: Arc<dyn Middleware> = Arc::new(NoopMiddleware);
stack.push_arc(mw);
assert_eq!(stack.len(), 1);
}
#[test]
fn add_response_header_adds_header() {
let mw = AddResponseHeader::new("X-Custom", b"custom-value".to_vec());
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/");
let response = Response::ok();
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
assert_eq!(
header_value(&response, "X-Custom"),
Some("custom-value".to_string())
);
}
#[test]
fn add_response_header_preserves_existing_headers() {
let mw = AddResponseHeader::new("X-New", b"new".to_vec());
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/");
let response = Response::ok().header("X-Existing", b"existing".to_vec());
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
assert_eq!(
header_value(&response, "X-Existing"),
Some("existing".to_string())
);
assert_eq!(header_value(&response, "X-New"), Some("new".to_string()));
}
#[test]
fn add_response_header_name() {
let mw = AddResponseHeader::new("X-Test", b"test".to_vec());
assert_eq!(mw.name(), "AddResponseHeader");
}
#[test]
fn require_header_allows_with_header() {
let mw = RequireHeader::new("X-Api-Key");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("X-Api-Key", b"secret-key".to_vec());
let result = futures_executor::block_on(mw.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
}
#[test]
fn require_header_blocks_without_header() {
let mw = RequireHeader::new("X-Api-Key");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let result = futures_executor::block_on(mw.before(&ctx, &mut req));
match result {
ControlFlow::Break(response) => {
assert_eq!(response.status().as_u16(), 400);
}
ControlFlow::Continue => panic!("Expected Break, got Continue"),
}
}
#[test]
fn require_header_name() {
let mw = RequireHeader::new("X-Test");
assert_eq!(mw.name(), "RequireHeader");
}
#[test]
fn path_prefix_filter_allows_matching_path() {
let mw = PathPrefixFilter::new("/api");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/api/users");
let result = futures_executor::block_on(mw.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
}
#[test]
fn path_prefix_filter_allows_exact_prefix() {
let mw = PathPrefixFilter::new("/api");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/api");
let result = futures_executor::block_on(mw.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
}
#[test]
fn path_prefix_filter_blocks_non_matching_path() {
let mw = PathPrefixFilter::new("/api");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/admin/users");
let result = futures_executor::block_on(mw.before(&ctx, &mut req));
match result {
ControlFlow::Break(response) => {
assert_eq!(response.status().as_u16(), 404);
}
ControlFlow::Continue => panic!("Expected Break, got Continue"),
}
}
#[test]
fn path_prefix_filter_name() {
let mw = PathPrefixFilter::new("/api");
assert_eq!(mw.name(), "PathPrefixFilter");
}
#[test]
fn conditional_status_applies_true_status() {
let mw = ConditionalStatus::new(
|req| req.path() == "/health",
StatusCode::OK,
StatusCode::NOT_FOUND,
);
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/health");
let response = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
assert_eq!(response.status().as_u16(), 200);
}
#[test]
fn conditional_status_applies_false_status() {
let mw = ConditionalStatus::new(
|req| req.path() == "/health",
StatusCode::OK,
StatusCode::NOT_FOUND,
);
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/other");
let response = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
assert_eq!(response.status().as_u16(), 404);
}
#[test]
fn conditional_status_name() {
let mw = ConditionalStatus::new(|_| true, StatusCode::OK, StatusCode::NOT_FOUND);
assert_eq!(mw.name(), "ConditionalStatus");
}
#[derive(Clone)]
struct LayerTestMiddleware {
prefix: String,
}
impl LayerTestMiddleware {
fn new(prefix: impl Into<String>) -> Self {
Self {
prefix: prefix.into(),
}
}
}
impl Middleware for LayerTestMiddleware {
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let prefix = self.prefix.clone();
Box::pin(async move { response.header("X-Layer", prefix.into_bytes()) })
}
}
#[test]
fn layer_wraps_handler() {
let layer = Layer::new(LayerTestMiddleware::new("wrapped"));
let wrapped = layer.wrap(OkHandler);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(wrapped.call(&ctx, &mut req));
assert_eq!(response.status().as_u16(), 200);
assert_eq!(
header_value(&response, "X-Layer"),
Some("wrapped".to_string())
);
}
#[test]
fn layered_handles_break() {
#[derive(Clone)]
struct BreakingMiddleware;
impl Middleware for BreakingMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
Box::pin(async {
ControlFlow::Break(Response::with_status(StatusCode::UNAUTHORIZED))
})
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async move { response.header("X-After", b"ran".to_vec()) })
}
}
let layer = Layer::new(BreakingMiddleware);
let wrapped = layer.wrap(OkHandler);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(wrapped.call(&ctx, &mut req));
assert_eq!(response.status().as_u16(), 401);
assert_eq!(header_value(&response, "X-After"), Some("ran".to_string()));
}
#[test]
fn request_response_logger_default() {
let logger = RequestResponseLogger::default();
assert!(logger.log_request_headers);
assert!(logger.log_response_headers);
assert!(!logger.log_body);
assert_eq!(logger.max_body_bytes, 1024);
}
#[test]
fn request_response_logger_builder() {
let logger = RequestResponseLogger::new()
.log_request_headers(false)
.log_response_headers(false)
.log_body(true)
.max_body_bytes(2048)
.redact_header("x-secret");
assert!(!logger.log_request_headers);
assert!(!logger.log_response_headers);
assert!(logger.log_body);
assert_eq!(logger.max_body_bytes, 2048);
assert!(logger.redact_headers.contains("x-secret"));
}
#[test]
fn request_response_logger_name() {
let logger = RequestResponseLogger::new();
assert_eq!(logger.name(), "RequestResponseLogger");
}
#[test]
fn middleware_stack_modifies_request_for_handler() {
struct RequestModifier;
impl Middleware for RequestModifier {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
req.headers_mut()
.insert("X-Modified-By", b"middleware".to_vec());
Box::pin(async { ControlFlow::Continue })
}
}
let mut stack = MiddlewareStack::new();
stack.push(RequestModifier);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response =
futures_executor::block_on(stack.execute(&CheckHeaderHandler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 200);
}
#[test]
fn middleware_stack_multiple_response_modifications() {
let mut stack = MiddlewareStack::new();
stack.push(AddResponseHeader::new("X-First", b"1".to_vec()));
stack.push(AddResponseHeader::new("X-Second", b"2".to_vec()));
stack.push(AddResponseHeader::new("X-Third", b"3".to_vec()));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
assert_eq!(header_value(&response, "X-First"), Some("1".to_string()));
assert_eq!(header_value(&response, "X-Second"), Some("2".to_string()));
assert_eq!(header_value(&response, "X-Third"), Some("3".to_string()));
}
#[test]
fn middleware_stack_handler_receives_response_after_break() {
let mut stack = MiddlewareStack::new();
stack.push(ConditionalBreakMiddleware::new(
"breaker",
true,
Arc::new(std::sync::Mutex::new(Vec::new())),
));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 403);
match response.body_ref() {
ResponseBody::Bytes(b) => assert_eq!(b, b"blocked"),
_ => panic!("Expected Bytes body"),
}
}
#[test]
fn middleware_after_can_change_status() {
struct StatusChanger;
impl Middleware for StatusChanger {
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
_response: Response,
) -> BoxFuture<'a, Response> {
Box::pin(async { Response::with_status(StatusCode::SERVICE_UNAVAILABLE) })
}
}
let mut stack = MiddlewareStack::new();
stack.push(StatusChanger);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 503);
}
#[test]
fn middleware_after_runs_even_on_error_status() {
let log = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&ErrorHandler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 500);
let calls = log.lock().unwrap().clone();
assert_eq!(calls, vec!["mw1.before", "mw1.after"]);
}
#[test]
fn wildcard_match_simple() {
assert!(super::wildcard_match("*.example.com", "api.example.com"));
assert!(super::wildcard_match("*.example.com", "www.example.com"));
assert!(!super::wildcard_match("*.example.com", "example.com"));
}
#[test]
fn wildcard_match_suffix_pattern() {
assert!(super::wildcard_match("*.txt", "file.txt"));
assert!(super::wildcard_match("*.txt", "document.txt"));
assert!(!super::wildcard_match("*.txt", "file.doc"));
assert!(super::wildcard_match("*-suffix", "any-suffix"));
}
#[test]
fn wildcard_match_no_wildcard() {
assert!(super::wildcard_match("exact", "exact"));
assert!(!super::wildcard_match("exact", "different"));
}
#[test]
fn regex_match_anchored() {
assert!(super::regex_match("^hello$", "hello"));
assert!(!super::regex_match("^hello$", "hello world"));
assert!(!super::regex_match("^hello$", "say hello"));
}
#[test]
fn regex_match_dot_wildcard() {
assert!(super::regex_match("h.llo", "hello"));
assert!(super::regex_match("h.llo", "hallo"));
}
#[test]
fn regex_match_star() {
assert!(super::regex_match("hel*o", "hello"));
assert!(super::regex_match("hel*o", "helo"));
assert!(super::regex_match("hel*o", "hellllllo"));
}
#[test]
fn middleware_default_before_continues() {
struct DefaultBefore;
impl Middleware for DefaultBefore {}
let mw = DefaultBefore;
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let result = futures_executor::block_on(mw.before(&ctx, &mut req));
assert!(matches!(result, ControlFlow::Continue));
}
#[test]
fn middleware_default_after_passes_through() {
struct DefaultAfter;
impl Middleware for DefaultAfter {}
let mw = DefaultAfter;
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/");
let response = Response::with_status(StatusCode::CREATED);
let result = futures_executor::block_on(mw.after(&ctx, &req, response));
assert_eq!(result.status().as_u16(), 201);
}
#[test]
fn middleware_default_name_is_type_name() {
struct MyCustomMiddleware;
impl Middleware for MyCustomMiddleware {}
let mw = MyCustomMiddleware;
assert!(mw.name().contains("MyCustomMiddleware"));
}
#[test]
fn security_headers_default_config() {
let config = SecurityHeadersConfig::default();
assert_eq!(config.x_content_type_options, Some("nosniff"));
assert_eq!(config.x_frame_options, Some(XFrameOptions::Deny));
assert_eq!(config.x_xss_protection, Some("0"));
assert!(config.content_security_policy.is_none());
assert!(config.hsts.is_none());
assert_eq!(
config.referrer_policy,
Some(ReferrerPolicy::StrictOriginWhenCrossOrigin)
);
assert!(config.permissions_policy.is_none());
}
#[test]
fn security_headers_none_config() {
let config = SecurityHeadersConfig::none();
assert!(config.x_content_type_options.is_none());
assert!(config.x_frame_options.is_none());
assert!(config.x_xss_protection.is_none());
assert!(config.content_security_policy.is_none());
assert!(config.hsts.is_none());
assert!(config.referrer_policy.is_none());
assert!(config.permissions_policy.is_none());
}
#[test]
fn security_headers_strict_config() {
let config = SecurityHeadersConfig::strict();
assert_eq!(config.x_content_type_options, Some("nosniff"));
assert_eq!(config.x_frame_options, Some(XFrameOptions::Deny));
assert_eq!(
config.content_security_policy,
Some("default-src 'self'".to_string())
);
assert_eq!(config.hsts, Some((31536000, true, false)));
assert_eq!(config.referrer_policy, Some(ReferrerPolicy::NoReferrer));
assert!(config.permissions_policy.is_some());
}
#[test]
fn security_headers_config_builder() {
let config = SecurityHeadersConfig::new()
.x_frame_options(Some(XFrameOptions::SameOrigin))
.content_security_policy("default-src 'self'")
.hsts(86400, false, false)
.referrer_policy(Some(ReferrerPolicy::Origin));
assert_eq!(config.x_frame_options, Some(XFrameOptions::SameOrigin));
assert_eq!(
config.content_security_policy,
Some("default-src 'self'".to_string())
);
assert_eq!(config.hsts, Some((86400, false, false)));
assert_eq!(config.referrer_policy, Some(ReferrerPolicy::Origin));
}
#[test]
fn security_headers_hsts_value_format() {
let config = SecurityHeadersConfig::none().hsts(3600, false, false);
assert_eq!(config.build_hsts_value(), Some("max-age=3600".to_string()));
let config = SecurityHeadersConfig::none().hsts(3600, true, false);
assert_eq!(
config.build_hsts_value(),
Some("max-age=3600; includeSubDomains".to_string())
);
let config = SecurityHeadersConfig::none().hsts(3600, false, true);
assert_eq!(
config.build_hsts_value(),
Some("max-age=3600; preload".to_string())
);
let config = SecurityHeadersConfig::none().hsts(3600, true, true);
assert_eq!(
config.build_hsts_value(),
Some("max-age=3600; includeSubDomains; preload".to_string())
);
}
#[test]
fn security_headers_middleware_adds_default_headers() {
let mw = SecurityHeaders::new();
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/");
let response = Response::ok();
let result = futures_executor::block_on(mw.after(&ctx, &req, response));
assert!(header_value(&result, "X-Content-Type-Options").is_some());
assert!(header_value(&result, "X-Frame-Options").is_some());
assert!(header_value(&result, "X-XSS-Protection").is_some());
assert!(header_value(&result, "Referrer-Policy").is_some());
assert!(header_value(&result, "Content-Security-Policy").is_none());
assert!(header_value(&result, "Strict-Transport-Security").is_none());
assert!(header_value(&result, "Permissions-Policy").is_none());
}
#[test]
fn security_headers_middleware_with_csp() {
let config = SecurityHeadersConfig::new()
.content_security_policy("default-src 'self'; script-src 'self' 'unsafe-inline'");
let mw = SecurityHeaders::with_config(config);
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/");
let response = Response::ok();
let result = futures_executor::block_on(mw.after(&ctx, &req, response));
let csp = header_value(&result, "Content-Security-Policy");
assert!(csp.is_some());
assert_eq!(
csp.unwrap(),
"default-src 'self'; script-src 'self' 'unsafe-inline'"
);
}
#[test]
fn security_headers_middleware_with_hsts() {
let config = SecurityHeadersConfig::new().hsts(31536000, true, false);
let mw = SecurityHeaders::with_config(config);
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/");
let response = Response::ok();
let result = futures_executor::block_on(mw.after(&ctx, &req, response));
let hsts = header_value(&result, "Strict-Transport-Security");
assert!(hsts.is_some());
assert_eq!(hsts.unwrap(), "max-age=31536000; includeSubDomains");
}
#[test]
fn security_headers_middleware_name() {
let mw = SecurityHeaders::new();
assert_eq!(mw.name(), "SecurityHeaders");
}
#[test]
fn x_frame_options_values() {
assert_eq!(XFrameOptions::Deny.as_bytes(), b"DENY");
assert_eq!(XFrameOptions::SameOrigin.as_bytes(), b"SAMEORIGIN");
}
#[test]
fn referrer_policy_values() {
assert_eq!(ReferrerPolicy::NoReferrer.as_bytes(), b"no-referrer");
assert_eq!(
ReferrerPolicy::NoReferrerWhenDowngrade.as_bytes(),
b"no-referrer-when-downgrade"
);
assert_eq!(ReferrerPolicy::Origin.as_bytes(), b"origin");
assert_eq!(
ReferrerPolicy::OriginWhenCrossOrigin.as_bytes(),
b"origin-when-cross-origin"
);
assert_eq!(ReferrerPolicy::SameOrigin.as_bytes(), b"same-origin");
assert_eq!(ReferrerPolicy::StrictOrigin.as_bytes(), b"strict-origin");
assert_eq!(
ReferrerPolicy::StrictOriginWhenCrossOrigin.as_bytes(),
b"strict-origin-when-cross-origin"
);
assert_eq!(ReferrerPolicy::UnsafeUrl.as_bytes(), b"unsafe-url");
}
#[test]
fn security_headers_strict_preset() {
let mw = SecurityHeaders::strict();
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/");
let response = Response::ok();
let result = futures_executor::block_on(mw.after(&ctx, &req, response));
assert!(header_value(&result, "X-Content-Type-Options").is_some());
assert!(header_value(&result, "X-Frame-Options").is_some());
assert!(header_value(&result, "Content-Security-Policy").is_some());
assert!(header_value(&result, "Strict-Transport-Security").is_some());
assert!(header_value(&result, "Referrer-Policy").is_some());
assert!(header_value(&result, "Permissions-Policy").is_some());
}
#[test]
fn security_headers_config_clearing_methods() {
let config = SecurityHeadersConfig::strict()
.no_content_security_policy()
.no_hsts()
.no_permissions_policy();
assert!(config.content_security_policy.is_none());
assert!(config.hsts.is_none());
assert!(config.permissions_policy.is_none());
}
#[test]
fn csrf_token_generate_produces_unique_tokens() {
let token1 = CsrfToken::generate();
let token2 = CsrfToken::generate();
assert_ne!(token1, token2);
assert!(!token1.as_str().is_empty());
assert!(!token2.as_str().is_empty());
}
#[test]
fn csrf_token_display() {
let token = CsrfToken::new("test-token-123");
assert_eq!(format!("{}", token), "test-token-123");
}
#[test]
fn csrf_config_defaults() {
let config = CsrfConfig::default();
assert_eq!(config.cookie_name, "csrf_token");
assert_eq!(config.header_name, "x-csrf-token");
assert_eq!(config.mode, CsrfMode::DoubleSubmit);
assert!(!config.rotate_token);
assert!(config.production);
assert!(config.error_message.is_none());
}
#[test]
fn csrf_config_builder() {
let config = CsrfConfig::new()
.cookie_name("XSRF-TOKEN")
.header_name("X-XSRF-Token")
.mode(CsrfMode::HeaderOnly)
.rotate_token(true)
.production(false)
.error_message("Custom CSRF error");
assert_eq!(config.cookie_name, "XSRF-TOKEN");
assert_eq!(config.header_name, "X-XSRF-Token");
assert_eq!(config.mode, CsrfMode::HeaderOnly);
assert!(config.rotate_token);
assert!(!config.production);
assert_eq!(config.error_message, Some("Custom CSRF error".to_string()));
}
#[test]
fn csrf_middleware_allows_get_without_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
assert!(req.get_extension::<CsrfToken>().is_some());
}
#[test]
fn csrf_middleware_allows_head_without_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Head, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn csrf_middleware_allows_options_without_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Options, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn csrf_middleware_blocks_post_without_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break());
if let ControlFlow::Break(response) = result {
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
#[test]
fn csrf_middleware_blocks_put_without_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Put, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break());
}
#[test]
fn csrf_middleware_blocks_delete_without_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Delete, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break());
}
#[test]
fn csrf_middleware_blocks_patch_without_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Patch, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break());
}
#[test]
fn csrf_middleware_allows_post_with_matching_tokens() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let token = "valid-csrf-token-12345";
req.headers_mut()
.insert("cookie", format!("csrf_token={}", token).into_bytes());
req.headers_mut()
.insert("x-csrf-token", token.as_bytes().to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
let stored_token = req.get_extension::<CsrfToken>().unwrap();
assert_eq!(stored_token.as_str(), token);
}
#[test]
fn csrf_middleware_blocks_post_with_mismatched_tokens() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut()
.insert("cookie", b"csrf_token=token-in-cookie".to_vec());
req.headers_mut()
.insert("x-csrf-token", b"different-token".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break());
if let ControlFlow::Break(response) = result {
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
#[test]
fn csrf_middleware_blocks_post_with_header_only_in_double_submit_mode() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut()
.insert("x-csrf-token", b"some-token".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break());
}
#[test]
fn csrf_middleware_blocks_post_with_cookie_only_in_double_submit_mode() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut()
.insert("cookie", b"csrf_token=some-token".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break());
}
#[test]
fn csrf_middleware_header_only_mode_accepts_header_token() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut()
.insert("x-csrf-token", b"valid-token".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn csrf_middleware_header_only_mode_rejects_empty_header() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut().insert("x-csrf-token", b"".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break());
}
#[test]
fn csrf_middleware_sets_cookie_on_get() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
let cookie_value = header_value(&result, "set-cookie");
assert!(cookie_value.is_some());
let cookie_value = cookie_value.unwrap();
assert!(cookie_value.starts_with("csrf_token="));
assert!(cookie_value.contains("SameSite=Strict"));
assert!(cookie_value.contains("Secure")); }
#[test]
fn csrf_middleware_no_secure_in_dev_mode() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().production(false));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
let cookie_value = header_value(&result, "set-cookie").unwrap();
assert!(!cookie_value.contains("Secure")); }
#[test]
fn csrf_middleware_does_not_set_cookie_if_already_present() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("cookie", b"csrf_token=existing-token".to_vec());
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
assert!(header_value(&result, "set-cookie").is_none());
}
#[test]
fn csrf_middleware_rotates_token_when_configured() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(true));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("cookie", b"csrf_token=old-token".to_vec());
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
assert!(header_value(&result, "set-cookie").is_some());
}
#[test]
fn csrf_middleware_custom_header_name() {
let csrf = CsrfMiddleware::with_config(
CsrfConfig::new()
.header_name("X-XSRF-Token")
.cookie_name("XSRF-TOKEN"),
);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let token = "custom-token-value";
req.headers_mut()
.insert("cookie", format!("XSRF-TOKEN={}", token).into_bytes());
req.headers_mut()
.insert("x-xsrf-token", token.as_bytes().to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn csrf_middleware_error_response_is_json() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
if let ControlFlow::Break(response) = result {
let content_type = header_value(&response, "content-type");
assert_eq!(content_type, Some("application/json".to_string()));
if let ResponseBody::Bytes(body) = response.body_ref() {
let body_str = std::str::from_utf8(body).unwrap();
assert!(body_str.contains("csrf_error"));
assert!(body_str.contains("x-csrf-token"));
} else {
panic!("Expected Bytes body");
}
} else {
panic!("Expected Break");
}
}
#[test]
fn csrf_middleware_custom_error_message() {
let csrf = CsrfMiddleware::with_config(
CsrfConfig::new().error_message("Access denied: invalid security token"),
);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
if let ControlFlow::Break(response) = result {
if let ResponseBody::Bytes(body) = response.body_ref() {
let body_str = std::str::from_utf8(body).unwrap();
assert!(body_str.contains("Access denied: invalid security token"));
}
}
}
#[test]
fn csrf_middleware_name() {
let csrf = CsrfMiddleware::new();
assert_eq!(csrf.name(), "CSRF");
}
#[test]
fn csrf_middleware_parses_cookie_with_multiple_cookies() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let token = "the-csrf-token";
req.headers_mut().insert(
"cookie",
format!("session=abc123; csrf_token={}; user=test", token).into_bytes(),
);
req.headers_mut()
.insert("x-csrf-token", token.as_bytes().to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn csrf_middleware_handles_empty_token_value() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
req.headers_mut().insert("x-csrf-token", b"".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break()); }
#[test]
fn csrf_token_generate_many_unique() {
let mut tokens = std::collections::HashSet::new();
for _ in 0..100 {
let token = CsrfToken::generate();
assert!(
tokens.insert(token.0.clone()),
"Duplicate token generated: {}",
token.0
);
}
assert_eq!(tokens.len(), 100);
}
#[test]
fn csrf_token_generate_format_is_hex() {
let token = CsrfToken::generate();
let s = token.as_str();
assert!(
s.len() >= 64,
"Expected at least 64 hex characters, got {} in '{s}'",
s.len()
);
assert!(
s.chars().all(|c| c.is_ascii_hexdigit()),
"Non-hex character in token: {s}"
);
}
#[test]
fn csrf_token_generate_minimum_length() {
let token = CsrfToken::generate();
assert!(
token.as_str().len() >= 64,
"Token too short: {} (len={})",
token.as_str(),
token.as_str().len()
);
}
#[test]
fn csrf_token_from_str() {
let token: CsrfToken = "my-token".into();
assert_eq!(token.as_str(), "my-token");
assert_eq!(token.0, "my-token");
}
#[test]
fn csrf_token_clone_eq() {
let t1 = CsrfToken::new("abc");
let t2 = t1.clone();
assert_eq!(t1, t2);
assert_eq!(t1.as_str(), t2.as_str());
}
#[test]
fn csrf_middleware_allows_trace_without_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Trace, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
assert!(req.get_extension::<CsrfToken>().is_some());
}
#[test]
fn csrf_safe_method_generates_token_into_extension() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
for method in [
crate::request::Method::Get,
crate::request::Method::Head,
crate::request::Method::Options,
crate::request::Method::Trace,
] {
let mut req = Request::new(method, "/test");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
let token = req.get_extension::<CsrfToken>().expect("token missing");
assert!(!token.as_str().is_empty());
}
}
#[test]
fn csrf_safe_method_preserves_existing_cookie_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("cookie", b"csrf_token=my-existing-token".to_vec());
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let token = req.get_extension::<CsrfToken>().unwrap();
assert_eq!(token.as_str(), "my-existing-token");
}
#[test]
fn csrf_valid_post_stores_token_in_extension() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/submit");
let tk = "valid-token-xyz";
req.headers_mut()
.insert("cookie", format!("csrf_token={}", tk).into_bytes());
req.headers_mut()
.insert("x-csrf-token", tk.as_bytes().to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
let stored = req.get_extension::<CsrfToken>().unwrap();
assert_eq!(stored.as_str(), tk);
}
#[test]
fn csrf_double_submit_both_empty_strings_rejected() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
req.headers_mut().insert("x-csrf-token", b"".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break());
}
#[test]
fn csrf_double_submit_matching_empty_rejected() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
req.headers_mut().insert("x-csrf-token", b"".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(
result.is_break(),
"Empty matching tokens should be rejected"
);
}
#[test]
fn csrf_header_only_mode_does_not_need_cookie() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut()
.insert("x-csrf-token", b"header-only-token".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
let token = req.get_extension::<CsrfToken>().unwrap();
assert_eq!(token.as_str(), "header-only-token");
}
#[test]
fn csrf_header_only_mode_ignores_mismatched_cookie() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut()
.insert("cookie", b"csrf_token=different-value".to_vec());
req.headers_mut()
.insert("x-csrf-token", b"header-value".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue(), "HeaderOnly should ignore cookie");
}
#[test]
fn csrf_header_only_mode_rejects_no_header() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break());
}
#[test]
fn csrf_header_only_error_message_mentions_header() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
if let ControlFlow::Break(response) = result {
if let ResponseBody::Bytes(body) = response.body_ref() {
let body_str = std::str::from_utf8(body).unwrap();
assert!(
body_str.contains("missing in header"),
"Expected 'missing in header' in: {}",
body_str
);
}
} else {
panic!("Expected Break");
}
}
#[test]
fn csrf_mismatch_error_differs_from_missing_error() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req_missing = Request::new(crate::request::Method::Post, "/");
let missing_result = futures_executor::block_on(csrf.before(&ctx, &mut req_missing));
let missing_body = match missing_result {
ControlFlow::Break(r) => match r.body_ref() {
ResponseBody::Bytes(b) => std::str::from_utf8(b).unwrap().to_string(),
ResponseBody::Empty | ResponseBody::Stream(_) => panic!("Expected Bytes"),
},
ControlFlow::Continue => panic!("Expected Break"),
};
let mut req_mismatch = Request::new(crate::request::Method::Post, "/");
req_mismatch
.headers_mut()
.insert("cookie", b"csrf_token=aaa".to_vec());
req_mismatch
.headers_mut()
.insert("x-csrf-token", b"bbb".to_vec());
let mismatch_result = futures_executor::block_on(csrf.before(&ctx, &mut req_mismatch));
let mismatch_body = match mismatch_result {
ControlFlow::Break(r) => match r.body_ref() {
ResponseBody::Bytes(b) => std::str::from_utf8(b).unwrap().to_string(),
ResponseBody::Empty | ResponseBody::Stream(_) => panic!("Expected Bytes"),
},
ControlFlow::Continue => panic!("Expected Break"),
};
assert_ne!(
missing_body, mismatch_body,
"Missing vs mismatch should have different error messages"
);
assert!(missing_body.contains("missing"));
assert!(mismatch_body.contains("mismatch"));
}
#[test]
fn csrf_cookie_not_httponly() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
let cookie_value = header_value(&result, "set-cookie").unwrap();
assert!(
!cookie_value.to_lowercase().contains("httponly"),
"CSRF cookie must NOT be HttpOnly (needs JS access), got: {}",
cookie_value
);
}
#[test]
fn csrf_cookie_has_path_slash() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
let cookie_value = header_value(&result, "set-cookie").unwrap();
assert!(
cookie_value.contains("Path=/"),
"Cookie should have Path=/, got: {}",
cookie_value
);
}
#[test]
fn csrf_cookie_has_samesite_strict() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
let cookie_value = header_value(&result, "set-cookie").unwrap();
assert!(
cookie_value.contains("SameSite=Strict"),
"Cookie should have SameSite=Strict, got: {}",
cookie_value
);
}
#[test]
fn csrf_production_mode_sets_secure_flag() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().production(true));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
let cookie_value = header_value(&result, "set-cookie").unwrap();
assert!(
cookie_value.contains("Secure"),
"Production cookie must have Secure flag, got: {}",
cookie_value
);
}
#[test]
fn csrf_no_set_cookie_on_post_response() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let token = "valid-token";
req.headers_mut()
.insert("cookie", format!("csrf_token={}", token).into_bytes());
req.headers_mut()
.insert("x-csrf-token", token.as_bytes().to_vec());
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
assert!(
header_value(&result, "set-cookie").is_none(),
"POST response should not set CSRF cookie"
);
}
#[test]
fn csrf_head_method_sets_cookie() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Head, "/");
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
assert!(
header_value(&result, "set-cookie").is_some(),
"HEAD response should set CSRF cookie"
);
}
#[test]
fn csrf_options_method_sets_cookie() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Options, "/");
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
assert!(
header_value(&result, "set-cookie").is_some(),
"OPTIONS response should set CSRF cookie"
);
}
#[test]
fn csrf_rotation_produces_different_token_in_cookie() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(true));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let old_token = "old-token-value";
req.headers_mut()
.insert("cookie", format!("csrf_token={}", old_token).into_bytes());
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
let cookie_value = header_value(&result, "set-cookie").unwrap();
assert!(cookie_value.starts_with("csrf_token="));
}
#[test]
fn csrf_no_rotation_skips_set_cookie_when_present() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(false));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
req.headers_mut()
.insert("cookie", b"csrf_token=existing".to_vec());
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
assert!(
header_value(&result, "set-cookie").is_none(),
"Without rotation, should not re-set existing cookie"
);
}
#[test]
fn csrf_custom_cookie_name_in_set_cookie_response() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().cookie_name("XSRF-TOKEN"));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
let response = Response::ok();
let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
let cookie_value = header_value(&result, "set-cookie").unwrap();
assert!(
cookie_value.starts_with("XSRF-TOKEN="),
"Custom cookie name should appear in Set-Cookie, got: {}",
cookie_value
);
}
#[test]
fn csrf_custom_header_name_validated() {
let csrf = CsrfMiddleware::with_config(
CsrfConfig::new()
.header_name("X-Custom-CSRF")
.cookie_name("my_csrf"),
);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let token = "custom-tok";
req.headers_mut()
.insert("cookie", format!("my_csrf={}", token).into_bytes());
req.headers_mut()
.insert("x-custom-csrf", token.as_bytes().to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn csrf_custom_header_name_wrong_header_rejected() {
let csrf = CsrfMiddleware::with_config(CsrfConfig::new().header_name("X-Custom-CSRF"));
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let token = "some-token";
req.headers_mut()
.insert("cookie", format!("csrf_token={}", token).into_bytes());
req.headers_mut()
.insert("x-csrf-token", token.as_bytes().to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_break(), "Wrong header name should be rejected");
}
#[test]
fn csrf_cookie_parsing_multiple_cookies_picks_correct() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let token = "correct-csrf";
req.headers_mut().insert(
"cookie",
format!("session=abc; other=xyz; csrf_token={}; tracking=123", token).into_bytes(),
);
req.headers_mut()
.insert("x-csrf-token", token.as_bytes().to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn csrf_cookie_parsing_spaces_around_semicolons() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let token = "spaced-token";
req.headers_mut().insert(
"cookie",
format!("session=abc ; csrf_token={} ; other=xyz", token).into_bytes(),
);
req.headers_mut()
.insert("x-csrf-token", token.as_bytes().to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn csrf_error_response_status_is_403() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
for method in [
crate::request::Method::Post,
crate::request::Method::Put,
crate::request::Method::Delete,
crate::request::Method::Patch,
] {
let mut req = Request::new(method, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
match result {
ControlFlow::Break(response) => {
assert_eq!(
response.status(),
StatusCode::FORBIDDEN,
"Expected 403 for {:?}",
method
);
}
ControlFlow::Continue => panic!("Expected Break for {:?}", method),
}
}
}
#[test]
fn csrf_error_body_json_structure() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
if let ControlFlow::Break(response) = result {
if let ResponseBody::Bytes(body) = response.body_ref() {
let body_str = std::str::from_utf8(body).unwrap();
let parsed: serde_json::Value = serde_json::from_str(body_str)
.unwrap_or_else(|e| panic!("Invalid JSON: {}: {}", body_str, e));
assert!(parsed["detail"].is_array());
let detail = &parsed["detail"][0];
assert_eq!(detail["type"], "csrf_error");
assert!(detail["loc"].is_array());
assert_eq!(detail["loc"][0], "header");
assert_eq!(detail["loc"][1], "x-csrf-token");
assert!(detail["msg"].is_string());
} else {
panic!("Expected Bytes body");
}
} else {
panic!("Expected Break");
}
}
#[test]
fn csrf_default_trait() {
let csrf = CsrfMiddleware::default();
assert_eq!(csrf.name(), "CSRF");
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn csrf_mode_default_is_double_submit() {
assert_eq!(CsrfMode::default(), CsrfMode::DoubleSubmit);
}
#[test]
fn csrf_double_submit_both_present_same_non_empty_passes() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let token = "a1b2c3d4e5f6";
let mut req = Request::new(crate::request::Method::Delete, "/resource/1");
req.headers_mut()
.insert("cookie", format!("csrf_token={}", token).into_bytes());
req.headers_mut()
.insert("x-csrf-token", token.as_bytes().to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn csrf_double_submit_case_sensitive() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Post, "/");
req.headers_mut()
.insert("cookie", b"csrf_token=AbCdEf".to_vec());
req.headers_mut().insert("x-csrf-token", b"abcdef".to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(
result.is_break(),
"Token comparison should be case-sensitive"
);
}
#[test]
fn csrf_token_cookie_extractor_reads_csrf_cookie() {
use crate::extract::{CookieName, CsrfTokenCookie};
assert_eq!(CsrfTokenCookie::NAME, "csrf_token");
}
#[test]
fn csrf_make_set_cookie_header_value_production() {
let value = CsrfMiddleware::make_set_cookie_header_value("csrf_token", "tok123", true);
let s = std::str::from_utf8(&value).unwrap();
assert!(s.contains("csrf_token=tok123"));
assert!(s.contains("Path=/"));
assert!(s.contains("SameSite=Strict"));
assert!(s.contains("Secure"));
assert!(!s.to_lowercase().contains("httponly"));
}
#[test]
fn csrf_make_set_cookie_header_value_development() {
let value = CsrfMiddleware::make_set_cookie_header_value("csrf_token", "tok123", false);
let s = std::str::from_utf8(&value).unwrap();
assert!(s.contains("csrf_token=tok123"));
assert!(s.contains("Path=/"));
assert!(s.contains("SameSite=Strict"));
assert!(!s.contains("Secure"));
}
#[test]
fn csrf_before_after_full_cycle_get_then_post() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
let mut get_req = Request::new(crate::request::Method::Get, "/form");
let _ = futures_executor::block_on(csrf.before(&ctx, &mut get_req));
let get_response = Response::ok();
let get_result = futures_executor::block_on(csrf.after(&ctx, &get_req, get_response));
let set_cookie = header_value(&get_result, "set-cookie").expect("GET should set cookie");
let token_value = set_cookie
.strip_prefix("csrf_token=")
.unwrap()
.split(';')
.next()
.unwrap();
assert!(!token_value.is_empty());
let mut post_req = Request::new(crate::request::Method::Post, "/form");
post_req
.headers_mut()
.insert("cookie", format!("csrf_token={}", token_value).into_bytes());
post_req
.headers_mut()
.insert("x-csrf-token", token_value.as_bytes().to_vec());
let result = futures_executor::block_on(csrf.before(&ctx, &mut post_req));
assert!(result.is_continue(), "POST with valid token should pass");
}
#[test]
fn csrf_all_state_changing_methods_require_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
for method in [
crate::request::Method::Post,
crate::request::Method::Put,
crate::request::Method::Delete,
crate::request::Method::Patch,
] {
let mut req = Request::new(method, "/resource");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(
result.is_break(),
"{:?} without token should be rejected",
method
);
}
}
#[test]
fn csrf_all_safe_methods_pass_without_token() {
let csrf = CsrfMiddleware::new();
let ctx = test_context();
for method in [
crate::request::Method::Get,
crate::request::Method::Head,
crate::request::Method::Options,
crate::request::Method::Trace,
] {
let mut req = Request::new(method, "/resource");
let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
assert!(
result.is_continue(),
"{:?} should be allowed without token",
method
);
}
}
#[derive(Clone)]
struct OrderRecordingMiddleware {
id: &'static str,
log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
}
impl OrderRecordingMiddleware {
fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
Self { id, log }
}
}
impl Middleware for OrderRecordingMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
let id = self.id;
let log = self.log.clone();
Box::pin(async move {
log.lock().unwrap().push(format!("{id}:before"));
ControlFlow::Continue
})
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let id = self.id;
let log = self.log.clone();
Box::pin(async move {
log.lock().unwrap().push(format!("{id}:after"));
response
})
}
fn name(&self) -> &'static str {
"OrderRecording"
}
}
struct ShortCircuitMiddleware {
id: &'static str,
log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
}
impl ShortCircuitMiddleware {
fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
Self { id, log }
}
}
impl Middleware for ShortCircuitMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
let id = self.id;
let log = self.log.clone();
Box::pin(async move {
log.lock().unwrap().push(format!("{id}:before:break"));
ControlFlow::Break(
Response::with_status(StatusCode::FORBIDDEN)
.body(ResponseBody::Bytes(b"short-circuited".to_vec())),
)
})
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let id = self.id;
let log = self.log.clone();
Box::pin(async move {
log.lock().unwrap().push(format!("{id}:after"));
response
})
}
fn name(&self) -> &'static str {
"ShortCircuit"
}
}
struct RecordingHandler {
log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
}
impl RecordingHandler {
fn new(log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
Self { log }
}
}
impl Handler for RecordingHandler {
fn call<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a mut Request,
) -> BoxFuture<'a, Response> {
let log = self.log.clone();
Box::pin(async move {
log.lock().unwrap().push("handler".to_string());
Response::ok().body(ResponseBody::Bytes(b"ok".to_vec()))
})
}
}
#[test]
fn middleware_stack_three_middleware_onion_order() {
let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
stack.push(OrderRecordingMiddleware::new("mw3", log.clone()));
let handler = RecordingHandler::new(log.clone());
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
let execution_log = log.lock().unwrap().clone();
assert_eq!(
execution_log,
vec![
"mw1:before",
"mw2:before",
"mw3:before",
"handler",
"mw3:after",
"mw2:after",
"mw1:after",
]
);
}
#[test]
fn middleware_stack_short_circuit_runs_prior_after_hooks() {
let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
stack.push(ShortCircuitMiddleware::new("mw2", log.clone()));
stack.push(OrderRecordingMiddleware::new("mw3", log.clone()));
let handler = RecordingHandler::new(log.clone());
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 403);
let execution_log = log.lock().unwrap().clone();
assert_eq!(
execution_log,
vec!["mw1:before", "mw2:before:break", "mw1:after",]
);
}
#[test]
fn middleware_stack_first_middleware_short_circuits() {
let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(ShortCircuitMiddleware::new("mw1", log.clone()));
stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
let handler = RecordingHandler::new(log.clone());
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 403);
let execution_log = log.lock().unwrap().clone();
assert_eq!(execution_log, vec!["mw1:before:break",]);
}
#[test]
fn middleware_stack_empty_runs_handler_only() {
let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let stack = MiddlewareStack::new();
let handler = RecordingHandler::new(log.clone());
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 200);
let execution_log = log.lock().unwrap().clone();
assert_eq!(execution_log, vec!["handler"]);
}
#[test]
fn middleware_stack_single_middleware_ordering() {
let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
let handler = RecordingHandler::new(log.clone());
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
let execution_log = log.lock().unwrap().clone();
assert_eq!(execution_log, vec!["mw1:before", "handler", "mw1:after",]);
}
#[test]
fn middleware_stack_five_middleware_onion_order() {
let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderRecordingMiddleware::new("a", log.clone()));
stack.push(OrderRecordingMiddleware::new("b", log.clone()));
stack.push(OrderRecordingMiddleware::new("c", log.clone()));
stack.push(OrderRecordingMiddleware::new("d", log.clone()));
stack.push(OrderRecordingMiddleware::new("e", log.clone()));
let handler = RecordingHandler::new(log.clone());
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
let execution_log = log.lock().unwrap().clone();
assert_eq!(
execution_log,
vec![
"a:before", "b:before", "c:before", "d:before", "e:before", "handler", "e:after",
"d:after", "c:after", "b:after", "a:after",
]
);
}
#[test]
fn middleware_stack_short_circuit_at_end_runs_prior_afters() {
let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
stack.push(ShortCircuitMiddleware::new("mw3", log.clone()));
let handler = RecordingHandler::new(log.clone());
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
assert_eq!(response.status().as_u16(), 403);
let execution_log = log.lock().unwrap().clone();
assert_eq!(
execution_log,
vec![
"mw1:before",
"mw2:before",
"mw3:before:break",
"mw2:after",
"mw1:after",
]
);
}
struct ModifyingMiddleware {
id: &'static str,
log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
}
impl ModifyingMiddleware {
fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
Self { id, log }
}
}
impl Middleware for ModifyingMiddleware {
fn before<'a>(
&'a self,
_ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, ControlFlow> {
let id = self.id;
let log = self.log.clone();
Box::pin(async move {
req.headers_mut()
.insert(format!("x-{id}-before"), b"true".to_vec());
log.lock().unwrap().push(format!("{id}:before"));
ControlFlow::Continue
})
}
fn after<'a>(
&'a self,
_ctx: &'a RequestContext,
_req: &'a Request,
response: Response,
) -> BoxFuture<'a, Response> {
let id = self.id;
let log = self.log.clone();
Box::pin(async move {
log.lock().unwrap().push(format!("{id}:after"));
response.header(format!("x-{id}-after"), b"true".to_vec())
})
}
fn name(&self) -> &'static str {
"Modifying"
}
}
#[test]
fn middleware_stack_modifications_accumulate_correctly() {
let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(ModifyingMiddleware::new("mw1", log.clone()));
stack.push(ModifyingMiddleware::new("mw2", log.clone()));
stack.push(ModifyingMiddleware::new("mw3", log.clone()));
let handler = RecordingHandler::new(log.clone());
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
assert!(header_value(&response, "x-mw1-after").is_some());
assert!(header_value(&response, "x-mw2-after").is_some());
assert!(header_value(&response, "x-mw3-after").is_some());
assert!(req.headers().contains("x-mw1-before"));
assert!(req.headers().contains("x-mw2-before"));
assert!(req.headers().contains("x-mw3-before"));
}
#[test]
fn layer_wrap_maintains_middleware_order() {
let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let layer = Layer::new(OrderRecordingMiddleware::new("layer", log.clone()));
let handler = RecordingHandler::new(log.clone());
let layered_handler = layer.wrap(handler);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/");
let _response = futures_executor::block_on(layered_handler.call(&ctx, &mut req));
let execution_log = log.lock().unwrap().clone();
assert_eq!(
execution_log,
vec!["layer:before", "handler", "layer:after",]
);
}
}
#[cfg(all(test, feature = "compression"))]
mod compression_tests {
use super::*;
use crate::request::Method;
use crate::response::ResponseBody;
fn test_context() -> RequestContext {
RequestContext::new(asupersync::Cx::for_testing(), 1)
}
#[test]
fn compression_config_defaults() {
let config = CompressionConfig::default();
assert_eq!(config.min_size, 1024);
assert_eq!(config.level, 6);
assert!(!config.skip_content_types.is_empty());
}
#[test]
fn compression_config_builder() {
let config = CompressionConfig::new().min_size(512).level(9);
assert_eq!(config.min_size, 512);
assert_eq!(config.level, 9);
}
#[test]
fn compression_level_clamped() {
let config = CompressionConfig::new().level(100);
assert_eq!(config.level, 9);
let config = CompressionConfig::new().level(0);
assert_eq!(config.level, 1);
}
#[test]
fn skip_content_type_exact_match() {
let config = CompressionConfig::default();
assert!(config.should_skip_content_type("image/jpeg"));
assert!(config.should_skip_content_type("image/jpeg; charset=utf-8"));
assert!(!config.should_skip_content_type("text/html"));
}
#[test]
fn skip_content_type_prefix_match() {
let config = CompressionConfig::default();
assert!(config.should_skip_content_type("video/mp4"));
assert!(config.should_skip_content_type("video/webm"));
assert!(config.should_skip_content_type("audio/mpeg"));
}
#[test]
fn compression_skips_small_responses() {
let middleware = CompressionMiddleware::new();
let ctx = test_context();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("accept-encoding", b"gzip".to_vec());
let response = Response::ok()
.header("content-type", b"text/plain".to_vec())
.body(ResponseBody::Bytes(b"Hello, World!".to_vec()));
let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
let has_encoding = result
.headers()
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
assert!(!has_encoding, "Small response should not be compressed");
}
#[test]
fn compression_works_for_large_responses() {
let config = CompressionConfig::new().min_size(10); let middleware = CompressionMiddleware::with_config(config);
let ctx = test_context();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("accept-encoding", b"gzip".to_vec());
let body = "Hello, World! ".repeat(100);
let original_size = body.len();
let response = Response::ok()
.header("content-type", b"text/plain".to_vec())
.body(ResponseBody::Bytes(body.into_bytes()));
let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
let encoding = result
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
assert!(encoding.is_some(), "Large response should be compressed");
let (_, value) = encoding.unwrap();
assert_eq!(value, b"gzip");
let vary = result
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("vary"));
assert!(vary.is_some(), "Should have Vary header");
if let ResponseBody::Bytes(compressed) = result.body_ref() {
assert!(
compressed.len() < original_size,
"Compressed size should be smaller"
);
} else {
panic!("Expected Bytes body");
}
}
#[test]
fn compression_skips_without_accept_encoding() {
let config = CompressionConfig::new().min_size(10);
let middleware = CompressionMiddleware::with_config(config);
let ctx = test_context();
let req = Request::new(Method::Get, "/");
let body = "Hello, World! ".repeat(100);
let response = Response::ok()
.header("content-type", b"text/plain".to_vec())
.body(ResponseBody::Bytes(body.into_bytes()));
let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
let has_encoding = result
.headers()
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
assert!(!has_encoding, "Should not compress without Accept-Encoding");
}
#[test]
fn compression_skips_already_compressed_content() {
let config = CompressionConfig::new().min_size(10);
let middleware = CompressionMiddleware::with_config(config);
let ctx = test_context();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("accept-encoding", b"gzip".to_vec());
let body = "Some image data".repeat(100);
let response = Response::ok()
.header("content-type", b"image/jpeg".to_vec())
.body(ResponseBody::Bytes(body.into_bytes()));
let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
let has_encoding = result
.headers()
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
assert!(
!has_encoding,
"Should not compress already-compressed content types"
);
}
#[test]
fn compression_skips_if_already_has_content_encoding() {
let config = CompressionConfig::new().min_size(10);
let middleware = CompressionMiddleware::with_config(config);
let ctx = test_context();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("accept-encoding", b"gzip".to_vec());
let body = "Hello, World! ".repeat(100);
let response = Response::ok()
.header("content-type", b"text/plain".to_vec())
.header("content-encoding", b"br".to_vec())
.body(ResponseBody::Bytes(body.into_bytes()));
let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
let encodings: Vec<_> = result
.headers()
.iter()
.filter(|(name, _)| name.eq_ignore_ascii_case("content-encoding"))
.collect();
assert_eq!(encodings.len(), 1);
assert_eq!(encodings[0].1, b"br");
}
#[test]
fn accepts_gzip_parses_header_correctly() {
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("accept-encoding", b"gzip".to_vec());
assert!(CompressionMiddleware::accepts_gzip(&req));
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("accept-encoding", b"deflate, gzip, br".to_vec());
assert!(CompressionMiddleware::accepts_gzip(&req));
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("accept-encoding", b"gzip;q=1.0, identity;q=0.5".to_vec());
assert!(CompressionMiddleware::accepts_gzip(&req));
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("accept-encoding", b"*".to_vec());
assert!(CompressionMiddleware::accepts_gzip(&req));
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("accept-encoding", b"deflate, br".to_vec());
assert!(!CompressionMiddleware::accepts_gzip(&req));
let req_no_header = Request::new(Method::Get, "/");
assert!(!CompressionMiddleware::accepts_gzip(&req_no_header));
}
#[test]
fn compression_middleware_name() {
let middleware = CompressionMiddleware::new();
assert_eq!(middleware.name(), "Compression");
}
}
#[cfg(test)]
mod request_inspection_tests {
use super::*;
use crate::request::Method;
use crate::response::ResponseBody;
fn test_context() -> RequestContext {
RequestContext::new(asupersync::Cx::for_testing(), 1)
}
#[test]
fn inspection_middleware_default_creates_normal_verbosity() {
let mw = RequestInspectionMiddleware::new();
assert_eq!(mw.verbosity, InspectionVerbosity::Normal);
assert_eq!(mw.slow_threshold_ms, 1000);
assert_eq!(mw.max_body_preview, 2048);
assert_eq!(mw.name(), "RequestInspection");
}
#[test]
fn inspection_middleware_builder_methods() {
let mw = RequestInspectionMiddleware::new()
.verbosity(InspectionVerbosity::Verbose)
.slow_threshold_ms(500)
.max_body_preview(4096)
.log_config(LogConfig::development())
.redact_header("x-api-key");
assert_eq!(mw.verbosity, InspectionVerbosity::Verbose);
assert_eq!(mw.slow_threshold_ms, 500);
assert_eq!(mw.max_body_preview, 4096);
assert!(mw.redact_headers.contains("x-api-key"));
assert!(mw.redact_headers.contains("authorization"));
assert!(mw.redact_headers.contains("cookie"));
}
#[test]
fn inspection_before_continues_processing() {
let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
let ctx = test_context();
let mut req = Request::new(Method::Post, "/api/users");
let result = futures_executor::block_on(mw.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn inspection_after_returns_response_unchanged() {
let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
let ctx = test_context();
let mut req = Request::new(Method::Get, "/health");
let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
let response = Response::ok().body(ResponseBody::Bytes(b"OK".to_vec()));
let result = futures_executor::block_on(mw.after(&ctx, &req, response));
assert_eq!(result.status().as_u16(), 200);
assert_eq!(result.body_ref().len(), 2);
}
#[test]
fn inspection_stores_start_extension() {
let mw = RequestInspectionMiddleware::new();
let ctx = test_context();
let mut req = Request::new(Method::Get, "/");
let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
assert!(req.get_extension::<InspectionStart>().is_some());
}
#[test]
fn inspection_all_verbosity_levels_continue() {
for verbosity in [
InspectionVerbosity::Minimal,
InspectionVerbosity::Normal,
InspectionVerbosity::Verbose,
] {
let mw = RequestInspectionMiddleware::new().verbosity(verbosity);
let ctx = test_context();
let mut req = Request::new(Method::Get, "/test");
req.headers_mut()
.insert("content-type", b"text/plain".to_vec());
let result = futures_executor::block_on(mw.before(&ctx, &mut req));
assert!(
result.is_continue(),
"Verbosity {verbosity:?} should continue"
);
}
}
#[test]
fn inspection_verbose_with_json_body() {
let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Verbose);
let ctx = test_context();
let body = br#"{"name":"Alice","age":30}"#;
let mut req = Request::new(Method::Post, "/api/users");
req.headers_mut()
.insert("content-type", b"application/json".to_vec());
req.set_body(Body::Bytes(body.to_vec()));
let result = futures_executor::block_on(mw.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn inspection_verbose_after_with_json_response() {
let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Verbose);
let ctx = test_context();
let mut req = Request::new(Method::Get, "/api/users/1");
let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
let response = Response::ok()
.header("content-type", b"application/json".to_vec())
.body(ResponseBody::Bytes(br#"{"id":1,"name":"Alice"}"#.to_vec()));
let result = futures_executor::block_on(mw.after(&ctx, &req, response));
assert_eq!(result.status().as_u16(), 200);
}
#[test]
fn inspection_redacts_sensitive_headers() {
let mw = RequestInspectionMiddleware::new();
assert!(mw.redact_headers.contains("authorization"));
assert!(mw.redact_headers.contains("proxy-authorization"));
assert!(mw.redact_headers.contains("cookie"));
assert!(mw.redact_headers.contains("set-cookie"));
}
#[test]
fn inspection_format_headers_redacts() {
let mw = RequestInspectionMiddleware::new().redact_header("x-secret");
let headers = vec![
("content-type", b"text/plain".as_slice()),
("x-secret", b"my-secret-value".as_slice()),
("x-normal", b"visible".as_slice()),
];
let output = mw.format_inspection_headers(headers.into_iter());
assert!(output.contains("content-type: text/plain"));
assert!(output.contains("x-secret: [REDACTED]"));
assert!(output.contains("x-normal: visible"));
assert!(!output.contains("my-secret-value"));
}
#[test]
fn inspection_format_body_preview_truncates() {
let mw = RequestInspectionMiddleware::new().max_body_preview(10);
let body = b"Hello, World! This is a long body.";
let result = mw.format_body_preview(body, None);
assert!(result.is_some());
let text = result.unwrap();
assert!(text.ends_with("..."));
assert!(text.len() <= 15); }
#[test]
fn inspection_format_body_preview_empty() {
let mw = RequestInspectionMiddleware::new();
assert!(mw.format_body_preview(b"", None).is_none());
}
#[test]
fn inspection_format_body_preview_zero_max() {
let mw = RequestInspectionMiddleware::new().max_body_preview(0);
assert!(mw.format_body_preview(b"hello", None).is_none());
}
#[test]
fn inspection_format_body_preview_json_pretty() {
let mw = RequestInspectionMiddleware::new();
let body = br#"{"key":"value","num":42}"#;
let ct = b"application/json".as_slice();
let result = mw.format_body_preview(body, Some(ct));
assert!(result.is_some());
let text = result.unwrap();
assert!(text.contains('\n'));
assert!(text.contains("\"key\": \"value\""));
}
#[test]
fn inspection_format_body_preview_non_json() {
let mw = RequestInspectionMiddleware::new();
let body = b"Hello, World!";
let ct = b"text/plain".as_slice();
let result = mw.format_body_preview(body, Some(ct));
assert_eq!(result.unwrap(), "Hello, World!");
}
#[test]
fn inspection_format_body_preview_binary() {
let mw = RequestInspectionMiddleware::new();
let body: &[u8] = &[0xFF, 0xFE, 0xFD, 0x00];
let result = mw.format_body_preview(body, None);
assert!(result.is_some());
assert!(result.unwrap().contains("binary"));
}
#[test]
fn try_pretty_json_valid_object() {
let result = try_pretty_json(r#"{"a":"b","c":1}"#);
assert!(result.is_some());
let pretty = result.unwrap();
assert!(pretty.contains('\n'));
assert!(pretty.contains(" \"a\": \"b\""));
}
#[test]
fn try_pretty_json_valid_array() {
let result = try_pretty_json(r"[1,2,3]");
assert!(result.is_some());
let pretty = result.unwrap();
assert!(pretty.contains('\n'));
}
#[test]
fn try_pretty_json_empty_object() {
let result = try_pretty_json("{}");
assert!(result.is_some());
assert_eq!(result.unwrap(), "{}");
}
#[test]
fn try_pretty_json_empty_array() {
let result = try_pretty_json("[]");
assert!(result.is_some());
assert_eq!(result.unwrap(), "[]");
}
#[test]
fn try_pretty_json_not_json() {
assert!(try_pretty_json("hello world").is_none());
assert!(try_pretty_json("12345").is_none());
}
#[test]
fn try_pretty_json_nested() {
let input = r#"{"user":{"name":"Alice","roles":["admin","user"]}}"#;
let result = try_pretty_json(input);
assert!(result.is_some());
let pretty = result.unwrap();
assert!(pretty.contains("\"user\":"));
assert!(pretty.contains("\"name\": \"Alice\""));
assert!(pretty.contains("\"roles\":"));
}
#[test]
fn try_pretty_json_with_escapes() {
let input = r#"{"msg":"hello \"world\""}"#;
let result = try_pretty_json(input);
assert!(result.is_some());
let pretty = result.unwrap();
assert!(pretty.contains(r#"\"world\""#));
}
#[test]
fn inspection_name() {
let mw = RequestInspectionMiddleware::new();
assert_eq!(mw.name(), "RequestInspection");
}
#[test]
fn inspection_default_via_default_trait() {
let mw = RequestInspectionMiddleware::default();
assert_eq!(mw.verbosity, InspectionVerbosity::Normal);
assert_eq!(mw.slow_threshold_ms, 1000);
}
#[test]
fn inspection_with_query_string() {
let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
let ctx = test_context();
let mut req = Request::new(Method::Get, "/search");
req.set_query(Some("q=rust&page=1".to_string()));
let result = futures_executor::block_on(mw.before(&ctx, &mut req));
assert!(result.is_continue());
}
#[test]
fn inspection_response_body_stream() {
let mw = RequestInspectionMiddleware::new();
let result = mw.format_response_preview(&ResponseBody::Empty, None);
assert!(result.is_none());
}
}
#[cfg(test)]
mod rate_limit_tests {
use super::*;
use crate::request::Method;
use crate::response::{ResponseBody, StatusCode};
use std::time::Duration;
fn test_context() -> RequestContext {
RequestContext::new(asupersync::Cx::for_testing(), 1)
}
fn run_rate_limit_before(mw: &RateLimitMiddleware, req: &mut Request) -> ControlFlow {
let ctx = test_context();
let fut = mw.before(&ctx, req);
futures_executor::block_on(fut)
}
fn run_rate_limit_after(mw: &RateLimitMiddleware, req: &Request, resp: Response) -> Response {
let ctx = test_context();
let fut = mw.after(&ctx, req, resp);
futures_executor::block_on(fut)
}
#[test]
fn rate_limit_default_allows_requests() {
let mw = RateLimitMiddleware::new();
let mut req = Request::new(Method::Get, "/api/test");
req.headers_mut()
.insert("x-forwarded-for", b"192.168.1.1".to_vec());
let result = run_rate_limit_before(&mw, &mut req);
assert!(result.is_continue(), "first request should be allowed");
}
#[test]
fn rate_limit_fixed_window_blocks_after_limit() {
let mw = RateLimitMiddleware::builder()
.requests(3)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::FixedWindow)
.key_extractor(IpKeyExtractor)
.build();
for i in 0..3 {
let mut req = Request::new(Method::Get, "/api/test");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
let result = run_rate_limit_before(&mw, &mut req);
assert!(
result.is_continue(),
"request {i} should be allowed within limit"
);
}
let mut req = Request::new(Method::Get, "/api/test");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
let result = run_rate_limit_before(&mw, &mut req);
assert!(result.is_break(), "fourth request should be blocked");
if let ControlFlow::Break(resp) = result {
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
}
}
#[test]
fn rate_limit_different_keys_independent() {
let mw = RateLimitMiddleware::builder()
.requests(2)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::FixedWindow)
.key_extractor(IpKeyExtractor)
.build();
for _ in 0..2 {
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"1.1.1.1".to_vec());
assert!(run_rate_limit_before(&mw, &mut req).is_continue());
}
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"1.1.1.1".to_vec());
assert!(run_rate_limit_before(&mw, &mut req).is_break());
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"2.2.2.2".to_vec());
assert!(run_rate_limit_before(&mw, &mut req).is_continue());
}
#[test]
fn rate_limit_token_bucket_allows_burst() {
let mw = RateLimitMiddleware::builder()
.requests(5)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::TokenBucket)
.key_extractor(IpKeyExtractor)
.build();
for i in 0..5 {
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
let result = run_rate_limit_before(&mw, &mut req);
assert!(result.is_continue(), "burst request {i} should be allowed");
}
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
assert!(run_rate_limit_before(&mw, &mut req).is_break());
}
#[test]
fn rate_limit_sliding_window_basic() {
let mw = RateLimitMiddleware::builder()
.requests(3)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::SlidingWindow)
.key_extractor(IpKeyExtractor)
.build();
for i in 0..3 {
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
assert!(
run_rate_limit_before(&mw, &mut req).is_continue(),
"sliding window request {i} should be allowed"
);
}
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
assert!(run_rate_limit_before(&mw, &mut req).is_break());
}
#[test]
fn rate_limit_header_key_extractor() {
let mw = RateLimitMiddleware::builder()
.requests(2)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::FixedWindow)
.key_extractor(HeaderKeyExtractor::new("x-api-key"))
.build();
for _ in 0..2 {
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("x-api-key", b"key-abc".to_vec());
assert!(run_rate_limit_before(&mw, &mut req).is_continue());
}
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("x-api-key", b"key-abc".to_vec());
assert!(run_rate_limit_before(&mw, &mut req).is_break());
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("x-api-key", b"key-xyz".to_vec());
assert!(run_rate_limit_before(&mw, &mut req).is_continue());
}
#[test]
fn rate_limit_path_key_extractor() {
let mw = RateLimitMiddleware::builder()
.requests(1)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::FixedWindow)
.key_extractor(PathKeyExtractor)
.build();
let mut req = Request::new(Method::Get, "/api/a");
assert!(run_rate_limit_before(&mw, &mut req).is_continue());
let mut req = Request::new(Method::Get, "/api/a");
assert!(run_rate_limit_before(&mw, &mut req).is_break());
let mut req = Request::new(Method::Get, "/api/b");
assert!(run_rate_limit_before(&mw, &mut req).is_continue());
}
#[test]
fn rate_limit_no_key_skips_limiting() {
let mw = RateLimitMiddleware::builder()
.requests(1)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::FixedWindow)
.key_extractor(HeaderKeyExtractor::new("x-api-key"))
.build();
let mut req = Request::new(Method::Get, "/");
assert!(run_rate_limit_before(&mw, &mut req).is_continue());
for _ in 0..10 {
let mut req = Request::new(Method::Get, "/");
assert!(run_rate_limit_before(&mw, &mut req).is_continue());
}
}
#[test]
fn rate_limit_response_headers_on_success() {
let mw = RateLimitMiddleware::builder()
.requests(10)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::FixedWindow)
.key_extractor(IpKeyExtractor)
.build();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
let cf = run_rate_limit_before(&mw, &mut req);
assert!(cf.is_continue());
let resp = Response::with_status(StatusCode::OK);
let resp = run_rate_limit_after(&mw, &req, resp);
let headers = resp.headers();
let has_limit = headers
.iter()
.any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"));
let has_remaining = headers
.iter()
.any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-remaining"));
let has_reset = headers
.iter()
.any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-reset"));
assert!(has_limit, "should have X-RateLimit-Limit header");
assert!(has_remaining, "should have X-RateLimit-Remaining header");
assert!(has_reset, "should have X-RateLimit-Reset header");
let limit_val = headers
.iter()
.find(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"))
.map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
.unwrap();
assert_eq!(limit_val, "10");
}
#[test]
fn rate_limit_429_response_has_retry_after() {
let mw = RateLimitMiddleware::builder()
.requests(1)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::FixedWindow)
.key_extractor(IpKeyExtractor)
.build();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
assert!(run_rate_limit_before(&mw, &mut req).is_continue());
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
let result = run_rate_limit_before(&mw, &mut req);
if let ControlFlow::Break(resp) = result {
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
let has_retry = resp
.headers()
.iter()
.any(|(n, _)| n.eq_ignore_ascii_case("retry-after"));
assert!(has_retry, "429 response should have Retry-After header");
let has_ct = resp
.headers()
.iter()
.any(|(n, v)| n.eq_ignore_ascii_case("content-type") && v == b"application/json");
assert!(has_ct, "429 response should have JSON content type");
} else {
panic!("expected Break(429)");
}
}
#[test]
fn rate_limit_no_headers_when_disabled() {
let mw = RateLimitMiddleware::builder()
.requests(10)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::FixedWindow)
.key_extractor(IpKeyExtractor)
.include_headers(false)
.build();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
assert!(run_rate_limit_before(&mw, &mut req).is_continue());
let resp = Response::with_status(StatusCode::OK);
let resp = run_rate_limit_after(&mw, &req, resp);
let has_limit = resp
.headers()
.iter()
.any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"));
assert!(
!has_limit,
"should NOT have rate limit headers when disabled"
);
}
#[test]
fn rate_limit_custom_retry_message() {
let mw = RateLimitMiddleware::builder()
.requests(1)
.per(Duration::from_secs(60))
.algorithm(RateLimitAlgorithm::FixedWindow)
.key_extractor(IpKeyExtractor)
.retry_message("Slow down, partner!")
.build();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
run_rate_limit_before(&mw, &mut req);
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
if let ControlFlow::Break(resp) = run_rate_limit_before(&mw, &mut req) {
if let ResponseBody::Bytes(body) = resp.body_ref() {
let body_str = std::str::from_utf8(body).unwrap();
assert!(
body_str.contains("Slow down, partner!"),
"expected custom message in body, got: {body_str}"
);
} else {
panic!("expected Bytes body");
}
} else {
panic!("expected Break(429)");
}
}
#[test]
fn rate_limit_ip_extractor_x_forwarded_for() {
let extractor = IpKeyExtractor;
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"1.2.3.4, 5.6.7.8".to_vec());
assert_eq!(extractor.extract_key(&req), Some("1.2.3.4".to_string()));
}
#[test]
fn rate_limit_ip_extractor_x_real_ip() {
let extractor = IpKeyExtractor;
let mut req = Request::new(Method::Get, "/");
req.headers_mut().insert("x-real-ip", b"9.8.7.6".to_vec());
assert_eq!(extractor.extract_key(&req), Some("9.8.7.6".to_string()));
}
#[test]
fn rate_limit_ip_extractor_fallback() {
let extractor = IpKeyExtractor;
let req = Request::new(Method::Get, "/");
assert_eq!(extractor.extract_key(&req), Some("unknown".to_string()));
}
#[test]
fn connected_ip_extractor_with_remote_addr() {
use std::net::{IpAddr, Ipv4Addr};
let extractor = ConnectedIpKeyExtractor;
let mut req = Request::new(Method::Get, "/");
req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100))));
assert_eq!(
extractor.extract_key(&req),
Some("192.168.1.100".to_string())
);
}
#[test]
fn connected_ip_extractor_without_remote_addr() {
let extractor = ConnectedIpKeyExtractor;
let req = Request::new(Method::Get, "/");
assert_eq!(extractor.extract_key(&req), None);
}
#[test]
fn connected_ip_extractor_ignores_headers() {
use std::net::{IpAddr, Ipv4Addr};
let extractor = ConnectedIpKeyExtractor;
let mut req = Request::new(Method::Get, "/");
req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
req.headers_mut()
.insert("x-forwarded-for", b"1.2.3.4".to_vec());
assert_eq!(extractor.extract_key(&req), Some("10.0.0.1".to_string()));
}
#[test]
fn trusted_proxy_extractor_from_trusted_proxy() {
use std::net::{IpAddr, Ipv4Addr};
let extractor = TrustedProxyIpKeyExtractor::new().trust_cidr("10.0.0.0/8");
let mut req = Request::new(Method::Get, "/");
req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
req.headers_mut()
.insert("x-forwarded-for", b"203.0.113.50".to_vec());
assert_eq!(
extractor.extract_key(&req),
Some("203.0.113.50".to_string())
);
}
#[test]
fn trusted_proxy_extractor_from_untrusted_direct() {
use std::net::{IpAddr, Ipv4Addr};
let extractor = TrustedProxyIpKeyExtractor::new().trust_cidr("10.0.0.0/8");
let mut req = Request::new(Method::Get, "/");
req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50))));
req.headers_mut()
.insert("x-forwarded-for", b"1.2.3.4".to_vec());
assert_eq!(
extractor.extract_key(&req),
Some("203.0.113.50".to_string())
);
}
#[test]
fn trusted_proxy_extractor_no_remote_addr() {
let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
let mut req = Request::new(Method::Get, "/");
req.headers_mut()
.insert("x-forwarded-for", b"1.2.3.4".to_vec());
assert_eq!(extractor.extract_key(&req), None);
}
#[test]
fn trusted_proxy_extractor_loopback_ipv4() {
use std::net::{IpAddr, Ipv4Addr};
let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
let mut req = Request::new(Method::Get, "/");
req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::LOCALHOST)));
req.headers_mut()
.insert("x-forwarded-for", b"8.8.8.8".to_vec());
assert_eq!(extractor.extract_key(&req), Some("8.8.8.8".to_string()));
}
#[test]
fn trusted_proxy_extractor_loopback_ipv6() {
use std::net::{IpAddr, Ipv6Addr};
let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
let mut req = Request::new(Method::Get, "/");
req.insert_extension(RemoteAddr(IpAddr::V6(Ipv6Addr::LOCALHOST)));
req.headers_mut()
.insert("x-forwarded-for", b"8.8.8.8".to_vec());
assert_eq!(extractor.extract_key(&req), Some("8.8.8.8".to_string()));
}
#[test]
fn cidr_parsing() {
assert!(parse_cidr("10.0.0.0/8").is_some());
assert!(parse_cidr("192.168.1.0/24").is_some());
assert!(parse_cidr("0.0.0.0/0").is_some());
assert!(parse_cidr("::1/128").is_some());
assert!(parse_cidr("::/0").is_some());
assert!(parse_cidr("10.0.0.0/33").is_none()); assert!(parse_cidr("invalid").is_none());
assert!(parse_cidr("10.0.0.0").is_none()); }
#[test]
fn ip_in_cidr_matching() {
use std::net::{IpAddr, Ipv4Addr};
let cidr_10 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0));
assert!(ip_in_cidr(
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
cidr_10,
8
));
assert!(ip_in_cidr(
IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255)),
cidr_10,
8
));
assert!(!ip_in_cidr(
IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1)),
cidr_10,
8
));
assert!(!ip_in_cidr(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
cidr_10,
8
));
}
#[test]
fn rate_limit_composite_key_extractor() {
let extractor =
CompositeKeyExtractor::new(vec![Box::new(IpKeyExtractor), Box::new(PathKeyExtractor)]);
let mut req = Request::new(Method::Get, "/api/users");
req.headers_mut()
.insert("x-forwarded-for", b"10.0.0.1".to_vec());
let key = extractor.extract_key(&req);
assert_eq!(key, Some("10.0.0.1:/api/users".to_string()));
}
#[test]
fn rate_limit_builder_defaults() {
let mw = RateLimitMiddleware::builder().build();
assert_eq!(mw.config.max_requests, 100);
assert_eq!(mw.config.window, Duration::from_secs(60));
assert_eq!(mw.config.algorithm, RateLimitAlgorithm::TokenBucket);
assert!(mw.config.include_headers);
}
#[test]
fn rate_limit_builder_per_minute() {
let mw = RateLimitMiddleware::builder()
.requests(50)
.per_minute(2)
.algorithm(RateLimitAlgorithm::SlidingWindow)
.build();
assert_eq!(mw.config.max_requests, 50);
assert_eq!(mw.config.window, Duration::from_secs(120));
assert_eq!(mw.config.algorithm, RateLimitAlgorithm::SlidingWindow);
}
#[test]
fn rate_limit_builder_per_hour() {
let mw = RateLimitMiddleware::builder()
.requests(1000)
.per_hour(1)
.build();
assert_eq!(mw.config.window, Duration::from_secs(3600));
}
#[test]
fn rate_limit_middleware_name() {
let mw = RateLimitMiddleware::new();
assert_eq!(mw.name(), "RateLimit");
}
#[test]
fn rate_limit_default_via_default_trait() {
let mw = RateLimitMiddleware::default();
assert_eq!(mw.config.max_requests, 100);
}
#[test]
fn etag_middleware_generates_etag_for_get() {
let mw = ETagMiddleware::new();
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/resource");
let response = Response::ok()
.header("content-type", b"application/json".to_vec())
.body(ResponseBody::Bytes(br#"{"status":"ok"}"#.to_vec()));
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
let etag = response
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("etag"));
assert!(etag.is_some(), "Response should have ETag header");
let etag_value = std::str::from_utf8(&etag.unwrap().1).unwrap();
assert!(etag_value.starts_with('"'), "ETag should start with quote");
assert!(etag_value.ends_with('"'), "ETag should end with quote");
}
#[test]
fn etag_middleware_returns_304_on_match() {
let mw = ETagMiddleware::new();
let ctx = test_context();
let req1 = Request::new(crate::request::Method::Get, "/resource");
let body = br#"{"status":"ok"}"#.to_vec();
let response1 = Response::ok().body(ResponseBody::Bytes(body.clone()));
let response1 = futures_executor::block_on(mw.after(&ctx, &req1, response1));
let etag = response1
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("etag"))
.map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
.unwrap();
let mut req2 = Request::new(crate::request::Method::Get, "/resource");
req2.headers_mut()
.insert("if-none-match", etag.as_bytes().to_vec());
let response2 = Response::ok().body(ResponseBody::Bytes(body));
let response2 = futures_executor::block_on(mw.after(&ctx, &req2, response2));
assert_eq!(response2.status().as_u16(), 304);
assert!(response2.body_ref().is_empty());
}
#[test]
fn etag_middleware_returns_full_response_on_mismatch() {
let mw = ETagMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/resource");
req.headers_mut()
.insert("if-none-match", b"\"old-etag\"".to_vec());
let body = br#"{"status":"updated"}"#.to_vec();
let response = Response::ok().body(ResponseBody::Bytes(body.clone()));
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
assert_eq!(response.status().as_u16(), 200);
assert!(!response.body_ref().is_empty());
}
#[test]
fn etag_middleware_weak_etag_generation() {
let config = ETagConfig::new().weak(true);
let mw = ETagMiddleware::with_config(config);
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/resource");
let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
let etag = response
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("etag"))
.map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
.unwrap();
assert!(etag.starts_with("W/"), "Weak ETag should start with W/");
}
#[test]
fn etag_middleware_skips_post_requests() {
let mw = ETagMiddleware::new();
let ctx = test_context();
let req = Request::new(crate::request::Method::Post, "/resource");
let response = Response::ok().body(ResponseBody::Bytes(b"created".to_vec()));
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
let etag = response
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("etag"));
assert!(etag.is_none(), "POST should not have ETag");
}
#[test]
fn etag_middleware_handles_head_requests() {
let mw = ETagMiddleware::new();
let ctx = test_context();
let req = Request::new(crate::request::Method::Head, "/resource");
let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
let etag = response
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("etag"));
assert!(etag.is_some(), "HEAD should have ETag");
}
#[test]
fn etag_middleware_disabled_mode() {
let config = ETagConfig::new().mode(ETagMode::Disabled);
let mw = ETagMiddleware::with_config(config);
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/resource");
let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
let etag = response
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("etag"));
assert!(etag.is_none(), "Disabled mode should not add ETag");
}
#[test]
fn etag_middleware_min_size_filter() {
let config = ETagConfig::new().min_size(1000);
let mw = ETagMiddleware::with_config(config);
let ctx = test_context();
let req = Request::new(crate::request::Method::Get, "/resource");
let response = Response::ok().body(ResponseBody::Bytes(b"small".to_vec()));
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
let etag = response
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("etag"));
assert!(etag.is_none(), "Small body should not get ETag");
}
#[test]
fn etag_middleware_preserves_existing_etag() {
let config = ETagConfig::new().mode(ETagMode::Manual);
let mw = ETagMiddleware::with_config(config);
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/resource");
req.headers_mut()
.insert("if-none-match", b"\"custom-etag\"".to_vec());
let response = Response::ok()
.header("etag", b"\"custom-etag\"".to_vec())
.body(ResponseBody::Bytes(b"data".to_vec()));
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
assert_eq!(response.status().as_u16(), 304);
}
#[test]
fn etag_middleware_wildcard_if_none_match() {
let mw = ETagMiddleware::new();
let ctx = test_context();
let mut req = Request::new(crate::request::Method::Get, "/resource");
req.headers_mut().insert("if-none-match", b"*".to_vec());
let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
let response = futures_executor::block_on(mw.after(&ctx, &req, response));
assert_eq!(response.status().as_u16(), 304);
}
#[test]
fn etag_middleware_weak_comparison_matches() {
let mw = ETagMiddleware::new();
let ctx = test_context();
let req1 = Request::new(crate::request::Method::Get, "/resource");
let body = b"test data".to_vec();
let response1 = Response::ok().body(ResponseBody::Bytes(body.clone()));
let response1 = futures_executor::block_on(mw.after(&ctx, &req1, response1));
let etag = response1
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("etag"))
.map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
.unwrap();
let mut req2 = Request::new(crate::request::Method::Get, "/resource");
let weak_etag = format!("W/{}", etag);
req2.headers_mut()
.insert("if-none-match", weak_etag.as_bytes().to_vec());
let response2 = Response::ok().body(ResponseBody::Bytes(body));
let response2 = futures_executor::block_on(mw.after(&ctx, &req2, response2));
assert_eq!(response2.status().as_u16(), 304);
}
#[test]
fn etag_middleware_name() {
let mw = ETagMiddleware::new();
assert_eq!(mw.name(), "ETagMiddleware");
}
#[test]
fn etag_config_builder() {
let config = ETagConfig::new()
.mode(ETagMode::Auto)
.weak(true)
.min_size(512);
assert_eq!(config.mode, ETagMode::Auto);
assert!(config.weak);
assert_eq!(config.min_size, 512);
}
#[test]
fn etag_generates_consistent_hash() {
let etag1 = ETagMiddleware::generate_etag(b"hello world", false);
let etag2 = ETagMiddleware::generate_etag(b"hello world", false);
assert_eq!(etag1, etag2);
let etag3 = ETagMiddleware::generate_etag(b"hello world!", false);
assert_ne!(etag1, etag3);
}
}