use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use super::types::{
DisplayConfig, DisplayResult, Executable, ResultContentType, ToolContext, ToolType,
};
use super::user_interaction::UserInteractionRegistry;
pub const ASK_USER_QUESTIONS_TOOL_NAME: &str = "ask_user_questions";
pub const ASK_USER_QUESTIONS_TOOL_DESCRIPTION: &str = "Ask the user one or more questions with structured response options. \
Supports single choice, multiple choice, and free text question types.";
pub const ASK_USER_QUESTIONS_TOOL_SCHEMA: &str = r#"{
"type": "object",
"properties": {
"questions": {
"type": "array",
"description": "List of questions to ask the user",
"items": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The question text to display"
},
"type": {
"type": "string",
"enum": ["SingleChoice", "MultiChoice", "FreeText"],
"description": "The type of question"
},
"choices": {
"type": "array",
"description": "Available choices for SingleChoice/MultiChoice. User can always type a custom answer instead.",
"items": {
"type": "string",
"description": "Choice text to display"
}
},
"required": {
"type": "boolean",
"description": "Whether an answer is required"
},
"defaultValue": {
"type": "string",
"description": "Default value for FreeText questions"
}
},
"required": ["text", "type"]
}
}
},
"required": ["questions"]
}"#;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type")]
pub enum Question {
SingleChoice {
text: String,
choices: Vec<String>,
#[serde(default)]
required: bool,
},
MultiChoice {
text: String,
choices: Vec<String>,
#[serde(default)]
required: bool,
},
FreeText {
text: String,
#[serde(default, rename = "defaultValue")]
default_value: Option<String>,
#[serde(default)]
required: bool,
},
}
impl Question {
pub fn text(&self) -> &str {
match self {
Question::SingleChoice { text, .. } => text,
Question::MultiChoice { text, .. } => text,
Question::FreeText { text, .. } => text,
}
}
pub fn is_required(&self) -> bool {
match self {
Question::SingleChoice { required, .. } => *required,
Question::MultiChoice { required, .. } => *required,
Question::FreeText { required, .. } => *required,
}
}
pub fn choices(&self) -> &[String] {
match self {
Question::SingleChoice { choices, .. } => choices,
Question::MultiChoice { choices, .. } => choices,
Question::FreeText { .. } => &[],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Answer {
pub question: String,
pub answer: Vec<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ValidationErrorCode {
RequiredFieldEmpty,
TooManySelections,
EmptyChoices,
UnknownQuestion,
}
impl std::fmt::Display for ValidationErrorCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ValidationErrorCode::RequiredFieldEmpty => write!(f, "required_field_empty"),
ValidationErrorCode::TooManySelections => write!(f, "too_many_selections"),
ValidationErrorCode::EmptyChoices => write!(f, "empty_choices"),
ValidationErrorCode::UnknownQuestion => write!(f, "unknown_question"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ValidationErrorDetail {
pub question: String,
pub error: ValidationErrorCode,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ValidationError {
pub error: String,
pub details: Vec<ValidationErrorDetail>,
}
impl ValidationError {
pub fn new(details: Vec<ValidationErrorDetail>) -> Self {
Self {
error: "validation_failed".to_string(),
details,
}
}
}
impl std::fmt::Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Validation failed: ")?;
for (i, detail) in self.details.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "'{}': {}", detail.question, detail.error)?;
}
Ok(())
}
}
impl std::error::Error for ValidationError {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AskUserQuestionsRequest {
pub questions: Vec<Question>,
}
impl AskUserQuestionsRequest {
pub fn validate(&self) -> Result<(), ValidationError> {
let mut errors = Vec::new();
for question in &self.questions {
match question {
Question::SingleChoice { text, choices, .. }
| Question::MultiChoice { text, choices, .. } => {
if choices.is_empty() {
errors.push(ValidationErrorDetail {
question: text.clone(),
error: ValidationErrorCode::EmptyChoices,
});
}
}
Question::FreeText { .. } => {
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(ValidationError::new(errors))
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AskUserQuestionsResponse {
pub answers: Vec<Answer>,
}
impl AskUserQuestionsResponse {
pub fn validate(&self, request: &AskUserQuestionsRequest) -> Result<(), ValidationError> {
let mut errors = Vec::new();
let questions: HashMap<&str, &Question> =
request.questions.iter().map(|q| (q.text(), q)).collect();
let mut answered: std::collections::HashSet<&str> = std::collections::HashSet::new();
for answer in &self.answers {
let Some(question) = questions.get(answer.question.as_str()) else {
errors.push(ValidationErrorDetail {
question: answer.question.clone(),
error: ValidationErrorCode::UnknownQuestion,
});
continue;
};
answered.insert(answer.question.as_str());
if let Question::SingleChoice { .. } = question
&& answer.answer.len() > 1
{
errors.push(ValidationErrorDetail {
question: answer.question.clone(),
error: ValidationErrorCode::TooManySelections,
});
}
}
for question in &request.questions {
let question_text = question.text();
if question.is_required() {
let has_valid_answer = self.answers.iter().any(|a| {
a.question == question_text
&& !a.answer.is_empty()
&& a.answer.iter().any(|s| !s.is_empty())
});
if !has_valid_answer {
errors.push(ValidationErrorDetail {
question: question_text.to_string(),
error: ValidationErrorCode::RequiredFieldEmpty,
});
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(ValidationError::new(errors))
}
}
}
pub struct AskUserQuestionsTool {
registry: Arc<UserInteractionRegistry>,
}
impl AskUserQuestionsTool {
pub fn new(registry: Arc<UserInteractionRegistry>) -> Self {
Self { registry }
}
}
impl Executable for AskUserQuestionsTool {
fn name(&self) -> &str {
ASK_USER_QUESTIONS_TOOL_NAME
}
fn description(&self) -> &str {
ASK_USER_QUESTIONS_TOOL_DESCRIPTION
}
fn input_schema(&self) -> &str {
ASK_USER_QUESTIONS_TOOL_SCHEMA
}
fn tool_type(&self) -> ToolType {
ToolType::UserInteraction
}
fn execute(
&self,
context: ToolContext,
input: HashMap<String, serde_json::Value>,
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
let registry = self.registry.clone();
Box::pin(async move {
let questions_value = input
.get("questions")
.ok_or_else(|| "Missing 'questions' field".to_string())?;
let questions: Vec<Question> = serde_json::from_value(questions_value.clone())
.map_err(|e| format!("Failed to parse questions: {}", e))?;
let request = AskUserQuestionsRequest { questions };
if let Err(validation_error) = request.validate() {
return Err(serde_json::to_string(&validation_error)
.unwrap_or_else(|_| validation_error.to_string()));
}
let rx = registry
.register(
context.tool_use_id,
context.session_id,
request.clone(),
context.turn_id,
)
.await
.map_err(|e| format!("Failed to register interaction: {}", e))?;
let response = rx
.await
.map_err(|_| "User declined to answer".to_string())?;
if let Err(validation_error) = response.validate(&request) {
return Err(serde_json::to_string(&validation_error)
.unwrap_or_else(|_| validation_error.to_string()));
}
serde_json::to_string(&response)
.map_err(|e| format!("Failed to serialize response: {}", e))
})
}
fn display_config(&self) -> DisplayConfig {
DisplayConfig {
display_name: "Ask User Questions".to_string(),
display_title: Box::new(|input| {
input
.get("questions")
.and_then(|v| v.as_array())
.map(|arr| {
if arr.len() == 1 {
"1 question".to_string()
} else {
format!("{} questions", arr.len())
}
})
.unwrap_or_default()
}),
display_content: Box::new(|input, _result| {
let content = input
.get("questions")
.and_then(|v| v.as_array())
.map(|questions| {
questions
.iter()
.filter_map(|q| q.get("text").and_then(|t| t.as_str()))
.collect::<Vec<_>>()
.join("\n")
})
.unwrap_or_default();
DisplayResult {
content,
content_type: ResultContentType::PlainText,
is_truncated: false,
full_length: 0,
}
}),
}
}
fn compact_summary(&self, input: &HashMap<String, serde_json::Value>, _result: &str) -> String {
let count = input
.get("questions")
.and_then(|v| v.as_array())
.map(|arr| arr.len())
.unwrap_or(0);
format!("[AskUserQuestions: {} question(s)]", count)
}
fn handles_own_permissions(&self) -> bool {
true }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_single_choice_question() {
let json = r#"{
"type": "SingleChoice",
"text": "Which database?",
"choices": ["PostgreSQL", "MySQL", "SQLite"],
"required": true
}"#;
let question: Question = serde_json::from_str(json).unwrap();
assert_eq!(question.text(), "Which database?");
assert!(question.is_required());
if let Question::SingleChoice { choices, .. } = question {
assert_eq!(choices.len(), 3);
assert_eq!(choices[0], "PostgreSQL");
} else {
panic!("Expected SingleChoice");
}
}
#[test]
fn test_parse_multi_choice_question() {
let json = r#"{
"type": "MultiChoice",
"text": "Which features?",
"choices": ["Authentication", "Logging", "Caching"],
"required": false
}"#;
let question: Question = serde_json::from_str(json).unwrap();
assert_eq!(question.text(), "Which features?");
assert!(!question.is_required());
if let Question::MultiChoice { choices, .. } = question {
assert_eq!(choices.len(), 3);
} else {
panic!("Expected MultiChoice");
}
}
#[test]
fn test_parse_free_text_question() {
let json = r#"{
"type": "FreeText",
"text": "Any notes?",
"defaultValue": "None",
"required": false
}"#;
let question: Question = serde_json::from_str(json).unwrap();
assert_eq!(question.text(), "Any notes?");
if let Question::FreeText { default_value, .. } = question {
assert_eq!(default_value, Some("None".to_string()));
} else {
panic!("Expected FreeText");
}
}
#[test]
fn test_validate_request_empty_choices() {
let request = AskUserQuestionsRequest {
questions: vec![Question::SingleChoice {
text: "Question?".to_string(),
choices: vec![],
required: true,
}],
};
let err = request.validate().unwrap_err();
assert_eq!(err.details.len(), 1);
assert_eq!(err.details[0].error, ValidationErrorCode::EmptyChoices);
}
#[test]
fn test_validate_response_too_many_selections() {
let request = AskUserQuestionsRequest {
questions: vec![Question::SingleChoice {
text: "Question?".to_string(),
choices: vec!["A".to_string(), "B".to_string()],
required: true,
}],
};
let response = AskUserQuestionsResponse {
answers: vec![Answer {
question: "Question?".to_string(),
answer: vec!["A".to_string(), "B".to_string()],
}],
};
let err = response.validate(&request).unwrap_err();
assert!(
err.details
.iter()
.any(|d| d.error == ValidationErrorCode::TooManySelections)
);
}
#[test]
fn test_validate_response_required_field_empty() {
let request = AskUserQuestionsRequest {
questions: vec![Question::SingleChoice {
text: "Question?".to_string(),
choices: vec!["A".to_string()],
required: true,
}],
};
let response = AskUserQuestionsResponse { answers: vec![] };
let err = response.validate(&request).unwrap_err();
assert!(
err.details
.iter()
.any(|d| d.error == ValidationErrorCode::RequiredFieldEmpty)
);
}
#[test]
fn test_validate_response_unknown_question() {
let request = AskUserQuestionsRequest {
questions: vec![Question::SingleChoice {
text: "Question?".to_string(),
choices: vec!["A".to_string()],
required: false,
}],
};
let response = AskUserQuestionsResponse {
answers: vec![Answer {
question: "Unknown question?".to_string(),
answer: vec!["A".to_string()],
}],
};
let err = response.validate(&request).unwrap_err();
assert!(
err.details
.iter()
.any(|d| d.error == ValidationErrorCode::UnknownQuestion)
);
}
#[test]
fn test_validate_response_success() {
let request = AskUserQuestionsRequest {
questions: vec![
Question::SingleChoice {
text: "Question 1?".to_string(),
choices: vec!["A".to_string(), "B".to_string()],
required: true,
},
Question::MultiChoice {
text: "Question 2?".to_string(),
choices: vec!["X".to_string(), "Y".to_string()],
required: false,
},
Question::FreeText {
text: "Question 3?".to_string(),
default_value: None,
required: false,
},
],
};
let response = AskUserQuestionsResponse {
answers: vec![
Answer {
question: "Question 1?".to_string(),
answer: vec!["A".to_string()],
},
Answer {
question: "Question 2?".to_string(),
answer: vec!["X".to_string(), "Y".to_string()],
},
Answer {
question: "Question 3?".to_string(),
answer: vec!["Some notes".to_string()],
},
],
};
assert!(response.validate(&request).is_ok());
}
#[test]
fn test_answer_serialization() {
let answer = Answer {
question: "Which database?".to_string(),
answer: vec!["PostgreSQL".to_string()],
};
let json = serde_json::to_string(&answer).unwrap();
assert!(json.contains("question"));
assert!(json.contains("answer"));
assert!(json.contains("PostgreSQL"));
}
#[test]
fn test_custom_answer_allowed() {
let request = AskUserQuestionsRequest {
questions: vec![Question::SingleChoice {
text: "Which database?".to_string(),
choices: vec!["PostgreSQL".to_string(), "MySQL".to_string()],
required: true,
}],
};
let response = AskUserQuestionsResponse {
answers: vec![Answer {
question: "Which database?".to_string(),
answer: vec!["MongoDB".to_string()], }],
};
assert!(response.validate(&request).is_ok());
}
}