use std::sync::Arc;
use axum::body::Body;
use axum::extract::{FromRequestParts, Request};
use axum::http::header::HeaderValue;
use axum::http::request::Parts;
use axum::http::Response;
use axum::middleware::Next;
use axum::Router;
const HEADER_NAME: &str = "x-request-id";
#[derive(Clone)]
pub struct RequestIdLayer {
pub always_generate: bool,
}
impl Default for RequestIdLayer {
fn default() -> Self {
Self::new()
}
}
impl RequestIdLayer {
#[must_use]
pub fn new() -> Self {
Self { always_generate: false }
}
#[must_use]
pub fn always_generate() -> Self {
Self { always_generate: true }
}
}
pub trait RequestIdRouterExt {
#[must_use]
fn request_id(self, layer: RequestIdLayer) -> Self;
}
impl<S: Clone + Send + Sync + 'static> RequestIdRouterExt for Router<S> {
fn request_id(self, layer: RequestIdLayer) -> Self {
let cfg = Arc::new(layer);
self.layer(axum::middleware::from_fn(
move |req: Request<Body>, next: Next| {
let cfg = cfg.clone();
async move { handle(cfg, req, next).await }
},
))
}
}
#[derive(Debug, Clone)]
pub struct RequestId(pub String);
impl<S: Send + Sync> FromRequestParts<S> for RequestId {
type Rejection = std::convert::Infallible;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
Ok(parts
.extensions
.get::<RequestId>()
.cloned()
.unwrap_or_else(|| RequestId(String::new())))
}
}
async fn handle(cfg: Arc<RequestIdLayer>, mut req: Request<Body>, next: Next) -> Response<Body> {
let id = if cfg.always_generate {
generate_id()
} else {
req.headers()
.get(HEADER_NAME)
.and_then(|v| v.to_str().ok())
.filter(|s| !s.is_empty() && is_safe(s))
.map_or_else(generate_id, str::to_owned)
};
req.extensions_mut().insert(RequestId(id.clone()));
let mut response = next.run(req).await;
if let Ok(v) = HeaderValue::from_str(&id) {
response.headers_mut().insert(HEADER_NAME, v);
}
response
}
fn generate_id() -> String {
use base64::Engine;
use rand::RngCore;
let mut bytes = [0u8; 16];
rand::thread_rng().fill_bytes(&mut bytes);
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
fn is_safe(s: &str) -> bool {
s.len() <= 128
&& s.chars().all(|c| {
!c.is_control() && c != '\n' && c != '\r' && c != '\0'
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generated_id_is_22_chars() {
let id = generate_id();
assert_eq!(id.len(), 22);
}
#[test]
fn generated_ids_are_unique() {
let a = generate_id();
let b = generate_id();
assert_ne!(a, b);
}
#[test]
fn is_safe_accepts_normal() {
assert!(is_safe("abc-123_xyz"));
assert!(is_safe("550e8400-e29b-41d4-a716-446655440000"));
}
#[test]
fn is_safe_rejects_long_strings() {
let long = "x".repeat(129);
assert!(!is_safe(&long));
}
#[test]
fn is_safe_rejects_newlines() {
assert!(!is_safe("abc\ndef"));
assert!(!is_safe("abc\rdef"));
}
#[test]
fn is_safe_rejects_null_bytes() {
assert!(!is_safe("abc\0def"));
}
#[test]
fn defaults_honor_inbound() {
let l = RequestIdLayer::default();
assert!(!l.always_generate);
}
#[test]
fn always_generate_overrides() {
let l = RequestIdLayer::always_generate();
assert!(l.always_generate);
}
}