Skip to main content

mcp_proxy/
inject.rs

1//! Argument injection middleware for tool calls.
2//!
3//! Merges default or per-tool arguments into tool call requests before they
4//! reach the backend. Useful for injecting timeouts, safety caps, or
5//! read-only flags without requiring clients to set them.
6//!
7//! # Configuration
8//!
9//! ```toml
10//! [[backends]]
11//! name = "db"
12//! transport = "http"
13//! url = "http://db.internal:8080"
14//!
15//! # Inject into all tool calls for this backend
16//! [backends.default_args]
17//! timeout = 30
18//!
19//! # Inject into a specific tool (overrides default_args for matching keys)
20//! [[backends.inject_args]]
21//! tool = "query"
22//! args = { read_only = true, max_rows = 1000 }
23//!
24//! # Overwrite existing arguments
25//! [[backends.inject_args]]
26//! tool = "dangerous_op"
27//! args = { dry_run = true }
28//! overwrite = true
29//! ```
30
31use std::collections::HashMap;
32use std::convert::Infallible;
33use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36use std::task::{Context, Poll};
37
38use tower::Service;
39use tower_mcp::router::{RouterRequest, RouterResponse};
40use tower_mcp_types::protocol::McpRequest;
41
42/// Per-tool injection rule.
43#[derive(Debug, Clone)]
44struct ToolInjection {
45    args: serde_json::Map<String, serde_json::Value>,
46    overwrite: bool,
47}
48
49/// Resolved injection rules for a single backend namespace.
50#[derive(Debug, Clone)]
51pub struct InjectionRules {
52    /// Namespace prefix (e.g. "db/").
53    namespace: String,
54    /// Default args applied to all tools in this namespace.
55    default_args: serde_json::Map<String, serde_json::Value>,
56    /// Per-tool overrides keyed by namespaced tool name (e.g. "db/query").
57    tool_rules: HashMap<String, ToolInjection>,
58}
59
60impl InjectionRules {
61    /// Create injection rules for a backend.
62    pub fn new(
63        namespace: String,
64        default_args: serde_json::Map<String, serde_json::Value>,
65        tool_rules: Vec<crate::config::InjectArgsConfig>,
66    ) -> Self {
67        let tool_rules = tool_rules
68            .into_iter()
69            .map(|r| {
70                let namespaced = format!("{namespace}{}", r.tool);
71                (
72                    namespaced,
73                    ToolInjection {
74                        args: r.args,
75                        overwrite: r.overwrite,
76                    },
77                )
78            })
79            .collect();
80
81        Self {
82            namespace,
83            default_args,
84            tool_rules,
85        }
86    }
87}
88
89/// Argument injection middleware.
90///
91/// Intercepts `CallTool` requests and merges configured arguments into
92/// the tool call arguments before forwarding to the inner service.
93#[derive(Clone)]
94pub struct InjectArgsService<S> {
95    inner: S,
96    rules: Arc<Vec<InjectionRules>>,
97}
98
99impl<S> InjectArgsService<S> {
100    /// Create a new argument injection service.
101    pub fn new(inner: S, rules: Vec<InjectionRules>) -> Self {
102        Self {
103            inner,
104            rules: Arc::new(rules),
105        }
106    }
107}
108
109/// Merge source args into target. If `overwrite` is false, existing keys
110/// in target are preserved.
111fn merge_args(
112    target: &mut serde_json::Value,
113    source: &serde_json::Map<String, serde_json::Value>,
114    overwrite: bool,
115) {
116    if let serde_json::Value::Object(map) = target {
117        for (key, value) in source {
118            if overwrite || !map.contains_key(key) {
119                map.insert(key.clone(), value.clone());
120            }
121        }
122    }
123}
124
125impl<S> Service<RouterRequest> for InjectArgsService<S>
126where
127    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
128        + Clone
129        + Send
130        + 'static,
131    S::Future: Send,
132{
133    type Response = RouterResponse;
134    type Error = Infallible;
135    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
136
137    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138        self.inner.poll_ready(cx)
139    }
140
141    fn call(&mut self, mut req: RouterRequest) -> Self::Future {
142        if let McpRequest::CallTool(ref mut params) = req.inner {
143            for rules in self.rules.iter() {
144                if !params.name.starts_with(&rules.namespace) {
145                    continue;
146                }
147
148                // Apply default args (never overwrite)
149                if !rules.default_args.is_empty() {
150                    merge_args(&mut params.arguments, &rules.default_args, false);
151                }
152
153                // Apply per-tool rules
154                if let Some(tool_rule) = rules.tool_rules.get(&params.name) {
155                    merge_args(&mut params.arguments, &tool_rule.args, tool_rule.overwrite);
156                }
157
158                break; // Only match one namespace
159            }
160        }
161
162        let fut = self.inner.call(req);
163        Box::pin(fut)
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::config::InjectArgsConfig;
171    use crate::test_util::{MockService, call_service};
172    use tower_mcp_types::protocol::{CallToolParams, McpRequest};
173
174    fn make_rules(
175        namespace: &str,
176        default_args: serde_json::Map<String, serde_json::Value>,
177        tool_rules: Vec<InjectArgsConfig>,
178    ) -> Vec<InjectionRules> {
179        vec![InjectionRules::new(
180            namespace.to_string(),
181            default_args,
182            tool_rules,
183        )]
184    }
185
186    #[tokio::test]
187    async fn test_injects_default_args() {
188        let mock = MockService::with_tools(&["db/query"]);
189        let mut defaults = serde_json::Map::new();
190        defaults.insert("timeout".to_string(), serde_json::json!(30));
191
192        let rules = make_rules("db/", defaults, vec![]);
193        let mut svc = InjectArgsService::new(mock, rules);
194
195        let resp = call_service(
196            &mut svc,
197            McpRequest::CallTool(CallToolParams {
198                name: "db/query".to_string(),
199                arguments: serde_json::json!({"sql": "SELECT 1"}),
200                meta: None,
201                task: None,
202            }),
203        )
204        .await;
205
206        // The mock returns "called: db/query" but we can verify it didn't error
207        assert!(resp.inner.is_ok());
208    }
209
210    #[tokio::test]
211    async fn test_default_args_dont_overwrite() {
212        let mock = MockService::with_tools(&["db/query"]);
213        let mut defaults = serde_json::Map::new();
214        defaults.insert("timeout".to_string(), serde_json::json!(30));
215
216        let rules = make_rules("db/", defaults, vec![]);
217        let _svc = InjectArgsService::new(mock, rules);
218
219        // Create a request that already has timeout=60
220        let mut req = RouterRequest {
221            id: tower_mcp::protocol::RequestId::Number(1),
222            inner: McpRequest::CallTool(CallToolParams {
223                name: "db/query".to_string(),
224                arguments: serde_json::json!({"sql": "SELECT 1", "timeout": 60}),
225                meta: None,
226                task: None,
227            }),
228            extensions: tower_mcp::router::Extensions::new(),
229        };
230
231        // Manually apply the injection to verify merge behavior
232        if let McpRequest::CallTool(ref mut params) = req.inner {
233            let mut defaults = serde_json::Map::new();
234            defaults.insert("timeout".to_string(), serde_json::json!(30));
235            merge_args(&mut params.arguments, &defaults, false);
236
237            // timeout should still be 60 (not overwritten)
238            assert_eq!(params.arguments["timeout"], 60);
239            // sql should be preserved
240            assert_eq!(params.arguments["sql"], "SELECT 1");
241        }
242    }
243
244    #[tokio::test]
245    async fn test_per_tool_injection() {
246        let mock = MockService::with_tools(&["db/query"]);
247        let tool_rules = vec![InjectArgsConfig {
248            tool: "query".to_string(),
249            args: {
250                let mut m = serde_json::Map::new();
251                m.insert("read_only".to_string(), serde_json::json!(true));
252                m
253            },
254            overwrite: false,
255        }];
256
257        let rules = make_rules("db/", serde_json::Map::new(), tool_rules);
258        let mut svc = InjectArgsService::new(mock, rules);
259
260        let resp = call_service(
261            &mut svc,
262            McpRequest::CallTool(CallToolParams {
263                name: "db/query".to_string(),
264                arguments: serde_json::json!({"sql": "SELECT 1"}),
265                meta: None,
266                task: None,
267            }),
268        )
269        .await;
270
271        assert!(resp.inner.is_ok());
272    }
273
274    #[tokio::test]
275    async fn test_overwrite_mode() {
276        let mut args = serde_json::json!({"dry_run": false, "data": "hello"});
277        let mut inject = serde_json::Map::new();
278        inject.insert("dry_run".to_string(), serde_json::json!(true));
279
280        // Without overwrite
281        merge_args(&mut args, &inject, false);
282        assert_eq!(args["dry_run"], false); // preserved
283
284        // With overwrite
285        merge_args(&mut args, &inject, true);
286        assert_eq!(args["dry_run"], true); // overwritten
287        assert_eq!(args["data"], "hello"); // other fields preserved
288    }
289
290    #[tokio::test]
291    async fn test_non_matching_namespace_passes_through() {
292        let mock = MockService::with_tools(&["other/tool"]);
293        let mut defaults = serde_json::Map::new();
294        defaults.insert("timeout".to_string(), serde_json::json!(30));
295
296        let rules = make_rules("db/", defaults, vec![]);
297        let mut svc = InjectArgsService::new(mock, rules);
298
299        let resp = call_service(
300            &mut svc,
301            McpRequest::CallTool(CallToolParams {
302                name: "other/tool".to_string(),
303                arguments: serde_json::json!({}),
304                meta: None,
305                task: None,
306            }),
307        )
308        .await;
309
310        assert!(resp.inner.is_ok());
311    }
312
313    #[tokio::test]
314    async fn test_non_call_tool_passes_through() {
315        let mock = MockService::with_tools(&["db/query"]);
316        let mut defaults = serde_json::Map::new();
317        defaults.insert("timeout".to_string(), serde_json::json!(30));
318
319        let rules = make_rules("db/", defaults, vec![]);
320        let mut svc = InjectArgsService::new(mock, rules);
321
322        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
323        assert!(resp.inner.is_ok());
324    }
325
326    #[test]
327    fn test_merge_args_into_non_object() {
328        // If arguments isn't an object, merge is a no-op
329        let mut args = serde_json::json!("not an object");
330        let mut inject = serde_json::Map::new();
331        inject.insert("key".to_string(), serde_json::json!("value"));
332        merge_args(&mut args, &inject, false);
333        assert_eq!(args, serde_json::json!("not an object"));
334    }
335
336    #[test]
337    fn test_merge_args_adds_new_keys() {
338        let mut args = serde_json::json!({"existing": 1});
339        let mut inject = serde_json::Map::new();
340        inject.insert("new_key".to_string(), serde_json::json!(42));
341        merge_args(&mut args, &inject, false);
342        assert_eq!(args["existing"], 1);
343        assert_eq!(args["new_key"], 42);
344    }
345}