1use std::sync::Arc;
8
9use futures::FutureExt;
10use rmcp::model::{
11 ClientInfo, CreateElicitationRequestParams, CreateElicitationResult, ElicitationAction,
12 ElicitationResponseNotificationParam, ElicitationSchema,
13};
14use rmcp::service::{NotificationContext, RequestContext, RoleClient};
15use serde_json::Value;
16
17#[async_trait::async_trait]
55pub trait ElicitationHandler: Send + Sync {
56 async fn handle_form_elicitation(
62 &self,
63 message: &str,
64 schema: &ElicitationSchema,
65 metadata: Option<&Value>,
66 ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>>;
67
68 async fn handle_url_elicitation(
74 &self,
75 message: &str,
76 url: &str,
77 elicitation_id: &str,
78 metadata: Option<&Value>,
79 ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>>;
80}
81
82#[derive(Debug, Clone, Copy)]
87pub struct AutoDeclineElicitationHandler;
88
89#[async_trait::async_trait]
90impl ElicitationHandler for AutoDeclineElicitationHandler {
91 async fn handle_form_elicitation(
92 &self,
93 _message: &str,
94 _schema: &ElicitationSchema,
95 _metadata: Option<&Value>,
96 ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>> {
97 Ok(CreateElicitationResult::new(ElicitationAction::Decline))
98 }
99
100 async fn handle_url_elicitation(
101 &self,
102 _message: &str,
103 _url: &str,
104 _elicitation_id: &str,
105 _metadata: Option<&Value>,
106 ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>> {
107 Ok(CreateElicitationResult::new(ElicitationAction::Decline))
108 }
109}
110
111pub struct AdkClientHandler {
119 handler: Arc<dyn ElicitationHandler>,
120 #[cfg(feature = "mcp-sampling")]
121 sampling_handler: Option<Arc<dyn crate::sampling::SamplingHandler>>,
122}
123
124impl AdkClientHandler {
125 pub fn new(handler: Arc<dyn ElicitationHandler>) -> Self {
126 Self {
127 handler,
128 #[cfg(feature = "mcp-sampling")]
129 sampling_handler: None,
130 }
131 }
132
133 #[cfg(feature = "mcp-sampling")]
138 pub fn with_sampling_handler(
139 mut self,
140 handler: Arc<dyn crate::sampling::SamplingHandler>,
141 ) -> Self {
142 self.sampling_handler = Some(handler);
143 self
144 }
145}
146
147impl rmcp::handler::client::ClientHandler for AdkClientHandler {
148 fn get_info(&self) -> ClientInfo {
149 let mut info = ClientInfo::default();
150
151 #[cfg(feature = "mcp-sampling")]
152 {
153 if self.sampling_handler.is_some() {
154 info.capabilities = rmcp::model::ClientCapabilities::builder()
155 .enable_elicitation()
156 .enable_sampling()
157 .build();
158 } else {
159 info.capabilities =
160 rmcp::model::ClientCapabilities::builder().enable_elicitation().build();
161 }
162 }
163
164 #[cfg(not(feature = "mcp-sampling"))]
165 {
166 info.capabilities =
167 rmcp::model::ClientCapabilities::builder().enable_elicitation().build();
168 }
169
170 info
171 }
172
173 #[cfg(feature = "mcp-sampling")]
174 async fn create_message(
175 &self,
176 params: rmcp::model::CreateMessageRequestParams,
177 _context: RequestContext<RoleClient>,
178 ) -> Result<rmcp::model::CreateMessageResult, rmcp::ErrorData> {
179 use crate::sampling::{SamplingContent, SamplingMessage, SamplingRequest};
180 use rmcp::model::{CreateMessageResult, Role, SamplingMessageContent};
181
182 let Some(ref sampling_handler) = self.sampling_handler else {
183 return Err(rmcp::ErrorData::new(
184 rmcp::model::ErrorCode::METHOD_NOT_FOUND,
185 "sampling handler not configured",
186 None,
187 ));
188 };
189
190 let messages: Vec<SamplingMessage> = params
192 .messages
193 .iter()
194 .map(|msg| {
195 let role = match msg.role {
196 Role::User => "user",
197 Role::Assistant => "assistant",
198 };
199 let content = msg
201 .content
202 .first()
203 .and_then(|c| match c {
204 SamplingMessageContent::Text(t) => {
205 Some(SamplingContent::text(t.text.clone()))
206 }
207 SamplingMessageContent::Image(img) => {
208 Some(SamplingContent::image(img.data.clone(), img.mime_type.clone()))
209 }
210 _ => None,
211 })
212 .unwrap_or_else(|| SamplingContent::text(""));
213 SamplingMessage::new(role, content)
214 })
215 .collect();
216
217 let request = SamplingRequest {
218 messages,
219 system_prompt: params.system_prompt.clone(),
220 model_preferences: None,
221 max_tokens: Some(params.max_tokens),
222 temperature: params.temperature.map(|t| t as f64),
223 };
224
225 match std::panic::AssertUnwindSafe(sampling_handler.handle_create_message(request))
226 .catch_unwind()
227 .await
228 {
229 Ok(Ok(response)) => {
230 let text = match &response.content {
232 SamplingContent::Text { text } => text.clone(),
233 SamplingContent::Image { .. } => String::new(),
234 };
235 let message = rmcp::model::SamplingMessage::assistant_text(text);
236 Ok(CreateMessageResult::new(message, response.model)
237 .with_stop_reason(response.stop_reason))
238 }
239 Ok(Err(e)) => {
240 tracing::warn!(error = %e, "sampling handler returned error");
241 Err(rmcp::ErrorData::new(
242 rmcp::model::ErrorCode::INTERNAL_ERROR,
243 format!("sampling handler error: {e}"),
244 None,
245 ))
246 }
247 Err(_panic) => {
248 tracing::warn!("sampling handler panicked");
249 Err(rmcp::ErrorData::new(
250 rmcp::model::ErrorCode::INTERNAL_ERROR,
251 "sampling handler panicked",
252 None,
253 ))
254 }
255 }
256 }
257
258 async fn create_elicitation(
259 &self,
260 request: CreateElicitationRequestParams,
261 _context: RequestContext<RoleClient>,
262 ) -> Result<CreateElicitationResult, rmcp::ErrorData> {
263 {
264 let result = match &request {
265 CreateElicitationRequestParams::FormElicitationParams {
266 message,
267 requested_schema,
268 meta,
269 ..
270 } => {
271 let metadata_value = meta.as_ref().and_then(|m| serde_json::to_value(m).ok());
272 std::panic::AssertUnwindSafe(self.handler.handle_form_elicitation(
273 message,
274 requested_schema,
275 metadata_value.as_ref(),
276 ))
277 .catch_unwind()
278 .await
279 }
280 CreateElicitationRequestParams::UrlElicitationParams {
281 message,
282 url,
283 elicitation_id,
284 meta,
285 ..
286 } => {
287 let metadata_value = meta.as_ref().and_then(|m| serde_json::to_value(m).ok());
288 std::panic::AssertUnwindSafe(self.handler.handle_url_elicitation(
289 message,
290 url,
291 elicitation_id,
292 metadata_value.as_ref(),
293 ))
294 .catch_unwind()
295 .await
296 }
297 };
298
299 match result {
300 Ok(Ok(elicitation_result)) => Ok(elicitation_result),
301 Ok(Err(e)) => {
302 tracing::warn!(error = %e, "elicitation handler returned error, declining");
303 Ok(CreateElicitationResult::new(ElicitationAction::Decline))
304 }
305 Err(_panic) => {
306 tracing::warn!("elicitation handler panicked, declining");
307 Ok(CreateElicitationResult::new(ElicitationAction::Decline))
308 }
309 }
310 }
311 }
312
313 async fn on_url_elicitation_notification_complete(
314 &self,
315 _params: ElicitationResponseNotificationParam,
316 _context: NotificationContext<RoleClient>,
317 ) {
318 tracing::debug!("received URL elicitation completion notification");
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_elicitation_handler_is_send_sync() {
328 fn require_send_sync<T: Send + Sync>() {}
329 require_send_sync::<AutoDeclineElicitationHandler>();
330 }
331
332 #[test]
333 fn test_adk_client_handler_is_send_sync() {
334 fn require_send_sync<T: Send + Sync>() {}
335 require_send_sync::<AdkClientHandler>();
336 }
337}