use crate::mcp::protocol::{
CallToolParams, CallToolResult, ClientCapabilities, ClientInfo, InitializeParams,
InitializeResult, JsonRpcNotification, JsonRpcRequest, ListResourcesResult, ListToolsResult,
McpNotification, McpResource, McpTool, ReadResourceParams, ReadResourceResult,
ServerCapabilities, PROTOCOL_VERSION,
};
use crate::mcp::transport::McpTransport;
use anyhow::{anyhow, Result};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
pub struct McpClient {
pub name: String,
transport: Arc<dyn McpTransport>,
capabilities: RwLock<ServerCapabilities>,
tools: RwLock<Vec<McpTool>>,
resources: RwLock<Vec<McpResource>>,
request_id: AtomicU64,
initialized: RwLock<bool>,
}
impl McpClient {
pub fn new(name: String, transport: Arc<dyn McpTransport>) -> Self {
Self {
name,
transport,
capabilities: RwLock::new(ServerCapabilities::default()),
tools: RwLock::new(Vec::new()),
resources: RwLock::new(Vec::new()),
request_id: AtomicU64::new(1),
initialized: RwLock::new(false),
}
}
fn next_id(&self) -> u64 {
self.request_id.fetch_add(1, Ordering::SeqCst)
}
pub async fn initialize(&self) -> Result<InitializeResult> {
let params = InitializeParams {
protocol_version: PROTOCOL_VERSION.to_string(),
capabilities: ClientCapabilities::default(),
client_info: ClientInfo {
name: "a3s-code".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
};
let request = JsonRpcRequest::new(
self.next_id(),
"initialize",
Some(serde_json::to_value(¶ms)?),
);
let response = self.transport.request(request).await?;
if let Some(error) = response.error {
return Err(anyhow!(
"MCP initialize error: {} ({})",
error.message,
error.code
));
}
let result: InitializeResult = serde_json::from_value(
response
.result
.ok_or_else(|| anyhow!("No result in response"))?,
)?;
{
let mut caps = self.capabilities.write().await;
*caps = result.capabilities.clone();
}
let notification = JsonRpcNotification::new("notifications/initialized", None);
self.transport.notify(notification).await?;
{
let mut init = self.initialized.write().await;
*init = true;
}
tracing::info!(
"MCP client '{}' initialized with server '{}' v{}",
self.name,
result.server_info.name,
result.server_info.version
);
Ok(result)
}
pub async fn is_initialized(&self) -> bool {
*self.initialized.read().await
}
pub async fn capabilities(&self) -> ServerCapabilities {
self.capabilities.read().await.clone()
}
pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
let request = JsonRpcRequest::new(self.next_id(), "tools/list", None);
let response = self.transport.request(request).await?;
if let Some(error) = response.error {
return Err(anyhow!(
"MCP list_tools error: {} ({})",
error.message,
error.code
));
}
let result: ListToolsResult =
serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
{
let mut tools = self.tools.write().await;
*tools = result.tools.clone();
}
Ok(result.tools)
}
pub async fn get_cached_tools(&self) -> Vec<McpTool> {
self.tools.read().await.clone()
}
pub async fn call_tool(
&self,
name: &str,
arguments: Option<serde_json::Value>,
) -> Result<CallToolResult> {
let params = CallToolParams {
name: name.to_string(),
arguments,
};
let request = JsonRpcRequest::new(
self.next_id(),
"tools/call",
Some(serde_json::to_value(¶ms)?),
);
let response = self.transport.request(request).await?;
if let Some(error) = response.error {
return Err(anyhow!(
"MCP call_tool error: {} ({})",
error.message,
error.code
));
}
let result: CallToolResult =
serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
Ok(result)
}
pub async fn list_resources(&self) -> Result<Vec<McpResource>> {
let request = JsonRpcRequest::new(self.next_id(), "resources/list", None);
let response = self.transport.request(request).await?;
if let Some(error) = response.error {
return Err(anyhow!(
"MCP list_resources error: {} ({})",
error.message,
error.code
));
}
let result: ListResourcesResult =
serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
{
let mut resources = self.resources.write().await;
*resources = result.resources.clone();
}
Ok(result.resources)
}
pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
let params = ReadResourceParams {
uri: uri.to_string(),
};
let request = JsonRpcRequest::new(
self.next_id(),
"resources/read",
Some(serde_json::to_value(¶ms)?),
);
let response = self.transport.request(request).await?;
if let Some(error) = response.error {
return Err(anyhow!(
"MCP read_resource error: {} ({})",
error.message,
error.code
));
}
let result: ReadResourceResult =
serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
Ok(result)
}
pub fn notifications(&self) -> tokio::sync::mpsc::Receiver<McpNotification> {
self.transport.notifications()
}
pub async fn close(&self) -> Result<()> {
self.transport.close().await
}
pub fn is_connected(&self) -> bool {
self.transport.is_connected()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_info() {
let info = ClientInfo {
name: "test".to_string(),
version: "1.0.0".to_string(),
};
let json = serde_json::to_string(&info).unwrap();
assert!(json.contains("test"));
}
#[test]
fn test_initialize_params() {
let params = InitializeParams {
protocol_version: PROTOCOL_VERSION.to_string(),
capabilities: ClientCapabilities::default(),
client_info: ClientInfo {
name: "a3s-code".to_string(),
version: "0.1.0".to_string(),
},
};
let json = serde_json::to_string(¶ms).unwrap();
assert!(json.contains("protocolVersion"));
assert!(json.contains("clientInfo"));
}
#[test]
fn test_client_info_serialize() {
let info = ClientInfo {
name: "test-client".to_string(),
version: "2.0.0".to_string(),
};
let json = serde_json::to_string(&info).unwrap();
assert!(json.contains("test-client"));
assert!(json.contains("2.0.0"));
}
#[test]
fn test_client_info_deserialize() {
let json = r#"{"name":"my-client","version":"1.2.3"}"#;
let info: ClientInfo = serde_json::from_str(json).unwrap();
assert_eq!(info.name, "my-client");
assert_eq!(info.version, "1.2.3");
}
#[test]
fn test_initialize_params_serialize() {
let params = InitializeParams {
protocol_version: "2024-11-05".to_string(),
capabilities: ClientCapabilities::default(),
client_info: ClientInfo {
name: "test".to_string(),
version: "1.0.0".to_string(),
},
};
let json = serde_json::to_string(¶ms).unwrap();
assert!(json.contains("2024-11-05"));
assert!(json.contains("capabilities"));
}
#[test]
fn test_call_tool_params_serialize() {
let params = CallToolParams {
name: "test_tool".to_string(),
arguments: Some(serde_json::json!({"key": "value"})),
};
let json = serde_json::to_string(¶ms).unwrap();
assert!(json.contains("test_tool"));
assert!(json.contains("key"));
}
#[test]
fn test_call_tool_params_no_arguments() {
let params = CallToolParams {
name: "simple_tool".to_string(),
arguments: None,
};
let json = serde_json::to_string(¶ms).unwrap();
assert!(json.contains("simple_tool"));
}
#[test]
fn test_read_resource_params_serialize() {
let params = ReadResourceParams {
uri: "file:///test.txt".to_string(),
};
let json = serde_json::to_string(¶ms).unwrap();
assert!(json.contains("file:///test.txt"));
}
#[test]
fn test_read_resource_params_deserialize() {
let json = r#"{"uri":"http://example.com/resource"}"#;
let params: ReadResourceParams = serde_json::from_str(json).unwrap();
assert_eq!(params.uri, "http://example.com/resource");
}
#[test]
fn test_server_capabilities_default() {
let caps = ServerCapabilities::default();
let json = serde_json::to_string(&caps).unwrap();
assert!(!json.is_empty());
}
#[test]
fn test_client_capabilities_default() {
let caps = ClientCapabilities::default();
let json = serde_json::to_string(&caps).unwrap();
assert!(!json.is_empty());
}
#[test]
fn test_protocol_version_constant() {
assert!(!PROTOCOL_VERSION.is_empty());
assert!(PROTOCOL_VERSION.contains("-"));
}
}