Skip to main content

abu_agent/toolbox/
mcp.rs

1use std::{collections::HashMap, ffi::OsStr, path::Path};
2use abu_mcp::{client::McpClient, transport::process::McpProcessTransport, McpToolCall, McpToolCallResult, McpToolCallResultContent};
3use abu_tool::{ToolCallResult, ToolError};
4use thiserrorctx::Context;
5use serde::Deserialize;
6use tracing::{debug, warn};
7use crate::AgentResult;
8
9pub struct McpManager {
10    pub default_protocol_version: String,
11    pub stdio_servers: Vec<McpClient<McpProcessTransport>>
12}
13
14#[derive(Debug, Clone, Deserialize)]
15pub struct McpConfig {
16    #[serde(default = "default_protocol_version", alias = "defaultProtocolVersion")]
17    pub default_protocol_version: String,
18    #[serde(alias = "mcpServers")]
19    pub mcp_servers: HashMap<String, McpServerConfig>,
20}
21
22#[derive(Debug, Clone, Deserialize)]
23pub struct McpServerConfig {
24    pub transport: String,
25    pub command: String,
26    #[serde(default)]
27    pub args: Vec<String>,
28    #[serde(default)]
29    pub env: HashMap<String, String>,
30}
31
32impl McpManager {
33    pub fn new() -> Self {
34        Self {
35            default_protocol_version: default_protocol_version(),
36            stdio_servers: vec![],
37        }
38    }
39
40    pub async fn load_config(path: impl AsRef<Path>) -> AgentResult<Self> {
41        debug!("load mcp config from {}", path.as_ref().display());
42        let context = std::fs::read_to_string(path).context("read config file")?;
43        let config: McpConfig = serde_json::from_str(&context).context("parse config file")?;
44
45        let mut mcp_manager = McpManager { default_protocol_version: config.default_protocol_version, stdio_servers: vec![],};
46        for (name, server_config) in config.mcp_servers {
47            debug!("add mcp server {}", name);
48            match server_config.transport.as_str() {
49                "stdio" => {
50                    mcp_manager.add_stdio_server(server_config.command, server_config.args) 
51                        .await.with_context(|| format!("init client {}", name))?;
52                }
53                transport => warn!("unsupport transport '{}' in mcpserver {}", transport, name),
54            };
55        }
56
57        Ok(mcp_manager)
58    }
59
60    pub async fn add_stdio_server<S, I>(&mut self, cmd: S, args: I) -> AgentResult<&McpClient<McpProcessTransport>> 
61    where 
62        I: IntoIterator<Item = S>,
63        S: AsRef<OsStr>,
64    {
65        let client = Self::init_stdio_clinet(cmd, args).await?;
66        self.stdio_servers.push(client);
67        Ok(self.stdio_servers.last().unwrap())
68    }
69
70    pub async fn execute_toolcall(&mut self, name: &str, arguments: serde_json::Value) -> AgentResult<ToolCallResult> {
71        for client in self.stdio_servers.iter_mut() {
72            if client.has_tool(&name) {
73                let mcp_tool_call = McpToolCall {
74                    name: name.to_string(), arguments: Some(arguments)
75                };
76                let mcp_tool_call_result = client.tools_call(mcp_tool_call).await?;
77                let tool_call_result = mcp_tool_call_result_to_tool_call_result(mcp_tool_call_result);
78                return Ok(tool_call_result)
79            }
80        }
81        Err(ToolError::ToolNotFound(name.to_string()))?
82    }
83
84    pub fn has_tool(&self, tool_name: &str) -> bool {
85        for client in self.stdio_servers.iter() {
86            if client.has_tool(tool_name) {
87                return true;
88            }
89        }
90        false
91    }
92
93    pub async fn init_stdio_clinet<I, S>(cmd: S, args: I) -> AgentResult<McpClient<McpProcessTransport>> 
94    where 
95        I: IntoIterator<Item = S>,
96        S: AsRef<OsStr>,
97    {
98        let transport = McpProcessTransport::new(cmd, args)
99            .context("new process transport")?;
100        let mut client = McpClient::new(transport);
101        client.initialize().await.context("initialize mcpserver")?;
102        client.tools_list().await.context("tools_list mcpserver")?;
103        Ok(client)
104    }
105}
106
107fn default_protocol_version() -> String {
108    abu_mcp::LATEST_PROTOCOL_VERSION.to_string()
109}
110
111fn mcp_tool_call_result_to_tool_call_result(result: McpToolCallResult) -> ToolCallResult {
112    let is_error = result.is_error.unwrap_or(false);
113    let context = result
114        .content
115        .iter()
116        .map(|content| {
117            match content {
118                McpToolCallResultContent::Text { text } => text.as_str(),
119            }
120        })
121        .collect::<Vec<&str>>()
122        .join("\n");
123    ToolCallResult { is_error, context }
124}