Skip to main content

mockforge_http/middleware/
keepalive_hint.rs

1//! Connection: keep-alive hint middleware.
2//!
3//! Adds explicit `Connection: keep-alive` and `Keep-Alive: timeout=N, max=M`
4//! response headers when enabled via the `MOCKFORGE_HTTP_KEEPALIVE_HINT`
5//! environment variable (or the `--http-keepalive-hint` CLI flag, when wired).
6//!
7//! This is a workaround for proxies that:
8//! - Speak HTTP/1.0 upstream by default (hyper closes the connection after one
9//!   response unless the request carried `Connection: keep-alive`).
10//! - Cache the keep-alive policy from the response headers rather than the
11//!   HTTP version. F5/Avi/HAProxy in some configurations look at the `Keep-
12//!   Alive` response header to decide whether to pool the upstream socket.
13//!
14//! Issue #79 — Srikanth's round-3 reply: proxy observed FIN from MockForge
15//! after every 200 response, then RST when it reused the socket. Root cause is
16//! upstream HTTP/1.1 not being negotiated. We can't force hyper to keep the
17//! connection alive after an HTTP/1.0 request, but we can advertise our
18//! preferred policy in the response so proxies that read it adjust.
19
20use axum::{
21    body::Body,
22    http::{HeaderValue, Request},
23    middleware::Next,
24    response::Response,
25};
26
27/// Default idle timeout advertised by the `Keep-Alive` header, in seconds.
28/// Picked to match hyper's documented behavior — long enough that even a
29/// slowly-draining proxy pool reuses the socket before MockForge closes it.
30const DEFAULT_TIMEOUT_SECS: u64 = 120;
31/// Default max requests per connection advertised by the `Keep-Alive` header.
32const DEFAULT_MAX_REQUESTS: u64 = 1000;
33
34/// Is the keepalive hint enabled? Reads `MOCKFORGE_HTTP_KEEPALIVE_HINT` once
35/// per startup process. Truthy values: `1`, `true`, `yes`, `on`.
36pub fn is_keepalive_hint_enabled() -> bool {
37    std::env::var("MOCKFORGE_HTTP_KEEPALIVE_HINT")
38        .ok()
39        .map(|v| matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on"))
40        .unwrap_or(false)
41}
42
43/// Read the advertised timeout in seconds. Falls back to the default.
44fn keepalive_timeout_secs() -> u64 {
45    std::env::var("MOCKFORGE_HTTP_KEEPALIVE_TIMEOUT_SECS")
46        .ok()
47        .and_then(|v| v.parse::<u64>().ok())
48        .unwrap_or(DEFAULT_TIMEOUT_SECS)
49}
50
51/// Read the advertised max requests per connection. Falls back to the default.
52fn keepalive_max_requests() -> u64 {
53    std::env::var("MOCKFORGE_HTTP_KEEPALIVE_MAX_REQUESTS")
54        .ok()
55        .and_then(|v| v.parse::<u64>().ok())
56        .unwrap_or(DEFAULT_MAX_REQUESTS)
57}
58
59/// Middleware: stamp `Connection: keep-alive` and `Keep-Alive: timeout=…,
60/// max=…` on every response. Does NOT override an upstream-set
61/// `Connection: close` header.
62pub async fn keepalive_hint_middleware(req: Request<Body>, next: Next) -> Response<Body> {
63    let mut response = next.run(req).await;
64
65    // Don't undo an explicit close; if downstream code already decided to
66    // close, leave it alone.
67    let already_close = response
68        .headers()
69        .get(http::header::CONNECTION)
70        .and_then(|v| v.to_str().ok())
71        .map(|s| s.to_ascii_lowercase().contains("close"))
72        .unwrap_or(false);
73    if already_close {
74        return response;
75    }
76
77    response
78        .headers_mut()
79        .insert(http::header::CONNECTION, HeaderValue::from_static("keep-alive"));
80
81    let header_value =
82        format!("timeout={}, max={}", keepalive_timeout_secs(), keepalive_max_requests());
83    if let Ok(v) = HeaderValue::from_str(&header_value) {
84        // `Keep-Alive` (case-sensitive in HeaderName) is a hop-by-hop header
85        // some intermediaries strip, but the ones we care about (F5, Avi,
86        // nginx) preserve it for their pool decisions.
87        response.headers_mut().insert("keep-alive", v);
88    }
89
90    response
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use axum::{routing::get, Router};
97    use tower::ServiceExt;
98
99    #[tokio::test]
100    async fn middleware_adds_keepalive_headers() {
101        let app: Router = Router::new()
102            .route("/", get(|| async { "ok" }))
103            .layer(axum::middleware::from_fn(keepalive_hint_middleware));
104
105        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
106        let res = app.oneshot(req).await.unwrap();
107
108        assert_eq!(res.headers().get(http::header::CONNECTION).unwrap(), "keep-alive");
109        let ka = res.headers().get("keep-alive").unwrap().to_str().unwrap();
110        assert!(ka.contains("timeout="));
111        assert!(ka.contains("max="));
112    }
113
114    #[tokio::test]
115    async fn middleware_respects_existing_close_header() {
116        let app: Router = Router::new()
117            .route(
118                "/",
119                get(|| async {
120                    let mut res = Response::new(Body::from("bye"));
121                    res.headers_mut()
122                        .insert(http::header::CONNECTION, HeaderValue::from_static("close"));
123                    res
124                }),
125            )
126            .layer(axum::middleware::from_fn(keepalive_hint_middleware));
127
128        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
129        let res = app.oneshot(req).await.unwrap();
130
131        assert_eq!(res.headers().get(http::header::CONNECTION).unwrap(), "close");
132        assert!(res.headers().get("keep-alive").is_none());
133    }
134}