Skip to main content

mcp_kit/server/
elicitation.rs

1//! Elicitation client for requesting user input from clients.
2//!
3//! Elicitation allows MCP servers to request input from users through the client.
4//! This is useful for interactive workflows where the server needs additional
5//! information that wasn't provided in the initial request.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use mcp_kit::server::ElicitationClientExt;
11//!
12//! // In a tool handler with access to an elicitation client
13//! let client: &impl ElicitationClientExt = get_client();
14//!
15//! // Simple confirmation
16//! if client.confirm("Delete this file?").await? {
17//!     // User confirmed
18//! }
19//!
20//! // Text input
21//! if let Some(name) = client.prompt_text("Enter project name").await? {
22//!     println!("Project: {}", name);
23//! }
24//!
25//! // Choice selection
26//! let options = vec!["small".into(), "medium".into(), "large".into()];
27//! if let Some(size) = client.choose("Select size", options).await? {
28//!     println!("Size: {}", size);
29//! }
30//! ```
31
32use crate::types::elicitation::{ElicitAction, ElicitRequest, ElicitResult, ElicitSchema};
33use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36use tokio::sync::{mpsc, oneshot};
37
38// ─── Elicitation Client Trait ────────────────────────────────────────────────
39
40/// Trait for clients that can request user input through elicitation.
41pub trait ElicitationClient: Send + Sync {
42    /// Send an elicitation request to the client.
43    fn elicit(
44        &self,
45        request: ElicitRequest,
46    ) -> Pin<Box<dyn Future<Output = Result<ElicitResult, ElicitationError>> + Send + '_>>;
47}
48
49/// Extension trait for convenient elicitation methods.
50pub trait ElicitationClientExt: ElicitationClient {
51    /// Request a simple yes/no confirmation from the user.
52    fn confirm(
53        &self,
54        message: &str,
55    ) -> Pin<Box<dyn Future<Output = Result<bool, ElicitationError>> + Send + '_>> {
56        let request = ElicitRequest::confirm(message);
57        Box::pin(async move {
58            let result = self.elicit(request).await?;
59            Ok(matches!(result.action, ElicitAction::Accepted))
60        })
61    }
62
63    /// Request text input from the user.
64    fn prompt_text(
65        &self,
66        message: &str,
67    ) -> Pin<Box<dyn Future<Output = Result<Option<String>, ElicitationError>> + Send + '_>> {
68        let request = ElicitRequest::text(message);
69        Box::pin(async move {
70            let result = self.elicit(request).await?;
71            match result.action {
72                ElicitAction::Accepted => Ok(result.as_string()),
73                _ => Ok(None),
74            }
75        })
76    }
77
78    /// Request the user to choose from a list of options.
79    fn choose(
80        &self,
81        message: &str,
82        options: Vec<String>,
83    ) -> Pin<Box<dyn Future<Output = Result<Option<String>, ElicitationError>> + Send + '_>> {
84        let request = ElicitRequest::choice(message, options);
85        Box::pin(async move {
86            let result = self.elicit(request).await?;
87            match result.action {
88                ElicitAction::Accepted => Ok(result.as_string()),
89                _ => Ok(None),
90            }
91        })
92    }
93
94    /// Request a number from the user.
95    fn prompt_number(
96        &self,
97        message: &str,
98    ) -> Pin<Box<dyn Future<Output = Result<Option<f64>, ElicitationError>> + Send + '_>> {
99        let request = ElicitRequest::with_schema(message, ElicitSchema::number());
100        Box::pin(async move {
101            let result = self.elicit(request).await?;
102            match result.action {
103                ElicitAction::Accepted => Ok(result.content_as::<f64>()),
104                _ => Ok(None),
105            }
106        })
107    }
108}
109
110// Blanket implementation for all ElicitationClient
111impl<T: ElicitationClient + ?Sized> ElicitationClientExt for T {}
112
113// ─── Elicitation Error ───────────────────────────────────────────────────────
114
115/// Errors that can occur during elicitation.
116#[derive(Debug, thiserror::Error)]
117pub enum ElicitationError {
118    /// The client doesn't support elicitation.
119    #[error("Elicitation not supported by client")]
120    NotSupported,
121
122    /// The elicitation request was cancelled.
123    #[error("Elicitation cancelled")]
124    Cancelled,
125
126    /// The client connection was lost.
127    #[error("Client connection lost")]
128    ConnectionLost,
129
130    /// A timeout occurred waiting for response.
131    #[error("Elicitation timeout")]
132    Timeout,
133
134    /// An error occurred during elicitation.
135    #[error("Elicitation error: {0}")]
136    Other(String),
137}
138
139// ─── Channel-based Elicitation Client ────────────────────────────────────────
140
141/// An elicitation request with a response channel.
142pub struct ElicitationRequestMessage {
143    /// The elicitation request.
144    pub request: ElicitRequest,
145    /// Channel to send the response.
146    pub response_tx: oneshot::Sender<Result<ElicitResult, ElicitationError>>,
147}
148
149/// Channel-based elicitation client.
150///
151/// This client sends elicitation requests through an mpsc channel,
152/// which can be processed by the transport layer.
153#[derive(Clone)]
154pub struct ChannelElicitationClient {
155    tx: mpsc::Sender<ElicitationRequestMessage>,
156}
157
158impl ChannelElicitationClient {
159    /// Create a new channel-based elicitation client.
160    pub fn new(tx: mpsc::Sender<ElicitationRequestMessage>) -> Self {
161        Self { tx }
162    }
163
164    /// Create a new channel pair for elicitation.
165    pub fn channel(buffer: usize) -> (Self, mpsc::Receiver<ElicitationRequestMessage>) {
166        let (tx, rx) = mpsc::channel(buffer);
167        (Self::new(tx), rx)
168    }
169}
170
171impl ElicitationClient for ChannelElicitationClient {
172    fn elicit(
173        &self,
174        request: ElicitRequest,
175    ) -> Pin<Box<dyn Future<Output = Result<ElicitResult, ElicitationError>> + Send + '_>> {
176        Box::pin(async move {
177            let (response_tx, response_rx) = oneshot::channel();
178
179            self.tx
180                .send(ElicitationRequestMessage {
181                    request,
182                    response_tx,
183                })
184                .await
185                .map_err(|_| ElicitationError::ConnectionLost)?;
186
187            response_rx
188                .await
189                .map_err(|_| ElicitationError::ConnectionLost)?
190        })
191    }
192}
193
194// ─── Arc wrapper implementation ──────────────────────────────────────────────
195
196impl<T: ElicitationClient + ?Sized> ElicitationClient for Arc<T> {
197    fn elicit(
198        &self,
199        request: ElicitRequest,
200    ) -> Pin<Box<dyn Future<Output = Result<ElicitResult, ElicitationError>> + Send + '_>> {
201        (**self).elicit(request)
202    }
203}
204
205// ─── Elicitation Request Builder ─────────────────────────────────────────────
206
207/// Builder for creating complex elicitation requests with multiple fields.
208#[derive(Debug, Default)]
209pub struct ElicitationRequestBuilder {
210    message: String,
211    properties: serde_json::Map<String, serde_json::Value>,
212    required: Vec<String>,
213}
214
215impl ElicitationRequestBuilder {
216    /// Create a new builder with the given message.
217    pub fn new(message: impl Into<String>) -> Self {
218        Self {
219            message: message.into(),
220            properties: serde_json::Map::new(),
221            required: Vec::new(),
222        }
223    }
224
225    /// Add a boolean field to the schema.
226    pub fn boolean(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
227        let name = name.into();
228        self.properties.insert(
229            name.clone(),
230            serde_json::json!({
231                "type": "boolean",
232                "title": title.into()
233            }),
234        );
235        self
236    }
237
238    /// Add a required boolean field.
239    pub fn boolean_required(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
240        let name = name.into();
241        self.required.push(name.clone());
242        self.boolean(name, title)
243    }
244
245    /// Add a text field to the schema.
246    pub fn text(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
247        let name = name.into();
248        self.properties.insert(
249            name.clone(),
250            serde_json::json!({
251                "type": "string",
252                "title": title.into()
253            }),
254        );
255        self
256    }
257
258    /// Add a required text field.
259    pub fn text_required(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
260        let name = name.into();
261        self.required.push(name.clone());
262        self.text(name, title)
263    }
264
265    /// Add a number field to the schema.
266    pub fn number(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
267        let name = name.into();
268        self.properties.insert(
269            name.clone(),
270            serde_json::json!({
271                "type": "number",
272                "title": title.into()
273            }),
274        );
275        self
276    }
277
278    /// Add a required number field.
279    pub fn number_required(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
280        let name = name.into();
281        self.required.push(name.clone());
282        self.number(name, title)
283    }
284
285    /// Add an enum field to the schema.
286    pub fn select(
287        mut self,
288        name: impl Into<String>,
289        title: impl Into<String>,
290        options: &[&str],
291    ) -> Self {
292        let name = name.into();
293        self.properties.insert(
294            name.clone(),
295            serde_json::json!({
296                "type": "string",
297                "title": title.into(),
298                "enum": options
299            }),
300        );
301        self
302    }
303
304    /// Add a required enum field.
305    pub fn select_required(
306        mut self,
307        name: impl Into<String>,
308        title: impl Into<String>,
309        options: &[&str],
310    ) -> Self {
311        let name = name.into();
312        self.required.push(name.clone());
313        self.select(name, title, options)
314    }
315
316    /// Build the elicitation request.
317    pub fn build(self) -> ElicitRequest {
318        let schema = serde_json::json!({
319            "type": "object",
320            "properties": self.properties,
321            "required": self.required
322        });
323
324        ElicitRequest::with_schema(self.message, ElicitSchema::object(schema))
325    }
326}
327
328// ─── Tests ───────────────────────────────────────────────────────────────────
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[tokio::test]
335    async fn test_channel_elicitation_client() {
336        let (client, mut rx) = ChannelElicitationClient::channel(10);
337
338        // Spawn a handler that accepts elicitation requests
339        tokio::spawn(async move {
340            while let Some(msg) = rx.recv().await {
341                let _ = msg
342                    .response_tx
343                    .send(Ok(ElicitResult::accepted(serde_json::json!(
344                        "test response"
345                    ))));
346            }
347        });
348
349        // Test prompt_text
350        let result = client.prompt_text("Enter something").await.unwrap();
351        assert_eq!(result, Some("test response".to_string()));
352    }
353
354    #[tokio::test]
355    async fn test_confirm() {
356        let (client, mut rx) = ChannelElicitationClient::channel(10);
357
358        tokio::spawn(async move {
359            while let Some(msg) = rx.recv().await {
360                let _ = msg
361                    .response_tx
362                    .send(Ok(ElicitResult::accepted(serde_json::json!(true))));
363            }
364        });
365
366        let result = client.confirm("Are you sure?").await.unwrap();
367        assert!(result);
368    }
369
370    #[test]
371    fn test_elicitation_request_builder() {
372        let request = ElicitationRequestBuilder::new("Configure your project")
373            .text_required("name", "Project Name")
374            .boolean("private", "Private Repository")
375            .number("port", "Port Number")
376            .select("language", "Language", &["rust", "python", "javascript"])
377            .build();
378
379        assert_eq!(request.message, "Configure your project");
380        assert!(request.requested_schema.is_some());
381
382        let schema = request.requested_schema.unwrap();
383        let props = schema.schema.get("properties").unwrap();
384        assert!(props.get("name").is_some());
385        assert!(props.get("private").is_some());
386        assert!(props.get("port").is_some());
387        assert!(props.get("language").is_some());
388    }
389}