1use std::collections::HashMap;
4
5use serde_json::Value;
6
7use crate::error::FrameworkError;
8use crate::validation::{Rule, ValidationError};
9
10use super::async_rule::AsyncRule;
11
12#[derive(Debug)]
32pub enum AsyncValidationError {
33 Validation(ValidationError),
36 Infra(FrameworkError),
39}
40
41impl std::fmt::Display for AsyncValidationError {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 Self::Validation(e) => write!(f, "Validation failed: {e}"),
45 Self::Infra(e) => write!(f, "Infrastructure error: {e}"),
46 }
47 }
48}
49
50impl std::error::Error for AsyncValidationError {}
51
52pub struct AsyncValidator<'a> {
90 data: &'a Value,
91 sync_rules: HashMap<String, Vec<Box<dyn Rule>>>,
92 async_rules: HashMap<String, Vec<Box<dyn AsyncRule>>>,
93 custom_messages: HashMap<String, String>,
94 custom_attributes: HashMap<String, String>,
95}
96
97impl<'a> AsyncValidator<'a> {
98 pub fn new(data: &'a Value) -> Self {
100 Self {
101 data,
102 sync_rules: HashMap::new(),
103 async_rules: HashMap::new(),
104 custom_messages: HashMap::new(),
105 custom_attributes: HashMap::new(),
106 }
107 }
108
109 pub fn rule<R: Rule + 'static>(mut self, field: impl Into<String>, rule: R) -> Self {
111 let field = field.into();
112 self.sync_rules
113 .entry(field)
114 .or_default()
115 .push(Box::new(rule) as Box<dyn Rule>);
116 self
117 }
118
119 pub fn rules(mut self, field: impl Into<String>, rules: Vec<Box<dyn Rule>>) -> Self {
132 self.sync_rules.insert(field.into(), rules);
133 self
134 }
135
136 pub fn async_rule<R: AsyncRule + 'static>(mut self, field: impl Into<String>, rule: R) -> Self {
140 self.async_rules
141 .entry(field.into())
142 .or_default()
143 .push(Box::new(rule) as Box<dyn AsyncRule>);
144 self
145 }
146
147 pub fn message(mut self, key: impl Into<String>, message: impl Into<String>) -> Self {
157 self.custom_messages.insert(key.into(), message.into());
158 self
159 }
160
161 pub fn messages(mut self, messages: HashMap<String, String>) -> Self {
163 self.custom_messages.extend(messages);
164 self
165 }
166
167 pub fn attribute(mut self, field: impl Into<String>, name: impl Into<String>) -> Self {
178 self.custom_attributes.insert(field.into(), name.into());
179 self
180 }
181
182 pub fn attributes(mut self, attributes: HashMap<String, String>) -> Self {
184 self.custom_attributes.extend(attributes);
185 self
186 }
187
188 pub async fn validate_async(self) -> Result<(), AsyncValidationError> {
208 let mut errors = ValidationError::new();
209
210 for (field, rules) in &self.sync_rules {
212 let value = self.get_value(field);
213 let display_field = self.get_display_field(field);
214
215 let has_nullable = rules.iter().any(|r| r.name() == "nullable");
217 if has_nullable && value.is_null() {
218 continue;
219 }
220
221 for rule in rules {
222 if rule.name() == "nullable" {
224 continue;
225 }
226
227 if let Err(default_message) = rule.validate(&display_field, &value, self.data) {
228 let message_key = format!("{}.{}", field, rule.name());
229 let message = self
230 .custom_messages
231 .get(&message_key)
232 .cloned()
233 .unwrap_or(default_message);
234 errors.add(field, message);
235 }
236 }
237 }
238
239 for (field, rules) in &self.async_rules {
241 if errors.has(field) {
243 continue;
244 }
245
246 let value = self.get_value(field);
247
248 if value.is_null() {
252 let nullable = self
253 .sync_rules
254 .get(field)
255 .map(|rs| rs.iter().any(|r| r.name() == "nullable"))
256 .unwrap_or(false);
257 if nullable {
258 continue;
259 }
260 }
261
262 let display_field = self.get_display_field(field);
263
264 for rule in rules {
265 match rule.validate(&display_field, &value, self.data).await {
266 Ok(()) => {}
267 Err(msg) => {
268 if let Some(rest) = msg.strip_prefix("__infra_error__:") {
270 return Err(AsyncValidationError::Infra(FrameworkError::database(
271 rest.trim().to_string(),
272 )));
273 }
274 let message_key = format!("{}.{}", field, rule.name());
275 let message = self
276 .custom_messages
277 .get(&message_key)
278 .cloned()
279 .unwrap_or(msg);
280 errors.add(field, message);
281 }
282 }
283 }
284 }
285
286 if errors.is_empty() {
287 Ok(())
288 } else {
289 Err(AsyncValidationError::Validation(errors))
290 }
291 }
292
293 fn get_value(&self, field: &str) -> Value {
295 get_nested_value(self.data, field)
296 .cloned()
297 .unwrap_or(Value::Null)
298 }
299
300 fn get_display_field(&self, field: &str) -> String {
302 self.custom_attributes
303 .get(field)
304 .cloned()
305 .unwrap_or_else(|| field.split('_').collect::<Vec<_>>().join(" "))
306 }
307}
308
309fn get_nested_value<'a>(data: &'a Value, path: &str) -> Option<&'a Value> {
311 let parts: Vec<&str> = path.split('.').collect();
312 let mut current = data;
313
314 for part in parts {
315 if let Value::Object(map) = current {
317 current = map.get(part)?;
318 }
319 else if let Value::Array(arr) = current {
321 let index: usize = part.parse().ok()?;
322 current = arr.get(index)?;
323 } else {
324 return None;
325 }
326 }
327
328 Some(current)
329}
330
331#[cfg(test)]
332mod tests {
333 use std::sync::atomic::{AtomicUsize, Ordering};
334 use std::sync::Arc;
335
336 use async_trait::async_trait;
337 use serde_json::json;
338 use serial_test::serial;
339
340 use super::*;
341 use crate::rules;
342 use crate::validation::rules::*;
343
344 struct OkRule;
350
351 #[async_trait]
352 impl AsyncRule for OkRule {
353 async fn validate(
354 &self,
355 _field: &str,
356 _value: &Value,
357 _data: &Value,
358 ) -> Result<(), String> {
359 Ok(())
360 }
361
362 fn name(&self) -> &'static str {
363 "ok_rule"
364 }
365 }
366
367 struct CountingRule {
369 counter: Arc<AtomicUsize>,
370 }
371
372 impl CountingRule {
373 fn new(counter: Arc<AtomicUsize>) -> Self {
374 Self { counter }
375 }
376 }
377
378 #[async_trait]
379 impl AsyncRule for CountingRule {
380 async fn validate(
381 &self,
382 _field: &str,
383 _value: &Value,
384 _data: &Value,
385 ) -> Result<(), String> {
386 self.counter.fetch_add(1, Ordering::SeqCst);
387 Ok(())
388 }
389
390 fn name(&self) -> &'static str {
391 "counting_rule"
392 }
393 }
394
395 struct InfraRule;
397
398 #[async_trait]
399 impl AsyncRule for InfraRule {
400 async fn validate(
401 &self,
402 _field: &str,
403 _value: &Value,
404 _data: &Value,
405 ) -> Result<(), String> {
406 Err("__infra_error__: boom".to_string())
407 }
408
409 fn name(&self) -> &'static str {
410 "infra_rule"
411 }
412 }
413
414 struct FailRule;
416
417 #[async_trait]
418 impl AsyncRule for FailRule {
419 async fn validate(&self, field: &str, _value: &Value, _data: &Value) -> Result<(), String> {
420 Err(format!("The {field} rule failed."))
421 }
422
423 fn name(&self) -> &'static str {
424 "fail_rule"
425 }
426 }
427
428 async fn init_test_db() {
433 use crate::database::{DatabaseConfig, DB};
434 use sea_orm::{ConnectionTrait, Statement};
435 let config = DatabaseConfig::builder().url("sqlite::memory:").build();
436 DB::init_with(config).await.expect("init in-memory sqlite");
437 let db = DB::connection().expect("connection after init");
438 db.execute(Statement::from_string(
439 db.get_database_backend(),
440 "CREATE TABLE IF NOT EXISTS widgets (id INTEGER PRIMARY KEY, slug TEXT)".to_owned(),
441 ))
442 .await
443 .expect("create widgets scratch table");
444 }
445
446 async fn seed_widget(id: i64, slug: &str) {
447 use crate::database::DB;
448 use sea_orm::{ConnectionTrait, Statement};
449 let db = DB::connection().expect("connection for seed_widget");
450 db.execute(Statement::from_string(
451 db.get_database_backend(),
452 format!("INSERT INTO widgets (id, slug) VALUES ({id}, '{slug}')"),
453 ))
454 .await
455 .expect("seed widget row");
456 }
457
458 #[tokio::test]
463 async fn async_validator_all_pass() {
464 let data = json!({"name": "Alice"});
465 let result = AsyncValidator::new(&data)
466 .rule("name", required())
467 .async_rule("name", OkRule)
468 .validate_async()
469 .await;
470 assert!(result.is_ok(), "expected Ok(()), got: {result:?}");
471 }
472
473 #[tokio::test]
474 async fn async_validator_sync_first() {
475 let counter = Arc::new(AtomicUsize::new(0));
477 let data = json!({"name": ""});
478 let result = AsyncValidator::new(&data)
479 .rule("name", required())
480 .async_rule("name", CountingRule::new(counter.clone()))
481 .validate_async()
482 .await;
483 assert!(result.is_err(), "expected Err (sync failure)");
484 assert_eq!(
485 counter.load(Ordering::SeqCst),
486 0,
487 "async rule must not run when sync rule fails"
488 );
489 }
490
491 #[tokio::test]
492 async fn async_validator_skips_async_on_sync_error() {
493 let counter = Arc::new(AtomicUsize::new(0));
495 let data = json!({"email": ""});
496 let result = AsyncValidator::new(&data)
497 .rules("email", rules![required()])
498 .async_rule("email", CountingRule::new(counter.clone()))
499 .validate_async()
500 .await;
501 match result {
502 Err(AsyncValidationError::Validation(e)) => {
503 assert!(e.has("email"), "expected 'email' field error");
504 }
505 other => panic!("expected Validation error, got {other:?}"),
506 }
507 assert_eq!(
508 counter.load(Ordering::SeqCst),
509 0,
510 "async rule counter must be 0 (no DB query issued)"
511 );
512 }
513
514 #[tokio::test]
515 async fn async_validator_infra_error_shape() {
516 let data = json!({"slug": "something"});
519 let result = AsyncValidator::new(&data)
520 .async_rule("slug", InfraRule)
521 .validate_async()
522 .await;
523 match result {
524 Err(AsyncValidationError::Infra(_)) => {
525 }
527 Err(AsyncValidationError::Validation(e)) => {
528 let msgs = e.get("slug").cloned().unwrap_or_default();
530 for m in &msgs {
531 assert!(
532 !m.contains("__infra_error__"),
533 "infra sentinel must not appear in field errors: {m}"
534 );
535 }
536 panic!("expected Infra error, got Validation with: {msgs:?}");
537 }
538 Ok(()) => panic!("expected Err(Infra), got Ok(())"),
539 }
540 }
541
542 #[tokio::test]
543 async fn async_validator_nullable_skips_async() {
544 let counter = Arc::new(AtomicUsize::new(0));
546 let data = json!({"nickname": null});
547 let result = AsyncValidator::new(&data)
548 .rules("nickname", rules![nullable()])
549 .async_rule("nickname", CountingRule::new(counter.clone()))
550 .validate_async()
551 .await;
552 assert!(
553 result.is_ok(),
554 "nullable null field should pass, got: {result:?}"
555 );
556 assert_eq!(
557 counter.load(Ordering::SeqCst),
558 0,
559 "async rule must not run for null nullable field"
560 );
561 }
562
563 #[tokio::test]
564 async fn async_validator_validation_failure_shape() {
565 let data = json!({"name": "Alice"});
567 let result = AsyncValidator::new(&data)
568 .async_rule("name", FailRule)
569 .validate_async()
570 .await;
571 match result {
572 Err(AsyncValidationError::Validation(e)) => {
573 assert!(e.has("name"), "expected 'name' field error");
574 }
575 other => panic!("expected Validation error, got {other:?}"),
576 }
577 }
578
579 #[tokio::test]
580 #[serial]
581 async fn async_validator_unique_duplicate_is_validation() {
582 init_test_db().await;
584 seed_widget(1, "taken").await;
585
586 let data = json!({"slug": "taken"});
587 let result = AsyncValidator::new(&data)
588 .async_rule(
589 "slug",
590 crate::validation::rules_async::unique("widgets", "slug"),
591 )
592 .validate_async()
593 .await;
594 match result {
595 Err(AsyncValidationError::Validation(e)) => {
596 assert!(
597 e.has("slug"),
598 "expected 'slug' field error for duplicate, errors: {e:?}"
599 );
600 }
601 other => panic!("expected Validation error for duplicate, got {other:?}"),
602 }
603 }
604}