1use crate::tool::ToolDefinition;
2pub use mcp_client::Error;
3pub use mcp_client::McpService;
4pub use mcp_client::client::{ClientCapabilities, ClientInfo};
5use mcp_client::client::{McpClient, McpClientTrait};
6pub use mcp_client::transport::{SseTransport, StdioTransport, Transport};
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::ops::Deref;
10use std::path::Path;
11use std::time::Duration;
12use thiserror::Error;
13use tracing_subscriber::EnvFilter;
14
15#[derive(Deserialize)]
16pub struct MCPServerConfig {
17 pub command: String,
18 #[serde(default)]
19 pub args: Vec<String>,
20 #[serde(default)]
21 pub env: HashMap<String, String>,
22}
23
24#[derive(Deserialize)]
25pub struct MCPConfig {
26 #[serde(rename = "mcpServers")]
27 pub mcp_servers: HashMap<String, MCPServerConfig>,
28}
29
30pub struct MCPClient {
31 pub client: Box<dyn McpClientTrait>,
32 pub tools: HashMap<String, ToolDefinition>,
33}
34
35impl Deref for MCPClient {
36 type Target = Box<dyn McpClientTrait>;
37
38 fn deref(&self) -> &Self::Target {
39 &self.client
40 }
41}
42
43pub async fn setup_mcp_clients<P: AsRef<Path>>(
47 path: P,
48) -> Result<HashMap<String, MCPClient>, MCPError> {
49 let path = path.as_ref();
50 let config_str = tokio::fs::read_to_string(&path).await?;
51 let config: MCPConfig = serde_json::from_str(&config_str)?;
52
53 let mut mcp_clients_map = HashMap::new();
54
55 for (server_name, server_conf) in config.mcp_servers {
57 let client = stdio_client(server_conf.command, server_conf.args, server_conf.env).await?;
58 mcp_clients_map.insert(server_name, client);
59 }
60
61 Ok(mcp_clients_map)
62}
63
64#[derive(Error, Debug)]
65pub enum MCPError {
66 #[error("Failed to read config file: {0}")]
67 ConfigReadError(#[from] std::io::Error),
68 #[error("Failed to parse config file: {0}")]
69 ConfigParseError(#[from] serde_json::Error),
70 #[error("MCP error {0}")]
71 MCPError(#[from] Error),
72}
73
74pub async fn sse_client<S: AsRef<str>>(
76 sse_url: S,
77 env: HashMap<String, String>,
78) -> Result<MCPClient, MCPError> {
79 tracing_subscriber::fmt()
81 .with_env_filter(
82 EnvFilter::from_default_env()
83 .add_directive("mcp_client=debug".parse().unwrap())
84 .add_directive("eventsource_client=info".parse().unwrap()),
85 )
86 .init();
87 let transport = SseTransport::new(sse_url.as_ref(), env);
89 let handle = transport.start().await.map_err(Error::Transport)?;
91 let service = McpService::with_timeout(handle, Duration::from_secs(3));
93 let mut client = McpClient::new(service);
95 client
97 .initialize(
98 ClientInfo {
99 name: "alith-client".into(),
100 version: "1.0.0".into(),
101 },
102 ClientCapabilities::default(),
103 )
104 .await?;
105 tokio::time::sleep(Duration::from_millis(100)).await;
107 let tool_result = client.list_tools(None).await?;
108 let mut tools = HashMap::new();
109 for tool in tool_result.tools {
110 tools.insert(
111 tool.name.clone(),
112 ToolDefinition {
113 name: tool.name,
114 description: tool.description,
115 parameters: tool.input_schema,
116 },
117 );
118 }
119 Ok(MCPClient {
120 client: Box::new(client),
121 tools,
122 })
123}
124
125pub async fn stdio_client<S: AsRef<str>>(
126 command: S,
127 args: Vec<S>,
128 env: HashMap<String, String>,
129) -> Result<MCPClient, MCPError> {
130 tracing_subscriber::fmt()
132 .with_env_filter(
133 EnvFilter::from_default_env()
134 .add_directive("mcp_client=debug".parse().unwrap())
135 .add_directive("eventsource_client=info".parse().unwrap()),
136 )
137 .init();
138 let transport = StdioTransport::new(
140 command.as_ref().to_string(),
141 args.iter().map(|s| s.as_ref().to_string()).collect(),
142 env,
143 );
144 let handle = transport.start().await.map_err(Error::Transport)?;
146 let service = McpService::with_timeout(handle, Duration::from_secs(3));
148 let mut client = McpClient::new(service);
150 client
152 .initialize(
153 ClientInfo {
154 name: "alith-client".into(),
155 version: "1.0.0".into(),
156 },
157 ClientCapabilities::default(),
158 )
159 .await?;
160 tokio::time::sleep(Duration::from_millis(100)).await;
162 let tool_result = client.list_tools(None).await?;
163 let mut tools = HashMap::new();
164 for tool in tool_result.tools {
165 tools.insert(
166 tool.name.clone(),
167 ToolDefinition {
168 name: tool.name,
169 description: tool.description,
170 parameters: tool.input_schema,
171 },
172 );
173 }
174 Ok(MCPClient {
175 client: Box::new(client),
176 tools,
177 })
178}