use async_trait::async_trait;
use mcp_runner::McpClient; use mcp_runner::error::{Error, Result};
use mcp_runner::transport::Transport;
use mockall::mock;
use mockall::predicate::*;
use serde::{Deserialize, Serialize};
use serde_json::Value;
mock! {
pub TransportMock {}
#[async_trait]
impl Transport for TransportMock {
async fn initialize(&self) -> Result<()>;
async fn list_tools(&self) -> Result<Vec<Value>>;
async fn call_tool(&self, name: &str, args: Value) -> Result<Value>;
async fn list_resources(&self) -> Result<Vec<Value>>;
async fn get_resource(&self, uri: &str) -> Result<Value>;
}
}
fn create_test_client(mock_transport: MockTransportMock) -> McpClient {
McpClient::new("test".to_string(), mock_transport)
}
#[tokio::test]
async fn test_list_tools() -> Result<()> {
let mut mock_transport = MockTransportMock::new();
mock_transport.expect_list_tools().times(1).returning(|| {
Ok(vec![serde_json::json!({
"name": "test_tool",
"description": "A test tool",
"inputSchema": {
"type": "object",
"properties": {
"input": {"type": "string"}
}
},
"outputSchema": {
"type": "object",
"properties": {
"output": {"type": "string"}
}
}
})])
});
let client = create_test_client(mock_transport);
let tools = client.list_tools().await?;
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "test_tool");
assert_eq!(tools[0].description, "A test tool");
assert!(tools[0].input_schema.is_some());
assert!(tools[0].output_schema.is_some());
Ok(())
}
#[tokio::test]
async fn test_call_tool() -> Result<()> {
#[derive(Serialize, Debug)] struct TestInput {
message: String,
}
#[derive(Deserialize, Debug, PartialEq)]
struct TestOutput {
response: String,
}
let mut mock_transport = MockTransportMock::new();
mock_transport
.expect_call_tool()
.with(eq("echo"), eq(serde_json::json!({"message": "hello"})))
.times(1)
.returning(|_, _| {
Ok(serde_json::json!({
"response": "hello from echo"
}))
});
let client = create_test_client(mock_transport);
let input = TestInput {
message: "hello".to_string(),
};
let output: TestOutput = client.call_tool("echo", &input).await?;
assert_eq!(
output,
TestOutput {
response: "hello from echo".to_string()
}
);
Ok(())
}
#[tokio::test]
async fn test_list_resources() -> Result<()> {
let mut mock_transport = MockTransportMock::new();
mock_transport
.expect_list_resources()
.times(1)
.returning(|| {
Ok(vec![serde_json::json!({
"uri": "resource:test",
"name": "test_resource",
"description": "A test resource",
"type": "text"
})])
});
let client = create_test_client(mock_transport);
let resources = client.list_resources().await?;
assert_eq!(resources.len(), 1);
assert_eq!(resources[0].uri, "resource:test");
assert_eq!(resources[0].name, "test_resource");
assert_eq!(
resources[0].description,
Some("A test resource".to_string())
);
assert_eq!(resources[0].resource_type, Some("text".to_string()));
Ok(())
}
#[tokio::test]
async fn test_get_resource() -> Result<()> {
#[derive(Deserialize, Debug, PartialEq)]
struct TestResource {
content: String,
metadata: TestResourceMetadata,
}
#[derive(Deserialize, Debug, PartialEq)]
struct TestResourceMetadata {
created: String,
}
let mut mock_transport = MockTransportMock::new();
mock_transport
.expect_get_resource()
.with(eq("resource:test"))
.times(1)
.returning(|_| {
Ok(serde_json::json!({
"content": "Sample content",
"metadata": {
"created": "2025-04-23T12:00:00Z"
}
}))
});
let client = create_test_client(mock_transport);
let resource: TestResource = client.get_resource("resource:test").await?;
assert_eq!(
resource,
TestResource {
content: "Sample content".to_string(),
metadata: TestResourceMetadata {
created: "2025-04-23T12:00:00Z".to_string()
}
}
);
Ok(())
}
#[tokio::test]
async fn test_initialize() -> Result<()> {
let mut mock_transport = MockTransportMock::new();
mock_transport
.expect_initialize()
.times(1)
.returning(|| Ok(()));
let client = create_test_client(mock_transport);
client.initialize().await?;
Ok(())
}
#[tokio::test]
async fn test_serialization_error() -> Result<()> {
use std::collections::HashMap;
#[derive(Debug)] struct Unserializable;
impl Serialize for Unserializable {
fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
Err(serde::ser::Error::custom("Test serialization error"))
}
}
let mock_transport = MockTransportMock::new();
let client = create_test_client(mock_transport);
let result = client
.call_tool::<_, HashMap<String, String>>("test", &Unserializable)
.await;
assert!(result.is_err());
if let Err(e) = result {
assert!(
e.to_string().contains("serialization"),
"Expected serialization error, got: {}",
e
);
}
Ok(())
}
#[tokio::test]
async fn test_deserialization_error() -> Result<()> {
let mut mock_transport = MockTransportMock::new();
mock_transport.expect_call_tool().returning(|_, _| {
Ok(serde_json::json!({
"invalid": "structure"
}))
});
let client = create_test_client(mock_transport);
#[derive(Deserialize)]
#[allow(dead_code)] struct ExpectedOutput {
required_field: String,
}
let result: std::result::Result<ExpectedOutput, _> =
client.call_tool("test", &serde_json::json!({})).await;
assert!(result.is_err());
if let Err(e) = result {
assert!(
e.to_string().contains("deserialize"),
"Expected deserialization error, got: {}",
e
);
}
Ok(())
}
#[tokio::test]
async fn test_transport_error_propagation() -> Result<()> {
let mut mock_transport = MockTransportMock::new();
mock_transport
.expect_list_tools()
.returning(|| Err(Error::Communication("Transport error".to_string())));
let client = create_test_client(mock_transport);
let result = client.list_tools().await;
assert!(result.is_err());
if let Err(e) = result {
if let Error::Communication(msg) = &e {
assert!(msg.contains("Transport error"));
} else {
panic!("Expected Communication error, got: {:?}", e);
}
}
Ok(())
}