Skip to main content

mcpr_core/proxy/pipeline/middlewares/
schema_stale.rs

1//! Response-side middleware: mark schema stale on server-initiated
2//! `list_changed` notifications. Pattern-matches on `ServerNotifMethod`
3//! directly — no JSON tree walk.
4//!
5//! Known limitation: fires only on `Response::McpBuffered`.
6//! `list_changed` notifications that arrive mid-stream inside
7//! `McpStreamed` bodies stay unobserved until server-push observability
8//! lands.
9
10use async_trait::async_trait;
11
12use crate::protocol::mcp::{MessageKind, ServerKind, ServerNotifMethod};
13use crate::proxy::pipeline::middleware::ResponseMiddleware;
14use crate::proxy::pipeline::values::{Context, Response};
15
16pub struct SchemaStaleMiddleware;
17
18#[async_trait]
19impl ResponseMiddleware for SchemaStaleMiddleware {
20    fn name(&self) -> &'static str {
21        "schema_stale"
22    }
23
24    async fn on_response(&self, resp: Response, cx: &mut Context) -> Response {
25        let message = match &resp {
26            Response::McpBuffered { message, .. } => message,
27            _ => return resp,
28        };
29        let MessageKind::Server(ServerKind::Notification(n)) = &message.kind else {
30            return resp;
31        };
32        let method = match n {
33            ServerNotifMethod::ToolsListChanged => "tools/list",
34            ServerNotifMethod::ResourcesListChanged => "resources/list",
35            ServerNotifMethod::PromptsListChanged => "prompts/list",
36            _ => return resp,
37        };
38        cx.intake.proxy.schema_manager.mark_stale(method);
39        resp
40    }
41}
42
43#[cfg(test)]
44#[allow(non_snake_case)]
45mod tests {
46    use super::*;
47
48    use axum::http::StatusCode;
49
50    use crate::proxy::pipeline::middlewares::test_support::{
51        mcp_buffered_response, test_context, test_proxy_state,
52    };
53
54    #[tokio::test]
55    async fn on_response__tools_list_changed_marks_stale() {
56        let proxy = test_proxy_state();
57        let mut cx = test_context(proxy.clone());
58        let resp = mcp_buffered_response(
59            r#"{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}"#,
60            StatusCode::OK,
61        );
62
63        SchemaStaleMiddleware.on_response(resp, &mut cx).await;
64        assert!(proxy.schema_manager.is_stale("tools/list"));
65    }
66
67    #[tokio::test]
68    async fn on_response__resources_list_changed_marks_stale() {
69        let proxy = test_proxy_state();
70        let mut cx = test_context(proxy.clone());
71        let resp = mcp_buffered_response(
72            r#"{"jsonrpc":"2.0","method":"notifications/resources/list_changed"}"#,
73            StatusCode::OK,
74        );
75
76        SchemaStaleMiddleware.on_response(resp, &mut cx).await;
77        assert!(proxy.schema_manager.is_stale("resources/list"));
78    }
79
80    #[tokio::test]
81    async fn on_response__prompts_list_changed_marks_stale() {
82        let proxy = test_proxy_state();
83        let mut cx = test_context(proxy.clone());
84        let resp = mcp_buffered_response(
85            r#"{"jsonrpc":"2.0","method":"notifications/prompts/list_changed"}"#,
86            StatusCode::OK,
87        );
88
89        SchemaStaleMiddleware.on_response(resp, &mut cx).await;
90        assert!(proxy.schema_manager.is_stale("prompts/list"));
91    }
92
93    #[tokio::test]
94    async fn on_response__unrelated_notification_passthrough() {
95        let proxy = test_proxy_state();
96        let mut cx = test_context(proxy.clone());
97        let resp = mcp_buffered_response(
98            r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"x","progress":1}}"#,
99            StatusCode::OK,
100        );
101
102        SchemaStaleMiddleware.on_response(resp, &mut cx).await;
103        assert!(!proxy.schema_manager.is_stale("tools/list"));
104    }
105
106    #[tokio::test]
107    async fn on_response__result_passthrough() {
108        let proxy = test_proxy_state();
109        let mut cx = test_context(proxy.clone());
110        let resp = mcp_buffered_response(
111            r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#,
112            StatusCode::OK,
113        );
114
115        SchemaStaleMiddleware.on_response(resp, &mut cx).await;
116        assert!(!proxy.schema_manager.is_stale("tools/list"));
117    }
118
119    #[tokio::test]
120    async fn on_response__non_buffered_passthrough() {
121        let proxy = test_proxy_state();
122        let mut cx = test_context(proxy.clone());
123        let resp = Response::Upstream502 {
124            reason: "boom".into(),
125        };
126
127        SchemaStaleMiddleware.on_response(resp, &mut cx).await;
128        assert!(!proxy.schema_manager.is_stale("tools/list"));
129    }
130}