use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use axum::body::Body;
use axum::extract::Request;
use axum::http::HeaderValue;
use axum::middleware::Next;
use axum::response::Response;
use futures::FutureExt;
pub(crate) struct Governor {
pub timeout: Duration,
pub max_in_flight: usize,
in_flight: AtomicUsize,
adaptive_target_micros: u64,
ewma_micros: AtomicU64,
roll: AtomicU64,
}
impl Governor {
pub(crate) fn new(
timeout: Duration,
max_in_flight: usize,
adaptive_target: Duration,
) -> Arc<Self> {
Arc::new(Self {
timeout,
max_in_flight,
in_flight: AtomicUsize::new(0),
adaptive_target_micros: adaptive_target.as_micros() as u64,
ewma_micros: AtomicU64::new(0),
roll: AtomicU64::new(0x9E37_79B9_7F4A_7C15),
})
}
fn record_latency(&self, micros: u64) {
let cur = self.ewma_micros.load(Ordering::Relaxed);
let next = if cur == 0 {
micros
} else {
cur - cur / 8 + micros / 8
};
self.ewma_micros.store(next, Ordering::Relaxed);
}
fn should_shed_adaptively(&self) -> bool {
if self.adaptive_target_micros == 0 {
return false;
}
let ewma = self.ewma_micros.load(Ordering::Relaxed);
let Some(threshold) = shed_threshold(ewma, self.adaptive_target_micros) else {
return false;
};
let mut x = self
.roll
.fetch_add(0x9E37_79B9_7F4A_7C15, Ordering::Relaxed);
x ^= x >> 33;
x = x.wrapping_mul(0xFF51_AFD7_ED55_8CCD);
x ^= x >> 33;
(x % 1000) < threshold
}
}
fn shed_threshold(ewma_micros: u64, target_micros: u64) -> Option<u64> {
if ewma_micros <= target_micros {
return None;
}
let pressure = ((ewma_micros - target_micros) * 1000) / target_micros;
Some(pressure.min(900))
}
struct Slot(Arc<Governor>);
impl Drop for Slot {
fn drop(&mut self) {
self.0.in_flight.fetch_sub(1, Ordering::AcqRel);
}
}
pub(crate) async fn govern(gov: Arc<Governor>, mut req: Request, next: Next) -> Response {
let rid: HeaderValue = match req.headers().get("x-request-id") {
Some(v) if !v.is_empty() => v.clone(),
_ => {
let v = HeaderValue::from_str(&gen_request_id()).expect("generated id is ASCII");
req.headers_mut().insert("x-request-id", v.clone());
v
}
};
let occupied = gov.in_flight.fetch_add(1, Ordering::AcqRel);
let _slot = Slot(Arc::clone(&gov));
if gov.max_in_flight > 0 && occupied >= gov.max_in_flight {
metrics::counter!("http_requests_shed_total").increment(1);
let mut resp = Response::builder()
.status(503)
.header("retry-after", "1")
.body(Body::from("server at capacity"))
.expect("static shed response");
resp.headers_mut().insert("x-request-id", rid);
return resp;
}
if gov.should_shed_adaptively() {
metrics::counter!("http_requests_adaptive_shed_total").increment(1);
let mut resp = Response::builder()
.status(503)
.header("retry-after", "1")
.body(Body::from("server shedding load (latency target exceeded)"))
.expect("static adaptive shed response");
resp.headers_mut().insert("x-request-id", rid);
return resp;
}
let started = std::time::Instant::now();
let guarded = std::panic::AssertUnwindSafe(next.run(req)).catch_unwind();
let outcome = if gov.timeout.is_zero() {
Ok(guarded.await)
} else {
tokio::time::timeout(gov.timeout, guarded).await
};
let mut resp = match outcome {
Ok(Ok(resp)) => resp,
Ok(Err(panic_payload)) => {
let msg = panic_payload
.downcast_ref::<&str>()
.copied()
.or_else(|| panic_payload.downcast_ref::<String>().map(String::as_str))
.unwrap_or("<non-string panic payload>");
metrics::counter!("http_handler_panics_total").increment(1);
tracing::error!(
request_id = ?rid,
panic = msg,
"handler panicked — answered 500"
);
Response::builder()
.status(500)
.body(Body::from("internal server error"))
.expect("static panic response")
}
Err(_) => {
metrics::counter!("http_requests_deadline_total").increment(1);
Response::builder()
.status(504)
.body(Body::from("request exceeded the server deadline"))
.expect("static timeout response")
}
};
gov.record_latency(started.elapsed().as_micros() as u64);
resp.headers_mut().insert("x-request-id", rid);
resp
}
fn gen_request_id() -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let mut h = DefaultHasher::new();
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
.hash(&mut h);
COUNTER.fetch_add(1, Ordering::Relaxed).hash(&mut h);
std::process::id().hash(&mut h);
format!("{:016x}", h.finish())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shed_threshold_ramps_with_pressure_and_caps() {
assert_eq!(shed_threshold(0, 10_000), None, "cold start never sheds");
assert_eq!(shed_threshold(10_000, 10_000), None, "at target: healthy");
assert_eq!(shed_threshold(15_000, 10_000), Some(500), "1.5x => 50%");
assert_eq!(shed_threshold(20_000, 10_000), Some(900), "2x caps at 90%");
assert_eq!(shed_threshold(100_000, 10_000), Some(900), "always capped");
}
#[test]
fn ewma_converges_toward_samples() {
let gov = Governor::new(Duration::ZERO, 0, Duration::from_millis(10));
for _ in 0..64 {
gov.record_latency(20_000);
}
let ewma = gov.ewma_micros.load(Ordering::Relaxed);
assert!(
(15_000..=20_000).contains(&ewma),
"ewma {ewma} should approach 20ms"
);
assert!(gov.adaptive_target_micros < ewma);
}
}