mod channel;
mod handler;
#[cfg(feature = "http-client")]
mod http;
#[cfg(feature = "oauth-client")]
mod oauth;
#[cfg(feature = "oauth-client")]
mod oauth_authcode;
mod stdio;
mod transport;
pub use channel::ChannelTransport;
pub use handler::{ClientHandler, NotificationHandler, ServerNotification};
#[cfg(feature = "http-client")]
pub use http::{HttpClientConfig, HttpClientTransport};
#[cfg(feature = "oauth-client")]
pub use oauth::{
OAuthClientCredentials, OAuthClientCredentialsBuilder, OAuthClientError, TokenProvider,
};
#[cfg(feature = "oauth-client")]
pub use oauth_authcode::{OAuthAuthCodeConfig, OAuthAuthorizationCode};
pub use stdio::StdioClientTransport;
pub use transport::ClientTransport;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
use tokio::task::JoinHandle;
use crate::error::{Error, Result};
use crate::protocol::{
CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
CompletionArgument, CompletionReference, ElicitationCapability, GetPromptParams,
GetPromptResult, Implementation, InitializeParams, InitializeResult, JsonRpcNotification,
JsonRpcRequest, ListPromptsParams, ListPromptsResult, ListResourceTemplatesParams,
ListResourceTemplatesResult, ListResourcesParams, ListResourcesResult, ListRootsResult,
ListToolsParams, ListToolsResult, PromptDefinition, ReadResourceParams, ReadResourceResult,
RequestId, ResourceDefinition, ResourceTemplateDefinition, Root, RootsCapability,
SamplingCapability, ToolDefinition, notifications,
};
use tower_mcp_types::JsonRpcError;
enum LoopCommand {
Request {
method: String,
params: serde_json::Value,
response_tx: oneshot::Sender<Result<serde_json::Value>>,
},
Notify {
method: String,
params: serde_json::Value,
},
ResetSession { done_tx: oneshot::Sender<()> },
Shutdown,
}
pub struct McpClient {
command_tx: mpsc::Sender<LoopCommand>,
task: Option<JoinHandle<()>>,
initialized: AtomicBool,
server_info: RwLock<Option<InitializeResult>>,
capabilities: ClientCapabilities,
roots: Arc<RwLock<Vec<Root>>>,
connected: Arc<AtomicBool>,
supports_session_recovery: bool,
init_params: RwLock<Option<(String, String)>>,
recovery_lock: Mutex<()>,
}
pub struct McpClientBuilder {
capabilities: ClientCapabilities,
roots: Vec<Root>,
}
impl McpClientBuilder {
pub fn new() -> Self {
Self {
capabilities: ClientCapabilities::default(),
roots: Vec::new(),
}
}
pub fn with_roots(mut self, roots: Vec<Root>) -> Self {
self.roots = roots;
self.capabilities.roots = Some(RootsCapability { list_changed: true });
self
}
pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
self.capabilities = capabilities;
self
}
pub fn with_sampling(mut self) -> Self {
self.capabilities.sampling = Some(SamplingCapability::default());
self
}
pub fn with_elicitation(mut self) -> Self {
self.capabilities.elicitation = Some(ElicitationCapability::default());
self
}
pub async fn connect<T, H>(self, transport: T, handler: H) -> Result<McpClient>
where
T: ClientTransport,
H: ClientHandler,
{
McpClient::connect_inner(transport, handler, self.capabilities, self.roots).await
}
pub async fn connect_simple<T: ClientTransport>(self, transport: T) -> Result<McpClient> {
self.connect(transport, ()).await
}
}
impl Default for McpClientBuilder {
fn default() -> Self {
Self::new()
}
}
impl McpClient {
pub async fn connect<T: ClientTransport>(transport: T) -> Result<Self> {
McpClientBuilder::new().connect_simple(transport).await
}
pub async fn connect_with_handler<T, H>(transport: T, handler: H) -> Result<Self>
where
T: ClientTransport,
H: ClientHandler,
{
McpClientBuilder::new().connect(transport, handler).await
}
pub fn builder() -> McpClientBuilder {
McpClientBuilder::new()
}
async fn connect_inner<T, H>(
transport: T,
handler: H,
capabilities: ClientCapabilities,
roots: Vec<Root>,
) -> Result<Self>
where
T: ClientTransport,
H: ClientHandler,
{
let supports_session_recovery = transport.supports_session_recovery();
let (command_tx, command_rx) = mpsc::channel::<LoopCommand>(64);
let connected = Arc::new(AtomicBool::new(true));
let roots = Arc::new(RwLock::new(roots));
let loop_connected = connected.clone();
let loop_roots = roots.clone();
let task = tokio::spawn(async move {
message_loop(transport, handler, command_rx, loop_connected, loop_roots).await;
});
Ok(Self {
command_tx,
task: Some(task),
initialized: AtomicBool::new(false),
server_info: RwLock::new(None),
capabilities,
roots,
connected,
supports_session_recovery,
init_params: RwLock::new(None),
recovery_lock: Mutex::new(()),
})
}
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Acquire)
}
pub fn is_connected(&self) -> bool {
self.connected.load(Ordering::Acquire)
}
pub async fn server_info(&self) -> Option<InitializeResult> {
self.server_info.read().await.clone()
}
pub fn server_info_blocking(&self) -> Option<InitializeResult> {
self.server_info.try_read().ok()?.clone()
}
pub async fn initialize(
&self,
client_name: &str,
client_version: &str,
) -> Result<InitializeResult> {
let params = InitializeParams {
protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
capabilities: self.capabilities.clone(),
client_info: Implementation {
name: client_name.to_string(),
version: client_version.to_string(),
..Default::default()
},
meta: None,
};
let result: InitializeResult = self.send_request("initialize", ¶ms).await?;
*self.server_info.write().await = Some(result.clone());
*self.init_params.write().await =
Some((client_name.to_string(), client_version.to_string()));
self.send_notification("notifications/initialized", &serde_json::json!({}))
.await?;
self.initialized.store(true, Ordering::Release);
Ok(result)
}
pub async fn list_tools(&self) -> Result<ListToolsResult> {
self.ensure_initialized()?;
self.send_request(
"tools/list",
&ListToolsParams {
cursor: None,
meta: None,
},
)
.await
}
pub async fn call_tool(
&self,
name: &str,
arguments: serde_json::Value,
) -> Result<CallToolResult> {
self.ensure_initialized()?;
let params = CallToolParams {
name: name.to_string(),
arguments,
meta: None,
task: None,
};
self.send_request("tools/call", ¶ms).await
}
pub async fn list_resources(&self) -> Result<ListResourcesResult> {
self.ensure_initialized()?;
self.send_request(
"resources/list",
&ListResourcesParams {
cursor: None,
meta: None,
},
)
.await
}
pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
self.ensure_initialized()?;
let params = ReadResourceParams {
uri: uri.to_string(),
meta: None,
};
self.send_request("resources/read", ¶ms).await
}
pub async fn list_prompts(&self) -> Result<ListPromptsResult> {
self.ensure_initialized()?;
self.send_request(
"prompts/list",
&ListPromptsParams {
cursor: None,
meta: None,
},
)
.await
}
pub async fn list_tools_with_cursor(&self, cursor: Option<String>) -> Result<ListToolsResult> {
self.ensure_initialized()?;
self.send_request("tools/list", &ListToolsParams { cursor, meta: None })
.await
}
pub async fn list_resources_with_cursor(
&self,
cursor: Option<String>,
) -> Result<ListResourcesResult> {
self.ensure_initialized()?;
self.send_request(
"resources/list",
&ListResourcesParams { cursor, meta: None },
)
.await
}
pub async fn list_resource_templates(&self) -> Result<ListResourceTemplatesResult> {
self.ensure_initialized()?;
self.send_request(
"resources/templates/list",
&ListResourceTemplatesParams {
cursor: None,
meta: None,
},
)
.await
}
pub async fn list_resource_templates_with_cursor(
&self,
cursor: Option<String>,
) -> Result<ListResourceTemplatesResult> {
self.ensure_initialized()?;
self.send_request(
"resources/templates/list",
&ListResourceTemplatesParams { cursor, meta: None },
)
.await
}
pub async fn list_prompts_with_cursor(
&self,
cursor: Option<String>,
) -> Result<ListPromptsResult> {
self.ensure_initialized()?;
self.send_request("prompts/list", &ListPromptsParams { cursor, meta: None })
.await
}
pub async fn list_all_tools(&self) -> Result<Vec<ToolDefinition>> {
let mut all = Vec::new();
let mut cursor = None;
loop {
let result = self.list_tools_with_cursor(cursor).await?;
all.extend(result.tools);
match result.next_cursor {
Some(c) => cursor = Some(c),
None => break,
}
}
Ok(all)
}
pub async fn list_all_resources(&self) -> Result<Vec<ResourceDefinition>> {
let mut all = Vec::new();
let mut cursor = None;
loop {
let result = self.list_resources_with_cursor(cursor).await?;
all.extend(result.resources);
match result.next_cursor {
Some(c) => cursor = Some(c),
None => break,
}
}
Ok(all)
}
pub async fn list_all_resource_templates(&self) -> Result<Vec<ResourceTemplateDefinition>> {
let mut all = Vec::new();
let mut cursor = None;
loop {
let result = self.list_resource_templates_with_cursor(cursor).await?;
all.extend(result.resource_templates);
match result.next_cursor {
Some(c) => cursor = Some(c),
None => break,
}
}
Ok(all)
}
pub async fn list_all_prompts(&self) -> Result<Vec<PromptDefinition>> {
let mut all = Vec::new();
let mut cursor = None;
loop {
let result = self.list_prompts_with_cursor(cursor).await?;
all.extend(result.prompts);
match result.next_cursor {
Some(c) => cursor = Some(c),
None => break,
}
}
Ok(all)
}
pub async fn call_tool_text(&self, name: &str, arguments: serde_json::Value) -> Result<String> {
let result = self.call_tool(name, arguments).await?;
if result.is_error {
return Err(Error::Internal(result.all_text()));
}
Ok(result.all_text())
}
pub async fn get_prompt(
&self,
name: &str,
arguments: Option<std::collections::HashMap<String, String>>,
) -> Result<GetPromptResult> {
self.ensure_initialized()?;
let params = GetPromptParams {
name: name.to_string(),
arguments: arguments.unwrap_or_default(),
meta: None,
};
self.send_request("prompts/get", ¶ms).await
}
pub async fn ping(&self) -> Result<()> {
let _: serde_json::Value = self.send_request("ping", &serde_json::json!({})).await?;
Ok(())
}
pub async fn complete(
&self,
reference: CompletionReference,
argument_name: &str,
argument_value: &str,
) -> Result<CompleteResult> {
self.ensure_initialized()?;
let params = CompleteParams {
reference,
argument: CompletionArgument::new(argument_name, argument_value),
context: None,
meta: None,
};
self.send_request("completion/complete", ¶ms).await
}
pub async fn complete_prompt_arg(
&self,
prompt_name: &str,
argument_name: &str,
argument_value: &str,
) -> Result<CompleteResult> {
self.complete(
CompletionReference::prompt(prompt_name),
argument_name,
argument_value,
)
.await
}
pub async fn complete_resource_uri(
&self,
resource_uri: &str,
argument_name: &str,
argument_value: &str,
) -> Result<CompleteResult> {
self.complete(
CompletionReference::resource(resource_uri),
argument_name,
argument_value,
)
.await
}
pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
&self,
method: &str,
params: &P,
) -> Result<R> {
self.send_request(method, params).await
}
pub async fn notify<P: serde::Serialize>(&self, method: &str, params: &P) -> Result<()> {
self.send_notification(method, params).await
}
pub async fn roots(&self) -> Vec<Root> {
self.roots.read().await.clone()
}
pub async fn set_roots(&self, roots: Vec<Root>) -> Result<()> {
*self.roots.write().await = roots;
if self.is_initialized() {
self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
.await?;
}
Ok(())
}
pub async fn add_root(&self, root: Root) -> Result<()> {
self.roots.write().await.push(root);
if self.is_initialized() {
self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
.await?;
}
Ok(())
}
pub async fn remove_root(&self, uri: &str) -> Result<bool> {
let mut roots = self.roots.write().await;
let initial_len = roots.len();
roots.retain(|r| r.uri != uri);
let removed = roots.len() < initial_len;
drop(roots);
if removed && self.is_initialized() {
self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
.await?;
}
Ok(removed)
}
pub async fn list_roots(&self) -> ListRootsResult {
ListRootsResult {
roots: self.roots.read().await.clone(),
meta: None,
}
}
pub async fn shutdown(mut self) -> Result<()> {
let _ = self.command_tx.send(LoopCommand::Shutdown).await;
if let Some(task) = self.task.take() {
let _ = task.await;
}
Ok(())
}
async fn send_request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
&self,
method: &str,
params: &P,
) -> Result<R> {
match self.send_request_once(method, params).await {
Err(Error::SessionExpired)
if self.supports_session_recovery && method != "initialize" =>
{
tracing::info!(method = %method, "Session expired, attempting recovery");
self.recover_session().await?;
self.send_request_once(method, params).await
}
other => other,
}
}
async fn send_request_once<P: serde::Serialize, R: serde::de::DeserializeOwned>(
&self,
method: &str,
params: &P,
) -> Result<R> {
self.ensure_connected()?;
let params_value = serde_json::to_value(params)
.map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
let (response_tx, response_rx) = oneshot::channel();
self.command_tx
.send(LoopCommand::Request {
method: method.to_string(),
params: params_value,
response_tx,
})
.await
.map_err(|_| Error::Transport("Connection closed".to_string()))?;
let result = response_rx
.await
.map_err(|_| Error::Transport("Connection closed".to_string()))??;
serde_json::from_value(result)
.map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
}
async fn recover_session(&self) -> Result<()> {
let _guard = self.recovery_lock.lock().await;
let init_params = self.init_params.read().await.clone();
let (client_name, client_version) = match init_params {
Some(params) => params,
None => {
return Err(Error::Transport(
"Cannot recover: never initialized".to_string(),
));
}
};
let (done_tx, done_rx) = oneshot::channel();
self.command_tx
.send(LoopCommand::ResetSession { done_tx })
.await
.map_err(|_| Error::Transport("Connection closed".to_string()))?;
done_rx
.await
.map_err(|_| Error::Transport("Connection closed during recovery".to_string()))?;
self.initialized.store(false, Ordering::Release);
*self.server_info.write().await = None;
tracing::info!("Re-initializing session after expiry");
let params = InitializeParams {
protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
capabilities: self.capabilities.clone(),
client_info: Implementation {
name: client_name,
version: client_version,
..Default::default()
},
meta: None,
};
let result: InitializeResult = self.send_request_once("initialize", ¶ms).await?;
*self.server_info.write().await = Some(result);
self.send_notification("notifications/initialized", &serde_json::json!({}))
.await?;
self.initialized.store(true, Ordering::Release);
Ok(())
}
async fn send_notification<P: serde::Serialize>(&self, method: &str, params: &P) -> Result<()> {
self.ensure_connected()?;
let params_value = serde_json::to_value(params)
.map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
self.command_tx
.send(LoopCommand::Notify {
method: method.to_string(),
params: params_value,
})
.await
.map_err(|_| Error::Transport("Connection closed".to_string()))?;
Ok(())
}
fn ensure_connected(&self) -> Result<()> {
if !self.connected.load(Ordering::Acquire) {
return Err(Error::Transport("Connection closed".to_string()));
}
Ok(())
}
fn ensure_initialized(&self) -> Result<()> {
if !self.initialized.load(Ordering::Acquire) {
return Err(Error::Transport("Client not initialized".to_string()));
}
Ok(())
}
}
impl Drop for McpClient {
fn drop(&mut self) {
if let Some(task) = self.task.take() {
task.abort();
}
}
}
struct PendingRequest {
response_tx: oneshot::Sender<Result<serde_json::Value>>,
}
async fn message_loop<T: ClientTransport, H: ClientHandler>(
mut transport: T,
handler: H,
mut command_rx: mpsc::Receiver<LoopCommand>,
connected: Arc<AtomicBool>,
roots: Arc<RwLock<Vec<Root>>>,
) {
let handler = Arc::new(handler);
let mut pending_requests: HashMap<RequestId, PendingRequest> = HashMap::new();
let next_id = AtomicI64::new(1);
loop {
tokio::select! {
command = command_rx.recv() => {
match command {
Some(LoopCommand::Request { method, params, response_tx }) => {
let id = RequestId::Number(next_id.fetch_add(1, Ordering::Relaxed));
let request = JsonRpcRequest::new(id.clone(), &method)
.with_params(params);
let json = match serde_json::to_string(&request) {
Ok(j) => j,
Err(e) => {
let _ = response_tx.send(Err(Error::Transport(
format!("Serialization failed: {}", e)
)));
continue;
}
};
tracing::debug!(method = %method, id = ?id, "Sending request");
pending_requests.insert(id, PendingRequest { response_tx });
if let Err(e) = transport.send(&json).await {
tracing::error!(error = %e, "Transport send error");
fail_all_pending(&mut pending_requests, &format!("Transport error: {}", e));
break;
}
}
Some(LoopCommand::Notify { method, params }) => {
let notification = JsonRpcNotification::new(&method)
.with_params(params);
if let Ok(json) = serde_json::to_string(¬ification) {
tracing::debug!(method = %method, "Sending notification");
let _ = transport.send(&json).await;
}
}
Some(LoopCommand::ResetSession { done_tx }) => {
tracing::info!("Resetting transport session for re-initialization");
transport.reset_session().await;
for (_, pending) in pending_requests.drain() {
let _ = pending.response_tx.send(Err(Error::SessionExpired));
}
let _ = done_tx.send(());
}
Some(LoopCommand::Shutdown) | None => {
tracing::debug!("Message loop shutting down");
break;
}
}
}
result = transport.recv() => {
match result {
Ok(Some(line)) => {
handle_incoming(
&line,
&mut pending_requests,
&handler,
&roots,
&mut transport,
).await;
}
Ok(None) => {
tracing::info!("Transport closed (EOF)");
break;
}
Err(e) => {
tracing::error!(error = %e, "Transport receive error");
break;
}
}
}
}
}
connected.store(false, Ordering::Release);
fail_all_pending(&mut pending_requests, "Connection closed");
let _ = transport.close().await;
}
async fn handle_incoming<T: ClientTransport, H: ClientHandler>(
line: &str,
pending_requests: &mut HashMap<RequestId, PendingRequest>,
handler: &Arc<H>,
roots: &Arc<RwLock<Vec<Root>>>,
transport: &mut T,
) {
let parsed: serde_json::Value = match serde_json::from_str(line) {
Ok(v) => v,
Err(e) => {
tracing::warn!(error = %e, "Failed to parse incoming message");
return;
}
};
if parsed.get("method").is_none()
&& (parsed.get("result").is_some() || parsed.get("error").is_some())
{
if let Some(error) = parsed.get("error") {
let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(0) as i32;
let id_missing_or_null = parsed.get("id").is_none_or(|id| id.is_null());
if code == -32005 && id_missing_or_null {
tracing::warn!(
"Session expired (-32005 with null id), failing all pending requests"
);
for (_, pending) in pending_requests.drain() {
let _ = pending.response_tx.send(Err(Error::SessionExpired));
}
return;
}
}
handle_response(&parsed, pending_requests);
return;
}
if parsed.get("id").is_some() && parsed.get("method").is_some() {
let id = parse_request_id(&parsed);
let method = parsed["method"].as_str().unwrap_or("");
let params = parsed.get("params").cloned();
let result = dispatch_server_request(handler, roots, method, params).await;
let response = match result {
Ok(value) => {
if let Some(id) = id {
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": value
})
} else {
return;
}
}
Err(error) => {
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": error.code,
"message": error.message
}
})
}
};
if let Ok(json) = serde_json::to_string(&response) {
let _ = transport.send(&json).await;
}
return;
}
if parsed.get("method").is_some() && parsed.get("id").is_none() {
let method = parsed["method"].as_str().unwrap_or("");
let params = parsed.get("params").cloned();
let notification = parse_server_notification(method, params);
handler.on_notification(notification).await;
}
}
fn handle_response(
parsed: &serde_json::Value,
pending_requests: &mut HashMap<RequestId, PendingRequest>,
) {
let id = match parse_request_id(parsed) {
Some(id) => id,
None => {
tracing::warn!("Response without id");
return;
}
};
let pending = match pending_requests.remove(&id) {
Some(p) => p,
None => {
tracing::warn!(id = ?id, "Response for unknown request");
return;
}
};
tracing::debug!(id = ?id, "Received response");
if let Some(error) = parsed.get("error") {
let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
let message = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error")
.to_string();
let data = error.get("data").cloned();
if code == -32005 {
let _ = pending.response_tx.send(Err(Error::SessionExpired));
return;
}
let json_rpc_error = JsonRpcError {
code,
message,
data,
};
let _ = pending
.response_tx
.send(Err(Error::JsonRpc(json_rpc_error)));
} else if let Some(result) = parsed.get("result") {
let _ = pending.response_tx.send(Ok(result.clone()));
} else {
let _ = pending
.response_tx
.send(Err(Error::Transport("Invalid response".to_string())));
}
}
async fn dispatch_server_request<H: ClientHandler>(
handler: &Arc<H>,
roots: &Arc<RwLock<Vec<Root>>>,
method: &str,
params: Option<serde_json::Value>,
) -> std::result::Result<serde_json::Value, JsonRpcError> {
match method {
"sampling/createMessage" => {
let p = serde_json::from_value(params.unwrap_or_default())
.map_err(|e| JsonRpcError::invalid_params(e.to_string()))?;
let result = handler.handle_create_message(p).await?;
serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
}
"elicitation/create" => {
let p = serde_json::from_value(params.unwrap_or_default())
.map_err(|e| JsonRpcError::invalid_params(e.to_string()))?;
let result = handler.handle_elicit(p).await?;
serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
}
"roots/list" => {
let roots_list = roots.read().await;
if !roots_list.is_empty() {
let result = ListRootsResult {
roots: roots_list.clone(),
meta: None,
};
return serde_json::to_value(result)
.map_err(|e| JsonRpcError::internal_error(e.to_string()));
}
drop(roots_list);
let result = handler.handle_list_roots().await?;
serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
}
"ping" => Ok(serde_json::json!({})),
_ => Err(JsonRpcError::method_not_found(method)),
}
}
fn parse_request_id(parsed: &serde_json::Value) -> Option<RequestId> {
parsed.get("id").and_then(|id| {
if let Some(n) = id.as_i64() {
Some(RequestId::Number(n))
} else {
id.as_str().map(|s| RequestId::String(s.to_string()))
}
})
}
fn parse_server_notification(
method: &str,
params: Option<serde_json::Value>,
) -> ServerNotification {
match method {
notifications::PROGRESS => {
if let Some(params) = params
&& let Ok(p) = serde_json::from_value(params)
{
return ServerNotification::Progress(p);
}
ServerNotification::Unknown {
method: method.to_string(),
params: None,
}
}
notifications::MESSAGE => {
if let Some(params) = params
&& let Ok(p) = serde_json::from_value(params)
{
return ServerNotification::LogMessage(p);
}
ServerNotification::Unknown {
method: method.to_string(),
params: None,
}
}
notifications::RESOURCE_UPDATED => {
if let Some(params) = ¶ms
&& let Some(uri) = params.get("uri").and_then(|u| u.as_str())
{
return ServerNotification::ResourceUpdated {
uri: uri.to_string(),
};
}
ServerNotification::Unknown {
method: method.to_string(),
params,
}
}
notifications::RESOURCES_LIST_CHANGED => ServerNotification::ResourcesListChanged,
notifications::TOOLS_LIST_CHANGED => ServerNotification::ToolsListChanged,
notifications::PROMPTS_LIST_CHANGED => ServerNotification::PromptsListChanged,
_ => ServerNotification::Unknown {
method: method.to_string(),
params,
},
}
}
fn fail_all_pending(pending: &mut HashMap<RequestId, PendingRequest>, reason: &str) {
for (_, req) in pending.drain() {
let _ = req
.response_tx
.send(Err(Error::Transport(reason.to_string())));
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::Mutex;
struct MockTransport {
responses: Arc<Mutex<Vec<serde_json::Value>>>,
response_idx: Arc<std::sync::atomic::AtomicUsize>,
incoming_tx: mpsc::Sender<String>,
incoming_rx: mpsc::Receiver<String>,
outgoing: Arc<Mutex<Vec<String>>>,
connected: Arc<AtomicBool>,
}
#[allow(dead_code)]
impl MockTransport {
fn new() -> Self {
let (tx, rx) = mpsc::channel(32);
Self {
responses: Arc::new(Mutex::new(Vec::new())),
response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
incoming_tx: tx,
incoming_rx: rx,
outgoing: Arc::new(Mutex::new(Vec::new())),
connected: Arc::new(AtomicBool::new(true)),
}
}
fn with_responses(responses: Vec<serde_json::Value>) -> Self {
let (tx, rx) = mpsc::channel(32);
Self {
responses: Arc::new(Mutex::new(responses)),
response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
incoming_tx: tx,
incoming_rx: rx,
outgoing: Arc::new(Mutex::new(Vec::new())),
connected: Arc::new(AtomicBool::new(true)),
}
}
}
#[async_trait]
impl ClientTransport for MockTransport {
async fn send(&mut self, message: &str) -> Result<()> {
self.outgoing.lock().unwrap().push(message.to_string());
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(message) {
if let Some(id) = parsed.get("id") {
let idx = self
.response_idx
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let responses = self.responses.lock().unwrap();
if let Some(result) = responses.get(idx) {
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result
});
let _ = self.incoming_tx.try_send(response.to_string());
}
}
}
Ok(())
}
async fn recv(&mut self) -> Result<Option<String>> {
match self.incoming_rx.recv().await {
Some(msg) => Ok(Some(msg)),
None => Ok(None),
}
}
fn is_connected(&self) -> bool {
self.connected.load(Ordering::Relaxed)
}
async fn close(&mut self) -> Result<()> {
self.connected.store(false, Ordering::Relaxed);
Ok(())
}
}
fn mock_initialize_response() -> serde_json::Value {
serde_json::json!({
"protocolVersion": "2025-11-25",
"serverInfo": {
"name": "test-server",
"version": "1.0.0"
},
"capabilities": {
"tools": {}
}
})
}
#[tokio::test]
async fn test_client_not_initialized() {
let client = McpClient::connect(MockTransport::with_responses(vec![]))
.await
.unwrap();
let result = client.list_tools().await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not initialized"));
}
#[tokio::test]
async fn test_client_initialize() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
]))
.await
.unwrap();
assert!(!client.is_initialized());
let result = client.initialize("test-client", "1.0.0").await;
assert!(result.is_ok());
assert!(client.is_initialized());
let server_info = client.server_info().await.unwrap();
assert_eq!(server_info.server_info.name, "test-server");
}
#[tokio::test]
async fn test_list_tools() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"tools": [
{
"name": "test_tool",
"description": "A test tool",
"inputSchema": {
"type": "object",
"properties": {}
}
}
]
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let tools = client.list_tools().await.unwrap();
assert_eq!(tools.tools.len(), 1);
assert_eq!(tools.tools[0].name, "test_tool");
}
#[tokio::test]
async fn test_call_tool() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"content": [
{
"type": "text",
"text": "Tool result"
}
]
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let result = client
.call_tool("test_tool", serde_json::json!({"arg": "value"}))
.await
.unwrap();
assert!(!result.content.is_empty());
}
#[tokio::test]
async fn test_list_resources() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"resources": [
{
"uri": "file://test.txt",
"name": "Test File"
}
]
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let resources = client.list_resources().await.unwrap();
assert_eq!(resources.resources.len(), 1);
assert_eq!(resources.resources[0].uri, "file://test.txt");
}
#[tokio::test]
async fn test_read_resource() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"contents": [
{
"uri": "file://test.txt",
"text": "File contents"
}
]
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let result = client.read_resource("file://test.txt").await.unwrap();
assert_eq!(result.contents.len(), 1);
assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
}
#[tokio::test]
async fn test_list_prompts() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"prompts": [
{
"name": "test_prompt",
"description": "A test prompt"
}
]
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let prompts = client.list_prompts().await.unwrap();
assert_eq!(prompts.prompts.len(), 1);
assert_eq!(prompts.prompts[0].name, "test_prompt");
}
#[tokio::test]
async fn test_get_prompt() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"messages": [
{
"role": "user",
"content": {
"type": "text",
"text": "Prompt message"
}
}
]
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let result = client.get_prompt("test_prompt", None).await.unwrap();
assert_eq!(result.messages.len(), 1);
}
#[tokio::test]
async fn test_ping() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let result = client.ping().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_with_roots() {
let roots = vec![Root::new("file:///test")];
let client = McpClient::builder()
.with_roots(roots)
.connect_simple(MockTransport::with_responses(vec![]))
.await
.unwrap();
let current_roots = client.roots().await;
assert_eq!(current_roots.len(), 1);
}
#[tokio::test]
async fn test_roots_management() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
]))
.await
.unwrap();
assert!(client.roots().await.is_empty());
client.add_root(Root::new("file:///project")).await.unwrap();
assert_eq!(client.roots().await.len(), 1);
client.initialize("test-client", "1.0.0").await.unwrap();
let removed = client.remove_root("file:///project").await.unwrap();
assert!(removed);
assert!(client.roots().await.is_empty());
let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
assert!(!not_removed);
}
#[tokio::test]
async fn test_list_roots() {
let roots = vec![
Root::new("file:///project1"),
Root::with_name("file:///project2", "Project 2"),
];
let client = McpClient::builder()
.with_roots(roots)
.connect_simple(MockTransport::with_responses(vec![]))
.await
.unwrap();
let result = client.list_roots().await;
assert_eq!(result.roots.len(), 2);
assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
}
#[test]
fn test_builder_with_sampling() {
let builder = McpClientBuilder::new().with_sampling();
assert!(builder.capabilities.sampling.is_some());
}
#[test]
fn test_builder_with_elicitation() {
let builder = McpClientBuilder::new().with_elicitation();
assert!(builder.capabilities.elicitation.is_some());
}
#[test]
fn test_builder_chaining() {
let builder = McpClientBuilder::new()
.with_sampling()
.with_elicitation()
.with_roots(vec![Root::new("file:///project")]);
assert!(builder.capabilities.sampling.is_some());
assert!(builder.capabilities.elicitation.is_some());
assert!(builder.capabilities.roots.is_some());
}
#[tokio::test]
async fn test_bidirectional_sampling_round_trip() {
use crate::protocol::{
ContentRole, CreateMessageParams, CreateMessageResult, SamplingContent,
SamplingContentOrArray,
};
struct RecordingHandler {
called: Arc<AtomicBool>,
}
#[async_trait]
impl ClientHandler for RecordingHandler {
async fn handle_create_message(
&self,
_params: CreateMessageParams,
) -> std::result::Result<CreateMessageResult, tower_mcp_types::JsonRpcError>
{
self.called.store(true, Ordering::SeqCst);
Ok(CreateMessageResult {
content: SamplingContentOrArray::Single(SamplingContent::Text {
text: "test response".to_string(),
annotations: None,
meta: None,
}),
model: "test-model".to_string(),
role: ContentRole::Assistant,
stop_reason: Some("end_turn".to_string()),
meta: None,
})
}
}
let called = Arc::new(AtomicBool::new(false));
let handler = RecordingHandler {
called: called.clone(),
};
let (inject_tx, rx) = mpsc::channel::<String>(32);
let responses = vec![mock_initialize_response()];
let inject_tx_clone = inject_tx.clone();
let transport = MockTransport {
responses: Arc::new(Mutex::new(responses)),
response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
incoming_tx: inject_tx,
incoming_rx: rx,
outgoing: Arc::new(Mutex::new(Vec::new())),
connected: Arc::new(AtomicBool::new(true)),
};
let client = McpClient::builder()
.with_sampling()
.connect(transport, handler)
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let sampling_request = serde_json::json!({
"jsonrpc": "2.0",
"id": 100,
"method": "sampling/createMessage",
"params": {
"messages": [
{
"role": "user",
"content": {
"type": "text",
"text": "Hello"
}
}
],
"maxTokens": 100
}
});
inject_tx_clone
.send(sampling_request.to_string())
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(
called.load(Ordering::SeqCst),
"handle_create_message should have been called"
);
}
#[tokio::test]
async fn test_list_resource_templates() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"resourceTemplates": [
{
"uriTemplate": "file:///{path}",
"name": "File Template",
"description": "A file template"
}
]
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let result = client.list_resource_templates().await.unwrap();
assert_eq!(result.resource_templates.len(), 1);
assert_eq!(result.resource_templates[0].name, "File Template");
}
#[tokio::test]
async fn test_list_all_tools_single_page() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"tools": [
{
"name": "tool_a",
"description": "Tool A",
"inputSchema": { "type": "object", "properties": {} }
},
{
"name": "tool_b",
"description": "Tool B",
"inputSchema": { "type": "object", "properties": {} }
}
]
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let tools = client.list_all_tools().await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].name, "tool_a");
assert_eq!(tools[1].name, "tool_b");
}
#[tokio::test]
async fn test_list_all_tools_paginated() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"tools": [
{
"name": "tool_a",
"description": "Tool A",
"inputSchema": { "type": "object", "properties": {} }
}
],
"nextCursor": "page2"
}),
serde_json::json!({
"tools": [
{
"name": "tool_b",
"description": "Tool B",
"inputSchema": { "type": "object", "properties": {} }
}
]
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let tools = client.list_all_tools().await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].name, "tool_a");
assert_eq!(tools[1].name, "tool_b");
}
#[tokio::test]
async fn test_call_tool_text_success() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"content": [
{ "type": "text", "text": "Hello " },
{ "type": "text", "text": "World" }
]
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let text = client
.call_tool_text("test_tool", serde_json::json!({}))
.await
.unwrap();
assert_eq!(text, "Hello World");
}
#[tokio::test]
async fn test_call_tool_text_error() {
let client = McpClient::connect(MockTransport::with_responses(vec![
mock_initialize_response(),
serde_json::json!({
"content": [
{ "type": "text", "text": "something went wrong" }
],
"isError": true
}),
]))
.await
.unwrap();
client.initialize("test-client", "1.0.0").await.unwrap();
let result = client
.call_tool_text("test_tool", serde_json::json!({}))
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("something went wrong"),
"Error message should contain tool error text, got: {}",
err
);
}
#[tokio::test]
async fn test_server_notification_parsing() {
let notification = parse_server_notification("notifications/tools/list_changed", None);
assert!(matches!(notification, ServerNotification::ToolsListChanged));
let notification = parse_server_notification("notifications/resources/list_changed", None);
assert!(matches!(
notification,
ServerNotification::ResourcesListChanged
));
let notification = parse_server_notification(
"notifications/resources/updated",
Some(serde_json::json!({"uri": "file:///test"})),
);
match notification {
ServerNotification::ResourceUpdated { uri } => {
assert_eq!(uri, "file:///test");
}
_ => panic!("Expected ResourceUpdated"),
}
let notification =
parse_server_notification("custom/notification", Some(serde_json::json!({"data": 42})));
match notification {
ServerNotification::Unknown { method, params } => {
assert_eq!(method, "custom/notification");
assert!(params.is_some());
}
_ => panic!("Expected Unknown"),
}
}
}