Skip to main content

mcpr_core/proxy/pipeline/middlewares/
health_track.rs

1//! Response-side middleware: flip the shared proxy-health flag to
2//! "connected" on a successful `initialize` response.
3
4use async_trait::async_trait;
5
6use crate::protocol::mcp::{ClientMethod, LifecycleMethod};
7use crate::proxy::lock_health;
8use crate::proxy::pipeline::middleware::ResponseMiddleware;
9use crate::proxy::pipeline::values::{Context, Response};
10
11pub struct HealthTrackMiddleware;
12
13#[async_trait]
14impl ResponseMiddleware for HealthTrackMiddleware {
15    fn name(&self) -> &'static str {
16        "health_track"
17    }
18
19    async fn on_response(&self, resp: Response, cx: &mut Context) -> Response {
20        let status = match &resp {
21            Response::McpBuffered { status, .. } | Response::McpStreamed { status, .. } => {
22                Some(status.as_u16())
23            }
24            Response::Upstream502 { .. } => None,
25            _ => return resp,
26        };
27        let is_init = matches!(
28            cx.working.request_method,
29            Some(ClientMethod::Lifecycle(LifecycleMethod::Initialize))
30        );
31        if let Some(code) = status
32            && code < 400
33            && is_init
34        {
35            lock_health(&cx.intake.proxy.health).confirm_mcp_connected();
36        }
37        // TODO: record per-request success / failure once the counter
38        // API lands on ProxyHealth.
39        resp
40    }
41}
42
43#[cfg(test)]
44#[allow(non_snake_case)]
45mod tests {
46    use super::*;
47
48    use axum::body::Body;
49    use axum::http::{HeaderMap, StatusCode};
50
51    use crate::protocol::mcp::{ClientMethod, ToolsMethod};
52    use crate::proxy::lock_health;
53    use crate::proxy::pipeline::middlewares::test_support::{
54        mcp_buffered_response, set_request_method, test_context, test_proxy_state,
55    };
56    use crate::proxy::pipeline::values::Envelope;
57
58    fn mcp_connected(proxy: &crate::proxy::ProxyState) -> bool {
59        matches!(
60            lock_health(&proxy.health).mcp_status,
61            crate::proxy::ConnectionStatus::Connected
62        )
63    }
64
65    #[tokio::test]
66    async fn on_response__init_200_confirms_connected() {
67        let proxy = test_proxy_state();
68        let mut cx = test_context(proxy.clone());
69        set_request_method(
70            &mut cx,
71            ClientMethod::Lifecycle(LifecycleMethod::Initialize),
72        );
73        let resp = mcp_buffered_response(r#"{"jsonrpc":"2.0","id":1,"result":{}}"#, StatusCode::OK);
74
75        HealthTrackMiddleware.on_response(resp, &mut cx).await;
76        assert!(mcp_connected(&proxy));
77    }
78
79    #[tokio::test]
80    async fn on_response__non_init_200_does_not_confirm() {
81        let proxy = test_proxy_state();
82        let mut cx = test_context(proxy.clone());
83        set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::List));
84        let resp = mcp_buffered_response(
85            r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#,
86            StatusCode::OK,
87        );
88
89        HealthTrackMiddleware.on_response(resp, &mut cx).await;
90        assert!(!mcp_connected(&proxy));
91    }
92
93    #[tokio::test]
94    async fn on_response__init_4xx_does_not_confirm() {
95        let proxy = test_proxy_state();
96        let mut cx = test_context(proxy.clone());
97        set_request_method(
98            &mut cx,
99            ClientMethod::Lifecycle(LifecycleMethod::Initialize),
100        );
101        let resp = mcp_buffered_response(
102            r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"bad"}}"#,
103            StatusCode::BAD_REQUEST,
104        );
105
106        HealthTrackMiddleware.on_response(resp, &mut cx).await;
107        assert!(!mcp_connected(&proxy));
108    }
109
110    #[tokio::test]
111    async fn on_response__streamed_200_does_not_confirm_non_init() {
112        let proxy = test_proxy_state();
113        let mut cx = test_context(proxy.clone());
114        set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::Call));
115        let resp = Response::McpStreamed {
116            envelope: Envelope::Json,
117            body: Body::empty(),
118            status: StatusCode::OK,
119            headers: HeaderMap::new(),
120        };
121
122        HealthTrackMiddleware.on_response(resp, &mut cx).await;
123        assert!(!mcp_connected(&proxy));
124    }
125
126    #[tokio::test]
127    async fn on_response__upstream502_passthrough() {
128        let proxy = test_proxy_state();
129        let mut cx = test_context(proxy.clone());
130        set_request_method(
131            &mut cx,
132            ClientMethod::Lifecycle(LifecycleMethod::Initialize),
133        );
134        let resp = Response::Upstream502 {
135            reason: "down".into(),
136        };
137
138        let out = HealthTrackMiddleware.on_response(resp, &mut cx).await;
139        assert!(matches!(out, Response::Upstream502 { .. }));
140        assert!(!mcp_connected(&proxy));
141    }
142
143    #[tokio::test]
144    async fn on_response__raw_passthrough() {
145        let proxy = test_proxy_state();
146        let mut cx = test_context(proxy.clone());
147        let resp = Response::Raw {
148            body: Body::empty(),
149            status: StatusCode::OK,
150            headers: HeaderMap::new(),
151        };
152        let out = HealthTrackMiddleware.on_response(resp, &mut cx).await;
153        assert!(matches!(out, Response::Raw { .. }));
154    }
155}