Skip to main content

rustia/
validate.rs

1use std::collections::{BTreeMap, HashMap};
2
3use base64::Engine as _;
4use regex::Regex;
5use serde::de::DeserializeOwned;
6use serde_json::Value;
7use url::Url;
8use uuid::Uuid;
9
10/// Validation result with rustia-compatible success/failure discriminator.
11#[derive(Debug, Clone, PartialEq)]
12pub enum IValidation<T> {
13    Success {
14        data: T,
15    },
16    Failure {
17        data: Value,
18        errors: Vec<IValidationError>,
19    },
20}
21
22/// Validation error detail compatible with rustia's validate() payload shape.
23#[derive(Debug, Clone, PartialEq)]
24pub struct IValidationError {
25    pub path: String,
26    pub expected: String,
27    pub value: Value,
28    pub description: Option<String>,
29}
30
31/// Runtime validator trait.
32pub trait Validate: DeserializeOwned + Sized {
33    fn validate(value: Value) -> IValidation<Self>;
34    fn validate_equals(value: Value) -> IValidation<Self>;
35}
36
37/// Runtime representation for derive-time rustia tags.
38#[derive(Debug, Clone, PartialEq)]
39pub enum TagRuntime {
40    MinLength(usize),
41    MaxLength(usize),
42    MinItems(usize),
43    MaxItems(usize),
44    UniqueItems(bool),
45    Minimum(f64),
46    Maximum(f64),
47    ExclusiveMinimum(f64),
48    ExclusiveMaximum(f64),
49    MultipleOf(f64),
50    Pattern(String),
51    Format(String),
52    Type(String),
53    Items(Vec<TagRuntime>),
54    Keys(Vec<TagRuntime>),
55    Values(Vec<TagRuntime>),
56    Metadata { kind: String, args: Vec<String> },
57}
58
59pub fn validate_with_serde<T>(value: Value) -> IValidation<T>
60where
61    T: DeserializeOwned,
62{
63    let encoded = match serde_json::to_vec(&value) {
64        Ok(encoded) => encoded,
65        Err(error) => {
66            return IValidation::Failure {
67                data: value.clone(),
68                errors: vec![IValidationError {
69                    path: "$input".to_owned(),
70                    expected: "JSON value".to_owned(),
71                    value,
72                    description: Some(error.to_string()),
73                }],
74            };
75        }
76    };
77
78    let mut deserializer = serde_json::Deserializer::from_slice(&encoded);
79    match serde_path_to_error::deserialize::<_, T>(&mut deserializer) {
80        Ok(data) => IValidation::Success { data },
81        Err(error) => {
82            let raw_path = error.path().to_string();
83            let path = normalize_path(&raw_path);
84            let description = Some(error.into_inner().to_string());
85            let error_value = read_value_on_path(&value, &path).unwrap_or(Value::Null);
86
87            IValidation::Failure {
88                data: value,
89                errors: vec![IValidationError {
90                    path,
91                    expected: "serde-compatible schema".to_owned(),
92                    value: error_value,
93                    description,
94                }],
95            }
96        }
97    }
98}
99
100fn normalize_path(path: &str) -> String {
101    if path.is_empty() {
102        "$input".to_owned()
103    } else if path.starts_with('[') {
104        format!("$input{path}")
105    } else {
106        format!("$input.{path}")
107    }
108}
109
110fn read_value_on_path(root: &Value, path: &str) -> Option<Value> {
111    if path == "$input" {
112        return Some(root.clone());
113    }
114    let mut cursor = root;
115    let mut chars = path.strip_prefix("$input")?.chars().peekable();
116
117    while let Some(ch) = chars.peek().copied() {
118        match ch {
119            '.' => {
120                chars.next();
121                let mut key = String::new();
122                while let Some(next) = chars.peek().copied() {
123                    if next == '.' || next == '[' {
124                        break;
125                    }
126                    key.push(next);
127                    chars.next();
128                }
129                cursor = cursor.get(&key)?;
130            }
131            '[' => {
132                chars.next();
133                if chars.peek().copied() == Some('"') {
134                    chars.next();
135                    let mut key = String::new();
136                    while let Some(next) = chars.next() {
137                        if next == '"' {
138                            break;
139                        }
140                        if next == '\\' {
141                            if let Some(escaped) = chars.next() {
142                                key.push(escaped);
143                            } else {
144                                return None;
145                            }
146                        } else {
147                            key.push(next);
148                        }
149                    }
150                    if chars.next() != Some(']') {
151                        return None;
152                    }
153                    cursor = cursor.get(&key)?;
154                } else {
155                    let mut index = String::new();
156                    while let Some(next) = chars.peek().copied() {
157                        if next == ']' {
158                            break;
159                        }
160                        index.push(next);
161                        chars.next();
162                    }
163                    if chars.next() != Some(']') {
164                        return None;
165                    }
166                    let parsed = index.parse::<usize>().ok()?;
167                    cursor = cursor.get(parsed)?;
168                }
169            }
170            _ => return None,
171        }
172    }
173
174    Some(cursor.clone())
175}
176
177macro_rules! impl_validate_with_serde {
178    ($($ty:ty),* $(,)?) => {
179        $(
180            impl Validate for $ty {
181                fn validate(value: Value) -> IValidation<Self> {
182                    validate_with_serde(value)
183                }
184
185                fn validate_equals(value: Value) -> IValidation<Self> {
186                    validate_with_serde(value)
187                }
188            }
189        )*
190    };
191}
192
193impl_validate_with_serde!(
194    bool, String, char, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64,
195);
196
197impl Validate for Value {
198    fn validate(value: Value) -> IValidation<Self> {
199        IValidation::Success { data: value }
200    }
201
202    fn validate_equals(value: Value) -> IValidation<Self> {
203        IValidation::Success { data: value }
204    }
205}
206
207impl<T> Validate for Option<T>
208where
209    T: Validate,
210{
211    fn validate(value: Value) -> IValidation<Self> {
212        if value.is_null() {
213            return IValidation::Success { data: None };
214        }
215
216        let original = value.clone();
217        match T::validate(value) {
218            IValidation::Success { data } => IValidation::Success { data: Some(data) },
219            IValidation::Failure { errors, .. } => IValidation::Failure {
220                data: original,
221                errors,
222            },
223        }
224    }
225
226    fn validate_equals(value: Value) -> IValidation<Self> {
227        if value.is_null() {
228            return IValidation::Success { data: None };
229        }
230
231        let original = value.clone();
232        match T::validate_equals(value) {
233            IValidation::Success { data } => IValidation::Success { data: Some(data) },
234            IValidation::Failure { errors, .. } => IValidation::Failure {
235                data: original,
236                errors,
237            },
238        }
239    }
240}
241
242impl<T> Validate for Vec<T>
243where
244    T: Validate,
245{
246    fn validate(value: Value) -> IValidation<Self> {
247        validate_vec(value, false)
248    }
249
250    fn validate_equals(value: Value) -> IValidation<Self> {
251        validate_vec(value, true)
252    }
253}
254
255fn validate_vec<T>(value: Value, strict: bool) -> IValidation<Vec<T>>
256where
257    T: Validate,
258{
259    let original = value.clone();
260    let array = match value {
261        Value::Array(array) => array,
262        other => {
263            return IValidation::Failure {
264                data: other.clone(),
265                errors: vec![IValidationError {
266                    path: "$input".to_owned(),
267                    expected: "array".to_owned(),
268                    value: other,
269                    description: Some("expected an array value".to_owned()),
270                }],
271            };
272        }
273    };
274
275    let mut data = Vec::with_capacity(array.len());
276    let mut errors = Vec::new();
277    for (index, item) in array.iter().cloned().enumerate() {
278        let validated = if strict {
279            T::validate_equals(item)
280        } else {
281            T::validate(item)
282        };
283        match validated {
284            IValidation::Success { data: item } => data.push(item),
285            IValidation::Failure { errors: nested, .. } => {
286                merge_prefixed_errors(&mut errors, &join_index_path("$input", index), nested);
287            }
288        }
289    }
290
291    if errors.is_empty() {
292        IValidation::Success { data }
293    } else {
294        IValidation::Failure {
295            data: original,
296            errors,
297        }
298    }
299}
300
301impl<T> Validate for HashMap<String, T>
302where
303    T: Validate,
304{
305    fn validate(value: Value) -> IValidation<Self> {
306        validate_hash_map(value, false)
307    }
308
309    fn validate_equals(value: Value) -> IValidation<Self> {
310        validate_hash_map(value, true)
311    }
312}
313
314fn validate_hash_map<T>(value: Value, strict: bool) -> IValidation<HashMap<String, T>>
315where
316    T: Validate,
317{
318    let original = value.clone();
319    let object = match value {
320        Value::Object(object) => object,
321        other => {
322            return IValidation::Failure {
323                data: other.clone(),
324                errors: vec![IValidationError {
325                    path: "$input".to_owned(),
326                    expected: "object".to_owned(),
327                    value: other,
328                    description: Some("expected an object value".to_owned()),
329                }],
330            };
331        }
332    };
333
334    let mut data = HashMap::with_capacity(object.len());
335    let mut errors = Vec::new();
336    for (key, item) in object.iter() {
337        let validated = if strict {
338            T::validate_equals(item.clone())
339        } else {
340            T::validate(item.clone())
341        };
342        match validated {
343            IValidation::Success { data: item } => {
344                data.insert(key.clone(), item);
345            }
346            IValidation::Failure { errors: nested, .. } => {
347                merge_prefixed_errors(&mut errors, &join_object_path("$input", key), nested);
348            }
349        }
350    }
351
352    if errors.is_empty() {
353        IValidation::Success { data }
354    } else {
355        IValidation::Failure {
356            data: original,
357            errors,
358        }
359    }
360}
361
362impl<T> Validate for BTreeMap<String, T>
363where
364    T: Validate,
365{
366    fn validate(value: Value) -> IValidation<Self> {
367        validate_btree_map(value, false)
368    }
369
370    fn validate_equals(value: Value) -> IValidation<Self> {
371        validate_btree_map(value, true)
372    }
373}
374
375fn validate_btree_map<T>(value: Value, strict: bool) -> IValidation<BTreeMap<String, T>>
376where
377    T: Validate,
378{
379    let original = value.clone();
380    let object = match value {
381        Value::Object(object) => object,
382        other => {
383            return IValidation::Failure {
384                data: other.clone(),
385                errors: vec![IValidationError {
386                    path: "$input".to_owned(),
387                    expected: "object".to_owned(),
388                    value: other,
389                    description: Some("expected an object value".to_owned()),
390                }],
391            };
392        }
393    };
394
395    let mut data = BTreeMap::new();
396    let mut errors = Vec::new();
397    for (key, item) in object.iter() {
398        let validated = if strict {
399            T::validate_equals(item.clone())
400        } else {
401            T::validate(item.clone())
402        };
403        match validated {
404            IValidation::Success { data: item } => {
405                data.insert(key.clone(), item);
406            }
407            IValidation::Failure { errors: nested, .. } => {
408                merge_prefixed_errors(&mut errors, &join_object_path("$input", key), nested);
409            }
410        }
411    }
412
413    if errors.is_empty() {
414        IValidation::Success { data }
415    } else {
416        IValidation::Failure {
417            data: original,
418            errors,
419        }
420    }
421}
422
423pub fn merge_prefixed_errors(
424    target: &mut Vec<IValidationError>,
425    prefix: &str,
426    mut nested: Vec<IValidationError>,
427) {
428    for error in &mut nested {
429        error.path = prepend_path(prefix, &error.path);
430    }
431    target.extend(nested);
432}
433
434pub fn prepend_path(prefix: &str, path: &str) -> String {
435    if path == "$input" {
436        prefix.to_owned()
437    } else if let Some(suffix) = path.strip_prefix("$input") {
438        format!("{prefix}{suffix}")
439    } else if path.starts_with('[') {
440        format!("{prefix}{path}")
441    } else {
442        format!("{prefix}.{path}")
443    }
444}
445
446pub fn join_object_path(base: &str, key: &str) -> String {
447    if key
448        .chars()
449        .next()
450        .is_some_and(|ch| ch == '_' || ch.is_ascii_alphabetic())
451        && key
452            .chars()
453            .all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
454    {
455        format!("{base}.{key}")
456    } else {
457        let escaped = key.replace('\\', "\\\\").replace('"', "\\\"");
458        format!("{base}[\"{escaped}\"]")
459    }
460}
461
462pub fn join_index_path(base: &str, index: usize) -> String {
463    format!("{base}[{index}]")
464}
465
466pub fn apply_tags(
467    value: &Value,
468    path: &str,
469    tags: &[TagRuntime],
470    errors: &mut Vec<IValidationError>,
471) {
472    for tag in tags {
473        match tag {
474            TagRuntime::MinLength(min) => {
475                if let Some(text) = value.as_str()
476                    && text.chars().count() < *min
477                {
478                    errors.push(tag_error(
479                        path,
480                        &format!("string & MinLength<{min}>"),
481                        value,
482                        Some(format!("string length must be >= {min}")),
483                    ));
484                }
485            }
486            TagRuntime::MaxLength(max) => {
487                if let Some(text) = value.as_str()
488                    && text.chars().count() > *max
489                {
490                    errors.push(tag_error(
491                        path,
492                        &format!("string & MaxLength<{max}>"),
493                        value,
494                        Some(format!("string length must be <= {max}")),
495                    ));
496                }
497            }
498            TagRuntime::MinItems(min) => {
499                if let Some(items) = value.as_array()
500                    && items.len() < *min
501                {
502                    errors.push(tag_error(
503                        path,
504                        &format!("array & MinItems<{min}>"),
505                        value,
506                        Some(format!("array length must be >= {min}")),
507                    ));
508                }
509            }
510            TagRuntime::MaxItems(max) => {
511                if let Some(items) = value.as_array()
512                    && items.len() > *max
513                {
514                    errors.push(tag_error(
515                        path,
516                        &format!("array & MaxItems<{max}>"),
517                        value,
518                        Some(format!("array length must be <= {max}")),
519                    ));
520                }
521            }
522            TagRuntime::UniqueItems(enabled) => {
523                if *enabled
524                    && let Some(items) = value.as_array()
525                    && !is_unique_items(items)
526                {
527                    errors.push(tag_error(
528                        path,
529                        "array & UniqueItems<true>",
530                        value,
531                        Some("array items must be unique".to_owned()),
532                    ));
533                }
534            }
535            TagRuntime::Minimum(minimum) => {
536                if let Some(number) = json_number_to_f64(value)
537                    && number < *minimum
538                {
539                    errors.push(tag_error(
540                        path,
541                        &format!("number & Minimum<{minimum}>"),
542                        value,
543                        Some(format!("number must be >= {minimum}")),
544                    ));
545                }
546            }
547            TagRuntime::Maximum(maximum) => {
548                if let Some(number) = json_number_to_f64(value)
549                    && number > *maximum
550                {
551                    errors.push(tag_error(
552                        path,
553                        &format!("number & Maximum<{maximum}>"),
554                        value,
555                        Some(format!("number must be <= {maximum}")),
556                    ));
557                }
558            }
559            TagRuntime::ExclusiveMinimum(minimum) => {
560                if let Some(number) = json_number_to_f64(value)
561                    && number <= *minimum
562                {
563                    errors.push(tag_error(
564                        path,
565                        &format!("number & ExclusiveMinimum<{minimum}>"),
566                        value,
567                        Some(format!("number must be > {minimum}")),
568                    ));
569                }
570            }
571            TagRuntime::ExclusiveMaximum(maximum) => {
572                if let Some(number) = json_number_to_f64(value)
573                    && number >= *maximum
574                {
575                    errors.push(tag_error(
576                        path,
577                        &format!("number & ExclusiveMaximum<{maximum}>"),
578                        value,
579                        Some(format!("number must be < {maximum}")),
580                    ));
581                }
582            }
583            TagRuntime::MultipleOf(divisor) => {
584                if let Some(number) = json_number_to_f64(value)
585                    && !is_multiple_of(number, *divisor)
586                {
587                    errors.push(tag_error(
588                        path,
589                        &format!("number & MultipleOf<{divisor}>"),
590                        value,
591                        Some(format!("number must be a multiple of {divisor}")),
592                    ));
593                }
594            }
595            TagRuntime::Pattern(pattern) => {
596                if let Some(text) = value.as_str() {
597                    match Regex::new(pattern) {
598                        Ok(regex) => {
599                            if !regex.is_match(text) {
600                                errors.push(tag_error(
601                                    path,
602                                    &format!("string & Pattern<{pattern}>"),
603                                    value,
604                                    Some("string does not match the required pattern".to_owned()),
605                                ));
606                            }
607                        }
608                        Err(error) => {
609                            errors.push(tag_error(
610                                path,
611                                &format!("string & Pattern<{pattern}>"),
612                                value,
613                                Some(format!("invalid pattern: {error}")),
614                            ));
615                        }
616                    }
617                }
618            }
619            TagRuntime::Format(format_name) => {
620                if let Some(text) = value.as_str()
621                    && !is_valid_format(format_name, text)
622                {
623                    errors.push(tag_error(
624                        path,
625                        &format!("string & Format<{format_name}>"),
626                        value,
627                        Some(format!("string does not satisfy format `{format_name}`")),
628                    ));
629                }
630            }
631            TagRuntime::Type(type_name) => {
632                if !matches_numeric_type(type_name, value) {
633                    errors.push(tag_error(
634                        path,
635                        &format!("number & Type<{type_name}>"),
636                        value,
637                        Some(format!("value does not satisfy numeric type `{type_name}`")),
638                    ));
639                }
640            }
641            TagRuntime::Items(nested) => {
642                if let Some(items) = value.as_array() {
643                    for (index, item) in items.iter().enumerate() {
644                        apply_tags(item, &join_index_path(path, index), nested, errors);
645                    }
646                }
647            }
648            TagRuntime::Keys(nested) => {
649                if let Some(object) = value.as_object() {
650                    for key in object.keys() {
651                        let key_value = Value::String(key.clone());
652                        apply_tags(&key_value, &join_object_path(path, key), nested, errors);
653                    }
654                }
655            }
656            TagRuntime::Values(nested) => {
657                if let Some(object) = value.as_object() {
658                    for (key, item) in object {
659                        apply_tags(item, &join_object_path(path, key), nested, errors);
660                    }
661                }
662            }
663            TagRuntime::Metadata { .. } => {}
664        }
665    }
666}
667
668fn tag_error(
669    path: &str,
670    expected: &str,
671    value: &Value,
672    description: Option<String>,
673) -> IValidationError {
674    IValidationError {
675        path: path.to_owned(),
676        expected: expected.to_owned(),
677        value: value.clone(),
678        description,
679    }
680}
681
682fn is_unique_items(items: &[Value]) -> bool {
683    for i in 0..items.len() {
684        for j in (i + 1)..items.len() {
685            if items[i] == items[j] {
686                return false;
687            }
688        }
689    }
690    true
691}
692
693fn json_number_to_f64(value: &Value) -> Option<f64> {
694    match value {
695        Value::Number(number) => number
696            .as_f64()
697            .or_else(|| number.as_i64().map(|number| number as f64))
698            .or_else(|| number.as_u64().map(|number| number as f64)),
699        _ => None,
700    }
701}
702
703fn is_multiple_of(value: f64, divisor: f64) -> bool {
704    if divisor == 0.0 {
705        return false;
706    }
707    let quotient = value / divisor;
708    (quotient - quotient.round()).abs() <= 1e-12
709}
710
711fn is_valid_format(format_name: &str, input: &str) -> bool {
712    match format_name {
713        "byte" => base64::engine::general_purpose::STANDARD
714            .decode(input)
715            .is_ok(),
716        "password" => true,
717        "regex" => Regex::new(input).is_ok(),
718        "uuid" => Uuid::parse_str(input).is_ok(),
719        "email" => Regex::new(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
720            .map(|regex| regex.is_match(input))
721            .unwrap_or(false),
722        "hostname" => is_valid_hostname(input),
723        "idn-email" => is_valid_idn_email(input),
724        "idn-hostname" => !input.is_empty(),
725        "iri" | "iri-reference" | "uri-reference" | "uri-template" => {
726            !input.trim().is_empty() && !input.contains(char::is_whitespace)
727        }
728        "ipv4" => input.parse::<std::net::Ipv4Addr>().is_ok(),
729        "ipv6" => input.parse::<std::net::Ipv6Addr>().is_ok(),
730        "uri" | "url" => Url::parse(input).is_ok(),
731        "date-time" => chrono::DateTime::parse_from_rfc3339(input).is_ok(),
732        "date" => chrono::NaiveDate::parse_from_str(input, "%Y-%m-%d").is_ok(),
733        "time" => is_valid_time(input),
734        "duration" => is_valid_duration(input),
735        "json-pointer" => is_valid_json_pointer(input),
736        "relative-json-pointer" => is_valid_relative_json_pointer(input),
737        _ => false,
738    }
739}
740
741fn is_valid_hostname(input: &str) -> bool {
742    if input.is_empty() || input.len() > 253 {
743        return false;
744    }
745    for label in input.split('.') {
746        if label.is_empty() || label.len() > 63 {
747            return false;
748        }
749        let bytes = label.as_bytes();
750        if bytes.first() == Some(&b'-') || bytes.last() == Some(&b'-') {
751            return false;
752        }
753        if !label
754            .chars()
755            .all(|character| character.is_ascii_alphanumeric() || character == '-')
756        {
757            return false;
758        }
759    }
760    true
761}
762
763fn is_valid_idn_email(input: &str) -> bool {
764    let mut parts = input.split('@');
765    let local = match parts.next() {
766        Some(local) if !local.is_empty() => local,
767        _ => return false,
768    };
769    let domain = match parts.next() {
770        Some(domain) if !domain.is_empty() => domain,
771        _ => return false,
772    };
773    if parts.next().is_some() {
774        return false;
775    }
776    !local.contains(char::is_whitespace) && !domain.contains(char::is_whitespace)
777}
778
779fn is_valid_time(input: &str) -> bool {
780    let formats = ["%H:%M:%S", "%H:%M:%S%.f", "%H:%M:%S%:z", "%H:%M:%S%.f%:z"];
781    formats
782        .iter()
783        .any(|format| chrono::NaiveTime::parse_from_str(input, format).is_ok())
784}
785
786fn is_valid_json_pointer(input: &str) -> bool {
787    if input.is_empty() {
788        return true;
789    }
790    if !input.starts_with('/') {
791        return false;
792    }
793    let mut chars = input.chars().peekable();
794    while let Some(character) = chars.next() {
795        if character == '~' {
796            match chars.next() {
797                Some('0') | Some('1') => {}
798                _ => return false,
799            }
800        }
801    }
802    true
803}
804
805fn is_valid_relative_json_pointer(input: &str) -> bool {
806    let Some(first_non_digit) = input.find(|character: char| !character.is_ascii_digit()) else {
807        return !input.is_empty() && input != "00";
808    };
809    let (digits, suffix) = input.split_at(first_non_digit);
810    if digits.is_empty() || (digits.starts_with('0') && digits.len() > 1) {
811        return false;
812    }
813    suffix == "#" || is_valid_json_pointer(suffix)
814}
815
816fn is_valid_duration(input: &str) -> bool {
817    let Ok(regex) = Regex::new(r"^P(\d+Y)?(\d+M)?(\d+W)?(\d+D)?(T(\d+H)?(\d+M)?(\d+(\.\d+)?S)?)?$")
818    else {
819        return false;
820    };
821
822    if !regex.is_match(input) {
823        return false;
824    }
825
826    input != "P" && input != "PT"
827}
828
829fn matches_numeric_type(type_name: &str, value: &Value) -> bool {
830    match type_name {
831        "int32" => value
832            .as_i64()
833            .is_some_and(|number| i32::MIN as i64 <= number && number <= i32::MAX as i64),
834        "uint32" => value
835            .as_u64()
836            .is_some_and(|number| number <= u32::MAX as u64),
837        "int64" => value.as_i64().is_some(),
838        "uint64" => value.as_u64().is_some(),
839        "float" => value
840            .as_f64()
841            .is_some_and(|number| number.is_finite() && (number as f32).is_finite()),
842        "double" => value.as_f64().is_some_and(f64::is_finite),
843        _ => false,
844    }
845}
846
847#[doc(hidden)]
848pub mod __private {
849    pub use super::{
850        TagRuntime, apply_tags, join_index_path, join_object_path, merge_prefixed_errors,
851        prepend_path, validate_with_serde,
852    };
853}