Skip to main content

adk_tool/mcp/
elicitation.rs

1//! MCP Elicitation lifecycle support.
2//!
3//! This module provides the [`ElicitationHandler`] trait for handling MCP elicitation
4//! requests from servers, an [`AutoDeclineElicitationHandler`] that declines all
5//! requests, and the internal [`AdkClientHandler`] bridge to rmcp's `ClientHandler`.
6
7use std::sync::Arc;
8
9use futures::FutureExt;
10use rmcp::model::{
11    ClientInfo, CreateElicitationRequestParams, CreateElicitationResult, ElicitationAction,
12    ElicitationResponseNotificationParam, ElicitationSchema,
13};
14use rmcp::service::{NotificationContext, RequestContext, RoleClient};
15use serde_json::Value;
16
17/// Trait for handling MCP elicitation requests from servers.
18///
19/// Implement this trait to provide custom elicitation behavior when
20/// an MCP server requests additional information during tool execution.
21///
22/// # Example
23///
24/// ```rust,ignore
25/// use adk_tool::ElicitationHandler;
26/// use rmcp::model::{CreateElicitationResult, ElicitationAction, ElicitationSchema};
27///
28/// struct MyHandler;
29///
30/// #[async_trait::async_trait]
31/// impl ElicitationHandler for MyHandler {
32///     async fn handle_form_elicitation(
33///         &self,
34///         message: &str,
35///         schema: &ElicitationSchema,
36///         metadata: Option<&serde_json::Value>,
37///     ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>> {
38///         println!("Server asks: {message}");
39///         Ok(CreateElicitationResult::new(ElicitationAction::Accept))
40///     }
41///
42///     async fn handle_url_elicitation(
43///         &self,
44///         message: &str,
45///         url: &str,
46///         elicitation_id: &str,
47///         metadata: Option<&serde_json::Value>,
48///     ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>> {
49///         println!("Server asks to visit: {url}");
50///         Ok(CreateElicitationResult::new(ElicitationAction::Accept))
51///     }
52/// }
53/// ```
54#[async_trait::async_trait]
55pub trait ElicitationHandler: Send + Sync {
56    /// Handle a form-based elicitation request.
57    ///
58    /// The server sends a human-readable message and a typed schema describing
59    /// the data it needs. Return `Accept` with content matching the schema,
60    /// `Decline` to refuse, or `Cancel` to abort the operation.
61    async fn handle_form_elicitation(
62        &self,
63        message: &str,
64        schema: &ElicitationSchema,
65        metadata: Option<&Value>,
66    ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>>;
67
68    /// Handle a URL-based elicitation request.
69    ///
70    /// The server sends a URL for the user to visit and interact with externally.
71    /// The `elicitation_id` uniquely identifies this request for the completion
72    /// notification flow.
73    async fn handle_url_elicitation(
74        &self,
75        message: &str,
76        url: &str,
77        elicitation_id: &str,
78        metadata: Option<&Value>,
79    ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>>;
80}
81
82/// Default handler that declines all elicitation requests.
83///
84/// Used when no custom handler is configured, preserving backward-compatible
85/// behavior identical to rmcp's `()` ClientHandler default.
86#[derive(Debug, Clone, Copy)]
87pub struct AutoDeclineElicitationHandler;
88
89#[async_trait::async_trait]
90impl ElicitationHandler for AutoDeclineElicitationHandler {
91    async fn handle_form_elicitation(
92        &self,
93        _message: &str,
94        _schema: &ElicitationSchema,
95        _metadata: Option<&Value>,
96    ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>> {
97        Ok(CreateElicitationResult::new(ElicitationAction::Decline))
98    }
99
100    async fn handle_url_elicitation(
101        &self,
102        _message: &str,
103        _url: &str,
104        _elicitation_id: &str,
105        _metadata: Option<&Value>,
106    ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>> {
107        Ok(CreateElicitationResult::new(ElicitationAction::Decline))
108    }
109}
110
111/// Internal bridge between ADK's [`ElicitationHandler`] and rmcp's `ClientHandler`.
112///
113/// Wraps an `Arc<dyn ElicitationHandler>` and implements rmcp's `ClientHandler` trait,
114/// advertising elicitation capabilities and delegating requests to the handler.
115///
116/// When the `mcp-sampling` feature is enabled, also accepts an optional
117/// `Arc<dyn SamplingHandler>` to handle `sampling/createMessage` requests.
118pub struct AdkClientHandler {
119    handler: Arc<dyn ElicitationHandler>,
120    #[cfg(feature = "mcp-sampling")]
121    sampling_handler: Option<Arc<dyn crate::sampling::SamplingHandler>>,
122}
123
124impl AdkClientHandler {
125    pub fn new(handler: Arc<dyn ElicitationHandler>) -> Self {
126        Self {
127            handler,
128            #[cfg(feature = "mcp-sampling")]
129            sampling_handler: None,
130        }
131    }
132
133    /// Set a sampling handler for `sampling/createMessage` requests.
134    ///
135    /// When configured, the handler advertises sampling capability and
136    /// delegates incoming sampling requests to the provided handler.
137    #[cfg(feature = "mcp-sampling")]
138    pub fn with_sampling_handler(
139        mut self,
140        handler: Arc<dyn crate::sampling::SamplingHandler>,
141    ) -> Self {
142        self.sampling_handler = Some(handler);
143        self
144    }
145}
146
147impl rmcp::handler::client::ClientHandler for AdkClientHandler {
148    fn get_info(&self) -> ClientInfo {
149        let mut info = ClientInfo::default();
150
151        #[cfg(feature = "mcp-sampling")]
152        {
153            if self.sampling_handler.is_some() {
154                info.capabilities = rmcp::model::ClientCapabilities::builder()
155                    .enable_elicitation()
156                    .enable_sampling()
157                    .build();
158            } else {
159                info.capabilities =
160                    rmcp::model::ClientCapabilities::builder().enable_elicitation().build();
161            }
162        }
163
164        #[cfg(not(feature = "mcp-sampling"))]
165        {
166            info.capabilities =
167                rmcp::model::ClientCapabilities::builder().enable_elicitation().build();
168        }
169
170        info
171    }
172
173    #[cfg(feature = "mcp-sampling")]
174    async fn create_message(
175        &self,
176        params: rmcp::model::CreateMessageRequestParams,
177        _context: RequestContext<RoleClient>,
178    ) -> Result<rmcp::model::CreateMessageResult, rmcp::ErrorData> {
179        use crate::sampling::{SamplingContent, SamplingMessage, SamplingRequest};
180        use rmcp::model::{CreateMessageResult, Role, SamplingMessageContent};
181
182        let Some(ref sampling_handler) = self.sampling_handler else {
183            return Err(rmcp::ErrorData::new(
184                rmcp::model::ErrorCode::METHOD_NOT_FOUND,
185                "sampling handler not configured",
186                None,
187            ));
188        };
189
190        // Convert rmcp SamplingMessages → our SamplingMessages
191        let messages: Vec<SamplingMessage> = params
192            .messages
193            .iter()
194            .map(|msg| {
195                let role = match msg.role {
196                    Role::User => "user",
197                    Role::Assistant => "assistant",
198                };
199                // Extract text from the first content item
200                let content = msg
201                    .content
202                    .first()
203                    .and_then(|c| match c {
204                        SamplingMessageContent::Text(t) => {
205                            Some(SamplingContent::text(t.text.clone()))
206                        }
207                        SamplingMessageContent::Image(img) => {
208                            Some(SamplingContent::image(img.data.clone(), img.mime_type.clone()))
209                        }
210                        _ => None,
211                    })
212                    .unwrap_or_else(|| SamplingContent::text(""));
213                SamplingMessage::new(role, content)
214            })
215            .collect();
216
217        let request = SamplingRequest {
218            messages,
219            system_prompt: params.system_prompt.clone(),
220            model_preferences: None,
221            max_tokens: Some(params.max_tokens),
222            temperature: params.temperature.map(|t| t as f64),
223        };
224
225        match std::panic::AssertUnwindSafe(sampling_handler.handle_create_message(request))
226            .catch_unwind()
227            .await
228        {
229            Ok(Ok(response)) => {
230                // Convert our SamplingResponse → rmcp CreateMessageResult
231                let text = match &response.content {
232                    SamplingContent::Text { text } => text.clone(),
233                    SamplingContent::Image { .. } => String::new(),
234                };
235                let message = rmcp::model::SamplingMessage::assistant_text(text);
236                Ok(CreateMessageResult::new(message, response.model)
237                    .with_stop_reason(response.stop_reason))
238            }
239            Ok(Err(e)) => {
240                tracing::warn!(error = %e, "sampling handler returned error");
241                Err(rmcp::ErrorData::new(
242                    rmcp::model::ErrorCode::INTERNAL_ERROR,
243                    format!("sampling handler error: {e}"),
244                    None,
245                ))
246            }
247            Err(_panic) => {
248                tracing::warn!("sampling handler panicked");
249                Err(rmcp::ErrorData::new(
250                    rmcp::model::ErrorCode::INTERNAL_ERROR,
251                    "sampling handler panicked",
252                    None,
253                ))
254            }
255        }
256    }
257
258    async fn create_elicitation(
259        &self,
260        request: CreateElicitationRequestParams,
261        _context: RequestContext<RoleClient>,
262    ) -> Result<CreateElicitationResult, rmcp::ErrorData> {
263        {
264            let result = match &request {
265                CreateElicitationRequestParams::FormElicitationParams {
266                    message,
267                    requested_schema,
268                    meta,
269                    ..
270                } => {
271                    let metadata_value = meta.as_ref().and_then(|m| serde_json::to_value(m).ok());
272                    std::panic::AssertUnwindSafe(self.handler.handle_form_elicitation(
273                        message,
274                        requested_schema,
275                        metadata_value.as_ref(),
276                    ))
277                    .catch_unwind()
278                    .await
279                }
280                CreateElicitationRequestParams::UrlElicitationParams {
281                    message,
282                    url,
283                    elicitation_id,
284                    meta,
285                    ..
286                } => {
287                    let metadata_value = meta.as_ref().and_then(|m| serde_json::to_value(m).ok());
288                    std::panic::AssertUnwindSafe(self.handler.handle_url_elicitation(
289                        message,
290                        url,
291                        elicitation_id,
292                        metadata_value.as_ref(),
293                    ))
294                    .catch_unwind()
295                    .await
296                }
297            };
298
299            match result {
300                Ok(Ok(elicitation_result)) => Ok(elicitation_result),
301                Ok(Err(e)) => {
302                    tracing::warn!(error = %e, "elicitation handler returned error, declining");
303                    Ok(CreateElicitationResult::new(ElicitationAction::Decline))
304                }
305                Err(_panic) => {
306                    tracing::warn!("elicitation handler panicked, declining");
307                    Ok(CreateElicitationResult::new(ElicitationAction::Decline))
308                }
309            }
310        }
311    }
312
313    async fn on_url_elicitation_notification_complete(
314        &self,
315        _params: ElicitationResponseNotificationParam,
316        _context: NotificationContext<RoleClient>,
317    ) {
318        tracing::debug!("received URL elicitation completion notification");
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_elicitation_handler_is_send_sync() {
328        fn require_send_sync<T: Send + Sync>() {}
329        require_send_sync::<AutoDeclineElicitationHandler>();
330    }
331
332    #[test]
333    fn test_adk_client_handler_is_send_sync() {
334        fn require_send_sync<T: Send + Sync>() {}
335        require_send_sync::<AdkClientHandler>();
336    }
337}