1use std::collections::HashMap;
32use std::convert::Infallible;
33use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36use std::task::{Context, Poll};
37
38use tower::{Layer, Service};
39use tower_mcp::router::{RouterRequest, RouterResponse};
40use tower_mcp_types::protocol::McpRequest;
41
42#[derive(Clone)]
44pub struct InjectArgsLayer {
45 rules: Vec<InjectionRules>,
46}
47
48impl InjectArgsLayer {
49 pub fn new(rules: Vec<InjectionRules>) -> Self {
51 Self { rules }
52 }
53}
54
55impl<S> Layer<S> for InjectArgsLayer {
56 type Service = InjectArgsService<S>;
57
58 fn layer(&self, inner: S) -> Self::Service {
59 InjectArgsService::new(inner, self.rules.clone())
60 }
61}
62
63#[derive(Debug, Clone)]
65struct ToolInjection {
66 args: serde_json::Map<String, serde_json::Value>,
67 overwrite: bool,
68}
69
70#[derive(Debug, Clone)]
72pub struct InjectionRules {
73 namespace: String,
75 default_args: serde_json::Map<String, serde_json::Value>,
77 tool_rules: HashMap<String, ToolInjection>,
79}
80
81impl InjectionRules {
82 pub fn new(
84 namespace: String,
85 default_args: serde_json::Map<String, serde_json::Value>,
86 tool_rules: Vec<crate::config::InjectArgsConfig>,
87 ) -> Self {
88 let tool_rules = tool_rules
89 .into_iter()
90 .map(|r| {
91 let namespaced = format!("{namespace}{}", r.tool);
92 (
93 namespaced,
94 ToolInjection {
95 args: r.args,
96 overwrite: r.overwrite,
97 },
98 )
99 })
100 .collect();
101
102 Self {
103 namespace,
104 default_args,
105 tool_rules,
106 }
107 }
108}
109
110#[derive(Clone)]
115pub struct InjectArgsService<S> {
116 inner: S,
117 rules: Arc<Vec<InjectionRules>>,
118}
119
120impl<S> InjectArgsService<S> {
121 pub fn new(inner: S, rules: Vec<InjectionRules>) -> Self {
123 Self {
124 inner,
125 rules: Arc::new(rules),
126 }
127 }
128}
129
130fn merge_args(
133 target: &mut serde_json::Value,
134 source: &serde_json::Map<String, serde_json::Value>,
135 overwrite: bool,
136) {
137 if let serde_json::Value::Object(map) = target {
138 for (key, value) in source {
139 if overwrite || !map.contains_key(key) {
140 map.insert(key.clone(), value.clone());
141 }
142 }
143 }
144}
145
146impl<S> Service<RouterRequest> for InjectArgsService<S>
147where
148 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
149 + Clone
150 + Send
151 + 'static,
152 S::Future: Send,
153{
154 type Response = RouterResponse;
155 type Error = Infallible;
156 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
157
158 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
159 self.inner.poll_ready(cx)
160 }
161
162 fn call(&mut self, mut req: RouterRequest) -> Self::Future {
163 if let McpRequest::CallTool(ref mut params) = req.inner {
164 for rules in self.rules.iter() {
165 if !params.name.starts_with(&rules.namespace) {
166 continue;
167 }
168
169 if !rules.default_args.is_empty() {
171 merge_args(&mut params.arguments, &rules.default_args, false);
172 }
173
174 if let Some(tool_rule) = rules.tool_rules.get(¶ms.name) {
176 merge_args(&mut params.arguments, &tool_rule.args, tool_rule.overwrite);
177 }
178
179 break; }
181 }
182
183 let fut = self.inner.call(req);
184 Box::pin(fut)
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::config::InjectArgsConfig;
192 use crate::test_util::{MockService, call_service};
193 use tower_mcp_types::protocol::{CallToolParams, McpRequest};
194
195 fn make_rules(
196 namespace: &str,
197 default_args: serde_json::Map<String, serde_json::Value>,
198 tool_rules: Vec<InjectArgsConfig>,
199 ) -> Vec<InjectionRules> {
200 vec![InjectionRules::new(
201 namespace.to_string(),
202 default_args,
203 tool_rules,
204 )]
205 }
206
207 #[tokio::test]
208 async fn test_injects_default_args() {
209 let mock = MockService::with_tools(&["db/query"]);
210 let mut defaults = serde_json::Map::new();
211 defaults.insert("timeout".to_string(), serde_json::json!(30));
212
213 let rules = make_rules("db/", defaults, vec![]);
214 let mut svc = InjectArgsService::new(mock, rules);
215
216 let resp = call_service(
217 &mut svc,
218 McpRequest::CallTool(CallToolParams {
219 name: "db/query".to_string(),
220 arguments: serde_json::json!({"sql": "SELECT 1"}),
221 meta: None,
222 task: None,
223 }),
224 )
225 .await;
226
227 assert!(resp.inner.is_ok());
229 }
230
231 #[tokio::test]
232 async fn test_default_args_dont_overwrite() {
233 let mock = MockService::with_tools(&["db/query"]);
234 let mut defaults = serde_json::Map::new();
235 defaults.insert("timeout".to_string(), serde_json::json!(30));
236
237 let rules = make_rules("db/", defaults, vec![]);
238 let _svc = InjectArgsService::new(mock, rules);
239
240 let mut req = RouterRequest {
242 id: tower_mcp::protocol::RequestId::Number(1),
243 inner: McpRequest::CallTool(CallToolParams {
244 name: "db/query".to_string(),
245 arguments: serde_json::json!({"sql": "SELECT 1", "timeout": 60}),
246 meta: None,
247 task: None,
248 }),
249 extensions: tower_mcp::router::Extensions::new(),
250 };
251
252 if let McpRequest::CallTool(ref mut params) = req.inner {
254 let mut defaults = serde_json::Map::new();
255 defaults.insert("timeout".to_string(), serde_json::json!(30));
256 merge_args(&mut params.arguments, &defaults, false);
257
258 assert_eq!(params.arguments["timeout"], 60);
260 assert_eq!(params.arguments["sql"], "SELECT 1");
262 }
263 }
264
265 #[tokio::test]
266 async fn test_per_tool_injection() {
267 let mock = MockService::with_tools(&["db/query"]);
268 let tool_rules = vec![InjectArgsConfig {
269 tool: "query".to_string(),
270 args: {
271 let mut m = serde_json::Map::new();
272 m.insert("read_only".to_string(), serde_json::json!(true));
273 m
274 },
275 overwrite: false,
276 }];
277
278 let rules = make_rules("db/", serde_json::Map::new(), tool_rules);
279 let mut svc = InjectArgsService::new(mock, rules);
280
281 let resp = call_service(
282 &mut svc,
283 McpRequest::CallTool(CallToolParams {
284 name: "db/query".to_string(),
285 arguments: serde_json::json!({"sql": "SELECT 1"}),
286 meta: None,
287 task: None,
288 }),
289 )
290 .await;
291
292 assert!(resp.inner.is_ok());
293 }
294
295 #[tokio::test]
296 async fn test_overwrite_mode() {
297 let mut args = serde_json::json!({"dry_run": false, "data": "hello"});
298 let mut inject = serde_json::Map::new();
299 inject.insert("dry_run".to_string(), serde_json::json!(true));
300
301 merge_args(&mut args, &inject, false);
303 assert_eq!(args["dry_run"], false); merge_args(&mut args, &inject, true);
307 assert_eq!(args["dry_run"], true); assert_eq!(args["data"], "hello"); }
310
311 #[tokio::test]
312 async fn test_non_matching_namespace_passes_through() {
313 let mock = MockService::with_tools(&["other/tool"]);
314 let mut defaults = serde_json::Map::new();
315 defaults.insert("timeout".to_string(), serde_json::json!(30));
316
317 let rules = make_rules("db/", defaults, vec![]);
318 let mut svc = InjectArgsService::new(mock, rules);
319
320 let resp = call_service(
321 &mut svc,
322 McpRequest::CallTool(CallToolParams {
323 name: "other/tool".to_string(),
324 arguments: serde_json::json!({}),
325 meta: None,
326 task: None,
327 }),
328 )
329 .await;
330
331 assert!(resp.inner.is_ok());
332 }
333
334 #[tokio::test]
335 async fn test_non_call_tool_passes_through() {
336 let mock = MockService::with_tools(&["db/query"]);
337 let mut defaults = serde_json::Map::new();
338 defaults.insert("timeout".to_string(), serde_json::json!(30));
339
340 let rules = make_rules("db/", defaults, vec![]);
341 let mut svc = InjectArgsService::new(mock, rules);
342
343 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
344 assert!(resp.inner.is_ok());
345 }
346
347 #[test]
348 fn test_merge_args_into_non_object() {
349 let mut args = serde_json::json!("not an object");
351 let mut inject = serde_json::Map::new();
352 inject.insert("key".to_string(), serde_json::json!("value"));
353 merge_args(&mut args, &inject, false);
354 assert_eq!(args, serde_json::json!("not an object"));
355 }
356
357 #[test]
358 fn test_merge_args_adds_new_keys() {
359 let mut args = serde_json::json!({"existing": 1});
360 let mut inject = serde_json::Map::new();
361 inject.insert("new_key".to_string(), serde_json::json!(42));
362 merge_args(&mut args, &inject, false);
363 assert_eq!(args["existing"], 1);
364 assert_eq!(args["new_key"], 42);
365 }
366}