aiscript_directive/validator/
mod.rs

1use std::any::Any;
2
3use date::DateValidator;
4use regex::RegexValidator;
5use serde_json::Value;
6
7use crate::{Directive, DirectiveParams, FromDirective};
8
9mod array;
10mod date;
11mod format;
12mod regex;
13
14pub trait Validator: Send + Sync + Any {
15    fn name(&self) -> &'static str;
16    fn validate(&self, value: &Value) -> Result<(), String>;
17    fn as_any(&self) -> &dyn Any;
18    fn downcast_ref<U: Any>(&self) -> Option<&U>
19    where
20        Self: Sized,
21    {
22        self.as_any().downcast_ref::<U>()
23    }
24}
25
26// Nighly feature gate required
27// impl dyn Validator {
28//     pub fn as_any(&self) -> &dyn Any where Self: 'static {
29//         self
30//     }
31// }
32
33impl Validator for Box<dyn Validator> {
34    fn name(&self) -> &'static str {
35        self.as_ref().name()
36    }
37
38    fn validate(&self, value: &Value) -> Result<(), String> {
39        self.as_ref().validate(value)
40    }
41
42    fn as_any(&self) -> &dyn Any {
43        self.as_ref().as_any()
44    }
45}
46
47impl<T: Validator> Validator for Box<T> {
48    fn name(&self) -> &'static str {
49        self.as_ref().name()
50    }
51
52    fn validate(&self, value: &Value) -> Result<(), String> {
53        self.as_ref().validate(value)
54    }
55
56    fn as_any(&self) -> &dyn Any {
57        self.as_ref().as_any()
58    }
59}
60
61pub struct AnyValidator<V>(pub Box<[V]>);
62
63pub struct NotValidator<V>(pub V);
64
65#[derive(Default)]
66pub struct StringValidator {
67    pub min_len: Option<u32>,
68    pub max_len: Option<u32>,
69    pub exact_len: Option<u32>,
70    // regex: Option<String>,
71    pub start_with: Option<String>,
72    pub end_with: Option<String>,
73}
74
75pub struct NumberValidator {
76    min: Option<f64>,
77    max: Option<f64>,
78    equal: Option<f64>,
79    strict_int: Option<bool>,
80    strict_float: Option<bool>,
81}
82
83pub struct InValidator(pub Vec<Value>);
84
85impl<V: Validator> Validator for AnyValidator<V> {
86    fn name(&self) -> &'static str {
87        "@any"
88    }
89
90    fn validate(&self, value: &Value) -> Result<(), String> {
91        for validator in &self.0 {
92            validator.validate(value)?
93        }
94        Ok(())
95    }
96
97    fn as_any(&self) -> &dyn Any {
98        self
99    }
100}
101
102impl<V: Validator> Validator for NotValidator<V> {
103    fn name(&self) -> &'static str {
104        "@not"
105    }
106
107    fn validate(&self, value: &Value) -> Result<(), String> {
108        let validator = &self.0;
109        if validator.validate(value).is_ok() {
110            return Err("Value does not meet the validation criteria".into());
111        }
112        Ok(())
113    }
114
115    fn as_any(&self) -> &dyn Any {
116        self
117    }
118}
119
120impl Validator for StringValidator {
121    fn name(&self) -> &'static str {
122        "@string"
123    }
124
125    fn validate(&self, value: &Value) -> Result<(), String> {
126        let value = value.as_str().unwrap();
127        if let Some(min_len) = self.min_len {
128            if value.len() < min_len as usize {
129                return Err(format!(
130                    "String length is less than the minimum length of {}",
131                    min_len
132                ));
133            }
134        }
135        if let Some(max_len) = self.max_len {
136            if value.len() > max_len as usize {
137                return Err(format!(
138                    "String length is greater than the maximum length of {}",
139                    max_len
140                ));
141            }
142        }
143
144        if let Some(exact_len) = self.exact_len {
145            if value.len() != exact_len as usize {
146                return Err(format!(
147                    "String length is not equal to the exact length of {}",
148                    exact_len
149                ));
150            }
151        }
152
153        // if let Some(regex) = &self.regex {
154        //     let regex = regex::Regex::new(regex).unwrap();
155        //     if !regex.is_match(value) {
156        //         return Err(format!(
157        //             "String does not match the required regex pattern: {}",
158        //             regex
159        //         ));
160        //     }
161        // }
162
163        if let Some(start_with) = &self.start_with {
164            if !value.starts_with(start_with) {
165                return Err(format!(
166                    "String does not start with the required string: {}",
167                    start_with
168                ));
169            }
170        }
171
172        if let Some(end_with) = &self.end_with {
173            if !value.ends_with(end_with) {
174                return Err(format!(
175                    "String does not end with the required string: {}",
176                    end_with
177                ));
178            }
179        }
180
181        Ok(())
182    }
183
184    fn as_any(&self) -> &dyn Any {
185        self
186    }
187}
188
189impl Validator for NumberValidator {
190    fn name(&self) -> &'static str {
191        "@number"
192    }
193
194    fn validate(&self, value: &Value) -> Result<(), String> {
195        let num = value.as_number().unwrap();
196        let value = num.as_f64().unwrap();
197        if let (Some(true), Some(true)) = (self.strict_int, self.strict_float) {
198            return Err("Cannot set both strict_int and strict_float to true".into());
199        }
200        if let Some(true) = self.strict_int {
201            if !num.is_i64() {
202                return Err("Value must be an integer".into());
203            }
204        }
205        if let Some(true) = self.strict_float {
206            if num.is_i64() {
207                return Err("Value must be a float".into());
208            }
209        }
210        if let Some(min) = self.min {
211            if value < min {
212                return Err(format!("Number is less than the minimum value of {}", min));
213            }
214        }
215        if let Some(max) = self.max {
216            if value > max {
217                return Err(format!(
218                    "Number is greater than the maximum value of {}",
219                    max
220                ));
221            }
222        }
223        if let Some(equal) = self.equal {
224            if value != equal {
225                return Err(format!(
226                    "Number is not equal to the required value of {}",
227                    equal
228                ));
229            }
230        }
231        Ok(())
232    }
233
234    fn as_any(&self) -> &dyn Any {
235        self
236    }
237}
238
239impl Validator for InValidator {
240    fn name(&self) -> &'static str {
241        "@in"
242    }
243
244    fn validate(&self, value: &Value) -> Result<(), String> {
245        if self.0.contains(value) {
246            Ok(())
247        } else {
248            Err("Value is not in the list of allowed values".into())
249        }
250    }
251
252    fn as_any(&self) -> &dyn Any {
253        self
254    }
255}
256
257impl FromDirective for Box<dyn Validator> {
258    fn from_directive(directive: Directive) -> Result<Self, String>
259    where
260        Self: Sized,
261    {
262        match directive.name.as_str() {
263            "string" => Ok(Box::new(StringValidator::from_directive(directive)?)),
264            "number" => Ok(Box::new(NumberValidator::from_directive(directive)?)),
265            "in" => Ok(Box::new(InValidator::from_directive(directive)?)),
266            "any" => Ok(Box::new(AnyValidator::from_directive(directive)?)),
267            "not" => Ok(Box::new(NotValidator::from_directive(directive)?)),
268            "date" => Ok(Box::new(DateValidator::from_directive(directive)?)),
269            "array" => Ok(Box::new(AnyValidator::from_directive(directive)?)),
270            "regex" => Ok(Box::new(RegexValidator::from_directive(directive)?)),
271            v => Err(format!("Invalid validators: @{}", v)),
272        }
273    }
274}
275
276impl FromDirective for StringValidator {
277    fn from_directive(Directive { params, .. }: Directive) -> Result<Self, String> {
278        match params {
279            DirectiveParams::KeyValue(params) => {
280                Ok(Self {
281                    min_len: params
282                        .get("min_len")
283                        .and_then(|v| v.as_u64().map(|v| v as u32)),
284                    max_len: params
285                        .get("max_len")
286                        .and_then(|v| v.as_u64().map(|v| v as u32)),
287                    exact_len: params
288                        .get("exact_len")
289                        .and_then(|v| v.as_u64().map(|v| v as u32)),
290                    // regex: params
291                    //     .get("regex")
292                    //     .and_then(|v| v.as_str().map(|v| v.to_string())),
293                    start_with: params
294                        .get("start_with")
295                        .and_then(|v| v.as_str().map(|v| v.to_string())),
296                    end_with: params
297                        .get("end_with")
298                        .and_then(|v| v.as_str().map(|v| v.to_string())),
299                })
300            }
301            _ => Err("Invalid params for @string directive".into()),
302        }
303    }
304}
305
306impl FromDirective for InValidator {
307    fn from_directive(Directive { params, .. }: Directive) -> Result<Self, String> {
308        match params {
309            DirectiveParams::Array(values) => Ok(Self(values)),
310            _ => Err("Invalid params for @in directive".into()),
311        }
312    }
313}
314
315impl FromDirective for AnyValidator<Box<dyn Validator>> {
316    fn from_directive(Directive { params, .. }: Directive) -> Result<Self, String> {
317        match params {
318            DirectiveParams::Directives(directives) => {
319                let mut validators = Vec::with_capacity(directives.len());
320                for directive in directives {
321                    validators.push(FromDirective::from_directive(directive)?);
322                }
323                Ok(Self(validators.into_boxed_slice()))
324            }
325            _ => Err("Invalid params for @any directive".into()),
326        }
327    }
328}
329
330impl FromDirective for NotValidator<Box<dyn Validator>> {
331    fn from_directive(Directive { params, .. }: Directive) -> Result<Self, String> {
332        match params {
333            DirectiveParams::Directives(mut directives) => {
334                if let Some(directive) = directives.pop() {
335                    let validator = FromDirective::from_directive(directive)?;
336                    if !directives.is_empty() {
337                        return Err("@not directive only support one directive".into());
338                    }
339
340                    Ok(Self(validator))
341                } else {
342                    Err("@not directive requires one directive".into())
343                }
344            }
345            _ => Err("Invalid params for @not directive, expect a directive".into()),
346        }
347    }
348}
349
350impl FromDirective for NumberValidator {
351    fn from_directive(Directive { params, .. }: Directive) -> Result<Self, String> {
352        match params {
353            DirectiveParams::KeyValue(params) => Ok(Self {
354                min: params.get("min").and_then(|v| v.as_f64()),
355                max: params.get("max").and_then(|v| v.as_f64()),
356                equal: params.get("equal").and_then(|v| v.as_f64()),
357                strict_int: params.get("strict_int").and_then(|v| v.as_bool()),
358                strict_float: params.get("strict_float").and_then(|v| v.as_bool()),
359            }),
360            _ => Err("Invalid params for @number directive".into()),
361        }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use serde_json::json;
369
370    // Simple validator implementations for testing
371    #[derive(Debug)]
372    struct RangeValidator {
373        min: i64,
374        max: i64,
375    }
376
377    impl RangeValidator {
378        fn new(min: i64, max: i64) -> Self {
379            Self { min, max }
380        }
381    }
382
383    impl Validator for RangeValidator {
384        fn name(&self) -> &'static str {
385            "range"
386        }
387
388        fn validate(&self, value: &Value) -> Result<(), String> {
389            if let Some(num) = value.as_i64() {
390                if num >= self.min && num <= self.max {
391                    Ok(())
392                } else {
393                    Err(format!(
394                        "Value must be between {} and {}",
395                        self.min, self.max
396                    ))
397                }
398            } else {
399                Err("Value must be an integer".to_string())
400            }
401        }
402
403        fn as_any(&self) -> &dyn Any {
404            self
405        }
406    }
407
408    #[test]
409    fn test_direct_downcast() {
410        let range_validator = RangeValidator::new(1, 10);
411
412        // Downcast using concrete type
413        let downcast_result = range_validator.downcast_ref::<RangeValidator>();
414        assert!(downcast_result.is_some());
415
416        let range = downcast_result.unwrap();
417        assert_eq!(range.min, 1);
418        assert_eq!(range.max, 10);
419
420        let range_validator = RangeValidator::new(1, 10);
421        let validator: Box<dyn Validator> = Box::new(range_validator);
422        let v = validator.downcast_ref::<RangeValidator>().unwrap();
423        assert_eq!(v.min, 1);
424        assert_eq!(v.max, 10);
425    }
426
427    #[test]
428    fn test_nested_downcast() {
429        let inner_validator = Box::new(RangeValidator::new(1, 10));
430        let not_validator = NotValidator(inner_validator);
431
432        // Test successful downcast
433        let downcast_result = not_validator.downcast_ref::<NotValidator<Box<RangeValidator>>>();
434        assert!(downcast_result.is_some());
435    }
436
437    #[test]
438    fn test_any_validator_downcast() {
439        let validators = vec![
440            Box::new(RangeValidator::new(1, 10)),
441            Box::new(RangeValidator::new(0, 5)),
442        ];
443        let any_validator = AnyValidator(validators.into_boxed_slice());
444
445        // Test successful downcast
446        let downcast_result = any_validator.downcast_ref::<AnyValidator<Box<RangeValidator>>>();
447        assert!(downcast_result.is_some());
448
449        // Verify the inner validators
450        let any = downcast_result.unwrap();
451        assert_eq!(any.0.len(), 2);
452    }
453
454    #[test]
455    fn test_wrong_downcast() {
456        let range_validator = RangeValidator::new(1, 10);
457
458        // Try to downcast to wrong types
459        let not_result = range_validator.downcast_ref::<NotValidator<Box<dyn Validator>>>();
460        assert!(not_result.is_none());
461
462        let any_result = range_validator.downcast_ref::<AnyValidator<Box<dyn Validator>>>();
463        assert!(any_result.is_none());
464    }
465
466    #[test]
467    fn test_downcast_and_validate() {
468        let range_validator = RangeValidator::new(1, 10);
469
470        // Downcast and validate
471        if let Some(range) = range_validator.downcast_ref::<RangeValidator>() {
472            assert!(range.validate(&json!(5)).is_ok());
473            assert!(range.validate(&json!(0)).is_err());
474            assert!(range.validate(&json!(11)).is_err());
475            assert!(range.validate(&json!("not a number")).is_err());
476        } else {
477            panic!("Downcast failed");
478        }
479    }
480
481    #[test]
482    fn test_nested_validator_chain() {
483        let range = Box::new(RangeValidator::new(1, 10));
484        let not = NotValidator(range);
485        let any = AnyValidator(vec![not].into_boxed_slice());
486
487        // Test validation behavior of the chain
488        assert!(any.validate(&json!(0)).is_ok()); // Outside range, so NotValidator makes it valid
489        assert!(any.validate(&json!(5)).is_err()); // Inside range, so NotValidator makes it invalid
490
491        // Test downcasting of each layer
492        let any_downcast = any
493            .downcast_ref::<AnyValidator<NotValidator<Box<RangeValidator>>>>()
494            .expect("Should downcast to AnyValidator");
495        assert_eq!(any_downcast.0.len(), 1);
496    }
497}