mcpr_core/proxy/pipeline/middlewares/
health_track.rs1use async_trait::async_trait;
5
6use crate::protocol::mcp::{ClientMethod, LifecycleMethod};
7use crate::proxy::lock_health;
8use crate::proxy::pipeline::middleware::ResponseMiddleware;
9use crate::proxy::pipeline::values::{Context, Response};
10
11pub struct HealthTrackMiddleware;
12
13#[async_trait]
14impl ResponseMiddleware for HealthTrackMiddleware {
15 fn name(&self) -> &'static str {
16 "health_track"
17 }
18
19 async fn on_response(&self, resp: Response, cx: &mut Context) -> Response {
20 let status = match &resp {
21 Response::McpBuffered { status, .. } | Response::McpStreamed { status, .. } => {
22 Some(status.as_u16())
23 }
24 Response::Upstream502 { .. } => None,
25 _ => return resp,
26 };
27 let is_init = matches!(
28 cx.working.request_method,
29 Some(ClientMethod::Lifecycle(LifecycleMethod::Initialize))
30 );
31 if let Some(code) = status
32 && code < 400
33 && is_init
34 {
35 lock_health(&cx.intake.proxy.health).confirm_mcp_connected();
36 }
37 resp
40 }
41}
42
43#[cfg(test)]
44#[allow(non_snake_case)]
45mod tests {
46 use super::*;
47
48 use axum::body::Body;
49 use axum::http::{HeaderMap, StatusCode};
50
51 use crate::protocol::mcp::{ClientMethod, ToolsMethod};
52 use crate::proxy::lock_health;
53 use crate::proxy::pipeline::middlewares::test_support::{
54 mcp_buffered_response, set_request_method, test_context, test_proxy_state,
55 };
56 use crate::proxy::pipeline::values::Envelope;
57
58 fn mcp_connected(proxy: &crate::proxy::ProxyState) -> bool {
59 matches!(
60 lock_health(&proxy.health).mcp_status,
61 crate::proxy::ConnectionStatus::Connected
62 )
63 }
64
65 #[tokio::test]
66 async fn on_response__init_200_confirms_connected() {
67 let proxy = test_proxy_state();
68 let mut cx = test_context(proxy.clone());
69 set_request_method(
70 &mut cx,
71 ClientMethod::Lifecycle(LifecycleMethod::Initialize),
72 );
73 let resp = mcp_buffered_response(r#"{"jsonrpc":"2.0","id":1,"result":{}}"#, StatusCode::OK);
74
75 HealthTrackMiddleware.on_response(resp, &mut cx).await;
76 assert!(mcp_connected(&proxy));
77 }
78
79 #[tokio::test]
80 async fn on_response__non_init_200_does_not_confirm() {
81 let proxy = test_proxy_state();
82 let mut cx = test_context(proxy.clone());
83 set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::List));
84 let resp = mcp_buffered_response(
85 r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#,
86 StatusCode::OK,
87 );
88
89 HealthTrackMiddleware.on_response(resp, &mut cx).await;
90 assert!(!mcp_connected(&proxy));
91 }
92
93 #[tokio::test]
94 async fn on_response__init_4xx_does_not_confirm() {
95 let proxy = test_proxy_state();
96 let mut cx = test_context(proxy.clone());
97 set_request_method(
98 &mut cx,
99 ClientMethod::Lifecycle(LifecycleMethod::Initialize),
100 );
101 let resp = mcp_buffered_response(
102 r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"bad"}}"#,
103 StatusCode::BAD_REQUEST,
104 );
105
106 HealthTrackMiddleware.on_response(resp, &mut cx).await;
107 assert!(!mcp_connected(&proxy));
108 }
109
110 #[tokio::test]
111 async fn on_response__streamed_200_does_not_confirm_non_init() {
112 let proxy = test_proxy_state();
113 let mut cx = test_context(proxy.clone());
114 set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::Call));
115 let resp = Response::McpStreamed {
116 envelope: Envelope::Json,
117 body: Body::empty(),
118 status: StatusCode::OK,
119 headers: HeaderMap::new(),
120 };
121
122 HealthTrackMiddleware.on_response(resp, &mut cx).await;
123 assert!(!mcp_connected(&proxy));
124 }
125
126 #[tokio::test]
127 async fn on_response__upstream502_passthrough() {
128 let proxy = test_proxy_state();
129 let mut cx = test_context(proxy.clone());
130 set_request_method(
131 &mut cx,
132 ClientMethod::Lifecycle(LifecycleMethod::Initialize),
133 );
134 let resp = Response::Upstream502 {
135 reason: "down".into(),
136 };
137
138 let out = HealthTrackMiddleware.on_response(resp, &mut cx).await;
139 assert!(matches!(out, Response::Upstream502 { .. }));
140 assert!(!mcp_connected(&proxy));
141 }
142
143 #[tokio::test]
144 async fn on_response__raw_passthrough() {
145 let proxy = test_proxy_state();
146 let mut cx = test_context(proxy.clone());
147 let resp = Response::Raw {
148 body: Body::empty(),
149 status: StatusCode::OK,
150 headers: HeaderMap::new(),
151 };
152 let out = HealthTrackMiddleware.on_response(resp, &mut cx).await;
153 assert!(matches!(out, Response::Raw { .. }));
154 }
155}