1use std::collections::HashMap;
24use std::convert::Infallible;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{Context, Poll};
29
30use tower::{Layer, Service};
31use tower_mcp::router::{RouterRequest, RouterResponse};
32use tower_mcp_types::protocol::{McpRequest, McpResponse};
33
34#[derive(Clone)]
47pub struct ParamOverrideLayer {
48 overrides: Vec<ToolOverride>,
49}
50
51impl ParamOverrideLayer {
52 pub fn new(overrides: Vec<ToolOverride>) -> Self {
54 Self { overrides }
55 }
56}
57
58impl<S> Layer<S> for ParamOverrideLayer {
59 type Service = ParamOverrideService<S>;
60
61 fn layer(&self, inner: S) -> Self::Service {
62 ParamOverrideService::new(inner, self.overrides.clone())
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct ToolOverride {
69 namespaced_tool: String,
71 hide: Vec<String>,
73 defaults: serde_json::Map<String, serde_json::Value>,
75 rename_forward: HashMap<String, String>,
77 rename_reverse: HashMap<String, String>,
79}
80
81impl ToolOverride {
82 pub fn new(namespace: &str, config: &crate::config::ParamOverrideConfig) -> Self {
84 let rename_forward: HashMap<String, String> = config.rename.clone();
85 let rename_reverse: HashMap<String, String> = config
86 .rename
87 .iter()
88 .map(|(orig, new)| (new.clone(), orig.clone()))
89 .collect();
90
91 Self {
92 namespaced_tool: format!("{namespace}{}", config.tool),
93 hide: config.hide.clone(),
94 defaults: config.defaults.clone(),
95 rename_forward,
96 rename_reverse,
97 }
98 }
99}
100
101#[derive(Clone)]
107pub struct ParamOverrideService<S> {
108 inner: S,
109 overrides: Arc<Vec<ToolOverride>>,
110}
111
112impl<S> ParamOverrideService<S> {
113 pub fn new(inner: S, overrides: Vec<ToolOverride>) -> Self {
115 Self {
116 inner,
117 overrides: Arc::new(overrides),
118 }
119 }
120}
121
122fn rewrite_schema(
124 schema: &mut serde_json::Value,
125 hide: &[String],
126 rename_forward: &HashMap<String, String>,
127) {
128 let Some(obj) = schema.as_object_mut() else {
129 return;
130 };
131
132 if let Some(props) = obj.get_mut("properties").and_then(|v| v.as_object_mut()) {
134 for param in hide {
135 props.remove(param);
136 }
137 for (original, renamed) in rename_forward {
138 if let Some(prop_schema) = props.remove(original) {
139 props.insert(renamed.clone(), prop_schema);
140 }
141 }
142 }
143
144 if let Some(required) = obj.get_mut("required").and_then(|v| v.as_array_mut()) {
146 required.retain(|v| {
147 v.as_str()
148 .map(|s| !hide.contains(&s.to_string()))
149 .unwrap_or(true)
150 });
151 for entry in required.iter_mut() {
152 if let Some(s) = entry.as_str()
153 && let Some(new_name) = rename_forward.get(s)
154 {
155 *entry = serde_json::Value::String(new_name.clone());
156 }
157 }
158 }
159}
160
161impl<S> Service<RouterRequest> for ParamOverrideService<S>
162where
163 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
164 + Clone
165 + Send
166 + 'static,
167 S::Future: Send,
168{
169 type Response = RouterResponse;
170 type Error = Infallible;
171 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
172
173 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
174 self.inner.poll_ready(cx)
175 }
176
177 fn call(&mut self, mut req: RouterRequest) -> Self::Future {
178 let overrides = Arc::clone(&self.overrides);
179
180 if let McpRequest::CallTool(ref mut params) = req.inner {
182 for tool_override in overrides.iter() {
183 if params.name != tool_override.namespaced_tool {
184 continue;
185 }
186
187 if let serde_json::Value::Object(ref mut args) = params.arguments {
189 for (key, value) in &tool_override.defaults {
190 if !args.contains_key(key) {
191 args.insert(key.clone(), value.clone());
192 }
193 }
194
195 let keys_to_rename: Vec<(String, String)> = args
197 .keys()
198 .filter_map(|k| {
199 tool_override
200 .rename_reverse
201 .get(k)
202 .map(|orig| (k.clone(), orig.clone()))
203 })
204 .collect();
205
206 for (new_name, original_name) in keys_to_rename {
207 if let Some(value) = args.remove(&new_name) {
208 args.insert(original_name, value);
209 }
210 }
211 }
212
213 break;
214 }
215 }
216
217 let fut = self.inner.call(req);
218
219 Box::pin(async move {
220 let mut resp = fut.await?;
221
222 if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
224 for tool in &mut result.tools {
225 for tool_override in overrides.iter() {
226 if tool.name == tool_override.namespaced_tool {
227 rewrite_schema(
228 &mut tool.input_schema,
229 &tool_override.hide,
230 &tool_override.rename_forward,
231 );
232 break;
233 }
234 }
235 }
236 }
237
238 Ok(resp)
239 })
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::config::ParamOverrideConfig;
247 use crate::test_util::{MockService, call_service};
248 use tower_mcp_types::protocol::{CallToolParams, McpRequest, McpResponse};
249
250 fn mock_with_schema(name: &str, schema: serde_json::Value) -> MockService {
252 use tower_mcp_types::protocol::ToolDefinition;
253 MockService {
254 tools: vec![ToolDefinition {
255 name: name.to_string(),
256 title: None,
257 description: Some(format!("{name} tool")),
258 input_schema: schema,
259 output_schema: None,
260 icons: None,
261 annotations: None,
262 execution: None,
263 meta: None,
264 }],
265 }
266 }
267
268 fn list_dir_schema() -> serde_json::Value {
269 serde_json::json!({
270 "type": "object",
271 "properties": {
272 "path": { "type": "string" },
273 "recursive": { "type": "boolean" },
274 "pattern": { "type": "string" }
275 },
276 "required": ["path"]
277 })
278 }
279
280 fn make_overrides(namespace: &str, configs: Vec<ParamOverrideConfig>) -> Vec<ToolOverride> {
281 configs
282 .iter()
283 .map(|c| ToolOverride::new(namespace, c))
284 .collect()
285 }
286
287 #[tokio::test]
288 async fn test_hide_removes_param_from_schema() {
289 let mock = mock_with_schema("fs/list_directory", list_dir_schema());
290 let overrides = make_overrides(
291 "fs/",
292 vec![ParamOverrideConfig {
293 tool: "list_directory".to_string(),
294 hide: vec!["path".to_string()],
295 defaults: {
296 let mut m = serde_json::Map::new();
297 m.insert("path".to_string(), serde_json::json!("/home/docs"));
298 m
299 },
300 rename: HashMap::new(),
301 }],
302 );
303 let mut svc = ParamOverrideService::new(mock, overrides);
304
305 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
306 match resp.inner.unwrap() {
307 McpResponse::ListTools(result) => {
308 let tool = &result.tools[0];
309 let props = tool.input_schema["properties"].as_object().unwrap();
310 assert!(
311 !props.contains_key("path"),
312 "path should be hidden from schema"
313 );
314 assert!(props.contains_key("recursive"), "recursive should remain");
315 assert!(props.contains_key("pattern"), "pattern should remain");
316 let required = tool.input_schema["required"].as_array().unwrap();
318 let req_strs: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
319 assert!(!req_strs.contains(&"path"), "path should not be required");
320 }
321 other => panic!("expected ListTools, got: {:?}", other),
322 }
323 }
324
325 #[tokio::test]
326 async fn test_hide_injects_defaults_on_call() {
327 let mock = mock_with_schema("fs/list_directory", list_dir_schema());
328 let overrides = make_overrides(
329 "fs/",
330 vec![ParamOverrideConfig {
331 tool: "list_directory".to_string(),
332 hide: vec!["path".to_string()],
333 defaults: {
334 let mut m = serde_json::Map::new();
335 m.insert("path".to_string(), serde_json::json!("/home/docs"));
336 m
337 },
338 rename: HashMap::new(),
339 }],
340 );
341 let mut svc = ParamOverrideService::new(mock, overrides);
342
343 let resp = call_service(
344 &mut svc,
345 McpRequest::CallTool(CallToolParams {
346 name: "fs/list_directory".to_string(),
347 arguments: serde_json::json!({"recursive": true}),
348 meta: None,
349 task: None,
350 }),
351 )
352 .await;
353
354 assert!(resp.inner.is_ok(), "call should succeed");
355 }
356
357 #[tokio::test]
358 async fn test_rename_rewrites_schema() {
359 let mock = mock_with_schema("fs/list_directory", list_dir_schema());
360 let overrides = make_overrides(
361 "fs/",
362 vec![ParamOverrideConfig {
363 tool: "list_directory".to_string(),
364 hide: vec![],
365 defaults: serde_json::Map::new(),
366 rename: {
367 let mut m = HashMap::new();
368 m.insert("recursive".to_string(), "deep_search".to_string());
369 m
370 },
371 }],
372 );
373 let mut svc = ParamOverrideService::new(mock, overrides);
374
375 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
376 match resp.inner.unwrap() {
377 McpResponse::ListTools(result) => {
378 let tool = &result.tools[0];
379 let props = tool.input_schema["properties"].as_object().unwrap();
380 assert!(
381 !props.contains_key("recursive"),
382 "recursive should be renamed"
383 );
384 assert!(
385 props.contains_key("deep_search"),
386 "deep_search should appear"
387 );
388 assert!(props.contains_key("path"), "path should remain");
389 }
390 other => panic!("expected ListTools, got: {:?}", other),
391 }
392 }
393
394 #[tokio::test]
395 async fn test_rename_reverse_maps_on_call() {
396 let mock = mock_with_schema("fs/list_directory", list_dir_schema());
397 let overrides = make_overrides(
398 "fs/",
399 vec![ParamOverrideConfig {
400 tool: "list_directory".to_string(),
401 hide: vec![],
402 defaults: serde_json::Map::new(),
403 rename: {
404 let mut m = HashMap::new();
405 m.insert("recursive".to_string(), "deep_search".to_string());
406 m
407 },
408 }],
409 );
410 let mut svc = ParamOverrideService::new(mock, overrides);
411
412 let resp = call_service(
414 &mut svc,
415 McpRequest::CallTool(CallToolParams {
416 name: "fs/list_directory".to_string(),
417 arguments: serde_json::json!({"path": "/tmp", "deep_search": true}),
418 meta: None,
419 task: None,
420 }),
421 )
422 .await;
423
424 assert!(resp.inner.is_ok(), "call should succeed");
425 }
426
427 #[tokio::test]
428 async fn test_hide_and_rename_combined() {
429 let mock = mock_with_schema("fs/list_directory", list_dir_schema());
430 let overrides = make_overrides(
431 "fs/",
432 vec![ParamOverrideConfig {
433 tool: "list_directory".to_string(),
434 hide: vec!["path".to_string()],
435 defaults: {
436 let mut m = serde_json::Map::new();
437 m.insert("path".to_string(), serde_json::json!("/home/docs"));
438 m
439 },
440 rename: {
441 let mut m = HashMap::new();
442 m.insert("recursive".to_string(), "deep_search".to_string());
443 m
444 },
445 }],
446 );
447 let mut svc = ParamOverrideService::new(mock, overrides);
448
449 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
451 match resp.inner.unwrap() {
452 McpResponse::ListTools(result) => {
453 let props = result.tools[0].input_schema["properties"]
454 .as_object()
455 .unwrap();
456 assert!(!props.contains_key("path"));
457 assert!(!props.contains_key("recursive"));
458 assert!(props.contains_key("deep_search"));
459 assert!(props.contains_key("pattern"));
460 }
461 other => panic!("expected ListTools, got: {:?}", other),
462 }
463 }
464
465 #[tokio::test]
466 async fn test_non_matching_tool_passes_through() {
467 let mock = mock_with_schema("db/query", list_dir_schema());
468 let overrides = make_overrides(
469 "fs/",
470 vec![ParamOverrideConfig {
471 tool: "list_directory".to_string(),
472 hide: vec!["path".to_string()],
473 defaults: serde_json::Map::new(),
474 rename: HashMap::new(),
475 }],
476 );
477 let mut svc = ParamOverrideService::new(mock, overrides);
478
479 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
480 match resp.inner.unwrap() {
481 McpResponse::ListTools(result) => {
482 let props = result.tools[0].input_schema["properties"]
484 .as_object()
485 .unwrap();
486 assert!(props.contains_key("path"), "unmatched tool is untouched");
487 }
488 other => panic!("expected ListTools, got: {:?}", other),
489 }
490 }
491
492 #[tokio::test]
493 async fn test_non_call_tool_passes_through() {
494 let mock = MockService::with_tools(&["fs/list_directory"]);
495 let overrides = make_overrides(
496 "fs/",
497 vec![ParamOverrideConfig {
498 tool: "list_directory".to_string(),
499 hide: vec!["path".to_string()],
500 defaults: serde_json::Map::new(),
501 rename: HashMap::new(),
502 }],
503 );
504 let mut svc = ParamOverrideService::new(mock, overrides);
505
506 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
508 assert!(resp.inner.is_ok());
509 }
510
511 #[tokio::test]
512 async fn test_rename_updates_required_array() {
513 let schema = serde_json::json!({
514 "type": "object",
515 "properties": {
516 "path": { "type": "string" },
517 "recursive": { "type": "boolean" }
518 },
519 "required": ["path", "recursive"]
520 });
521 let mock = mock_with_schema("fs/list_directory", schema);
522 let overrides = make_overrides(
523 "fs/",
524 vec![ParamOverrideConfig {
525 tool: "list_directory".to_string(),
526 hide: vec![],
527 defaults: serde_json::Map::new(),
528 rename: {
529 let mut m = HashMap::new();
530 m.insert("recursive".to_string(), "deep_search".to_string());
531 m
532 },
533 }],
534 );
535 let mut svc = ParamOverrideService::new(mock, overrides);
536
537 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
538 match resp.inner.unwrap() {
539 McpResponse::ListTools(result) => {
540 let required = result.tools[0].input_schema["required"].as_array().unwrap();
541 let req_strs: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
542 assert!(req_strs.contains(&"path"));
543 assert!(req_strs.contains(&"deep_search"));
544 assert!(!req_strs.contains(&"recursive"));
545 }
546 other => panic!("expected ListTools, got: {:?}", other),
547 }
548 }
549
550 #[test]
551 fn test_rewrite_schema_no_properties() {
552 let mut schema = serde_json::json!({"type": "object"});
554 rewrite_schema(&mut schema, &["path".to_string()], &HashMap::new());
555 assert_eq!(schema, serde_json::json!({"type": "object"}));
556 }
557
558 #[test]
559 fn test_rewrite_schema_non_object() {
560 let mut schema = serde_json::json!("string");
562 rewrite_schema(&mut schema, &["path".to_string()], &HashMap::new());
563 assert_eq!(schema, serde_json::json!("string"));
564 }
565
566 #[test]
567 fn test_tool_override_construction() {
568 let config = ParamOverrideConfig {
569 tool: "list_directory".to_string(),
570 hide: vec!["path".to_string()],
571 defaults: {
572 let mut m = serde_json::Map::new();
573 m.insert("path".to_string(), serde_json::json!("/home"));
574 m
575 },
576 rename: {
577 let mut m = HashMap::new();
578 m.insert("recursive".to_string(), "deep_search".to_string());
579 m
580 },
581 };
582 let to = ToolOverride::new("fs/", &config);
583 assert_eq!(to.namespaced_tool, "fs/list_directory");
584 assert_eq!(to.hide, vec!["path"]);
585 assert_eq!(to.rename_forward.get("recursive").unwrap(), "deep_search");
586 assert_eq!(to.rename_reverse.get("deep_search").unwrap(), "recursive");
587 }
588
589 #[tokio::test]
590 async fn test_hidden_default_does_not_overwrite_explicit_arg() {
591 let _mock = mock_with_schema("fs/list_directory", list_dir_schema());
592 let overrides = make_overrides(
593 "fs/",
594 vec![ParamOverrideConfig {
595 tool: "list_directory".to_string(),
596 hide: vec!["path".to_string()],
597 defaults: {
598 let mut m = serde_json::Map::new();
599 m.insert("path".to_string(), serde_json::json!("/home/docs"));
600 m
601 },
602 rename: HashMap::new(),
603 }],
604 );
605
606 let mut req = RouterRequest {
609 id: tower_mcp::protocol::RequestId::Number(1),
610 inner: McpRequest::CallTool(CallToolParams {
611 name: "fs/list_directory".to_string(),
612 arguments: serde_json::json!({"path": "/custom"}),
613 meta: None,
614 task: None,
615 }),
616 extensions: tower_mcp::router::Extensions::new(),
617 };
618
619 if let McpRequest::CallTool(ref mut params) = req.inner
621 && let serde_json::Value::Object(ref mut args) = params.arguments
622 {
623 let defaults = &overrides[0].defaults;
624 for (key, value) in defaults {
625 if !args.contains_key(key) {
626 args.insert(key.clone(), value.clone());
627 }
628 }
629 assert_eq!(args.get("path").unwrap(), "/custom");
631 }
632 }
633}