Skip to main content

adk_tool/mcp/
toolset.rs

1// MCP (Model Context Protocol) Toolset Integration
2//
3// Based on Go implementation: adk-go/tool/mcptoolset/
4// Uses official Rust SDK: https://github.com/modelcontextprotocol/rust-sdk
5//
6// The McpToolset connects to an MCP server, discovers available tools,
7// and exposes them as ADK-compatible tools for use with LlmAgent.
8
9use super::task::{McpTaskConfig, TaskError, TaskStatus};
10use super::{ConnectionFactory, RefreshConfig, should_refresh_connection};
11use adk_core::{AdkError, ReadonlyContext, Result, Tool, ToolContext, Toolset};
12use async_trait::async_trait;
13use rmcp::{
14    RoleClient,
15    model::{
16        CallToolRequestParams, ErrorCode, RawContent, ReadResourceRequestParams, Resource,
17        ResourceContents, ResourceTemplate,
18    },
19    service::RunningService,
20};
21use serde_json::{Value, json};
22use std::ops::Deref;
23use std::sync::Arc;
24use std::time::Instant;
25use tokio::sync::Mutex;
26use tracing::{debug, warn};
27
28/// Shared factory object used to recreate MCP connections for refresh/retry.
29type DynConnectionFactory<S> = Arc<dyn ConnectionFactory<S>>;
30
31/// Type alias for tool filter predicate
32pub type ToolFilter = Arc<dyn Fn(&str) -> bool + Send + Sync>;
33
34/// Sanitize JSON schema for LLM compatibility.
35/// Removes fields like `$schema`, `additionalProperties`, `definitions`, `$ref`
36/// that some LLM APIs (like Gemini) don't accept.
37fn sanitize_schema(value: &mut Value) {
38    if let Value::Object(map) = value {
39        map.remove("$schema");
40        map.remove("definitions");
41        map.remove("$ref");
42        map.remove("additionalProperties");
43
44        for (_, v) in map.iter_mut() {
45            sanitize_schema(v);
46        }
47    } else if let Value::Array(arr) = value {
48        for v in arr.iter_mut() {
49            sanitize_schema(v);
50        }
51    }
52}
53
54fn should_retry_mcp_operation(
55    error: &str,
56    attempt: u32,
57    refresh_config: &RefreshConfig,
58    has_connection_factory: bool,
59) -> bool {
60    has_connection_factory
61        && attempt < refresh_config.max_attempts
62        && should_refresh_connection(error)
63}
64
65/// Returns `true` when the `ServiceError` wraps an MCP `MethodNotFound` (-32601)
66/// JSON-RPC error, indicating the server does not implement the requested method.
67fn is_method_not_found(err: &rmcp::ServiceError) -> bool {
68    matches!(
69        err,
70        rmcp::ServiceError::McpError(e) if e.code == ErrorCode::METHOD_NOT_FOUND
71    )
72}
73
74/// MCP Toolset - connects to an MCP server and exposes its tools as ADK tools.
75///
76/// This toolset implements the ADK `Toolset` trait and bridges the gap between
77/// MCP servers and ADK agents. It:
78/// 1. Connects to an MCP server via the provided transport
79/// 2. Discovers available tools from the server
80/// 3. Converts MCP tools to ADK-compatible `Tool` implementations
81/// 4. Proxies tool execution calls to the MCP server
82///
83/// # Example
84///
85/// ```rust,ignore
86/// use adk_tool::McpToolset;
87/// use rmcp::{ServiceExt, transport::TokioChildProcess};
88/// use tokio::process::Command;
89///
90/// // Create MCP client connection to a local server
91/// let client = ().serve(TokioChildProcess::new(
92///     Command::new("npx")
93///         .arg("-y")
94///         .arg("@modelcontextprotocol/server-everything")
95/// )?).await?;
96///
97/// // Create toolset from the client
98/// let toolset = McpToolset::new(client);
99///
100/// // Add to agent
101/// let agent = LlmAgentBuilder::new("assistant")
102///     .toolset(Arc::new(toolset))
103///     .build()?;
104/// ```
105pub struct McpToolset<S = ()>
106where
107    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
108{
109    /// The running MCP client service
110    client: Arc<Mutex<RunningService<RoleClient, S>>>,
111    /// Optional filter to select which tools to expose
112    tool_filter: Option<ToolFilter>,
113    /// Name of this toolset
114    name: String,
115    /// Task configuration for long-running operations
116    task_config: McpTaskConfig,
117    /// Optional connection factory used for reconnection on transport failures.
118    connection_factory: Option<DynConnectionFactory<S>>,
119    /// Reconnection/retry configuration.
120    refresh_config: RefreshConfig,
121}
122
123impl<S> McpToolset<S>
124where
125    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
126{
127    /// Create a new MCP toolset from a running MCP client service.
128    ///
129    /// The client should already be connected and initialized.
130    /// Use `rmcp::ServiceExt::serve()` to create the client.
131    ///
132    /// # Example
133    ///
134    /// ```rust,ignore
135    /// use rmcp::{ServiceExt, transport::TokioChildProcess};
136    /// use tokio::process::Command;
137    ///
138    /// let client = ().serve(TokioChildProcess::new(
139    ///     Command::new("my-mcp-server")
140    /// )?).await?;
141    ///
142    /// let toolset = McpToolset::new(client);
143    /// ```
144    pub fn new(client: RunningService<RoleClient, S>) -> Self {
145        Self {
146            client: Arc::new(Mutex::new(client)),
147            tool_filter: None,
148            name: "mcp_toolset".to_string(),
149            task_config: McpTaskConfig::default(),
150            connection_factory: None,
151            refresh_config: RefreshConfig::default(),
152        }
153    }
154
155    /// Create a McpToolset from a RunningService with a custom ClientHandler.
156    ///
157    /// This is functionally identical to `new()` but makes the intent explicit
158    /// when using a custom `ClientHandler` type.
159    ///
160    /// # Example
161    ///
162    /// ```rust,ignore
163    /// use rmcp::ServiceExt;
164    /// use adk_tool::McpToolset;
165    ///
166    /// let client = my_custom_handler.serve(transport).await?;
167    /// let toolset = McpToolset::with_client_handler(client);
168    /// ```
169    pub fn with_client_handler(client: RunningService<RoleClient, S>) -> Self {
170        Self::new(client)
171    }
172
173    /// Set a custom name for this toolset.
174    pub fn with_name(mut self, name: impl Into<String>) -> Self {
175        self.name = name.into();
176        self
177    }
178
179    /// Enable task support for long-running operations.
180    ///
181    /// When enabled, tools marked as `is_long_running()` will use MCP's
182    /// async task lifecycle (SEP-1686) instead of blocking calls.
183    ///
184    /// # Example
185    ///
186    /// ```rust,ignore
187    /// let toolset = McpToolset::new(client)
188    ///     .with_task_support(McpTaskConfig::enabled()
189    ///         .poll_interval(Duration::from_secs(2))
190    ///         .timeout(Duration::from_secs(300)));
191    /// ```
192    pub fn with_task_support(mut self, config: McpTaskConfig) -> Self {
193        self.task_config = config;
194        self
195    }
196
197    /// Provide a connection factory to enable automatic MCP reconnection.
198    pub fn with_connection_factory<F>(mut self, factory: Arc<F>) -> Self
199    where
200        F: ConnectionFactory<S> + 'static,
201    {
202        self.connection_factory = Some(factory);
203        self
204    }
205
206    /// Configure MCP reconnect/retry behavior.
207    pub fn with_refresh_config(mut self, config: RefreshConfig) -> Self {
208        self.refresh_config = config;
209        self
210    }
211
212    /// Add a filter to select which tools to expose.
213    ///
214    /// The filter function receives a tool name and returns true if the tool
215    /// should be included.
216    ///
217    /// # Example
218    ///
219    /// ```rust,ignore
220    /// let toolset = McpToolset::new(client)
221    ///     .with_filter(|name| {
222    ///         matches!(name, "read_file" | "list_directory" | "search_files")
223    ///     });
224    /// ```
225    pub fn with_filter<F>(mut self, filter: F) -> Self
226    where
227        F: Fn(&str) -> bool + Send + Sync + 'static,
228    {
229        self.tool_filter = Some(Arc::new(filter));
230        self
231    }
232
233    /// Add a filter that only includes tools with the specified names.
234    ///
235    /// # Example
236    ///
237    /// ```rust,ignore
238    /// let toolset = McpToolset::new(client)
239    ///     .with_tools(&["read_file", "write_file"]);
240    /// ```
241    pub fn with_tools(self, tool_names: &[&str]) -> Self {
242        let names: Vec<String> = tool_names.iter().map(|s| s.to_string()).collect();
243        self.with_filter(move |name| names.iter().any(|n| n == name))
244    }
245
246    /// Get a cancellation token that can be used to shutdown the MCP server.
247    ///
248    /// Call `cancel()` on the returned token to cleanly shutdown the MCP server.
249    /// This should be called before exiting to avoid EPIPE errors.
250    ///
251    /// # Example
252    ///
253    /// ```rust,ignore
254    /// let toolset = McpToolset::new(client);
255    /// let cancel_token = toolset.cancellation_token().await;
256    ///
257    /// // ... use the toolset ...
258    ///
259    /// // Before exiting:
260    /// cancel_token.cancel();
261    /// ```
262    pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
263        let client = self.client.lock().await;
264        client.cancellation_token()
265    }
266
267    /// Check whether the underlying MCP service connection has been closed or cancelled.
268    ///
269    /// Returns `true` if the service loop has terminated (transport closed,
270    /// cancellation token fired, or the background task completed). This is
271    /// useful for health monitoring — a closed connection indicates the server
272    /// process has crashed or the transport has been lost.
273    ///
274    /// # Example
275    ///
276    /// ```rust,ignore
277    /// if toolset.is_closed().await {
278    ///     tracing::warn!("MCP server connection lost");
279    /// }
280    /// ```
281    pub async fn is_closed(&self) -> bool {
282        let client = self.client.lock().await;
283        client.is_closed()
284    }
285
286    async fn try_refresh_connection(&self) -> Result<bool> {
287        let Some(factory) = self.connection_factory.clone() else {
288            return Ok(false);
289        };
290
291        let new_client = factory
292            .create_connection()
293            .await
294            .map_err(|e| AdkError::tool(format!("Failed to refresh MCP connection: {e}")))?;
295
296        let mut client = self.client.lock().await;
297        let old_token = client.cancellation_token();
298        old_token.cancel();
299        *client = new_client;
300        Ok(true)
301    }
302
303    /// List static resources from the connected MCP server.
304    ///
305    /// Returns the list of resources advertised by the server via the
306    /// `resources/list` protocol method. Returns an empty `Vec` when the
307    /// server does not support resources (i.e. responds with
308    /// `MethodNotFound`).
309    ///
310    /// # Errors
311    ///
312    /// Returns `AdkError::Tool` on transport or unexpected server errors.
313    pub async fn list_resources(&self) -> Result<Vec<Resource>> {
314        let client = self.client.lock().await;
315        match client.list_all_resources().await {
316            Ok(resources) => Ok(resources),
317            Err(e) => {
318                if is_method_not_found(&e) {
319                    Ok(vec![])
320                } else {
321                    Err(AdkError::tool(format!("Failed to list MCP resources: {e}")))
322                }
323            }
324        }
325    }
326
327    /// List URI template resources from the connected MCP server.
328    ///
329    /// Returns the list of resource templates advertised by the server via
330    /// the `resourceTemplates/list` protocol method. Returns an empty `Vec`
331    /// when the server does not support resource templates (i.e. responds
332    /// with `MethodNotFound`).
333    ///
334    /// # Errors
335    ///
336    /// Returns `AdkError::Tool` on transport or unexpected server errors.
337    pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>> {
338        let client = self.client.lock().await;
339        match client.list_all_resource_templates().await {
340            Ok(templates) => Ok(templates),
341            Err(e) => {
342                if is_method_not_found(&e) {
343                    Ok(vec![])
344                } else {
345                    Err(AdkError::tool(format!("Failed to list MCP resource templates: {e}")))
346                }
347            }
348        }
349    }
350
351    /// Read a resource by URI from the connected MCP server.
352    ///
353    /// Delegates to the `resources/read` protocol method. Returns the
354    /// resource contents on success.
355    ///
356    /// # Errors
357    ///
358    /// Returns `AdkError::Tool("resource not found: {uri}")` when the URI
359    /// does not match any resource on the server. Returns a generic
360    /// `AdkError::Tool` on transport or other server errors.
361    pub async fn read_resource(&self, uri: &str) -> Result<Vec<ResourceContents>> {
362        let client = self.client.lock().await;
363        let params = ReadResourceRequestParams::new(uri.to_string());
364        match client.read_resource(params).await {
365            Ok(result) => Ok(result.contents),
366            Err(e) => {
367                if is_method_not_found(&e) {
368                    Err(AdkError::tool(format!("resource not found: {uri}")))
369                } else {
370                    Err(AdkError::tool(format!("Failed to read MCP resource '{uri}': {e}")))
371                }
372            }
373        }
374    }
375}
376
377#[async_trait]
378impl<S> Toolset for McpToolset<S>
379where
380    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
381{
382    fn name(&self) -> &str {
383        &self.name
384    }
385
386    async fn tools(&self, _ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
387        let mut attempt = 0u32;
388        let has_connection_factory = self.connection_factory.is_some();
389        let mcp_tools = loop {
390            let list_result = {
391                let client = self.client.lock().await;
392                client.list_all_tools().await.map_err(|e| e.to_string())
393            };
394
395            match list_result {
396                Ok(tools) => break tools,
397                Err(error) => {
398                    if !should_retry_mcp_operation(
399                        &error,
400                        attempt,
401                        &self.refresh_config,
402                        has_connection_factory,
403                    ) {
404                        return Err(AdkError::tool(format!("Failed to list MCP tools: {error}")));
405                    }
406
407                    let retry_attempt = attempt + 1;
408                    if self.refresh_config.log_reconnections {
409                        warn!(
410                            attempt = retry_attempt,
411                            max_attempts = self.refresh_config.max_attempts,
412                            error = %error,
413                            "MCP list_all_tools failed; reconnecting and retrying"
414                        );
415                    }
416
417                    if self.refresh_config.retry_delay_ms > 0 {
418                        tokio::time::sleep(tokio::time::Duration::from_millis(
419                            self.refresh_config.retry_delay_ms,
420                        ))
421                        .await;
422                    }
423
424                    if !self.try_refresh_connection().await? {
425                        return Err(AdkError::tool(format!("Failed to list MCP tools: {error}")));
426                    }
427                    attempt += 1;
428                }
429            }
430        };
431
432        // Convert MCP tools to ADK tools
433        let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
434
435        for mcp_tool in mcp_tools {
436            let tool_name = mcp_tool.name.to_string();
437
438            // Apply filter if present
439            if let Some(ref filter) = self.tool_filter {
440                if !filter(&tool_name) {
441                    continue;
442                }
443            }
444
445            let adk_tool = McpTool {
446                name: tool_name,
447                description: mcp_tool.description.map(|d| d.to_string()).unwrap_or_default(),
448                input_schema: {
449                    let mut schema = Value::Object(mcp_tool.input_schema.as_ref().clone());
450                    sanitize_schema(&mut schema);
451                    Some(schema)
452                },
453                output_schema: mcp_tool.output_schema.map(|s| {
454                    let mut schema = Value::Object(s.as_ref().clone());
455                    sanitize_schema(&mut schema);
456                    schema
457                }),
458                client: self.client.clone(),
459                connection_factory: self.connection_factory.clone(),
460                refresh_config: self.refresh_config.clone(),
461                // MCP ToolAnnotations (read_only_hint, destructive_hint, etc.)
462                // do not include a "long_running" hint. When task support is
463                // enabled on this toolset, treat non-read-only open-world tools
464                // as potentially long-running so the task lifecycle activates.
465                is_long_running: self.task_config.enable_tasks
466                    && mcp_tool.annotations.as_ref().is_some_and(|a| {
467                        a.read_only_hint != Some(true) && a.open_world_hint != Some(false)
468                    }),
469                task_config: self.task_config.clone(),
470            };
471
472            tools.push(Arc::new(adk_tool) as Arc<dyn Tool>);
473        }
474
475        Ok(tools)
476    }
477}
478
479impl McpToolset<super::elicitation::AdkClientHandler> {
480    /// Create a McpToolset with elicitation support from a transport.
481    ///
482    /// This creates the MCP client using `AdkClientHandler`, which advertises
483    /// elicitation capabilities and delegates requests to the provided handler.
484    ///
485    /// # Example
486    ///
487    /// ```rust,ignore
488    /// use adk_tool::{McpToolset, ElicitationHandler, AutoDeclineElicitationHandler};
489    /// use rmcp::transport::TokioChildProcess;
490    /// use tokio::process::Command;
491    /// use std::sync::Arc;
492    ///
493    /// let transport = TokioChildProcess::new(Command::new("my-mcp-server"))?;
494    /// let handler = Arc::new(AutoDeclineElicitationHandler);
495    /// let toolset = McpToolset::with_elicitation_handler(transport, handler).await?;
496    /// ```
497    ///
498    /// # ConnectionFactory with Elicitation
499    ///
500    /// To preserve elicitation across reconnections, clone the `Arc<dyn ElicitationHandler>`
501    /// into your `ConnectionFactory` implementation:
502    ///
503    /// ```rust,ignore
504    /// use adk_tool::{McpToolset, ElicitationHandler};
505    /// use adk_tool::mcp::ConnectionFactory;
506    /// use rmcp::{ServiceExt, service::{RoleClient, RunningService}};
507    /// use rmcp::transport::TokioChildProcess;
508    /// use tokio::process::Command;
509    /// use std::sync::Arc;
510    ///
511    /// struct MyReconnectFactory {
512    ///     handler: Arc<dyn ElicitationHandler>,
513    ///     server_command: String,
514    /// }
515    ///
516    /// // The factory creates a fresh AdkClientHandler on each reconnection,
517    /// // so the new connection advertises elicitation capabilities.
518    /// // The ConnectionFactory trait itself is unchanged.
519    /// ```
520    pub async fn with_elicitation_handler<T, E, A>(
521        transport: T,
522        handler: std::sync::Arc<dyn super::elicitation::ElicitationHandler>,
523    ) -> Result<Self>
524    where
525        T: rmcp::transport::IntoTransport<rmcp::RoleClient, E, A> + Send + 'static,
526        E: std::error::Error + Send + Sync + 'static,
527    {
528        use rmcp::ServiceExt;
529        let adk_handler = super::elicitation::AdkClientHandler::new(handler);
530        let client = adk_handler
531            .serve(transport)
532            .await
533            .map_err(|e| AdkError::tool(format!("failed to connect MCP server: {e}")))?;
534        Ok(Self::new(client))
535    }
536
537    /// Create a McpToolset with MCP sampling support from a transport.
538    ///
539    /// This creates the MCP client using `AdkClientHandler`, which advertises
540    /// both elicitation and sampling capabilities. When the connected MCP server
541    /// sends a `sampling/createMessage` request, it is delegated to the provided
542    /// [`SamplingHandler`](crate::sampling::SamplingHandler).
543    ///
544    /// An elicitation handler is also required because `AdkClientHandler` always
545    /// advertises elicitation. Use [`AutoDeclineElicitationHandler`] if you don't
546    /// need custom elicitation behavior.
547    ///
548    /// # Example
549    ///
550    /// ```rust,ignore
551    /// use adk_tool::{McpToolset, AutoDeclineElicitationHandler};
552    /// use adk_tool::sampling::LlmSamplingHandler;
553    /// use rmcp::transport::TokioChildProcess;
554    /// use tokio::process::Command;
555    /// use std::sync::Arc;
556    ///
557    /// let transport = TokioChildProcess::new(Command::new("my-mcp-server"))?;
558    /// let elicitation = Arc::new(AutoDeclineElicitationHandler);
559    /// let sampling = Arc::new(LlmSamplingHandler::new(my_llm.clone()));
560    /// let toolset = McpToolset::with_sampling_handler(transport, elicitation, sampling).await?;
561    /// ```
562    ///
563    /// # ConnectionFactory with Sampling
564    ///
565    /// To preserve sampling across reconnections, clone both handler `Arc`s
566    /// into your `ConnectionFactory` implementation and rebuild the
567    /// `AdkClientHandler` on each reconnection.
568    #[cfg(feature = "mcp-sampling")]
569    pub async fn with_sampling_handler<T, E, A>(
570        transport: T,
571        elicitation_handler: std::sync::Arc<dyn super::elicitation::ElicitationHandler>,
572        sampling_handler: std::sync::Arc<dyn crate::sampling::SamplingHandler>,
573    ) -> Result<Self>
574    where
575        T: rmcp::transport::IntoTransport<rmcp::RoleClient, E, A> + Send + 'static,
576        E: std::error::Error + Send + Sync + 'static,
577    {
578        use rmcp::ServiceExt;
579        let adk_handler = super::elicitation::AdkClientHandler::new(elicitation_handler)
580            .with_sampling_handler(sampling_handler);
581        let client = adk_handler
582            .serve(transport)
583            .await
584            .map_err(|e| AdkError::tool(format!("failed to connect MCP server: {e}")))?;
585        Ok(Self::new(client))
586    }
587}
588
589/// Individual MCP tool wrapper that implements the ADK `Tool` trait.
590///
591/// This struct wraps an MCP tool and proxies execution calls to the MCP server.
592struct McpTool<S>
593where
594    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
595{
596    name: String,
597    description: String,
598    input_schema: Option<Value>,
599    output_schema: Option<Value>,
600    client: Arc<Mutex<RunningService<RoleClient, S>>>,
601    connection_factory: Option<DynConnectionFactory<S>>,
602    refresh_config: RefreshConfig,
603    /// Whether this tool is long-running (from MCP tool metadata)
604    is_long_running: bool,
605    /// Task configuration
606    task_config: McpTaskConfig,
607}
608
609impl<S> McpTool<S>
610where
611    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
612{
613    async fn try_refresh_connection(&self) -> Result<bool> {
614        let Some(factory) = self.connection_factory.clone() else {
615            return Ok(false);
616        };
617
618        let new_client = factory
619            .create_connection()
620            .await
621            .map_err(|e| AdkError::tool(format!("Failed to refresh MCP connection: {e}")))?;
622
623        let mut client = self.client.lock().await;
624        let old_token = client.cancellation_token();
625        old_token.cancel();
626        *client = new_client;
627        Ok(true)
628    }
629
630    async fn call_tool_with_retry(
631        &self,
632        params: CallToolRequestParams,
633    ) -> Result<rmcp::model::CallToolResult> {
634        let has_connection_factory = self.connection_factory.is_some();
635        let mut attempt = 0u32;
636
637        loop {
638            let call_result = {
639                let client = self.client.lock().await;
640                client.call_tool(params.clone()).await.map_err(|e| e.to_string())
641            };
642
643            match call_result {
644                Ok(result) => return Ok(result),
645                Err(error) => {
646                    if !should_retry_mcp_operation(
647                        &error,
648                        attempt,
649                        &self.refresh_config,
650                        has_connection_factory,
651                    ) {
652                        return Err(AdkError::tool(format!(
653                            "Failed to call MCP tool '{}': {error}",
654                            self.name
655                        )));
656                    }
657
658                    let retry_attempt = attempt + 1;
659                    if self.refresh_config.log_reconnections {
660                        warn!(
661                            tool = %self.name,
662                            attempt = retry_attempt,
663                            max_attempts = self.refresh_config.max_attempts,
664                            error = %error,
665                            "MCP call_tool failed; reconnecting and retrying"
666                        );
667                    }
668
669                    if self.refresh_config.retry_delay_ms > 0 {
670                        tokio::time::sleep(tokio::time::Duration::from_millis(
671                            self.refresh_config.retry_delay_ms,
672                        ))
673                        .await;
674                    }
675
676                    if !self.try_refresh_connection().await? {
677                        return Err(AdkError::tool(format!(
678                            "Failed to call MCP tool '{}': {error}",
679                            self.name
680                        )));
681                    }
682                    attempt += 1;
683                }
684            }
685        }
686    }
687
688    /// Poll a task until completion or timeout
689    async fn poll_task(&self, task_id: &str) -> std::result::Result<Value, TaskError> {
690        let start = Instant::now();
691        let mut attempts = 0u32;
692
693        loop {
694            // Check timeout
695            if let Some(timeout_ms) = self.task_config.timeout_ms {
696                let elapsed = start.elapsed().as_millis() as u64;
697                if elapsed >= timeout_ms {
698                    return Err(TaskError::Timeout {
699                        task_id: task_id.to_string(),
700                        elapsed_ms: elapsed,
701                    });
702                }
703            }
704
705            // Check max attempts
706            if let Some(max_attempts) = self.task_config.max_poll_attempts {
707                if attempts >= max_attempts {
708                    return Err(TaskError::MaxAttemptsExceeded {
709                        task_id: task_id.to_string(),
710                        attempts,
711                    });
712                }
713            }
714
715            // Wait before polling
716            tokio::time::sleep(self.task_config.poll_duration()).await;
717            attempts += 1;
718
719            debug!(task_id = task_id, attempt = attempts, "Polling MCP task status");
720
721            // Poll task status using tasks/get
722            // Note: This requires the MCP server to support SEP-1686 task lifecycle
723            let poll_result = self
724                .call_tool_with_retry(CallToolRequestParams::new("tasks/get").with_arguments(
725                    serde_json::Map::from_iter([(
726                        "task_id".to_string(),
727                        Value::String(task_id.to_string()),
728                    )]),
729                ))
730                .await
731                .map_err(|e| TaskError::PollFailed(e.to_string()))?;
732
733            // Parse task status from response
734            let status = self.parse_task_status(&poll_result)?;
735
736            match status {
737                TaskStatus::Completed => {
738                    debug!(task_id = task_id, "Task completed successfully");
739                    // Extract result from the poll response
740                    return self.extract_task_result(&poll_result);
741                }
742                TaskStatus::Failed => {
743                    let error_msg = self.extract_error_message(&poll_result);
744                    return Err(TaskError::TaskFailed {
745                        task_id: task_id.to_string(),
746                        error: error_msg,
747                    });
748                }
749                TaskStatus::Cancelled => {
750                    return Err(TaskError::Cancelled(task_id.to_string()));
751                }
752                TaskStatus::Pending | TaskStatus::Running => {
753                    // Continue polling
754                    debug!(
755                        task_id = task_id,
756                        status = ?status,
757                        "Task still in progress"
758                    );
759                }
760            }
761        }
762    }
763
764    /// Parse task status from poll response
765    fn parse_task_status(
766        &self,
767        result: &rmcp::model::CallToolResult,
768    ) -> std::result::Result<TaskStatus, TaskError> {
769        // Try to extract status from structured content first
770        if let Some(ref structured) = result.structured_content {
771            if let Some(status_str) = structured.get("status").and_then(|v| v.as_str()) {
772                return match status_str {
773                    "pending" => Ok(TaskStatus::Pending),
774                    "running" => Ok(TaskStatus::Running),
775                    "completed" => Ok(TaskStatus::Completed),
776                    "failed" => Ok(TaskStatus::Failed),
777                    "cancelled" => Ok(TaskStatus::Cancelled),
778                    _ => {
779                        warn!(status = status_str, "Unknown task status");
780                        Ok(TaskStatus::Running) // Assume still running
781                    }
782                };
783            }
784        }
785
786        // Try to extract from text content
787        for content in &result.content {
788            if let Some(text_content) = content.deref().as_text() {
789                // Try to parse as JSON
790                if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
791                    if let Some(status_str) = parsed.get("status").and_then(|v| v.as_str()) {
792                        return match status_str {
793                            "pending" => Ok(TaskStatus::Pending),
794                            "running" => Ok(TaskStatus::Running),
795                            "completed" => Ok(TaskStatus::Completed),
796                            "failed" => Ok(TaskStatus::Failed),
797                            "cancelled" => Ok(TaskStatus::Cancelled),
798                            _ => Ok(TaskStatus::Running),
799                        };
800                    }
801                }
802            }
803        }
804
805        // Default to running if we can't determine status
806        Ok(TaskStatus::Running)
807    }
808
809    /// Extract result from completed task
810    fn extract_task_result(
811        &self,
812        result: &rmcp::model::CallToolResult,
813    ) -> std::result::Result<Value, TaskError> {
814        // Try structured content first
815        if let Some(ref structured) = result.structured_content {
816            if let Some(output) = structured.get("result") {
817                return Ok(json!({ "output": output }));
818            }
819            return Ok(json!({ "output": structured }));
820        }
821
822        // Fall back to text content
823        let mut text_parts: Vec<String> = Vec::new();
824        for content in &result.content {
825            if let Some(text_content) = content.deref().as_text() {
826                text_parts.push(text_content.text.clone());
827            }
828        }
829
830        if text_parts.is_empty() {
831            Ok(json!({ "output": null }))
832        } else {
833            Ok(json!({ "output": text_parts.join("\n") }))
834        }
835    }
836
837    /// Extract error message from failed task
838    fn extract_error_message(&self, result: &rmcp::model::CallToolResult) -> String {
839        // Try structured content
840        if let Some(ref structured) = result.structured_content {
841            if let Some(error) = structured.get("error").and_then(|v| v.as_str()) {
842                return error.to_string();
843            }
844        }
845
846        // Try text content
847        for content in &result.content {
848            if let Some(text_content) = content.deref().as_text() {
849                return text_content.text.clone();
850            }
851        }
852
853        "Unknown error".to_string()
854    }
855
856    /// Extract task ID from create task response
857    fn extract_task_id(
858        &self,
859        result: &rmcp::model::CallToolResult,
860    ) -> std::result::Result<String, TaskError> {
861        // Try structured content
862        if let Some(ref structured) = result.structured_content {
863            if let Some(task_id) = structured.get("task_id").and_then(|v| v.as_str()) {
864                return Ok(task_id.to_string());
865            }
866        }
867
868        // Try text content (might be JSON)
869        for content in &result.content {
870            if let Some(text_content) = content.deref().as_text() {
871                if let Ok(parsed) = serde_json::from_str::<Value>(&text_content.text) {
872                    if let Some(task_id) = parsed.get("task_id").and_then(|v| v.as_str()) {
873                        return Ok(task_id.to_string());
874                    }
875                }
876            }
877        }
878
879        Err(TaskError::CreateFailed("No task_id in response".to_string()))
880    }
881}
882
883#[async_trait]
884impl<S> Tool for McpTool<S>
885where
886    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
887{
888    fn name(&self) -> &str {
889        &self.name
890    }
891
892    fn description(&self) -> &str {
893        &self.description
894    }
895
896    fn is_long_running(&self) -> bool {
897        self.is_long_running
898    }
899
900    fn parameters_schema(&self) -> Option<Value> {
901        self.input_schema.clone()
902    }
903
904    fn response_schema(&self) -> Option<Value> {
905        self.output_schema.clone()
906    }
907
908    async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
909        // Determine if we should use task mode
910        let use_task_mode = self.task_config.enable_tasks && self.is_long_running;
911
912        if use_task_mode {
913            debug!(tool = self.name, "Executing tool in task mode (long-running)");
914
915            // Create task request with task parameters
916            let task_params = self.task_config.to_task_params();
917            let task_map = task_params.as_object().cloned();
918
919            let create_result = self
920                .call_tool_with_retry({
921                    let mut params = CallToolRequestParams::new(self.name.clone());
922                    if !(args.is_null() || args == json!({})) {
923                        match args {
924                            Value::Object(map) => {
925                                params = params.with_arguments(map);
926                            }
927                            _ => {
928                                return Err(AdkError::tool("Tool arguments must be an object"));
929                            }
930                        }
931                    }
932                    if let Some(task_map) = task_map {
933                        params = params.with_task(task_map);
934                    }
935                    params
936                })
937                .await?;
938
939            // Extract task ID
940            let task_id = self
941                .extract_task_id(&create_result)
942                .map_err(|e| AdkError::tool(format!("Failed to get task ID: {e}")))?;
943
944            debug!(tool = self.name, task_id = task_id, "Task created, polling for completion");
945
946            // Poll for completion
947            let result = self
948                .poll_task(&task_id)
949                .await
950                .map_err(|e| AdkError::tool(format!("Task execution failed: {e}")))?;
951
952            return Ok(result);
953        }
954
955        // Standard synchronous execution
956        let result = self
957            .call_tool_with_retry({
958                let mut params = CallToolRequestParams::new(self.name.clone());
959                if !(args.is_null() || args == json!({})) {
960                    match args {
961                        Value::Object(map) => {
962                            params = params.with_arguments(map);
963                        }
964                        _ => {
965                            return Err(AdkError::tool("Tool arguments must be an object"));
966                        }
967                    }
968                }
969                params
970            })
971            .await?;
972
973        // Check for error response
974        if result.is_error.unwrap_or(false) {
975            let mut error_msg = format!("MCP tool '{}' execution failed", self.name);
976
977            // Extract error details from content
978            for content in &result.content {
979                // Use Deref to access the inner RawContent
980                if let Some(text_content) = content.deref().as_text() {
981                    error_msg.push_str(": ");
982                    error_msg.push_str(&text_content.text);
983                    break;
984                }
985            }
986
987            return Err(AdkError::tool(error_msg));
988        }
989
990        // Return structured content if available
991        if let Some(structured) = result.structured_content {
992            return Ok(json!({ "output": structured }));
993        }
994
995        // Otherwise, collect text content
996        let mut text_parts: Vec<String> = Vec::new();
997
998        for content in &result.content {
999            // Access the inner RawContent via Deref
1000            let raw: &RawContent = content.deref();
1001            match raw {
1002                RawContent::Text(text_content) => {
1003                    text_parts.push(text_content.text.clone());
1004                }
1005                RawContent::Image(image_content) => {
1006                    // Return image data as base64
1007                    text_parts.push(format!(
1008                        "[Image: {} bytes, mime: {}]",
1009                        image_content.data.len(),
1010                        image_content.mime_type
1011                    ));
1012                }
1013                RawContent::Resource(resource_content) => {
1014                    let uri = match &resource_content.resource {
1015                        ResourceContents::TextResourceContents { uri, .. } => uri,
1016                        ResourceContents::BlobResourceContents { uri, .. } => uri,
1017                    };
1018                    text_parts.push(format!("[Resource: {}]", uri));
1019                }
1020                RawContent::Audio(_) => {
1021                    text_parts.push("[Audio content]".to_string());
1022                }
1023                RawContent::ResourceLink(link) => {
1024                    text_parts.push(format!("[ResourceLink: {}]", link.uri));
1025                }
1026            }
1027        }
1028
1029        if text_parts.is_empty() {
1030            return Err(AdkError::tool(format!("MCP tool '{}' returned no content", self.name)));
1031        }
1032
1033        Ok(json!({ "output": text_parts.join("\n") }))
1034    }
1035}
1036
1037// McpTool<S> is Send + Sync when S: Send + Sync because all fields are
1038// composed of Send + Sync primitives (String, Arc<Mutex<_>>, Arc<dyn Send + Sync>, etc.).
1039// The compiler enforces this through the Tool trait bound (Tool: Send + Sync).
1040// No unsafe impl needed — the previous unsafe impl was removed as unnecessary.
1041
1042#[cfg(test)]
1043mod tests {
1044    use super::*;
1045
1046    /// Proves that `McpTool<S>` is `Send + Sync` for any service `S: Send + Sync`
1047    /// without requiring `unsafe impl`. The compiler rejects this test at build
1048    /// time if any field breaks the auto-trait derivation.
1049    ///
1050    /// This replaced a previous `unsafe impl Send/Sync for McpTool<S>` that was
1051    /// unnecessary — all fields (String, Arc<Mutex<_>>, Arc<dyn Send+Sync>, bool)
1052    /// are naturally Send + Sync.
1053    #[test]
1054    fn mcp_tool_is_send_and_sync() {
1055        fn require_send_sync<T: Send + Sync>() {}
1056
1057        // The compiler proves Send + Sync for McpTool<S> and McpToolset<S> by
1058        // type-checking these function bodies. If any field were !Send or !Sync,
1059        // this would be a compile error — no unsafe needed.
1060        //
1061        // () satisfies Service<RoleClient> via the ClientHandler blanket impl
1062        // in rmcp, so this is a valid concrete instantiation.
1063        require_send_sync::<McpTool<()>>();
1064        require_send_sync::<McpToolset<()>>();
1065    }
1066
1067    #[test]
1068    fn test_should_retry_mcp_operation_reconnectable_errors() {
1069        let config = RefreshConfig::default().with_max_attempts(3);
1070        assert!(should_retry_mcp_operation("EOF", 0, &config, true));
1071        assert!(should_retry_mcp_operation("connection reset by peer", 1, &config, true));
1072    }
1073
1074    #[test]
1075    fn test_should_retry_mcp_operation_stops_at_max_attempts() {
1076        let config = RefreshConfig::default().with_max_attempts(2);
1077        assert!(!should_retry_mcp_operation("EOF", 2, &config, true));
1078    }
1079
1080    #[test]
1081    fn test_should_retry_mcp_operation_requires_factory() {
1082        let config = RefreshConfig::default().with_max_attempts(3);
1083        assert!(!should_retry_mcp_operation("EOF", 0, &config, false));
1084    }
1085
1086    #[test]
1087    fn test_should_retry_mcp_operation_non_reconnectable_error() {
1088        let config = RefreshConfig::default().with_max_attempts(3);
1089        assert!(!should_retry_mcp_operation("invalid arguments for tool", 0, &config, true));
1090    }
1091}