use std::sync::LazyLock;
use prost_reflect::{DynamicMessage, ReflectMessage};
use crate::config::ValidationConfig;
use crate::error::{Error, ValidationError};
use crate::violation::Violation;
use super::MessageEvaluator;
use super::value::ValueEval;
static REQUIRED_RULE_DESCRIPTOR: LazyLock<Option<prost_reflect::FieldDescriptor>> =
LazyLock::new(|| {
prost_protovalidate_types::FieldRules::default()
.descriptor()
.get_field_by_name("required")
});
pub(crate) struct FieldEval {
pub value: ValueEval,
pub required: bool,
pub has_presence: bool,
pub is_legacy_required: bool,
pub ignore: IgnoreMode,
pub err: Option<crate::error::CompilationError>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub(crate) enum IgnoreMode {
#[default]
Unspecified,
Always,
IfZeroValue,
}
impl FieldEval {
fn should_ignore_always(&self) -> bool {
self.ignore == IgnoreMode::Always
}
fn should_ignore_empty(&self) -> bool {
self.has_presence || self.ignore == IgnoreMode::IfZeroValue
}
}
impl MessageEvaluator for FieldEval {
fn tautology(&self) -> bool {
!self.required && self.value.tautology() && self.err.is_none()
}
fn evaluate_message(&self, msg: &DynamicMessage, cfg: &ValidationConfig) -> Result<(), Error> {
if self.should_ignore_always() {
return Ok(());
}
let field_desc = &self.value.descriptor;
let field_name = field_desc.name().to_string();
if !cfg.filter.should_validate_field(msg, field_desc) {
return Ok(());
}
if let Some(ref err) = self.err {
return Err(crate::error::CompilationError {
cause: err.cause.clone(),
}
.into());
}
let field_is_set = self.is_legacy_required || msg.has_field(field_desc);
if self.required && !field_is_set {
let mut violation = Violation::new(&field_name, "required", "value is required")
.with_rule_value(prost_reflect::Value::Bool(true))
.with_field_descriptor(field_desc);
if let Some(rule_descriptor) = REQUIRED_RULE_DESCRIPTOR.clone() {
violation = violation.with_rule_descriptor(rule_descriptor);
}
return Err(ValidationError::single(violation).into());
}
if self.should_ignore_empty() && !field_is_set {
return Ok(());
}
let val = msg.get_field(field_desc);
let result = self.value.evaluate_value(msg, &val, cfg, &field_name);
enrich_field_violations(result, field_desc, &val)
}
}
fn enrich_field_violations(
result: Result<(), Error>,
field_desc: &prost_reflect::FieldDescriptor,
value: &prost_reflect::Value,
) -> Result<(), Error> {
match result {
Ok(()) => Ok(()),
Err(Error::Validation(mut ve)) => {
for violation in ve.violations_mut() {
if !violation.has_field_descriptor() {
violation.set_field_descriptor(field_desc);
}
if !violation.has_field_value() {
violation.set_field_value(value.clone());
}
}
Err(Error::Validation(ve))
}
Err(other) => Err(other),
}
}