Skip to main content

mcp_proxy/
inject.rs

1//! Argument injection middleware for tool calls.
2//!
3//! Merges default or per-tool arguments into tool call requests before they
4//! reach the backend. Useful for injecting timeouts, safety caps, or
5//! read-only flags without requiring clients to set them.
6//!
7//! # Configuration
8//!
9//! ```toml
10//! [[backends]]
11//! name = "db"
12//! transport = "http"
13//! url = "http://db.internal:8080"
14//!
15//! # Inject into all tool calls for this backend
16//! [backends.default_args]
17//! timeout = 30
18//!
19//! # Inject into a specific tool (overrides default_args for matching keys)
20//! [[backends.inject_args]]
21//! tool = "query"
22//! args = { read_only = true, max_rows = 1000 }
23//!
24//! # Overwrite existing arguments
25//! [[backends.inject_args]]
26//! tool = "dangerous_op"
27//! args = { dry_run = true }
28//! overwrite = true
29//! ```
30
31use std::collections::HashMap;
32use std::convert::Infallible;
33use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36use std::task::{Context, Poll};
37
38use tower::{Layer, Service};
39use tower_mcp::router::{RouterRequest, RouterResponse};
40use tower_mcp_types::protocol::McpRequest;
41
42/// Tower layer that produces an [`InjectArgsService`].
43#[derive(Clone)]
44pub struct InjectArgsLayer {
45    rules: Vec<InjectionRules>,
46}
47
48impl InjectArgsLayer {
49    /// Create a new argument injection layer.
50    pub fn new(rules: Vec<InjectionRules>) -> Self {
51        Self { rules }
52    }
53}
54
55impl<S> Layer<S> for InjectArgsLayer {
56    type Service = InjectArgsService<S>;
57
58    fn layer(&self, inner: S) -> Self::Service {
59        InjectArgsService::new(inner, self.rules.clone())
60    }
61}
62
63/// Per-tool injection rule.
64#[derive(Debug, Clone)]
65struct ToolInjection {
66    args: serde_json::Map<String, serde_json::Value>,
67    overwrite: bool,
68}
69
70/// Resolved injection rules for a single backend namespace.
71#[derive(Debug, Clone)]
72pub struct InjectionRules {
73    /// Namespace prefix (e.g. "db/").
74    namespace: String,
75    /// Default args applied to all tools in this namespace.
76    default_args: serde_json::Map<String, serde_json::Value>,
77    /// Per-tool overrides keyed by namespaced tool name (e.g. "db/query").
78    tool_rules: HashMap<String, ToolInjection>,
79}
80
81impl InjectionRules {
82    /// Create injection rules for a backend.
83    pub fn new(
84        namespace: String,
85        default_args: serde_json::Map<String, serde_json::Value>,
86        tool_rules: Vec<crate::config::InjectArgsConfig>,
87    ) -> Self {
88        let tool_rules = tool_rules
89            .into_iter()
90            .map(|r| {
91                let namespaced = format!("{namespace}{}", r.tool);
92                (
93                    namespaced,
94                    ToolInjection {
95                        args: r.args,
96                        overwrite: r.overwrite,
97                    },
98                )
99            })
100            .collect();
101
102        Self {
103            namespace,
104            default_args,
105            tool_rules,
106        }
107    }
108}
109
110/// Argument injection middleware.
111///
112/// Intercepts `CallTool` requests and merges configured arguments into
113/// the tool call arguments before forwarding to the inner service.
114#[derive(Clone)]
115pub struct InjectArgsService<S> {
116    inner: S,
117    rules: Arc<Vec<InjectionRules>>,
118}
119
120impl<S> InjectArgsService<S> {
121    /// Create a new argument injection service.
122    pub fn new(inner: S, rules: Vec<InjectionRules>) -> Self {
123        Self {
124            inner,
125            rules: Arc::new(rules),
126        }
127    }
128}
129
130/// Merge source args into target. If `overwrite` is false, existing keys
131/// in target are preserved.
132fn merge_args(
133    target: &mut serde_json::Value,
134    source: &serde_json::Map<String, serde_json::Value>,
135    overwrite: bool,
136) {
137    if let serde_json::Value::Object(map) = target {
138        for (key, value) in source {
139            if overwrite || !map.contains_key(key) {
140                map.insert(key.clone(), value.clone());
141            }
142        }
143    }
144}
145
146impl<S> Service<RouterRequest> for InjectArgsService<S>
147where
148    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
149        + Clone
150        + Send
151        + 'static,
152    S::Future: Send,
153{
154    type Response = RouterResponse;
155    type Error = Infallible;
156    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
157
158    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
159        self.inner.poll_ready(cx)
160    }
161
162    fn call(&mut self, mut req: RouterRequest) -> Self::Future {
163        if let McpRequest::CallTool(ref mut params) = req.inner {
164            for rules in self.rules.iter() {
165                if !params.name.starts_with(&rules.namespace) {
166                    continue;
167                }
168
169                // Apply default args (never overwrite)
170                if !rules.default_args.is_empty() {
171                    merge_args(&mut params.arguments, &rules.default_args, false);
172                }
173
174                // Apply per-tool rules
175                if let Some(tool_rule) = rules.tool_rules.get(&params.name) {
176                    merge_args(&mut params.arguments, &tool_rule.args, tool_rule.overwrite);
177                }
178
179                break; // Only match one namespace
180            }
181        }
182
183        let fut = self.inner.call(req);
184        Box::pin(fut)
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::config::InjectArgsConfig;
192    use crate::test_util::{MockService, call_service};
193    use tower_mcp_types::protocol::{CallToolParams, McpRequest};
194
195    fn make_rules(
196        namespace: &str,
197        default_args: serde_json::Map<String, serde_json::Value>,
198        tool_rules: Vec<InjectArgsConfig>,
199    ) -> Vec<InjectionRules> {
200        vec![InjectionRules::new(
201            namespace.to_string(),
202            default_args,
203            tool_rules,
204        )]
205    }
206
207    #[tokio::test]
208    async fn test_injects_default_args() {
209        let mock = MockService::with_tools(&["db/query"]);
210        let mut defaults = serde_json::Map::new();
211        defaults.insert("timeout".to_string(), serde_json::json!(30));
212
213        let rules = make_rules("db/", defaults, vec![]);
214        let mut svc = InjectArgsService::new(mock, rules);
215
216        let resp = call_service(
217            &mut svc,
218            McpRequest::CallTool(CallToolParams {
219                name: "db/query".to_string(),
220                arguments: serde_json::json!({"sql": "SELECT 1"}),
221                meta: None,
222                task: None,
223            }),
224        )
225        .await;
226
227        // The mock returns "called: db/query" but we can verify it didn't error
228        assert!(resp.inner.is_ok());
229    }
230
231    #[tokio::test]
232    async fn test_default_args_dont_overwrite() {
233        let mock = MockService::with_tools(&["db/query"]);
234        let mut defaults = serde_json::Map::new();
235        defaults.insert("timeout".to_string(), serde_json::json!(30));
236
237        let rules = make_rules("db/", defaults, vec![]);
238        let _svc = InjectArgsService::new(mock, rules);
239
240        // Create a request that already has timeout=60
241        let mut req = RouterRequest {
242            id: tower_mcp::protocol::RequestId::Number(1),
243            inner: McpRequest::CallTool(CallToolParams {
244                name: "db/query".to_string(),
245                arguments: serde_json::json!({"sql": "SELECT 1", "timeout": 60}),
246                meta: None,
247                task: None,
248            }),
249            extensions: tower_mcp::router::Extensions::new(),
250        };
251
252        // Manually apply the injection to verify merge behavior
253        if let McpRequest::CallTool(ref mut params) = req.inner {
254            let mut defaults = serde_json::Map::new();
255            defaults.insert("timeout".to_string(), serde_json::json!(30));
256            merge_args(&mut params.arguments, &defaults, false);
257
258            // timeout should still be 60 (not overwritten)
259            assert_eq!(params.arguments["timeout"], 60);
260            // sql should be preserved
261            assert_eq!(params.arguments["sql"], "SELECT 1");
262        }
263    }
264
265    #[tokio::test]
266    async fn test_per_tool_injection() {
267        let mock = MockService::with_tools(&["db/query"]);
268        let tool_rules = vec![InjectArgsConfig {
269            tool: "query".to_string(),
270            args: {
271                let mut m = serde_json::Map::new();
272                m.insert("read_only".to_string(), serde_json::json!(true));
273                m
274            },
275            overwrite: false,
276        }];
277
278        let rules = make_rules("db/", serde_json::Map::new(), tool_rules);
279        let mut svc = InjectArgsService::new(mock, rules);
280
281        let resp = call_service(
282            &mut svc,
283            McpRequest::CallTool(CallToolParams {
284                name: "db/query".to_string(),
285                arguments: serde_json::json!({"sql": "SELECT 1"}),
286                meta: None,
287                task: None,
288            }),
289        )
290        .await;
291
292        assert!(resp.inner.is_ok());
293    }
294
295    #[tokio::test]
296    async fn test_overwrite_mode() {
297        let mut args = serde_json::json!({"dry_run": false, "data": "hello"});
298        let mut inject = serde_json::Map::new();
299        inject.insert("dry_run".to_string(), serde_json::json!(true));
300
301        // Without overwrite
302        merge_args(&mut args, &inject, false);
303        assert_eq!(args["dry_run"], false); // preserved
304
305        // With overwrite
306        merge_args(&mut args, &inject, true);
307        assert_eq!(args["dry_run"], true); // overwritten
308        assert_eq!(args["data"], "hello"); // other fields preserved
309    }
310
311    #[tokio::test]
312    async fn test_non_matching_namespace_passes_through() {
313        let mock = MockService::with_tools(&["other/tool"]);
314        let mut defaults = serde_json::Map::new();
315        defaults.insert("timeout".to_string(), serde_json::json!(30));
316
317        let rules = make_rules("db/", defaults, vec![]);
318        let mut svc = InjectArgsService::new(mock, rules);
319
320        let resp = call_service(
321            &mut svc,
322            McpRequest::CallTool(CallToolParams {
323                name: "other/tool".to_string(),
324                arguments: serde_json::json!({}),
325                meta: None,
326                task: None,
327            }),
328        )
329        .await;
330
331        assert!(resp.inner.is_ok());
332    }
333
334    #[tokio::test]
335    async fn test_non_call_tool_passes_through() {
336        let mock = MockService::with_tools(&["db/query"]);
337        let mut defaults = serde_json::Map::new();
338        defaults.insert("timeout".to_string(), serde_json::json!(30));
339
340        let rules = make_rules("db/", defaults, vec![]);
341        let mut svc = InjectArgsService::new(mock, rules);
342
343        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
344        assert!(resp.inner.is_ok());
345    }
346
347    #[test]
348    fn test_merge_args_into_non_object() {
349        // If arguments isn't an object, merge is a no-op
350        let mut args = serde_json::json!("not an object");
351        let mut inject = serde_json::Map::new();
352        inject.insert("key".to_string(), serde_json::json!("value"));
353        merge_args(&mut args, &inject, false);
354        assert_eq!(args, serde_json::json!("not an object"));
355    }
356
357    #[test]
358    fn test_merge_args_adds_new_keys() {
359        let mut args = serde_json::json!({"existing": 1});
360        let mut inject = serde_json::Map::new();
361        inject.insert("new_key".to_string(), serde_json::json!(42));
362        merge_args(&mut args, &inject, false);
363        assert_eq!(args["existing"], 1);
364        assert_eq!(args["new_key"], 42);
365    }
366}