mcpr_core/proxy/pipeline/middlewares/
schema_stale.rs1use async_trait::async_trait;
11
12use crate::protocol::mcp::{MessageKind, ServerKind, ServerNotifMethod};
13use crate::proxy::pipeline::middleware::ResponseMiddleware;
14use crate::proxy::pipeline::values::{Context, Response};
15
16pub struct SchemaStaleMiddleware;
17
18#[async_trait]
19impl ResponseMiddleware for SchemaStaleMiddleware {
20 fn name(&self) -> &'static str {
21 "schema_stale"
22 }
23
24 async fn on_response(&self, resp: Response, cx: &mut Context) -> Response {
25 let message = match &resp {
26 Response::McpBuffered { message, .. } => message,
27 _ => return resp,
28 };
29 let MessageKind::Server(ServerKind::Notification(n)) = &message.kind else {
30 return resp;
31 };
32 let method = match n {
33 ServerNotifMethod::ToolsListChanged => "tools/list",
34 ServerNotifMethod::ResourcesListChanged => "resources/list",
35 ServerNotifMethod::PromptsListChanged => "prompts/list",
36 _ => return resp,
37 };
38 cx.intake.proxy.schema_manager.mark_stale(method);
39 resp
40 }
41}
42
43#[cfg(test)]
44#[allow(non_snake_case)]
45mod tests {
46 use super::*;
47
48 use axum::http::StatusCode;
49
50 use crate::proxy::pipeline::middlewares::test_support::{
51 mcp_buffered_response, test_context, test_proxy_state,
52 };
53
54 #[tokio::test]
55 async fn on_response__tools_list_changed_marks_stale() {
56 let proxy = test_proxy_state();
57 let mut cx = test_context(proxy.clone());
58 let resp = mcp_buffered_response(
59 r#"{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}"#,
60 StatusCode::OK,
61 );
62
63 SchemaStaleMiddleware.on_response(resp, &mut cx).await;
64 assert!(proxy.schema_manager.is_stale("tools/list"));
65 }
66
67 #[tokio::test]
68 async fn on_response__resources_list_changed_marks_stale() {
69 let proxy = test_proxy_state();
70 let mut cx = test_context(proxy.clone());
71 let resp = mcp_buffered_response(
72 r#"{"jsonrpc":"2.0","method":"notifications/resources/list_changed"}"#,
73 StatusCode::OK,
74 );
75
76 SchemaStaleMiddleware.on_response(resp, &mut cx).await;
77 assert!(proxy.schema_manager.is_stale("resources/list"));
78 }
79
80 #[tokio::test]
81 async fn on_response__prompts_list_changed_marks_stale() {
82 let proxy = test_proxy_state();
83 let mut cx = test_context(proxy.clone());
84 let resp = mcp_buffered_response(
85 r#"{"jsonrpc":"2.0","method":"notifications/prompts/list_changed"}"#,
86 StatusCode::OK,
87 );
88
89 SchemaStaleMiddleware.on_response(resp, &mut cx).await;
90 assert!(proxy.schema_manager.is_stale("prompts/list"));
91 }
92
93 #[tokio::test]
94 async fn on_response__unrelated_notification_passthrough() {
95 let proxy = test_proxy_state();
96 let mut cx = test_context(proxy.clone());
97 let resp = mcp_buffered_response(
98 r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"x","progress":1}}"#,
99 StatusCode::OK,
100 );
101
102 SchemaStaleMiddleware.on_response(resp, &mut cx).await;
103 assert!(!proxy.schema_manager.is_stale("tools/list"));
104 }
105
106 #[tokio::test]
107 async fn on_response__result_passthrough() {
108 let proxy = test_proxy_state();
109 let mut cx = test_context(proxy.clone());
110 let resp = mcp_buffered_response(
111 r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#,
112 StatusCode::OK,
113 );
114
115 SchemaStaleMiddleware.on_response(resp, &mut cx).await;
116 assert!(!proxy.schema_manager.is_stale("tools/list"));
117 }
118
119 #[tokio::test]
120 async fn on_response__non_buffered_passthrough() {
121 let proxy = test_proxy_state();
122 let mut cx = test_context(proxy.clone());
123 let resp = Response::Upstream502 {
124 reason: "boom".into(),
125 };
126
127 SchemaStaleMiddleware.on_response(resp, &mut cx).await;
128 assert!(!proxy.schema_manager.is_stale("tools/list"));
129 }
130}