use std::collections::HashMap;
use serde_json::Value;
use crate::error::FrameworkError;
use crate::validation::{Rule, ValidationError};
use super::async_rule::AsyncRule;
#[derive(Debug)]
pub enum AsyncValidationError {
Validation(ValidationError),
Infra(FrameworkError),
}
impl std::fmt::Display for AsyncValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Validation(e) => write!(f, "Validation failed: {e}"),
Self::Infra(e) => write!(f, "Infrastructure error: {e}"),
}
}
}
impl std::error::Error for AsyncValidationError {}
pub struct AsyncValidator<'a> {
data: &'a Value,
sync_rules: HashMap<String, Vec<Box<dyn Rule>>>,
async_rules: HashMap<String, Vec<Box<dyn AsyncRule>>>,
custom_messages: HashMap<String, String>,
custom_attributes: HashMap<String, String>,
}
impl<'a> AsyncValidator<'a> {
pub fn new(data: &'a Value) -> Self {
Self {
data,
sync_rules: HashMap::new(),
async_rules: HashMap::new(),
custom_messages: HashMap::new(),
custom_attributes: HashMap::new(),
}
}
pub fn rule<R: Rule + 'static>(mut self, field: impl Into<String>, rule: R) -> Self {
let field = field.into();
self.sync_rules
.entry(field)
.or_default()
.push(Box::new(rule) as Box<dyn Rule>);
self
}
pub fn rules(mut self, field: impl Into<String>, rules: Vec<Box<dyn Rule>>) -> Self {
self.sync_rules.insert(field.into(), rules);
self
}
pub fn async_rule<R: AsyncRule + 'static>(mut self, field: impl Into<String>, rule: R) -> Self {
self.async_rules
.entry(field.into())
.or_default()
.push(Box::new(rule) as Box<dyn AsyncRule>);
self
}
pub fn message(mut self, key: impl Into<String>, message: impl Into<String>) -> Self {
self.custom_messages.insert(key.into(), message.into());
self
}
pub fn messages(mut self, messages: HashMap<String, String>) -> Self {
self.custom_messages.extend(messages);
self
}
pub fn attribute(mut self, field: impl Into<String>, name: impl Into<String>) -> Self {
self.custom_attributes.insert(field.into(), name.into());
self
}
pub fn attributes(mut self, attributes: HashMap<String, String>) -> Self {
self.custom_attributes.extend(attributes);
self
}
pub async fn validate_async(self) -> Result<(), AsyncValidationError> {
let mut errors = ValidationError::new();
for (field, rules) in &self.sync_rules {
let value = self.get_value(field);
let display_field = self.get_display_field(field);
let has_nullable = rules.iter().any(|r| r.name() == "nullable");
if has_nullable && value.is_null() {
continue;
}
for rule in rules {
if rule.name() == "nullable" {
continue;
}
if let Err(default_message) = rule.validate(&display_field, &value, self.data) {
let message_key = format!("{}.{}", field, rule.name());
let message = self
.custom_messages
.get(&message_key)
.cloned()
.unwrap_or(default_message);
errors.add(field, message);
}
}
}
for (field, rules) in &self.async_rules {
if errors.has(field) {
continue;
}
let value = self.get_value(field);
if value.is_null() {
let nullable = self
.sync_rules
.get(field)
.map(|rs| rs.iter().any(|r| r.name() == "nullable"))
.unwrap_or(false);
if nullable {
continue;
}
}
let display_field = self.get_display_field(field);
for rule in rules {
match rule.validate(&display_field, &value, self.data).await {
Ok(()) => {}
Err(msg) => {
if let Some(rest) = msg.strip_prefix("__infra_error__:") {
return Err(AsyncValidationError::Infra(FrameworkError::database(
rest.trim().to_string(),
)));
}
let message_key = format!("{}.{}", field, rule.name());
let message = self
.custom_messages
.get(&message_key)
.cloned()
.unwrap_or(msg);
errors.add(field, message);
}
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(AsyncValidationError::Validation(errors))
}
}
fn get_value(&self, field: &str) -> Value {
get_nested_value(self.data, field)
.cloned()
.unwrap_or(Value::Null)
}
fn get_display_field(&self, field: &str) -> String {
self.custom_attributes
.get(field)
.cloned()
.unwrap_or_else(|| field.split('_').collect::<Vec<_>>().join(" "))
}
}
fn get_nested_value<'a>(data: &'a Value, path: &str) -> Option<&'a Value> {
let parts: Vec<&str> = path.split('.').collect();
let mut current = data;
for part in parts {
if let Value::Object(map) = current {
current = map.get(part)?;
}
else if let Value::Array(arr) = current {
let index: usize = part.parse().ok()?;
current = arr.get(index)?;
} else {
return None;
}
}
Some(current)
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use serial_test::serial;
use super::*;
use crate::rules;
use crate::validation::rules::*;
struct OkRule;
#[async_trait]
impl AsyncRule for OkRule {
async fn validate(
&self,
_field: &str,
_value: &Value,
_data: &Value,
) -> Result<(), String> {
Ok(())
}
fn name(&self) -> &'static str {
"ok_rule"
}
}
struct CountingRule {
counter: Arc<AtomicUsize>,
}
impl CountingRule {
fn new(counter: Arc<AtomicUsize>) -> Self {
Self { counter }
}
}
#[async_trait]
impl AsyncRule for CountingRule {
async fn validate(
&self,
_field: &str,
_value: &Value,
_data: &Value,
) -> Result<(), String> {
self.counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
fn name(&self) -> &'static str {
"counting_rule"
}
}
struct InfraRule;
#[async_trait]
impl AsyncRule for InfraRule {
async fn validate(
&self,
_field: &str,
_value: &Value,
_data: &Value,
) -> Result<(), String> {
Err("__infra_error__: boom".to_string())
}
fn name(&self) -> &'static str {
"infra_rule"
}
}
struct FailRule;
#[async_trait]
impl AsyncRule for FailRule {
async fn validate(&self, field: &str, _value: &Value, _data: &Value) -> Result<(), String> {
Err(format!("The {field} rule failed."))
}
fn name(&self) -> &'static str {
"fail_rule"
}
}
async fn init_test_db() {
use crate::database::{DatabaseConfig, DB};
use sea_orm::{ConnectionTrait, Statement};
let config = DatabaseConfig::builder().url("sqlite::memory:").build();
DB::init_with(config).await.expect("init in-memory sqlite");
let db = DB::connection().expect("connection after init");
db.execute(Statement::from_string(
db.get_database_backend(),
"CREATE TABLE IF NOT EXISTS widgets (id INTEGER PRIMARY KEY, slug TEXT)".to_owned(),
))
.await
.expect("create widgets scratch table");
}
async fn seed_widget(id: i64, slug: &str) {
use crate::database::DB;
use sea_orm::{ConnectionTrait, Statement};
let db = DB::connection().expect("connection for seed_widget");
db.execute(Statement::from_string(
db.get_database_backend(),
format!("INSERT INTO widgets (id, slug) VALUES ({id}, '{slug}')"),
))
.await
.expect("seed widget row");
}
#[tokio::test]
async fn async_validator_all_pass() {
let data = json!({"name": "Alice"});
let result = AsyncValidator::new(&data)
.rule("name", required())
.async_rule("name", OkRule)
.validate_async()
.await;
assert!(result.is_ok(), "expected Ok(()), got: {result:?}");
}
#[tokio::test]
async fn async_validator_sync_first() {
let counter = Arc::new(AtomicUsize::new(0));
let data = json!({"name": ""});
let result = AsyncValidator::new(&data)
.rule("name", required())
.async_rule("name", CountingRule::new(counter.clone()))
.validate_async()
.await;
assert!(result.is_err(), "expected Err (sync failure)");
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"async rule must not run when sync rule fails"
);
}
#[tokio::test]
async fn async_validator_skips_async_on_sync_error() {
let counter = Arc::new(AtomicUsize::new(0));
let data = json!({"email": ""});
let result = AsyncValidator::new(&data)
.rules("email", rules![required()])
.async_rule("email", CountingRule::new(counter.clone()))
.validate_async()
.await;
match result {
Err(AsyncValidationError::Validation(e)) => {
assert!(e.has("email"), "expected 'email' field error");
}
other => panic!("expected Validation error, got {other:?}"),
}
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"async rule counter must be 0 (no DB query issued)"
);
}
#[tokio::test]
async fn async_validator_infra_error_shape() {
let data = json!({"slug": "something"});
let result = AsyncValidator::new(&data)
.async_rule("slug", InfraRule)
.validate_async()
.await;
match result {
Err(AsyncValidationError::Infra(_)) => {
}
Err(AsyncValidationError::Validation(e)) => {
let msgs = e.get("slug").cloned().unwrap_or_default();
for m in &msgs {
assert!(
!m.contains("__infra_error__"),
"infra sentinel must not appear in field errors: {m}"
);
}
panic!("expected Infra error, got Validation with: {msgs:?}");
}
Ok(()) => panic!("expected Err(Infra), got Ok(())"),
}
}
#[tokio::test]
async fn async_validator_nullable_skips_async() {
let counter = Arc::new(AtomicUsize::new(0));
let data = json!({"nickname": null});
let result = AsyncValidator::new(&data)
.rules("nickname", rules![nullable()])
.async_rule("nickname", CountingRule::new(counter.clone()))
.validate_async()
.await;
assert!(
result.is_ok(),
"nullable null field should pass, got: {result:?}"
);
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"async rule must not run for null nullable field"
);
}
#[tokio::test]
async fn async_validator_validation_failure_shape() {
let data = json!({"name": "Alice"});
let result = AsyncValidator::new(&data)
.async_rule("name", FailRule)
.validate_async()
.await;
match result {
Err(AsyncValidationError::Validation(e)) => {
assert!(e.has("name"), "expected 'name' field error");
}
other => panic!("expected Validation error, got {other:?}"),
}
}
#[tokio::test]
#[serial]
async fn async_validator_unique_duplicate_is_validation() {
init_test_db().await;
seed_widget(1, "taken").await;
let data = json!({"slug": "taken"});
let result = AsyncValidator::new(&data)
.async_rule(
"slug",
crate::validation::rules_async::unique("widgets", "slug"),
)
.validate_async()
.await;
match result {
Err(AsyncValidationError::Validation(e)) => {
assert!(
e.has("slug"),
"expected 'slug' field error for duplicate, errors: {e:?}"
);
}
other => panic!("expected Validation error for duplicate, got {other:?}"),
}
}
}