use std::{collections::HashMap, future::Future, pin::Pin, sync::Mutex, time::Duration};
use async_trait::async_trait;
use rmcp::model::Tool;
use serde_json::Value;
use crate::client::{CallOutcome, Client, ClientError};
type CallFuture = Pin<Box<dyn Future<Output = CallOutcome> + Send>>;
type SyncHandler = Box<dyn Fn(&Value) -> CallOutcome + Send + Sync>;
type AsyncHandler = Box<dyn Fn(&Value) -> CallFuture + Send + Sync>;
enum MockHandler {
Sync(SyncHandler),
Async(AsyncHandler),
}
#[async_trait]
pub trait McpExec: Send + Sync {
async fn list_tools(&self) -> Result<Vec<Tool>, ClientError>;
async fn call_tool(&self, name: &str, arguments: Value, timeout: Duration) -> CallOutcome;
async fn reconnect(&self) -> Result<(), ClientError>;
}
#[async_trait]
impl McpExec for Client {
async fn list_tools(&self) -> Result<Vec<Tool>, ClientError> {
Client::list_tools(self).await
}
async fn call_tool(&self, name: &str, arguments: Value, timeout: Duration) -> CallOutcome {
Client::call_tool(self, name, arguments, timeout).await
}
async fn reconnect(&self) -> Result<(), ClientError> {
Client::reconnect(self).await
}
}
pub struct MockClient {
tools: Vec<Tool>,
handlers: Mutex<HashMap<String, MockHandler>>,
reconnect_count: Mutex<usize>,
}
impl MockClient {
#[must_use]
pub fn new() -> Self {
Self {
tools: Vec::new(),
handlers: Mutex::new(HashMap::new()),
reconnect_count: Mutex::new(0),
}
}
pub fn register<F>(mut self, tool: Tool, handler: F) -> Self
where
F: Fn(&Value) -> CallOutcome + Send + Sync + 'static,
{
let name = tool.name.to_string();
self.tools.push(tool);
self.handlers
.lock()
.unwrap_or_else(|p| p.into_inner())
.insert(name, MockHandler::Sync(Box::new(handler)));
self
}
pub fn register_async<F, Fut>(mut self, tool: Tool, handler: F) -> Self
where
F: Fn(&Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = CallOutcome> + Send + 'static,
{
let name = tool.name.to_string();
self.tools.push(tool);
let boxed: AsyncHandler = Box::new(move |args| Box::pin(handler(args)));
self.handlers
.lock()
.unwrap_or_else(|p| p.into_inner())
.insert(name, MockHandler::Async(boxed));
self
}
#[must_use]
pub fn reconnect_count(&self) -> usize {
*self
.reconnect_count
.lock()
.unwrap_or_else(|p| p.into_inner())
}
}
impl Default for MockClient {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl McpExec for MockClient {
async fn list_tools(&self) -> Result<Vec<Tool>, ClientError> {
Ok(self.tools.clone())
}
async fn call_tool(&self, name: &str, arguments: Value, _timeout: Duration) -> CallOutcome {
let future = {
let handlers = self.handlers.lock().unwrap_or_else(|p| p.into_inner());
match handlers.get(name) {
Some(MockHandler::Sync(handler)) => return handler(&arguments),
Some(MockHandler::Async(handler)) => handler(&arguments),
None => return CallOutcome::ProtocolError(format!("unknown tool `{name}`")),
}
};
future.await
}
async fn reconnect(&self) -> Result<(), ClientError> {
*self
.reconnect_count
.lock()
.unwrap_or_else(|p| p.into_inner()) += 1;
Ok(())
}
}