use axum::{
body::Body,
extract::Extension,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Json, Response},
};
use std::sync::Arc;
const DEFAULT_QUERY_TIMEOUT_SECS: u64 = 30;
#[derive(Clone, Debug)]
pub struct QueryTimeoutConfig {
pub timeout: std::time::Duration,
}
impl QueryTimeoutConfig {
pub fn from_env() -> Arc<Self> {
let secs = std::env::var("TRUSTY_QUERY_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_QUERY_TIMEOUT_SECS);
tracing::info!("query timeout: {}s (TRUSTY_QUERY_TIMEOUT_SECS)", secs);
Arc::new(Self {
timeout: std::time::Duration::from_secs(secs),
})
}
#[cfg(test)]
pub fn from_duration(duration: std::time::Duration) -> Arc<Self> {
Arc::new(Self { timeout: duration })
}
}
fn timeout_response() -> Response {
let body = Json(serde_json::json!({
"error": "query_timeout",
"message": "Query exceeded the configured time limit — try a narrower query or retry",
}));
(StatusCode::REQUEST_TIMEOUT, body).into_response()
}
pub async fn apply_query_timeout(
Extension(cfg): Extension<Arc<QueryTimeoutConfig>>,
request: Request<Body>,
next: Next,
) -> Response {
match tokio::time::timeout(cfg.timeout, next.run(request)).await {
Ok(response) => response,
Err(_elapsed) => {
tracing::warn!(
timeout_secs = cfg.timeout.as_secs(),
"query_timeout: interactive query exceeded deadline, returning 408 (issue #907)"
);
metrics::counter!("trusty_query_timeouts_total").increment(1);
timeout_response()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
routing::post,
Router,
};
use std::time::Duration;
use tower::ServiceExt;
fn query_router_with_timeout(cfg: Arc<QueryTimeoutConfig>) -> Router {
Router::new()
.route(
"/search",
post(|| async {
"ok"
}),
)
.route(
"/search_slow",
post(|| async {
std::future::pending::<&str>().await
}),
)
.route_layer(axum::middleware::from_fn(apply_query_timeout))
.layer(Extension(cfg))
}
#[tokio::test]
async fn query_timeout_passes_through_fast_response() {
let cfg = QueryTimeoutConfig::from_duration(Duration::from_millis(100));
let app = query_router_with_timeout(cfg);
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/search")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(
resp.status(),
StatusCode::OK,
"fast query must return 200, not be cut off by timeout"
);
}
#[tokio::test]
async fn query_timeout_returns_408_when_handler_stalls() {
let cfg = QueryTimeoutConfig::from_duration(Duration::from_millis(50));
let app = query_router_with_timeout(cfg);
let start = std::time::Instant::now();
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/search_slow")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let elapsed = start.elapsed();
assert_eq!(
resp.status(),
StatusCode::REQUEST_TIMEOUT,
"stalled query must receive 408, not hang (elapsed: {:?})",
elapsed,
);
assert!(
elapsed < Duration::from_secs(2),
"408 must arrive before the 2 s wall-clock guard (elapsed: {:?})",
elapsed,
);
}
}