1use std::collections::HashMap;
32use std::convert::Infallible;
33use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36use std::task::{Context, Poll};
37
38use tower::Service;
39use tower_mcp::router::{RouterRequest, RouterResponse};
40use tower_mcp_types::protocol::McpRequest;
41
42#[derive(Debug, Clone)]
44struct ToolInjection {
45 args: serde_json::Map<String, serde_json::Value>,
46 overwrite: bool,
47}
48
49#[derive(Debug, Clone)]
51pub struct InjectionRules {
52 namespace: String,
54 default_args: serde_json::Map<String, serde_json::Value>,
56 tool_rules: HashMap<String, ToolInjection>,
58}
59
60impl InjectionRules {
61 pub fn new(
63 namespace: String,
64 default_args: serde_json::Map<String, serde_json::Value>,
65 tool_rules: Vec<crate::config::InjectArgsConfig>,
66 ) -> Self {
67 let tool_rules = tool_rules
68 .into_iter()
69 .map(|r| {
70 let namespaced = format!("{namespace}{}", r.tool);
71 (
72 namespaced,
73 ToolInjection {
74 args: r.args,
75 overwrite: r.overwrite,
76 },
77 )
78 })
79 .collect();
80
81 Self {
82 namespace,
83 default_args,
84 tool_rules,
85 }
86 }
87}
88
89#[derive(Clone)]
94pub struct InjectArgsService<S> {
95 inner: S,
96 rules: Arc<Vec<InjectionRules>>,
97}
98
99impl<S> InjectArgsService<S> {
100 pub fn new(inner: S, rules: Vec<InjectionRules>) -> Self {
102 Self {
103 inner,
104 rules: Arc::new(rules),
105 }
106 }
107}
108
109fn merge_args(
112 target: &mut serde_json::Value,
113 source: &serde_json::Map<String, serde_json::Value>,
114 overwrite: bool,
115) {
116 if let serde_json::Value::Object(map) = target {
117 for (key, value) in source {
118 if overwrite || !map.contains_key(key) {
119 map.insert(key.clone(), value.clone());
120 }
121 }
122 }
123}
124
125impl<S> Service<RouterRequest> for InjectArgsService<S>
126where
127 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
128 + Clone
129 + Send
130 + 'static,
131 S::Future: Send,
132{
133 type Response = RouterResponse;
134 type Error = Infallible;
135 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
136
137 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138 self.inner.poll_ready(cx)
139 }
140
141 fn call(&mut self, mut req: RouterRequest) -> Self::Future {
142 if let McpRequest::CallTool(ref mut params) = req.inner {
143 for rules in self.rules.iter() {
144 if !params.name.starts_with(&rules.namespace) {
145 continue;
146 }
147
148 if !rules.default_args.is_empty() {
150 merge_args(&mut params.arguments, &rules.default_args, false);
151 }
152
153 if let Some(tool_rule) = rules.tool_rules.get(¶ms.name) {
155 merge_args(&mut params.arguments, &tool_rule.args, tool_rule.overwrite);
156 }
157
158 break; }
160 }
161
162 let fut = self.inner.call(req);
163 Box::pin(fut)
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::config::InjectArgsConfig;
171 use crate::test_util::{MockService, call_service};
172 use tower_mcp_types::protocol::{CallToolParams, McpRequest};
173
174 fn make_rules(
175 namespace: &str,
176 default_args: serde_json::Map<String, serde_json::Value>,
177 tool_rules: Vec<InjectArgsConfig>,
178 ) -> Vec<InjectionRules> {
179 vec![InjectionRules::new(
180 namespace.to_string(),
181 default_args,
182 tool_rules,
183 )]
184 }
185
186 #[tokio::test]
187 async fn test_injects_default_args() {
188 let mock = MockService::with_tools(&["db/query"]);
189 let mut defaults = serde_json::Map::new();
190 defaults.insert("timeout".to_string(), serde_json::json!(30));
191
192 let rules = make_rules("db/", defaults, vec![]);
193 let mut svc = InjectArgsService::new(mock, rules);
194
195 let resp = call_service(
196 &mut svc,
197 McpRequest::CallTool(CallToolParams {
198 name: "db/query".to_string(),
199 arguments: serde_json::json!({"sql": "SELECT 1"}),
200 meta: None,
201 task: None,
202 }),
203 )
204 .await;
205
206 assert!(resp.inner.is_ok());
208 }
209
210 #[tokio::test]
211 async fn test_default_args_dont_overwrite() {
212 let mock = MockService::with_tools(&["db/query"]);
213 let mut defaults = serde_json::Map::new();
214 defaults.insert("timeout".to_string(), serde_json::json!(30));
215
216 let rules = make_rules("db/", defaults, vec![]);
217 let _svc = InjectArgsService::new(mock, rules);
218
219 let mut req = RouterRequest {
221 id: tower_mcp::protocol::RequestId::Number(1),
222 inner: McpRequest::CallTool(CallToolParams {
223 name: "db/query".to_string(),
224 arguments: serde_json::json!({"sql": "SELECT 1", "timeout": 60}),
225 meta: None,
226 task: None,
227 }),
228 extensions: tower_mcp::router::Extensions::new(),
229 };
230
231 if let McpRequest::CallTool(ref mut params) = req.inner {
233 let mut defaults = serde_json::Map::new();
234 defaults.insert("timeout".to_string(), serde_json::json!(30));
235 merge_args(&mut params.arguments, &defaults, false);
236
237 assert_eq!(params.arguments["timeout"], 60);
239 assert_eq!(params.arguments["sql"], "SELECT 1");
241 }
242 }
243
244 #[tokio::test]
245 async fn test_per_tool_injection() {
246 let mock = MockService::with_tools(&["db/query"]);
247 let tool_rules = vec![InjectArgsConfig {
248 tool: "query".to_string(),
249 args: {
250 let mut m = serde_json::Map::new();
251 m.insert("read_only".to_string(), serde_json::json!(true));
252 m
253 },
254 overwrite: false,
255 }];
256
257 let rules = make_rules("db/", serde_json::Map::new(), tool_rules);
258 let mut svc = InjectArgsService::new(mock, rules);
259
260 let resp = call_service(
261 &mut svc,
262 McpRequest::CallTool(CallToolParams {
263 name: "db/query".to_string(),
264 arguments: serde_json::json!({"sql": "SELECT 1"}),
265 meta: None,
266 task: None,
267 }),
268 )
269 .await;
270
271 assert!(resp.inner.is_ok());
272 }
273
274 #[tokio::test]
275 async fn test_overwrite_mode() {
276 let mut args = serde_json::json!({"dry_run": false, "data": "hello"});
277 let mut inject = serde_json::Map::new();
278 inject.insert("dry_run".to_string(), serde_json::json!(true));
279
280 merge_args(&mut args, &inject, false);
282 assert_eq!(args["dry_run"], false); merge_args(&mut args, &inject, true);
286 assert_eq!(args["dry_run"], true); assert_eq!(args["data"], "hello"); }
289
290 #[tokio::test]
291 async fn test_non_matching_namespace_passes_through() {
292 let mock = MockService::with_tools(&["other/tool"]);
293 let mut defaults = serde_json::Map::new();
294 defaults.insert("timeout".to_string(), serde_json::json!(30));
295
296 let rules = make_rules("db/", defaults, vec![]);
297 let mut svc = InjectArgsService::new(mock, rules);
298
299 let resp = call_service(
300 &mut svc,
301 McpRequest::CallTool(CallToolParams {
302 name: "other/tool".to_string(),
303 arguments: serde_json::json!({}),
304 meta: None,
305 task: None,
306 }),
307 )
308 .await;
309
310 assert!(resp.inner.is_ok());
311 }
312
313 #[tokio::test]
314 async fn test_non_call_tool_passes_through() {
315 let mock = MockService::with_tools(&["db/query"]);
316 let mut defaults = serde_json::Map::new();
317 defaults.insert("timeout".to_string(), serde_json::json!(30));
318
319 let rules = make_rules("db/", defaults, vec![]);
320 let mut svc = InjectArgsService::new(mock, rules);
321
322 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
323 assert!(resp.inner.is_ok());
324 }
325
326 #[test]
327 fn test_merge_args_into_non_object() {
328 let mut args = serde_json::json!("not an object");
330 let mut inject = serde_json::Map::new();
331 inject.insert("key".to_string(), serde_json::json!("value"));
332 merge_args(&mut args, &inject, false);
333 assert_eq!(args, serde_json::json!("not an object"));
334 }
335
336 #[test]
337 fn test_merge_args_adds_new_keys() {
338 let mut args = serde_json::json!({"existing": 1});
339 let mut inject = serde_json::Map::new();
340 inject.insert("new_key".to_string(), serde_json::json!(42));
341 merge_args(&mut args, &inject, false);
342 assert_eq!(args["existing"], 1);
343 assert_eq!(args["new_key"], 42);
344 }
345}