1use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9
10use serde::{Deserialize, Serialize};
11
12use std::sync::Arc;
13
14use super::types::{DisplayConfig, DisplayResult, Executable, ResultContentType, ToolContext, ToolType};
15use super::user_interaction::UserInteractionRegistry;
16
17pub const ASK_USER_QUESTIONS_TOOL_NAME: &str = "ask_user_questions";
19
20pub const ASK_USER_QUESTIONS_TOOL_DESCRIPTION: &str =
22 "Ask the user one or more questions with structured response options. \
23 Supports single choice, multiple choice, and free text question types.";
24
25pub const ASK_USER_QUESTIONS_TOOL_SCHEMA: &str = r#"{
27 "type": "object",
28 "properties": {
29 "questions": {
30 "type": "array",
31 "description": "List of questions to ask the user",
32 "items": {
33 "type": "object",
34 "properties": {
35 "text": {
36 "type": "string",
37 "description": "The question text to display"
38 },
39 "type": {
40 "type": "string",
41 "enum": ["SingleChoice", "MultiChoice", "FreeText"],
42 "description": "The type of question"
43 },
44 "choices": {
45 "type": "array",
46 "description": "Available choices for SingleChoice/MultiChoice. User can always type a custom answer instead.",
47 "items": {
48 "type": "string",
49 "description": "Choice text to display"
50 }
51 },
52 "required": {
53 "type": "boolean",
54 "description": "Whether an answer is required"
55 },
56 "defaultValue": {
57 "type": "string",
58 "description": "Default value for FreeText questions"
59 }
60 },
61 "required": ["text", "type"]
62 }
63 }
64 },
65 "required": ["questions"]
66}"#;
67
68#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
70#[serde(tag = "type")]
71pub enum Question {
72 SingleChoice {
75 text: String,
77 choices: Vec<String>,
79 #[serde(default)]
81 required: bool,
82 },
83 MultiChoice {
86 text: String,
88 choices: Vec<String>,
90 #[serde(default)]
92 required: bool,
93 },
94 FreeText {
96 text: String,
98 #[serde(default, rename = "defaultValue")]
100 default_value: Option<String>,
101 #[serde(default)]
103 required: bool,
104 },
105}
106
107impl Question {
108 pub fn text(&self) -> &str {
110 match self {
111 Question::SingleChoice { text, .. } => text,
112 Question::MultiChoice { text, .. } => text,
113 Question::FreeText { text, .. } => text,
114 }
115 }
116
117 pub fn is_required(&self) -> bool {
119 match self {
120 Question::SingleChoice { required, .. } => *required,
121 Question::MultiChoice { required, .. } => *required,
122 Question::FreeText { required, .. } => *required,
123 }
124 }
125
126 pub fn choices(&self) -> &[String] {
128 match self {
129 Question::SingleChoice { choices, .. } => choices,
130 Question::MultiChoice { choices, .. } => choices,
131 Question::FreeText { .. } => &[],
132 }
133 }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
138pub struct Answer {
139 pub question: String,
141 pub answer: Vec<String>,
144}
145
146#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
148#[serde(rename_all = "snake_case")]
149pub enum ValidationErrorCode {
150 RequiredFieldEmpty,
152 TooManySelections,
154 EmptyChoices,
156 UnknownQuestion,
158}
159
160impl std::fmt::Display for ValidationErrorCode {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 ValidationErrorCode::RequiredFieldEmpty => write!(f, "required_field_empty"),
164 ValidationErrorCode::TooManySelections => write!(f, "too_many_selections"),
165 ValidationErrorCode::EmptyChoices => write!(f, "empty_choices"),
166 ValidationErrorCode::UnknownQuestion => write!(f, "unknown_question"),
167 }
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
173pub struct ValidationErrorDetail {
174 pub question: String,
176 pub error: ValidationErrorCode,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
182pub struct ValidationError {
183 pub error: String,
185 pub details: Vec<ValidationErrorDetail>,
187}
188
189impl ValidationError {
190 pub fn new(details: Vec<ValidationErrorDetail>) -> Self {
192 Self {
193 error: "validation_failed".to_string(),
194 details,
195 }
196 }
197}
198
199impl std::fmt::Display for ValidationError {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 write!(f, "Validation failed: ")?;
202 for (i, detail) in self.details.iter().enumerate() {
203 if i > 0 {
204 write!(f, ", ")?;
205 }
206 write!(f, "'{}': {}", detail.question, detail.error)?;
207 }
208 Ok(())
209 }
210}
211
212impl std::error::Error for ValidationError {}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct AskUserQuestionsRequest {
217 pub questions: Vec<Question>,
219}
220
221impl AskUserQuestionsRequest {
222 pub fn validate(&self) -> Result<(), ValidationError> {
224 let mut errors = Vec::new();
225
226 for question in &self.questions {
227 match question {
228 Question::SingleChoice { text, choices, .. }
229 | Question::MultiChoice { text, choices, .. } => {
230 if choices.is_empty() {
232 errors.push(ValidationErrorDetail {
233 question: text.clone(),
234 error: ValidationErrorCode::EmptyChoices,
235 });
236 }
237 }
238 Question::FreeText { .. } => {
239 }
241 }
242 }
243
244 if errors.is_empty() {
245 Ok(())
246 } else {
247 Err(ValidationError::new(errors))
248 }
249 }
250}
251
252#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
254pub struct AskUserQuestionsResponse {
255 pub answers: Vec<Answer>,
257}
258
259impl AskUserQuestionsResponse {
260 pub fn validate(&self, request: &AskUserQuestionsRequest) -> Result<(), ValidationError> {
262 let mut errors = Vec::new();
263
264 let questions: HashMap<&str, &Question> =
266 request.questions.iter().map(|q| (q.text(), q)).collect();
267
268 let mut answered: std::collections::HashSet<&str> = std::collections::HashSet::new();
270
271 for answer in &self.answers {
272 let Some(question) = questions.get(answer.question.as_str()) else {
274 errors.push(ValidationErrorDetail {
275 question: answer.question.clone(),
276 error: ValidationErrorCode::UnknownQuestion,
277 });
278 continue;
279 };
280
281 answered.insert(answer.question.as_str());
282
283 if let Question::SingleChoice { .. } = question {
285 if answer.answer.len() > 1 {
286 errors.push(ValidationErrorDetail {
287 question: answer.question.clone(),
288 error: ValidationErrorCode::TooManySelections,
289 });
290 }
291 }
292 }
293
294 for question in &request.questions {
296 let question_text = question.text();
297 if question.is_required() {
298 let has_valid_answer = self.answers.iter().any(|a| {
300 a.question == question_text && !a.answer.is_empty() && a.answer.iter().any(|s| !s.is_empty())
301 });
302
303 if !has_valid_answer {
304 errors.push(ValidationErrorDetail {
305 question: question_text.to_string(),
306 error: ValidationErrorCode::RequiredFieldEmpty,
307 });
308 }
309 }
310 }
311
312 if errors.is_empty() {
313 Ok(())
314 } else {
315 Err(ValidationError::new(errors))
316 }
317 }
318}
319
320pub struct AskUserQuestionsTool {
322 registry: Arc<UserInteractionRegistry>,
324}
325
326impl AskUserQuestionsTool {
327 pub fn new(registry: Arc<UserInteractionRegistry>) -> Self {
332 Self { registry }
333 }
334}
335
336impl Executable for AskUserQuestionsTool {
337 fn name(&self) -> &str {
338 ASK_USER_QUESTIONS_TOOL_NAME
339 }
340
341 fn description(&self) -> &str {
342 ASK_USER_QUESTIONS_TOOL_DESCRIPTION
343 }
344
345 fn input_schema(&self) -> &str {
346 ASK_USER_QUESTIONS_TOOL_SCHEMA
347 }
348
349 fn tool_type(&self) -> ToolType {
350 ToolType::UserInteraction
351 }
352
353 fn execute(
354 &self,
355 context: ToolContext,
356 input: HashMap<String, serde_json::Value>,
357 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
358 let registry = self.registry.clone();
359
360 Box::pin(async move {
361 let questions_value = input
363 .get("questions")
364 .ok_or_else(|| "Missing 'questions' field".to_string())?;
365
366 let questions: Vec<Question> = serde_json::from_value(questions_value.clone())
367 .map_err(|e| format!("Failed to parse questions: {}", e))?;
368
369 let request = AskUserQuestionsRequest { questions };
370
371 if let Err(validation_error) = request.validate() {
373 return Err(serde_json::to_string(&validation_error)
374 .unwrap_or_else(|_| validation_error.to_string()));
375 }
376
377 let rx = registry
379 .register(
380 context.tool_use_id,
381 context.session_id,
382 request.clone(),
383 context.turn_id,
384 )
385 .await
386 .map_err(|e| format!("Failed to register interaction: {}", e))?;
387
388 let response = rx
390 .await
391 .map_err(|_| "User declined to answer".to_string())?;
392
393 if let Err(validation_error) = response.validate(&request) {
395 return Err(serde_json::to_string(&validation_error)
396 .unwrap_or_else(|_| validation_error.to_string()));
397 }
398
399 serde_json::to_string(&response)
401 .map_err(|e| format!("Failed to serialize response: {}", e))
402 })
403 }
404
405 fn display_config(&self) -> DisplayConfig {
406 DisplayConfig {
407 display_name: "Ask User Questions".to_string(),
408 display_title: Box::new(|input| {
409 input
410 .get("questions")
411 .and_then(|v| v.as_array())
412 .map(|arr| {
413 if arr.len() == 1 {
414 "1 question".to_string()
415 } else {
416 format!("{} questions", arr.len())
417 }
418 })
419 .unwrap_or_default()
420 }),
421 display_content: Box::new(|input, _result| {
422 let content = input
423 .get("questions")
424 .and_then(|v| v.as_array())
425 .map(|questions| {
426 questions
427 .iter()
428 .filter_map(|q| q.get("text").and_then(|t| t.as_str()))
429 .collect::<Vec<_>>()
430 .join("\n")
431 })
432 .unwrap_or_default();
433
434 DisplayResult {
435 content,
436 content_type: ResultContentType::PlainText,
437 is_truncated: false,
438 full_length: 0,
439 }
440 }),
441 }
442 }
443
444 fn compact_summary(
445 &self,
446 input: &HashMap<String, serde_json::Value>,
447 _result: &str,
448 ) -> String {
449 let count = input
450 .get("questions")
451 .and_then(|v| v.as_array())
452 .map(|arr| arr.len())
453 .unwrap_or(0);
454 format!("[AskUserQuestions: {} question(s)]", count)
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_parse_single_choice_question() {
464 let json = r#"{
465 "type": "SingleChoice",
466 "text": "Which database?",
467 "choices": ["PostgreSQL", "MySQL", "SQLite"],
468 "required": true
469 }"#;
470
471 let question: Question = serde_json::from_str(json).unwrap();
472 assert_eq!(question.text(), "Which database?");
473 assert!(question.is_required());
474
475 if let Question::SingleChoice { choices, .. } = question {
476 assert_eq!(choices.len(), 3);
477 assert_eq!(choices[0], "PostgreSQL");
478 } else {
479 panic!("Expected SingleChoice");
480 }
481 }
482
483 #[test]
484 fn test_parse_multi_choice_question() {
485 let json = r#"{
486 "type": "MultiChoice",
487 "text": "Which features?",
488 "choices": ["Authentication", "Logging", "Caching"],
489 "required": false
490 }"#;
491
492 let question: Question = serde_json::from_str(json).unwrap();
493 assert_eq!(question.text(), "Which features?");
494 assert!(!question.is_required());
495
496 if let Question::MultiChoice { choices, .. } = question {
497 assert_eq!(choices.len(), 3);
498 } else {
499 panic!("Expected MultiChoice");
500 }
501 }
502
503 #[test]
504 fn test_parse_free_text_question() {
505 let json = r#"{
506 "type": "FreeText",
507 "text": "Any notes?",
508 "defaultValue": "None",
509 "required": false
510 }"#;
511
512 let question: Question = serde_json::from_str(json).unwrap();
513 assert_eq!(question.text(), "Any notes?");
514
515 if let Question::FreeText { default_value, .. } = question {
516 assert_eq!(default_value, Some("None".to_string()));
517 } else {
518 panic!("Expected FreeText");
519 }
520 }
521
522 #[test]
523 fn test_validate_request_empty_choices() {
524 let request = AskUserQuestionsRequest {
525 questions: vec![Question::SingleChoice {
526 text: "Question?".to_string(),
527 choices: vec![],
528 required: true,
529 }],
530 };
531
532 let err = request.validate().unwrap_err();
533 assert_eq!(err.details.len(), 1);
534 assert_eq!(err.details[0].error, ValidationErrorCode::EmptyChoices);
535 }
536
537 #[test]
538 fn test_validate_response_too_many_selections() {
539 let request = AskUserQuestionsRequest {
540 questions: vec![Question::SingleChoice {
541 text: "Question?".to_string(),
542 choices: vec!["A".to_string(), "B".to_string()],
543 required: true,
544 }],
545 };
546
547 let response = AskUserQuestionsResponse {
548 answers: vec![Answer {
549 question: "Question?".to_string(),
550 answer: vec!["A".to_string(), "B".to_string()],
551 }],
552 };
553
554 let err = response.validate(&request).unwrap_err();
555 assert!(err
556 .details
557 .iter()
558 .any(|d| d.error == ValidationErrorCode::TooManySelections));
559 }
560
561 #[test]
562 fn test_validate_response_required_field_empty() {
563 let request = AskUserQuestionsRequest {
564 questions: vec![Question::SingleChoice {
565 text: "Question?".to_string(),
566 choices: vec!["A".to_string()],
567 required: true,
568 }],
569 };
570
571 let response = AskUserQuestionsResponse { answers: vec![] };
572
573 let err = response.validate(&request).unwrap_err();
574 assert!(err
575 .details
576 .iter()
577 .any(|d| d.error == ValidationErrorCode::RequiredFieldEmpty));
578 }
579
580 #[test]
581 fn test_validate_response_unknown_question() {
582 let request = AskUserQuestionsRequest {
583 questions: vec![Question::SingleChoice {
584 text: "Question?".to_string(),
585 choices: vec!["A".to_string()],
586 required: false,
587 }],
588 };
589
590 let response = AskUserQuestionsResponse {
591 answers: vec![Answer {
592 question: "Unknown question?".to_string(),
593 answer: vec!["A".to_string()],
594 }],
595 };
596
597 let err = response.validate(&request).unwrap_err();
598 assert!(err
599 .details
600 .iter()
601 .any(|d| d.error == ValidationErrorCode::UnknownQuestion));
602 }
603
604 #[test]
605 fn test_validate_response_success() {
606 let request = AskUserQuestionsRequest {
607 questions: vec![
608 Question::SingleChoice {
609 text: "Question 1?".to_string(),
610 choices: vec!["A".to_string(), "B".to_string()],
611 required: true,
612 },
613 Question::MultiChoice {
614 text: "Question 2?".to_string(),
615 choices: vec!["X".to_string(), "Y".to_string()],
616 required: false,
617 },
618 Question::FreeText {
619 text: "Question 3?".to_string(),
620 default_value: None,
621 required: false,
622 },
623 ],
624 };
625
626 let response = AskUserQuestionsResponse {
627 answers: vec![
628 Answer {
629 question: "Question 1?".to_string(),
630 answer: vec!["A".to_string()],
631 },
632 Answer {
633 question: "Question 2?".to_string(),
634 answer: vec!["X".to_string(), "Y".to_string()],
635 },
636 Answer {
637 question: "Question 3?".to_string(),
638 answer: vec!["Some notes".to_string()],
639 },
640 ],
641 };
642
643 assert!(response.validate(&request).is_ok());
644 }
645
646 #[test]
647 fn test_answer_serialization() {
648 let answer = Answer {
649 question: "Which database?".to_string(),
650 answer: vec!["PostgreSQL".to_string()],
651 };
652
653 let json = serde_json::to_string(&answer).unwrap();
654 assert!(json.contains("question"));
655 assert!(json.contains("answer"));
656 assert!(json.contains("PostgreSQL"));
657 }
658
659 #[test]
660 fn test_custom_answer_allowed() {
661 let request = AskUserQuestionsRequest {
663 questions: vec![Question::SingleChoice {
664 text: "Which database?".to_string(),
665 choices: vec!["PostgreSQL".to_string(), "MySQL".to_string()],
666 required: true,
667 }],
668 };
669
670 let response = AskUserQuestionsResponse {
671 answers: vec![Answer {
672 question: "Which database?".to_string(),
673 answer: vec!["MongoDB".to_string()], }],
675 };
676
677 assert!(response.validate(&request).is_ok());
679 }
680}