use super::task::{McpTaskConfig, TaskError, TaskStatus};
use super::{ConnectionFactory, RefreshConfig, should_refresh_connection};
use adk_core::{AdkError, ReadonlyContext, Result, Tool, ToolContext, Toolset};
use async_trait::async_trait;
use rmcp::{
RoleClient,
model::{
CallToolRequestParams, ErrorCode, RawContent, ReadResourceRequestParams, Resource,
ResourceContents, ResourceTemplate,
},
service::RunningService,
};
use serde_json::{Value, json};
use std::ops::Deref;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;
use tracing::{debug, warn};
type DynConnectionFactory<S> = Arc<dyn ConnectionFactory<S>>;
pub type ToolFilter = Arc<dyn Fn(&str) -> bool + Send + Sync>;
fn sanitize_schema(value: &mut Value) {
if let Value::Object(map) = value {
map.remove("$schema");
map.remove("definitions");
map.remove("$ref");
map.remove("additionalProperties");
for (_, v) in map.iter_mut() {
sanitize_schema(v);
}
} else if let Value::Array(arr) = value {
for v in arr.iter_mut() {
sanitize_schema(v);
}
}
}
fn should_retry_mcp_operation(
error: &str,
attempt: u32,
refresh_config: &RefreshConfig,
has_connection_factory: bool,
) -> bool {
has_connection_factory
&& attempt < refresh_config.max_attempts
&& should_refresh_connection(error)
}
fn is_method_not_found(err: &rmcp::ServiceError) -> bool {
matches!(
err,
rmcp::ServiceError::McpError(e) if e.code == ErrorCode::METHOD_NOT_FOUND
)
}
pub struct McpToolset<S = ()>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
{
client: Arc<Mutex<RunningService<RoleClient, S>>>,
tool_filter: Option<ToolFilter>,
name: String,
task_config: McpTaskConfig,
connection_factory: Option<DynConnectionFactory<S>>,
refresh_config: RefreshConfig,
}
impl<S> McpToolset<S>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
{
pub fn new(client: RunningService<RoleClient, S>) -> Self {
Self {
client: Arc::new(Mutex::new(client)),
tool_filter: None,
name: "mcp_toolset".to_string(),
task_config: McpTaskConfig::default(),
connection_factory: None,
refresh_config: RefreshConfig::default(),
}
}
pub fn with_client_handler(client: RunningService<RoleClient, S>) -> Self {
Self::new(client)
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn with_task_support(mut self, config: McpTaskConfig) -> Self {
self.task_config = config;
self
}
pub fn with_connection_factory<F>(mut self, factory: Arc<F>) -> Self
where
F: ConnectionFactory<S> + 'static,
{
self.connection_factory = Some(factory);
self
}
pub fn with_refresh_config(mut self, config: RefreshConfig) -> Self {
self.refresh_config = config;
self
}
pub fn with_filter<F>(mut self, filter: F) -> Self
where
F: Fn(&str) -> bool + Send + Sync + 'static,
{
self.tool_filter = Some(Arc::new(filter));
self
}
pub fn with_tools(self, tool_names: &[&str]) -> Self {
let names: Vec<String> = tool_names.iter().map(|s| s.to_string()).collect();
self.with_filter(move |name| names.iter().any(|n| n == name))
}
pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
let client = self.client.lock().await;
client.cancellation_token()
}
async fn try_refresh_connection(&self) -> Result<bool> {
let Some(factory) = self.connection_factory.clone() else {
return Ok(false);
};
let new_client = factory
.create_connection()
.await
.map_err(|e| AdkError::tool(format!("Failed to refresh MCP connection: {e}")))?;
let mut client = self.client.lock().await;
let old_token = client.cancellation_token();
old_token.cancel();
*client = new_client;
Ok(true)
}
pub async fn list_resources(&self) -> Result<Vec<Resource>> {
let client = self.client.lock().await;
match client.list_all_resources().await {
Ok(resources) => Ok(resources),
Err(e) => {
if is_method_not_found(&e) {
Ok(vec![])
} else {
Err(AdkError::tool(format!("Failed to list MCP resources: {e}")))
}
}
}
}
pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>> {
let client = self.client.lock().await;
match client.list_all_resource_templates().await {
Ok(templates) => Ok(templates),
Err(e) => {
if is_method_not_found(&e) {
Ok(vec![])
} else {
Err(AdkError::tool(format!("Failed to list MCP resource templates: {e}")))
}
}
}
}
pub async fn read_resource(&self, uri: &str) -> Result<Vec<ResourceContents>> {
let client = self.client.lock().await;
let params = ReadResourceRequestParams::new(uri.to_string());
match client.read_resource(params).await {
Ok(result) => Ok(result.contents),
Err(e) => {
if is_method_not_found(&e) {
Err(AdkError::tool(format!("resource not found: {uri}")))
} else {
Err(AdkError::tool(format!("Failed to read MCP resource '{uri}': {e}")))
}
}
}
}
}
#[async_trait]
impl<S> Toolset for McpToolset<S>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
{
fn name(&self) -> &str {
&self.name
}
async fn tools(&self, _ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
let mut attempt = 0u32;
let has_connection_factory = self.connection_factory.is_some();
let mcp_tools = loop {
let list_result = {
let client = self.client.lock().await;
client.list_all_tools().await.map_err(|e| e.to_string())
};
match list_result {
Ok(tools) => break tools,
Err(error) => {
if !should_retry_mcp_operation(
&error,
attempt,
&self.refresh_config,
has_connection_factory,
) {
return Err(AdkError::tool(format!("Failed to list MCP tools: {error}")));
}
let retry_attempt = attempt + 1;
if self.refresh_config.log_reconnections {
warn!(
attempt = retry_attempt,
max_attempts = self.refresh_config.max_attempts,
error = %error,
"MCP list_all_tools failed; reconnecting and retrying"
);
}
if self.refresh_config.retry_delay_ms > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(
self.refresh_config.retry_delay_ms,
))
.await;
}
if !self.try_refresh_connection().await? {
return Err(AdkError::tool(format!("Failed to list MCP tools: {error}")));
}
attempt += 1;
}
}
};
let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
for mcp_tool in mcp_tools {
let tool_name = mcp_tool.name.to_string();
if let Some(ref filter) = self.tool_filter {
if !filter(&tool_name) {
continue;
}
}
let adk_tool = McpTool {
name: tool_name,
description: mcp_tool.description.map(|d| d.to_string()).unwrap_or_default(),
input_schema: {
let mut schema = Value::Object(mcp_tool.input_schema.as_ref().clone());
sanitize_schema(&mut schema);
Some(schema)
},
output_schema: mcp_tool.output_schema.map(|s| {
let mut schema = Value::Object(s.as_ref().clone());
sanitize_schema(&mut schema);
schema
}),
client: self.client.clone(),
connection_factory: self.connection_factory.clone(),
refresh_config: self.refresh_config.clone(),
is_long_running: self.task_config.enable_tasks
&& mcp_tool.annotations.as_ref().is_some_and(|a| {
a.read_only_hint != Some(true) && a.open_world_hint != Some(false)
}),
task_config: self.task_config.clone(),
};
tools.push(Arc::new(adk_tool) as Arc<dyn Tool>);
}
Ok(tools)
}
}
impl McpToolset<super::elicitation::AdkClientHandler> {
pub async fn with_elicitation_handler<T, E, A>(
transport: T,
handler: std::sync::Arc<dyn super::elicitation::ElicitationHandler>,
) -> Result<Self>
where
T: rmcp::transport::IntoTransport<rmcp::RoleClient, E, A> + Send + 'static,
E: std::error::Error + Send + Sync + 'static,
{
use rmcp::ServiceExt;
let adk_handler = super::elicitation::AdkClientHandler::new(handler);
let client = adk_handler
.serve(transport)
.await
.map_err(|e| AdkError::tool(format!("failed to connect MCP server: {e}")))?;
Ok(Self::new(client))
}
}
struct McpTool<S>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
{
name: String,
description: String,
input_schema: Option<Value>,
output_schema: Option<Value>,
client: Arc<Mutex<RunningService<RoleClient, S>>>,
connection_factory: Option<DynConnectionFactory<S>>,
refresh_config: RefreshConfig,
is_long_running: bool,
task_config: McpTaskConfig,
}
impl<S> McpTool<S>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
{
async fn try_refresh_connection(&self) -> Result<bool> {
let Some(factory) = self.connection_factory.clone() else {
return Ok(false);
};
let new_client = factory
.create_connection()
.await
.map_err(|e| AdkError::tool(format!("Failed to refresh MCP connection: {e}")))?;
let mut client = self.client.lock().await;
let old_token = client.cancellation_token();
old_token.cancel();
*client = new_client;
Ok(true)
}
async fn call_tool_with_retry(
&self,
params: CallToolRequestParams,
) -> Result<rmcp::model::CallToolResult> {
let has_connection_factory = self.connection_factory.is_some();
let mut attempt = 0u32;
loop {
let call_result = {
let client = self.client.lock().await;
client.call_tool(params.clone()).await.map_err(|e| e.to_string())
};
match call_result {
Ok(result) => return Ok(result),
Err(error) => {
if !should_retry_mcp_operation(
&error,
attempt,
&self.refresh_config,
has_connection_factory,
) {
return Err(AdkError::tool(format!(
"Failed to call MCP tool '{}': {error}",
self.name
)));
}
let retry_attempt = attempt + 1;
if self.refresh_config.log_reconnections {
warn!(
tool = %self.name,
attempt = retry_attempt,
max_attempts = self.refresh_config.max_attempts,
error = %error,
"MCP call_tool failed; reconnecting and retrying"
);
}
if self.refresh_config.retry_delay_ms > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(
self.refresh_config.retry_delay_ms,
))
.await;
}
if !self.try_refresh_connection().await? {
return Err(AdkError::tool(format!(
"Failed to call MCP tool '{}': {error}",
self.name
)));
}
attempt += 1;
}
}
}
}
async fn poll_task(&self, task_id: &str) -> std::result::Result<Value, TaskError> {
let start = Instant::now();
let mut attempts = 0u32;
loop {
if let Some(timeout_ms) = self.task_config.timeout_ms {
let elapsed = start.elapsed().as_millis() as u64;
if elapsed >= timeout_ms {
return Err(TaskError::Timeout {
task_id: task_id.to_string(),
elapsed_ms: elapsed,
});
}
}
if let Some(max_attempts) = self.task_config.max_poll_attempts {
if attempts >= max_attempts {
return Err(TaskError::MaxAttemptsExceeded {
task_id: task_id.to_string(),
attempts,
});
}
}
tokio::time::sleep(self.task_config.poll_duration()).await;
attempts += 1;
debug!(task_id = task_id, attempt = attempts, "Polling MCP task status");
let poll_result = self
.call_tool_with_retry(CallToolRequestParams::new("tasks/get").with_arguments(
serde_json::Map::from_iter([(
"task_id".to_string(),
Value::String(task_id.to_string()),
)]),
))
.await
.map_err(|e| TaskError::PollFailed(e.to_string()))?;
let status = self.parse_task_status(&poll_result)?;
match status {
TaskStatus::Completed => {
debug!(task_id = task_id, "Task completed successfully");
return self.extract_task_result(&poll_result);
}
TaskStatus::Failed => {
let error_msg = self.extract_error_message(&poll_result);
return Err(TaskError::TaskFailed {
task_id: task_id.to_string(),
error: error_msg,
});
}
TaskStatus::Cancelled => {
return Err(TaskError::Cancelled(task_id.to_string()));
}
TaskStatus::Pending | TaskStatus::Running => {
debug!(
task_id = task_id,
status = ?status,
"Task still in progress"
);
}
}
}
}
fn parse_task_status(
&self,
result: &rmcp::model::CallToolResult,
) -> std::result::Result<TaskStatus, TaskError> {
if let Some(ref structured) = result.structured_content {
if let Some(status_str) = structured.get("status").and_then(|v| v.as_str()) {
return match status_str {
"pending" => Ok(TaskStatus::Pending),
"running" => Ok(TaskStatus::Running),
"completed" => Ok(TaskStatus::Completed),
"failed" => Ok(TaskStatus::Failed),
"cancelled" => Ok(TaskStatus::Cancelled),
_ => {
warn!(status = status_str, "Unknown task status");
Ok(TaskStatus::Running) }
};
}
}
for content in &result.content {
if let Some(text_content) = content.deref().as_text() {
if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
if let Some(status_str) = parsed.get("status").and_then(|v| v.as_str()) {
return match status_str {
"pending" => Ok(TaskStatus::Pending),
"running" => Ok(TaskStatus::Running),
"completed" => Ok(TaskStatus::Completed),
"failed" => Ok(TaskStatus::Failed),
"cancelled" => Ok(TaskStatus::Cancelled),
_ => Ok(TaskStatus::Running),
};
}
}
}
}
Ok(TaskStatus::Running)
}
fn extract_task_result(
&self,
result: &rmcp::model::CallToolResult,
) -> std::result::Result<Value, TaskError> {
if let Some(ref structured) = result.structured_content {
if let Some(output) = structured.get("result") {
return Ok(json!({ "output": output }));
}
return Ok(json!({ "output": structured }));
}
let mut text_parts: Vec<String> = Vec::new();
for content in &result.content {
if let Some(text_content) = content.deref().as_text() {
text_parts.push(text_content.text.clone());
}
}
if text_parts.is_empty() {
Ok(json!({ "output": null }))
} else {
Ok(json!({ "output": text_parts.join("\n") }))
}
}
fn extract_error_message(&self, result: &rmcp::model::CallToolResult) -> String {
if let Some(ref structured) = result.structured_content {
if let Some(error) = structured.get("error").and_then(|v| v.as_str()) {
return error.to_string();
}
}
for content in &result.content {
if let Some(text_content) = content.deref().as_text() {
return text_content.text.clone();
}
}
"Unknown error".to_string()
}
fn extract_task_id(
&self,
result: &rmcp::model::CallToolResult,
) -> std::result::Result<String, TaskError> {
if let Some(ref structured) = result.structured_content {
if let Some(task_id) = structured.get("task_id").and_then(|v| v.as_str()) {
return Ok(task_id.to_string());
}
}
for content in &result.content {
if let Some(text_content) = content.deref().as_text() {
if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
if let Some(task_id) = parsed.get("task_id").and_then(|v| v.as_str()) {
return Ok(task_id.to_string());
}
}
}
}
Err(TaskError::CreateFailed("No task_id in response".to_string()))
}
}
#[async_trait]
impl<S> Tool for McpTool<S>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
{
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn is_long_running(&self) -> bool {
self.is_long_running
}
fn parameters_schema(&self) -> Option<Value> {
self.input_schema.clone()
}
fn response_schema(&self) -> Option<Value> {
self.output_schema.clone()
}
async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
let use_task_mode = self.task_config.enable_tasks && self.is_long_running;
if use_task_mode {
debug!(tool = self.name, "Executing tool in task mode (long-running)");
let task_params = self.task_config.to_task_params();
let task_map = task_params.as_object().cloned();
let create_result = self
.call_tool_with_retry({
let mut params = CallToolRequestParams::new(self.name.clone());
if !(args.is_null() || args == json!({})) {
match args {
Value::Object(map) => {
params = params.with_arguments(map);
}
_ => {
return Err(AdkError::tool("Tool arguments must be an object"));
}
}
}
if let Some(task_map) = task_map {
params = params.with_task(task_map);
}
params
})
.await?;
let task_id = self
.extract_task_id(&create_result)
.map_err(|e| AdkError::tool(format!("Failed to get task ID: {e}")))?;
debug!(tool = self.name, task_id = task_id, "Task created, polling for completion");
let result = self
.poll_task(&task_id)
.await
.map_err(|e| AdkError::tool(format!("Task execution failed: {e}")))?;
return Ok(result);
}
let result = self
.call_tool_with_retry({
let mut params = CallToolRequestParams::new(self.name.clone());
if !(args.is_null() || args == json!({})) {
match args {
Value::Object(map) => {
params = params.with_arguments(map);
}
_ => {
return Err(AdkError::tool("Tool arguments must be an object"));
}
}
}
params
})
.await?;
if result.is_error.unwrap_or(false) {
let mut error_msg = format!("MCP tool '{}' execution failed", self.name);
for content in &result.content {
if let Some(text_content) = content.deref().as_text() {
error_msg.push_str(": ");
error_msg.push_str(&text_content.text);
break;
}
}
return Err(AdkError::tool(error_msg));
}
if let Some(structured) = result.structured_content {
return Ok(json!({ "output": structured }));
}
let mut text_parts: Vec<String> = Vec::new();
for content in &result.content {
let raw: &RawContent = content.deref();
match raw {
RawContent::Text(text_content) => {
text_parts.push(text_content.text.clone());
}
RawContent::Image(image_content) => {
text_parts.push(format!(
"[Image: {} bytes, mime: {}]",
image_content.data.len(),
image_content.mime_type
));
}
RawContent::Resource(resource_content) => {
let uri = match &resource_content.resource {
ResourceContents::TextResourceContents { uri, .. } => uri,
ResourceContents::BlobResourceContents { uri, .. } => uri,
};
text_parts.push(format!("[Resource: {}]", uri));
}
RawContent::Audio(_) => {
text_parts.push("[Audio content]".to_string());
}
RawContent::ResourceLink(link) => {
text_parts.push(format!("[ResourceLink: {}]", link.uri));
}
}
}
if text_parts.is_empty() {
return Err(AdkError::tool(format!("MCP tool '{}' returned no content", self.name)));
}
Ok(json!({ "output": text_parts.join("\n") }))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mcp_tool_is_send_and_sync() {
fn require_send_sync<T: Send + Sync>() {}
require_send_sync::<McpTool<()>>();
require_send_sync::<McpToolset<()>>();
}
#[test]
fn test_should_retry_mcp_operation_reconnectable_errors() {
let config = RefreshConfig::default().with_max_attempts(3);
assert!(should_retry_mcp_operation("EOF", 0, &config, true));
assert!(should_retry_mcp_operation("connection reset by peer", 1, &config, true));
}
#[test]
fn test_should_retry_mcp_operation_stops_at_max_attempts() {
let config = RefreshConfig::default().with_max_attempts(2);
assert!(!should_retry_mcp_operation("EOF", 2, &config, true));
}
#[test]
fn test_should_retry_mcp_operation_requires_factory() {
let config = RefreshConfig::default().with_max_attempts(3);
assert!(!should_retry_mcp_operation("EOF", 0, &config, false));
}
#[test]
fn test_should_retry_mcp_operation_non_reconnectable_error() {
let config = RefreshConfig::default().with_max_attempts(3);
assert!(!should_retry_mcp_operation("invalid arguments for tool", 0, &config, true));
}
}