use axum::{
body::Body,
extract::Extension,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Json, Response},
};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::Semaphore;
const DEFAULT_MAX_CONCURRENT: usize = 8;
const DEFAULT_QUEUE_DEPTH: usize = 32;
const DEFAULT_QUEUE_TIMEOUT_SECS: u64 = 30;
fn queue_timeout() -> std::time::Duration {
static CACHED: std::sync::OnceLock<std::time::Duration> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
let secs = std::env::var("TRUSTY_QUEUE_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_QUEUE_TIMEOUT_SECS);
std::time::Duration::from_secs(secs)
})
}
pub struct ConcurrencyLimiter {
semaphore: Arc<Semaphore>,
queue_depth: usize,
waiting: Arc<AtomicUsize>,
max_concurrent: usize,
queue_timeout: std::time::Duration,
}
impl ConcurrencyLimiter {
pub fn from_env() -> Arc<Self> {
let max_concurrent = std::env::var("TRUSTY_MAX_CONCURRENT_REQUESTS")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.map(|n| n.max(1))
.unwrap_or(DEFAULT_MAX_CONCURRENT);
let queue_depth = std::env::var("TRUSTY_QUEUE_DEPTH")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(DEFAULT_QUEUE_DEPTH);
tracing::info!(
"concurrency limiter: max_concurrent={} queue_depth={}",
max_concurrent,
queue_depth
);
Arc::new(Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
queue_depth,
waiting: Arc::new(AtomicUsize::new(0)),
max_concurrent,
queue_timeout: queue_timeout(),
})
}
#[cfg(test)]
pub fn with_limits(max_concurrent: usize, queue_depth: usize) -> Arc<Self> {
Self::with_limits_and_timeout(
max_concurrent,
queue_depth,
std::time::Duration::from_secs(DEFAULT_QUEUE_TIMEOUT_SECS),
)
}
#[cfg(test)]
pub fn with_limits_and_timeout(
max_concurrent: usize,
queue_depth: usize,
queue_timeout: std::time::Duration,
) -> Arc<Self> {
Arc::new(Self {
semaphore: Arc::new(Semaphore::new(max_concurrent.max(1))),
queue_depth,
waiting: Arc::new(AtomicUsize::new(0)),
max_concurrent: max_concurrent.max(1),
queue_timeout,
})
}
pub fn waiting(&self) -> usize {
self.waiting.load(Ordering::Relaxed)
}
pub fn max_concurrent(&self) -> usize {
self.max_concurrent
}
}
fn busy_response() -> Response {
let body = Json(serde_json::json!({
"error": "server_busy",
"message": "Request queue full, retry shortly",
}));
let mut resp = (StatusCode::SERVICE_UNAVAILABLE, body).into_response();
resp.headers_mut().insert(
axum::http::header::RETRY_AFTER,
axum::http::HeaderValue::from_static("2"),
);
resp
}
pub async fn apply_limiter(
Extension(limiter): Extension<Arc<ConcurrencyLimiter>>,
request: Request<Body>,
next: Next,
) -> Response {
let permit = limiter.semaphore.clone().try_acquire_owned().ok();
let permit = match permit {
Some(p) => p,
None => {
let current_waiters = limiter.waiting.fetch_add(1, Ordering::Relaxed);
metrics::gauge!("trusty_queue_depth").set((current_waiters + 1) as f64);
if current_waiters >= limiter.queue_depth {
limiter.waiting.fetch_sub(1, Ordering::Relaxed);
metrics::gauge!("trusty_queue_depth")
.set(limiter.waiting.load(Ordering::Relaxed) as f64);
metrics::counter!("trusty_requests_rejected_total").increment(1);
tracing::warn!("concurrency limiter: queue full, returning 503");
return busy_response();
}
let deadline = limiter.queue_timeout;
let acquired =
tokio::time::timeout(deadline, limiter.semaphore.clone().acquire_owned()).await;
limiter.waiting.fetch_sub(1, Ordering::Relaxed);
metrics::gauge!("trusty_queue_depth")
.set(limiter.waiting.load(Ordering::Relaxed) as f64);
match acquired {
Err(_elapsed) => {
metrics::counter!("trusty_requests_rejected_total").increment(1);
tracing::warn!(
timeout_secs = deadline.as_secs(),
"concurrency limiter: queue-wait timed out, returning 503 (issue #907)"
);
return busy_response();
}
Ok(Ok(p)) => p,
Ok(Err(_)) => {
return busy_response();
}
}
}
};
let response = next.run(request).await;
drop(permit);
response
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
routing::get,
Router,
};
use std::time::Duration;
use tower::ServiceExt;
fn limited_router(limiter: Arc<ConcurrencyLimiter>) -> Router {
Router::new()
.route(
"/slow",
get(|| async {
tokio::time::sleep(Duration::from_millis(100)).await;
"ok"
}),
)
.route_layer(axum::middleware::from_fn(apply_limiter))
.layer(Extension(limiter))
}
fn forever_router_with_signal(
limiter: Arc<ConcurrencyLimiter>,
) -> (Router, tokio::sync::oneshot::Receiver<()>) {
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx)));
let router = Router::new()
.route(
"/forever",
get(move || {
let tx = std::sync::Arc::clone(&tx);
async move {
if let Some(sender) = tx.lock().await.take() {
let _ = sender.send(());
}
std::future::pending::<&str>().await
}
}),
)
.route_layer(axum::middleware::from_fn(apply_limiter))
.layer(Extension(limiter));
(router, rx)
}
#[tokio::test]
async fn from_env_uses_defaults_when_unset() {
std::env::remove_var("TRUSTY_MAX_CONCURRENT_REQUESTS");
std::env::remove_var("TRUSTY_QUEUE_DEPTH");
let limiter = ConcurrencyLimiter::from_env();
assert_eq!(limiter.max_concurrent(), DEFAULT_MAX_CONCURRENT);
}
#[tokio::test]
async fn limiter_admits_up_to_concurrency() {
let limiter = ConcurrencyLimiter::with_limits(2, 4);
let app = limited_router(limiter);
let req = || {
Request::builder()
.uri("/slow")
.body(Body::empty())
.expect("valid request")
};
let r1 = app.clone().oneshot(req());
let r2 = app.clone().oneshot(req());
let (res1, res2) = tokio::join!(r1, r2);
assert_eq!(res1.unwrap().status(), StatusCode::OK);
assert_eq!(res2.unwrap().status(), StatusCode::OK);
}
#[tokio::test]
async fn limiter_returns_503_when_queue_full() {
let limiter = ConcurrencyLimiter::with_limits(1, 0);
let app = limited_router(limiter);
let req = || {
Request::builder()
.uri("/slow")
.body(Body::empty())
.expect("valid request")
};
let in_flight = tokio::spawn(app.clone().oneshot(req()));
tokio::time::sleep(Duration::from_millis(10)).await;
let rejected = app.oneshot(req()).await.expect("oneshot returns");
assert_eq!(rejected.status(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(
rejected
.headers()
.get(axum::http::header::RETRY_AFTER)
.map(|v| v.to_str().unwrap()),
Some("2")
);
let _ = in_flight.await;
}
#[tokio::test]
async fn queue_wait_returns_503_on_timeout() {
let limiter = ConcurrencyLimiter::with_limits_and_timeout(1, 1, Duration::from_millis(50));
let (app, permit_acquired) = forever_router_with_signal(limiter);
let req = || {
Request::builder()
.uri("/forever")
.body(Body::empty())
.expect("valid request")
};
let _in_flight = tokio::spawn(app.clone().oneshot(req()));
permit_acquired
.await
.expect("in-flight handler must send the permit-acquired signal");
let start = std::time::Instant::now();
let waiting = app.oneshot(req()).await.expect("oneshot returns");
let elapsed = start.elapsed();
assert_eq!(
waiting.status(),
StatusCode::SERVICE_UNAVAILABLE,
"queue-wait timeout must return 503, got {} (elapsed: {:?})",
waiting.status(),
elapsed,
);
assert!(
elapsed < Duration::from_secs(2),
"queue-wait timeout must not block indefinitely (elapsed: {:?})",
elapsed,
);
assert_eq!(
waiting
.headers()
.get(axum::http::header::RETRY_AFTER)
.map(|v| v.to_str().unwrap()),
Some("2"),
"503 response must include Retry-After header"
);
}
}