use super::{MCPConnectionParams, MCPSessionManager, MCPTool};
use crate::errors::{AgentError, AgentResult};
use crate::tools::{BaseTool, BaseToolset};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::OnceCell;
use tracing::{debug, error, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MCPToolFilter {
All,
Include(Vec<String>),
Exclude(Vec<String>),
}
impl MCPToolFilter {
#[must_use]
pub fn matches(&self, tool_name: &str) -> bool {
match self {
Self::All => true,
Self::Include(names) => names.iter().any(|name| name == tool_name),
Self::Exclude(names) => !names.iter().any(|name| name == tool_name),
}
}
}
pub struct MCPToolset {
session_manager: Arc<MCPSessionManager>,
tool_filter: Option<MCPToolFilter>,
tools: OnceCell<Vec<Box<dyn BaseTool>>>,
}
impl MCPToolset {
#[must_use]
pub fn new(connection_params: MCPConnectionParams) -> Self {
Self {
session_manager: Arc::new(MCPSessionManager::new(connection_params)),
tool_filter: None,
tools: OnceCell::new(),
}
}
#[must_use]
pub fn with_filter(mut self, filter: MCPToolFilter) -> Self {
self.tool_filter = Some(filter);
self
}
pub async fn test_connection(&self) -> AgentResult<()> {
info!("Testing MCP connection");
let session = self.session_manager.create_session(None).await?;
match session.list_all_tools().await {
Ok(_) => {
info!("MCP connection test successful");
Ok(())
}
Err(e) => {
error!("MCP connection test failed: {:?}", e);
Err(AgentError::ToolSetupFailed {
tool_name: "mcp_connection_test".to_string(),
reason: format!("MCP connection failed: {e:?}"),
})
}
}
}
}
#[cfg_attr(all(target_os = "wasi", target_env = "p1"), async_trait::async_trait(?Send))]
#[cfg_attr(
not(all(target_os = "wasi", target_env = "p1")),
async_trait::async_trait
)]
impl BaseToolset for MCPToolset {
async fn get_tools(&self) -> Vec<&dyn BaseTool> {
let tools = self
.tools
.get_or_init(|| async {
debug!("Discovering MCP tools");
let session = match self.session_manager.create_session(None).await {
Ok(s) => s,
Err(e) => {
error!("Failed to create MCP session for tool discovery: {}", e);
return Vec::new();
}
};
let tools_response = match session.list_all_tools().await {
Ok(response) => response,
Err(e) => {
error!("Failed to list MCP tools: {:?}", e);
return Vec::new();
}
};
info!("Discovered {} MCP tools", tools_response.len());
let mut radkit_tools: Vec<Box<dyn BaseTool>> = Vec::new();
for mcp_tool in tools_response {
if let Some(ref filter) = self.tool_filter {
if !filter.matches(&mcp_tool.name) {
debug!("Filtering out tool: {}", mcp_tool.name);
continue;
}
}
let tool = Box::new(MCPTool::new(
mcp_tool.name.to_string(),
mcp_tool
.description
.map(|d| d.to_string())
.unwrap_or_default(),
serde_json::Value::Object((*mcp_tool.input_schema).clone()),
Arc::clone(&self.session_manager),
));
debug!("Added MCP tool: {}", tool.name());
radkit_tools.push(tool);
}
info!(
"Created {} Radkit tools from MCP server",
radkit_tools.len()
);
radkit_tools
})
.await;
tools.iter().map(std::convert::AsRef::as_ref).collect()
}
async fn close(&self) {
info!("Closing MCPToolset");
self.session_manager.close().await;
}
}
impl std::fmt::Debug for MCPToolset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MCPToolset")
.field("tool_filter", &self.tool_filter)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_tool_filter_all() {
let filter = MCPToolFilter::All;
assert!(filter.matches("any_tool"));
assert!(filter.matches("another_tool"));
}
#[test]
fn test_tool_filter_include() {
let filter = MCPToolFilter::Include(vec!["tool1".to_string(), "tool2".to_string()]);
assert!(filter.matches("tool1"));
assert!(filter.matches("tool2"));
assert!(!filter.matches("tool3"));
}
#[test]
fn test_tool_filter_exclude() {
let filter = MCPToolFilter::Exclude(vec!["bad_tool".to_string()]);
assert!(filter.matches("good_tool"));
assert!(!filter.matches("bad_tool"));
}
#[test]
fn test_mcp_toolset_creation() {
let params = MCPConnectionParams::Stdio {
command: "echo".to_string(),
args: vec!["test".to_string()],
env: HashMap::new(),
timeout: std::time::Duration::from_secs(5),
};
let toolset = MCPToolset::new(params);
assert!(toolset.tool_filter.is_none());
let toolset = toolset.with_filter(MCPToolFilter::All);
assert!(matches!(toolset.tool_filter, Some(MCPToolFilter::All)));
}
}