use crate::mcp::client::auth::AuthConfig;
use crate::mcp::client::cache::{CacheConfig, ResourceCache};
use crate::mcp::client::error::{ClientError, Result};
use crate::mcp::client::registry::ToolRegistry;
use crate::mcp::client::result::ToolResult;
use crate::mcp::client::transport::Transport;
use std::time::Duration;
use rmcp::{
RoleClient,
model::{CallToolRequestParam, ReadResourceRequestParam},
service::{RunningService, ServiceExt},
transport::TokioChildProcess,
};
pub struct McpClient {
service: Option<RunningService<RoleClient, ()>>,
registry: ToolRegistry,
auth_config: Option<AuthConfig>,
resource_cache: Option<ResourceCache>,
timeout: Duration,
}
impl McpClient {
pub fn new(_transport: Box<dyn Transport>) -> Self {
Self {
service: None, registry: ToolRegistry::new(), auth_config: None, resource_cache: None, timeout: Duration::from_millis(5000), }
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_auth(mut self, auth_config: AuthConfig) -> Self {
self.auth_config = Some(auth_config);
self
}
pub fn auth_config(&self) -> Option<&AuthConfig> {
self.auth_config.as_ref()
}
pub async fn with_cache(mut self, cache_config: CacheConfig) -> Result<Self> {
let cache = ResourceCache::new(cache_config).await?;
self.resource_cache = Some(cache);
Ok(self)
}
pub fn without_cache(mut self) -> Self {
self.resource_cache = None;
self
}
pub fn cache_analytics(&self) -> Option<&crate::mcp::client::cache::CacheAnalytics> {
self.resource_cache
.as_ref()
.map(|cache| cache.get_analytics())
}
pub async fn connect_to_child_process(
&mut self,
command: tokio::process::Command,
) -> Result<()> {
let transport = TokioChildProcess::new(command).map_err(|e| {
ClientError::Transport(format!("Failed to create child process: {}", e))
})?;
let service = ().serve(transport).await.map_err(|e| {
ClientError::Protocol(format!("Failed to connect to MCP server: {}", e))
})?;
self.service = Some(service);
Ok(())
}
pub async fn ping(&mut self) -> Result<()> {
match &self.service {
Some(service) => {
let _info = service.peer_info();
Ok(())
}
None => Err(ClientError::Client(
"Not connected to MCP server. Call connect_to_child_process() first.".to_string(),
)),
}
}
pub async fn list_tools(&mut self) -> Result<Vec<String>> {
match &self.service {
Some(service) => {
let tools_response = service
.list_tools(Default::default())
.await
.map_err(|e| ClientError::Protocol(format!("Failed to list tools: {}", e)))?;
self.registry
.update_from_rmcp_tools(tools_response.tools.clone());
let tool_names = tools_response
.tools
.into_iter()
.map(|tool| tool.name.to_string()) .collect();
Ok(tool_names)
}
None => Err(ClientError::Client(
"Not connected to MCP server. Call connect_to_child_process() first.".to_string(),
)),
}
}
pub async fn call_tool(
&mut self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value> {
match &self.service {
Some(service) => {
self.registry.validate_parameters(tool_name, &arguments)?;
let arguments_object = arguments.as_object().cloned();
let request = CallToolRequestParam {
name: tool_name.to_string().into(),
arguments: arguments_object,
};
let tool_response = service.call_tool(request).await.map_err(|e| {
ClientError::Protocol(format!("Failed to call tool '{}': {}", tool_name, e))
})?;
let response_json = serde_json::to_value(&tool_response).map_err(|e| {
ClientError::Client(format!("Failed to serialize tool response: {}", e))
})?;
Ok(response_json)
}
None => Err(ClientError::Client(
"Not connected to MCP server. Call connect_to_child_process() first.".to_string(),
)),
}
}
pub async fn call_tool_streaming(
&mut self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<Box<dyn futures::Stream<Item = Result<serde_json::Value>> + Send + Unpin>> {
let service = self.service.as_ref().ok_or_else(|| {
ClientError::Client(
"Not connected to MCP server. Call connect_to_child_process() first.".to_string(),
)
})?;
self.registry.validate_parameters(tool_name, &arguments)?;
let tool_response = self
.execute_tool_call(service, tool_name, arguments)
.await?;
let response_json = serde_json::to_value(&tool_response).map_err(|e| {
ClientError::Client(format!("Failed to serialize tool response: {}", e))
})?;
let stream = if self.is_streaming_response(&response_json) {
self.create_progress_stream(response_json)
} else {
self.create_single_item_stream(response_json)
};
Ok(stream)
}
async fn execute_tool_call(
&self,
service: &rmcp::service::RunningService<rmcp::RoleClient, ()>,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<rmcp::model::CallToolResult> {
let arguments_object = arguments.as_object().cloned();
let request = CallToolRequestParam {
name: tool_name.to_string().into(),
arguments: arguments_object,
};
service.call_tool(request).await.map_err(|e| {
ClientError::Protocol(format!("Failed to call tool '{}': {}", tool_name, e))
})
}
fn is_streaming_response(&self, response: &serde_json::Value) -> bool {
if let Some(content_array) = response.get("content").and_then(|c| c.as_array()) {
if let Some(first_content) = content_array.first() {
if let Some(text_content) = first_content.get("text").and_then(|t| t.as_str()) {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text_content) {
return parsed.get("streaming").is_some()
|| parsed.get("progress").is_some()
|| parsed.get("status").is_some();
}
}
}
}
false
}
fn create_progress_stream(
&self,
final_response: serde_json::Value,
) -> Box<dyn futures::Stream<Item = Result<serde_json::Value>> + Send + Unpin> {
use futures::stream;
let progress_updates = vec![
Ok(serde_json::json!({"status": "started", "progress": 0})),
Ok(serde_json::json!({"status": "processing", "progress": 50})),
Ok(final_response), ];
Box::new(stream::iter(progress_updates))
}
fn create_single_item_stream(
&self,
response: serde_json::Value,
) -> Box<dyn futures::Stream<Item = Result<serde_json::Value>> + Send + Unpin> {
use futures::stream;
Box::new(stream::iter(vec![Ok(response)]))
}
pub async fn call_tool_typed(
&mut self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<ToolResult> {
let service = self.service.as_ref().ok_or_else(|| {
ClientError::Client(
"Not connected to MCP server. Call connect_to_child_process() first.".to_string(),
)
})?;
self.registry.validate_parameters(tool_name, &arguments)?;
let tool_response = self
.execute_tool_call(service, tool_name, arguments)
.await?;
ToolResult::from_rmcp_result(&tool_response)
}
pub fn registry(&self) -> &ToolRegistry {
&self.registry
}
#[cfg(test)]
pub fn registry_mut(&mut self) -> &mut ToolRegistry {
&mut self.registry
}
pub async fn validate_parameters(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<()> {
self.registry.validate_parameters(tool_name, &arguments)
}
pub async fn list_resources(
&mut self,
) -> Result<Vec<crate::mcp::client::resource::ResourceInfo>> {
let service = self.service.as_ref().ok_or_else(|| {
ClientError::Client(
"Not connected to MCP server. Call connect_to_child_process() first.".to_string(),
)
})?;
let rmcp_resources = service
.list_all_resources()
.await
.map_err(|e| ClientError::Protocol(format!("Failed to list resources: {}", e)))?;
let resources = rmcp_resources
.into_iter()
.map(|rmcp_resource| {
let mut metadata = std::collections::HashMap::new();
if let Some(size) = rmcp_resource.size {
metadata.insert(
"size".to_string(),
serde_json::Value::Number(serde_json::Number::from(size)),
);
}
crate::mcp::client::resource::ResourceInfo {
uri: rmcp_resource.uri.clone(),
name: Some(rmcp_resource.name.clone()),
description: rmcp_resource.description.clone(),
mime_type: rmcp_resource.mime_type.clone(),
metadata,
}
})
.collect();
Ok(resources)
}
pub async fn get_resource(
&mut self,
uri: &str,
) -> Result<crate::mcp::client::resource::ResourceContent> {
if let Some(ref mut cache) = self.resource_cache {
if let Some(cached_resource) = cache.get_resource(uri).await? {
log::debug!("Cache hit for resource: {}", uri);
return Ok(cached_resource);
}
log::debug!("Cache miss for resource: {}", uri);
}
let service = self.service.as_ref().ok_or_else(|| {
ClientError::Client(
"Not connected to MCP server. Call connect_to_child_process() first.".to_string(),
)
})?;
let read_result = service
.read_resource(ReadResourceRequestParam {
uri: uri.to_string(),
})
.await
.map_err(|e| {
ClientError::Protocol(format!("Failed to read resource '{}': {}", uri, e))
})?;
if let Some(content) = read_result.contents.into_iter().next() {
let (data, encoding, mime_type) = match content {
rmcp::model::ResourceContents::TextResourceContents {
text, mime_type, ..
} => (text.into_bytes(), Some("utf-8".to_string()), mime_type),
rmcp::model::ResourceContents::BlobResourceContents {
blob, mime_type, ..
} => {
use base64::prelude::*;
let decoded_data = BASE64_STANDARD.decode(&blob).map_err(|e| {
ClientError::Protocol(format!("Failed to decode base64 blob: {}", e))
})?;
(decoded_data, None, mime_type)
}
};
let resource_info = crate::mcp::client::resource::ResourceInfo {
uri: uri.to_string(),
name: None, description: None,
mime_type,
metadata: std::collections::HashMap::new(),
};
let resource_content = crate::mcp::client::resource::ResourceContent {
info: resource_info,
data,
encoding,
};
if let Some(ref mut cache) = self.resource_cache {
if let Err(e) = cache.store_resource(&resource_content).await {
log::warn!("Failed to cache resource '{}': {}", uri, e);
}
}
Ok(resource_content)
} else {
Err(ClientError::Protocol(format!(
"No content returned for resource '{}'",
uri
)))
}
}
pub async fn invalidate_cache(&mut self, uri: Option<&str>) -> Result<()> {
if let Some(ref mut cache) = self.resource_cache {
match uri {
Some(uri) => {
cache.remove_resource(uri).await?;
log::debug!("Invalidated cache for resource: {}", uri);
}
None => {
cache.clear().await?;
log::debug!("Cleared all cached resources");
}
}
}
Ok(())
}
pub async fn cleanup_cache(&mut self) -> Result<u64> {
if let Some(ref mut cache) = self.resource_cache {
let removed_count = cache.cleanup_expired().await?;
log::debug!("Cleaned up {} expired cache entries", removed_count);
Ok(removed_count)
} else {
Ok(0)
}
}
pub async fn list_cached_resources(
&self,
) -> Result<Vec<crate::mcp::client::cache::CachedResource>> {
if let Some(ref cache) = self.resource_cache {
cache.list_cached_resources().await
} else {
Ok(Vec::new())
}
}
pub async fn search_cached_resources(
&self,
query: &str,
) -> Result<Vec<crate::mcp::client::cache::CachedResource>> {
if let Some(ref cache) = self.resource_cache {
cache.search_resources(query).await
} else {
Ok(Vec::new())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::client::transport::MockTransport;
use serde_json::json;
use uuid::Uuid;
fn create_test_cache_config() -> (crate::mcp::client::cache::CacheConfig, tempfile::TempDir) {
let temp_dir = tempfile::tempdir().unwrap();
let db_path = temp_dir.path().join(format!("test_{}.db", Uuid::new_v4()));
let config = crate::mcp::client::cache::CacheConfig {
database_path: db_path.to_string_lossy().to_string(),
..Default::default()
};
(config, temp_dir)
}
#[tokio::test]
#[ignore] async fn test_connect_to_mcp_server() {
use tokio::process::Command;
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let mut command = Command::new("echo");
command.arg("Mock MCP server that doesn't exist");
let result = client.connect_to_child_process(command).await;
assert!(result.is_err());
if let Err(ClientError::Protocol(msg)) = result {
assert!(msg.contains("Failed to connect to MCP server"));
} else if let Err(ClientError::Transport(_)) = result {
} else {
panic!("Expected connection to fail with Protocol or Transport error");
}
}
#[tokio::test]
async fn test_client_creation() {
let mock_transport = MockTransport::new(vec![]);
let client = McpClient::new(Box::new(mock_transport));
assert_eq!(client.timeout, Duration::from_millis(5000));
}
#[tokio::test]
async fn test_client_with_custom_timeout() {
let mock_transport = MockTransport::new(vec![]);
let timeout = Duration::from_millis(1000);
let client = McpClient::new(Box::new(mock_transport)).with_timeout(timeout);
assert_eq!(client.timeout, timeout);
}
#[tokio::test]
async fn test_ping_not_connected() {
let mock_response = json!({
"jsonrpc": "2.0",
"id": 1,
"result": {}
});
let mock_transport = MockTransport::new(vec![mock_response]);
let mut client = McpClient::new(Box::new(mock_transport));
let result = client.ping().await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client");
}
}
#[tokio::test]
async fn test_list_tools_not_connected() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let result = client.list_tools().await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client");
}
}
#[tokio::test]
async fn test_call_tool_not_connected() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let result = client.call_tool("get_pet_by_id", json!({"id": 123})).await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client");
}
}
#[tokio::test]
async fn test_call_tool_snake_case_naming() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let test_cases = vec![
("get_pet_by_id", json!({"id": 123})),
("list_pets", json!({})),
(
"create_pet",
json!({"name": "Fluffy", "status": "available"}),
),
("update_pet_status", json!({"id": 456, "status": "sold"})),
];
for (tool_name, params) in test_cases {
let result = client.call_tool(tool_name, params).await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client for tool: {}", tool_name);
}
}
}
#[tokio::test]
async fn test_call_tool_argument_handling() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let test_cases = vec![
("ping", json!({})),
("get_pet_by_id", json!({"id": 123})),
(
"create_pet",
json!({
"name": "Fluffy",
"status": "available",
"tags": ["cute", "fluffy"],
"metadata": {"breed": "Persian", "age": 2}
}),
),
("batch_process", json!(["item1", "item2", "item3"])),
];
for (tool_name, args) in test_cases {
let result = client.call_tool(tool_name, args).await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client for tool: {}", tool_name);
}
}
}
#[test]
fn test_registry_access() {
let mock_transport = MockTransport::new(vec![]);
let client = McpClient::new(Box::new(mock_transport));
let registry = client.registry();
assert_eq!(registry.tool_names().len(), 0);
assert!(!registry.has_tool("get_pet_by_id"));
}
#[tokio::test]
async fn test_call_tool_streaming_not_connected() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let result = client
.call_tool_streaming("get_pet_by_id", json!({"id": 123}))
.await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client");
}
}
#[tokio::test]
async fn test_call_tool_streaming_mock_response() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let test_cases = vec![
("long_running_task", json!({"input": "test"})),
("data_processing", json!({"batch_size": 100})),
];
for (tool_name, params) in test_cases {
let result = client.call_tool_streaming(tool_name, params).await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!(
"Expected ClientError::Client for streaming tool: {}",
tool_name
);
}
}
}
#[tokio::test]
async fn test_streaming_response_format() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let result = client
.call_tool_streaming("mock_stream_tool", json!({"delay": 100}))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_streaming_vs_non_streaming_response() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let test_cases = vec![
("simple_tool", json!({"input": "test"})),
(
"long_running_tool",
json!({"streaming": true, "task": "process_data"}),
),
];
for (tool_name, params) in test_cases {
let result = client.call_tool_streaming(tool_name, params).await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client for tool: {}", tool_name);
}
}
}
#[tokio::test]
async fn test_call_tool_typed_not_connected() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let result = client
.call_tool_typed("get_pet_by_id", json!({"id": 123}))
.await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client");
}
}
#[tokio::test]
async fn test_call_tool_typed_response_processing() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let test_cases = vec![
("get_status", json!({})),
("get_data", json!({"format": "json"})),
("invalid_tool", json!({"bad": "params"})),
];
for (tool_name, params) in test_cases {
let result = client.call_tool_typed(tool_name, params).await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client for typed tool: {}", tool_name);
}
}
}
#[tokio::test]
async fn test_tool_result_content_types() {
use crate::mcp::client::result::{ContentType, ToolResult};
let mock_result = ToolResult {
content: vec![
ContentType::Text {
text: "Status: OK".to_string(),
},
ContentType::Json {
json: json!({"count": 42, "status": "success"}),
},
],
is_error: false,
error_code: None,
raw_response: json!({"mock": "response"}),
};
assert_eq!(mock_result.first_text(), Some("Status: OK"));
assert_eq!(mock_result.text(), "Status: OK");
assert!(!mock_result.has_error());
let json_items = mock_result.json();
assert_eq!(json_items.len(), 1);
assert_eq!(json_items[0].get("count").unwrap(), 42);
}
#[tokio::test]
async fn test_error_tool_result_handling() {
use crate::mcp::client::result::{ContentType, ToolResult};
let error_result = ToolResult {
content: vec![ContentType::Text {
text: "Tool execution failed".to_string(),
}],
is_error: true,
error_code: Some("EXECUTION_ERROR".to_string()),
raw_response: json!({"error": "Tool not found"}),
};
assert!(error_result.has_error());
assert_eq!(error_result.error_code, Some("EXECUTION_ERROR".to_string()));
assert_eq!(error_result.first_text(), Some("Tool execution failed"));
}
#[tokio::test]
async fn test_validate_required_parameters() {
use crate::mcp::client::registry::ToolInfo;
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let tool = ToolInfo {
name: "get_pet_by_id".to_string(),
description: Some("Get a pet by ID".to_string()),
input_schema: Some(json!({
"type": "object",
"properties": {
"id": {"type": "integer"}
},
"required": ["id"]
})),
};
client.registry_mut().add_tool(tool);
let result = client.validate_parameters("get_pet_by_id", json!({})).await;
assert!(result.is_err());
if let Err(ClientError::Validation(msg)) = result {
assert!(msg.contains("required parameter 'id' is missing"));
} else {
panic!("Expected ClientError::Validation for missing required parameter");
}
}
#[tokio::test]
async fn test_validate_parameter_types() {
use crate::mcp::client::registry::ToolInfo;
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let tool = ToolInfo {
name: "get_pet_by_id".to_string(),
description: Some("Get a pet by ID".to_string()),
input_schema: Some(json!({
"type": "object",
"properties": {
"id": {"type": "integer"}
},
"required": ["id"]
})),
};
client.registry_mut().add_tool(tool);
let result = client
.validate_parameters("get_pet_by_id", json!({"id": "not_a_number"}))
.await;
assert!(result.is_err());
if let Err(ClientError::Validation(msg)) = result {
assert!(msg.contains("parameter 'id' should be a number"));
} else {
panic!("Expected ClientError::Validation for wrong parameter type");
}
}
#[tokio::test]
async fn test_validate_unknown_parameters() {
use crate::mcp::client::registry::ToolInfo;
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let tool = ToolInfo {
name: "get_pet_by_id".to_string(),
description: Some("Get a pet by ID".to_string()),
input_schema: Some(json!({
"type": "object",
"properties": {
"id": {"type": "integer"}
},
"required": ["id"]
})),
};
client.registry_mut().add_tool(tool);
let result = client
.validate_parameters(
"get_pet_by_id",
json!({"id": 123, "unknown_param": "value"}),
)
.await;
assert!(result.is_err());
if let Err(ClientError::Validation(msg)) = result {
assert!(msg.contains("unknown parameter 'unknown_param'"));
} else {
panic!("Expected ClientError::Validation for unknown parameter");
}
}
#[tokio::test]
async fn test_validate_parameters_successful() {
use crate::mcp::client::registry::ToolInfo;
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let tool = ToolInfo {
name: "get_pet_by_id".to_string(),
description: Some("Get a pet by ID".to_string()),
input_schema: Some(json!({
"type": "object",
"properties": {
"id": {"type": "integer"}
},
"required": ["id"]
})),
};
client.registry_mut().add_tool(tool);
let result = client
.validate_parameters("get_pet_by_id", json!({"id": 123}))
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_list_resources_not_connected() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let result = client.list_resources().await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client");
}
}
#[tokio::test]
async fn test_get_resource_not_connected() {
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let result = client.get_resource("file:///test.txt").await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected ClientError::Client");
}
}
#[tokio::test]
async fn test_end_to_end_tool_validation_integration() {
use crate::mcp::client::registry::ToolInfo;
let mock_transport = MockTransport::new(vec![]);
let mut client = McpClient::new(Box::new(mock_transport));
let tool = ToolInfo {
name: "create_pet".to_string(),
description: Some("Create a new pet".to_string()),
input_schema: Some(json!({
"type": "object",
"properties": {
"name": {"type": "string"},
"status": {"type": "string"},
"age": {"type": "integer"},
"tags": {"type": "array"}
},
"required": ["name", "status"]
})),
};
client.registry_mut().add_tool(tool);
let valid_params = json!({
"name": "Fluffy",
"status": "available",
"age": 2,
"tags": ["cute", "fluffy"]
});
let result = client.validate_parameters("create_pet", valid_params).await;
assert!(result.is_ok(), "Valid parameters should pass validation");
let missing_required = json!({"name": "Fluffy"});
let result = client
.validate_parameters("create_pet", missing_required)
.await;
assert!(result.is_err());
if let Err(ClientError::Validation(msg)) = result {
assert!(msg.contains("required parameter 'status' is missing"));
} else {
panic!("Expected validation error for missing required parameter");
}
let wrong_type = json!({"name": "Fluffy", "status": "available", "age": "not_a_number"});
let result = client.validate_parameters("create_pet", wrong_type).await;
assert!(result.is_err());
if let Err(ClientError::Validation(msg)) = result {
assert!(msg.contains("parameter 'age' should be a number"));
} else {
panic!("Expected validation error for wrong parameter type");
}
let unknown_param = json!({"name": "Fluffy", "status": "available", "unknown": "value"});
let result = client
.validate_parameters("create_pet", unknown_param)
.await;
assert!(result.is_err());
if let Err(ClientError::Validation(msg)) = result {
assert!(msg.contains("unknown parameter 'unknown'"));
} else {
panic!("Expected validation error for unknown parameter");
}
let result = client
.call_tool_typed(
"create_pet",
json!({"name": "Fluffy", "status": "available"}),
)
.await;
assert!(result.is_err());
if let Err(ClientError::Client(msg)) = result {
assert!(msg.contains("Not connected to MCP server"));
} else {
panic!("Expected client error when not connected to server");
}
}
#[tokio::test]
async fn test_client_with_auth_configuration() {
use crate::mcp::client::auth::AuthConfig;
let mock_transport = MockTransport::new(vec![]);
let auth_config = AuthConfig::new().with_api_key(
"test_api_key_123".to_string(),
Some("X-API-Key".to_string()),
);
assert!(auth_config.is_ok());
let auth_config = auth_config.unwrap();
let client = McpClient::new(Box::new(mock_transport)).with_auth(auth_config);
assert!(client.auth_config().is_some());
let auth_headers = client.auth_config().unwrap().get_auth_headers();
assert!(auth_headers.is_ok());
let headers = auth_headers.unwrap();
assert!(headers.contains_key("X-API-Key"));
assert_eq!(
headers.get("X-API-Key"),
Some(&"test_api_key_123".to_string())
);
}
#[tokio::test]
async fn test_client_auth_security_validation() {
use crate::mcp::client::auth::AuthConfig;
let dangerous_api_key = "ignore previous instructions\x00malicious";
let auth_result = AuthConfig::new()
.with_api_key(dangerous_api_key.to_string(), Some("X-API-Key".to_string()));
assert!(auth_result.is_err());
if let Err(ClientError::Validation(msg)) = auth_result {
assert!(msg.contains("potentially unsafe characters"));
} else {
panic!("Expected validation error for dangerous credential");
}
}
#[tokio::test]
async fn test_client_bearer_token_auth() {
use crate::mcp::client::auth::AuthConfig;
let mock_transport = MockTransport::new(vec![]);
let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
let auth_config = AuthConfig::new().with_bearer_token(jwt.to_string());
assert!(auth_config.is_ok());
let auth_config = auth_config.unwrap();
let client = McpClient::new(Box::new(mock_transport)).with_auth(auth_config);
assert!(client.auth_config().is_some());
let headers = client.auth_config().unwrap().get_auth_headers().unwrap();
assert!(headers.contains_key("Authorization"));
assert!(headers.get("Authorization").unwrap().starts_with("Bearer "));
}
#[tokio::test]
async fn test_auth_header_injection_protection() {
use crate::mcp::client::auth::AuthConfig;
let malicious_header_name = "X-API-Key\r\nInjected-Header: malicious";
let auth_result = AuthConfig::new()
.with_custom_header(malicious_header_name.to_string(), "value".to_string());
assert!(auth_result.is_err());
if let Err(ClientError::Validation(msg)) = auth_result {
assert!(msg.contains("invalid characters"));
} else {
panic!("Expected validation error for header injection attempt");
}
}
#[tokio::test]
async fn test_client_without_auth() {
let mock_transport = MockTransport::new(vec![]);
let client = McpClient::new(Box::new(mock_transport));
assert!(client.auth_config().is_none());
}
#[tokio::test]
async fn test_cache_configuration() {
let mock_transport = MockTransport::new(vec![]);
let client = McpClient::new(Box::new(mock_transport));
assert!(client.cache_analytics().is_none());
let (cache_config, _temp_dir) = create_test_cache_config();
let client = client.with_cache(cache_config).await.unwrap();
assert!(client.cache_analytics().is_some());
let analytics = client.cache_analytics().unwrap();
assert_eq!(analytics.resource_count, 0);
assert_eq!(analytics.cache_size_bytes, 0);
let client = client.without_cache();
assert!(client.cache_analytics().is_none());
}
#[tokio::test]
async fn test_cache_operations() {
let mock_transport = MockTransport::new(vec![]);
let (cache_config, _temp_dir) = create_test_cache_config();
let mut client = McpClient::new(Box::new(mock_transport))
.with_cache(cache_config)
.await
.unwrap();
let cached_resources = client.list_cached_resources().await.unwrap();
assert_eq!(cached_resources.len(), 0);
let search_results = client.search_cached_resources("test").await.unwrap();
assert_eq!(search_results.len(), 0);
client.invalidate_cache(Some("nonexistent")).await.unwrap();
client.invalidate_cache(None).await.unwrap();
let cleaned_count = client.cleanup_cache().await.unwrap();
assert_eq!(cleaned_count, 0);
}
#[tokio::test]
async fn test_cache_analytics_tracking() {
let mock_transport = MockTransport::new(vec![]);
let (cache_config, _temp_dir) = create_test_cache_config();
let mut client = McpClient::new(Box::new(mock_transport))
.with_cache(cache_config)
.await
.unwrap();
let analytics = client.cache_analytics().unwrap();
assert_eq!(analytics.total_requests, 0);
assert_eq!(analytics.cache_hits, 0);
assert_eq!(analytics.cache_misses, 0);
assert_eq!(analytics.hit_rate, 0.0);
let result = client.get_resource("test://resource").await;
assert!(result.is_err());
assert!(client.cache_analytics().is_some());
}
#[tokio::test]
async fn test_cache_with_custom_config() {
use std::time::Duration;
let mock_transport = MockTransport::new(vec![]);
let cache_config = crate::mcp::client::cache::CacheConfig {
database_path: ":memory:".to_string(),
default_ttl: Duration::from_secs(300), max_size_mb: 50,
auto_cleanup: true,
cleanup_interval: Duration::from_secs(60),
pool_min_connections: None,
pool_max_connections: None,
pool_connection_timeout: None,
pool_max_lifetime: None,
};
let client = McpClient::new(Box::new(mock_transport))
.with_cache(cache_config)
.await
.unwrap();
assert!(client.cache_analytics().is_some());
let analytics = client.cache_analytics().unwrap();
assert_eq!(analytics.resource_count, 0);
}
}