1use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9
10use serde::{Deserialize, Serialize};
11
12use std::sync::Arc;
13
14use super::types::{
15 DisplayConfig, DisplayResult, Executable, ResultContentType, ToolContext, ToolType,
16};
17use super::user_interaction::UserInteractionRegistry;
18
19pub const ASK_USER_QUESTIONS_TOOL_NAME: &str = "ask_user_questions";
21
22pub const ASK_USER_QUESTIONS_TOOL_DESCRIPTION: &str = "Ask the user one or more questions with structured response options. \
24 Supports single choice, multiple choice, and free text question types.";
25
26pub const ASK_USER_QUESTIONS_TOOL_SCHEMA: &str = r#"{
28 "type": "object",
29 "properties": {
30 "questions": {
31 "type": "array",
32 "description": "List of questions to ask the user",
33 "items": {
34 "type": "object",
35 "properties": {
36 "text": {
37 "type": "string",
38 "description": "The question text to display"
39 },
40 "type": {
41 "type": "string",
42 "enum": ["SingleChoice", "MultiChoice", "FreeText"],
43 "description": "The type of question"
44 },
45 "choices": {
46 "type": "array",
47 "description": "Available choices for SingleChoice/MultiChoice. User can always type a custom answer instead.",
48 "items": {
49 "type": "string",
50 "description": "Choice text to display"
51 }
52 },
53 "required": {
54 "type": "boolean",
55 "description": "Whether an answer is required"
56 },
57 "defaultValue": {
58 "type": "string",
59 "description": "Default value for FreeText questions"
60 }
61 },
62 "required": ["text", "type"]
63 }
64 }
65 },
66 "required": ["questions"]
67}"#;
68
69#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
71#[serde(tag = "type")]
72pub enum Question {
73 SingleChoice {
76 text: String,
78 choices: Vec<String>,
80 #[serde(default)]
82 required: bool,
83 },
84 MultiChoice {
87 text: String,
89 choices: Vec<String>,
91 #[serde(default)]
93 required: bool,
94 },
95 FreeText {
97 text: String,
99 #[serde(default, rename = "defaultValue")]
101 default_value: Option<String>,
102 #[serde(default)]
104 required: bool,
105 },
106}
107
108impl Question {
109 pub fn text(&self) -> &str {
111 match self {
112 Question::SingleChoice { text, .. } => text,
113 Question::MultiChoice { text, .. } => text,
114 Question::FreeText { text, .. } => text,
115 }
116 }
117
118 pub fn is_required(&self) -> bool {
120 match self {
121 Question::SingleChoice { required, .. } => *required,
122 Question::MultiChoice { required, .. } => *required,
123 Question::FreeText { required, .. } => *required,
124 }
125 }
126
127 pub fn choices(&self) -> &[String] {
129 match self {
130 Question::SingleChoice { choices, .. } => choices,
131 Question::MultiChoice { choices, .. } => choices,
132 Question::FreeText { .. } => &[],
133 }
134 }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
139pub struct Answer {
140 pub question: String,
142 pub answer: Vec<String>,
145}
146
147#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
149#[serde(rename_all = "snake_case")]
150pub enum ValidationErrorCode {
151 RequiredFieldEmpty,
153 TooManySelections,
155 EmptyChoices,
157 UnknownQuestion,
159}
160
161impl std::fmt::Display for ValidationErrorCode {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 match self {
164 ValidationErrorCode::RequiredFieldEmpty => write!(f, "required_field_empty"),
165 ValidationErrorCode::TooManySelections => write!(f, "too_many_selections"),
166 ValidationErrorCode::EmptyChoices => write!(f, "empty_choices"),
167 ValidationErrorCode::UnknownQuestion => write!(f, "unknown_question"),
168 }
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
174pub struct ValidationErrorDetail {
175 pub question: String,
177 pub error: ValidationErrorCode,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
183pub struct ValidationError {
184 pub error: String,
186 pub details: Vec<ValidationErrorDetail>,
188}
189
190impl ValidationError {
191 pub fn new(details: Vec<ValidationErrorDetail>) -> Self {
193 Self {
194 error: "validation_failed".to_string(),
195 details,
196 }
197 }
198}
199
200impl std::fmt::Display for ValidationError {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 write!(f, "Validation failed: ")?;
203 for (i, detail) in self.details.iter().enumerate() {
204 if i > 0 {
205 write!(f, ", ")?;
206 }
207 write!(f, "'{}': {}", detail.question, detail.error)?;
208 }
209 Ok(())
210 }
211}
212
213impl std::error::Error for ValidationError {}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct AskUserQuestionsRequest {
218 pub questions: Vec<Question>,
220}
221
222impl AskUserQuestionsRequest {
223 pub fn validate(&self) -> Result<(), ValidationError> {
225 let mut errors = Vec::new();
226
227 for question in &self.questions {
228 match question {
229 Question::SingleChoice { text, choices, .. }
230 | Question::MultiChoice { text, choices, .. } => {
231 if choices.is_empty() {
233 errors.push(ValidationErrorDetail {
234 question: text.clone(),
235 error: ValidationErrorCode::EmptyChoices,
236 });
237 }
238 }
239 Question::FreeText { .. } => {
240 }
242 }
243 }
244
245 if errors.is_empty() {
246 Ok(())
247 } else {
248 Err(ValidationError::new(errors))
249 }
250 }
251}
252
253#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
255pub struct AskUserQuestionsResponse {
256 pub answers: Vec<Answer>,
258}
259
260impl AskUserQuestionsResponse {
261 pub fn validate(&self, request: &AskUserQuestionsRequest) -> Result<(), ValidationError> {
263 let mut errors = Vec::new();
264
265 let questions: HashMap<&str, &Question> =
267 request.questions.iter().map(|q| (q.text(), q)).collect();
268
269 let mut answered: std::collections::HashSet<&str> = std::collections::HashSet::new();
271
272 for answer in &self.answers {
273 let Some(question) = questions.get(answer.question.as_str()) else {
275 errors.push(ValidationErrorDetail {
276 question: answer.question.clone(),
277 error: ValidationErrorCode::UnknownQuestion,
278 });
279 continue;
280 };
281
282 answered.insert(answer.question.as_str());
283
284 if let Question::SingleChoice { .. } = question
286 && answer.answer.len() > 1
287 {
288 errors.push(ValidationErrorDetail {
289 question: answer.question.clone(),
290 error: ValidationErrorCode::TooManySelections,
291 });
292 }
293 }
294
295 for question in &request.questions {
297 let question_text = question.text();
298 if question.is_required() {
299 let has_valid_answer = self.answers.iter().any(|a| {
301 a.question == question_text
302 && !a.answer.is_empty()
303 && a.answer.iter().any(|s| !s.is_empty())
304 });
305
306 if !has_valid_answer {
307 errors.push(ValidationErrorDetail {
308 question: question_text.to_string(),
309 error: ValidationErrorCode::RequiredFieldEmpty,
310 });
311 }
312 }
313 }
314
315 if errors.is_empty() {
316 Ok(())
317 } else {
318 Err(ValidationError::new(errors))
319 }
320 }
321}
322
323pub struct AskUserQuestionsTool {
325 registry: Arc<UserInteractionRegistry>,
327}
328
329impl AskUserQuestionsTool {
330 pub fn new(registry: Arc<UserInteractionRegistry>) -> Self {
335 Self { registry }
336 }
337}
338
339impl Executable for AskUserQuestionsTool {
340 fn name(&self) -> &str {
341 ASK_USER_QUESTIONS_TOOL_NAME
342 }
343
344 fn description(&self) -> &str {
345 ASK_USER_QUESTIONS_TOOL_DESCRIPTION
346 }
347
348 fn input_schema(&self) -> &str {
349 ASK_USER_QUESTIONS_TOOL_SCHEMA
350 }
351
352 fn tool_type(&self) -> ToolType {
353 ToolType::UserInteraction
354 }
355
356 fn execute(
357 &self,
358 context: ToolContext,
359 input: HashMap<String, serde_json::Value>,
360 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
361 let registry = self.registry.clone();
362
363 Box::pin(async move {
364 let questions_value = input
366 .get("questions")
367 .ok_or_else(|| "Missing 'questions' field".to_string())?;
368
369 let questions: Vec<Question> = serde_json::from_value(questions_value.clone())
370 .map_err(|e| format!("Failed to parse questions: {}", e))?;
371
372 let request = AskUserQuestionsRequest { questions };
373
374 if let Err(validation_error) = request.validate() {
376 return Err(serde_json::to_string(&validation_error)
377 .unwrap_or_else(|_| validation_error.to_string()));
378 }
379
380 let rx = registry
382 .register(
383 context.tool_use_id,
384 context.session_id,
385 request.clone(),
386 context.turn_id,
387 )
388 .await
389 .map_err(|e| format!("Failed to register interaction: {}", e))?;
390
391 let response = rx
393 .await
394 .map_err(|_| "User declined to answer".to_string())?;
395
396 if let Err(validation_error) = response.validate(&request) {
398 return Err(serde_json::to_string(&validation_error)
399 .unwrap_or_else(|_| validation_error.to_string()));
400 }
401
402 serde_json::to_string(&response)
404 .map_err(|e| format!("Failed to serialize response: {}", e))
405 })
406 }
407
408 fn display_config(&self) -> DisplayConfig {
409 DisplayConfig {
410 display_name: "Ask User Questions".to_string(),
411 display_title: Box::new(|input| {
412 input
413 .get("questions")
414 .and_then(|v| v.as_array())
415 .map(|arr| {
416 if arr.len() == 1 {
417 "1 question".to_string()
418 } else {
419 format!("{} questions", arr.len())
420 }
421 })
422 .unwrap_or_default()
423 }),
424 display_content: Box::new(|input, _result| {
425 let content = input
426 .get("questions")
427 .and_then(|v| v.as_array())
428 .map(|questions| {
429 questions
430 .iter()
431 .filter_map(|q| q.get("text").and_then(|t| t.as_str()))
432 .collect::<Vec<_>>()
433 .join("\n")
434 })
435 .unwrap_or_default();
436
437 DisplayResult {
438 content,
439 content_type: ResultContentType::PlainText,
440 is_truncated: false,
441 full_length: 0,
442 }
443 }),
444 }
445 }
446
447 fn compact_summary(&self, input: &HashMap<String, serde_json::Value>, _result: &str) -> String {
448 let count = input
449 .get("questions")
450 .and_then(|v| v.as_array())
451 .map(|arr| arr.len())
452 .unwrap_or(0);
453 format!("[AskUserQuestions: {} question(s)]", count)
454 }
455
456 fn handles_own_permissions(&self) -> bool {
457 true }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_parse_single_choice_question() {
467 let json = r#"{
468 "type": "SingleChoice",
469 "text": "Which database?",
470 "choices": ["PostgreSQL", "MySQL", "SQLite"],
471 "required": true
472 }"#;
473
474 let question: Question = serde_json::from_str(json).unwrap();
475 assert_eq!(question.text(), "Which database?");
476 assert!(question.is_required());
477
478 if let Question::SingleChoice { choices, .. } = question {
479 assert_eq!(choices.len(), 3);
480 assert_eq!(choices[0], "PostgreSQL");
481 } else {
482 panic!("Expected SingleChoice");
483 }
484 }
485
486 #[test]
487 fn test_parse_multi_choice_question() {
488 let json = r#"{
489 "type": "MultiChoice",
490 "text": "Which features?",
491 "choices": ["Authentication", "Logging", "Caching"],
492 "required": false
493 }"#;
494
495 let question: Question = serde_json::from_str(json).unwrap();
496 assert_eq!(question.text(), "Which features?");
497 assert!(!question.is_required());
498
499 if let Question::MultiChoice { choices, .. } = question {
500 assert_eq!(choices.len(), 3);
501 } else {
502 panic!("Expected MultiChoice");
503 }
504 }
505
506 #[test]
507 fn test_parse_free_text_question() {
508 let json = r#"{
509 "type": "FreeText",
510 "text": "Any notes?",
511 "defaultValue": "None",
512 "required": false
513 }"#;
514
515 let question: Question = serde_json::from_str(json).unwrap();
516 assert_eq!(question.text(), "Any notes?");
517
518 if let Question::FreeText { default_value, .. } = question {
519 assert_eq!(default_value, Some("None".to_string()));
520 } else {
521 panic!("Expected FreeText");
522 }
523 }
524
525 #[test]
526 fn test_validate_request_empty_choices() {
527 let request = AskUserQuestionsRequest {
528 questions: vec![Question::SingleChoice {
529 text: "Question?".to_string(),
530 choices: vec![],
531 required: true,
532 }],
533 };
534
535 let err = request.validate().unwrap_err();
536 assert_eq!(err.details.len(), 1);
537 assert_eq!(err.details[0].error, ValidationErrorCode::EmptyChoices);
538 }
539
540 #[test]
541 fn test_validate_response_too_many_selections() {
542 let request = AskUserQuestionsRequest {
543 questions: vec![Question::SingleChoice {
544 text: "Question?".to_string(),
545 choices: vec!["A".to_string(), "B".to_string()],
546 required: true,
547 }],
548 };
549
550 let response = AskUserQuestionsResponse {
551 answers: vec![Answer {
552 question: "Question?".to_string(),
553 answer: vec!["A".to_string(), "B".to_string()],
554 }],
555 };
556
557 let err = response.validate(&request).unwrap_err();
558 assert!(
559 err.details
560 .iter()
561 .any(|d| d.error == ValidationErrorCode::TooManySelections)
562 );
563 }
564
565 #[test]
566 fn test_validate_response_required_field_empty() {
567 let request = AskUserQuestionsRequest {
568 questions: vec![Question::SingleChoice {
569 text: "Question?".to_string(),
570 choices: vec!["A".to_string()],
571 required: true,
572 }],
573 };
574
575 let response = AskUserQuestionsResponse { answers: vec![] };
576
577 let err = response.validate(&request).unwrap_err();
578 assert!(
579 err.details
580 .iter()
581 .any(|d| d.error == ValidationErrorCode::RequiredFieldEmpty)
582 );
583 }
584
585 #[test]
586 fn test_validate_response_unknown_question() {
587 let request = AskUserQuestionsRequest {
588 questions: vec![Question::SingleChoice {
589 text: "Question?".to_string(),
590 choices: vec!["A".to_string()],
591 required: false,
592 }],
593 };
594
595 let response = AskUserQuestionsResponse {
596 answers: vec![Answer {
597 question: "Unknown question?".to_string(),
598 answer: vec!["A".to_string()],
599 }],
600 };
601
602 let err = response.validate(&request).unwrap_err();
603 assert!(
604 err.details
605 .iter()
606 .any(|d| d.error == ValidationErrorCode::UnknownQuestion)
607 );
608 }
609
610 #[test]
611 fn test_validate_response_success() {
612 let request = AskUserQuestionsRequest {
613 questions: vec![
614 Question::SingleChoice {
615 text: "Question 1?".to_string(),
616 choices: vec!["A".to_string(), "B".to_string()],
617 required: true,
618 },
619 Question::MultiChoice {
620 text: "Question 2?".to_string(),
621 choices: vec!["X".to_string(), "Y".to_string()],
622 required: false,
623 },
624 Question::FreeText {
625 text: "Question 3?".to_string(),
626 default_value: None,
627 required: false,
628 },
629 ],
630 };
631
632 let response = AskUserQuestionsResponse {
633 answers: vec![
634 Answer {
635 question: "Question 1?".to_string(),
636 answer: vec!["A".to_string()],
637 },
638 Answer {
639 question: "Question 2?".to_string(),
640 answer: vec!["X".to_string(), "Y".to_string()],
641 },
642 Answer {
643 question: "Question 3?".to_string(),
644 answer: vec!["Some notes".to_string()],
645 },
646 ],
647 };
648
649 assert!(response.validate(&request).is_ok());
650 }
651
652 #[test]
653 fn test_answer_serialization() {
654 let answer = Answer {
655 question: "Which database?".to_string(),
656 answer: vec!["PostgreSQL".to_string()],
657 };
658
659 let json = serde_json::to_string(&answer).unwrap();
660 assert!(json.contains("question"));
661 assert!(json.contains("answer"));
662 assert!(json.contains("PostgreSQL"));
663 }
664
665 #[test]
666 fn test_custom_answer_allowed() {
667 let request = AskUserQuestionsRequest {
669 questions: vec![Question::SingleChoice {
670 text: "Which database?".to_string(),
671 choices: vec!["PostgreSQL".to_string(), "MySQL".to_string()],
672 required: true,
673 }],
674 };
675
676 let response = AskUserQuestionsResponse {
677 answers: vec![Answer {
678 question: "Which database?".to_string(),
679 answer: vec!["MongoDB".to_string()], }],
681 };
682
683 assert!(response.validate(&request).is_ok());
685 }
686}