Skip to main content

agent_air_runtime/controller/tools/
ask_user_questions.rs

1//! AskUserQuestions tool implementation
2//!
3//! This tool allows the LLM to ask the user one or more questions
4//! with structured response options (single choice, multi choice, or free text).
5
6use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9
10use serde::{Deserialize, Serialize};
11
12use std::sync::Arc;
13
14use super::types::{
15    DisplayConfig, DisplayResult, Executable, ResultContentType, ToolContext, ToolType,
16};
17use super::user_interaction::UserInteractionRegistry;
18
19/// AskUserQuestions tool name constant.
20pub const ASK_USER_QUESTIONS_TOOL_NAME: &str = "ask_user_questions";
21
22/// AskUserQuestions tool description constant.
23pub const ASK_USER_QUESTIONS_TOOL_DESCRIPTION: &str = "Ask the user one or more questions with structured response options. \
24     Supports single choice, multiple choice, and free text question types.";
25
26/// AskUserQuestions tool JSON schema constant.
27pub const ASK_USER_QUESTIONS_TOOL_SCHEMA: &str = r#"{
28    "type": "object",
29    "properties": {
30        "questions": {
31            "type": "array",
32            "description": "List of questions to ask the user",
33            "items": {
34                "type": "object",
35                "properties": {
36                    "text": {
37                        "type": "string",
38                        "description": "The question text to display"
39                    },
40                    "type": {
41                        "type": "string",
42                        "enum": ["SingleChoice", "MultiChoice", "FreeText"],
43                        "description": "The type of question"
44                    },
45                    "choices": {
46                        "type": "array",
47                        "description": "Available choices for SingleChoice/MultiChoice. User can always type a custom answer instead.",
48                        "items": {
49                            "type": "string",
50                            "description": "Choice text to display"
51                        }
52                    },
53                    "required": {
54                        "type": "boolean",
55                        "description": "Whether an answer is required"
56                    },
57                    "defaultValue": {
58                        "type": "string",
59                        "description": "Default value for FreeText questions"
60                    }
61                },
62                "required": ["text", "type"]
63            }
64        }
65    },
66    "required": ["questions"]
67}"#;
68
69/// Question types supported by the tool.
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
71#[serde(tag = "type")]
72pub enum Question {
73    /// Single choice question - user can select exactly one option.
74    /// User can always type a custom answer in addition to the provided choices.
75    SingleChoice {
76        /// Question text to display.
77        text: String,
78        /// Available choices (as simple strings).
79        choices: Vec<String>,
80        /// Whether an answer is required.
81        #[serde(default)]
82        required: bool,
83    },
84    /// Multiple choice question - user can select zero or more options.
85    /// User can always type a custom answer in addition to the provided choices.
86    MultiChoice {
87        /// Question text to display.
88        text: String,
89        /// Available choices (as simple strings).
90        choices: Vec<String>,
91        /// Whether an answer is required.
92        #[serde(default)]
93        required: bool,
94    },
95    /// Free text question - user can enter any text.
96    FreeText {
97        /// Question text to display.
98        text: String,
99        /// Default value for the text field.
100        #[serde(default, rename = "defaultValue")]
101        default_value: Option<String>,
102        /// Whether an answer is required.
103        #[serde(default)]
104        required: bool,
105    },
106}
107
108impl Question {
109    /// Get the question text.
110    pub fn text(&self) -> &str {
111        match self {
112            Question::SingleChoice { text, .. } => text,
113            Question::MultiChoice { text, .. } => text,
114            Question::FreeText { text, .. } => text,
115        }
116    }
117
118    /// Check if the question is required.
119    pub fn is_required(&self) -> bool {
120        match self {
121            Question::SingleChoice { required, .. } => *required,
122            Question::MultiChoice { required, .. } => *required,
123            Question::FreeText { required, .. } => *required,
124        }
125    }
126
127    /// Get the choices for this question (empty for FreeText).
128    pub fn choices(&self) -> &[String] {
129        match self {
130            Question::SingleChoice { choices, .. } => choices,
131            Question::MultiChoice { choices, .. } => choices,
132            Question::FreeText { .. } => &[],
133        }
134    }
135}
136
137/// Answer to a question - simplified to just question text and answer values.
138#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
139pub struct Answer {
140    /// The question text this answer corresponds to.
141    pub question: String,
142    /// The answer value(s). For SingleChoice/FreeText this will have one element.
143    /// For MultiChoice this may have multiple elements.
144    pub answer: Vec<String>,
145}
146
147/// Error codes for validation failures.
148#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
149#[serde(rename_all = "snake_case")]
150pub enum ValidationErrorCode {
151    /// Required field was not answered.
152    RequiredFieldEmpty,
153    /// More than one selection for SingleChoice.
154    TooManySelections,
155    /// No choices provided for choice question.
156    EmptyChoices,
157    /// Unknown question in answer.
158    UnknownQuestion,
159}
160
161impl std::fmt::Display for ValidationErrorCode {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        match self {
164            ValidationErrorCode::RequiredFieldEmpty => write!(f, "required_field_empty"),
165            ValidationErrorCode::TooManySelections => write!(f, "too_many_selections"),
166            ValidationErrorCode::EmptyChoices => write!(f, "empty_choices"),
167            ValidationErrorCode::UnknownQuestion => write!(f, "unknown_question"),
168        }
169    }
170}
171
172/// Detail about a single validation error.
173#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
174pub struct ValidationErrorDetail {
175    /// Question text where the error occurred.
176    pub question: String,
177    /// Error code.
178    pub error: ValidationErrorCode,
179}
180
181/// Validation error response.
182#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
183pub struct ValidationError {
184    /// Error type identifier.
185    pub error: String,
186    /// List of validation error details.
187    pub details: Vec<ValidationErrorDetail>,
188}
189
190impl ValidationError {
191    /// Create a new validation error.
192    pub fn new(details: Vec<ValidationErrorDetail>) -> Self {
193        Self {
194            error: "validation_failed".to_string(),
195            details,
196        }
197    }
198}
199
200impl std::fmt::Display for ValidationError {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        write!(f, "Validation failed: ")?;
203        for (i, detail) in self.details.iter().enumerate() {
204            if i > 0 {
205                write!(f, ", ")?;
206            }
207            write!(f, "'{}': {}", detail.question, detail.error)?;
208        }
209        Ok(())
210    }
211}
212
213impl std::error::Error for ValidationError {}
214
215/// Request to ask user questions.
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct AskUserQuestionsRequest {
218    /// List of questions to ask.
219    pub questions: Vec<Question>,
220}
221
222impl AskUserQuestionsRequest {
223    /// Validate the request structure (before sending to user).
224    pub fn validate(&self) -> Result<(), ValidationError> {
225        let mut errors = Vec::new();
226
227        for question in &self.questions {
228            match question {
229                Question::SingleChoice { text, choices, .. }
230                | Question::MultiChoice { text, choices, .. } => {
231                    // Check for empty choices
232                    if choices.is_empty() {
233                        errors.push(ValidationErrorDetail {
234                            question: text.clone(),
235                            error: ValidationErrorCode::EmptyChoices,
236                        });
237                    }
238                }
239                Question::FreeText { .. } => {
240                    // No validation needed for FreeText
241                }
242            }
243        }
244
245        if errors.is_empty() {
246            Ok(())
247        } else {
248            Err(ValidationError::new(errors))
249        }
250    }
251}
252
253/// Response containing answers to questions.
254#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
255pub struct AskUserQuestionsResponse {
256    /// List of answers.
257    pub answers: Vec<Answer>,
258}
259
260impl AskUserQuestionsResponse {
261    /// Validate the response against the request.
262    pub fn validate(&self, request: &AskUserQuestionsRequest) -> Result<(), ValidationError> {
263        let mut errors = Vec::new();
264
265        // Build a map of questions by text
266        let questions: HashMap<&str, &Question> =
267            request.questions.iter().map(|q| (q.text(), q)).collect();
268
269        // Track which questions have been answered
270        let mut answered: std::collections::HashSet<&str> = std::collections::HashSet::new();
271
272        for answer in &self.answers {
273            // Check if question exists
274            let Some(question) = questions.get(answer.question.as_str()) else {
275                errors.push(ValidationErrorDetail {
276                    question: answer.question.clone(),
277                    error: ValidationErrorCode::UnknownQuestion,
278                });
279                continue;
280            };
281
282            answered.insert(answer.question.as_str());
283
284            // Check SingleChoice has at most one selection
285            if let Question::SingleChoice { .. } = question
286                && answer.answer.len() > 1
287            {
288                errors.push(ValidationErrorDetail {
289                    question: answer.question.clone(),
290                    error: ValidationErrorCode::TooManySelections,
291                });
292            }
293        }
294
295        // Check required questions are answered
296        for question in &request.questions {
297            let question_text = question.text();
298            if question.is_required() {
299                // Check if question was answered with non-empty content
300                let has_valid_answer = self.answers.iter().any(|a| {
301                    a.question == question_text
302                        && !a.answer.is_empty()
303                        && a.answer.iter().any(|s| !s.is_empty())
304                });
305
306                if !has_valid_answer {
307                    errors.push(ValidationErrorDetail {
308                        question: question_text.to_string(),
309                        error: ValidationErrorCode::RequiredFieldEmpty,
310                    });
311                }
312            }
313        }
314
315        if errors.is_empty() {
316            Ok(())
317        } else {
318            Err(ValidationError::new(errors))
319        }
320    }
321}
322
323/// Tool that asks the user structured questions.
324pub struct AskUserQuestionsTool {
325    /// Registry for managing pending user interactions.
326    registry: Arc<UserInteractionRegistry>,
327}
328
329impl AskUserQuestionsTool {
330    /// Create a new AskUserQuestionsTool instance.
331    ///
332    /// # Arguments
333    /// * `registry` - The user interaction registry to use for tracking pending questions.
334    pub fn new(registry: Arc<UserInteractionRegistry>) -> Self {
335        Self { registry }
336    }
337}
338
339impl Executable for AskUserQuestionsTool {
340    fn name(&self) -> &str {
341        ASK_USER_QUESTIONS_TOOL_NAME
342    }
343
344    fn description(&self) -> &str {
345        ASK_USER_QUESTIONS_TOOL_DESCRIPTION
346    }
347
348    fn input_schema(&self) -> &str {
349        ASK_USER_QUESTIONS_TOOL_SCHEMA
350    }
351
352    fn tool_type(&self) -> ToolType {
353        ToolType::UserInteraction
354    }
355
356    fn execute(
357        &self,
358        context: ToolContext,
359        input: HashMap<String, serde_json::Value>,
360    ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
361        let registry = self.registry.clone();
362
363        Box::pin(async move {
364            // Parse the input into a request
365            let questions_value = input
366                .get("questions")
367                .ok_or_else(|| "Missing 'questions' field".to_string())?;
368
369            let questions: Vec<Question> = serde_json::from_value(questions_value.clone())
370                .map_err(|e| format!("Failed to parse questions: {}", e))?;
371
372            let request = AskUserQuestionsRequest { questions };
373
374            // Validate the request
375            if let Err(validation_error) = request.validate() {
376                return Err(serde_json::to_string(&validation_error)
377                    .unwrap_or_else(|_| validation_error.to_string()));
378            }
379
380            // Register the interaction and wait for user response
381            let rx = registry
382                .register(
383                    context.tool_use_id,
384                    context.session_id,
385                    request.clone(),
386                    context.turn_id,
387                )
388                .await
389                .map_err(|e| format!("Failed to register interaction: {}", e))?;
390
391            // Block waiting for user response
392            let response = rx
393                .await
394                .map_err(|_| "User declined to answer".to_string())?;
395
396            // Validate the response against the request
397            if let Err(validation_error) = response.validate(&request) {
398                return Err(serde_json::to_string(&validation_error)
399                    .unwrap_or_else(|_| validation_error.to_string()));
400            }
401
402            // Return the validated response as JSON
403            serde_json::to_string(&response)
404                .map_err(|e| format!("Failed to serialize response: {}", e))
405        })
406    }
407
408    fn display_config(&self) -> DisplayConfig {
409        DisplayConfig {
410            display_name: "Ask User Questions".to_string(),
411            display_title: Box::new(|input| {
412                input
413                    .get("questions")
414                    .and_then(|v| v.as_array())
415                    .map(|arr| {
416                        if arr.len() == 1 {
417                            "1 question".to_string()
418                        } else {
419                            format!("{} questions", arr.len())
420                        }
421                    })
422                    .unwrap_or_default()
423            }),
424            display_content: Box::new(|input, _result| {
425                let content = input
426                    .get("questions")
427                    .and_then(|v| v.as_array())
428                    .map(|questions| {
429                        questions
430                            .iter()
431                            .filter_map(|q| q.get("text").and_then(|t| t.as_str()))
432                            .collect::<Vec<_>>()
433                            .join("\n")
434                    })
435                    .unwrap_or_default();
436
437                DisplayResult {
438                    content,
439                    content_type: ResultContentType::PlainText,
440                    is_truncated: false,
441                    full_length: 0,
442                }
443            }),
444        }
445    }
446
447    fn compact_summary(&self, input: &HashMap<String, serde_json::Value>, _result: &str) -> String {
448        let count = input
449            .get("questions")
450            .and_then(|v| v.as_array())
451            .map(|arr| arr.len())
452            .unwrap_or(0);
453        format!("[AskUserQuestions: {} question(s)]", count)
454    }
455
456    fn handles_own_permissions(&self) -> bool {
457        true // User interaction tool - no permissions needed
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_parse_single_choice_question() {
467        let json = r#"{
468            "type": "SingleChoice",
469            "text": "Which database?",
470            "choices": ["PostgreSQL", "MySQL", "SQLite"],
471            "required": true
472        }"#;
473
474        let question: Question = serde_json::from_str(json).unwrap();
475        assert_eq!(question.text(), "Which database?");
476        assert!(question.is_required());
477
478        if let Question::SingleChoice { choices, .. } = question {
479            assert_eq!(choices.len(), 3);
480            assert_eq!(choices[0], "PostgreSQL");
481        } else {
482            panic!("Expected SingleChoice");
483        }
484    }
485
486    #[test]
487    fn test_parse_multi_choice_question() {
488        let json = r#"{
489            "type": "MultiChoice",
490            "text": "Which features?",
491            "choices": ["Authentication", "Logging", "Caching"],
492            "required": false
493        }"#;
494
495        let question: Question = serde_json::from_str(json).unwrap();
496        assert_eq!(question.text(), "Which features?");
497        assert!(!question.is_required());
498
499        if let Question::MultiChoice { choices, .. } = question {
500            assert_eq!(choices.len(), 3);
501        } else {
502            panic!("Expected MultiChoice");
503        }
504    }
505
506    #[test]
507    fn test_parse_free_text_question() {
508        let json = r#"{
509            "type": "FreeText",
510            "text": "Any notes?",
511            "defaultValue": "None",
512            "required": false
513        }"#;
514
515        let question: Question = serde_json::from_str(json).unwrap();
516        assert_eq!(question.text(), "Any notes?");
517
518        if let Question::FreeText { default_value, .. } = question {
519            assert_eq!(default_value, Some("None".to_string()));
520        } else {
521            panic!("Expected FreeText");
522        }
523    }
524
525    #[test]
526    fn test_validate_request_empty_choices() {
527        let request = AskUserQuestionsRequest {
528            questions: vec![Question::SingleChoice {
529                text: "Question?".to_string(),
530                choices: vec![],
531                required: true,
532            }],
533        };
534
535        let err = request.validate().unwrap_err();
536        assert_eq!(err.details.len(), 1);
537        assert_eq!(err.details[0].error, ValidationErrorCode::EmptyChoices);
538    }
539
540    #[test]
541    fn test_validate_response_too_many_selections() {
542        let request = AskUserQuestionsRequest {
543            questions: vec![Question::SingleChoice {
544                text: "Question?".to_string(),
545                choices: vec!["A".to_string(), "B".to_string()],
546                required: true,
547            }],
548        };
549
550        let response = AskUserQuestionsResponse {
551            answers: vec![Answer {
552                question: "Question?".to_string(),
553                answer: vec!["A".to_string(), "B".to_string()],
554            }],
555        };
556
557        let err = response.validate(&request).unwrap_err();
558        assert!(
559            err.details
560                .iter()
561                .any(|d| d.error == ValidationErrorCode::TooManySelections)
562        );
563    }
564
565    #[test]
566    fn test_validate_response_required_field_empty() {
567        let request = AskUserQuestionsRequest {
568            questions: vec![Question::SingleChoice {
569                text: "Question?".to_string(),
570                choices: vec!["A".to_string()],
571                required: true,
572            }],
573        };
574
575        let response = AskUserQuestionsResponse { answers: vec![] };
576
577        let err = response.validate(&request).unwrap_err();
578        assert!(
579            err.details
580                .iter()
581                .any(|d| d.error == ValidationErrorCode::RequiredFieldEmpty)
582        );
583    }
584
585    #[test]
586    fn test_validate_response_unknown_question() {
587        let request = AskUserQuestionsRequest {
588            questions: vec![Question::SingleChoice {
589                text: "Question?".to_string(),
590                choices: vec!["A".to_string()],
591                required: false,
592            }],
593        };
594
595        let response = AskUserQuestionsResponse {
596            answers: vec![Answer {
597                question: "Unknown question?".to_string(),
598                answer: vec!["A".to_string()],
599            }],
600        };
601
602        let err = response.validate(&request).unwrap_err();
603        assert!(
604            err.details
605                .iter()
606                .any(|d| d.error == ValidationErrorCode::UnknownQuestion)
607        );
608    }
609
610    #[test]
611    fn test_validate_response_success() {
612        let request = AskUserQuestionsRequest {
613            questions: vec![
614                Question::SingleChoice {
615                    text: "Question 1?".to_string(),
616                    choices: vec!["A".to_string(), "B".to_string()],
617                    required: true,
618                },
619                Question::MultiChoice {
620                    text: "Question 2?".to_string(),
621                    choices: vec!["X".to_string(), "Y".to_string()],
622                    required: false,
623                },
624                Question::FreeText {
625                    text: "Question 3?".to_string(),
626                    default_value: None,
627                    required: false,
628                },
629            ],
630        };
631
632        let response = AskUserQuestionsResponse {
633            answers: vec![
634                Answer {
635                    question: "Question 1?".to_string(),
636                    answer: vec!["A".to_string()],
637                },
638                Answer {
639                    question: "Question 2?".to_string(),
640                    answer: vec!["X".to_string(), "Y".to_string()],
641                },
642                Answer {
643                    question: "Question 3?".to_string(),
644                    answer: vec!["Some notes".to_string()],
645                },
646            ],
647        };
648
649        assert!(response.validate(&request).is_ok());
650    }
651
652    #[test]
653    fn test_answer_serialization() {
654        let answer = Answer {
655            question: "Which database?".to_string(),
656            answer: vec!["PostgreSQL".to_string()],
657        };
658
659        let json = serde_json::to_string(&answer).unwrap();
660        assert!(json.contains("question"));
661        assert!(json.contains("answer"));
662        assert!(json.contains("PostgreSQL"));
663    }
664
665    #[test]
666    fn test_custom_answer_allowed() {
667        // Users can always type a custom answer, even for choice questions
668        let request = AskUserQuestionsRequest {
669            questions: vec![Question::SingleChoice {
670                text: "Which database?".to_string(),
671                choices: vec!["PostgreSQL".to_string(), "MySQL".to_string()],
672                required: true,
673            }],
674        };
675
676        let response = AskUserQuestionsResponse {
677            answers: vec![Answer {
678                question: "Which database?".to_string(),
679                answer: vec!["MongoDB".to_string()], // Custom answer not in choices
680            }],
681        };
682
683        // Should succeed - custom answers are always allowed
684        assert!(response.validate(&request).is_ok());
685    }
686}