use std::sync::Arc;
use serde_json::Value;
use super::server_config::{McpServerConfig, TransportConfig};
use super::transport::McpTransport;
use super::transport::http::HttpTransport;
use super::transport::sse::SseTransport;
use super::transport::stdio::StdioTransport;
use super::types::{
ClientCapabilities, ClientInfo, ElicitationCapability, InitializeParams, InitializeResult,
JsonRpcNotification, JsonRpcRequest, MCP_PROTOCOL_VERSION, McpContent, McpPrompt,
McpPromptGetParams, McpPromptGetResult, McpPromptsListResult, McpResource,
McpResourceReadParams, McpResourceReadResult, McpResourcesListResult, McpTool,
McpToolCallParams, McpToolCallResult, McpToolsListResult, RootsCapability, SamplingCapability,
ServerCapabilities,
};
use echo_core::error::{McpError, ReactError, Result};
pub struct McpClient {
transport: Arc<dyn McpTransport>,
server_name: String,
negotiated_version: String,
server_capabilities: ServerCapabilities,
tools: Vec<McpTool>,
resources: Vec<McpResource>,
prompts: Vec<McpPrompt>,
}
impl McpClient {
pub async fn new(config: McpServerConfig) -> Result<Arc<Self>> {
let transport: Arc<dyn McpTransport> = match config.transport {
TransportConfig::Stdio { command, args, env } => {
Arc::new(StdioTransport::new(&command, &args, &env).await?)
}
TransportConfig::Http { base_url, headers } => {
Arc::new(HttpTransport::new(base_url, headers))
}
TransportConfig::Sse { base_url, headers } => {
Arc::new(SseTransport::new(base_url, headers).await?)
}
};
tracing::info!("MCP: 正在连接服务端 '{}'", config.name);
let init_params = InitializeParams {
protocol_version: MCP_PROTOCOL_VERSION.to_string(),
capabilities: Self::build_client_capabilities(),
client_info: ClientInfo {
name: "echo-agent".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
title: Some("Echo Agent MCP Client".to_string()),
description: None,
icons: Vec::new(),
website_url: None,
},
};
let init_req = JsonRpcRequest::new("initialize", Some(serde_json::to_value(init_params)?));
let init_resp = transport.send(init_req).await?;
if let Some(err) = init_resp.error {
return Err(ReactError::Mcp(McpError::InitializationFailed(err.message)));
}
let init_result: InitializeResult =
serde_json::from_value(init_resp.result.ok_or_else(|| {
ReactError::Mcp(McpError::InitializationFailed(
"initialize 响应为空".to_string(),
))
})?)?;
let negotiated_version = init_result.protocol_version.clone();
tracing::info!(
"MCP: 已连接 '{}' (协议版本: {}, 请求版本: {})",
config.name,
negotiated_version,
MCP_PROTOCOL_VERSION
);
if let Some(info) = &init_result.server_info {
tracing::info!("MCP: 服务端信息: {} v{}", info.name, info.version);
}
if let Some(instructions) = &init_result.instructions {
tracing::info!(
"MCP: 服务端指令: {}",
instructions.chars().take(100).collect::<String>()
);
}
transport
.notify(JsonRpcNotification::new("notifications/initialized", None))
.await?;
let server_capabilities = init_result.capabilities;
let mut tools = Vec::new();
let mut resources = Vec::new();
let mut prompts = Vec::new();
if server_capabilities.tools.is_some() {
tools = Self::fetch_tools(&transport, &config.name).await?;
tracing::info!("MCP: 从 '{}' 发现 {} 个工具", config.name, tools.len());
}
if server_capabilities.resources.is_some() {
resources = Self::fetch_resources(&transport, &config.name).await?;
tracing::info!("MCP: 从 '{}' 发现 {} 个资源", config.name, resources.len());
}
if server_capabilities.prompts.is_some() {
prompts = Self::fetch_prompts(&transport, &config.name).await?;
tracing::info!("MCP: 从 '{}' 发现 {} 个提示词", config.name, prompts.len());
}
Ok(Arc::new(McpClient {
transport,
server_name: config.name,
negotiated_version,
server_capabilities,
tools,
resources,
prompts,
}))
}
fn build_client_capabilities() -> ClientCapabilities {
ClientCapabilities {
roots: Some(RootsCapability {
list_changed: Some(true),
}),
sampling: Some(SamplingCapability::default()),
elicitation: Some(ElicitationCapability::default()),
experimental: None,
}
}
async fn fetch_tools(
transport: &Arc<dyn McpTransport>,
server_name: &str,
) -> Result<Vec<McpTool>> {
let mut all_tools = Vec::new();
let mut cursor: Option<String> = None;
let mut iterations = 0;
const MAX_PAGINATION: usize = 100;
loop {
iterations += 1;
if iterations > MAX_PAGINATION {
tracing::warn!(
"MCP: '{}' tools/list 达到最大分页限制 ({}),停止获取",
server_name,
MAX_PAGINATION
);
break;
}
let params = cursor.as_ref().map(|c| serde_json::json!({ "cursor": c }));
let req = JsonRpcRequest::new("tools/list", params);
let resp =
tokio::time::timeout(std::time::Duration::from_secs(30), transport.send(req))
.await
.map_err(|_| {
ReactError::Mcp(McpError::ProtocolError("获取工具列表超时".to_string()))
})??;
if let Some(err) = resp.error {
tracing::warn!(
"MCP: '{}' tools/list 返回错误: {}",
server_name,
err.message
);
break;
}
let result: McpToolsListResult =
serde_json::from_value(resp.result.unwrap_or(Value::Null))?;
all_tools.extend(result.tools);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(all_tools)
}
pub async fn refresh_tools(&mut self) -> Result<()> {
self.tools = Self::fetch_tools(&self.transport, &self.server_name).await?;
tracing::info!(
"MCP: '{}' 工具列表已刷新,共 {} 个",
self.server_name,
self.tools.len()
);
Ok(())
}
pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<McpToolCallResult> {
let params = McpToolCallParams {
name: name.to_string(),
arguments: Some(arguments),
};
let req = JsonRpcRequest::new("tools/call", Some(serde_json::to_value(params)?));
let resp = self.transport.send(req).await?;
if let Some(err) = resp.error {
return Err(ReactError::Mcp(McpError::ToolCallFailed(format!(
"工具 '{}' 调用失败: {}",
name, err.message
))));
}
let result: McpToolCallResult = serde_json::from_value(resp.result.unwrap_or(Value::Null))?;
Ok(result)
}
pub fn tools(&self) -> &[McpTool] {
&self.tools
}
async fn fetch_resources(
transport: &Arc<dyn McpTransport>,
server_name: &str,
) -> Result<Vec<McpResource>> {
let mut all_resources = Vec::new();
let mut cursor: Option<String> = None;
let mut iterations = 0;
const MAX_PAGINATION: usize = 100;
loop {
iterations += 1;
if iterations > MAX_PAGINATION {
tracing::warn!(
"MCP: '{}' resources/list 达到最大分页限制 ({}),停止获取",
server_name,
MAX_PAGINATION
);
break;
}
let params = cursor.as_ref().map(|c| serde_json::json!({ "cursor": c }));
let req = JsonRpcRequest::new("resources/list", params);
let resp =
tokio::time::timeout(std::time::Duration::from_secs(30), transport.send(req))
.await
.map_err(|_| {
ReactError::Mcp(McpError::ProtocolError("获取资源列表超时".to_string()))
})??;
if let Some(err) = resp.error {
tracing::warn!(
"MCP: '{}' resources/list 返回错误: {}",
server_name,
err.message
);
break;
}
let result: McpResourcesListResult =
serde_json::from_value(resp.result.unwrap_or(Value::Null))?;
all_resources.extend(result.resources);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(all_resources)
}
pub async fn refresh_resources(&mut self) -> Result<()> {
self.resources = Self::fetch_resources(&self.transport, &self.server_name).await?;
tracing::info!(
"MCP: '{}' 资源列表已刷新,共 {} 个",
self.server_name,
self.resources.len()
);
Ok(())
}
pub async fn read_resource(&self, uri: &str) -> Result<McpResourceReadResult> {
let params = McpResourceReadParams {
uri: uri.to_string(),
};
let req = JsonRpcRequest::new("resources/read", Some(serde_json::to_value(params)?));
let resp = self.transport.send(req).await?;
if let Some(err) = resp.error {
return Err(ReactError::Mcp(McpError::ProtocolError(format!(
"读取资源 '{}' 失败: {}",
uri, err.message
))));
}
let result: McpResourceReadResult =
serde_json::from_value(resp.result.unwrap_or(Value::Null))?;
Ok(result)
}
pub fn resources(&self) -> &[McpResource] {
&self.resources
}
pub fn supports_resources(&self) -> bool {
self.server_capabilities.resources.is_some()
}
async fn fetch_prompts(
transport: &Arc<dyn McpTransport>,
server_name: &str,
) -> Result<Vec<McpPrompt>> {
let mut all_prompts = Vec::new();
let mut cursor: Option<String> = None;
let mut iterations = 0;
const MAX_PAGINATION: usize = 100;
loop {
iterations += 1;
if iterations > MAX_PAGINATION {
tracing::warn!(
"MCP: '{}' prompts/list 达到最大分页限制 ({}),停止获取",
server_name,
MAX_PAGINATION
);
break;
}
let params = cursor.as_ref().map(|c| serde_json::json!({ "cursor": c }));
let req = JsonRpcRequest::new("prompts/list", params);
let resp =
tokio::time::timeout(std::time::Duration::from_secs(30), transport.send(req))
.await
.map_err(|_| {
ReactError::Mcp(McpError::ProtocolError("获取提示词列表超时".to_string()))
})??;
if let Some(err) = resp.error {
tracing::warn!(
"MCP: '{}' prompts/list 返回错误: {}",
server_name,
err.message
);
break;
}
let result: McpPromptsListResult =
serde_json::from_value(resp.result.unwrap_or(Value::Null))?;
all_prompts.extend(result.prompts);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(all_prompts)
}
pub async fn refresh_prompts(&mut self) -> Result<()> {
self.prompts = Self::fetch_prompts(&self.transport, &self.server_name).await?;
tracing::info!(
"MCP: '{}' 提示词列表已刷新,共 {} 个",
self.server_name,
self.prompts.len()
);
Ok(())
}
pub async fn get_prompt(
&self,
name: &str,
arguments: Option<std::collections::HashMap<String, String>>,
) -> Result<McpPromptGetResult> {
let params = McpPromptGetParams {
name: name.to_string(),
arguments,
};
let req = JsonRpcRequest::new("prompts/get", Some(serde_json::to_value(params)?));
let resp = self.transport.send(req).await?;
if let Some(err) = resp.error {
return Err(ReactError::Mcp(McpError::ProtocolError(format!(
"获取提示词 '{}' 失败: {}",
name, err.message
))));
}
let result: McpPromptGetResult =
serde_json::from_value(resp.result.unwrap_or(Value::Null))?;
Ok(result)
}
pub fn prompts(&self) -> &[McpPrompt] {
&self.prompts
}
pub fn supports_prompts(&self) -> bool {
self.server_capabilities.prompts.is_some()
}
pub async fn ping(&self) -> Result<()> {
let req = JsonRpcRequest::new("ping", None);
let resp = self.transport.send(req).await?;
if let Some(err) = resp.error {
return Err(ReactError::Mcp(McpError::ProtocolError(format!(
"ping 失败: {}",
err.message
))));
}
Ok(())
}
pub fn server_name(&self) -> &str {
&self.server_name
}
pub fn protocol_version(&self) -> &str {
&self.negotiated_version
}
pub fn server_capabilities(&self) -> &ServerCapabilities {
&self.server_capabilities
}
pub async fn close(&self) {
self.transport.close().await;
}
pub fn content_to_text(content: &[McpContent]) -> String {
content
.iter()
.map(|c| match c {
McpContent::Text { text } => text.clone(),
McpContent::Image { mime_type, .. } => format!("[图片: {}]", mime_type),
McpContent::Resource { resource } => {
let name = resource.name.as_deref().unwrap_or("unnamed");
format!("[资源: {} ({})]", name, resource.uri)
}
McpContent::Audio { mime_type, .. } => format!("[音频: {}]", mime_type),
})
.collect::<Vec<_>>()
.join("\n")
}
}