use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use crate::protocol::v2::manifest::McpConfig;
use crate::types::tool::{FunctionDefinition, ToolCall, ToolDefinition, ToolResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpTool {
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default, rename = "inputSchema")]
pub input_schema: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolInvocation {
pub name: String,
pub arguments: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolResult {
pub content: Vec<McpContent>,
#[serde(default)]
pub is_error: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpContent {
#[serde(rename = "type")]
pub content_type: String,
#[serde(default)]
pub text: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerSpec {
pub name: String,
pub transport: String,
pub uri: String,
#[serde(default)]
pub auth: Option<McpAuth>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpAuth {
pub method: String,
#[serde(default)]
pub token: Option<String>,
#[serde(default)]
pub token_env: Option<String>,
}
#[derive(Debug, Clone)]
pub struct McpProviderConfig {
pub tool_type: String,
pub beta_header: Option<String>,
pub api_endpoint: Option<String>,
pub config_method: McpConfigMethod,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum McpConfigMethod {
ToolParameter,
SdkConfig,
CliFlag,
}
#[derive(Debug)]
pub struct McpToolBridge {
namespace: String,
allow_filter: HashSet<String>,
deny_filter: HashSet<String>,
}
impl McpToolBridge {
pub fn new(server_name: &str) -> Self {
Self {
namespace: format!("mcp__{}__", server_name),
allow_filter: HashSet::new(),
deny_filter: HashSet::new(),
}
}
pub fn with_allow_filter(mut self, tools: impl IntoIterator<Item = String>) -> Self {
self.allow_filter = tools.into_iter().collect();
self
}
pub fn with_deny_filter(mut self, tools: impl IntoIterator<Item = String>) -> Self {
self.deny_filter = tools.into_iter().collect();
self
}
pub fn mcp_tools_to_protocol(&self, mcp_tools: &[McpTool]) -> Vec<ToolDefinition> {
mcp_tools
.iter()
.filter(|t| self.is_tool_allowed(&t.name))
.map(|t| self.convert_tool(t))
.collect()
}
fn convert_tool(&self, tool: &McpTool) -> ToolDefinition {
ToolDefinition {
tool_type: "function".to_string(),
function: FunctionDefinition {
name: self.namespaced_name(&tool.name),
description: tool.description.clone(),
parameters: tool.input_schema.clone(),
},
}
}
pub fn protocol_call_to_mcp(&self, call: &ToolCall) -> Option<McpToolInvocation> {
let original_name = self.strip_namespace(&call.name)?;
Some(McpToolInvocation {
name: original_name,
arguments: call.arguments.clone(),
})
}
pub fn mcp_result_to_protocol(&self, tool_call_id: &str, result: &McpToolResult) -> ToolResult {
let content = result
.content
.iter()
.filter_map(|c| c.text.clone())
.collect::<Vec<_>>()
.join("\n");
ToolResult {
tool_use_id: tool_call_id.to_string(),
content: if result.is_error {
serde_json::json!({ "error": content })
} else {
serde_json::json!(content)
},
is_error: result.is_error,
}
}
fn is_tool_allowed(&self, name: &str) -> bool {
if !self.deny_filter.is_empty() && self.deny_filter.contains(name) {
return false;
}
if !self.allow_filter.is_empty() {
return self.allow_filter.contains(name);
}
true
}
fn namespaced_name(&self, name: &str) -> String {
format!("{}{}", self.namespace, name)
}
fn strip_namespace(&self, namespaced: &str) -> Option<String> {
namespaced.strip_prefix(&self.namespace).map(String::from)
}
}
pub fn extract_provider_config(mcp_config: &McpConfig) -> Option<McpProviderConfig> {
let client = mcp_config.client.as_ref()?;
if !client.supported {
return None;
}
let mapping = client.provider_mapping.as_ref();
let tool_type = mapping
.and_then(|m| m.get("tool_type"))
.and_then(|v| v.as_str())
.unwrap_or("mcp")
.to_string();
let beta_header = mapping
.and_then(|m| m.get("beta_header"))
.and_then(|v| v.as_str())
.map(String::from);
let api_endpoint = mapping
.and_then(|m| m.get("api_endpoint"))
.and_then(|v| v.as_str())
.map(String::from);
let config_method = mapping
.and_then(|m| m.get("config_method"))
.and_then(|v| v.as_str())
.map(|s| match s {
"sdk_config" => McpConfigMethod::SdkConfig,
"cli_flag" => McpConfigMethod::CliFlag,
_ => McpConfigMethod::ToolParameter,
})
.unwrap_or(McpConfigMethod::ToolParameter);
Some(McpProviderConfig {
tool_type,
beta_header,
api_endpoint,
config_method,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_mcp_tools() -> Vec<McpTool> {
vec![
McpTool {
name: "read_file".into(),
description: Some("Read a file from disk".into()),
input_schema: Some(serde_json::json!({
"type": "object",
"properties": { "path": { "type": "string" } },
"required": ["path"]
})),
},
McpTool {
name: "search".into(),
description: Some("Search the web".into()),
input_schema: Some(serde_json::json!({
"type": "object",
"properties": { "query": { "type": "string" } },
"required": ["query"]
})),
},
McpTool {
name: "exec_dangerous".into(),
description: Some("Execute shell command".into()),
input_schema: None,
},
]
}
#[test]
fn test_mcp_to_protocol_conversion() {
let bridge = McpToolBridge::new("fileserver");
let tools = bridge.mcp_tools_to_protocol(&sample_mcp_tools());
assert_eq!(tools.len(), 3);
assert_eq!(tools[0].function.name, "mcp__fileserver__read_file");
assert_eq!(tools[0].tool_type, "function");
assert!(tools[0].function.parameters.is_some());
}
#[test]
fn test_tool_filtering_allow() {
let bridge =
McpToolBridge::new("srv").with_allow_filter(vec!["read_file".into(), "search".into()]);
let tools = bridge.mcp_tools_to_protocol(&sample_mcp_tools());
assert_eq!(tools.len(), 2);
assert!(tools
.iter()
.all(|t| !t.function.name.contains("exec_dangerous")));
}
#[test]
fn test_tool_filtering_deny() {
let bridge = McpToolBridge::new("srv").with_deny_filter(vec!["exec_dangerous".into()]);
let tools = bridge.mcp_tools_to_protocol(&sample_mcp_tools());
assert_eq!(tools.len(), 2);
}
#[test]
fn test_protocol_call_to_mcp() {
let bridge = McpToolBridge::new("srv");
let call = ToolCall {
id: "call_123".into(),
name: "mcp__srv__read_file".into(),
arguments: serde_json::json!({"path": "/tmp/test.txt"}),
};
let invocation = bridge.protocol_call_to_mcp(&call).unwrap();
assert_eq!(invocation.name, "read_file");
assert_eq!(invocation.arguments["path"], "/tmp/test.txt");
}
#[test]
fn test_protocol_call_wrong_namespace() {
let bridge = McpToolBridge::new("srv");
let call = ToolCall {
id: "call_1".into(),
name: "mcp__other__read_file".into(),
arguments: Value::Null,
};
assert!(bridge.protocol_call_to_mcp(&call).is_none());
}
#[test]
fn test_mcp_result_to_protocol() {
let bridge = McpToolBridge::new("srv");
let result = McpToolResult {
content: vec![McpContent {
content_type: "text".into(),
text: Some("file contents here".into()),
extra: HashMap::new(),
}],
is_error: false,
};
let proto = bridge.mcp_result_to_protocol("call_123", &result);
assert_eq!(proto.tool_use_id, "call_123");
assert!(!proto.is_error);
}
#[test]
fn test_mcp_result_error() {
let bridge = McpToolBridge::new("srv");
let result = McpToolResult {
content: vec![McpContent {
content_type: "text".into(),
text: Some("file not found".into()),
extra: HashMap::new(),
}],
is_error: true,
};
let proto = bridge.mcp_result_to_protocol("call_1", &result);
assert!(proto.is_error);
assert!(proto.content["error"]
.as_str()
.unwrap()
.contains("file not found"));
}
#[test]
fn test_extract_provider_config() {
use crate::protocol::v2::manifest::McpClientConfig;
let config = McpConfig {
client: Some(McpClientConfig {
supported: true,
protocol_version: Some("2025-11-25".into()),
transports: vec!["sse".into()],
auth_methods: vec![],
capabilities: None,
tool_filtering: None,
approval_modes: vec![],
provider_mapping: Some(HashMap::from([
("tool_type".into(), Value::String("mcp".into())),
(
"beta_header".into(),
Value::String("mcp-client-2025-11-20".into()),
),
(
"config_method".into(),
Value::String("tool_parameter".into()),
),
])),
}),
server: None,
};
let prov = extract_provider_config(&config).unwrap();
assert_eq!(prov.tool_type, "mcp");
assert_eq!(prov.beta_header.as_deref(), Some("mcp-client-2025-11-20"));
assert_eq!(prov.config_method, McpConfigMethod::ToolParameter);
}
}