use anyhow::{Context, Result, bail};
use serde_json::{Value, json};
use std::sync::Arc;
use super::protocol::JsonRpcRequest;
use super::protocol::{
ClientCapabilities, ClientInfo, InitializeParams, InitializeResult, McpToolCallResult,
McpToolDefinition, ToolCallParams, ToolsListResult,
};
use super::transport::McpTransport;
pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
pub struct McpClient<T: McpTransport> {
transport: Arc<T>,
server_name: String,
server_info: Option<InitializeResult>,
}
impl<T: McpTransport> McpClient<T> {
pub async fn new(transport: Arc<T>, server_name: String) -> Result<Self> {
let mut client = Self {
transport,
server_name,
server_info: None,
};
client.initialize().await?;
Ok(client)
}
#[must_use]
pub const fn new_uninitialized(transport: Arc<T>, server_name: String) -> Self {
Self {
transport,
server_name,
server_info: None,
}
}
pub async fn initialize(&mut self) -> Result<&InitializeResult> {
let params = InitializeParams {
protocol_version: MCP_PROTOCOL_VERSION.to_string(),
capabilities: ClientCapabilities::default(),
client_info: ClientInfo {
name: "agent-sdk".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
};
let request = JsonRpcRequest::new("initialize", Some(serde_json::to_value(¶ms)?), 0);
let response = self.transport.send(request).await?;
let result: InitializeResult = response
.result
.map(serde_json::from_value)
.transpose()
.context("Failed to parse initialize response")?
.context("Initialize response missing result")?;
let notification = JsonRpcRequest::new("notifications/initialized", None, 0);
let _ = self.transport.send_notification(notification).await;
self.server_info = Some(result);
self.server_info
.as_ref()
.context("Server info not available")
}
#[must_use]
pub fn server_name(&self) -> &str {
&self.server_name
}
#[must_use]
pub const fn server_info(&self) -> Option<&InitializeResult> {
self.server_info.as_ref()
}
pub async fn list_tools(&self) -> Result<Vec<McpToolDefinition>> {
let request = JsonRpcRequest::new("tools/list", None, 0);
let response = self.transport.send(request).await?;
let result: ToolsListResult = response
.result
.map(serde_json::from_value)
.transpose()
.context("Failed to parse tools/list response")?
.context("tools/list response missing result")?;
Ok(result.tools)
}
pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<McpToolCallResult> {
let params = ToolCallParams {
name: name.to_string(),
arguments: Some(arguments),
};
let request = JsonRpcRequest::new("tools/call", Some(serde_json::to_value(¶ms)?), 0);
let response = self.transport.send(request).await?;
if let Some(ref error) = response.error {
bail!("Tool call failed: {} (code {})", error.message, error.code);
}
let result: McpToolCallResult = response
.result
.map(serde_json::from_value)
.transpose()
.context("Failed to parse tools/call response")?
.context("tools/call response missing result")?;
Ok(result)
}
pub async fn call_tool_raw(
&self,
name: &str,
arguments: Option<Value>,
) -> Result<McpToolCallResult> {
let args = arguments.unwrap_or_else(|| json!({}));
self.call_tool(name, args).await
}
pub async fn close(&self) -> Result<()> {
self.transport.close().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mcp_protocol_version() {
assert!(!MCP_PROTOCOL_VERSION.is_empty());
}
#[test]
fn test_client_info() {
let info = ClientInfo {
name: "test".to_string(),
version: "1.0.0".to_string(),
};
let json = serde_json::to_string(&info).expect("serialize");
assert!(json.contains("test"));
assert!(json.contains("1.0.0"));
}
}