Skip to main content

aster/tools/
ask.rs

1//! Ask Tool Implementation
2//!
3//! Provides user interaction capabilities for the agent to ask questions
4//! and receive responses from the user.
5//!
6//! Requirements: 6.1, 6.2, 6.3, 6.4, 6.5
7
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::time::Duration;
14
15use crate::tools::base::{PermissionCheckResult, Tool};
16use crate::tools::context::{ToolContext, ToolResult};
17use crate::tools::error::ToolError;
18
19/// Default timeout for user response (5 minutes)
20pub const DEFAULT_ASK_TIMEOUT_SECS: u64 = 300;
21
22/// Callback type for handling user questions
23///
24/// The callback receives the question and optional options, and returns
25/// the user's response as a future.
26pub type AskCallback = Arc<
27    dyn Fn(String, Option<Vec<String>>) -> Pin<Box<dyn Future<Output = Option<String>> + Send>>
28        + Send
29        + Sync,
30>;
31
32/// A predefined option for the user to select
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct AskOption {
35    /// The value to return if this option is selected
36    pub value: String,
37    /// Optional display label (defaults to value if not provided)
38    pub label: Option<String>,
39}
40
41impl AskOption {
42    /// Create a new option with just a value
43    pub fn new(value: impl Into<String>) -> Self {
44        Self {
45            value: value.into(),
46            label: None,
47        }
48    }
49
50    /// Create a new option with a value and label
51    pub fn with_label(value: impl Into<String>, label: impl Into<String>) -> Self {
52        Self {
53            value: value.into(),
54            label: Some(label.into()),
55        }
56    }
57
58    /// Get the display text for this option
59    pub fn display(&self) -> &str {
60        self.label.as_deref().unwrap_or(&self.value)
61    }
62}
63
64/// Result of an ask operation
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct AskResult {
67    /// The user's response
68    pub response: String,
69    /// Whether the response was from a predefined option
70    pub from_option: bool,
71    /// The index of the selected option (if applicable)
72    pub option_index: Option<usize>,
73}
74
75impl AskResult {
76    /// Create a new AskResult from free-form input
77    pub fn from_input(response: String) -> Self {
78        Self {
79            response,
80            from_option: false,
81            option_index: None,
82        }
83    }
84
85    /// Create a new AskResult from an option selection
86    pub fn from_option(response: String, index: usize) -> Self {
87        Self {
88            response,
89            from_option: true,
90            option_index: Some(index),
91        }
92    }
93}
94
95/// Ask tool for user interaction
96///
97/// Allows the agent to ask questions to the user and receive responses.
98/// Supports:
99/// - Free-form text questions
100/// - Predefined options for selection
101/// - Configurable timeout
102///
103/// Requirements: 6.1, 6.2, 6.3, 6.4, 6.5
104pub struct AskTool {
105    /// Callback for handling user questions
106    callback: Option<AskCallback>,
107    /// Default timeout for user response
108    timeout: Duration,
109}
110
111impl Default for AskTool {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117impl AskTool {
118    /// Create a new AskTool without a callback
119    ///
120    /// Note: Without a callback, the tool will return an error when executed.
121    /// Use `with_callback` to set up the user interaction handler.
122    pub fn new() -> Self {
123        Self {
124            callback: None,
125            timeout: Duration::from_secs(DEFAULT_ASK_TIMEOUT_SECS),
126        }
127    }
128
129    /// Set the callback for handling user questions
130    pub fn with_callback(mut self, callback: AskCallback) -> Self {
131        self.callback = Some(callback);
132        self
133    }
134
135    /// Set the default timeout for user responses
136    pub fn with_timeout(mut self, timeout: Duration) -> Self {
137        self.timeout = timeout;
138        self
139    }
140
141    /// Check if a callback is configured
142    pub fn has_callback(&self) -> bool {
143        self.callback.is_some()
144    }
145
146    /// Get the configured timeout
147    pub fn timeout(&self) -> Duration {
148        self.timeout
149    }
150
151    /// Ask a question to the user
152    ///
153    /// This method invokes the callback with the question and optional options,
154    /// and waits for the user's response with timeout.
155    ///
156    /// # Arguments
157    /// * `question` - The question to ask
158    /// * `options` - Optional predefined options for the user to select from
159    ///
160    /// # Returns
161    /// * `Ok(AskResult)` - The user's response
162    /// * `Err(ToolError)` - If no callback is configured, timeout occurs, or user cancels
163    pub async fn ask(
164        &self,
165        question: &str,
166        options: Option<&[AskOption]>,
167    ) -> Result<AskResult, ToolError> {
168        let callback = self.callback.as_ref().ok_or_else(|| {
169            ToolError::execution_failed("No callback configured for user interaction")
170        })?;
171
172        // Convert options to string labels for the callback
173        let option_labels: Option<Vec<String>> =
174            options.map(|opts| opts.iter().map(|o| o.display().to_string()).collect());
175
176        // Call the callback with timeout
177        let response = tokio::time::timeout(
178            self.timeout,
179            callback(question.to_string(), option_labels.clone()),
180        )
181        .await
182        .map_err(|_| ToolError::timeout(self.timeout))?;
183
184        // Handle the response
185        match response {
186            Some(response_text) => {
187                // Check if response matches an option
188                if let Some(opts) = options {
189                    for (idx, opt) in opts.iter().enumerate() {
190                        if response_text == opt.value || response_text == opt.display() {
191                            return Ok(AskResult::from_option(opt.value.clone(), idx));
192                        }
193                    }
194                }
195                // Free-form response
196                Ok(AskResult::from_input(response_text))
197            }
198            None => Err(ToolError::execution_failed(
199                "User cancelled the interaction",
200            )),
201        }
202    }
203}
204
205#[async_trait]
206impl Tool for AskTool {
207    fn name(&self) -> &str {
208        "ask"
209    }
210
211    fn description(&self) -> &str {
212        "Ask a question to the user and wait for their response. \
213         Supports free-form text input or selection from predefined options. \
214         Use this tool when you need clarification, confirmation, or user input \
215         to proceed with a task."
216    }
217
218    fn input_schema(&self) -> serde_json::Value {
219        serde_json::json!({
220            "type": "object",
221            "properties": {
222                "question": {
223                    "type": "string",
224                    "description": "The question to ask the user"
225                },
226                "options": {
227                    "type": "array",
228                    "description": "Optional predefined options for the user to select from",
229                    "items": {
230                        "type": "object",
231                        "properties": {
232                            "value": {
233                                "type": "string",
234                                "description": "The value to return if this option is selected"
235                            },
236                            "label": {
237                                "type": "string",
238                                "description": "Optional display label (defaults to value)"
239                            }
240                        },
241                        "required": ["value"]
242                    }
243                }
244            },
245            "required": ["question"]
246        })
247    }
248
249    async fn execute(
250        &self,
251        params: serde_json::Value,
252        _context: &ToolContext,
253    ) -> Result<ToolResult, ToolError> {
254        // Parse question
255        let question = params
256            .get("question")
257            .and_then(|v| v.as_str())
258            .ok_or_else(|| ToolError::invalid_params("Missing required parameter: question"))?;
259
260        // Parse options if provided
261        let options: Option<Vec<AskOption>> = params
262            .get("options")
263            .and_then(|v| serde_json::from_value(v.clone()).ok());
264
265        // Ask the question
266        let result = self.ask(question, options.as_deref()).await?;
267
268        // Format the response
269        let output = if result.from_option {
270            format!(
271                "User selected option {}: {}",
272                result.option_index.unwrap_or(0) + 1,
273                result.response
274            )
275        } else {
276            format!("User response: {}", result.response)
277        };
278
279        Ok(ToolResult::success(output)
280            .with_metadata("response", serde_json::json!(result.response))
281            .with_metadata("from_option", serde_json::json!(result.from_option))
282            .with_metadata("option_index", serde_json::json!(result.option_index)))
283    }
284
285    async fn check_permissions(
286        &self,
287        _params: &serde_json::Value,
288        _context: &ToolContext,
289    ) -> PermissionCheckResult {
290        // Ask tool always requires user interaction, so it's always allowed
291        // The actual permission is implicit in the user's response
292        PermissionCheckResult::allow()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use std::path::PathBuf;
300
301    /// Create a mock callback that returns a fixed response
302    fn mock_callback(response: Option<String>) -> AskCallback {
303        Arc::new(move |_question, _options| {
304            let resp = response.clone();
305            Box::pin(async move { resp })
306        })
307    }
308
309    /// Create a mock callback that delays before responding
310    fn mock_callback_delayed(response: Option<String>, delay_ms: u64) -> AskCallback {
311        Arc::new(move |_question, _options| {
312            let resp = response.clone();
313            Box::pin(async move {
314                tokio::time::sleep(Duration::from_millis(delay_ms)).await;
315                resp
316            })
317        })
318    }
319
320    #[test]
321    fn test_ask_option_new() {
322        let opt = AskOption::new("yes");
323        assert_eq!(opt.value, "yes");
324        assert!(opt.label.is_none());
325        assert_eq!(opt.display(), "yes");
326    }
327
328    #[test]
329    fn test_ask_option_with_label() {
330        let opt = AskOption::with_label("y", "Yes, proceed");
331        assert_eq!(opt.value, "y");
332        assert_eq!(opt.label, Some("Yes, proceed".to_string()));
333        assert_eq!(opt.display(), "Yes, proceed");
334    }
335
336    #[test]
337    fn test_ask_result_from_input() {
338        let result = AskResult::from_input("hello".to_string());
339        assert_eq!(result.response, "hello");
340        assert!(!result.from_option);
341        assert!(result.option_index.is_none());
342    }
343
344    #[test]
345    fn test_ask_result_from_option() {
346        let result = AskResult::from_option("yes".to_string(), 0);
347        assert_eq!(result.response, "yes");
348        assert!(result.from_option);
349        assert_eq!(result.option_index, Some(0));
350    }
351
352    #[test]
353    fn test_ask_tool_new() {
354        let tool = AskTool::new();
355        assert!(!tool.has_callback());
356        assert_eq!(
357            tool.timeout(),
358            Duration::from_secs(DEFAULT_ASK_TIMEOUT_SECS)
359        );
360    }
361
362    #[test]
363    fn test_ask_tool_with_callback() {
364        let callback = mock_callback(Some("test".to_string()));
365        let tool = AskTool::new().with_callback(callback);
366        assert!(tool.has_callback());
367    }
368
369    #[test]
370    fn test_ask_tool_with_timeout() {
371        let tool = AskTool::new().with_timeout(Duration::from_secs(60));
372        assert_eq!(tool.timeout(), Duration::from_secs(60));
373    }
374
375    #[test]
376    fn test_ask_tool_default() {
377        let tool = AskTool::default();
378        assert!(!tool.has_callback());
379        assert_eq!(
380            tool.timeout(),
381            Duration::from_secs(DEFAULT_ASK_TIMEOUT_SECS)
382        );
383    }
384
385    #[tokio::test]
386    async fn test_ask_without_callback() {
387        let tool = AskTool::new();
388        let result = tool.ask("What is your name?", None).await;
389        assert!(result.is_err());
390        assert!(matches!(result.unwrap_err(), ToolError::ExecutionFailed(_)));
391    }
392
393    #[tokio::test]
394    async fn test_ask_free_form_response() {
395        let callback = mock_callback(Some("John".to_string()));
396        let tool = AskTool::new().with_callback(callback);
397
398        let result = tool.ask("What is your name?", None).await.unwrap();
399        assert_eq!(result.response, "John");
400        assert!(!result.from_option);
401        assert!(result.option_index.is_none());
402    }
403
404    #[tokio::test]
405    async fn test_ask_with_options_select_by_value() {
406        let callback = mock_callback(Some("yes".to_string()));
407        let tool = AskTool::new().with_callback(callback);
408
409        let options = vec![AskOption::new("yes"), AskOption::new("no")];
410
411        let result = tool.ask("Continue?", Some(&options)).await.unwrap();
412        assert_eq!(result.response, "yes");
413        assert!(result.from_option);
414        assert_eq!(result.option_index, Some(0));
415    }
416
417    #[tokio::test]
418    async fn test_ask_with_options_select_by_label() {
419        let callback = mock_callback(Some("Yes, proceed".to_string()));
420        let tool = AskTool::new().with_callback(callback);
421
422        let options = vec![
423            AskOption::with_label("y", "Yes, proceed"),
424            AskOption::with_label("n", "No, cancel"),
425        ];
426
427        let result = tool.ask("Continue?", Some(&options)).await.unwrap();
428        assert_eq!(result.response, "y");
429        assert!(result.from_option);
430        assert_eq!(result.option_index, Some(0));
431    }
432
433    #[tokio::test]
434    async fn test_ask_with_options_free_form() {
435        let callback = mock_callback(Some("maybe".to_string()));
436        let tool = AskTool::new().with_callback(callback);
437
438        let options = vec![AskOption::new("yes"), AskOption::new("no")];
439
440        let result = tool.ask("Continue?", Some(&options)).await.unwrap();
441        assert_eq!(result.response, "maybe");
442        assert!(!result.from_option);
443        assert!(result.option_index.is_none());
444    }
445
446    #[tokio::test]
447    async fn test_ask_user_cancels() {
448        let callback = mock_callback(None);
449        let tool = AskTool::new().with_callback(callback);
450
451        let result = tool.ask("What is your name?", None).await;
452        assert!(result.is_err());
453        assert!(matches!(result.unwrap_err(), ToolError::ExecutionFailed(_)));
454    }
455
456    #[tokio::test]
457    async fn test_ask_timeout() {
458        let callback = mock_callback_delayed(Some("response".to_string()), 200);
459        let tool = AskTool::new()
460            .with_callback(callback)
461            .with_timeout(Duration::from_millis(50));
462
463        let result = tool.ask("What is your name?", None).await;
464        assert!(result.is_err());
465        assert!(matches!(result.unwrap_err(), ToolError::Timeout(_)));
466    }
467
468    #[tokio::test]
469    async fn test_ask_tool_trait_name() {
470        let tool = AskTool::new();
471        assert_eq!(tool.name(), "ask");
472    }
473
474    #[tokio::test]
475    async fn test_ask_tool_trait_description() {
476        let tool = AskTool::new();
477        assert!(tool.description().contains("Ask a question"));
478    }
479
480    #[tokio::test]
481    async fn test_ask_tool_trait_input_schema() {
482        let tool = AskTool::new();
483        let schema = tool.input_schema();
484
485        assert_eq!(schema["type"], "object");
486        assert!(schema["properties"]["question"].is_object());
487        assert!(schema["properties"]["options"].is_object());
488        assert!(schema["required"]
489            .as_array()
490            .unwrap()
491            .contains(&serde_json::json!("question")));
492    }
493
494    #[tokio::test]
495    async fn test_ask_tool_execute_success() {
496        let callback = mock_callback(Some("John".to_string()));
497        let tool = AskTool::new().with_callback(callback);
498        let context = ToolContext::new(PathBuf::from("/tmp"));
499
500        let params = serde_json::json!({
501            "question": "What is your name?"
502        });
503
504        let result = tool.execute(params, &context).await.unwrap();
505        assert!(result.is_success());
506        assert!(result.output.unwrap().contains("John"));
507        assert_eq!(
508            result.metadata.get("response"),
509            Some(&serde_json::json!("John"))
510        );
511        assert_eq!(
512            result.metadata.get("from_option"),
513            Some(&serde_json::json!(false))
514        );
515    }
516
517    #[tokio::test]
518    async fn test_ask_tool_execute_with_options() {
519        let callback = mock_callback(Some("yes".to_string()));
520        let tool = AskTool::new().with_callback(callback);
521        let context = ToolContext::new(PathBuf::from("/tmp"));
522
523        let params = serde_json::json!({
524            "question": "Continue?",
525            "options": [
526                { "value": "yes", "label": "Yes" },
527                { "value": "no", "label": "No" }
528            ]
529        });
530
531        let result = tool.execute(params, &context).await.unwrap();
532        assert!(result.is_success());
533        assert!(result.output.unwrap().contains("selected option"));
534        assert_eq!(
535            result.metadata.get("from_option"),
536            Some(&serde_json::json!(true))
537        );
538        assert_eq!(
539            result.metadata.get("option_index"),
540            Some(&serde_json::json!(0))
541        );
542    }
543
544    #[tokio::test]
545    async fn test_ask_tool_execute_missing_question() {
546        let callback = mock_callback(Some("test".to_string()));
547        let tool = AskTool::new().with_callback(callback);
548        let context = ToolContext::new(PathBuf::from("/tmp"));
549
550        let params = serde_json::json!({});
551
552        let result = tool.execute(params, &context).await;
553        assert!(result.is_err());
554        assert!(matches!(result.unwrap_err(), ToolError::InvalidParams(_)));
555    }
556
557    #[tokio::test]
558    async fn test_ask_tool_check_permissions() {
559        let tool = AskTool::new();
560        let context = ToolContext::new(PathBuf::from("/tmp"));
561        let params = serde_json::json!({"question": "test"});
562
563        let result = tool.check_permissions(&params, &context).await;
564        assert!(result.is_allowed());
565    }
566
567    #[test]
568    fn test_ask_option_serialization() {
569        let opt = AskOption::with_label("y", "Yes");
570        let json = serde_json::to_string(&opt).unwrap();
571        let deserialized: AskOption = serde_json::from_str(&json).unwrap();
572
573        assert_eq!(opt.value, deserialized.value);
574        assert_eq!(opt.label, deserialized.label);
575    }
576
577    #[test]
578    fn test_ask_result_serialization() {
579        let result = AskResult::from_option("yes".to_string(), 0);
580        let json = serde_json::to_string(&result).unwrap();
581        let deserialized: AskResult = serde_json::from_str(&json).unwrap();
582
583        assert_eq!(result.response, deserialized.response);
584        assert_eq!(result.from_option, deserialized.from_option);
585        assert_eq!(result.option_index, deserialized.option_index);
586    }
587}