Skip to main content

mcpr_core/proxy/pipeline/middlewares/
target_extract.rs

1//! Request-side middleware: extract the per-method target identifier
2//! from JSON-RPC `params` and stash it on `Working`.
3//!
4//! Each MCP request kind that operates on a named entity carries an
5//! identifier in `params`. Capturing them here lets `emit` populate
6//! `RequestEvent` so downstream consumers (SQLite log, cloud dashboard)
7//! can group, filter, and surface per-target metrics:
8//!
9//! | Method                  | Identifier in `params` | Stashed on `Working` as |
10//! |-------------------------|------------------------|--------------------------|
11//! | `tools/call`            | `name`                 | `request_tool`           |
12//! | `resources/read`        | `uri`                  | `request_resource_uri`   |
13//! | `resources/subscribe`   | `uri`                  | `request_resource_uri`   |
14//! | `resources/unsubscribe` | `uri`                  | `request_resource_uri`   |
15//! | `prompts/get`           | `name`                 | `request_prompt_name`    |
16//!
17//! Methods without a useful identifier (`tools/list`, `resources/list`,
18//! `initialize`, …) are no-ops.
19
20use async_trait::async_trait;
21use serde::Deserialize;
22
23use crate::protocol::mcp::{ClientKind, ClientMethod, PromptsMethod, ResourcesMethod, ToolsMethod};
24use crate::proxy::pipeline::middleware::{Flow, RequestMiddleware};
25use crate::proxy::pipeline::values::{Context, Request};
26
27pub struct TargetExtractMiddleware;
28
29#[derive(Deserialize)]
30struct NameParams {
31    name: String,
32}
33
34#[derive(Deserialize)]
35struct UriParams {
36    uri: String,
37}
38
39#[async_trait]
40impl RequestMiddleware for TargetExtractMiddleware {
41    fn name(&self) -> &'static str {
42        "target_extract"
43    }
44
45    async fn on_request(&self, req: Request, cx: &mut Context) -> Flow {
46        let Request::Mcp(ref mcp) = req else {
47            return Flow::Continue(req);
48        };
49        let ClientKind::Request(method) = &mcp.kind else {
50            return Flow::Continue(req);
51        };
52
53        match method {
54            ClientMethod::Tools(ToolsMethod::Call) => {
55                if let Some(p) = mcp.envelope.params_as::<NameParams>() {
56                    cx.working.request_tool = Some(p.name);
57                }
58            }
59            ClientMethod::Resources(
60                ResourcesMethod::Read | ResourcesMethod::Subscribe | ResourcesMethod::Unsubscribe,
61            ) => {
62                if let Some(p) = mcp.envelope.params_as::<UriParams>() {
63                    cx.working.request_resource_uri = Some(p.uri);
64                }
65            }
66            ClientMethod::Prompts(PromptsMethod::Get) => {
67                if let Some(p) = mcp.envelope.params_as::<NameParams>() {
68                    cx.working.request_prompt_name = Some(p.name);
69                }
70            }
71            _ => {}
72        }
73
74        Flow::Continue(req)
75    }
76}
77
78#[cfg(test)]
79#[allow(non_snake_case)]
80mod tests {
81    use super::*;
82
83    use axum::body::Body;
84    use axum::http::{HeaderMap, Method};
85    use serde_json::json;
86
87    use crate::proxy::pipeline::middlewares::test_support::{
88        mcp_request, test_context, test_proxy_state,
89    };
90    use crate::proxy::pipeline::values::RawRequest;
91
92    // ── tools/call → request_tool ──────────────────────────────
93
94    #[tokio::test]
95    async fn on_request__tools_call_stashes_tool_name() {
96        let proxy = test_proxy_state();
97        let mut cx = test_context(proxy);
98        let req = mcp_request(
99            "tools/call",
100            json!({"name": "weather", "arguments": {"city": "Paris"}}),
101            None,
102        );
103
104        TargetExtractMiddleware.on_request(req, &mut cx).await;
105        assert_eq!(cx.working.request_tool.as_deref(), Some("weather"));
106    }
107
108    #[tokio::test]
109    async fn on_request__tools_call_missing_name_leaves_none() {
110        let proxy = test_proxy_state();
111        let mut cx = test_context(proxy);
112        let req = mcp_request("tools/call", json!({"arguments": {"city": "Paris"}}), None);
113
114        TargetExtractMiddleware.on_request(req, &mut cx).await;
115        assert!(cx.working.request_tool.is_none());
116    }
117
118    #[tokio::test]
119    async fn on_request__tools_call_empty_name_stashes_empty_string() {
120        let proxy = test_proxy_state();
121        let mut cx = test_context(proxy);
122        let req = mcp_request("tools/call", json!({"name": ""}), None);
123
124        TargetExtractMiddleware.on_request(req, &mut cx).await;
125        assert_eq!(cx.working.request_tool.as_deref(), Some(""));
126    }
127
128    // ── resources/* → request_resource_uri ─────────────────────
129
130    #[tokio::test]
131    async fn on_request__resources_read_stashes_uri() {
132        let proxy = test_proxy_state();
133        let mut cx = test_context(proxy);
134        let req = mcp_request("resources/read", json!({"uri": "file:///etc/hosts"}), None);
135
136        TargetExtractMiddleware.on_request(req, &mut cx).await;
137        assert_eq!(
138            cx.working.request_resource_uri.as_deref(),
139            Some("file:///etc/hosts")
140        );
141        assert!(cx.working.request_tool.is_none());
142        assert!(cx.working.request_prompt_name.is_none());
143    }
144
145    #[tokio::test]
146    async fn on_request__resources_subscribe_stashes_uri() {
147        let proxy = test_proxy_state();
148        let mut cx = test_context(proxy);
149        let req = mcp_request("resources/subscribe", json!({"uri": "logs://stream"}), None);
150
151        TargetExtractMiddleware.on_request(req, &mut cx).await;
152        assert_eq!(
153            cx.working.request_resource_uri.as_deref(),
154            Some("logs://stream")
155        );
156    }
157
158    #[tokio::test]
159    async fn on_request__resources_unsubscribe_stashes_uri() {
160        let proxy = test_proxy_state();
161        let mut cx = test_context(proxy);
162        let req = mcp_request(
163            "resources/unsubscribe",
164            json!({"uri": "logs://stream"}),
165            None,
166        );
167
168        TargetExtractMiddleware.on_request(req, &mut cx).await;
169        assert_eq!(
170            cx.working.request_resource_uri.as_deref(),
171            Some("logs://stream")
172        );
173    }
174
175    #[tokio::test]
176    async fn on_request__resources_read_missing_uri_leaves_none() {
177        let proxy = test_proxy_state();
178        let mut cx = test_context(proxy);
179        let req = mcp_request("resources/read", json!({}), None);
180
181        TargetExtractMiddleware.on_request(req, &mut cx).await;
182        assert!(cx.working.request_resource_uri.is_none());
183    }
184
185    // ── prompts/get → request_prompt_name ──────────────────────
186
187    #[tokio::test]
188    async fn on_request__prompts_get_stashes_prompt_name() {
189        let proxy = test_proxy_state();
190        let mut cx = test_context(proxy);
191        let req = mcp_request(
192            "prompts/get",
193            json!({"name": "code_review", "arguments": {}}),
194            None,
195        );
196
197        TargetExtractMiddleware.on_request(req, &mut cx).await;
198        assert_eq!(
199            cx.working.request_prompt_name.as_deref(),
200            Some("code_review")
201        );
202        assert!(cx.working.request_tool.is_none());
203        assert!(cx.working.request_resource_uri.is_none());
204    }
205
206    // ── methods without identifiers ────────────────────────────
207
208    #[tokio::test]
209    async fn on_request__tools_list_is_noop() {
210        let proxy = test_proxy_state();
211        let mut cx = test_context(proxy);
212        let req = mcp_request("tools/list", serde_json::Value::Null, None);
213
214        TargetExtractMiddleware.on_request(req, &mut cx).await;
215        assert!(cx.working.request_tool.is_none());
216        assert!(cx.working.request_resource_uri.is_none());
217        assert!(cx.working.request_prompt_name.is_none());
218    }
219
220    #[tokio::test]
221    async fn on_request__resources_list_is_noop() {
222        let proxy = test_proxy_state();
223        let mut cx = test_context(proxy);
224        let req = mcp_request("resources/list", serde_json::Value::Null, None);
225
226        TargetExtractMiddleware.on_request(req, &mut cx).await;
227        assert!(cx.working.request_resource_uri.is_none());
228    }
229
230    #[tokio::test]
231    async fn on_request__initialize_is_noop() {
232        let proxy = test_proxy_state();
233        let mut cx = test_context(proxy);
234        let req = mcp_request("initialize", json!({"protocolVersion": "2025-11-25"}), None);
235
236        TargetExtractMiddleware.on_request(req, &mut cx).await;
237        assert!(cx.working.request_tool.is_none());
238        assert!(cx.working.request_resource_uri.is_none());
239        assert!(cx.working.request_prompt_name.is_none());
240    }
241
242    #[tokio::test]
243    async fn on_request__non_mcp_passthrough() {
244        let proxy = test_proxy_state();
245        let mut cx = test_context(proxy);
246        let req = Request::Raw(RawRequest {
247            method: Method::GET,
248            path: "/health".into(),
249            body: Body::empty(),
250            headers: HeaderMap::new(),
251        });
252
253        TargetExtractMiddleware.on_request(req, &mut cx).await;
254        assert!(cx.working.request_tool.is_none());
255        assert!(cx.working.request_resource_uri.is_none());
256        assert!(cx.working.request_prompt_name.is_none());
257    }
258}