use std::collections::HashSet;
use crate::config::ValidationConfig;
use crate::error::{Error, ValidationError};
use crate::violation::Violation;
pub(crate) struct RepeatedRuleEval {
min_items: Option<u64>,
max_items: Option<u64>,
unique: bool,
}
impl RepeatedRuleEval {
pub fn new(rules: &prost_protovalidate_types::RepeatedRules) -> Self {
Self {
min_items: rules.min_items,
max_items: rules.max_items,
unique: rules.unique.unwrap_or(false),
}
}
pub fn tautology(&self) -> bool {
self.min_items.is_none() && self.max_items.is_none() && !self.unique
}
pub fn evaluate(
&self,
val: &prost_reflect::Value,
_cfg: &ValidationConfig,
) -> Result<(), Error> {
let Some(list) = val.as_list() else {
return Ok(());
};
let mut violations = Vec::new();
#[allow(clippy::cast_possible_truncation)]
let len = list.len() as u64;
if let Some(min) = self.min_items {
if len < min {
violations.push(Violation::new(
"",
"repeated.min_items",
format!("must have at least {min} items"),
));
}
}
if let Some(max) = self.max_items {
if len > max {
violations.push(Violation::new(
"",
"repeated.max_items",
format!("must have at most {max} items"),
));
}
}
if self.unique && !is_unique(list) {
violations.push(Violation::new(
"",
"repeated.unique",
"items must be unique",
));
}
if violations.is_empty() {
Ok(())
} else {
Err(ValidationError::new(violations).into())
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum UniqueKey {
Bool(bool),
I32(i32),
I64(i64),
U32(u32),
U64(u64),
F32(u32),
F64(u64),
String(String),
Bytes(Vec<u8>),
Enum(i32),
}
enum UniqueKeyExtraction {
Key(UniqueKey),
AlwaysUnique,
Unsupported,
}
fn canonical_f32_bits(value: f32) -> Option<u32> {
if value.is_nan() {
return None;
}
Some(if value == 0.0 {
0.0_f32.to_bits()
} else {
value.to_bits()
})
}
fn canonical_f64_bits(value: f64) -> Option<u64> {
if value.is_nan() {
return None;
}
Some(if value == 0.0 {
0.0_f64.to_bits()
} else {
value.to_bits()
})
}
fn unique_key(value: &prost_reflect::Value) -> UniqueKeyExtraction {
match value {
prost_reflect::Value::Bool(v) => UniqueKeyExtraction::Key(UniqueKey::Bool(*v)),
prost_reflect::Value::I32(v) => UniqueKeyExtraction::Key(UniqueKey::I32(*v)),
prost_reflect::Value::I64(v) => UniqueKeyExtraction::Key(UniqueKey::I64(*v)),
prost_reflect::Value::U32(v) => UniqueKeyExtraction::Key(UniqueKey::U32(*v)),
prost_reflect::Value::U64(v) => UniqueKeyExtraction::Key(UniqueKey::U64(*v)),
prost_reflect::Value::F32(v) => canonical_f32_bits(*v)
.map_or(UniqueKeyExtraction::AlwaysUnique, |bits| {
UniqueKeyExtraction::Key(UniqueKey::F32(bits))
}),
prost_reflect::Value::F64(v) => canonical_f64_bits(*v)
.map_or(UniqueKeyExtraction::AlwaysUnique, |bits| {
UniqueKeyExtraction::Key(UniqueKey::F64(bits))
}),
prost_reflect::Value::String(v) => UniqueKeyExtraction::Key(UniqueKey::String(v.clone())),
prost_reflect::Value::Bytes(v) => UniqueKeyExtraction::Key(UniqueKey::Bytes(v.to_vec())),
prost_reflect::Value::EnumNumber(v) => UniqueKeyExtraction::Key(UniqueKey::Enum(*v)),
_ => UniqueKeyExtraction::Unsupported,
}
}
fn is_unique(list: &[prost_reflect::Value]) -> bool {
let mut seen = HashSet::with_capacity(list.len());
for value in list {
match unique_key(value) {
UniqueKeyExtraction::Key(key) => {
if !seen.insert(key) {
return false;
}
}
UniqueKeyExtraction::AlwaysUnique => {}
UniqueKeyExtraction::Unsupported => return false,
}
}
true
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
use crate::config::ValidationConfig;
#[test]
fn unique_rejects_equal_values() {
let eval = RepeatedRuleEval::new(&prost_protovalidate_types::RepeatedRules {
unique: Some(true),
..Default::default()
});
let value = prost_reflect::Value::List(vec![
prost_reflect::Value::I32(7),
prost_reflect::Value::I32(7),
]);
let err = eval
.evaluate(&value, &ValidationConfig::default())
.expect_err("duplicate list values should fail unique check");
match err {
Error::Validation(err) => {
assert_eq!(err.len(), 1);
assert_eq!(err.violations()[0].rule_id(), "repeated.unique");
}
other => panic!("unexpected error type: {other}"),
}
}
#[test]
fn unique_accepts_distinct_values() {
let eval = RepeatedRuleEval::new(&prost_protovalidate_types::RepeatedRules {
unique: Some(true),
..Default::default()
});
let value = prost_reflect::Value::List(vec![
prost_reflect::Value::I32(7),
prost_reflect::Value::I32(8),
]);
assert!(eval.evaluate(&value, &ValidationConfig::default()).is_ok());
}
#[test]
fn unique_rejects_negative_zero_and_positive_zero() {
let eval = RepeatedRuleEval::new(&prost_protovalidate_types::RepeatedRules {
unique: Some(true),
..Default::default()
});
let value = prost_reflect::Value::List(vec![
prost_reflect::Value::F64(-0.0),
prost_reflect::Value::F64(0.0),
]);
let err = eval
.evaluate(&value, &ValidationConfig::default())
.expect_err("signed and unsigned zero should be treated as equal");
let Error::Validation(err) = err else {
panic!("unexpected error type");
};
assert_eq!(err.violations()[0].rule_id(), "repeated.unique");
}
#[test]
fn unique_allows_multiple_nan_values() {
let eval = RepeatedRuleEval::new(&prost_protovalidate_types::RepeatedRules {
unique: Some(true),
..Default::default()
});
let value = prost_reflect::Value::List(vec![
prost_reflect::Value::F64(f64::NAN),
prost_reflect::Value::F64(f64::NAN),
]);
assert!(eval.evaluate(&value, &ValidationConfig::default()).is_ok());
}
}