use anyhow::Result;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
use std::time::Duration;
use tokio::sync::{Mutex, RwLock, mpsc};
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
use super::jsonrpc::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, RequestId};
use super::transport::Transport;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct McpTool {
pub name: String,
pub description: String,
#[serde(rename = "inputSchema")]
pub input_schema: Value,
}
#[derive(Debug, Clone)]
pub struct McpToolResult {
pub content: Vec<McpContent>,
pub is_error: bool,
}
#[derive(Debug, Clone)]
pub enum McpContent {
Text { text: String },
Image { data: String, mime_type: String },
}
pub struct McpClient {
transport: Arc<Mutex<Box<dyn Transport>>>,
request_id: AtomicI64,
pending_requests:
Arc<RwLock<HashMap<RequestId, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
tools: Arc<RwLock<HashMap<String, McpTool>>>,
config: McpClientConfig,
shutdown_tx: Option<mpsc::Sender<()>>,
}
#[derive(Debug, Clone)]
pub struct McpClientConfig {
pub request_timeout: Duration,
pub max_concurrent_requests: usize,
pub client_name: String,
pub client_version: String,
}
impl Default for McpClientConfig {
fn default() -> Self {
Self {
request_timeout: Duration::from_secs(30),
max_concurrent_requests: 100,
client_name: "ccswarm".to_string(),
client_version: env!("CARGO_PKG_VERSION").to_string(),
}
}
}
impl McpClient {
pub fn new(transport: Box<dyn Transport>) -> Self {
Self::with_config(transport, McpClientConfig::default())
}
pub fn with_config(transport: Box<dyn Transport>, config: McpClientConfig) -> Self {
Self {
transport: Arc::new(Mutex::new(transport)),
request_id: AtomicI64::new(1),
pending_requests: Arc::new(RwLock::new(HashMap::new())),
tools: Arc::new(RwLock::new(HashMap::new())),
config,
shutdown_tx: None,
}
}
pub async fn start(&mut self) -> Result<()> {
info!(
"Starting MCP client: {}/{}",
self.config.client_name, self.config.client_version
);
self.initialize().await?;
self.discover_tools().await?;
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
self.shutdown_tx = Some(shutdown_tx);
let transport = self.transport.clone();
let pending_requests = self.pending_requests.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
info!("MCP client shutting down");
break;
}
message_result = async {
let mut transport = transport.lock().await;
transport.receive().await
} => {
match message_result {
Ok(Some(message)) => {
if let Err(e) = Self::handle_message(message, &pending_requests).await {
error!("Error handling message: {}", e);
}
}
Ok(None) => {
warn!("Transport closed");
break;
}
Err(e) => {
error!("Error receiving message: {}", e);
}
}
}
}
}
});
Ok(())
}
pub async fn shutdown(&mut self) -> Result<()> {
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(()).await;
}
let mut transport = self.transport.lock().await;
transport.close().await?;
Ok(())
}
async fn initialize(&self) -> Result<()> {
let request = JsonRpcRequest::new(
self.next_request_id(),
"initialize".to_string(),
Some(json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": self.config.client_name,
"version": self.config.client_version
}
})),
);
let response = self.send_request(request).await?;
if response.error.is_some() {
return Err(anyhow::anyhow!("Initialize failed: {:?}", response.error));
}
info!("MCP client initialized successfully");
Ok(())
}
async fn discover_tools(&self) -> Result<()> {
let request = JsonRpcRequest::new(self.next_request_id(), "tools/list".to_string(), None);
let response = self.send_request(request).await?;
if let Some(error) = response.error {
return Err(anyhow::anyhow!("Tools discovery failed: {:?}", error));
}
if let Some(result) = response.result
&& let Some(tools_array) = result.get("tools").and_then(|v| v.as_array())
{
let mut tools = self.tools.write().await;
for tool_value in tools_array {
if let Ok(tool) = serde_json::from_value::<McpTool>(tool_value.clone()) {
debug!("Discovered tool: {}", tool.name);
tools.insert(tool.name.clone(), tool);
}
}
info!("Discovered {} tools", tools.len());
}
Ok(())
}
pub async fn execute_tool(&self, tool_name: &str, arguments: Value) -> Result<McpToolResult> {
{
let tools = self.tools.read().await;
if !tools.contains_key(tool_name) {
return Err(anyhow::anyhow!("Tool '{}' not found", tool_name));
}
}
let request = JsonRpcRequest::new(
self.next_request_id(),
"tools/call".to_string(),
Some(json!({
"name": tool_name,
"arguments": arguments
})),
);
let response = self.send_request(request).await?;
if let Some(error) = response.error {
return Err(anyhow::anyhow!("Tool execution failed: {:?}", error));
}
let result = response
.result
.ok_or_else(|| anyhow::anyhow!("No result from tool execution"))?;
let content = result
.get("content")
.and_then(|v| v.as_array())
.unwrap_or(&vec![])
.iter()
.filter_map(|v| {
if let Some(text) = v.get("text").and_then(|t| t.as_str()) {
Some(McpContent::Text {
text: text.to_string(),
})
} else if let (Some(data), Some(mime_type)) = (
v.get("data").and_then(|d| d.as_str()),
v.get("mimeType").and_then(|m| m.as_str()),
) {
Some(McpContent::Image {
data: data.to_string(),
mime_type: mime_type.to_string(),
})
} else {
None
}
})
.collect();
let is_error = result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false);
Ok(McpToolResult { content, is_error })
}
pub async fn list_tools(&self) -> Vec<McpTool> {
let tools = self.tools.read().await;
tools.values().cloned().collect()
}
async fn send_request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
{
let mut pending = self.pending_requests.write().await;
pending.insert(request.id.clone(), response_tx);
}
{
let mut transport = self.transport.lock().await;
transport.send(&JsonRpcMessage::Request(request)).await?;
}
let response = timeout(self.config.request_timeout, response_rx)
.await
.map_err(|_| anyhow::anyhow!("Request timeout"))?
.map_err(|_| anyhow::anyhow!("Response channel closed"))?;
Ok(response)
}
async fn handle_message(
message: JsonRpcMessage,
pending_requests: &Arc<
RwLock<HashMap<RequestId, tokio::sync::oneshot::Sender<JsonRpcResponse>>>,
>,
) -> Result<()> {
match message {
JsonRpcMessage::Response(response) => {
let mut pending = pending_requests.write().await;
if let Some(sender) = pending.remove(&response.id) {
let _ = sender.send(response);
} else {
warn!("Received response for unknown request: {}", response.id);
}
}
JsonRpcMessage::Notification(notification) => {
debug!("Received notification: {}", notification.method);
}
JsonRpcMessage::Request(request) => {
warn!("Client received unexpected request: {}", request.method);
}
}
Ok(())
}
pub fn next_request_id(&self) -> RequestId {
RequestId::Number(self.request_id.fetch_add(1, Ordering::SeqCst))
}
}
pub struct AiSessionClient {
mcp_client: McpClient,
}
impl AiSessionClient {
pub fn new(transport: Box<dyn Transport>) -> Self {
Self {
mcp_client: McpClient::new(transport),
}
}
pub async fn start(&mut self) -> Result<()> {
self.mcp_client.start().await
}
pub async fn shutdown(&mut self) -> Result<()> {
self.mcp_client.shutdown().await
}
pub async fn create_session(
&self,
name: &str,
working_directory: Option<&str>,
) -> Result<String> {
let mut args = json!({ "name": name });
if let Some(wd) = working_directory {
args["working_directory"] = json!(wd);
}
let result = self.mcp_client.execute_tool("create_session", args).await?;
if result.is_error {
return Err(anyhow::anyhow!("Failed to create session"));
}
if let Some(McpContent::Text { text }) = result.content.first() {
if let Some(start) = text.find("ID: ") {
let id_start = start + 4;
if let Some(end) = text[id_start..].find(char::is_whitespace) {
return Ok(text[id_start..id_start + end].to_string());
} else {
return Ok(text[id_start..].to_string());
}
}
}
Err(anyhow::anyhow!(
"Could not extract session ID from response"
))
}
pub async fn execute_command(&self, session_id: &str, command: &str) -> Result<String> {
let args = json!({
"session_id": session_id,
"command": command
});
let result = self
.mcp_client
.execute_tool("execute_command", args)
.await?;
if result.is_error {
return Err(anyhow::anyhow!("Command execution failed"));
}
if let Some(McpContent::Text { text }) = result.content.first() {
Ok(text.clone())
} else {
Ok(String::new())
}
}
pub async fn get_session_info(&self, session_id: &str) -> Result<Value> {
let args = json!({ "session_id": session_id });
let result = self
.mcp_client
.execute_tool("get_session_info", args)
.await?;
if result.is_error {
return Err(anyhow::anyhow!("Failed to get session info"));
}
if let Some(McpContent::Text { text }) = result.content.first() {
serde_json::from_str(text)
.map_err(|e| anyhow::anyhow!("Failed to parse session info: {}", e))
} else {
Ok(json!({}))
}
}
}