use std::pin::Pin;
use std::task::{Context, Poll};
use axum::body::Body;
use http::header::USER_AGENT;
use http::{HeaderValue, Request};
use tower::{Layer, Service};
const DEFAULT_MAX_LEN: usize = 512;
#[derive(Debug, Clone, Copy)]
pub struct UserAgentLayer {
max_len: usize,
}
impl UserAgentLayer {
pub fn new() -> Self {
Self {
max_len: DEFAULT_MAX_LEN,
}
}
pub fn with_max_length(max_len: usize) -> Self {
Self { max_len }
}
}
impl Default for UserAgentLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for UserAgentLayer {
type Service = UserAgentMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
UserAgentMiddleware {
inner,
max_len: self.max_len,
}
}
}
pub struct UserAgentMiddleware<S> {
inner: S,
max_len: usize,
}
impl<S: Clone> Clone for UserAgentMiddleware<S> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
max_len: self.max_len,
}
}
}
impl<S, ReqBody> Service<Request<ReqBody>> for UserAgentMiddleware<S>
where
S: Service<Request<ReqBody>, Response = http::Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
ReqBody: Send + 'static,
{
type Response = http::Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
let max_len = self.max_len;
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
let raw = request
.headers()
.get(USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
if let Some(raw) = raw {
match sanitize_user_agent(&raw, max_len) {
Some(clean) => {
let value = HeaderValue::from_str(&clean)
.expect("sanitized user-agent must be a valid header value");
request.headers_mut().insert(USER_AGENT, value);
}
None => {
request.headers_mut().remove(USER_AGENT);
}
}
}
inner.call(request).await
})
}
}
pub(crate) fn sanitize_user_agent(raw: &str, max_len: usize) -> Option<String> {
let mut end = raw.len().min(max_len);
while end > 0 && !raw.is_char_boundary(end) {
end -= 1;
}
let truncated = &raw[..end];
let mut out = String::with_capacity(truncated.len());
let mut prev_ws = false;
for c in truncated.chars() {
if c.is_ascii_whitespace() {
if !prev_ws {
out.push(' ');
prev_ws = true;
}
continue;
}
if c.is_ascii_control() {
continue;
}
out.push(c);
prev_ws = false;
}
let trimmed = out.trim();
if trimmed.is_empty() {
None
} else if trimmed.len() == out.len() {
Some(out)
} else {
Some(trimmed.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use http::{Request, Response, StatusCode};
use std::convert::Infallible;
use tower::ServiceExt;
#[test]
fn passes_clean_short_ua() {
assert_eq!(
sanitize_user_agent("Mozilla/5.0", 512).as_deref(),
Some("Mozilla/5.0"),
);
}
#[test]
fn truncates_to_max_len_ascii() {
let raw: String = "A".repeat(1024);
let out = sanitize_user_agent(&raw, 64).unwrap();
assert_eq!(out.len(), 64);
assert!(out.chars().all(|c| c == 'A'));
}
#[test]
fn truncates_at_char_boundary_multibyte() {
let raw: String = "ñ".repeat(20);
let out = sanitize_user_agent(&raw, 5).unwrap();
assert!(out.len() <= 5);
assert!(out.chars().all(|c| c == 'ñ'));
assert_eq!(out.len() % 2, 0);
}
#[test]
fn strips_ascii_control_chars() {
let out = sanitize_user_agent("Mozilla\x01/\x07X", 512).unwrap();
assert_eq!(out, "Mozilla/X");
}
#[test]
fn collapses_whitespace_runs() {
let out = sanitize_user_agent("Mozilla \t /5.0", 512).unwrap();
assert_eq!(out, "Mozilla /5.0");
}
#[test]
fn trims_leading_and_trailing_whitespace() {
assert_eq!(
sanitize_user_agent(" UA-Test ", 512).as_deref(),
Some("UA-Test"),
);
}
#[test]
fn keeps_non_ascii_chars() {
assert_eq!(
sanitize_user_agent("клиент/1.0", 512).as_deref(),
Some("клиент/1.0"),
);
}
#[test]
fn returns_none_for_empty_input() {
assert!(sanitize_user_agent("", 512).is_none());
}
#[test]
fn returns_none_for_only_whitespace() {
assert!(sanitize_user_agent(" \t ", 512).is_none());
}
#[test]
fn returns_none_for_only_controls() {
assert!(sanitize_user_agent("\x01\x02\x03", 512).is_none());
}
#[test]
fn zero_max_len_returns_none() {
assert!(sanitize_user_agent("Mozilla/5.0", 0).is_none());
}
async fn echo_ua(req: Request<Body>) -> Result<Response<Body>, Infallible> {
let ua = req
.headers()
.get(USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(str::to_string)
.unwrap_or_else(|| "<absent>".to_string());
Ok(Response::new(Body::from(ua)))
}
async fn run(svc_layer: UserAgentLayer, req: Request<Body>) -> String {
let svc = svc_layer.layer(tower::service_fn(echo_ua));
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
String::from_utf8(body.to_vec()).unwrap()
}
#[tokio::test]
async fn passes_clean_ua_unchanged() {
let req = Request::builder()
.header(USER_AGENT, "Mozilla/5.0")
.body(Body::empty())
.unwrap();
assert_eq!(run(UserAgentLayer::new(), req).await, "Mozilla/5.0");
}
#[tokio::test]
async fn truncates_long_ua() {
let long = "A".repeat(2000);
let req = Request::builder()
.header(USER_AGENT, long)
.body(Body::empty())
.unwrap();
let out = run(UserAgentLayer::with_max_length(64), req).await;
assert_eq!(out.len(), 64);
assert!(out.chars().all(|c| c == 'A'));
}
#[tokio::test]
async fn strips_controls_and_collapses_whitespace() {
let req = Request::builder()
.header(USER_AGENT, "Mozilla/5.0\t\t (foo) bar")
.body(Body::empty())
.unwrap();
assert_eq!(
run(UserAgentLayer::new(), req).await,
"Mozilla/5.0 (foo) bar",
);
}
#[tokio::test]
async fn removes_header_when_only_whitespace() {
let req = Request::builder()
.header(USER_AGENT, " \t ")
.body(Body::empty())
.unwrap();
assert_eq!(run(UserAgentLayer::new(), req).await, "<absent>");
}
#[tokio::test]
async fn leaves_absent_header_alone() {
let req = Request::builder().body(Body::empty()).unwrap();
assert_eq!(run(UserAgentLayer::new(), req).await, "<absent>");
}
#[tokio::test]
async fn respects_with_max_length() {
let req = Request::builder()
.header(USER_AGENT, "abcdefghijklmnop")
.body(Body::empty())
.unwrap();
assert_eq!(
run(UserAgentLayer::with_max_length(8), req).await,
"abcdefgh"
);
}
#[tokio::test]
async fn normalizes_duplicate_user_agent_headers() {
let mut req = Request::builder().body(Body::empty()).unwrap();
req.headers_mut()
.append(USER_AGENT, "Mozilla/5.0".parse().unwrap());
req.headers_mut()
.append(USER_AGENT, "Other/1.0".parse().unwrap());
let svc = UserAgentLayer::new().layer(tower::service_fn(|req: Request<Body>| async move {
let count = req.headers().get_all(USER_AGENT).iter().count();
let first = req
.headers()
.get(USER_AGENT)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
Ok::<_, Infallible>(Response::new(Body::from(format!("{count}|{first}"))))
}));
let resp = svc.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(body.as_ref(), b"1|Mozilla/5.0");
}
}