1use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::time::Duration;
14
15use crate::tools::base::{PermissionCheckResult, Tool};
16use crate::tools::context::{ToolContext, ToolResult};
17use crate::tools::error::ToolError;
18
19pub const DEFAULT_ASK_TIMEOUT_SECS: u64 = 300;
21
22pub type AskCallback = Arc<
27 dyn Fn(String, Option<Vec<String>>) -> Pin<Box<dyn Future<Output = Option<String>> + Send>>
28 + Send
29 + Sync,
30>;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct AskOption {
35 pub value: String,
37 pub label: Option<String>,
39}
40
41impl AskOption {
42 pub fn new(value: impl Into<String>) -> Self {
44 Self {
45 value: value.into(),
46 label: None,
47 }
48 }
49
50 pub fn with_label(value: impl Into<String>, label: impl Into<String>) -> Self {
52 Self {
53 value: value.into(),
54 label: Some(label.into()),
55 }
56 }
57
58 pub fn display(&self) -> &str {
60 self.label.as_deref().unwrap_or(&self.value)
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct AskResult {
67 pub response: String,
69 pub from_option: bool,
71 pub option_index: Option<usize>,
73}
74
75impl AskResult {
76 pub fn from_input(response: String) -> Self {
78 Self {
79 response,
80 from_option: false,
81 option_index: None,
82 }
83 }
84
85 pub fn from_option(response: String, index: usize) -> Self {
87 Self {
88 response,
89 from_option: true,
90 option_index: Some(index),
91 }
92 }
93}
94
95pub struct AskTool {
105 callback: Option<AskCallback>,
107 timeout: Duration,
109}
110
111impl Default for AskTool {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117impl AskTool {
118 pub fn new() -> Self {
123 Self {
124 callback: None,
125 timeout: Duration::from_secs(DEFAULT_ASK_TIMEOUT_SECS),
126 }
127 }
128
129 pub fn with_callback(mut self, callback: AskCallback) -> Self {
131 self.callback = Some(callback);
132 self
133 }
134
135 pub fn with_timeout(mut self, timeout: Duration) -> Self {
137 self.timeout = timeout;
138 self
139 }
140
141 pub fn has_callback(&self) -> bool {
143 self.callback.is_some()
144 }
145
146 pub fn timeout(&self) -> Duration {
148 self.timeout
149 }
150
151 pub async fn ask(
164 &self,
165 question: &str,
166 options: Option<&[AskOption]>,
167 ) -> Result<AskResult, ToolError> {
168 let callback = self.callback.as_ref().ok_or_else(|| {
169 ToolError::execution_failed("No callback configured for user interaction")
170 })?;
171
172 let option_labels: Option<Vec<String>> =
174 options.map(|opts| opts.iter().map(|o| o.display().to_string()).collect());
175
176 let response = tokio::time::timeout(
178 self.timeout,
179 callback(question.to_string(), option_labels.clone()),
180 )
181 .await
182 .map_err(|_| ToolError::timeout(self.timeout))?;
183
184 match response {
186 Some(response_text) => {
187 if let Some(opts) = options {
189 for (idx, opt) in opts.iter().enumerate() {
190 if response_text == opt.value || response_text == opt.display() {
191 return Ok(AskResult::from_option(opt.value.clone(), idx));
192 }
193 }
194 }
195 Ok(AskResult::from_input(response_text))
197 }
198 None => Err(ToolError::execution_failed(
199 "User cancelled the interaction",
200 )),
201 }
202 }
203}
204
205#[async_trait]
206impl Tool for AskTool {
207 fn name(&self) -> &str {
208 "ask"
209 }
210
211 fn description(&self) -> &str {
212 "Ask a question to the user and wait for their response. \
213 Supports free-form text input or selection from predefined options. \
214 Use this tool when you need clarification, confirmation, or user input \
215 to proceed with a task."
216 }
217
218 fn input_schema(&self) -> serde_json::Value {
219 serde_json::json!({
220 "type": "object",
221 "properties": {
222 "question": {
223 "type": "string",
224 "description": "The question to ask the user"
225 },
226 "options": {
227 "type": "array",
228 "description": "Optional predefined options for the user to select from",
229 "items": {
230 "type": "object",
231 "properties": {
232 "value": {
233 "type": "string",
234 "description": "The value to return if this option is selected"
235 },
236 "label": {
237 "type": "string",
238 "description": "Optional display label (defaults to value)"
239 }
240 },
241 "required": ["value"]
242 }
243 }
244 },
245 "required": ["question"]
246 })
247 }
248
249 async fn execute(
250 &self,
251 params: serde_json::Value,
252 _context: &ToolContext,
253 ) -> Result<ToolResult, ToolError> {
254 let question = params
256 .get("question")
257 .and_then(|v| v.as_str())
258 .ok_or_else(|| ToolError::invalid_params("Missing required parameter: question"))?;
259
260 let options: Option<Vec<AskOption>> = params
262 .get("options")
263 .and_then(|v| serde_json::from_value(v.clone()).ok());
264
265 let result = self.ask(question, options.as_deref()).await?;
267
268 let output = if result.from_option {
270 format!(
271 "User selected option {}: {}",
272 result.option_index.unwrap_or(0) + 1,
273 result.response
274 )
275 } else {
276 format!("User response: {}", result.response)
277 };
278
279 Ok(ToolResult::success(output)
280 .with_metadata("response", serde_json::json!(result.response))
281 .with_metadata("from_option", serde_json::json!(result.from_option))
282 .with_metadata("option_index", serde_json::json!(result.option_index)))
283 }
284
285 async fn check_permissions(
286 &self,
287 _params: &serde_json::Value,
288 _context: &ToolContext,
289 ) -> PermissionCheckResult {
290 PermissionCheckResult::allow()
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use std::path::PathBuf;
300
301 fn mock_callback(response: Option<String>) -> AskCallback {
303 Arc::new(move |_question, _options| {
304 let resp = response.clone();
305 Box::pin(async move { resp })
306 })
307 }
308
309 fn mock_callback_delayed(response: Option<String>, delay_ms: u64) -> AskCallback {
311 Arc::new(move |_question, _options| {
312 let resp = response.clone();
313 Box::pin(async move {
314 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
315 resp
316 })
317 })
318 }
319
320 #[test]
321 fn test_ask_option_new() {
322 let opt = AskOption::new("yes");
323 assert_eq!(opt.value, "yes");
324 assert!(opt.label.is_none());
325 assert_eq!(opt.display(), "yes");
326 }
327
328 #[test]
329 fn test_ask_option_with_label() {
330 let opt = AskOption::with_label("y", "Yes, proceed");
331 assert_eq!(opt.value, "y");
332 assert_eq!(opt.label, Some("Yes, proceed".to_string()));
333 assert_eq!(opt.display(), "Yes, proceed");
334 }
335
336 #[test]
337 fn test_ask_result_from_input() {
338 let result = AskResult::from_input("hello".to_string());
339 assert_eq!(result.response, "hello");
340 assert!(!result.from_option);
341 assert!(result.option_index.is_none());
342 }
343
344 #[test]
345 fn test_ask_result_from_option() {
346 let result = AskResult::from_option("yes".to_string(), 0);
347 assert_eq!(result.response, "yes");
348 assert!(result.from_option);
349 assert_eq!(result.option_index, Some(0));
350 }
351
352 #[test]
353 fn test_ask_tool_new() {
354 let tool = AskTool::new();
355 assert!(!tool.has_callback());
356 assert_eq!(
357 tool.timeout(),
358 Duration::from_secs(DEFAULT_ASK_TIMEOUT_SECS)
359 );
360 }
361
362 #[test]
363 fn test_ask_tool_with_callback() {
364 let callback = mock_callback(Some("test".to_string()));
365 let tool = AskTool::new().with_callback(callback);
366 assert!(tool.has_callback());
367 }
368
369 #[test]
370 fn test_ask_tool_with_timeout() {
371 let tool = AskTool::new().with_timeout(Duration::from_secs(60));
372 assert_eq!(tool.timeout(), Duration::from_secs(60));
373 }
374
375 #[test]
376 fn test_ask_tool_default() {
377 let tool = AskTool::default();
378 assert!(!tool.has_callback());
379 assert_eq!(
380 tool.timeout(),
381 Duration::from_secs(DEFAULT_ASK_TIMEOUT_SECS)
382 );
383 }
384
385 #[tokio::test]
386 async fn test_ask_without_callback() {
387 let tool = AskTool::new();
388 let result = tool.ask("What is your name?", None).await;
389 assert!(result.is_err());
390 assert!(matches!(result.unwrap_err(), ToolError::ExecutionFailed(_)));
391 }
392
393 #[tokio::test]
394 async fn test_ask_free_form_response() {
395 let callback = mock_callback(Some("John".to_string()));
396 let tool = AskTool::new().with_callback(callback);
397
398 let result = tool.ask("What is your name?", None).await.unwrap();
399 assert_eq!(result.response, "John");
400 assert!(!result.from_option);
401 assert!(result.option_index.is_none());
402 }
403
404 #[tokio::test]
405 async fn test_ask_with_options_select_by_value() {
406 let callback = mock_callback(Some("yes".to_string()));
407 let tool = AskTool::new().with_callback(callback);
408
409 let options = vec![AskOption::new("yes"), AskOption::new("no")];
410
411 let result = tool.ask("Continue?", Some(&options)).await.unwrap();
412 assert_eq!(result.response, "yes");
413 assert!(result.from_option);
414 assert_eq!(result.option_index, Some(0));
415 }
416
417 #[tokio::test]
418 async fn test_ask_with_options_select_by_label() {
419 let callback = mock_callback(Some("Yes, proceed".to_string()));
420 let tool = AskTool::new().with_callback(callback);
421
422 let options = vec![
423 AskOption::with_label("y", "Yes, proceed"),
424 AskOption::with_label("n", "No, cancel"),
425 ];
426
427 let result = tool.ask("Continue?", Some(&options)).await.unwrap();
428 assert_eq!(result.response, "y");
429 assert!(result.from_option);
430 assert_eq!(result.option_index, Some(0));
431 }
432
433 #[tokio::test]
434 async fn test_ask_with_options_free_form() {
435 let callback = mock_callback(Some("maybe".to_string()));
436 let tool = AskTool::new().with_callback(callback);
437
438 let options = vec![AskOption::new("yes"), AskOption::new("no")];
439
440 let result = tool.ask("Continue?", Some(&options)).await.unwrap();
441 assert_eq!(result.response, "maybe");
442 assert!(!result.from_option);
443 assert!(result.option_index.is_none());
444 }
445
446 #[tokio::test]
447 async fn test_ask_user_cancels() {
448 let callback = mock_callback(None);
449 let tool = AskTool::new().with_callback(callback);
450
451 let result = tool.ask("What is your name?", None).await;
452 assert!(result.is_err());
453 assert!(matches!(result.unwrap_err(), ToolError::ExecutionFailed(_)));
454 }
455
456 #[tokio::test]
457 async fn test_ask_timeout() {
458 let callback = mock_callback_delayed(Some("response".to_string()), 200);
459 let tool = AskTool::new()
460 .with_callback(callback)
461 .with_timeout(Duration::from_millis(50));
462
463 let result = tool.ask("What is your name?", None).await;
464 assert!(result.is_err());
465 assert!(matches!(result.unwrap_err(), ToolError::Timeout(_)));
466 }
467
468 #[tokio::test]
469 async fn test_ask_tool_trait_name() {
470 let tool = AskTool::new();
471 assert_eq!(tool.name(), "ask");
472 }
473
474 #[tokio::test]
475 async fn test_ask_tool_trait_description() {
476 let tool = AskTool::new();
477 assert!(tool.description().contains("Ask a question"));
478 }
479
480 #[tokio::test]
481 async fn test_ask_tool_trait_input_schema() {
482 let tool = AskTool::new();
483 let schema = tool.input_schema();
484
485 assert_eq!(schema["type"], "object");
486 assert!(schema["properties"]["question"].is_object());
487 assert!(schema["properties"]["options"].is_object());
488 assert!(schema["required"]
489 .as_array()
490 .unwrap()
491 .contains(&serde_json::json!("question")));
492 }
493
494 #[tokio::test]
495 async fn test_ask_tool_execute_success() {
496 let callback = mock_callback(Some("John".to_string()));
497 let tool = AskTool::new().with_callback(callback);
498 let context = ToolContext::new(PathBuf::from("/tmp"));
499
500 let params = serde_json::json!({
501 "question": "What is your name?"
502 });
503
504 let result = tool.execute(params, &context).await.unwrap();
505 assert!(result.is_success());
506 assert!(result.output.unwrap().contains("John"));
507 assert_eq!(
508 result.metadata.get("response"),
509 Some(&serde_json::json!("John"))
510 );
511 assert_eq!(
512 result.metadata.get("from_option"),
513 Some(&serde_json::json!(false))
514 );
515 }
516
517 #[tokio::test]
518 async fn test_ask_tool_execute_with_options() {
519 let callback = mock_callback(Some("yes".to_string()));
520 let tool = AskTool::new().with_callback(callback);
521 let context = ToolContext::new(PathBuf::from("/tmp"));
522
523 let params = serde_json::json!({
524 "question": "Continue?",
525 "options": [
526 { "value": "yes", "label": "Yes" },
527 { "value": "no", "label": "No" }
528 ]
529 });
530
531 let result = tool.execute(params, &context).await.unwrap();
532 assert!(result.is_success());
533 assert!(result.output.unwrap().contains("selected option"));
534 assert_eq!(
535 result.metadata.get("from_option"),
536 Some(&serde_json::json!(true))
537 );
538 assert_eq!(
539 result.metadata.get("option_index"),
540 Some(&serde_json::json!(0))
541 );
542 }
543
544 #[tokio::test]
545 async fn test_ask_tool_execute_missing_question() {
546 let callback = mock_callback(Some("test".to_string()));
547 let tool = AskTool::new().with_callback(callback);
548 let context = ToolContext::new(PathBuf::from("/tmp"));
549
550 let params = serde_json::json!({});
551
552 let result = tool.execute(params, &context).await;
553 assert!(result.is_err());
554 assert!(matches!(result.unwrap_err(), ToolError::InvalidParams(_)));
555 }
556
557 #[tokio::test]
558 async fn test_ask_tool_check_permissions() {
559 let tool = AskTool::new();
560 let context = ToolContext::new(PathBuf::from("/tmp"));
561 let params = serde_json::json!({"question": "test"});
562
563 let result = tool.check_permissions(¶ms, &context).await;
564 assert!(result.is_allowed());
565 }
566
567 #[test]
568 fn test_ask_option_serialization() {
569 let opt = AskOption::with_label("y", "Yes");
570 let json = serde_json::to_string(&opt).unwrap();
571 let deserialized: AskOption = serde_json::from_str(&json).unwrap();
572
573 assert_eq!(opt.value, deserialized.value);
574 assert_eq!(opt.label, deserialized.label);
575 }
576
577 #[test]
578 fn test_ask_result_serialization() {
579 let result = AskResult::from_option("yes".to_string(), 0);
580 let json = serde_json::to_string(&result).unwrap();
581 let deserialized: AskResult = serde_json::from_str(&json).unwrap();
582
583 assert_eq!(result.response, deserialized.response);
584 assert_eq!(result.from_option, deserialized.from_option);
585 assert_eq!(result.option_index, deserialized.option_index);
586 }
587}