1use elicitor::{
4 DefaultValue, ListElementKind, Question, QuestionKind, ResponsePath, ResponseValue, Responses,
5 SELECTED_VARIANT_KEY, SELECTED_VARIANTS_KEY, SurveyBackend, SurveyDefinition,
6};
7use thiserror::Error;
8
9#[derive(Debug, Error)]
11pub enum RequesttyError {
12 #[error("Survey cancelled by user")]
14 Cancelled,
15
16 #[error("Prompt error: {0}")]
18 PromptError(String),
19
20 #[error("Unexpected answer type: expected {expected}, got {got}")]
22 UnexpectedAnswerType { expected: String, got: String },
23}
24
25impl From<requestty::ErrorKind> for RequesttyError {
26 fn from(err: requestty::ErrorKind) -> Self {
27 match err {
28 requestty::ErrorKind::Interrupted => Self::Cancelled,
29 _ => Self::PromptError(err.to_string()),
30 }
31 }
32}
33
34#[derive(Debug, Default, Clone)]
39pub struct RequesttyBackend;
40
41impl RequesttyBackend {
42 pub const fn new() -> Self {
44 Self
45 }
46
47 fn ask_question(
49 &self,
50 question: &Question,
51 responses: &mut Responses,
52 validate: &dyn Fn(&ResponseValue, &Responses, &ResponsePath) -> Result<(), String>,
53 path_prefix: Option<&ResponsePath>,
54 ) -> Result<(), RequesttyError> {
55 let path = match path_prefix {
56 Some(prefix) => prefix.child(question.path().as_str()),
57 None => question.path().clone(),
58 };
59
60 let prompt = if question.ask().is_empty() {
62 path.as_str()
64 .split('.')
65 .last()
66 .unwrap_or("")
67 .split('_')
68 .map(|word| {
69 let mut chars = word.chars();
70 match chars.next() {
71 None => String::new(),
72 Some(first) => first.to_uppercase().chain(chars).collect(),
73 }
74 })
75 .collect::<Vec<_>>()
76 .join(" ")
77 } else {
78 question.ask().to_string()
79 };
80
81 if let DefaultValue::Assumed(value) = question.default() {
83 responses.insert(path, value.clone());
84 return Ok(());
85 }
86
87 match question.kind() {
88 QuestionKind::Unit => {
89 Ok(())
91 }
92
93 QuestionKind::Input(input_q) => self.ask_input(
94 &path,
95 &prompt,
96 input_q,
97 question.default(),
98 responses,
99 validate,
100 ),
101
102 QuestionKind::Multiline(multiline_q) => self.ask_multiline(
103 &path,
104 &prompt,
105 multiline_q,
106 question.default(),
107 responses,
108 validate,
109 ),
110
111 QuestionKind::Masked(masked_q) => self.ask_masked(
112 &path,
113 &prompt,
114 masked_q,
115 question.default(),
116 responses,
117 validate,
118 ),
119
120 QuestionKind::Int(int_q) => self.ask_int(
121 &path,
122 &prompt,
123 int_q,
124 question.default(),
125 responses,
126 validate,
127 ),
128
129 QuestionKind::Float(float_q) => self.ask_float(
130 &path,
131 &prompt,
132 float_q,
133 question.default(),
134 responses,
135 validate,
136 ),
137
138 QuestionKind::Confirm(confirm_q) => {
139 self.ask_confirm(&path, &prompt, confirm_q, question.default(), responses)
140 }
141
142 QuestionKind::List(list_q) => self.ask_list(
143 &path,
144 &prompt,
145 list_q,
146 question.default(),
147 responses,
148 validate,
149 ),
150
151 QuestionKind::OneOf(one_of) => {
152 self.ask_one_of(&path, &prompt, one_of, responses, validate)
153 }
154
155 QuestionKind::AnyOf(any_of) => {
156 self.ask_any_of(&path, &prompt, any_of, responses, validate)
157 }
158
159 QuestionKind::AllOf(all_of) => {
160 for nested_q in all_of.questions() {
162 self.ask_question(nested_q, responses, validate, Some(&path))?;
163 }
164 Ok(())
165 }
166 }
167 }
168
169 fn ask_input(
170 &self,
171 path: &ResponsePath,
172 prompt: &str,
173 input_q: &elicitor::InputQuestion,
174 default: &DefaultValue,
175 responses: &mut Responses,
176 validate: &dyn Fn(&ResponseValue, &Responses, &ResponsePath) -> Result<(), String>,
177 ) -> Result<(), RequesttyError> {
178 loop {
179 let mut q = requestty::Question::input(path.as_str()).message(prompt);
180
181 if let Some(default_val) = default.value() {
183 if let ResponseValue::String(s) = default_val {
184 q = q.default(s.clone());
185 }
186 } else if let Some(ref def) = input_q.default {
187 q = q.default(def.clone());
188 }
189
190 let responses_clone = responses.clone();
192 let path_clone = path.clone();
193 let validate_fn = move |value: &str, _: &requestty::Answers| -> Result<(), String> {
194 let rv = ResponseValue::String(value.to_string());
195 validate(&rv, &responses_clone, &path_clone)
196 };
197
198 let result = requestty::prompt_one(q.validate(validate_fn).build());
199
200 match result {
201 Ok(requestty::Answer::String(s)) => {
202 responses.insert(path.clone(), ResponseValue::String(s));
203 return Ok(());
204 }
205 Ok(other) => {
206 return Err(RequesttyError::UnexpectedAnswerType {
207 expected: "String".to_string(),
208 got: format!("{other:?}"),
209 });
210 }
211 Err(e) => {
212 if matches!(e, requestty::ErrorKind::Interrupted) {
213 return Err(RequesttyError::Cancelled);
214 }
215 eprintln!("Error: {e}");
217 continue;
218 }
219 }
220 }
221 }
222
223 fn ask_multiline(
224 &self,
225 path: &ResponsePath,
226 prompt: &str,
227 multiline_q: &elicitor::MultilineQuestion,
228 default: &DefaultValue,
229 responses: &mut Responses,
230 validate: &dyn Fn(&ResponseValue, &Responses, &ResponsePath) -> Result<(), String>,
231 ) -> Result<(), RequesttyError> {
232 loop {
233 let mut q = requestty::Question::editor(path.as_str()).message(prompt);
234
235 if let Some(default_val) = default.value() {
236 if let ResponseValue::String(s) = default_val {
237 q = q.default(s.clone());
238 }
239 } else if let Some(ref def) = multiline_q.default {
240 q = q.default(def.clone());
241 }
242
243 let responses_clone = responses.clone();
244 let path_clone = path.clone();
245 let validate_fn = move |value: &str, _: &requestty::Answers| -> Result<(), String> {
246 let rv = ResponseValue::String(value.to_string());
247 validate(&rv, &responses_clone, &path_clone)
248 };
249
250 let result = requestty::prompt_one(q.validate(validate_fn).build());
251
252 match result {
253 Ok(requestty::Answer::String(s)) => {
254 responses.insert(path.clone(), ResponseValue::String(s));
255 return Ok(());
256 }
257 Ok(other) => {
258 return Err(RequesttyError::UnexpectedAnswerType {
259 expected: "String".to_string(),
260 got: format!("{other:?}"),
261 });
262 }
263 Err(e) => {
264 if matches!(e, requestty::ErrorKind::Interrupted) {
265 return Err(RequesttyError::Cancelled);
266 }
267 eprintln!("Error: {e}");
268 continue;
269 }
270 }
271 }
272 }
273
274 fn ask_masked(
275 &self,
276 path: &ResponsePath,
277 prompt: &str,
278 masked_q: &elicitor::MaskedQuestion,
279 default: &DefaultValue,
280 responses: &mut Responses,
281 validate: &dyn Fn(&ResponseValue, &Responses, &ResponsePath) -> Result<(), String>,
282 ) -> Result<(), RequesttyError> {
283 let _ = default;
285
286 loop {
287 let mut q = requestty::Question::password(path.as_str()).message(prompt);
288
289 if let Some(mask) = masked_q.mask {
290 q = q.mask(mask);
291 }
292
293 let responses_clone = responses.clone();
294 let path_clone = path.clone();
295 let validate_fn = move |value: &str, _: &requestty::Answers| -> Result<(), String> {
296 let rv = ResponseValue::String(value.to_string());
297 validate(&rv, &responses_clone, &path_clone)
298 };
299
300 let result = requestty::prompt_one(q.validate(validate_fn).build());
301
302 match result {
303 Ok(requestty::Answer::String(s)) => {
304 responses.insert(path.clone(), ResponseValue::String(s));
305 return Ok(());
306 }
307 Ok(other) => {
308 return Err(RequesttyError::UnexpectedAnswerType {
309 expected: "String".to_string(),
310 got: format!("{other:?}"),
311 });
312 }
313 Err(e) => {
314 if matches!(e, requestty::ErrorKind::Interrupted) {
315 return Err(RequesttyError::Cancelled);
316 }
317 eprintln!("Error: {e}");
318 continue;
319 }
320 }
321 }
322 }
323
324 fn ask_int(
325 &self,
326 path: &ResponsePath,
327 prompt: &str,
328 int_q: &elicitor::IntQuestion,
329 default: &DefaultValue,
330 responses: &mut Responses,
331 validate: &dyn Fn(&ResponseValue, &Responses, &ResponsePath) -> Result<(), String>,
332 ) -> Result<(), RequesttyError> {
333 loop {
334 let mut q = requestty::Question::int(path.as_str()).message(prompt);
335
336 if let Some(default_val) = default.value() {
337 if let ResponseValue::Int(i) = default_val {
338 q = q.default(*i);
339 }
340 } else if let Some(def) = int_q.default {
341 q = q.default(def);
342 }
343
344 let min = int_q.min;
346 let max = int_q.max;
347 let responses_clone = responses.clone();
348 let path_clone = path.clone();
349
350 let validate_fn = move |value: i64, _: &requestty::Answers| -> Result<(), String> {
351 if let Some(min_val) = min
353 && value < min_val
354 {
355 return Err(format!("Value must be at least {min_val}"));
356 }
357 if let Some(max_val) = max
358 && value > max_val
359 {
360 return Err(format!("Value must be at most {max_val}"));
361 }
362 let rv = ResponseValue::Int(value);
364 validate(&rv, &responses_clone, &path_clone)
365 };
366
367 let result = requestty::prompt_one(q.validate(validate_fn).build());
368
369 match result {
370 Ok(requestty::Answer::Int(i)) => {
371 responses.insert(path.clone(), ResponseValue::Int(i));
372 return Ok(());
373 }
374 Ok(other) => {
375 return Err(RequesttyError::UnexpectedAnswerType {
376 expected: "Int".to_string(),
377 got: format!("{other:?}"),
378 });
379 }
380 Err(e) => {
381 if matches!(e, requestty::ErrorKind::Interrupted) {
382 return Err(RequesttyError::Cancelled);
383 }
384 eprintln!("Error: {e}");
385 continue;
386 }
387 }
388 }
389 }
390
391 fn ask_float(
392 &self,
393 path: &ResponsePath,
394 prompt: &str,
395 float_q: &elicitor::FloatQuestion,
396 default: &DefaultValue,
397 responses: &mut Responses,
398 validate: &dyn Fn(&ResponseValue, &Responses, &ResponsePath) -> Result<(), String>,
399 ) -> Result<(), RequesttyError> {
400 loop {
401 let mut q = requestty::Question::float(path.as_str()).message(prompt);
402
403 if let Some(default_val) = default.value() {
404 if let ResponseValue::Float(f) = default_val {
405 q = q.default(*f);
406 }
407 } else if let Some(def) = float_q.default {
408 q = q.default(def);
409 }
410
411 let min = float_q.min;
413 let max = float_q.max;
414 let responses_clone = responses.clone();
415 let path_clone = path.clone();
416
417 let validate_fn = move |value: f64, _: &requestty::Answers| -> Result<(), String> {
418 if let Some(min_val) = min
419 && value < min_val
420 {
421 return Err(format!("Value must be at least {min_val}"));
422 }
423 if let Some(max_val) = max
424 && value > max_val
425 {
426 return Err(format!("Value must be at most {max_val}"));
427 }
428 let rv = ResponseValue::Float(value);
429 validate(&rv, &responses_clone, &path_clone)
430 };
431
432 let result = requestty::prompt_one(q.validate(validate_fn).build());
433
434 match result {
435 Ok(requestty::Answer::Float(f)) => {
436 responses.insert(path.clone(), ResponseValue::Float(f));
437 return Ok(());
438 }
439 Ok(other) => {
440 return Err(RequesttyError::UnexpectedAnswerType {
441 expected: "Float".to_string(),
442 got: format!("{other:?}"),
443 });
444 }
445 Err(e) => {
446 if matches!(e, requestty::ErrorKind::Interrupted) {
447 return Err(RequesttyError::Cancelled);
448 }
449 eprintln!("Error: {e}");
450 continue;
451 }
452 }
453 }
454 }
455
456 fn ask_confirm(
457 &self,
458 path: &ResponsePath,
459 prompt: &str,
460 confirm_q: &elicitor::ConfirmQuestion,
461 default: &DefaultValue,
462 responses: &mut Responses,
463 ) -> Result<(), RequesttyError> {
464 let default_val = if let Some(ResponseValue::Bool(b)) = default.value() {
465 *b
466 } else {
467 confirm_q.default
468 };
469
470 let q = requestty::Question::confirm(path.as_str())
471 .message(prompt)
472 .default(default_val)
473 .build();
474
475 let result = requestty::prompt_one(q)?;
476
477 match result {
478 requestty::Answer::Bool(b) => {
479 responses.insert(path.clone(), ResponseValue::Bool(b));
480 Ok(())
481 }
482 other => Err(RequesttyError::UnexpectedAnswerType {
483 expected: "Bool".to_string(),
484 got: format!("{other:?}"),
485 }),
486 }
487 }
488
489 fn ask_list(
490 &self,
491 path: &ResponsePath,
492 prompt: &str,
493 list_q: &elicitor::ListQuestion,
494 _default: &DefaultValue,
495 responses: &mut Responses,
496 validate: &dyn Fn(&ResponseValue, &Responses, &ResponsePath) -> Result<(), String>,
497 ) -> Result<(), RequesttyError> {
498 let mut items: Vec<ResponseValue> = Vec::new();
499
500 println!("{}", prompt);
501 println!(" (Enter values one per line, empty line to finish)");
502
503 loop {
504 let item_prompt = format!("[{}]", items.len() + 1);
505
506 let q = requestty::Question::input(&item_prompt)
507 .message(&item_prompt)
508 .build();
509
510 let result = requestty::prompt_one(q)?;
511
512 match result {
513 requestty::Answer::String(s) if s.is_empty() => break,
514 requestty::Answer::String(s) => {
515 let value = match &list_q.element_kind {
516 ListElementKind::String => Some(ResponseValue::String(s)),
517 ListElementKind::Int { min, max } => match s.parse::<i64>() {
518 Ok(n) => {
519 if let Some(min_val) = min {
520 if n < *min_val {
521 println!(" Error: Value must be at least {min_val}");
522 continue;
523 }
524 }
525 if let Some(max_val) = max {
526 if n > *max_val {
527 println!(" Error: Value must be at most {max_val}");
528 continue;
529 }
530 }
531 Some(ResponseValue::Int(n))
532 }
533 Err(_) => {
534 println!(" Error: Please enter a valid integer");
535 continue;
536 }
537 },
538 ListElementKind::Float { min, max } => match s.parse::<f64>() {
539 Ok(n) => {
540 if let Some(min_val) = min {
541 if n < *min_val {
542 println!(" Error: Value must be at least {min_val}");
543 continue;
544 }
545 }
546 if let Some(max_val) = max {
547 if n > *max_val {
548 println!(" Error: Value must be at most {max_val}");
549 continue;
550 }
551 }
552 Some(ResponseValue::Float(n))
553 }
554 Err(_) => {
555 println!(" Error: Please enter a valid number");
556 continue;
557 }
558 },
559 };
560
561 if let Some(v) = value {
562 items.push(v);
563 }
564 }
565 _ => break,
566 }
567 }
568
569 let rv = match &list_q.element_kind {
571 ListElementKind::String => {
572 let strings: Vec<String> = items
573 .into_iter()
574 .filter_map(|v| {
575 if let ResponseValue::String(s) = v {
576 Some(s)
577 } else {
578 None
579 }
580 })
581 .collect();
582 ResponseValue::StringList(strings)
583 }
584 ListElementKind::Int { .. } => {
585 let ints: Vec<i64> = items
586 .into_iter()
587 .filter_map(|v| {
588 if let ResponseValue::Int(n) = v {
589 Some(n)
590 } else {
591 None
592 }
593 })
594 .collect();
595 ResponseValue::IntList(ints)
596 }
597 ListElementKind::Float { .. } => {
598 let floats: Vec<f64> = items
599 .into_iter()
600 .filter_map(|v| {
601 if let ResponseValue::Float(n) = v {
602 Some(n)
603 } else {
604 None
605 }
606 })
607 .collect();
608 ResponseValue::FloatList(floats)
609 }
610 };
611
612 if let Err(msg) = validate(&rv, responses, path) {
614 return Err(RequesttyError::PromptError(msg));
615 }
616
617 responses.insert(path.clone(), rv);
618 Ok(())
619 }
620
621 fn ask_one_of(
622 &self,
623 path: &ResponsePath,
624 prompt: &str,
625 one_of: &elicitor::OneOfQuestion,
626 responses: &mut Responses,
627 validate: &dyn Fn(&ResponseValue, &Responses, &ResponsePath) -> Result<(), String>,
628 ) -> Result<(), RequesttyError> {
629 let choices: Vec<String> = one_of.variants.iter().map(|v| v.name.clone()).collect();
631
632 let mut q = requestty::Question::select(path.as_str())
633 .message(prompt)
634 .choices(choices);
635
636 if let Some(default_idx) = one_of.default {
637 q = q.default(default_idx);
638 }
639
640 let result = requestty::prompt_one(q.build())?;
641
642 let selection = match result {
643 requestty::Answer::ListItem(item) => item.index,
644 other => {
645 return Err(RequesttyError::UnexpectedAnswerType {
646 expected: "ListItem".to_string(),
647 got: format!("{other:?}"),
648 });
649 }
650 };
651
652 let variant_path = path.child(SELECTED_VARIANT_KEY);
654 responses.insert(variant_path, ResponseValue::ChosenVariant(selection));
655
656 let selected_variant = &one_of.variants[selection];
658 match &selected_variant.kind {
659 QuestionKind::Unit => {
660 }
662 QuestionKind::AllOf(all_of) => {
663 for nested_q in all_of.questions() {
664 self.ask_question(nested_q, responses, validate, Some(path))?;
665 }
666 }
667 QuestionKind::Input(_)
668 | QuestionKind::Int(_)
669 | QuestionKind::Float(_)
670 | QuestionKind::Confirm(_)
671 | QuestionKind::Masked(_)
672 | QuestionKind::Multiline(_)
673 | QuestionKind::List(_) => {
674 let variant_q = Question::new(
676 selected_variant.name.clone(),
677 format!("Enter {} value:", selected_variant.name),
678 selected_variant.kind.clone(),
679 );
680 self.ask_question(&variant_q, responses, validate, Some(path))?;
681 }
682 QuestionKind::OneOf(nested_one_of) => {
683 let variant_q = Question::new(
685 selected_variant.name.clone(),
686 format!("Select {}:", selected_variant.name),
687 QuestionKind::OneOf(nested_one_of.clone()),
688 );
689 self.ask_question(&variant_q, responses, validate, Some(path))?;
690 }
691 QuestionKind::AnyOf(nested_any_of) => {
692 let variant_q = Question::new(
693 selected_variant.name.clone(),
694 format!("Select {} options:", selected_variant.name),
695 QuestionKind::AnyOf(nested_any_of.clone()),
696 );
697 self.ask_question(&variant_q, responses, validate, Some(path))?;
698 }
699 }
700
701 Ok(())
702 }
703
704 fn ask_any_of(
705 &self,
706 path: &ResponsePath,
707 prompt: &str,
708 any_of: &elicitor::AnyOfQuestion,
709 responses: &mut Responses,
710 validate: &dyn Fn(&ResponseValue, &Responses, &ResponsePath) -> Result<(), String>,
711 ) -> Result<(), RequesttyError> {
712 let selections = loop {
714 let choices: Vec<_> = any_of
716 .variants
717 .iter()
718 .enumerate()
719 .map(|(idx, v)| {
720 let selected = any_of.defaults.contains(&idx);
721 (v.name.clone(), selected)
722 })
723 .collect();
724
725 let q = requestty::Question::multi_select(path.as_str())
726 .message(prompt)
727 .choices_with_default(choices)
728 .build();
729
730 let result = requestty::prompt_one(q)?;
731
732 let selections = match result {
733 requestty::Answer::ListItems(items) => {
734 items.iter().map(|item| item.index).collect::<Vec<_>>()
735 }
736 other => {
737 return Err(RequesttyError::UnexpectedAnswerType {
738 expected: "ListItems".to_string(),
739 got: format!("{other:?}"),
740 });
741 }
742 };
743
744 let selection_value = ResponseValue::ChosenVariants(selections.clone());
746 if let Err(msg) = validate(&selection_value, responses, path) {
747 println!("Error: {msg}");
749 continue;
750 }
751
752 break selections;
753 };
754
755 let variants_path = path.child(SELECTED_VARIANTS_KEY);
757 responses.insert(
758 variants_path,
759 ResponseValue::ChosenVariants(selections.clone()),
760 );
761
762 for (item_idx, &variant_idx) in selections.iter().enumerate() {
765 let variant = &any_of.variants[variant_idx];
766 let item_path = path.child(&item_idx.to_string());
767
768 let item_variant_path = item_path.child(SELECTED_VARIANT_KEY);
770 responses.insert(item_variant_path, ResponseValue::ChosenVariant(variant_idx));
771
772 match &variant.kind {
773 QuestionKind::Unit => {
774 }
776 QuestionKind::AllOf(all_of) => {
777 for nested_q in all_of.questions() {
778 self.ask_question(nested_q, responses, validate, Some(&item_path))?;
779 }
780 }
781 _ => {
782 }
784 }
785 }
786
787 Ok(())
788 }
789}
790
791impl SurveyBackend for RequesttyBackend {
792 type Error = RequesttyError;
793
794 fn collect(
795 &self,
796 definition: &SurveyDefinition,
797 validate: &dyn Fn(&ResponseValue, &Responses, &ResponsePath) -> Result<(), String>,
798 ) -> Result<Responses, Self::Error> {
799 let mut responses = Responses::new();
800
801 if let Some(prelude) = &definition.prelude {
803 println!("{prelude}");
804 println!();
805 }
806
807 for question in definition.questions() {
809 self.ask_question(question, &mut responses, validate, None)?;
810 }
811
812 if let Some(epilogue) = &definition.epilogue {
814 println!();
815 println!("{epilogue}");
816 }
817
818 Ok(responses)
819 }
820}
821
822#[cfg(test)]
823mod tests {
824 use super::*;
825
826 #[test]
827 fn backend_creation() {
828 let _backend = RequesttyBackend::new();
829 }
830
831 #[test]
832 fn error_types() {
833 let err = RequesttyError::Cancelled;
834 assert_eq!(err.to_string(), "Survey cancelled by user");
835
836 let err = RequesttyError::PromptError("test error".to_string());
837 assert_eq!(err.to_string(), "Prompt error: test error");
838
839 let err = RequesttyError::UnexpectedAnswerType {
840 expected: "String".to_string(),
841 got: "Int".to_string(),
842 };
843 assert_eq!(
844 err.to_string(),
845 "Unexpected answer type: expected String, got Int"
846 );
847 }
848}