use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use axum::body::Body;
use axum::extract::Request;
use axum::http::HeaderValue;
use axum::middleware::Next;
use axum::response::Response;
use futures::FutureExt;
pub(crate) struct Governor {
pub timeout: Duration,
pub max_in_flight: usize,
in_flight: AtomicUsize,
}
impl Governor {
pub(crate) fn new(timeout: Duration, max_in_flight: usize) -> Arc<Self> {
Arc::new(Self {
timeout,
max_in_flight,
in_flight: AtomicUsize::new(0),
})
}
}
struct Slot(Arc<Governor>);
impl Drop for Slot {
fn drop(&mut self) {
self.0.in_flight.fetch_sub(1, Ordering::AcqRel);
}
}
pub(crate) async fn govern(gov: Arc<Governor>, mut req: Request, next: Next) -> Response {
let rid: HeaderValue = match req.headers().get("x-request-id") {
Some(v) if !v.is_empty() => v.clone(),
_ => {
let v = HeaderValue::from_str(&gen_request_id()).expect("generated id is ASCII");
req.headers_mut().insert("x-request-id", v.clone());
v
}
};
let occupied = gov.in_flight.fetch_add(1, Ordering::AcqRel);
let _slot = Slot(Arc::clone(&gov));
if gov.max_in_flight > 0 && occupied >= gov.max_in_flight {
metrics::counter!("http_requests_shed_total").increment(1);
let mut resp = Response::builder()
.status(503)
.header("retry-after", "1")
.body(Body::from("server at capacity"))
.expect("static shed response");
resp.headers_mut().insert("x-request-id", rid);
return resp;
}
let guarded = std::panic::AssertUnwindSafe(next.run(req)).catch_unwind();
let outcome = if gov.timeout.is_zero() {
Ok(guarded.await)
} else {
tokio::time::timeout(gov.timeout, guarded).await
};
let mut resp = match outcome {
Ok(Ok(resp)) => resp,
Ok(Err(panic_payload)) => {
let msg = panic_payload
.downcast_ref::<&str>()
.copied()
.or_else(|| panic_payload.downcast_ref::<String>().map(String::as_str))
.unwrap_or("<non-string panic payload>");
metrics::counter!("http_handler_panics_total").increment(1);
tracing::error!(
request_id = ?rid,
panic = msg,
"handler panicked — answered 500"
);
Response::builder()
.status(500)
.body(Body::from("internal server error"))
.expect("static panic response")
}
Err(_) => {
metrics::counter!("http_requests_deadline_total").increment(1);
Response::builder()
.status(504)
.body(Body::from("request exceeded the server deadline"))
.expect("static timeout response")
}
};
resp.headers_mut().insert("x-request-id", rid);
resp
}
fn gen_request_id() -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let mut h = DefaultHasher::new();
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
.hash(&mut h);
COUNTER.fetch_add(1, Ordering::Relaxed).hash(&mut h);
std::process::id().hash(&mut h);
format!("{:016x}", h.finish())
}