use std::sync::Arc;
use tokio::sync::Mutex;
use url::Url;
use crate::error::{Error, Result};
use crate::mcp::types::*;
use crate::utils::security::create_safe_error_message;
#[derive(Clone)]
pub struct MCPClient {
client: reqwest::Client,
server_url: Url,
capabilities: Arc<Mutex<Option<ServerCapabilities>>>,
config: McpConfig,
semaphore: Arc<tokio::sync::Semaphore>,
}
impl MCPClient {
pub fn new(server_url: impl AsRef<str>) -> Result<Self> {
Self::new_with_config(server_url, McpConfig::default())
}
pub fn new_with_config(server_url: impl AsRef<str>, config: McpConfig) -> Result<Self> {
let server_url = Url::parse(server_url.as_ref())
.map_err(|e| Error::ConfigError(format!("Invalid server URL: {e}")))?;
let client = reqwest::Client::builder()
.timeout(config.request_timeout)
.build()
.map_err(|e| Error::ConfigError(format!("Failed to create HTTP client: {e}")))?;
Ok(Self {
client,
server_url,
capabilities: Arc::new(Mutex::new(None)),
config: config.clone(),
semaphore: Arc::new(tokio::sync::Semaphore::new(config.max_concurrent_requests)),
})
}
fn generate_id() -> String {
uuid::Uuid::new_v4().to_string()
}
fn should_skip_id_validation() -> bool {
cfg!(test)
}
pub async fn initialize(
&self,
client_capabilities: ClientCapabilities,
) -> Result<ServerCapabilities> {
let request_id = Self::generate_id();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: request_id.clone(),
method: "initialize".to_string(),
params: Some(
serde_json::to_value(InitializeParams {
capabilities: client_capabilities,
})
.map_err(Error::SerializationError)?,
),
protocol_version: Some(MCP_PROTOCOL_VERSION.to_string()),
};
let response = self.send_request(request).await?;
let capabilities = self.parse_response::<ServerCapabilities>(response, request_id)?;
let mut caps = self.capabilities.lock().await;
*caps = Some(capabilities.clone());
Ok(capabilities)
}
pub async fn get_resource(&self, params: GetResourceParams) -> Result<ResourceResponse> {
self.ensure_initialized().await?;
let request_id = Self::generate_id();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: request_id.clone(),
method: "getResource".to_string(),
params: Some(serde_json::to_value(params).map_err(Error::SerializationError)?),
protocol_version: Some(MCP_PROTOCOL_VERSION.to_string()),
};
let response = self.send_request(request).await?;
self.parse_response::<ResourceResponse>(response, request_id)
}
pub async fn tool_call(&self, params: ToolCallParams) -> Result<ToolCallResponse> {
self.ensure_initialized().await?;
let request_id = Self::generate_id();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: request_id.clone(),
method: "toolCall".to_string(),
params: Some(serde_json::to_value(params).map_err(Error::SerializationError)?),
protocol_version: Some(MCP_PROTOCOL_VERSION.to_string()),
};
let response = self.send_request(request).await?;
self.parse_response::<ToolCallResponse>(response, request_id)
}
pub async fn execute_prompt(
&self,
params: ExecutePromptParams,
) -> Result<ExecutePromptResponse> {
self.ensure_initialized().await?;
let request_id = Self::generate_id();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: request_id.clone(),
method: "executePrompt".to_string(),
params: Some(serde_json::to_value(params).map_err(Error::SerializationError)?),
protocol_version: Some(MCP_PROTOCOL_VERSION.to_string()),
};
let response = self.send_request(request).await?;
self.parse_response::<ExecutePromptResponse>(response, request_id)
}
pub async fn respond_to_sampling(&self, id: String, result: SamplingResponse) -> Result<()> {
self.ensure_initialized().await?;
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: Some(serde_json::to_value(result).map_err(Error::SerializationError)?),
error: None,
};
self.send_response(response).await
}
pub async fn capabilities(&self) -> Option<ServerCapabilities> {
self.capabilities.lock().await.clone()
}
async fn send_request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
let _permit = self.semaphore.acquire().await.map_err(|_| {
Error::ResourceExhausted("Too many concurrent MCP requests".to_string())
})?;
let request_json = serde_json::to_string(&request).map_err(Error::SerializationError)?;
if request_json.len() > self.config.max_request_size {
return Err(Error::ResourceExhausted(format!(
"Request too large: {} bytes (max: {})",
request_json.len(),
self.config.max_request_size
)));
}
let response = tokio::time::timeout(
self.config.request_timeout,
self.client
.post(self.server_url.clone())
.header("Content-Type", "application/json")
.body(request_json)
.send(),
)
.await
.map_err(|_| {
Error::TimeoutError(format!(
"MCP request timeout after {:?}",
self.config.request_timeout
))
})?
.map_err(Error::HttpError)?;
if !response.status().is_success() {
let status_code = response.status().as_u16();
let raw_body = response.text().await.unwrap_or_default();
return Err(Error::ApiError {
code: status_code,
message: create_safe_error_message(&raw_body, "MCP server error"),
metadata: None,
});
}
let content_length = response.content_length().unwrap_or(0);
if content_length > self.config.max_response_size as u64 {
return Err(Error::ResourceExhausted(format!(
"Response too large: {} bytes (max: {})",
content_length, self.config.max_response_size
)));
}
use futures::StreamExt;
let mut stream = response.bytes_stream();
let mut body_bytes = Vec::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(Error::HttpError)?;
if body_bytes.len() + chunk.len() > self.config.max_response_size {
return Err(Error::ResourceExhausted(format!(
"Response body exceeded maximum size of {} bytes",
self.config.max_response_size
)));
}
body_bytes.extend_from_slice(&chunk);
}
let response_body = String::from_utf8(body_bytes)
.map_err(|e| Error::ConfigError(format!("Invalid UTF-8 in response: {}", e)))?;
let response: JsonRpcResponse =
serde_json::from_str(&response_body).map_err(Error::SerializationError)?;
Ok(response)
}
async fn send_response(&self, response: JsonRpcResponse) -> Result<()> {
let _permit = self.semaphore.acquire().await.map_err(|_| {
Error::ResourceExhausted("Too many concurrent MCP requests".to_string())
})?;
let response_json = serde_json::to_string(&response).map_err(Error::SerializationError)?;
if response_json.len() > self.config.max_request_size {
return Err(Error::ResourceExhausted(format!(
"Response too large: {} bytes (max: {})",
response_json.len(),
self.config.max_request_size
)));
}
let _response = tokio::time::timeout(
self.config.request_timeout,
self.client
.post(self.server_url.clone())
.header("Content-Type", "application/json")
.body(response_json)
.send(),
)
.await
.map_err(|_| Error::TimeoutError("MCP response timed out".to_string()))?
.map_err(Error::HttpError)?;
Ok(())
}
fn parse_response<T: serde::de::DeserializeOwned>(
&self,
response: JsonRpcResponse,
expected_id: String,
) -> Result<T> {
if !Self::should_skip_id_validation() && response.id != expected_id {
return Err(Error::ConfigError(format!(
"JSON-RPC response ID mismatch: expected {}, got {}",
expected_id, response.id
)));
}
if let Some(error) = response.error {
return Err(Error::ApiError {
code: error.code.try_into().unwrap_or(500),
message: error.message,
metadata: error.data,
});
}
match response.result {
Some(result) => serde_json::from_value(result).map_err(Error::SerializationError),
None => Err(Error::ConfigError("Response contains no result".into())),
}
}
async fn ensure_initialized(&self) -> Result<()> {
if self.capabilities.lock().await.is_none() {
return Err(Error::ConfigError("MCP client not initialized".into()));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::StatusCode;
use std::time::Duration;
use wiremock::{matchers, Mock, MockServer, ResponseTemplate};
fn create_test_config() -> McpConfig {
McpConfig {
request_timeout: Duration::from_secs(1),
max_response_size: 1024, max_request_size: 512, max_concurrent_requests: 2,
}
}
#[tokio::test]
async fn test_mcp_client_with_config() {
let mock_server = MockServer::start().await;
let client = MCPClient::new_with_config(mock_server.uri(), create_test_config()).unwrap();
assert_eq!(client.config.request_timeout, Duration::from_secs(1));
assert_eq!(client.config.max_response_size, 1024);
assert_eq!(client.config.max_request_size, 512);
assert_eq!(client.config.max_concurrent_requests, 2);
}
#[tokio::test]
async fn test_mcp_client_with_default_config() {
let mock_server = MockServer::start().await;
let client = MCPClient::new(mock_server.uri()).unwrap();
assert_eq!(client.config.request_timeout, Duration::from_secs(30));
assert_eq!(client.config.max_response_size, 10 * 1024 * 1024);
assert_eq!(client.config.max_request_size, 1024 * 1024);
assert_eq!(client.config.max_concurrent_requests, 10);
}
#[tokio::test]
async fn test_request_timeout() {
let mock_server = MockServer::start().await;
Mock::given(matchers::method("POST"))
.respond_with(
ResponseTemplate::new(StatusCode::OK)
.set_delay(Duration::from_secs(3)) .set_body_json(serde_json::json!({
"jsonrpc": "2.0",
"id": "test",
"result": {"protocolVersion": "2025-03-26"}
})),
)
.mount(&mock_server)
.await;
let client = MCPClient::new_with_config(mock_server.uri(), create_test_config()).unwrap();
let capabilities = ClientCapabilities {
protocol_version: "2025-03-26".to_string(),
supports_sampling: None,
};
let result = client.initialize(capabilities).await;
assert!(result.is_err());
let error = result.unwrap_err();
match &error {
Error::TimeoutError(msg) => assert!(msg.contains("timeout")),
Error::ConfigError(msg) => assert!(msg.contains("timed out")),
Error::HttpError(_) => {} _ => panic!("Expected timeout error, got: {:?}", error),
}
}
#[tokio::test]
async fn test_response_size_limit() {
let mock_server = MockServer::start().await;
let large_result = "x".repeat(2048); Mock::given(matchers::method("POST"))
.respond_with(
ResponseTemplate::new(StatusCode::OK).set_body_json(serde_json::json!({
"jsonrpc": "2.0",
"id": "test",
"result": {"data": large_result}
})),
)
.mount(&mock_server)
.await;
let client = MCPClient::new_with_config(mock_server.uri(), create_test_config()).unwrap();
let capabilities = ClientCapabilities {
protocol_version: "2025-03-26".to_string(),
supports_sampling: None,
};
let result = client.initialize(capabilities).await;
assert!(result.is_err());
match result.unwrap_err() {
Error::ResourceExhausted(msg) => assert!(msg.contains("too large")),
_ => panic!("Expected ResourceExhausted error"),
}
}
#[tokio::test]
async fn test_request_size_limit() {
let mock_server = MockServer::start().await;
Mock::given(matchers::method("POST"))
.respond_with(
ResponseTemplate::new(StatusCode::OK).set_body_json(serde_json::json!({
"jsonrpc": "2.0",
"id": "test",
"result": {"protocolVersion": "2025-03-26"}
})),
)
.mount(&mock_server)
.await;
let client = MCPClient::new_with_config(mock_server.uri(), create_test_config()).unwrap();
let large_protocol = "x".repeat(600);
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: "test".to_string(),
method: "initialize".to_string(),
params: Some(serde_json::json!({
"protocolVersion": large_protocol,
"capabilities": {
"sampling": {}
}
})),
protocol_version: Some("2025-03-26".to_string()),
};
let request_json = serde_json::to_string(&request).unwrap();
assert!(request_json.len() > 512, "Request should exceed 512B limit");
let result = client.send_request(request).await;
assert!(result.is_err());
let error = result.unwrap_err();
match &error {
Error::ResourceExhausted(msg) => assert!(msg.contains("Request too large")),
_ => panic!("Expected ResourceExhausted error, got: {:?}", error),
}
}
#[tokio::test]
async fn test_concurrent_request_limiting() {
let mock_server = MockServer::start().await;
Mock::given(matchers::method("POST"))
.respond_with(
ResponseTemplate::new(StatusCode::OK)
.set_delay(Duration::from_millis(300))
.set_body_json(serde_json::json!({
"jsonrpc": "2.0",
"id": "test",
"result": {"protocol_version": "2025-03-26"}
})),
)
.mount(&mock_server)
.await;
let client = MCPClient::new_with_config(mock_server.uri(), create_test_config()).unwrap();
let capabilities = ClientCapabilities {
protocol_version: "2025-03-26".to_string(),
supports_sampling: None,
};
let start = std::time::Instant::now();
let handles: Vec<_> = (0..4)
.map(|_| {
let client = client.clone();
let caps = capabilities.clone();
tokio::spawn(async move { client.initialize(caps).await })
})
.collect();
let results: Vec<_> = futures::future::join_all(handles).await;
let elapsed = start.elapsed();
let mut successes = 0;
let mut errors: Vec<String> = Vec::new();
for r in &results {
match r {
Ok(Ok(_)) => successes += 1,
Ok(Err(e)) => errors.push(format!("Request error: {:?}", e)),
Err(e) => errors.push(format!("Join error: {:?}", e)),
}
}
assert_eq!(
successes, 4,
"All 4 requests should succeed (semaphore queues, doesn't reject). Got {}, errors: {:?}",
successes, errors
);
assert!(
elapsed >= Duration::from_millis(500),
"Expected >= 500ms (2 batches of 300ms), got {:?}. Semaphore may not be limiting.",
elapsed
);
}
#[tokio::test]
async fn test_successful_request_within_limits() {
let mock_server = MockServer::start().await;
Mock::given(matchers::method("POST"))
.respond_with(
ResponseTemplate::new(StatusCode::OK).set_body_json(serde_json::json!({
"jsonrpc": "2.0",
"id": "test",
"result": {
"protocol_version": "2025-03-26"
}
})),
)
.mount(&mock_server)
.await;
let client = MCPClient::new_with_config(mock_server.uri(), create_test_config()).unwrap();
let capabilities = ClientCapabilities {
protocol_version: "2025-03-26".to_string(),
supports_sampling: None,
};
let result = client.initialize(capabilities).await;
assert!(result.is_ok());
let server_caps = result.unwrap();
assert_eq!(server_caps.protocol_version, "2025-03-26");
}
#[tokio::test]
async fn test_content_length_header_validation() {
let mock_server = MockServer::start().await;
let large_result = "x".repeat(1500); Mock::given(matchers::method("POST"))
.respond_with(
ResponseTemplate::new(StatusCode::OK).set_body_json(serde_json::json!({
"jsonrpc": "2.0",
"id": "test",
"result": {"protocolVersion": "2025-03-26", "data": large_result}
})),
)
.mount(&mock_server)
.await;
let client = MCPClient::new_with_config(mock_server.uri(), create_test_config()).unwrap();
let capabilities = ClientCapabilities {
protocol_version: "2025-03-26".to_string(),
supports_sampling: None,
};
let result = client.initialize(capabilities).await;
assert!(result.is_err());
let error = result.unwrap_err();
match &error {
Error::ResourceExhausted(msg) => assert!(msg.contains("too large")),
_ => panic!("Expected ResourceExhausted error, got: {:?}", error),
}
}
#[tokio::test]
async fn test_response_size_limit_with_chunked_encoding() {
let mock_server = MockServer::start().await;
let large_result = "x".repeat(2048); Mock::given(matchers::method("POST"))
.respond_with(
ResponseTemplate::new(StatusCode::OK)
.set_body_json(serde_json::json!({
"jsonrpc": "2.0",
"id": "test",
"result": {"data": large_result}
}))
.append_header("Transfer-Encoding", "chunked"),
)
.mount(&mock_server)
.await;
let client = MCPClient::new_with_config(mock_server.uri(), create_test_config()).unwrap();
let capabilities = ClientCapabilities {
protocol_version: "2025-03-26".to_string(),
supports_sampling: None,
};
let result = client.initialize(capabilities).await;
assert!(result.is_err());
let error = result.unwrap_err();
match &error {
Error::ResourceExhausted(msg) => assert!(msg.contains("exceeded maximum size")),
_ => panic!("Expected ResourceExhausted error, got: {:?}", error),
}
}
}