mockforge_http/middleware/
keepalive_hint.rs1use axum::{
21 body::Body,
22 http::{HeaderValue, Request},
23 middleware::Next,
24 response::Response,
25};
26
27const DEFAULT_TIMEOUT_SECS: u64 = 120;
31const DEFAULT_MAX_REQUESTS: u64 = 1000;
33
34pub fn is_keepalive_hint_enabled() -> bool {
37 std::env::var("MOCKFORGE_HTTP_KEEPALIVE_HINT")
38 .ok()
39 .map(|v| matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on"))
40 .unwrap_or(false)
41}
42
43fn keepalive_timeout_secs() -> u64 {
45 std::env::var("MOCKFORGE_HTTP_KEEPALIVE_TIMEOUT_SECS")
46 .ok()
47 .and_then(|v| v.parse::<u64>().ok())
48 .unwrap_or(DEFAULT_TIMEOUT_SECS)
49}
50
51fn keepalive_max_requests() -> u64 {
53 std::env::var("MOCKFORGE_HTTP_KEEPALIVE_MAX_REQUESTS")
54 .ok()
55 .and_then(|v| v.parse::<u64>().ok())
56 .unwrap_or(DEFAULT_MAX_REQUESTS)
57}
58
59pub async fn keepalive_hint_middleware(req: Request<Body>, next: Next) -> Response<Body> {
63 let mut response = next.run(req).await;
64
65 let already_close = response
68 .headers()
69 .get(http::header::CONNECTION)
70 .and_then(|v| v.to_str().ok())
71 .map(|s| s.to_ascii_lowercase().contains("close"))
72 .unwrap_or(false);
73 if already_close {
74 return response;
75 }
76
77 response
78 .headers_mut()
79 .insert(http::header::CONNECTION, HeaderValue::from_static("keep-alive"));
80
81 let header_value =
82 format!("timeout={}, max={}", keepalive_timeout_secs(), keepalive_max_requests());
83 if let Ok(v) = HeaderValue::from_str(&header_value) {
84 response.headers_mut().insert("keep-alive", v);
88 }
89
90 response
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use axum::{routing::get, Router};
97 use tower::ServiceExt;
98
99 #[tokio::test]
100 async fn middleware_adds_keepalive_headers() {
101 let app: Router = Router::new()
102 .route("/", get(|| async { "ok" }))
103 .layer(axum::middleware::from_fn(keepalive_hint_middleware));
104
105 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
106 let res = app.oneshot(req).await.unwrap();
107
108 assert_eq!(res.headers().get(http::header::CONNECTION).unwrap(), "keep-alive");
109 let ka = res.headers().get("keep-alive").unwrap().to_str().unwrap();
110 assert!(ka.contains("timeout="));
111 assert!(ka.contains("max="));
112 }
113
114 #[tokio::test]
115 async fn middleware_respects_existing_close_header() {
116 let app: Router = Router::new()
117 .route(
118 "/",
119 get(|| async {
120 let mut res = Response::new(Body::from("bye"));
121 res.headers_mut()
122 .insert(http::header::CONNECTION, HeaderValue::from_static("close"));
123 res
124 }),
125 )
126 .layer(axum::middleware::from_fn(keepalive_hint_middleware));
127
128 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
129 let res = app.oneshot(req).await.unwrap();
130
131 assert_eq!(res.headers().get(http::header::CONNECTION).unwrap(), "close");
132 assert!(res.headers().get("keep-alive").is_none());
133 }
134}