radkit 0.0.5

Rust AI Agent Development Kit
Documentation
use super::{MCPConnectionParams, MCPSessionManager, MCPTool};
use crate::errors::{AgentError, AgentResult};
use crate::tools::{BaseTool, BaseToolset};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::OnceCell;
use tracing::{debug, error, info};

/// Filter for selecting which MCP tools to include
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MCPToolFilter {
    /// Include all tools
    All,
    /// Include only tools with these exact names
    Include(Vec<String>),
    /// Exclude tools with these names
    Exclude(Vec<String>),
}

impl MCPToolFilter {
    /// Check if a tool matches this filter
    #[must_use]
    pub fn matches(&self, tool_name: &str) -> bool {
        match self {
            Self::All => true,
            Self::Include(names) => names.iter().any(|name| name == tool_name),
            Self::Exclude(names) => !names.iter().any(|name| name == tool_name),
        }
    }
}

/// `MCPToolset` that lazily discovers tools and manages connections
pub struct MCPToolset {
    session_manager: Arc<MCPSessionManager>,
    tool_filter: Option<MCPToolFilter>,
    /// Cached tools discovered from the MCP server (lazily initialized)
    tools: OnceCell<Vec<Box<dyn BaseTool>>>,
}

impl MCPToolset {
    /// Create a new `MCPToolset` with the given connection parameters
    #[must_use]
    pub fn new(connection_params: MCPConnectionParams) -> Self {
        Self {
            session_manager: Arc::new(MCPSessionManager::new(connection_params)),
            tool_filter: None,
            tools: OnceCell::new(),
        }
    }

    /// Add a tool filter to limit which tools are exposed
    #[must_use]
    pub fn with_filter(mut self, filter: MCPToolFilter) -> Self {
        self.tool_filter = Some(filter);
        self
    }

    /// Test the connection to the MCP server
    ///
    /// # Errors
    ///
    /// Returns an error if the connection test fails or tools cannot be listed.
    pub async fn test_connection(&self) -> AgentResult<()> {
        info!("Testing MCP connection");
        let session = self.session_manager.create_session(None).await?;

        // Try to list tools as a connection test
        match session.list_all_tools().await {
            Ok(_) => {
                info!("MCP connection test successful");
                Ok(())
            }
            Err(e) => {
                error!("MCP connection test failed: {:?}", e);
                Err(AgentError::ToolSetupFailed {
                    tool_name: "mcp_connection_test".to_string(),
                    reason: format!("MCP connection failed: {e:?}"),
                })
            }
        }
    }
}

#[cfg_attr(all(target_os = "wasi", target_env = "p1"), async_trait::async_trait(?Send))]
#[cfg_attr(
    not(all(target_os = "wasi", target_env = "p1")),
    async_trait::async_trait
)]
impl BaseToolset for MCPToolset {
    async fn get_tools(&self) -> Vec<&dyn BaseTool> {
        // Lazily discover and cache tools on first call
        let tools = self
            .tools
            .get_or_init(|| async {
                debug!("Discovering MCP tools");

                // Create session for tool discovery
                let session = match self.session_manager.create_session(None).await {
                    Ok(s) => s,
                    Err(e) => {
                        error!("Failed to create MCP session for tool discovery: {}", e);
                        return Vec::new();
                    }
                };

                // Discover tools from server
                let tools_response = match session.list_all_tools().await {
                    Ok(response) => response,
                    Err(e) => {
                        error!("Failed to list MCP tools: {:?}", e);
                        return Vec::new();
                    }
                };

                info!("Discovered {} MCP tools", tools_response.len());

                // Convert to Radkit tools
                let mut radkit_tools: Vec<Box<dyn BaseTool>> = Vec::new();
                for mcp_tool in tools_response {
                    // Apply filter if configured
                    if let Some(ref filter) = self.tool_filter {
                        if !filter.matches(&mcp_tool.name) {
                            debug!("Filtering out tool: {}", mcp_tool.name);
                            continue;
                        }
                    }

                    let tool = Box::new(MCPTool::new(
                        mcp_tool.name.to_string(),
                        mcp_tool
                            .description
                            .map(|d| d.to_string())
                            .unwrap_or_default(),
                        serde_json::Value::Object((*mcp_tool.input_schema).clone()),
                        Arc::clone(&self.session_manager),
                    ));

                    debug!("Added MCP tool: {}", tool.name());
                    radkit_tools.push(tool);
                }

                info!(
                    "Created {} Radkit tools from MCP server",
                    radkit_tools.len()
                );
                radkit_tools
            })
            .await;

        // Return references to cached tools
        tools.iter().map(std::convert::AsRef::as_ref).collect()
    }

    async fn close(&self) {
        info!("Closing MCPToolset");
        self.session_manager.close().await;
    }
}

impl std::fmt::Debug for MCPToolset {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("MCPToolset")
            .field("tool_filter", &self.tool_filter)
            .finish_non_exhaustive()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap;

    #[test]
    fn test_tool_filter_all() {
        let filter = MCPToolFilter::All;
        assert!(filter.matches("any_tool"));
        assert!(filter.matches("another_tool"));
    }

    #[test]
    fn test_tool_filter_include() {
        let filter = MCPToolFilter::Include(vec!["tool1".to_string(), "tool2".to_string()]);
        assert!(filter.matches("tool1"));
        assert!(filter.matches("tool2"));
        assert!(!filter.matches("tool3"));
    }

    #[test]
    fn test_tool_filter_exclude() {
        let filter = MCPToolFilter::Exclude(vec!["bad_tool".to_string()]);
        assert!(filter.matches("good_tool"));
        assert!(!filter.matches("bad_tool"));
    }

    #[test]
    fn test_mcp_toolset_creation() {
        let params = MCPConnectionParams::Stdio {
            command: "echo".to_string(),
            args: vec!["test".to_string()],
            env: HashMap::new(),
            timeout: std::time::Duration::from_secs(5),
        };

        let toolset = MCPToolset::new(params);
        assert!(toolset.tool_filter.is_none());

        let toolset = toolset.with_filter(MCPToolFilter::All);
        assert!(matches!(toolset.tool_filter, Some(MCPToolFilter::All)));
    }
}