adk_tool/mcp/
elicitation.rs1use std::sync::Arc;
8
9use futures::FutureExt;
10use rmcp::model::{
11 ClientInfo, CreateElicitationRequestParams, CreateElicitationResult, ElicitationAction,
12 ElicitationResponseNotificationParam, ElicitationSchema,
13};
14use rmcp::service::{NotificationContext, RequestContext, RoleClient};
15use serde_json::Value;
16
17#[async_trait::async_trait]
55pub trait ElicitationHandler: Send + Sync {
56 async fn handle_form_elicitation(
62 &self,
63 message: &str,
64 schema: &ElicitationSchema,
65 metadata: Option<&Value>,
66 ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>>;
67
68 async fn handle_url_elicitation(
74 &self,
75 message: &str,
76 url: &str,
77 elicitation_id: &str,
78 metadata: Option<&Value>,
79 ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>>;
80}
81
82#[derive(Debug, Clone, Copy)]
87pub struct AutoDeclineElicitationHandler;
88
89#[async_trait::async_trait]
90impl ElicitationHandler for AutoDeclineElicitationHandler {
91 async fn handle_form_elicitation(
92 &self,
93 _message: &str,
94 _schema: &ElicitationSchema,
95 _metadata: Option<&Value>,
96 ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>> {
97 Ok(CreateElicitationResult::new(ElicitationAction::Decline))
98 }
99
100 async fn handle_url_elicitation(
101 &self,
102 _message: &str,
103 _url: &str,
104 _elicitation_id: &str,
105 _metadata: Option<&Value>,
106 ) -> Result<CreateElicitationResult, Box<dyn std::error::Error + Send + Sync>> {
107 Ok(CreateElicitationResult::new(ElicitationAction::Decline))
108 }
109}
110
111pub struct AdkClientHandler {
116 handler: Arc<dyn ElicitationHandler>,
117}
118
119impl AdkClientHandler {
120 pub fn new(handler: Arc<dyn ElicitationHandler>) -> Self {
121 Self { handler }
122 }
123}
124
125impl rmcp::handler::client::ClientHandler for AdkClientHandler {
126 fn get_info(&self) -> ClientInfo {
127 let mut info = ClientInfo::default();
128 info.capabilities = rmcp::model::ClientCapabilities::builder().enable_elicitation().build();
129 info
130 }
131
132 async fn create_elicitation(
133 &self,
134 request: CreateElicitationRequestParams,
135 _context: RequestContext<RoleClient>,
136 ) -> Result<CreateElicitationResult, rmcp::ErrorData> {
137 {
138 let result = match &request {
139 CreateElicitationRequestParams::FormElicitationParams {
140 message,
141 requested_schema,
142 meta,
143 ..
144 } => {
145 let metadata_value = meta.as_ref().and_then(|m| serde_json::to_value(m).ok());
146 std::panic::AssertUnwindSafe(self.handler.handle_form_elicitation(
147 message,
148 requested_schema,
149 metadata_value.as_ref(),
150 ))
151 .catch_unwind()
152 .await
153 }
154 CreateElicitationRequestParams::UrlElicitationParams {
155 message,
156 url,
157 elicitation_id,
158 meta,
159 ..
160 } => {
161 let metadata_value = meta.as_ref().and_then(|m| serde_json::to_value(m).ok());
162 std::panic::AssertUnwindSafe(self.handler.handle_url_elicitation(
163 message,
164 url,
165 elicitation_id,
166 metadata_value.as_ref(),
167 ))
168 .catch_unwind()
169 .await
170 }
171 };
172
173 match result {
174 Ok(Ok(elicitation_result)) => Ok(elicitation_result),
175 Ok(Err(e)) => {
176 tracing::warn!(error = %e, "elicitation handler returned error, declining");
177 Ok(CreateElicitationResult::new(ElicitationAction::Decline))
178 }
179 Err(_panic) => {
180 tracing::warn!("elicitation handler panicked, declining");
181 Ok(CreateElicitationResult::new(ElicitationAction::Decline))
182 }
183 }
184 }
185 }
186
187 async fn on_url_elicitation_notification_complete(
188 &self,
189 _params: ElicitationResponseNotificationParam,
190 _context: NotificationContext<RoleClient>,
191 ) {
192 tracing::debug!("received URL elicitation completion notification");
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_elicitation_handler_is_send_sync() {
202 fn require_send_sync<T: Send + Sync>() {}
203 require_send_sync::<AutoDeclineElicitationHandler>();
204 }
205
206 #[test]
207 fn test_adk_client_handler_is_send_sync() {
208 fn require_send_sync<T: Send + Sync>() {}
209 require_send_sync::<AdkClientHandler>();
210 }
211}