flag_rs/
validator.rs

1//! Argument validation for commands
2//!
3//! This module provides validation capabilities for command arguments,
4//! allowing commands to enforce constraints on the number and type of
5//! arguments they accept.
6
7use crate::error::{Error, Result};
8
9/// Type alias for custom validation functions
10pub type ValidatorFn = dyn Fn(&[String]) -> Result<()> + Send + Sync;
11
12/// Defines validation rules for command arguments
13#[derive(Clone)]
14pub enum ArgValidator {
15    /// Exactly N arguments required
16    ExactArgs(usize),
17    /// At least N arguments required
18    MinimumArgs(usize),
19    /// At most N arguments allowed
20    MaximumArgs(usize),
21    /// Between min and max arguments (inclusive)
22    RangeArgs(usize, usize),
23    /// Arguments must be in the valid args list
24    OnlyValidArgs(Vec<String>),
25    /// Custom validation function
26    Custom(std::sync::Arc<ValidatorFn>),
27}
28
29impl ArgValidator {
30    /// Validates the given arguments against this validator
31    pub fn validate(&self, args: &[String]) -> Result<()> {
32        match self {
33            Self::ExactArgs(expected) => {
34                if args.len() != *expected {
35                    return Err(Error::ArgumentValidation {
36                        message: format!("accepts {} arg(s), received {}", expected, args.len()),
37                        expected: expected.to_string(),
38                        received: args.len(),
39                    });
40                }
41                Ok(())
42            }
43            Self::MinimumArgs(min) => {
44                if args.len() < *min {
45                    return Err(Error::ArgumentValidation {
46                        message: format!(
47                            "requires at least {} arg(s), received {}",
48                            min,
49                            args.len()
50                        ),
51                        expected: format!("at least {min}"),
52                        received: args.len(),
53                    });
54                }
55                Ok(())
56            }
57            Self::MaximumArgs(max) => {
58                if args.len() > *max {
59                    return Err(Error::ArgumentValidation {
60                        message: format!("accepts at most {} arg(s), received {}", max, args.len()),
61                        expected: format!("at most {max}"),
62                        received: args.len(),
63                    });
64                }
65                Ok(())
66            }
67            Self::RangeArgs(min, max) => {
68                if args.len() < *min || args.len() > *max {
69                    return Err(Error::ArgumentValidation {
70                        message: format!(
71                            "accepts between {} and {} arg(s), received {}",
72                            min,
73                            max,
74                            args.len()
75                        ),
76                        expected: format!("{min} to {max}"),
77                        received: args.len(),
78                    });
79                }
80                Ok(())
81            }
82            Self::OnlyValidArgs(valid_args) => {
83                for arg in args {
84                    if !valid_args.contains(arg) {
85                        return Err(Error::ArgumentValidation {
86                            message: format!("invalid argument \"{arg}\""),
87                            expected: format!("one of: {}", valid_args.join(", ")),
88                            received: 1,
89                        });
90                    }
91                }
92                Ok(())
93            }
94            Self::Custom(validator) => validator(args),
95        }
96    }
97}
98
99impl std::fmt::Debug for ArgValidator {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        match self {
102            Self::ExactArgs(n) => write!(f, "ExactArgs({n})"),
103            Self::MinimumArgs(n) => write!(f, "MinimumArgs({n})"),
104            Self::MaximumArgs(n) => write!(f, "MaximumArgs({n})"),
105            Self::RangeArgs(min, max) => write!(f, "RangeArgs({min}, {max})"),
106            Self::OnlyValidArgs(args) => write!(f, "OnlyValidArgs({args:?})"),
107            Self::Custom(_) => write!(f, "Custom(<function>)"),
108        }
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn test_exact_args() {
118        let validator = ArgValidator::ExactArgs(2);
119
120        // Should pass with exactly 2 args
121        assert!(
122            validator
123                .validate(&["arg1".to_string(), "arg2".to_string()])
124                .is_ok()
125        );
126
127        // Should fail with wrong number of args
128        assert!(validator.validate(&["arg1".to_string()]).is_err());
129        assert!(
130            validator
131                .validate(&["arg1".to_string(), "arg2".to_string(), "arg3".to_string()])
132                .is_err()
133        );
134    }
135
136    #[test]
137    fn test_minimum_args() {
138        let validator = ArgValidator::MinimumArgs(2);
139
140        // Should pass with 2 or more args
141        assert!(
142            validator
143                .validate(&["arg1".to_string(), "arg2".to_string()])
144                .is_ok()
145        );
146        assert!(
147            validator
148                .validate(&["arg1".to_string(), "arg2".to_string(), "arg3".to_string()])
149                .is_ok()
150        );
151
152        // Should fail with fewer args
153        assert!(validator.validate(&["arg1".to_string()]).is_err());
154        assert!(validator.validate(&[]).is_err());
155    }
156
157    #[test]
158    fn test_maximum_args() {
159        let validator = ArgValidator::MaximumArgs(2);
160
161        // Should pass with 2 or fewer args
162        assert!(validator.validate(&[]).is_ok());
163        assert!(validator.validate(&["arg1".to_string()]).is_ok());
164        assert!(
165            validator
166                .validate(&["arg1".to_string(), "arg2".to_string()])
167                .is_ok()
168        );
169
170        // Should fail with more args
171        assert!(
172            validator
173                .validate(&["arg1".to_string(), "arg2".to_string(), "arg3".to_string()])
174                .is_err()
175        );
176    }
177
178    #[test]
179    fn test_range_args() {
180        let validator = ArgValidator::RangeArgs(1, 3);
181
182        // Should pass within range
183        assert!(validator.validate(&["arg1".to_string()]).is_ok());
184        assert!(
185            validator
186                .validate(&["arg1".to_string(), "arg2".to_string()])
187                .is_ok()
188        );
189        assert!(
190            validator
191                .validate(&["arg1".to_string(), "arg2".to_string(), "arg3".to_string()])
192                .is_ok()
193        );
194
195        // Should fail outside range
196        assert!(validator.validate(&[]).is_err());
197        assert!(
198            validator
199                .validate(&[
200                    "1".to_string(),
201                    "2".to_string(),
202                    "3".to_string(),
203                    "4".to_string()
204                ])
205                .is_err()
206        );
207    }
208
209    #[test]
210    fn test_only_valid_args() {
211        let validator = ArgValidator::OnlyValidArgs(vec![
212            "start".to_string(),
213            "stop".to_string(),
214            "restart".to_string(),
215        ]);
216
217        // Should pass with valid args
218        assert!(validator.validate(&["start".to_string()]).is_ok());
219        assert!(
220            validator
221                .validate(&["stop".to_string(), "restart".to_string()])
222                .is_ok()
223        );
224
225        // Should fail with invalid args
226        assert!(validator.validate(&["invalid".to_string()]).is_err());
227        assert!(
228            validator
229                .validate(&["start".to_string(), "invalid".to_string()])
230                .is_err()
231        );
232    }
233
234    #[test]
235    fn test_custom_validator() {
236        let validator = ArgValidator::Custom(std::sync::Arc::new(|args| {
237            if args.iter().all(|arg| arg.parse::<i32>().is_ok()) {
238                Ok(())
239            } else {
240                Err(Error::ArgumentValidation {
241                    message: "all arguments must be integers".to_string(),
242                    expected: "integers".to_string(),
243                    received: args.len(),
244                })
245            }
246        }));
247
248        // Should pass with all integers
249        assert!(
250            validator
251                .validate(&["1".to_string(), "2".to_string(), "3".to_string()])
252                .is_ok()
253        );
254
255        // Should fail with non-integers
256        assert!(
257            validator
258                .validate(&["1".to_string(), "abc".to_string()])
259                .is_err()
260        );
261    }
262}