use std::sync::{Arc, Mutex};
use std::time::Instant;
use axum::body::Body;
use axum::extract::Request;
use axum::http::{HeaderName, HeaderValue, Response};
use axum::middleware::Next;
use axum::Router;
const HEADER: &str = "server-timing";
#[derive(Clone)]
pub struct Timings {
inner: Arc<Mutex<TimingsInner>>,
}
struct TimingsInner {
request_started: Instant,
last_mark: Instant,
entries: Vec<(String, f64)>, }
impl Timings {
fn new(request_started: Instant) -> Self {
Self {
inner: Arc::new(Mutex::new(TimingsInner {
request_started,
last_mark: request_started,
entries: Vec::new(),
})),
}
}
pub fn measure(&self, stage: impl Into<String>) {
let now = Instant::now();
let mut g = self.inner.lock().expect("Timings poisoned");
let elapsed = now.duration_since(g.last_mark).as_secs_f64() * 1000.0;
g.entries.push((stage.into(), elapsed));
g.last_mark = now;
}
pub fn add(&self, stage: impl Into<String>, ms: f64) {
let mut g = self.inner.lock().expect("Timings poisoned");
g.entries.push((stage.into(), ms));
}
pub fn finish(&self) {
}
fn render(&self) -> String {
let g = self.inner.lock().expect("Timings poisoned");
let total_ms = g.request_started.elapsed().as_secs_f64() * 1000.0;
let mut parts = Vec::with_capacity(g.entries.len() + 1);
parts.push(format!("total;dur={total_ms:.1}"));
for (name, ms) in &g.entries {
let n = sanitize_token(name);
parts.push(format!("{n};dur={ms:.1}"));
}
parts.join(", ")
}
}
fn sanitize_token(name: &str) -> String {
let mut out = String::with_capacity(name.len());
for c in name.chars() {
if c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.' {
out.push(c);
} else {
out.push('_');
}
}
if out.is_empty() {
out.push('_');
}
out
}
#[derive(Clone, Default, Debug)]
pub struct ServerTimingLayer;
impl ServerTimingLayer {
#[must_use]
pub fn new() -> Self {
Self
}
}
pub trait ServerTimingRouterExt {
#[must_use]
fn server_timing(self, layer: ServerTimingLayer) -> Self;
}
impl<S: Clone + Send + Sync + 'static> ServerTimingRouterExt for Router<S> {
fn server_timing(self, _layer: ServerTimingLayer) -> Self {
self.layer(axum::middleware::from_fn(handle))
}
}
async fn handle(mut req: Request<Body>, next: Next) -> Response<Body> {
let started = Instant::now();
let timings = Timings::new(started);
req.extensions_mut().insert(timings.clone());
let mut response = next.run(req).await;
let value = timings.render();
if let Ok(v) = HeaderValue::from_str(&value) {
if let Ok(name) = HeaderName::try_from(HEADER) {
response.headers_mut().insert(name, v);
}
}
response
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Request;
use axum::routing::get;
use axum::Extension;
use tower::ServiceExt;
#[tokio::test]
async fn header_is_set_with_total_only_when_no_stages() {
async fn h() -> &'static str {
"ok"
}
let app = Router::new()
.route("/", get(h))
.server_timing(ServerTimingLayer::new());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let v = resp
.headers()
.get("server-timing")
.unwrap()
.to_str()
.unwrap();
assert!(v.starts_with("total;dur="), "got: {v}");
}
#[tokio::test]
async fn handler_can_record_stages() {
async fn h(Extension(t): Extension<Timings>) -> &'static str {
t.measure("db");
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
t.measure("render");
"ok"
}
let app = Router::new()
.route("/", get(h))
.server_timing(ServerTimingLayer::new());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let v = resp
.headers()
.get("server-timing")
.unwrap()
.to_str()
.unwrap();
assert!(v.contains("total;dur="));
assert!(v.contains("db;dur="));
assert!(v.contains("render;dur="));
}
#[tokio::test]
async fn add_records_stage_with_explicit_duration() {
async fn h(Extension(t): Extension<Timings>) -> &'static str {
t.add("synthetic", 12.5);
"ok"
}
let app = Router::new()
.route("/", get(h))
.server_timing(ServerTimingLayer::new());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let v = resp
.headers()
.get("server-timing")
.unwrap()
.to_str()
.unwrap();
assert!(v.contains("synthetic;dur=12.5"));
}
#[tokio::test]
async fn stage_names_are_sanitized() {
async fn h(Extension(t): Extension<Timings>) -> &'static str {
t.add("db query (selects)", 1.0);
"ok"
}
let app = Router::new()
.route("/", get(h))
.server_timing(ServerTimingLayer::new());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let v = resp
.headers()
.get("server-timing")
.unwrap()
.to_str()
.unwrap();
assert!(v.contains("db_query__selects_;dur=1.0"));
}
#[tokio::test]
async fn finish_is_a_noop_marker() {
async fn h(Extension(t): Extension<Timings>) -> &'static str {
t.measure("a");
t.finish();
"ok"
}
let app = Router::new()
.route("/", get(h))
.server_timing(ServerTimingLayer::new());
let resp = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), 200);
}
#[test]
fn sanitize_token_keeps_ascii_word_chars() {
assert_eq!(sanitize_token("db_query"), "db_query");
assert_eq!(sanitize_token("db-query"), "db-query");
assert_eq!(sanitize_token("db.query"), "db.query");
assert_eq!(sanitize_token("db123"), "db123");
}
#[test]
fn sanitize_token_replaces_invalid_chars() {
assert_eq!(sanitize_token("db query"), "db_query");
assert_eq!(sanitize_token("ñ"), "_");
}
#[test]
fn sanitize_token_empty_input_returns_underscore() {
assert_eq!(sanitize_token(""), "_");
}
#[test]
fn render_orders_total_first() {
let t = Timings::new(Instant::now());
t.add("db", 5.0);
t.add("render", 3.0);
let s = t.render();
let parts: Vec<&str> = s.split(',').map(str::trim).collect();
assert!(parts[0].starts_with("total;dur="));
assert!(parts[1].starts_with("db;dur=5.0"));
assert!(parts[2].starts_with("render;dur=3.0"));
}
}