use std::collections::HashMap;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::{Layer, Service};
use tower_mcp::router::{RouterRequest, RouterResponse};
use tower_mcp_types::protocol::McpRequest;
#[derive(Clone)]
pub struct InjectArgsLayer {
rules: Vec<InjectionRules>,
}
impl InjectArgsLayer {
pub fn new(rules: Vec<InjectionRules>) -> Self {
Self { rules }
}
}
impl<S> Layer<S> for InjectArgsLayer {
type Service = InjectArgsService<S>;
fn layer(&self, inner: S) -> Self::Service {
InjectArgsService::new(inner, self.rules.clone())
}
}
#[derive(Debug, Clone)]
struct ToolInjection {
args: serde_json::Map<String, serde_json::Value>,
overwrite: bool,
}
#[derive(Debug, Clone)]
pub struct InjectionRules {
namespace: String,
default_args: serde_json::Map<String, serde_json::Value>,
tool_rules: HashMap<String, ToolInjection>,
}
impl InjectionRules {
pub fn new(
namespace: String,
default_args: serde_json::Map<String, serde_json::Value>,
tool_rules: Vec<crate::config::InjectArgsConfig>,
) -> Self {
let tool_rules = tool_rules
.into_iter()
.map(|r| {
let namespaced = format!("{namespace}{}", r.tool);
(
namespaced,
ToolInjection {
args: r.args,
overwrite: r.overwrite,
},
)
})
.collect();
Self {
namespace,
default_args,
tool_rules,
}
}
}
#[derive(Clone)]
pub struct InjectArgsService<S> {
inner: S,
rules: Arc<Vec<InjectionRules>>,
}
impl<S> InjectArgsService<S> {
pub fn new(inner: S, rules: Vec<InjectionRules>) -> Self {
Self {
inner,
rules: Arc::new(rules),
}
}
}
fn merge_args(
target: &mut serde_json::Value,
source: &serde_json::Map<String, serde_json::Value>,
overwrite: bool,
) {
if let serde_json::Value::Object(map) = target {
for (key, value) in source {
if overwrite || !map.contains_key(key) {
map.insert(key.clone(), value.clone());
}
}
}
}
impl<S> Service<RouterRequest> for InjectArgsService<S>
where
S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = RouterResponse;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: RouterRequest) -> Self::Future {
if let McpRequest::CallTool(ref mut params) = req.inner {
for rules in self.rules.iter() {
if !params.name.starts_with(&rules.namespace) {
continue;
}
if !rules.default_args.is_empty() {
merge_args(&mut params.arguments, &rules.default_args, false);
}
if let Some(tool_rule) = rules.tool_rules.get(¶ms.name) {
merge_args(&mut params.arguments, &tool_rule.args, tool_rule.overwrite);
}
break; }
}
let fut = self.inner.call(req);
Box::pin(fut)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::InjectArgsConfig;
use crate::test_util::{MockService, call_service};
use tower_mcp_types::protocol::{CallToolParams, McpRequest};
fn make_rules(
namespace: &str,
default_args: serde_json::Map<String, serde_json::Value>,
tool_rules: Vec<InjectArgsConfig>,
) -> Vec<InjectionRules> {
vec![InjectionRules::new(
namespace.to_string(),
default_args,
tool_rules,
)]
}
#[tokio::test]
async fn test_injects_default_args() {
let mock = MockService::with_tools(&["db/query"]);
let mut defaults = serde_json::Map::new();
defaults.insert("timeout".to_string(), serde_json::json!(30));
let rules = make_rules("db/", defaults, vec![]);
let mut svc = InjectArgsService::new(mock, rules);
let resp = call_service(
&mut svc,
McpRequest::CallTool(CallToolParams {
name: "db/query".to_string(),
arguments: serde_json::json!({"sql": "SELECT 1"}),
meta: None,
task: None,
}),
)
.await;
assert!(resp.inner.is_ok());
}
#[tokio::test]
async fn test_default_args_dont_overwrite() {
let mock = MockService::with_tools(&["db/query"]);
let mut defaults = serde_json::Map::new();
defaults.insert("timeout".to_string(), serde_json::json!(30));
let rules = make_rules("db/", defaults, vec![]);
let _svc = InjectArgsService::new(mock, rules);
let mut req = RouterRequest {
id: tower_mcp::protocol::RequestId::Number(1),
inner: McpRequest::CallTool(CallToolParams {
name: "db/query".to_string(),
arguments: serde_json::json!({"sql": "SELECT 1", "timeout": 60}),
meta: None,
task: None,
}),
extensions: tower_mcp::router::Extensions::new(),
};
if let McpRequest::CallTool(ref mut params) = req.inner {
let mut defaults = serde_json::Map::new();
defaults.insert("timeout".to_string(), serde_json::json!(30));
merge_args(&mut params.arguments, &defaults, false);
assert_eq!(params.arguments["timeout"], 60);
assert_eq!(params.arguments["sql"], "SELECT 1");
}
}
#[tokio::test]
async fn test_per_tool_injection() {
let mock = MockService::with_tools(&["db/query"]);
let tool_rules = vec![InjectArgsConfig {
tool: "query".to_string(),
args: {
let mut m = serde_json::Map::new();
m.insert("read_only".to_string(), serde_json::json!(true));
m
},
overwrite: false,
}];
let rules = make_rules("db/", serde_json::Map::new(), tool_rules);
let mut svc = InjectArgsService::new(mock, rules);
let resp = call_service(
&mut svc,
McpRequest::CallTool(CallToolParams {
name: "db/query".to_string(),
arguments: serde_json::json!({"sql": "SELECT 1"}),
meta: None,
task: None,
}),
)
.await;
assert!(resp.inner.is_ok());
}
#[tokio::test]
async fn test_overwrite_mode() {
let mut args = serde_json::json!({"dry_run": false, "data": "hello"});
let mut inject = serde_json::Map::new();
inject.insert("dry_run".to_string(), serde_json::json!(true));
merge_args(&mut args, &inject, false);
assert_eq!(args["dry_run"], false);
merge_args(&mut args, &inject, true);
assert_eq!(args["dry_run"], true); assert_eq!(args["data"], "hello"); }
#[tokio::test]
async fn test_non_matching_namespace_passes_through() {
let mock = MockService::with_tools(&["other/tool"]);
let mut defaults = serde_json::Map::new();
defaults.insert("timeout".to_string(), serde_json::json!(30));
let rules = make_rules("db/", defaults, vec![]);
let mut svc = InjectArgsService::new(mock, rules);
let resp = call_service(
&mut svc,
McpRequest::CallTool(CallToolParams {
name: "other/tool".to_string(),
arguments: serde_json::json!({}),
meta: None,
task: None,
}),
)
.await;
assert!(resp.inner.is_ok());
}
#[tokio::test]
async fn test_non_call_tool_passes_through() {
let mock = MockService::with_tools(&["db/query"]);
let mut defaults = serde_json::Map::new();
defaults.insert("timeout".to_string(), serde_json::json!(30));
let rules = make_rules("db/", defaults, vec![]);
let mut svc = InjectArgsService::new(mock, rules);
let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
assert!(resp.inner.is_ok());
}
#[test]
fn test_merge_args_into_non_object() {
let mut args = serde_json::json!("not an object");
let mut inject = serde_json::Map::new();
inject.insert("key".to_string(), serde_json::json!("value"));
merge_args(&mut args, &inject, false);
assert_eq!(args, serde_json::json!("not an object"));
}
#[test]
fn test_merge_args_adds_new_keys() {
let mut args = serde_json::json!({"existing": 1});
let mut inject = serde_json::Map::new();
inject.insert("new_key".to_string(), serde_json::json!(42));
merge_args(&mut args, &inject, false);
assert_eq!(args["existing"], 1);
assert_eq!(args["new_key"], 42);
}
}