use super::transport::{HttpTransport, McpTransport, StdioTransport};
use super::types::*;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct McpClient {
transport: Arc<Mutex<Box<dyn McpTransport>>>,
server_info: Option<ServerInfo>,
capabilities: Option<ServerCapabilities>,
}
impl McpClient {
pub async fn connect_stdio(
command: &str,
args: &[&str],
env: Option<HashMap<String, String>>,
) -> Result<Self, McpError> {
let transport = StdioTransport::new(command, args, env).await?;
let mut client = Self {
transport: Arc::new(Mutex::new(Box::new(transport))),
server_info: None,
capabilities: None,
};
client.initialize().await?;
Ok(client)
}
pub async fn connect_http(url: &str) -> Result<Self, McpError> {
let transport = HttpTransport::new(url)?;
let mut client = Self {
transport: Arc::new(Mutex::new(Box::new(transport))),
server_info: None,
capabilities: None,
};
client.initialize().await?;
Ok(client)
}
pub fn from_transport(transport: Box<dyn McpTransport>) -> Self {
Self {
transport: Arc::new(Mutex::new(transport)),
server_info: None,
capabilities: None,
}
}
pub async fn initialize(&mut self) -> Result<ServerInfo, McpError> {
let params = serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": ClientInfo::default()
});
let request = JsonRpcRequest::new("initialize", Some(params));
let response = self.send_request(request).await?;
let result: InitializeResult = serde_json::from_value(response)?;
self.server_info = Some(result.server_info.clone());
self.capabilities = Some(result.capabilities);
let notify = JsonRpcRequest::new("notifications/initialized", None);
let _ = self.send_request(notify).await;
Ok(result.server_info)
}
pub async fn list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
let request = JsonRpcRequest::new("tools/list", Some(serde_json::json!({})));
let response = self.send_request(request).await?;
let result: ToolsListResult = serde_json::from_value(response)?;
Ok(result.tools)
}
pub async fn call_tool(
&self,
name: &str,
arguments: serde_json::Value,
) -> Result<McpToolCallResult, McpError> {
let params = serde_json::json!({
"name": name,
"arguments": arguments
});
let request = JsonRpcRequest::new("tools/call", Some(params));
let response = self.send_request(request).await?;
let result: McpToolCallResult = serde_json::from_value(response)?;
Ok(result)
}
pub async fn close(&self) -> Result<(), McpError> {
self.transport.lock().await.close().await
}
pub fn server_info(&self) -> Option<&ServerInfo> {
self.server_info.as_ref()
}
async fn send_request(&self, request: JsonRpcRequest) -> Result<serde_json::Value, McpError> {
let transport = self.transport.lock().await;
let response = transport.send(request).await?;
if let Some(error) = response.error {
return Err(McpError::JsonRpc {
code: error.code,
message: error.message,
});
}
response
.result
.ok_or_else(|| McpError::Protocol("Response has neither result nor error".into()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_info_default() {
let info = ClientInfo::default();
assert_eq!(info.name, "yoagent");
assert!(!info.version.is_empty());
}
}