use axum::{
body::Body,
http::{HeaderValue, Request},
middleware::Next,
response::Response,
};
const DEFAULT_TIMEOUT_SECS: u64 = 120;
const DEFAULT_MAX_REQUESTS: u64 = 1000;
pub fn is_keepalive_hint_enabled() -> bool {
std::env::var("MOCKFORGE_HTTP_KEEPALIVE_HINT")
.ok()
.map(|v| matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on"))
.unwrap_or(false)
}
fn keepalive_timeout_secs() -> u64 {
std::env::var("MOCKFORGE_HTTP_KEEPALIVE_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_TIMEOUT_SECS)
}
fn keepalive_max_requests() -> u64 {
std::env::var("MOCKFORGE_HTTP_KEEPALIVE_MAX_REQUESTS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_MAX_REQUESTS)
}
pub async fn keepalive_hint_middleware(req: Request<Body>, next: Next) -> Response<Body> {
let mut response = next.run(req).await;
let already_close = response
.headers()
.get(http::header::CONNECTION)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_ascii_lowercase().contains("close"))
.unwrap_or(false);
if already_close {
return response;
}
response
.headers_mut()
.insert(http::header::CONNECTION, HeaderValue::from_static("keep-alive"));
let header_value =
format!("timeout={}, max={}", keepalive_timeout_secs(), keepalive_max_requests());
if let Ok(v) = HeaderValue::from_str(&header_value) {
response.headers_mut().insert("keep-alive", v);
}
response
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{routing::get, Router};
use tower::ServiceExt;
#[tokio::test]
async fn middleware_adds_keepalive_headers() {
let app: Router = Router::new()
.route("/", get(|| async { "ok" }))
.layer(axum::middleware::from_fn(keepalive_hint_middleware));
let req = Request::builder().uri("/").body(Body::empty()).unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.headers().get(http::header::CONNECTION).unwrap(), "keep-alive");
let ka = res.headers().get("keep-alive").unwrap().to_str().unwrap();
assert!(ka.contains("timeout="));
assert!(ka.contains("max="));
}
#[tokio::test]
async fn middleware_respects_existing_close_header() {
let app: Router = Router::new()
.route(
"/",
get(|| async {
let mut res = Response::new(Body::from("bye"));
res.headers_mut()
.insert(http::header::CONNECTION, HeaderValue::from_static("close"));
res
}),
)
.layer(axum::middleware::from_fn(keepalive_hint_middleware));
let req = Request::builder().uri("/").body(Body::empty()).unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.headers().get(http::header::CONNECTION).unwrap(), "close");
assert!(res.headers().get("keep-alive").is_none());
}
}