agent_core/controller/tools/
ask_user_questions.rs

1//! AskUserQuestions tool implementation
2//!
3//! This tool allows the LLM to ask the user one or more questions
4//! with structured response options (single choice, multi choice, or free text).
5
6use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9
10use serde::{Deserialize, Serialize};
11
12use std::sync::Arc;
13
14use super::types::{DisplayConfig, DisplayResult, Executable, ResultContentType, ToolContext, ToolType};
15use super::user_interaction::UserInteractionRegistry;
16
17/// AskUserQuestions tool name constant.
18pub const ASK_USER_QUESTIONS_TOOL_NAME: &str = "ask_user_questions";
19
20/// AskUserQuestions tool description constant.
21pub const ASK_USER_QUESTIONS_TOOL_DESCRIPTION: &str =
22    "Ask the user one or more questions with structured response options. \
23     Supports single choice, multiple choice, and free text question types.";
24
25/// AskUserQuestions tool JSON schema constant.
26pub const ASK_USER_QUESTIONS_TOOL_SCHEMA: &str = r#"{
27    "type": "object",
28    "properties": {
29        "questions": {
30            "type": "array",
31            "description": "List of questions to ask the user",
32            "items": {
33                "type": "object",
34                "properties": {
35                    "text": {
36                        "type": "string",
37                        "description": "The question text to display"
38                    },
39                    "type": {
40                        "type": "string",
41                        "enum": ["SingleChoice", "MultiChoice", "FreeText"],
42                        "description": "The type of question"
43                    },
44                    "choices": {
45                        "type": "array",
46                        "description": "Available choices for SingleChoice/MultiChoice. User can always type a custom answer instead.",
47                        "items": {
48                            "type": "string",
49                            "description": "Choice text to display"
50                        }
51                    },
52                    "required": {
53                        "type": "boolean",
54                        "description": "Whether an answer is required"
55                    },
56                    "defaultValue": {
57                        "type": "string",
58                        "description": "Default value for FreeText questions"
59                    }
60                },
61                "required": ["text", "type"]
62            }
63        }
64    },
65    "required": ["questions"]
66}"#;
67
68/// Question types supported by the tool.
69#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
70#[serde(tag = "type")]
71pub enum Question {
72    /// Single choice question - user can select exactly one option.
73    /// User can always type a custom answer in addition to the provided choices.
74    SingleChoice {
75        /// Question text to display.
76        text: String,
77        /// Available choices (as simple strings).
78        choices: Vec<String>,
79        /// Whether an answer is required.
80        #[serde(default)]
81        required: bool,
82    },
83    /// Multiple choice question - user can select zero or more options.
84    /// User can always type a custom answer in addition to the provided choices.
85    MultiChoice {
86        /// Question text to display.
87        text: String,
88        /// Available choices (as simple strings).
89        choices: Vec<String>,
90        /// Whether an answer is required.
91        #[serde(default)]
92        required: bool,
93    },
94    /// Free text question - user can enter any text.
95    FreeText {
96        /// Question text to display.
97        text: String,
98        /// Default value for the text field.
99        #[serde(default, rename = "defaultValue")]
100        default_value: Option<String>,
101        /// Whether an answer is required.
102        #[serde(default)]
103        required: bool,
104    },
105}
106
107impl Question {
108    /// Get the question text.
109    pub fn text(&self) -> &str {
110        match self {
111            Question::SingleChoice { text, .. } => text,
112            Question::MultiChoice { text, .. } => text,
113            Question::FreeText { text, .. } => text,
114        }
115    }
116
117    /// Check if the question is required.
118    pub fn is_required(&self) -> bool {
119        match self {
120            Question::SingleChoice { required, .. } => *required,
121            Question::MultiChoice { required, .. } => *required,
122            Question::FreeText { required, .. } => *required,
123        }
124    }
125
126    /// Get the choices for this question (empty for FreeText).
127    pub fn choices(&self) -> &[String] {
128        match self {
129            Question::SingleChoice { choices, .. } => choices,
130            Question::MultiChoice { choices, .. } => choices,
131            Question::FreeText { .. } => &[],
132        }
133    }
134}
135
136/// Answer to a question - simplified to just question text and answer values.
137#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
138pub struct Answer {
139    /// The question text this answer corresponds to.
140    pub question: String,
141    /// The answer value(s). For SingleChoice/FreeText this will have one element.
142    /// For MultiChoice this may have multiple elements.
143    pub answer: Vec<String>,
144}
145
146/// Error codes for validation failures.
147#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
148#[serde(rename_all = "snake_case")]
149pub enum ValidationErrorCode {
150    /// Required field was not answered.
151    RequiredFieldEmpty,
152    /// More than one selection for SingleChoice.
153    TooManySelections,
154    /// No choices provided for choice question.
155    EmptyChoices,
156    /// Unknown question in answer.
157    UnknownQuestion,
158}
159
160impl std::fmt::Display for ValidationErrorCode {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match self {
163            ValidationErrorCode::RequiredFieldEmpty => write!(f, "required_field_empty"),
164            ValidationErrorCode::TooManySelections => write!(f, "too_many_selections"),
165            ValidationErrorCode::EmptyChoices => write!(f, "empty_choices"),
166            ValidationErrorCode::UnknownQuestion => write!(f, "unknown_question"),
167        }
168    }
169}
170
171/// Detail about a single validation error.
172#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
173pub struct ValidationErrorDetail {
174    /// Question text where the error occurred.
175    pub question: String,
176    /// Error code.
177    pub error: ValidationErrorCode,
178}
179
180/// Validation error response.
181#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
182pub struct ValidationError {
183    /// Error type identifier.
184    pub error: String,
185    /// List of validation error details.
186    pub details: Vec<ValidationErrorDetail>,
187}
188
189impl ValidationError {
190    /// Create a new validation error.
191    pub fn new(details: Vec<ValidationErrorDetail>) -> Self {
192        Self {
193            error: "validation_failed".to_string(),
194            details,
195        }
196    }
197}
198
199impl std::fmt::Display for ValidationError {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        write!(f, "Validation failed: ")?;
202        for (i, detail) in self.details.iter().enumerate() {
203            if i > 0 {
204                write!(f, ", ")?;
205            }
206            write!(f, "'{}': {}", detail.question, detail.error)?;
207        }
208        Ok(())
209    }
210}
211
212impl std::error::Error for ValidationError {}
213
214/// Request to ask user questions.
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct AskUserQuestionsRequest {
217    /// List of questions to ask.
218    pub questions: Vec<Question>,
219}
220
221impl AskUserQuestionsRequest {
222    /// Validate the request structure (before sending to user).
223    pub fn validate(&self) -> Result<(), ValidationError> {
224        let mut errors = Vec::new();
225
226        for question in &self.questions {
227            match question {
228                Question::SingleChoice { text, choices, .. }
229                | Question::MultiChoice { text, choices, .. } => {
230                    // Check for empty choices
231                    if choices.is_empty() {
232                        errors.push(ValidationErrorDetail {
233                            question: text.clone(),
234                            error: ValidationErrorCode::EmptyChoices,
235                        });
236                    }
237                }
238                Question::FreeText { .. } => {
239                    // No validation needed for FreeText
240                }
241            }
242        }
243
244        if errors.is_empty() {
245            Ok(())
246        } else {
247            Err(ValidationError::new(errors))
248        }
249    }
250}
251
252/// Response containing answers to questions.
253#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
254pub struct AskUserQuestionsResponse {
255    /// List of answers.
256    pub answers: Vec<Answer>,
257}
258
259impl AskUserQuestionsResponse {
260    /// Validate the response against the request.
261    pub fn validate(&self, request: &AskUserQuestionsRequest) -> Result<(), ValidationError> {
262        let mut errors = Vec::new();
263
264        // Build a map of questions by text
265        let questions: HashMap<&str, &Question> =
266            request.questions.iter().map(|q| (q.text(), q)).collect();
267
268        // Track which questions have been answered
269        let mut answered: std::collections::HashSet<&str> = std::collections::HashSet::new();
270
271        for answer in &self.answers {
272            // Check if question exists
273            let Some(question) = questions.get(answer.question.as_str()) else {
274                errors.push(ValidationErrorDetail {
275                    question: answer.question.clone(),
276                    error: ValidationErrorCode::UnknownQuestion,
277                });
278                continue;
279            };
280
281            answered.insert(answer.question.as_str());
282
283            // Check SingleChoice has at most one selection
284            if let Question::SingleChoice { .. } = question {
285                if answer.answer.len() > 1 {
286                    errors.push(ValidationErrorDetail {
287                        question: answer.question.clone(),
288                        error: ValidationErrorCode::TooManySelections,
289                    });
290                }
291            }
292        }
293
294        // Check required questions are answered
295        for question in &request.questions {
296            let question_text = question.text();
297            if question.is_required() {
298                // Check if question was answered with non-empty content
299                let has_valid_answer = self.answers.iter().any(|a| {
300                    a.question == question_text && !a.answer.is_empty() && a.answer.iter().any(|s| !s.is_empty())
301                });
302
303                if !has_valid_answer {
304                    errors.push(ValidationErrorDetail {
305                        question: question_text.to_string(),
306                        error: ValidationErrorCode::RequiredFieldEmpty,
307                    });
308                }
309            }
310        }
311
312        if errors.is_empty() {
313            Ok(())
314        } else {
315            Err(ValidationError::new(errors))
316        }
317    }
318}
319
320/// Tool that asks the user structured questions.
321pub struct AskUserQuestionsTool {
322    /// Registry for managing pending user interactions.
323    registry: Arc<UserInteractionRegistry>,
324}
325
326impl AskUserQuestionsTool {
327    /// Create a new AskUserQuestionsTool instance.
328    ///
329    /// # Arguments
330    /// * `registry` - The user interaction registry to use for tracking pending questions.
331    pub fn new(registry: Arc<UserInteractionRegistry>) -> Self {
332        Self { registry }
333    }
334}
335
336impl Executable for AskUserQuestionsTool {
337    fn name(&self) -> &str {
338        ASK_USER_QUESTIONS_TOOL_NAME
339    }
340
341    fn description(&self) -> &str {
342        ASK_USER_QUESTIONS_TOOL_DESCRIPTION
343    }
344
345    fn input_schema(&self) -> &str {
346        ASK_USER_QUESTIONS_TOOL_SCHEMA
347    }
348
349    fn tool_type(&self) -> ToolType {
350        ToolType::UserInteraction
351    }
352
353    fn execute(
354        &self,
355        context: ToolContext,
356        input: HashMap<String, serde_json::Value>,
357    ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
358        let registry = self.registry.clone();
359
360        Box::pin(async move {
361            // Parse the input into a request
362            let questions_value = input
363                .get("questions")
364                .ok_or_else(|| "Missing 'questions' field".to_string())?;
365
366            let questions: Vec<Question> = serde_json::from_value(questions_value.clone())
367                .map_err(|e| format!("Failed to parse questions: {}", e))?;
368
369            let request = AskUserQuestionsRequest { questions };
370
371            // Validate the request
372            if let Err(validation_error) = request.validate() {
373                return Err(serde_json::to_string(&validation_error)
374                    .unwrap_or_else(|_| validation_error.to_string()));
375            }
376
377            // Register the interaction and wait for user response
378            let rx = registry
379                .register(
380                    context.tool_use_id,
381                    context.session_id,
382                    request.clone(),
383                    context.turn_id,
384                )
385                .await
386                .map_err(|e| format!("Failed to register interaction: {}", e))?;
387
388            // Block waiting for user response
389            let response = rx
390                .await
391                .map_err(|_| "User declined to answer".to_string())?;
392
393            // Validate the response against the request
394            if let Err(validation_error) = response.validate(&request) {
395                return Err(serde_json::to_string(&validation_error)
396                    .unwrap_or_else(|_| validation_error.to_string()));
397            }
398
399            // Return the validated response as JSON
400            serde_json::to_string(&response)
401                .map_err(|e| format!("Failed to serialize response: {}", e))
402        })
403    }
404
405    fn display_config(&self) -> DisplayConfig {
406        DisplayConfig {
407            display_name: "Ask User Questions".to_string(),
408            display_title: Box::new(|input| {
409                input
410                    .get("questions")
411                    .and_then(|v| v.as_array())
412                    .map(|arr| {
413                        if arr.len() == 1 {
414                            "1 question".to_string()
415                        } else {
416                            format!("{} questions", arr.len())
417                        }
418                    })
419                    .unwrap_or_default()
420            }),
421            display_content: Box::new(|input, _result| {
422                let content = input
423                    .get("questions")
424                    .and_then(|v| v.as_array())
425                    .map(|questions| {
426                        questions
427                            .iter()
428                            .filter_map(|q| q.get("text").and_then(|t| t.as_str()))
429                            .collect::<Vec<_>>()
430                            .join("\n")
431                    })
432                    .unwrap_or_default();
433
434                DisplayResult {
435                    content,
436                    content_type: ResultContentType::PlainText,
437                    is_truncated: false,
438                    full_length: 0,
439                }
440            }),
441        }
442    }
443
444    fn compact_summary(
445        &self,
446        input: &HashMap<String, serde_json::Value>,
447        _result: &str,
448    ) -> String {
449        let count = input
450            .get("questions")
451            .and_then(|v| v.as_array())
452            .map(|arr| arr.len())
453            .unwrap_or(0);
454        format!("[AskUserQuestions: {} question(s)]", count)
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_parse_single_choice_question() {
464        let json = r#"{
465            "type": "SingleChoice",
466            "text": "Which database?",
467            "choices": ["PostgreSQL", "MySQL", "SQLite"],
468            "required": true
469        }"#;
470
471        let question: Question = serde_json::from_str(json).unwrap();
472        assert_eq!(question.text(), "Which database?");
473        assert!(question.is_required());
474
475        if let Question::SingleChoice { choices, .. } = question {
476            assert_eq!(choices.len(), 3);
477            assert_eq!(choices[0], "PostgreSQL");
478        } else {
479            panic!("Expected SingleChoice");
480        }
481    }
482
483    #[test]
484    fn test_parse_multi_choice_question() {
485        let json = r#"{
486            "type": "MultiChoice",
487            "text": "Which features?",
488            "choices": ["Authentication", "Logging", "Caching"],
489            "required": false
490        }"#;
491
492        let question: Question = serde_json::from_str(json).unwrap();
493        assert_eq!(question.text(), "Which features?");
494        assert!(!question.is_required());
495
496        if let Question::MultiChoice { choices, .. } = question {
497            assert_eq!(choices.len(), 3);
498        } else {
499            panic!("Expected MultiChoice");
500        }
501    }
502
503    #[test]
504    fn test_parse_free_text_question() {
505        let json = r#"{
506            "type": "FreeText",
507            "text": "Any notes?",
508            "defaultValue": "None",
509            "required": false
510        }"#;
511
512        let question: Question = serde_json::from_str(json).unwrap();
513        assert_eq!(question.text(), "Any notes?");
514
515        if let Question::FreeText { default_value, .. } = question {
516            assert_eq!(default_value, Some("None".to_string()));
517        } else {
518            panic!("Expected FreeText");
519        }
520    }
521
522    #[test]
523    fn test_validate_request_empty_choices() {
524        let request = AskUserQuestionsRequest {
525            questions: vec![Question::SingleChoice {
526                text: "Question?".to_string(),
527                choices: vec![],
528                required: true,
529            }],
530        };
531
532        let err = request.validate().unwrap_err();
533        assert_eq!(err.details.len(), 1);
534        assert_eq!(err.details[0].error, ValidationErrorCode::EmptyChoices);
535    }
536
537    #[test]
538    fn test_validate_response_too_many_selections() {
539        let request = AskUserQuestionsRequest {
540            questions: vec![Question::SingleChoice {
541                text: "Question?".to_string(),
542                choices: vec!["A".to_string(), "B".to_string()],
543                required: true,
544            }],
545        };
546
547        let response = AskUserQuestionsResponse {
548            answers: vec![Answer {
549                question: "Question?".to_string(),
550                answer: vec!["A".to_string(), "B".to_string()],
551            }],
552        };
553
554        let err = response.validate(&request).unwrap_err();
555        assert!(err
556            .details
557            .iter()
558            .any(|d| d.error == ValidationErrorCode::TooManySelections));
559    }
560
561    #[test]
562    fn test_validate_response_required_field_empty() {
563        let request = AskUserQuestionsRequest {
564            questions: vec![Question::SingleChoice {
565                text: "Question?".to_string(),
566                choices: vec!["A".to_string()],
567                required: true,
568            }],
569        };
570
571        let response = AskUserQuestionsResponse { answers: vec![] };
572
573        let err = response.validate(&request).unwrap_err();
574        assert!(err
575            .details
576            .iter()
577            .any(|d| d.error == ValidationErrorCode::RequiredFieldEmpty));
578    }
579
580    #[test]
581    fn test_validate_response_unknown_question() {
582        let request = AskUserQuestionsRequest {
583            questions: vec![Question::SingleChoice {
584                text: "Question?".to_string(),
585                choices: vec!["A".to_string()],
586                required: false,
587            }],
588        };
589
590        let response = AskUserQuestionsResponse {
591            answers: vec![Answer {
592                question: "Unknown question?".to_string(),
593                answer: vec!["A".to_string()],
594            }],
595        };
596
597        let err = response.validate(&request).unwrap_err();
598        assert!(err
599            .details
600            .iter()
601            .any(|d| d.error == ValidationErrorCode::UnknownQuestion));
602    }
603
604    #[test]
605    fn test_validate_response_success() {
606        let request = AskUserQuestionsRequest {
607            questions: vec![
608                Question::SingleChoice {
609                    text: "Question 1?".to_string(),
610                    choices: vec!["A".to_string(), "B".to_string()],
611                    required: true,
612                },
613                Question::MultiChoice {
614                    text: "Question 2?".to_string(),
615                    choices: vec!["X".to_string(), "Y".to_string()],
616                    required: false,
617                },
618                Question::FreeText {
619                    text: "Question 3?".to_string(),
620                    default_value: None,
621                    required: false,
622                },
623            ],
624        };
625
626        let response = AskUserQuestionsResponse {
627            answers: vec![
628                Answer {
629                    question: "Question 1?".to_string(),
630                    answer: vec!["A".to_string()],
631                },
632                Answer {
633                    question: "Question 2?".to_string(),
634                    answer: vec!["X".to_string(), "Y".to_string()],
635                },
636                Answer {
637                    question: "Question 3?".to_string(),
638                    answer: vec!["Some notes".to_string()],
639                },
640            ],
641        };
642
643        assert!(response.validate(&request).is_ok());
644    }
645
646    #[test]
647    fn test_answer_serialization() {
648        let answer = Answer {
649            question: "Which database?".to_string(),
650            answer: vec!["PostgreSQL".to_string()],
651        };
652
653        let json = serde_json::to_string(&answer).unwrap();
654        assert!(json.contains("question"));
655        assert!(json.contains("answer"));
656        assert!(json.contains("PostgreSQL"));
657    }
658
659    #[test]
660    fn test_custom_answer_allowed() {
661        // Users can always type a custom answer, even for choice questions
662        let request = AskUserQuestionsRequest {
663            questions: vec![Question::SingleChoice {
664                text: "Which database?".to_string(),
665                choices: vec!["PostgreSQL".to_string(), "MySQL".to_string()],
666                required: true,
667            }],
668        };
669
670        let response = AskUserQuestionsResponse {
671            answers: vec![Answer {
672                question: "Which database?".to_string(),
673                answer: vec!["MongoDB".to_string()], // Custom answer not in choices
674            }],
675        };
676
677        // Should succeed - custom answers are always allowed
678        assert!(response.validate(&request).is_ok());
679    }
680}