use std::borrow::Cow;
pub trait ProtoValidator {
fn validate(&self) -> Result<(), ValidationError>;
}
#[derive(Debug, Clone)]
pub struct FieldRule {
pub field: Cow<'static, str>,
pub constraint: Cow<'static, str>,
}
impl FieldRule {
pub fn required(field: &str) -> Self {
Self {
field: Cow::Owned(field.to_string()),
constraint: Cow::Borrowed("field is required"),
}
}
pub fn range(field: &str, min: i64, max: i64) -> Self {
Self {
field: Cow::Owned(field.to_string()),
constraint: Cow::Owned(format!("value must be between {min} and {max}")),
}
}
pub fn max_length(field: &str, max: usize) -> Self {
Self {
field: Cow::Owned(field.to_string()),
constraint: Cow::Owned(format!("length must not exceed {max}")),
}
}
pub fn min_length(field: &str, min: usize) -> Self {
Self {
field: Cow::Owned(field.to_string()),
constraint: Cow::Owned(format!("length must be at least {min}")),
}
}
pub fn max_items(field: &str, max: usize) -> Self {
Self {
field: Cow::Owned(field.to_string()),
constraint: Cow::Owned(format!("number of items must not exceed {max}")),
}
}
pub fn custom(field: &str, constraint: &str) -> Self {
Self {
field: Cow::Owned(field.to_string()),
constraint: Cow::Owned(constraint.to_string()),
}
}
}
impl std::fmt::Display for FieldRule {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.field, self.constraint)
}
}
#[derive(Debug, thiserror::Error)]
#[error("protobuf validation failed: {}", format_violations(&self.violations))]
pub struct ValidationError {
violations: Vec<FieldRule>,
}
impl ValidationError {
pub fn constraint_violations(violations: Vec<FieldRule>) -> Self {
Self { violations }
}
pub fn violations(&self) -> &[FieldRule] {
&self.violations
}
pub fn into_status(self) -> tonic::Status {
let message = format_violations(&self.violations);
tonic::Status::invalid_argument(message)
}
}
fn format_violations(violations: &[FieldRule]) -> String {
violations
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join("; ")
}
#[derive(Debug, Default)]
pub struct ValidationRuleSet {
violations: Vec<FieldRule>,
}
impl ValidationRuleSet {
pub fn new() -> Self {
Self::default()
}
pub fn require_non_empty(mut self, field: &str, value: &str) -> Self {
if value.is_empty() {
self.violations.push(FieldRule::required(field));
}
self
}
pub fn require_range(mut self, field: &str, value: i64, min: i64, max: i64) -> Self {
if value < min || value > max {
self.violations.push(FieldRule::range(field, min, max));
}
self
}
pub fn require_max_length(mut self, field: &str, value: &str, max: usize) -> Self {
if value.chars().take(max + 1).count() > max {
self.violations.push(FieldRule::max_length(field, max));
}
self
}
pub fn require_min_length(mut self, field: &str, value: &str, min: usize) -> Self {
if value.chars().take(min).count() < min {
self.violations.push(FieldRule::min_length(field, min));
}
self
}
pub fn require_max_items(mut self, field: &str, count: usize, max: usize) -> Self {
if count > max {
self.violations.push(FieldRule::max_items(field, max));
}
self
}
pub fn require(mut self, field: &str, constraint: &str, condition: bool) -> Self {
if !condition {
self.violations.push(FieldRule::custom(field, constraint));
}
self
}
pub fn validate(self) -> Result<(), ValidationError> {
if self.violations.is_empty() {
Ok(())
} else {
Err(ValidationError::constraint_violations(self.violations))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
fn field_rule_required_display() {
let rule = FieldRule::required("name");
let display = rule.to_string();
assert_eq!(display, "name: field is required");
}
#[rstest]
fn field_rule_range_display() {
let rule = FieldRule::range("count", 0, 100);
let display = rule.to_string();
assert_eq!(display, "count: value must be between 0 and 100");
}
#[rstest]
fn field_rule_max_length_display() {
let rule = FieldRule::max_length("description", 255);
let display = rule.to_string();
assert_eq!(display, "description: length must not exceed 255");
}
#[rstest]
fn field_rule_min_length_display() {
let rule = FieldRule::min_length("password", 8);
let display = rule.to_string();
assert_eq!(display, "password: length must be at least 8");
}
#[rstest]
fn field_rule_max_items_display() {
let rule = FieldRule::max_items("errors", 50);
let display = rule.to_string();
assert_eq!(display, "errors: number of items must not exceed 50");
}
#[rstest]
fn field_rule_custom_display() {
let rule = FieldRule::custom("email", "must be a valid email address");
let display = rule.to_string();
assert_eq!(display, "email: must be a valid email address");
}
#[rstest]
fn validation_error_single_violation() {
let error = ValidationError::constraint_violations(vec![FieldRule::required("name")]);
let message = error.to_string();
let violations = error.violations();
assert_eq!(
message,
"protobuf validation failed: name: field is required"
);
assert_eq!(violations.len(), 1);
}
#[rstest]
fn validation_error_multiple_violations() {
let error = ValidationError::constraint_violations(vec![
FieldRule::required("name"),
FieldRule::range("count", 0, 100),
]);
let message = error.to_string();
assert_eq!(
message,
"protobuf validation failed: name: field is required; count: value must be between 0 and 100"
);
}
#[rstest]
fn validation_error_into_status() {
let error = ValidationError::constraint_violations(vec![FieldRule::required("query")]);
let status = error.into_status();
assert_eq!(status.code(), tonic::Code::InvalidArgument);
assert!(status.message().contains("query: field is required"));
}
#[rstest]
fn rule_set_passes_when_all_valid() {
let result = ValidationRuleSet::new()
.require_non_empty("name", "Alice")
.require_range("age", 25, 0, 150)
.require_max_length("bio", "Short bio", 1000)
.validate();
assert!(result.is_ok());
}
#[rstest]
fn rule_set_fails_on_empty_required() {
let result = ValidationRuleSet::new()
.require_non_empty("name", "")
.validate();
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.violations().len(), 1);
assert_eq!(err.violations()[0].field, "name");
}
#[rstest]
fn rule_set_fails_on_out_of_range() {
let result = ValidationRuleSet::new()
.require_range("page", -1, 0, 1000)
.validate();
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.violations()[0].field, "page");
}
#[rstest]
fn rule_set_fails_on_excessive_length() {
let long_string = "x".repeat(300);
let result = ValidationRuleSet::new()
.require_max_length("name", &long_string, 255)
.validate();
assert!(result.is_err());
}
#[rstest]
fn rule_set_fails_on_insufficient_length() {
let result = ValidationRuleSet::new()
.require_min_length("password", "abc", 8)
.validate();
assert!(result.is_err());
}
#[rstest]
fn rule_set_fails_on_excessive_items() {
let result = ValidationRuleSet::new()
.require_max_items("errors", 100, 50)
.validate();
assert!(result.is_err());
}
#[rstest]
fn rule_set_custom_constraint() {
let result = ValidationRuleSet::new()
.require("email", "must contain @", "invalid".contains('@'))
.validate();
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.violations()[0].field, "email");
assert_eq!(err.violations()[0].constraint, "must contain @");
}
#[rstest]
fn rule_set_collects_multiple_violations() {
let result = ValidationRuleSet::new()
.require_non_empty("name", "")
.require_range("page", -1, 0, 100)
.require_max_length("query", &"x".repeat(10000), 1000)
.validate();
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.violations().len(), 3);
}
#[rstest]
fn proto_validator_trait_implementation() {
struct TestMessage {
name: String,
count: i32,
}
impl ProtoValidator for TestMessage {
fn validate(&self) -> Result<(), ValidationError> {
ValidationRuleSet::new()
.require_non_empty("name", &self.name)
.require_range("count", self.count as i64, 0, 1000)
.validate()
}
}
let valid = TestMessage {
name: "test".to_string(),
count: 42,
};
let invalid = TestMessage {
name: String::new(),
count: -1,
};
assert!(valid.validate().is_ok());
assert!(invalid.validate().is_err());
assert_eq!(invalid.validate().unwrap_err().violations().len(), 2);
}
}