alith_core/
mcp.rs

1use crate::tool::ToolDefinition;
2pub use mcp_client::Error;
3pub use mcp_client::McpService;
4pub use mcp_client::client::{ClientCapabilities, ClientInfo};
5use mcp_client::client::{McpClient, McpClientTrait};
6pub use mcp_client::transport::{SseTransport, StdioTransport, Transport};
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::ops::Deref;
10use std::path::Path;
11use std::time::Duration;
12use thiserror::Error;
13use tracing_subscriber::EnvFilter;
14
15#[derive(Deserialize)]
16pub struct MCPServerConfig {
17    pub command: String,
18    #[serde(default)]
19    pub args: Vec<String>,
20    #[serde(default)]
21    pub env: HashMap<String, String>,
22}
23
24#[derive(Deserialize)]
25pub struct MCPConfig {
26    #[serde(rename = "mcpServers")]
27    pub mcp_servers: HashMap<String, MCPServerConfig>,
28}
29
30pub struct MCPClient {
31    pub client: Box<dyn McpClientTrait>,
32    pub tools: HashMap<String, ToolDefinition>,
33}
34
35impl Deref for MCPClient {
36    type Target = Box<dyn McpClientTrait>;
37
38    fn deref(&self) -> &Self::Target {
39        &self.client
40    }
41}
42
43/// Set up MCP clients config from the path, spawning each server,
44/// and returning a HashMap<server_name -> Arc<Client>>.
45/// spawn a single MCP process per server, share references.
46pub async fn setup_mcp_clients<P: AsRef<Path>>(
47    path: P,
48) -> Result<HashMap<String, MCPClient>, MCPError> {
49    let path = path.as_ref();
50    let config_str = tokio::fs::read_to_string(&path).await?;
51    let config: MCPConfig = serde_json::from_str(&config_str)?;
52
53    let mut mcp_clients_map = HashMap::new();
54
55    // For each server in the config, spawn an MCP client
56    for (server_name, server_conf) in config.mcp_servers {
57        let client = stdio_client(server_conf.command, server_conf.args, server_conf.env).await?;
58        mcp_clients_map.insert(server_name, client);
59    }
60
61    Ok(mcp_clients_map)
62}
63
64#[derive(Error, Debug)]
65pub enum MCPError {
66    #[error("Failed to read config file: {0}")]
67    ConfigReadError(#[from] std::io::Error),
68    #[error("Failed to parse config file: {0}")]
69    ConfigParseError(#[from] serde_json::Error),
70    #[error("MCP error {0}")]
71    MCPError(#[from] Error),
72}
73
74/// Create a sse mcp client.
75pub async fn sse_client<S: AsRef<str>>(
76    sse_url: S,
77    env: HashMap<String, String>,
78) -> Result<MCPClient, MCPError> {
79    // Initialize logging
80    tracing_subscriber::fmt()
81        .with_env_filter(
82            EnvFilter::from_default_env()
83                .add_directive("mcp_client=debug".parse().unwrap())
84                .add_directive("eventsource_client=info".parse().unwrap()),
85        )
86        .init();
87    // Create the base transport
88    let transport = SseTransport::new(sse_url.as_ref(), env);
89    // Start transport
90    let handle = transport.start().await.map_err(Error::Transport)?;
91    // Create the service with timeout middleware
92    let service = McpService::with_timeout(handle, Duration::from_secs(3));
93    // Create client
94    let mut client = McpClient::new(service);
95    // Initialize
96    client
97        .initialize(
98            ClientInfo {
99                name: "alith-client".into(),
100                version: "1.0.0".into(),
101            },
102            ClientCapabilities::default(),
103        )
104        .await?;
105    // Sleep for 100ms to allow the server to start - surprisingly this is required!
106    tokio::time::sleep(Duration::from_millis(100)).await;
107    let tool_result = client.list_tools(None).await?;
108    let mut tools = HashMap::new();
109    for tool in tool_result.tools {
110        tools.insert(
111            tool.name.clone(),
112            ToolDefinition {
113                name: tool.name,
114                description: tool.description,
115                parameters: tool.input_schema,
116            },
117        );
118    }
119    Ok(MCPClient {
120        client: Box::new(client),
121        tools,
122    })
123}
124
125pub async fn stdio_client<S: AsRef<str>>(
126    command: S,
127    args: Vec<S>,
128    env: HashMap<String, String>,
129) -> Result<MCPClient, MCPError> {
130    // Initialize logging
131    tracing_subscriber::fmt()
132        .with_env_filter(
133            EnvFilter::from_default_env()
134                .add_directive("mcp_client=debug".parse().unwrap())
135                .add_directive("eventsource_client=info".parse().unwrap()),
136        )
137        .init();
138    // Create the base transport
139    let transport = StdioTransport::new(
140        command.as_ref().to_string(),
141        args.iter().map(|s| s.as_ref().to_string()).collect(),
142        env,
143    );
144    // Start transport
145    let handle = transport.start().await.map_err(Error::Transport)?;
146    // Create the service with timeout middleware
147    let service = McpService::with_timeout(handle, Duration::from_secs(3));
148    // Create client
149    let mut client = McpClient::new(service);
150    // Initialize
151    client
152        .initialize(
153            ClientInfo {
154                name: "alith-client".into(),
155                version: "1.0.0".into(),
156            },
157            ClientCapabilities::default(),
158        )
159        .await?;
160    // Sleep for 100ms to allow the server to start - surprisingly this is required!
161    tokio::time::sleep(Duration::from_millis(100)).await;
162    let tool_result = client.list_tools(None).await?;
163    let mut tools = HashMap::new();
164    for tool in tool_result.tools {
165        tools.insert(
166            tool.name.clone(),
167            ToolDefinition {
168                name: tool.name,
169                description: tool.description,
170                parameters: tool.input_schema,
171            },
172        );
173    }
174    Ok(MCPClient {
175        client: Box::new(client),
176        tools,
177    })
178}