mcpr_core/proxy/pipeline/middlewares/
target_extract.rs1use async_trait::async_trait;
21use serde::Deserialize;
22
23use crate::protocol::mcp::{ClientKind, ClientMethod, PromptsMethod, ResourcesMethod, ToolsMethod};
24use crate::proxy::pipeline::middleware::{Flow, RequestMiddleware};
25use crate::proxy::pipeline::values::{Context, Request};
26
27pub struct TargetExtractMiddleware;
28
29#[derive(Deserialize)]
30struct NameParams {
31 name: String,
32}
33
34#[derive(Deserialize)]
35struct UriParams {
36 uri: String,
37}
38
39#[async_trait]
40impl RequestMiddleware for TargetExtractMiddleware {
41 fn name(&self) -> &'static str {
42 "target_extract"
43 }
44
45 async fn on_request(&self, req: Request, cx: &mut Context) -> Flow {
46 let Request::Mcp(ref mcp) = req else {
47 return Flow::Continue(req);
48 };
49 let ClientKind::Request(method) = &mcp.kind else {
50 return Flow::Continue(req);
51 };
52
53 match method {
54 ClientMethod::Tools(ToolsMethod::Call) => {
55 if let Some(p) = mcp.envelope.params_as::<NameParams>() {
56 cx.working.request_tool = Some(p.name);
57 }
58 }
59 ClientMethod::Resources(
60 ResourcesMethod::Read | ResourcesMethod::Subscribe | ResourcesMethod::Unsubscribe,
61 ) => {
62 if let Some(p) = mcp.envelope.params_as::<UriParams>() {
63 cx.working.request_resource_uri = Some(p.uri);
64 }
65 }
66 ClientMethod::Prompts(PromptsMethod::Get) => {
67 if let Some(p) = mcp.envelope.params_as::<NameParams>() {
68 cx.working.request_prompt_name = Some(p.name);
69 }
70 }
71 _ => {}
72 }
73
74 Flow::Continue(req)
75 }
76}
77
78#[cfg(test)]
79#[allow(non_snake_case)]
80mod tests {
81 use super::*;
82
83 use axum::body::Body;
84 use axum::http::{HeaderMap, Method};
85 use serde_json::json;
86
87 use crate::proxy::pipeline::middlewares::test_support::{
88 mcp_request, test_context, test_proxy_state,
89 };
90 use crate::proxy::pipeline::values::RawRequest;
91
92 #[tokio::test]
95 async fn on_request__tools_call_stashes_tool_name() {
96 let proxy = test_proxy_state();
97 let mut cx = test_context(proxy);
98 let req = mcp_request(
99 "tools/call",
100 json!({"name": "weather", "arguments": {"city": "Paris"}}),
101 None,
102 );
103
104 TargetExtractMiddleware.on_request(req, &mut cx).await;
105 assert_eq!(cx.working.request_tool.as_deref(), Some("weather"));
106 }
107
108 #[tokio::test]
109 async fn on_request__tools_call_missing_name_leaves_none() {
110 let proxy = test_proxy_state();
111 let mut cx = test_context(proxy);
112 let req = mcp_request("tools/call", json!({"arguments": {"city": "Paris"}}), None);
113
114 TargetExtractMiddleware.on_request(req, &mut cx).await;
115 assert!(cx.working.request_tool.is_none());
116 }
117
118 #[tokio::test]
119 async fn on_request__tools_call_empty_name_stashes_empty_string() {
120 let proxy = test_proxy_state();
121 let mut cx = test_context(proxy);
122 let req = mcp_request("tools/call", json!({"name": ""}), None);
123
124 TargetExtractMiddleware.on_request(req, &mut cx).await;
125 assert_eq!(cx.working.request_tool.as_deref(), Some(""));
126 }
127
128 #[tokio::test]
131 async fn on_request__resources_read_stashes_uri() {
132 let proxy = test_proxy_state();
133 let mut cx = test_context(proxy);
134 let req = mcp_request("resources/read", json!({"uri": "file:///etc/hosts"}), None);
135
136 TargetExtractMiddleware.on_request(req, &mut cx).await;
137 assert_eq!(
138 cx.working.request_resource_uri.as_deref(),
139 Some("file:///etc/hosts")
140 );
141 assert!(cx.working.request_tool.is_none());
142 assert!(cx.working.request_prompt_name.is_none());
143 }
144
145 #[tokio::test]
146 async fn on_request__resources_subscribe_stashes_uri() {
147 let proxy = test_proxy_state();
148 let mut cx = test_context(proxy);
149 let req = mcp_request("resources/subscribe", json!({"uri": "logs://stream"}), None);
150
151 TargetExtractMiddleware.on_request(req, &mut cx).await;
152 assert_eq!(
153 cx.working.request_resource_uri.as_deref(),
154 Some("logs://stream")
155 );
156 }
157
158 #[tokio::test]
159 async fn on_request__resources_unsubscribe_stashes_uri() {
160 let proxy = test_proxy_state();
161 let mut cx = test_context(proxy);
162 let req = mcp_request(
163 "resources/unsubscribe",
164 json!({"uri": "logs://stream"}),
165 None,
166 );
167
168 TargetExtractMiddleware.on_request(req, &mut cx).await;
169 assert_eq!(
170 cx.working.request_resource_uri.as_deref(),
171 Some("logs://stream")
172 );
173 }
174
175 #[tokio::test]
176 async fn on_request__resources_read_missing_uri_leaves_none() {
177 let proxy = test_proxy_state();
178 let mut cx = test_context(proxy);
179 let req = mcp_request("resources/read", json!({}), None);
180
181 TargetExtractMiddleware.on_request(req, &mut cx).await;
182 assert!(cx.working.request_resource_uri.is_none());
183 }
184
185 #[tokio::test]
188 async fn on_request__prompts_get_stashes_prompt_name() {
189 let proxy = test_proxy_state();
190 let mut cx = test_context(proxy);
191 let req = mcp_request(
192 "prompts/get",
193 json!({"name": "code_review", "arguments": {}}),
194 None,
195 );
196
197 TargetExtractMiddleware.on_request(req, &mut cx).await;
198 assert_eq!(
199 cx.working.request_prompt_name.as_deref(),
200 Some("code_review")
201 );
202 assert!(cx.working.request_tool.is_none());
203 assert!(cx.working.request_resource_uri.is_none());
204 }
205
206 #[tokio::test]
209 async fn on_request__tools_list_is_noop() {
210 let proxy = test_proxy_state();
211 let mut cx = test_context(proxy);
212 let req = mcp_request("tools/list", serde_json::Value::Null, None);
213
214 TargetExtractMiddleware.on_request(req, &mut cx).await;
215 assert!(cx.working.request_tool.is_none());
216 assert!(cx.working.request_resource_uri.is_none());
217 assert!(cx.working.request_prompt_name.is_none());
218 }
219
220 #[tokio::test]
221 async fn on_request__resources_list_is_noop() {
222 let proxy = test_proxy_state();
223 let mut cx = test_context(proxy);
224 let req = mcp_request("resources/list", serde_json::Value::Null, None);
225
226 TargetExtractMiddleware.on_request(req, &mut cx).await;
227 assert!(cx.working.request_resource_uri.is_none());
228 }
229
230 #[tokio::test]
231 async fn on_request__initialize_is_noop() {
232 let proxy = test_proxy_state();
233 let mut cx = test_context(proxy);
234 let req = mcp_request("initialize", json!({"protocolVersion": "2025-11-25"}), None);
235
236 TargetExtractMiddleware.on_request(req, &mut cx).await;
237 assert!(cx.working.request_tool.is_none());
238 assert!(cx.working.request_resource_uri.is_none());
239 assert!(cx.working.request_prompt_name.is_none());
240 }
241
242 #[tokio::test]
243 async fn on_request__non_mcp_passthrough() {
244 let proxy = test_proxy_state();
245 let mut cx = test_context(proxy);
246 let req = Request::Raw(RawRequest {
247 method: Method::GET,
248 path: "/health".into(),
249 body: Body::empty(),
250 headers: HeaderMap::new(),
251 });
252
253 TargetExtractMiddleware.on_request(req, &mut cx).await;
254 assert!(cx.working.request_tool.is_none());
255 assert!(cx.working.request_resource_uri.is_none());
256 assert!(cx.working.request_prompt_name.is_none());
257 }
258}