1use crate::error::{Error, Result};
8
9pub type ValidatorFn = dyn Fn(&[String]) -> Result<()> + Send + Sync;
11
12#[derive(Clone)]
14pub enum ArgValidator {
15 ExactArgs(usize),
17 MinimumArgs(usize),
19 MaximumArgs(usize),
21 RangeArgs(usize, usize),
23 OnlyValidArgs(Vec<String>),
25 Custom(std::sync::Arc<ValidatorFn>),
27}
28
29impl ArgValidator {
30 pub fn validate(&self, args: &[String]) -> Result<()> {
32 match self {
33 Self::ExactArgs(expected) => {
34 if args.len() != *expected {
35 return Err(Error::ArgumentValidation {
36 message: format!("accepts {} arg(s), received {}", expected, args.len()),
37 expected: expected.to_string(),
38 received: args.len(),
39 });
40 }
41 Ok(())
42 }
43 Self::MinimumArgs(min) => {
44 if args.len() < *min {
45 return Err(Error::ArgumentValidation {
46 message: format!(
47 "requires at least {} arg(s), received {}",
48 min,
49 args.len()
50 ),
51 expected: format!("at least {min}"),
52 received: args.len(),
53 });
54 }
55 Ok(())
56 }
57 Self::MaximumArgs(max) => {
58 if args.len() > *max {
59 return Err(Error::ArgumentValidation {
60 message: format!("accepts at most {} arg(s), received {}", max, args.len()),
61 expected: format!("at most {max}"),
62 received: args.len(),
63 });
64 }
65 Ok(())
66 }
67 Self::RangeArgs(min, max) => {
68 if args.len() < *min || args.len() > *max {
69 return Err(Error::ArgumentValidation {
70 message: format!(
71 "accepts between {} and {} arg(s), received {}",
72 min,
73 max,
74 args.len()
75 ),
76 expected: format!("{min} to {max}"),
77 received: args.len(),
78 });
79 }
80 Ok(())
81 }
82 Self::OnlyValidArgs(valid_args) => {
83 for arg in args {
84 if !valid_args.contains(arg) {
85 return Err(Error::ArgumentValidation {
86 message: format!("invalid argument \"{arg}\""),
87 expected: format!("one of: {}", valid_args.join(", ")),
88 received: 1,
89 });
90 }
91 }
92 Ok(())
93 }
94 Self::Custom(validator) => validator(args),
95 }
96 }
97}
98
99impl std::fmt::Debug for ArgValidator {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 match self {
102 Self::ExactArgs(n) => write!(f, "ExactArgs({n})"),
103 Self::MinimumArgs(n) => write!(f, "MinimumArgs({n})"),
104 Self::MaximumArgs(n) => write!(f, "MaximumArgs({n})"),
105 Self::RangeArgs(min, max) => write!(f, "RangeArgs({min}, {max})"),
106 Self::OnlyValidArgs(args) => write!(f, "OnlyValidArgs({args:?})"),
107 Self::Custom(_) => write!(f, "Custom(<function>)"),
108 }
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_exact_args() {
118 let validator = ArgValidator::ExactArgs(2);
119
120 assert!(
122 validator
123 .validate(&["arg1".to_string(), "arg2".to_string()])
124 .is_ok()
125 );
126
127 assert!(validator.validate(&["arg1".to_string()]).is_err());
129 assert!(
130 validator
131 .validate(&["arg1".to_string(), "arg2".to_string(), "arg3".to_string()])
132 .is_err()
133 );
134 }
135
136 #[test]
137 fn test_minimum_args() {
138 let validator = ArgValidator::MinimumArgs(2);
139
140 assert!(
142 validator
143 .validate(&["arg1".to_string(), "arg2".to_string()])
144 .is_ok()
145 );
146 assert!(
147 validator
148 .validate(&["arg1".to_string(), "arg2".to_string(), "arg3".to_string()])
149 .is_ok()
150 );
151
152 assert!(validator.validate(&["arg1".to_string()]).is_err());
154 assert!(validator.validate(&[]).is_err());
155 }
156
157 #[test]
158 fn test_maximum_args() {
159 let validator = ArgValidator::MaximumArgs(2);
160
161 assert!(validator.validate(&[]).is_ok());
163 assert!(validator.validate(&["arg1".to_string()]).is_ok());
164 assert!(
165 validator
166 .validate(&["arg1".to_string(), "arg2".to_string()])
167 .is_ok()
168 );
169
170 assert!(
172 validator
173 .validate(&["arg1".to_string(), "arg2".to_string(), "arg3".to_string()])
174 .is_err()
175 );
176 }
177
178 #[test]
179 fn test_range_args() {
180 let validator = ArgValidator::RangeArgs(1, 3);
181
182 assert!(validator.validate(&["arg1".to_string()]).is_ok());
184 assert!(
185 validator
186 .validate(&["arg1".to_string(), "arg2".to_string()])
187 .is_ok()
188 );
189 assert!(
190 validator
191 .validate(&["arg1".to_string(), "arg2".to_string(), "arg3".to_string()])
192 .is_ok()
193 );
194
195 assert!(validator.validate(&[]).is_err());
197 assert!(
198 validator
199 .validate(&[
200 "1".to_string(),
201 "2".to_string(),
202 "3".to_string(),
203 "4".to_string()
204 ])
205 .is_err()
206 );
207 }
208
209 #[test]
210 fn test_only_valid_args() {
211 let validator = ArgValidator::OnlyValidArgs(vec![
212 "start".to_string(),
213 "stop".to_string(),
214 "restart".to_string(),
215 ]);
216
217 assert!(validator.validate(&["start".to_string()]).is_ok());
219 assert!(
220 validator
221 .validate(&["stop".to_string(), "restart".to_string()])
222 .is_ok()
223 );
224
225 assert!(validator.validate(&["invalid".to_string()]).is_err());
227 assert!(
228 validator
229 .validate(&["start".to_string(), "invalid".to_string()])
230 .is_err()
231 );
232 }
233
234 #[test]
235 fn test_custom_validator() {
236 let validator = ArgValidator::Custom(std::sync::Arc::new(|args| {
237 if args.iter().all(|arg| arg.parse::<i32>().is_ok()) {
238 Ok(())
239 } else {
240 Err(Error::ArgumentValidation {
241 message: "all arguments must be integers".to_string(),
242 expected: "integers".to_string(),
243 received: args.len(),
244 })
245 }
246 }));
247
248 assert!(
250 validator
251 .validate(&["1".to_string(), "2".to_string(), "3".to_string()])
252 .is_ok()
253 );
254
255 assert!(
257 validator
258 .validate(&["1".to_string(), "abc".to_string()])
259 .is_err()
260 );
261 }
262}