1use std::{collections::HashMap, ffi::OsStr, path::Path};
2use abu_mcp::{client::McpClient, transport::process::McpProcessTransport, McpToolCall, McpToolCallResult, McpToolCallResultContent};
3use abu_tool::{ToolCallResult, ToolError};
4use thiserrorctx::Context;
5use serde::Deserialize;
6use tracing::{debug, warn};
7use crate::AgentResult;
8
9pub struct McpManager {
10 pub default_protocol_version: String,
11 pub stdio_servers: Vec<McpClient<McpProcessTransport>>
12}
13
14#[derive(Debug, Clone, Deserialize)]
15pub struct McpConfig {
16 #[serde(default = "default_protocol_version", alias = "defaultProtocolVersion")]
17 pub default_protocol_version: String,
18 #[serde(alias = "mcpServers")]
19 pub mcp_servers: HashMap<String, McpServerConfig>,
20}
21
22#[derive(Debug, Clone, Deserialize)]
23pub struct McpServerConfig {
24 pub transport: String,
25 pub command: String,
26 #[serde(default)]
27 pub args: Vec<String>,
28 #[serde(default)]
29 pub env: HashMap<String, String>,
30}
31
32impl McpManager {
33 pub fn new() -> Self {
34 Self {
35 default_protocol_version: default_protocol_version(),
36 stdio_servers: vec![],
37 }
38 }
39
40 pub async fn load_config(path: impl AsRef<Path>) -> AgentResult<Self> {
41 debug!("load mcp config from {}", path.as_ref().display());
42 let context = std::fs::read_to_string(path).context("read config file")?;
43 let config: McpConfig = serde_json::from_str(&context).context("parse config file")?;
44
45 let mut mcp_manager = McpManager { default_protocol_version: config.default_protocol_version, stdio_servers: vec![],};
46 for (name, server_config) in config.mcp_servers {
47 debug!("add mcp server {}", name);
48 match server_config.transport.as_str() {
49 "stdio" => {
50 mcp_manager.add_stdio_server(server_config.command, server_config.args)
51 .await.with_context(|| format!("init client {}", name))?;
52 }
53 transport => warn!("unsupport transport '{}' in mcpserver {}", transport, name),
54 };
55 }
56
57 Ok(mcp_manager)
58 }
59
60 pub async fn add_stdio_server<S, I>(&mut self, cmd: S, args: I) -> AgentResult<&McpClient<McpProcessTransport>>
61 where
62 I: IntoIterator<Item = S>,
63 S: AsRef<OsStr>,
64 {
65 let client = Self::init_stdio_clinet(cmd, args).await?;
66 self.stdio_servers.push(client);
67 Ok(self.stdio_servers.last().unwrap())
68 }
69
70 pub async fn execute_toolcall(&mut self, name: String, arguments: serde_json::Value) -> AgentResult<ToolCallResult> {
71 for client in self.stdio_servers.iter_mut() {
72 if client.has_tool(&name) {
73 let mcp_tool_call = McpToolCall {
74 name, arguments: Some(arguments)
75 };
76 let mcp_tool_call_result = client.tools_call(mcp_tool_call).await?;
77 let tool_call_result = mcp_tool_call_result_to_tool_call_result(mcp_tool_call_result);
78 return Ok(tool_call_result)
79 }
80 }
81 Err(ToolError::ToolNotFound(name))?
82 }
83
84 pub fn has_tool(&self, tool_name: &str) -> bool {
85 for client in self.stdio_servers.iter() {
86 if client.has_tool(tool_name) {
87 return true;
88 }
89 }
90 false
91 }
92
93 pub async fn init_stdio_clinet<I, S>(cmd: S, args: I) -> AgentResult<McpClient<McpProcessTransport>>
94 where
95 I: IntoIterator<Item = S>,
96 S: AsRef<OsStr>,
97 {
98 let transport = McpProcessTransport::new(cmd, args)
99 .context("new process transport")?;
100 let mut client = McpClient::new(transport);
101 client.initialize().await.context("initialize mcpserver")?;
102 client.tools_list().await.context("tools_list mcpserver")?;
103 Ok(client)
104 }
105}
106
107fn default_protocol_version() -> String {
108 abu_mcp::LATEST_PROTOCOL_VERSION.to_string()
109}
110
111fn mcp_tool_call_result_to_tool_call_result(result: McpToolCallResult) -> ToolCallResult {
112 let is_error = result.is_error.unwrap_or(false);
113 let context = result
114 .content
115 .iter()
116 .map(|content| {
117 match content {
118 McpToolCallResultContent::Text { text } => text.as_str(),
119 }
120 })
121 .collect::<Vec<&str>>()
122 .join("\n");
123 ToolCallResult { is_error, context }
124}