askr/validation/rules/
choice.rs

1use super::super::{PartialValidationResult, Priority, ValidationResult, Validator};
2use std::collections::HashSet;
3
4/// Choice validator for selecting from predefined options
5#[derive(Debug)]
6pub struct ChoiceValidator {
7    choices: Vec<String>,
8    case_sensitive: bool,
9    min_choices: usize,
10    max_choices: usize,
11    priority: Priority,
12    custom_message: Option<String>,
13}
14
15impl ChoiceValidator {
16    pub fn new(choices: Vec<String>) -> Self {
17        Self {
18            choices,
19            case_sensitive: false,
20            min_choices: 1,
21            max_choices: 1,
22            priority: Priority::High,
23            custom_message: None,
24        }
25    }
26
27    pub fn case_sensitive(mut self, case_sensitive: bool) -> Self {
28        self.case_sensitive = case_sensitive;
29        self
30    }
31
32    pub fn min_choices(mut self, min: usize) -> Self {
33        self.min_choices = min;
34        self
35    }
36
37    pub fn max_choices(mut self, max: usize) -> Self {
38        self.max_choices = max;
39        self
40    }
41
42    pub fn with_priority(mut self, priority: Priority) -> Self {
43        self.priority = priority;
44        self
45    }
46
47    pub fn with_message(mut self, message: impl Into<String>) -> Self {
48        self.custom_message = Some(message.into());
49        self
50    }
51
52    /// Parse input for multiple choices (comma-separated)
53    fn parse_input(&self, input: &str) -> Vec<String> {
54        if self.max_choices == 1 {
55            vec![input.trim().to_string()]
56        } else {
57            input
58                .split(',')
59                .map(|s| s.trim().to_string())
60                .filter(|s| !s.is_empty())
61                .collect()
62        }
63    }
64
65    /// Check if a choice matches any of the valid options
66    fn is_valid_choice(&self, choice: &str) -> bool {
67        if self.case_sensitive {
68            self.choices.contains(&choice.to_string())
69        } else {
70            let choice_lower = choice.to_lowercase();
71            self.choices
72                .iter()
73                .any(|c| c.to_lowercase() == choice_lower)
74        }
75    }
76
77    /// Get the canonical form of a choice (with correct case)
78    fn get_canonical_choice(&self, choice: &str) -> Option<String> {
79        if self.case_sensitive {
80            if self.choices.contains(&choice.to_string()) {
81                Some(choice.to_string())
82            } else {
83                None
84            }
85        } else {
86            let choice_lower = choice.to_lowercase();
87            self.choices
88                .iter()
89                .find(|c| c.to_lowercase() == choice_lower)
90                .cloned()
91        }
92    }
93}
94
95impl Validator for ChoiceValidator {
96    fn validate(&self, input: &str) -> ValidationResult {
97        let parsed_choices = self.parse_input(input);
98
99        // Check choice count
100        if parsed_choices.len() < self.min_choices {
101            let message = if let Some(msg) = &self.custom_message {
102                msg.clone()
103            } else {
104                format!("At least {} choice(s) required", self.min_choices)
105            };
106            return ValidationResult::failure("choice", self.priority, &message);
107        }
108
109        if parsed_choices.len() > self.max_choices {
110            let message = if let Some(msg) = &self.custom_message {
111                msg.clone()
112            } else {
113                format!("At most {} choice(s) allowed", self.max_choices)
114            };
115            return ValidationResult::failure("choice", self.priority, &message);
116        }
117
118        // Check for duplicates
119        let mut seen = HashSet::new();
120        let mut duplicates = Vec::new();
121
122        for choice in &parsed_choices {
123            if let Some(canonical) = self.get_canonical_choice(choice) {
124                if !seen.insert(canonical.clone()) {
125                    duplicates.push(canonical);
126                }
127            }
128        }
129
130        if !duplicates.is_empty() {
131            let message = if let Some(msg) = &self.custom_message {
132                msg.clone()
133            } else {
134                format!("Duplicate choices not allowed: {}", duplicates.join(", "))
135            };
136            return ValidationResult::failure("choice", self.priority, &message);
137        }
138
139        // Check each choice validity
140        let mut invalid_choices = Vec::new();
141        for choice in &parsed_choices {
142            if !self.is_valid_choice(choice) {
143                invalid_choices.push(choice.clone());
144            }
145        }
146
147        if !invalid_choices.is_empty() {
148            let message = if let Some(msg) = &self.custom_message {
149                msg.clone()
150            } else {
151                let valid_choices_str = self.choices.join(", ");
152                format!(
153                    "Invalid choice(s): {}. Valid options: {}",
154                    invalid_choices.join(", "),
155                    valid_choices_str
156                )
157            };
158            return ValidationResult::failure("choice", self.priority, &message);
159        }
160
161        ValidationResult::success("choice")
162    }
163
164    fn partial_validate(&self, input: &str, _cursor_pos: usize) -> PartialValidationResult {
165        if input.is_empty() {
166            return PartialValidationResult::valid();
167        }
168
169        // For single choice, check partial matching
170        if self.max_choices == 1 {
171            // Check if any choice starts with the current input
172            let has_partial_match = self.choices.iter().any(|choice| {
173                if self.case_sensitive {
174                    choice.starts_with(input)
175                } else {
176                    choice.to_lowercase().starts_with(&input.to_lowercase())
177                }
178            });
179
180            if !has_partial_match {
181                // Find the first character where it diverges
182                for (i, _ch) in input.char_indices() {
183                    let partial = &input[..=i];
184
185                    let has_match = self.choices.iter().any(|choice| {
186                        if self.case_sensitive {
187                            choice.starts_with(partial)
188                        } else {
189                            choice.to_lowercase().starts_with(&partial.to_lowercase())
190                        }
191                    });
192
193                    if !has_match {
194                        return PartialValidationResult::error_at(i);
195                    }
196                }
197            }
198        } else {
199            // For multiple choices, validate the current choice being typed
200            let parts: Vec<&str> = input.split(',').collect();
201            if let Some(current_choice) = parts.last() {
202                let current_choice = current_choice.trim();
203                if !current_choice.is_empty() {
204                    let has_partial_match = self.choices.iter().any(|choice| {
205                        if self.case_sensitive {
206                            choice.starts_with(current_choice)
207                        } else {
208                            choice
209                                .to_lowercase()
210                                .starts_with(&current_choice.to_lowercase())
211                        }
212                    });
213
214                    if !has_partial_match {
215                        // Calculate position of error in the current choice
216                        let prefix_len: usize = parts[..parts.len() - 1]
217                            .iter()
218                            .map(|p| p.len() + 1) // +1 for comma
219                            .sum();
220
221                        for (i, _ch) in current_choice.char_indices() {
222                            let partial = current_choice[..=i].trim();
223
224                            let has_match = self.choices.iter().any(|choice| {
225                                if self.case_sensitive {
226                                    choice.starts_with(partial)
227                                } else {
228                                    choice.to_lowercase().starts_with(&partial.to_lowercase())
229                                }
230                            });
231
232                            if !has_match {
233                                return PartialValidationResult::error_at(prefix_len + i);
234                            }
235                        }
236                    }
237                }
238            }
239        }
240
241        PartialValidationResult::valid()
242    }
243
244    fn priority(&self) -> Priority {
245        self.priority
246    }
247
248    fn name(&self) -> &str {
249        "choice"
250    }
251}