1use crate::types::elicitation::{ElicitAction, ElicitRequest, ElicitResult, ElicitSchema};
33use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36use tokio::sync::{mpsc, oneshot};
37
38pub trait ElicitationClient: Send + Sync {
42 fn elicit(
44 &self,
45 request: ElicitRequest,
46 ) -> Pin<Box<dyn Future<Output = Result<ElicitResult, ElicitationError>> + Send + '_>>;
47}
48
49pub trait ElicitationClientExt: ElicitationClient {
51 fn confirm(
53 &self,
54 message: &str,
55 ) -> Pin<Box<dyn Future<Output = Result<bool, ElicitationError>> + Send + '_>> {
56 let request = ElicitRequest::confirm(message);
57 Box::pin(async move {
58 let result = self.elicit(request).await?;
59 Ok(matches!(result.action, ElicitAction::Accepted))
60 })
61 }
62
63 fn prompt_text(
65 &self,
66 message: &str,
67 ) -> Pin<Box<dyn Future<Output = Result<Option<String>, ElicitationError>> + Send + '_>> {
68 let request = ElicitRequest::text(message);
69 Box::pin(async move {
70 let result = self.elicit(request).await?;
71 match result.action {
72 ElicitAction::Accepted => Ok(result.as_string()),
73 _ => Ok(None),
74 }
75 })
76 }
77
78 fn choose(
80 &self,
81 message: &str,
82 options: Vec<String>,
83 ) -> Pin<Box<dyn Future<Output = Result<Option<String>, ElicitationError>> + Send + '_>> {
84 let request = ElicitRequest::choice(message, options);
85 Box::pin(async move {
86 let result = self.elicit(request).await?;
87 match result.action {
88 ElicitAction::Accepted => Ok(result.as_string()),
89 _ => Ok(None),
90 }
91 })
92 }
93
94 fn prompt_number(
96 &self,
97 message: &str,
98 ) -> Pin<Box<dyn Future<Output = Result<Option<f64>, ElicitationError>> + Send + '_>> {
99 let request = ElicitRequest::with_schema(message, ElicitSchema::number());
100 Box::pin(async move {
101 let result = self.elicit(request).await?;
102 match result.action {
103 ElicitAction::Accepted => Ok(result.content_as::<f64>()),
104 _ => Ok(None),
105 }
106 })
107 }
108}
109
110impl<T: ElicitationClient + ?Sized> ElicitationClientExt for T {}
112
113#[derive(Debug, thiserror::Error)]
117pub enum ElicitationError {
118 #[error("Elicitation not supported by client")]
120 NotSupported,
121
122 #[error("Elicitation cancelled")]
124 Cancelled,
125
126 #[error("Client connection lost")]
128 ConnectionLost,
129
130 #[error("Elicitation timeout")]
132 Timeout,
133
134 #[error("Elicitation error: {0}")]
136 Other(String),
137}
138
139pub struct ElicitationRequestMessage {
143 pub request: ElicitRequest,
145 pub response_tx: oneshot::Sender<Result<ElicitResult, ElicitationError>>,
147}
148
149#[derive(Clone)]
154pub struct ChannelElicitationClient {
155 tx: mpsc::Sender<ElicitationRequestMessage>,
156}
157
158impl ChannelElicitationClient {
159 pub fn new(tx: mpsc::Sender<ElicitationRequestMessage>) -> Self {
161 Self { tx }
162 }
163
164 pub fn channel(buffer: usize) -> (Self, mpsc::Receiver<ElicitationRequestMessage>) {
166 let (tx, rx) = mpsc::channel(buffer);
167 (Self::new(tx), rx)
168 }
169}
170
171impl ElicitationClient for ChannelElicitationClient {
172 fn elicit(
173 &self,
174 request: ElicitRequest,
175 ) -> Pin<Box<dyn Future<Output = Result<ElicitResult, ElicitationError>> + Send + '_>> {
176 Box::pin(async move {
177 let (response_tx, response_rx) = oneshot::channel();
178
179 self.tx
180 .send(ElicitationRequestMessage {
181 request,
182 response_tx,
183 })
184 .await
185 .map_err(|_| ElicitationError::ConnectionLost)?;
186
187 response_rx
188 .await
189 .map_err(|_| ElicitationError::ConnectionLost)?
190 })
191 }
192}
193
194impl<T: ElicitationClient + ?Sized> ElicitationClient for Arc<T> {
197 fn elicit(
198 &self,
199 request: ElicitRequest,
200 ) -> Pin<Box<dyn Future<Output = Result<ElicitResult, ElicitationError>> + Send + '_>> {
201 (**self).elicit(request)
202 }
203}
204
205#[derive(Debug, Default)]
209pub struct ElicitationRequestBuilder {
210 message: String,
211 properties: serde_json::Map<String, serde_json::Value>,
212 required: Vec<String>,
213}
214
215impl ElicitationRequestBuilder {
216 pub fn new(message: impl Into<String>) -> Self {
218 Self {
219 message: message.into(),
220 properties: serde_json::Map::new(),
221 required: Vec::new(),
222 }
223 }
224
225 pub fn boolean(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
227 let name = name.into();
228 self.properties.insert(
229 name.clone(),
230 serde_json::json!({
231 "type": "boolean",
232 "title": title.into()
233 }),
234 );
235 self
236 }
237
238 pub fn boolean_required(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
240 let name = name.into();
241 self.required.push(name.clone());
242 self.boolean(name, title)
243 }
244
245 pub fn text(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
247 let name = name.into();
248 self.properties.insert(
249 name.clone(),
250 serde_json::json!({
251 "type": "string",
252 "title": title.into()
253 }),
254 );
255 self
256 }
257
258 pub fn text_required(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
260 let name = name.into();
261 self.required.push(name.clone());
262 self.text(name, title)
263 }
264
265 pub fn number(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
267 let name = name.into();
268 self.properties.insert(
269 name.clone(),
270 serde_json::json!({
271 "type": "number",
272 "title": title.into()
273 }),
274 );
275 self
276 }
277
278 pub fn number_required(mut self, name: impl Into<String>, title: impl Into<String>) -> Self {
280 let name = name.into();
281 self.required.push(name.clone());
282 self.number(name, title)
283 }
284
285 pub fn select(
287 mut self,
288 name: impl Into<String>,
289 title: impl Into<String>,
290 options: &[&str],
291 ) -> Self {
292 let name = name.into();
293 self.properties.insert(
294 name.clone(),
295 serde_json::json!({
296 "type": "string",
297 "title": title.into(),
298 "enum": options
299 }),
300 );
301 self
302 }
303
304 pub fn select_required(
306 mut self,
307 name: impl Into<String>,
308 title: impl Into<String>,
309 options: &[&str],
310 ) -> Self {
311 let name = name.into();
312 self.required.push(name.clone());
313 self.select(name, title, options)
314 }
315
316 pub fn build(self) -> ElicitRequest {
318 let schema = serde_json::json!({
319 "type": "object",
320 "properties": self.properties,
321 "required": self.required
322 });
323
324 ElicitRequest::with_schema(self.message, ElicitSchema::object(schema))
325 }
326}
327
328#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[tokio::test]
335 async fn test_channel_elicitation_client() {
336 let (client, mut rx) = ChannelElicitationClient::channel(10);
337
338 tokio::spawn(async move {
340 while let Some(msg) = rx.recv().await {
341 let _ = msg
342 .response_tx
343 .send(Ok(ElicitResult::accepted(serde_json::json!(
344 "test response"
345 ))));
346 }
347 });
348
349 let result = client.prompt_text("Enter something").await.unwrap();
351 assert_eq!(result, Some("test response".to_string()));
352 }
353
354 #[tokio::test]
355 async fn test_confirm() {
356 let (client, mut rx) = ChannelElicitationClient::channel(10);
357
358 tokio::spawn(async move {
359 while let Some(msg) = rx.recv().await {
360 let _ = msg
361 .response_tx
362 .send(Ok(ElicitResult::accepted(serde_json::json!(true))));
363 }
364 });
365
366 let result = client.confirm("Are you sure?").await.unwrap();
367 assert!(result);
368 }
369
370 #[test]
371 fn test_elicitation_request_builder() {
372 let request = ElicitationRequestBuilder::new("Configure your project")
373 .text_required("name", "Project Name")
374 .boolean("private", "Private Repository")
375 .number("port", "Port Number")
376 .select("language", "Language", &["rust", "python", "javascript"])
377 .build();
378
379 assert_eq!(request.message, "Configure your project");
380 assert!(request.requested_schema.is_some());
381
382 let schema = request.requested_schema.unwrap();
383 let props = schema.schema.get("properties").unwrap();
384 assert!(props.get("name").is_some());
385 assert!(props.get("private").is_some());
386 assert!(props.get("port").is_some());
387 assert!(props.get("language").is_some());
388 }
389}