Skip to main content

mcp_proxy/
param_override.rs

1//! Parameter override middleware for tool customization.
2//!
3//! Modifies tool schemas and call arguments to hide parameters (injecting
4//! defaults) and rename parameters. This turns generic tools into
5//! domain-specific ones via config.
6//!
7//! # Configuration
8//!
9//! ```toml
10//! [[backends.param_overrides]]
11//! tool = "list_directory"
12//! hide = ["path"]
13//! defaults = { path = "/home/docs" }
14//! rename = { recursive = "deep_search" }
15//! ```
16//!
17//! On `ListTools`: hidden parameters are removed from the tool's `input_schema`,
18//! and renamed parameters have their schema keys swapped.
19//!
20//! On `CallTool`: hidden parameter defaults are injected, and renamed
21//! parameters are mapped back to their original names before forwarding.
22
23use std::collections::HashMap;
24use std::convert::Infallible;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{Context, Poll};
29
30use tower::{Layer, Service};
31use tower_mcp::router::{RouterRequest, RouterResponse};
32use tower_mcp_types::protocol::{McpRequest, McpResponse};
33
34/// Tower layer that produces a [`ParamOverrideService`].
35///
36/// # Example
37///
38/// ```rust,ignore
39/// use tower::ServiceBuilder;
40/// use mcp_proxy::param_override::ParamOverrideLayer;
41///
42/// let service = ServiceBuilder::new()
43///     .layer(ParamOverrideLayer::new(overrides))
44///     .service(proxy);
45/// ```
46#[derive(Clone)]
47pub struct ParamOverrideLayer {
48    overrides: Vec<ToolOverride>,
49}
50
51impl ParamOverrideLayer {
52    /// Create a new parameter override layer.
53    pub fn new(overrides: Vec<ToolOverride>) -> Self {
54        Self { overrides }
55    }
56}
57
58impl<S> Layer<S> for ParamOverrideLayer {
59    type Service = ParamOverrideService<S>;
60
61    fn layer(&self, inner: S) -> Self::Service {
62        ParamOverrideService::new(inner, self.overrides.clone())
63    }
64}
65
66/// Resolved parameter override rules for a single tool.
67#[derive(Debug, Clone)]
68pub struct ToolOverride {
69    /// Namespaced tool name (e.g. "fs/list_directory").
70    namespaced_tool: String,
71    /// Parameters to hide from the schema.
72    hide: Vec<String>,
73    /// Default values for hidden parameters.
74    defaults: serde_json::Map<String, serde_json::Value>,
75    /// Forward rename map: original_name -> new_name (for schema rewriting).
76    rename_forward: HashMap<String, String>,
77    /// Reverse rename map: new_name -> original_name (for call rewriting).
78    rename_reverse: HashMap<String, String>,
79}
80
81impl ToolOverride {
82    /// Create a new tool override from config.
83    pub fn new(namespace: &str, config: &crate::config::ParamOverrideConfig) -> Self {
84        let rename_forward: HashMap<String, String> = config.rename.clone();
85        let rename_reverse: HashMap<String, String> = config
86            .rename
87            .iter()
88            .map(|(orig, new)| (new.clone(), orig.clone()))
89            .collect();
90
91        Self {
92            namespaced_tool: format!("{namespace}{}", config.tool),
93            hide: config.hide.clone(),
94            defaults: config.defaults.clone(),
95            rename_forward,
96            rename_reverse,
97        }
98    }
99}
100
101/// Parameter override middleware.
102///
103/// Intercepts `ListTools` responses to modify tool schemas (hiding and
104/// renaming parameters) and `CallTool` requests to inject hidden defaults
105/// and reverse-map renamed parameters.
106#[derive(Clone)]
107pub struct ParamOverrideService<S> {
108    inner: S,
109    overrides: Arc<Vec<ToolOverride>>,
110}
111
112impl<S> ParamOverrideService<S> {
113    /// Create a new parameter override service.
114    pub fn new(inner: S, overrides: Vec<ToolOverride>) -> Self {
115        Self {
116            inner,
117            overrides: Arc::new(overrides),
118        }
119    }
120}
121
122/// Remove hidden properties from a JSON Schema object and apply renames.
123fn rewrite_schema(
124    schema: &mut serde_json::Value,
125    hide: &[String],
126    rename_forward: &HashMap<String, String>,
127) {
128    let Some(obj) = schema.as_object_mut() else {
129        return;
130    };
131
132    // Rewrite "properties" object: remove hidden, rename others
133    if let Some(props) = obj.get_mut("properties").and_then(|v| v.as_object_mut()) {
134        for param in hide {
135            props.remove(param);
136        }
137        for (original, renamed) in rename_forward {
138            if let Some(prop_schema) = props.remove(original) {
139                props.insert(renamed.clone(), prop_schema);
140            }
141        }
142    }
143
144    // Rewrite "required" array: remove hidden, rename others
145    if let Some(required) = obj.get_mut("required").and_then(|v| v.as_array_mut()) {
146        required.retain(|v| {
147            v.as_str()
148                .map(|s| !hide.contains(&s.to_string()))
149                .unwrap_or(true)
150        });
151        for entry in required.iter_mut() {
152            if let Some(s) = entry.as_str()
153                && let Some(new_name) = rename_forward.get(s)
154            {
155                *entry = serde_json::Value::String(new_name.clone());
156            }
157        }
158    }
159}
160
161impl<S> Service<RouterRequest> for ParamOverrideService<S>
162where
163    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
164        + Clone
165        + Send
166        + 'static,
167    S::Future: Send,
168{
169    type Response = RouterResponse;
170    type Error = Infallible;
171    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
172
173    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
174        self.inner.poll_ready(cx)
175    }
176
177    fn call(&mut self, mut req: RouterRequest) -> Self::Future {
178        let overrides = Arc::clone(&self.overrides);
179
180        // On CallTool: inject hidden defaults, reverse-map renamed params
181        if let McpRequest::CallTool(ref mut params) = req.inner {
182            for tool_override in overrides.iter() {
183                if params.name != tool_override.namespaced_tool {
184                    continue;
185                }
186
187                // Inject defaults for hidden parameters
188                if let serde_json::Value::Object(ref mut args) = params.arguments {
189                    for (key, value) in &tool_override.defaults {
190                        if !args.contains_key(key) {
191                            args.insert(key.clone(), value.clone());
192                        }
193                    }
194
195                    // Reverse-map renamed parameters back to originals
196                    let keys_to_rename: Vec<(String, String)> = args
197                        .keys()
198                        .filter_map(|k| {
199                            tool_override
200                                .rename_reverse
201                                .get(k)
202                                .map(|orig| (k.clone(), orig.clone()))
203                        })
204                        .collect();
205
206                    for (new_name, original_name) in keys_to_rename {
207                        if let Some(value) = args.remove(&new_name) {
208                            args.insert(original_name, value);
209                        }
210                    }
211                }
212
213                break;
214            }
215        }
216
217        let fut = self.inner.call(req);
218
219        Box::pin(async move {
220            let mut resp = fut.await?;
221
222            // On ListTools: rewrite schemas
223            if let Ok(McpResponse::ListTools(ref mut result)) = resp.inner {
224                for tool in &mut result.tools {
225                    for tool_override in overrides.iter() {
226                        if tool.name == tool_override.namespaced_tool {
227                            rewrite_schema(
228                                &mut tool.input_schema,
229                                &tool_override.hide,
230                                &tool_override.rename_forward,
231                            );
232                            break;
233                        }
234                    }
235                }
236            }
237
238            Ok(resp)
239        })
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::config::ParamOverrideConfig;
247    use crate::test_util::{MockService, call_service};
248    use tower_mcp_types::protocol::{CallToolParams, McpRequest, McpResponse};
249
250    /// Create a MockService with a tool that has a rich input_schema.
251    fn mock_with_schema(name: &str, schema: serde_json::Value) -> MockService {
252        use tower_mcp_types::protocol::ToolDefinition;
253        MockService {
254            tools: vec![ToolDefinition {
255                name: name.to_string(),
256                title: None,
257                description: Some(format!("{name} tool")),
258                input_schema: schema,
259                output_schema: None,
260                icons: None,
261                annotations: None,
262                execution: None,
263                meta: None,
264            }],
265        }
266    }
267
268    fn list_dir_schema() -> serde_json::Value {
269        serde_json::json!({
270            "type": "object",
271            "properties": {
272                "path": { "type": "string" },
273                "recursive": { "type": "boolean" },
274                "pattern": { "type": "string" }
275            },
276            "required": ["path"]
277        })
278    }
279
280    fn make_overrides(namespace: &str, configs: Vec<ParamOverrideConfig>) -> Vec<ToolOverride> {
281        configs
282            .iter()
283            .map(|c| ToolOverride::new(namespace, c))
284            .collect()
285    }
286
287    #[tokio::test]
288    async fn test_hide_removes_param_from_schema() {
289        let mock = mock_with_schema("fs/list_directory", list_dir_schema());
290        let overrides = make_overrides(
291            "fs/",
292            vec![ParamOverrideConfig {
293                tool: "list_directory".to_string(),
294                hide: vec!["path".to_string()],
295                defaults: {
296                    let mut m = serde_json::Map::new();
297                    m.insert("path".to_string(), serde_json::json!("/home/docs"));
298                    m
299                },
300                rename: HashMap::new(),
301            }],
302        );
303        let mut svc = ParamOverrideService::new(mock, overrides);
304
305        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
306        match resp.inner.unwrap() {
307            McpResponse::ListTools(result) => {
308                let tool = &result.tools[0];
309                let props = tool.input_schema["properties"].as_object().unwrap();
310                assert!(
311                    !props.contains_key("path"),
312                    "path should be hidden from schema"
313                );
314                assert!(props.contains_key("recursive"), "recursive should remain");
315                assert!(props.contains_key("pattern"), "pattern should remain");
316                // "path" should be removed from required
317                let required = tool.input_schema["required"].as_array().unwrap();
318                let req_strs: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
319                assert!(!req_strs.contains(&"path"), "path should not be required");
320            }
321            other => panic!("expected ListTools, got: {:?}", other),
322        }
323    }
324
325    #[tokio::test]
326    async fn test_hide_injects_defaults_on_call() {
327        let mock = mock_with_schema("fs/list_directory", list_dir_schema());
328        let overrides = make_overrides(
329            "fs/",
330            vec![ParamOverrideConfig {
331                tool: "list_directory".to_string(),
332                hide: vec!["path".to_string()],
333                defaults: {
334                    let mut m = serde_json::Map::new();
335                    m.insert("path".to_string(), serde_json::json!("/home/docs"));
336                    m
337                },
338                rename: HashMap::new(),
339            }],
340        );
341        let mut svc = ParamOverrideService::new(mock, overrides);
342
343        let resp = call_service(
344            &mut svc,
345            McpRequest::CallTool(CallToolParams {
346                name: "fs/list_directory".to_string(),
347                arguments: serde_json::json!({"recursive": true}),
348                meta: None,
349                task: None,
350            }),
351        )
352        .await;
353
354        assert!(resp.inner.is_ok(), "call should succeed");
355    }
356
357    #[tokio::test]
358    async fn test_rename_rewrites_schema() {
359        let mock = mock_with_schema("fs/list_directory", list_dir_schema());
360        let overrides = make_overrides(
361            "fs/",
362            vec![ParamOverrideConfig {
363                tool: "list_directory".to_string(),
364                hide: vec![],
365                defaults: serde_json::Map::new(),
366                rename: {
367                    let mut m = HashMap::new();
368                    m.insert("recursive".to_string(), "deep_search".to_string());
369                    m
370                },
371            }],
372        );
373        let mut svc = ParamOverrideService::new(mock, overrides);
374
375        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
376        match resp.inner.unwrap() {
377            McpResponse::ListTools(result) => {
378                let tool = &result.tools[0];
379                let props = tool.input_schema["properties"].as_object().unwrap();
380                assert!(
381                    !props.contains_key("recursive"),
382                    "recursive should be renamed"
383                );
384                assert!(
385                    props.contains_key("deep_search"),
386                    "deep_search should appear"
387                );
388                assert!(props.contains_key("path"), "path should remain");
389            }
390            other => panic!("expected ListTools, got: {:?}", other),
391        }
392    }
393
394    #[tokio::test]
395    async fn test_rename_reverse_maps_on_call() {
396        let mock = mock_with_schema("fs/list_directory", list_dir_schema());
397        let overrides = make_overrides(
398            "fs/",
399            vec![ParamOverrideConfig {
400                tool: "list_directory".to_string(),
401                hide: vec![],
402                defaults: serde_json::Map::new(),
403                rename: {
404                    let mut m = HashMap::new();
405                    m.insert("recursive".to_string(), "deep_search".to_string());
406                    m
407                },
408            }],
409        );
410        let mut svc = ParamOverrideService::new(mock, overrides);
411
412        // Client sends "deep_search" (the renamed param)
413        let resp = call_service(
414            &mut svc,
415            McpRequest::CallTool(CallToolParams {
416                name: "fs/list_directory".to_string(),
417                arguments: serde_json::json!({"path": "/tmp", "deep_search": true}),
418                meta: None,
419                task: None,
420            }),
421        )
422        .await;
423
424        assert!(resp.inner.is_ok(), "call should succeed");
425    }
426
427    #[tokio::test]
428    async fn test_hide_and_rename_combined() {
429        let mock = mock_with_schema("fs/list_directory", list_dir_schema());
430        let overrides = make_overrides(
431            "fs/",
432            vec![ParamOverrideConfig {
433                tool: "list_directory".to_string(),
434                hide: vec!["path".to_string()],
435                defaults: {
436                    let mut m = serde_json::Map::new();
437                    m.insert("path".to_string(), serde_json::json!("/home/docs"));
438                    m
439                },
440                rename: {
441                    let mut m = HashMap::new();
442                    m.insert("recursive".to_string(), "deep_search".to_string());
443                    m
444                },
445            }],
446        );
447        let mut svc = ParamOverrideService::new(mock, overrides);
448
449        // Verify schema: path hidden, recursive renamed
450        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
451        match resp.inner.unwrap() {
452            McpResponse::ListTools(result) => {
453                let props = result.tools[0].input_schema["properties"]
454                    .as_object()
455                    .unwrap();
456                assert!(!props.contains_key("path"));
457                assert!(!props.contains_key("recursive"));
458                assert!(props.contains_key("deep_search"));
459                assert!(props.contains_key("pattern"));
460            }
461            other => panic!("expected ListTools, got: {:?}", other),
462        }
463    }
464
465    #[tokio::test]
466    async fn test_non_matching_tool_passes_through() {
467        let mock = mock_with_schema("db/query", list_dir_schema());
468        let overrides = make_overrides(
469            "fs/",
470            vec![ParamOverrideConfig {
471                tool: "list_directory".to_string(),
472                hide: vec!["path".to_string()],
473                defaults: serde_json::Map::new(),
474                rename: HashMap::new(),
475            }],
476        );
477        let mut svc = ParamOverrideService::new(mock, overrides);
478
479        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
480        match resp.inner.unwrap() {
481            McpResponse::ListTools(result) => {
482                // db/query schema should be untouched
483                let props = result.tools[0].input_schema["properties"]
484                    .as_object()
485                    .unwrap();
486                assert!(props.contains_key("path"), "unmatched tool is untouched");
487            }
488            other => panic!("expected ListTools, got: {:?}", other),
489        }
490    }
491
492    #[tokio::test]
493    async fn test_non_call_tool_passes_through() {
494        let mock = MockService::with_tools(&["fs/list_directory"]);
495        let overrides = make_overrides(
496            "fs/",
497            vec![ParamOverrideConfig {
498                tool: "list_directory".to_string(),
499                hide: vec!["path".to_string()],
500                defaults: serde_json::Map::new(),
501                rename: HashMap::new(),
502            }],
503        );
504        let mut svc = ParamOverrideService::new(mock, overrides);
505
506        // ListTools should pass through without error
507        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
508        assert!(resp.inner.is_ok());
509    }
510
511    #[tokio::test]
512    async fn test_rename_updates_required_array() {
513        let schema = serde_json::json!({
514            "type": "object",
515            "properties": {
516                "path": { "type": "string" },
517                "recursive": { "type": "boolean" }
518            },
519            "required": ["path", "recursive"]
520        });
521        let mock = mock_with_schema("fs/list_directory", schema);
522        let overrides = make_overrides(
523            "fs/",
524            vec![ParamOverrideConfig {
525                tool: "list_directory".to_string(),
526                hide: vec![],
527                defaults: serde_json::Map::new(),
528                rename: {
529                    let mut m = HashMap::new();
530                    m.insert("recursive".to_string(), "deep_search".to_string());
531                    m
532                },
533            }],
534        );
535        let mut svc = ParamOverrideService::new(mock, overrides);
536
537        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
538        match resp.inner.unwrap() {
539            McpResponse::ListTools(result) => {
540                let required = result.tools[0].input_schema["required"].as_array().unwrap();
541                let req_strs: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
542                assert!(req_strs.contains(&"path"));
543                assert!(req_strs.contains(&"deep_search"));
544                assert!(!req_strs.contains(&"recursive"));
545            }
546            other => panic!("expected ListTools, got: {:?}", other),
547        }
548    }
549
550    #[test]
551    fn test_rewrite_schema_no_properties() {
552        // Schema without properties should be a no-op
553        let mut schema = serde_json::json!({"type": "object"});
554        rewrite_schema(&mut schema, &["path".to_string()], &HashMap::new());
555        assert_eq!(schema, serde_json::json!({"type": "object"}));
556    }
557
558    #[test]
559    fn test_rewrite_schema_non_object() {
560        // Non-object schema is a no-op
561        let mut schema = serde_json::json!("string");
562        rewrite_schema(&mut schema, &["path".to_string()], &HashMap::new());
563        assert_eq!(schema, serde_json::json!("string"));
564    }
565
566    #[test]
567    fn test_tool_override_construction() {
568        let config = ParamOverrideConfig {
569            tool: "list_directory".to_string(),
570            hide: vec!["path".to_string()],
571            defaults: {
572                let mut m = serde_json::Map::new();
573                m.insert("path".to_string(), serde_json::json!("/home"));
574                m
575            },
576            rename: {
577                let mut m = HashMap::new();
578                m.insert("recursive".to_string(), "deep_search".to_string());
579                m
580            },
581        };
582        let to = ToolOverride::new("fs/", &config);
583        assert_eq!(to.namespaced_tool, "fs/list_directory");
584        assert_eq!(to.hide, vec!["path"]);
585        assert_eq!(to.rename_forward.get("recursive").unwrap(), "deep_search");
586        assert_eq!(to.rename_reverse.get("deep_search").unwrap(), "recursive");
587    }
588
589    #[tokio::test]
590    async fn test_hidden_default_does_not_overwrite_explicit_arg() {
591        let _mock = mock_with_schema("fs/list_directory", list_dir_schema());
592        let overrides = make_overrides(
593            "fs/",
594            vec![ParamOverrideConfig {
595                tool: "list_directory".to_string(),
596                hide: vec!["path".to_string()],
597                defaults: {
598                    let mut m = serde_json::Map::new();
599                    m.insert("path".to_string(), serde_json::json!("/home/docs"));
600                    m
601                },
602                rename: HashMap::new(),
603            }],
604        );
605
606        // Even though "path" is hidden, if the client passes it, the default
607        // should not overwrite. This is consistent with inject behavior.
608        let mut req = RouterRequest {
609            id: tower_mcp::protocol::RequestId::Number(1),
610            inner: McpRequest::CallTool(CallToolParams {
611                name: "fs/list_directory".to_string(),
612                arguments: serde_json::json!({"path": "/custom"}),
613                meta: None,
614                task: None,
615            }),
616            extensions: tower_mcp::router::Extensions::new(),
617        };
618
619        // Simulate the override logic manually
620        if let McpRequest::CallTool(ref mut params) = req.inner
621            && let serde_json::Value::Object(ref mut args) = params.arguments
622        {
623            let defaults = &overrides[0].defaults;
624            for (key, value) in defaults {
625                if !args.contains_key(key) {
626                    args.insert(key.clone(), value.clone());
627                }
628            }
629            // path should still be /custom
630            assert_eq!(args.get("path").unwrap(), "/custom");
631        }
632    }
633}