use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use super::protocol::*;
pub struct McpClient {
url: String,
http: reqwest::Client,
next_id: AtomicU64,
tools_cache: Arc<RwLock<Option<Vec<McpTool>>>>,
server_name: String,
}
impl McpClient {
pub fn new(name: &str, url: &str, timeout_secs: u64) -> Self {
let http = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(timeout_secs))
.build()
.unwrap_or_default();
Self {
url: url.to_string(),
http,
next_id: AtomicU64::new(1),
tools_cache: Arc::new(RwLock::new(None)),
server_name: name.to_string(),
}
}
fn next_request_id(&self) -> u64 {
self.next_id.fetch_add(1, Ordering::Relaxed)
}
pub fn server_name(&self) -> &str {
&self.server_name
}
pub fn url(&self) -> &str {
&self.url
}
async fn send_request(&self, request: &McpRequest) -> Result<McpResponse, String> {
let resp = self
.http
.post(&self.url)
.json(request)
.send()
.await
.map_err(|e| format!("HTTP request failed: {}", e))?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(format!("HTTP {} from MCP server: {}", status, body));
}
resp.json::<McpResponse>()
.await
.map_err(|e| format!("Failed to parse MCP response: {}", e))
}
pub async fn initialize(&self) -> Result<serde_json::Value, String> {
let params = InitializeParams::default();
let request = McpRequest::new(
self.next_request_id(),
"initialize",
Some(serde_json::to_value(¶ms).map_err(|e| e.to_string())?),
);
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
return Err(format!("MCP initialize error: {}", error.message));
}
Ok(response.result.unwrap_or(serde_json::Value::Null))
}
pub async fn list_tools(&self) -> Result<Vec<McpTool>, String> {
{
let cache = self.tools_cache.read().await;
if let Some(ref tools) = *cache {
return Ok(tools.clone());
}
}
let request = McpRequest::new(self.next_request_id(), "tools/list", None);
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
return Err(format!("MCP tools/list error: {}", error.message));
}
let result: ListToolsResult =
serde_json::from_value(response.result.ok_or("No result in tools/list response")?)
.map_err(|e| format!("Failed to parse tools list: {}", e))?;
let tools = result.tools;
{
let mut cache = self.tools_cache.write().await;
*cache = Some(tools.clone());
}
Ok(tools)
}
pub async fn call_tool(
&self,
name: &str,
arguments: serde_json::Value,
) -> Result<CallToolResult, String> {
let params = serde_json::json!({
"name": name,
"arguments": arguments,
});
let request = McpRequest::new(self.next_request_id(), "tools/call", Some(params));
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
return Err(format!("MCP tools/call error: {}", error.message));
}
let result: CallToolResult =
serde_json::from_value(response.result.ok_or("No result in tools/call response")?)
.map_err(|e| format!("Failed to parse tool call result: {}", e))?;
Ok(result)
}
pub async fn invalidate_cache(&self) {
let mut cache = self.tools_cache.write().await;
*cache = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = McpClient::new("test-server", "http://localhost:8080", 30);
assert_eq!(client.server_name(), "test-server");
assert_eq!(client.url(), "http://localhost:8080");
}
#[test]
fn test_request_id_increments() {
let client = McpClient::new("test", "http://localhost:8080", 30);
let id1 = client.next_request_id();
let id2 = client.next_request_id();
let id3 = client.next_request_id();
assert_eq!(id1, 1);
assert_eq!(id2, 2);
assert_eq!(id3, 3);
}
#[tokio::test]
async fn test_invalidate_cache() {
let client = McpClient::new("test", "http://localhost:8080", 30);
{
let mut cache = client.tools_cache.write().await;
*cache = Some(vec![McpTool {
name: "test_tool".to_string(),
description: Some("A test tool".to_string()),
input_schema: serde_json::json!({"type": "object"}),
}]);
}
{
let cache = client.tools_cache.read().await;
assert!(cache.is_some());
}
client.invalidate_cache().await;
{
let cache = client.tools_cache.read().await;
assert!(cache.is_none());
}
}
#[test]
fn test_client_default_timeout() {
let _c1 = McpClient::new("fast", "http://localhost:8080", 5);
let _c2 = McpClient::new("slow", "http://localhost:8080", 120);
let _c3 = McpClient::new("very-slow", "http://localhost:8080", 600);
}
#[test]
fn test_server_name_accessor() {
let client = McpClient::new("my-mcp-server", "http://example.com", 30);
assert_eq!(client.server_name(), "my-mcp-server");
}
#[test]
fn test_url_accessor() {
let client = McpClient::new("test", "https://mcp.example.com/rpc", 30);
assert_eq!(client.url(), "https://mcp.example.com/rpc");
}
#[tokio::test]
async fn test_call_tool_no_server() {
let client = McpClient::new("test", "http://127.0.0.1:1", 5);
let result = client
.call_tool("some_tool", serde_json::json!({"key": "value"}))
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.contains("HTTP request failed"),
"Expected connection error, got: {}",
err
);
}
#[tokio::test]
async fn test_initialize_no_server() {
let client = McpClient::new("test", "http://127.0.0.1:1", 5);
let result = client.initialize().await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.contains("HTTP request failed"),
"Expected connection error, got: {}",
err
);
}
#[tokio::test]
async fn test_list_tools_no_server() {
let client = McpClient::new("test", "http://127.0.0.1:1", 5);
let result = client.list_tools().await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.contains("HTTP request failed"),
"Expected connection error, got: {}",
err
);
}
#[tokio::test]
async fn test_cache_starts_empty() {
let client = McpClient::new("test", "http://localhost:8080", 30);
let cache = client.tools_cache.read().await;
assert!(cache.is_none(), "Cache should start as None");
}
}