use axum::{
body::Body,
extract::Extension,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Json, Response},
};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::Semaphore;
const DEFAULT_MAX_CONCURRENT: usize = 8;
const DEFAULT_QUEUE_DEPTH: usize = 32;
pub struct ConcurrencyLimiter {
semaphore: Arc<Semaphore>,
queue_depth: usize,
waiting: Arc<AtomicUsize>,
max_concurrent: usize,
}
impl ConcurrencyLimiter {
pub fn from_env() -> Arc<Self> {
let max_concurrent = std::env::var("TRUSTY_MAX_CONCURRENT_REQUESTS")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.map(|n| n.max(1))
.unwrap_or(DEFAULT_MAX_CONCURRENT);
let queue_depth = std::env::var("TRUSTY_QUEUE_DEPTH")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(DEFAULT_QUEUE_DEPTH);
tracing::info!(
"concurrency limiter: max_concurrent={} queue_depth={}",
max_concurrent,
queue_depth
);
Arc::new(Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
queue_depth,
waiting: Arc::new(AtomicUsize::new(0)),
max_concurrent,
})
}
#[cfg(test)]
pub fn with_limits(max_concurrent: usize, queue_depth: usize) -> Arc<Self> {
Arc::new(Self {
semaphore: Arc::new(Semaphore::new(max_concurrent.max(1))),
queue_depth,
waiting: Arc::new(AtomicUsize::new(0)),
max_concurrent: max_concurrent.max(1),
})
}
pub fn waiting(&self) -> usize {
self.waiting.load(Ordering::Relaxed)
}
pub fn max_concurrent(&self) -> usize {
self.max_concurrent
}
}
fn busy_response() -> Response {
let body = Json(serde_json::json!({
"error": "server_busy",
"message": "Request queue full, retry shortly",
}));
let mut resp = (StatusCode::SERVICE_UNAVAILABLE, body).into_response();
resp.headers_mut().insert(
axum::http::header::RETRY_AFTER,
axum::http::HeaderValue::from_static("2"),
);
resp
}
pub async fn apply_limiter(
Extension(limiter): Extension<Arc<ConcurrencyLimiter>>,
request: Request<Body>,
next: Next,
) -> Response {
let permit = limiter.semaphore.clone().try_acquire_owned().ok();
let permit = match permit {
Some(p) => p,
None => {
let current_waiters = limiter.waiting.fetch_add(1, Ordering::Relaxed);
metrics::gauge!("trusty_queue_depth").set((current_waiters + 1) as f64);
if current_waiters >= limiter.queue_depth {
limiter.waiting.fetch_sub(1, Ordering::Relaxed);
metrics::gauge!("trusty_queue_depth")
.set(limiter.waiting.load(Ordering::Relaxed) as f64);
metrics::counter!("trusty_requests_rejected_total").increment(1);
tracing::warn!("concurrency limiter: queue full, returning 503");
return busy_response();
}
let acquired = limiter.semaphore.clone().acquire_owned().await;
limiter.waiting.fetch_sub(1, Ordering::Relaxed);
metrics::gauge!("trusty_queue_depth")
.set(limiter.waiting.load(Ordering::Relaxed) as f64);
match acquired {
Ok(p) => p,
Err(_) => {
return busy_response();
}
}
}
};
let response = next.run(request).await;
drop(permit);
response
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
routing::get,
Router,
};
use std::time::Duration;
use tower::ServiceExt;
fn limited_router(limiter: Arc<ConcurrencyLimiter>) -> Router {
Router::new()
.route(
"/slow",
get(|| async {
tokio::time::sleep(Duration::from_millis(100)).await;
"ok"
}),
)
.route_layer(axum::middleware::from_fn(apply_limiter))
.layer(Extension(limiter))
}
#[tokio::test]
async fn from_env_uses_defaults_when_unset() {
std::env::remove_var("TRUSTY_MAX_CONCURRENT_REQUESTS");
std::env::remove_var("TRUSTY_QUEUE_DEPTH");
let limiter = ConcurrencyLimiter::from_env();
assert_eq!(limiter.max_concurrent(), DEFAULT_MAX_CONCURRENT);
}
#[tokio::test]
async fn limiter_admits_up_to_concurrency() {
let limiter = ConcurrencyLimiter::with_limits(2, 4);
let app = limited_router(limiter);
let req = || {
Request::builder()
.uri("/slow")
.body(Body::empty())
.expect("valid request")
};
let r1 = app.clone().oneshot(req());
let r2 = app.clone().oneshot(req());
let (res1, res2) = tokio::join!(r1, r2);
assert_eq!(res1.unwrap().status(), StatusCode::OK);
assert_eq!(res2.unwrap().status(), StatusCode::OK);
}
#[tokio::test]
async fn limiter_returns_503_when_queue_full() {
let limiter = ConcurrencyLimiter::with_limits(1, 0);
let app = limited_router(limiter);
let req = || {
Request::builder()
.uri("/slow")
.body(Body::empty())
.expect("valid request")
};
let in_flight = tokio::spawn(app.clone().oneshot(req()));
tokio::time::sleep(Duration::from_millis(10)).await;
let rejected = app.oneshot(req()).await.expect("oneshot returns");
assert_eq!(rejected.status(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(
rejected
.headers()
.get(axum::http::header::RETRY_AFTER)
.map(|v| v.to_str().unwrap()),
Some("2")
);
let _ = in_flight.await;
}
}