Skip to main content

ferro_rs/validation/
async_validator.rs

1//! Async validator builder and error types.
2
3use std::collections::HashMap;
4
5use serde_json::Value;
6
7use crate::error::FrameworkError;
8use crate::validation::{Rule, ValidationError};
9
10use super::async_rule::AsyncRule;
11
12/// Errors from [`AsyncValidator::validate_async`].
13///
14/// Separates field-level validation failures (→ redirect-back with old input)
15/// from infrastructure failures (→ HTTP 500). A DB error is NEVER a validation
16/// result.
17///
18/// # Usage
19///
20/// ```rust,ignore
21/// match validator.validate_async().await {
22///     Ok(()) => { /* proceed */ }
23///     Err(AsyncValidationError::Validation(e)) => {
24///         return Err(e.with_old_input(&data).into_action_error("/back"));
25///     }
26///     Err(AsyncValidationError::Infra(fe)) => {
27///         return Err(ActionError::from(fe));
28///     }
29/// }
30/// ```
31#[derive(Debug)]
32pub enum AsyncValidationError {
33    /// One or more field validation rules failed. Use `.with_old_input()` +
34    /// `redirect_back` / `redirect_to` / `into_action_error` as usual.
35    Validation(ValidationError),
36    /// A DB or infrastructure error occurred during an async rule. Propagate
37    /// as a framework error (→ 500).
38    Infra(FrameworkError),
39}
40
41impl std::fmt::Display for AsyncValidationError {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self {
44            Self::Validation(e) => write!(f, "Validation failed: {e}"),
45            Self::Infra(e) => write!(f, "Infrastructure error: {e}"),
46        }
47    }
48}
49
50impl std::error::Error for AsyncValidationError {}
51
52// NOTE: no blanket `From<AsyncValidationError> for ActionError`. Such a
53// conversion cannot preserve the field errors, the flashed old input, or the
54// redirect-back URL — all of which require caller context. A lossy `?`
55// conversion would silently drop the field-level message. Callers must match
56// the variants explicitly (see the `AsyncValidationError` doc example):
57//   Validation(e) => e.with_old_input(&data).into_action_error(back_url)
58//   Infra(fe)     => ActionError::from(fe)
59
60/// Async request validator.
61///
62/// Mirrors [`crate::validation::Validator`] ergonomics while adding support for
63/// `Box<dyn AsyncRule>` rules (e.g. DB uniqueness checks). Sync rules run
64/// first; async rules run only on fields with no sync error (fail-fast, D-03).
65///
66/// # Example
67///
68/// ```rust,ignore
69/// use ferro_rs::{AsyncValidator, AsyncValidationError, unique};
70/// use ferro_rs::validation::rules::*;
71/// use ferro_rs::rules;
72///
73/// let data = req.input::<serde_json::Value>().await?;
74/// match AsyncValidator::new(&data)
75///     .rules("slug", rules![required(), string()])
76///     .async_rule("slug", unique("articles", "slug"))
77///     .validate_async()
78///     .await
79/// {
80///     Ok(()) => {}
81///     Err(AsyncValidationError::Validation(e)) => {
82///         return Err(e.with_old_input(&data).into_action_error("/articles/new"));
83///     }
84///     Err(AsyncValidationError::Infra(fe)) => {
85///         return Err(fe.into());
86///     }
87/// }
88/// ```
89pub struct AsyncValidator<'a> {
90    data: &'a Value,
91    sync_rules: HashMap<String, Vec<Box<dyn Rule>>>,
92    async_rules: HashMap<String, Vec<Box<dyn AsyncRule>>>,
93    custom_messages: HashMap<String, String>,
94    custom_attributes: HashMap<String, String>,
95}
96
97impl<'a> AsyncValidator<'a> {
98    /// Create a new async validator for the given data.
99    pub fn new(data: &'a Value) -> Self {
100        Self {
101            data,
102            sync_rules: HashMap::new(),
103            async_rules: HashMap::new(),
104            custom_messages: HashMap::new(),
105            custom_attributes: HashMap::new(),
106        }
107    }
108
109    /// Add a single sync validation rule for a field.
110    pub fn rule<R: Rule + 'static>(mut self, field: impl Into<String>, rule: R) -> Self {
111        let field = field.into();
112        self.sync_rules
113            .entry(field)
114            .or_default()
115            .push(Box::new(rule) as Box<dyn Rule>);
116        self
117    }
118
119    /// Add multiple sync validation rules for a field using boxed rules.
120    ///
121    /// # Example
122    ///
123    /// ```rust,ignore
124    /// use ferro_rs::rules;
125    /// use ferro_rs::validation::rules::*;
126    ///
127    /// AsyncValidator::new(&data)
128    ///     .rules("email", rules![required(), email()])
129    ///     .rules("name", rules![required(), string(), max(255)]);
130    /// ```
131    pub fn rules(mut self, field: impl Into<String>, rules: Vec<Box<dyn Rule>>) -> Self {
132        self.sync_rules.insert(field.into(), rules);
133        self
134    }
135
136    /// Add a single async validation rule for a field.
137    ///
138    /// Async rules run only after all sync rules pass for the field (D-03).
139    pub fn async_rule<R: AsyncRule + 'static>(mut self, field: impl Into<String>, rule: R) -> Self {
140        self.async_rules
141            .entry(field.into())
142            .or_default()
143            .push(Box::new(rule) as Box<dyn AsyncRule>);
144        self
145    }
146
147    /// Set a custom error message for a field.rule combination.
148    ///
149    /// # Example
150    ///
151    /// ```rust,ignore
152    /// AsyncValidator::new(&data)
153    ///     .rules("email", rules![required(), email()])
154    ///     .message("email.required", "Please provide your email address");
155    /// ```
156    pub fn message(mut self, key: impl Into<String>, message: impl Into<String>) -> Self {
157        self.custom_messages.insert(key.into(), message.into());
158        self
159    }
160
161    /// Set custom messages from a map.
162    pub fn messages(mut self, messages: HashMap<String, String>) -> Self {
163        self.custom_messages.extend(messages);
164        self
165    }
166
167    /// Set a custom attribute name for a field.
168    ///
169    /// # Example
170    ///
171    /// ```rust,ignore
172    /// AsyncValidator::new(&data)
173    ///     .rules("email", rules![required()])
174    ///     .attribute("email", "email address");
175    /// // Error: "The email address field is required."
176    /// ```
177    pub fn attribute(mut self, field: impl Into<String>, name: impl Into<String>) -> Self {
178        self.custom_attributes.insert(field.into(), name.into());
179        self
180    }
181
182    /// Set custom attributes from a map.
183    pub fn attributes(mut self, attributes: HashMap<String, String>) -> Self {
184        self.custom_attributes.extend(attributes);
185        self
186    }
187
188    /// Run async validation. Returns:
189    ///
190    /// - `Ok(())` — all rules pass.
191    /// - `Err(AsyncValidationError::Validation(e))` — field-level failures.
192    /// - `Err(AsyncValidationError::Infra(e))` — DB/infra failure (handler → 500).
193    ///
194    /// # Execution order (D-03)
195    ///
196    /// Phase 1: all sync rules run across all fields.
197    /// Phase 2: async rules run only on fields with no sync error — no DB query
198    /// is issued for an already-failed field.
199    ///
200    /// # Infra sentinel (D-12)
201    ///
202    /// An async rule that returns `Err(msg)` where `msg` starts with
203    /// `__infra_error__:` is treated as an infrastructure failure, not a
204    /// field error. The stripped message is wrapped in
205    /// `AsyncValidationError::Infra(FrameworkError::database(...))` and
206    /// returned immediately.
207    pub async fn validate_async(self) -> Result<(), AsyncValidationError> {
208        let mut errors = ValidationError::new();
209
210        // Phase 1 — sync rules first (verbatim from Validator::validate).
211        for (field, rules) in &self.sync_rules {
212            let value = self.get_value(field);
213            let display_field = self.get_display_field(field);
214
215            // nullable() rule: skip all other rules for this field if value is null.
216            let has_nullable = rules.iter().any(|r| r.name() == "nullable");
217            if has_nullable && value.is_null() {
218                continue;
219            }
220
221            for rule in rules {
222                // Skip nullable rule itself — it has no validation message.
223                if rule.name() == "nullable" {
224                    continue;
225                }
226
227                if let Err(default_message) = rule.validate(&display_field, &value, self.data) {
228                    let message_key = format!("{}.{}", field, rule.name());
229                    let message = self
230                        .custom_messages
231                        .get(&message_key)
232                        .cloned()
233                        .unwrap_or(default_message);
234                    errors.add(field, message);
235                }
236            }
237        }
238
239        // Phase 2 — async rules only on fields with no sync error (D-03).
240        for (field, rules) in &self.async_rules {
241            // Fail-fast: no DB query for an already-failed field.
242            if errors.has(field) {
243                continue;
244            }
245
246            let value = self.get_value(field);
247
248            // nullable mirror: if this field carries a sync nullable() rule and
249            // the value is null, skip async rules too (prevents a DB query for
250            // a null value — mirrors sync behavior).
251            if value.is_null() {
252                let nullable = self
253                    .sync_rules
254                    .get(field)
255                    .map(|rs| rs.iter().any(|r| r.name() == "nullable"))
256                    .unwrap_or(false);
257                if nullable {
258                    continue;
259                }
260            }
261
262            let display_field = self.get_display_field(field);
263
264            for rule in rules {
265                match rule.validate(&display_field, &value, self.data).await {
266                    Ok(()) => {}
267                    Err(msg) => {
268                        // D-12: infra failures are NOT field errors.
269                        if let Some(rest) = msg.strip_prefix("__infra_error__:") {
270                            return Err(AsyncValidationError::Infra(FrameworkError::database(
271                                rest.trim().to_string(),
272                            )));
273                        }
274                        let message_key = format!("{}.{}", field, rule.name());
275                        let message = self
276                            .custom_messages
277                            .get(&message_key)
278                            .cloned()
279                            .unwrap_or(msg);
280                        errors.add(field, message);
281                    }
282                }
283            }
284        }
285
286        if errors.is_empty() {
287            Ok(())
288        } else {
289            Err(AsyncValidationError::Validation(errors))
290        }
291    }
292
293    /// Get a value from the data, supporting dot notation.
294    fn get_value(&self, field: &str) -> Value {
295        get_nested_value(self.data, field)
296            .cloned()
297            .unwrap_or(Value::Null)
298    }
299
300    /// Get the display name for a field.
301    fn get_display_field(&self, field: &str) -> String {
302        self.custom_attributes
303            .get(field)
304            .cloned()
305            .unwrap_or_else(|| field.split('_').collect::<Vec<_>>().join(" "))
306    }
307}
308
309/// Get a nested value from JSON using dot notation (verbatim from validator.rs).
310fn get_nested_value<'a>(data: &'a Value, path: &str) -> Option<&'a Value> {
311    let parts: Vec<&str> = path.split('.').collect();
312    let mut current = data;
313
314    for part in parts {
315        // Try as object key.
316        if let Value::Object(map) = current {
317            current = map.get(part)?;
318        }
319        // Try as array index.
320        else if let Value::Array(arr) = current {
321            let index: usize = part.parse().ok()?;
322            current = arr.get(index)?;
323        } else {
324            return None;
325        }
326    }
327
328    Some(current)
329}
330
331#[cfg(test)]
332mod tests {
333    use std::sync::atomic::{AtomicUsize, Ordering};
334    use std::sync::Arc;
335
336    use async_trait::async_trait;
337    use serde_json::json;
338    use serial_test::serial;
339
340    use super::*;
341    use crate::rules;
342    use crate::validation::rules::*;
343
344    // -----------------------------------------------------------------------
345    // Tiny test AsyncRule implementations.
346    // -----------------------------------------------------------------------
347
348    /// Always returns Ok(()).
349    struct OkRule;
350
351    #[async_trait]
352    impl AsyncRule for OkRule {
353        async fn validate(
354            &self,
355            _field: &str,
356            _value: &Value,
357            _data: &Value,
358        ) -> Result<(), String> {
359            Ok(())
360        }
361
362        fn name(&self) -> &'static str {
363            "ok_rule"
364        }
365    }
366
367    /// Increments a shared counter on every validate() call, then returns Ok(()).
368    struct CountingRule {
369        counter: Arc<AtomicUsize>,
370    }
371
372    impl CountingRule {
373        fn new(counter: Arc<AtomicUsize>) -> Self {
374            Self { counter }
375        }
376    }
377
378    #[async_trait]
379    impl AsyncRule for CountingRule {
380        async fn validate(
381            &self,
382            _field: &str,
383            _value: &Value,
384            _data: &Value,
385        ) -> Result<(), String> {
386            self.counter.fetch_add(1, Ordering::SeqCst);
387            Ok(())
388        }
389
390        fn name(&self) -> &'static str {
391            "counting_rule"
392        }
393    }
394
395    /// Always returns Err("__infra_error__: boom").
396    struct InfraRule;
397
398    #[async_trait]
399    impl AsyncRule for InfraRule {
400        async fn validate(
401            &self,
402            _field: &str,
403            _value: &Value,
404            _data: &Value,
405        ) -> Result<(), String> {
406            Err("__infra_error__: boom".to_string())
407        }
408
409        fn name(&self) -> &'static str {
410            "infra_rule"
411        }
412    }
413
414    /// Always returns a validation failure.
415    struct FailRule;
416
417    #[async_trait]
418    impl AsyncRule for FailRule {
419        async fn validate(&self, field: &str, _value: &Value, _data: &Value) -> Result<(), String> {
420            Err(format!("The {field} rule failed."))
421        }
422
423        fn name(&self) -> &'static str {
424            "fail_rule"
425        }
426    }
427
428    // -----------------------------------------------------------------------
429    // Inline DB fixture (mirrors async_rule_fixture.rs).
430    // -----------------------------------------------------------------------
431
432    async fn init_test_db() {
433        use crate::database::{DatabaseConfig, DB};
434        use sea_orm::{ConnectionTrait, Statement};
435        let config = DatabaseConfig::builder().url("sqlite::memory:").build();
436        DB::init_with(config).await.expect("init in-memory sqlite");
437        let db = DB::connection().expect("connection after init");
438        db.execute(Statement::from_string(
439            db.get_database_backend(),
440            "CREATE TABLE IF NOT EXISTS widgets (id INTEGER PRIMARY KEY, slug TEXT)".to_owned(),
441        ))
442        .await
443        .expect("create widgets scratch table");
444    }
445
446    async fn seed_widget(id: i64, slug: &str) {
447        use crate::database::DB;
448        use sea_orm::{ConnectionTrait, Statement};
449        let db = DB::connection().expect("connection for seed_widget");
450        db.execute(Statement::from_string(
451            db.get_database_backend(),
452            format!("INSERT INTO widgets (id, slug) VALUES ({id}, '{slug}')"),
453        ))
454        .await
455        .expect("seed widget row");
456    }
457
458    // -----------------------------------------------------------------------
459    // Tests
460    // -----------------------------------------------------------------------
461
462    #[tokio::test]
463    async fn async_validator_all_pass() {
464        let data = json!({"name": "Alice"});
465        let result = AsyncValidator::new(&data)
466            .rule("name", required())
467            .async_rule("name", OkRule)
468            .validate_async()
469            .await;
470        assert!(result.is_ok(), "expected Ok(()), got: {result:?}");
471    }
472
473    #[tokio::test]
474    async fn async_validator_sync_first() {
475        // Sync rule fails → async rule must never run.
476        let counter = Arc::new(AtomicUsize::new(0));
477        let data = json!({"name": ""});
478        let result = AsyncValidator::new(&data)
479            .rule("name", required())
480            .async_rule("name", CountingRule::new(counter.clone()))
481            .validate_async()
482            .await;
483        assert!(result.is_err(), "expected Err (sync failure)");
484        assert_eq!(
485            counter.load(Ordering::SeqCst),
486            0,
487            "async rule must not run when sync rule fails"
488        );
489    }
490
491    #[tokio::test]
492    async fn async_validator_skips_async_on_sync_error() {
493        // Same as above but with rules![] helper; checks the error shape.
494        let counter = Arc::new(AtomicUsize::new(0));
495        let data = json!({"email": ""});
496        let result = AsyncValidator::new(&data)
497            .rules("email", rules![required()])
498            .async_rule("email", CountingRule::new(counter.clone()))
499            .validate_async()
500            .await;
501        match result {
502            Err(AsyncValidationError::Validation(e)) => {
503                assert!(e.has("email"), "expected 'email' field error");
504            }
505            other => panic!("expected Validation error, got {other:?}"),
506        }
507        assert_eq!(
508            counter.load(Ordering::SeqCst),
509            0,
510            "async rule counter must be 0 (no DB query issued)"
511        );
512    }
513
514    #[tokio::test]
515    async fn async_validator_infra_error_shape() {
516        // An async rule returning __infra_error__: → AsyncValidationError::Infra,
517        // NOT Validation. The field error map must not carry the raw message.
518        let data = json!({"slug": "something"});
519        let result = AsyncValidator::new(&data)
520            .async_rule("slug", InfraRule)
521            .validate_async()
522            .await;
523        match result {
524            Err(AsyncValidationError::Infra(_)) => {
525                // Correct — infra errors must not be field errors.
526            }
527            Err(AsyncValidationError::Validation(e)) => {
528                // Check the field error does not carry the raw sentinel.
529                let msgs = e.get("slug").cloned().unwrap_or_default();
530                for m in &msgs {
531                    assert!(
532                        !m.contains("__infra_error__"),
533                        "infra sentinel must not appear in field errors: {m}"
534                    );
535                }
536                panic!("expected Infra error, got Validation with: {msgs:?}");
537            }
538            Ok(()) => panic!("expected Err(Infra), got Ok(())"),
539        }
540    }
541
542    #[tokio::test]
543    async fn async_validator_nullable_skips_async() {
544        // Field with nullable() + null value → async rule never runs.
545        let counter = Arc::new(AtomicUsize::new(0));
546        let data = json!({"nickname": null});
547        let result = AsyncValidator::new(&data)
548            .rules("nickname", rules![nullable()])
549            .async_rule("nickname", CountingRule::new(counter.clone()))
550            .validate_async()
551            .await;
552        assert!(
553            result.is_ok(),
554            "nullable null field should pass, got: {result:?}"
555        );
556        assert_eq!(
557            counter.load(Ordering::SeqCst),
558            0,
559            "async rule must not run for null nullable field"
560        );
561    }
562
563    #[tokio::test]
564    async fn async_validator_validation_failure_shape() {
565        // Async rule that fails → AsyncValidationError::Validation with the field.
566        let data = json!({"name": "Alice"});
567        let result = AsyncValidator::new(&data)
568            .async_rule("name", FailRule)
569            .validate_async()
570            .await;
571        match result {
572            Err(AsyncValidationError::Validation(e)) => {
573                assert!(e.has("name"), "expected 'name' field error");
574            }
575            other => panic!("expected Validation error, got {other:?}"),
576        }
577    }
578
579    #[tokio::test]
580    #[serial]
581    async fn async_validator_unique_duplicate_is_validation() {
582        // Real unique rule with a seeded duplicate → AsyncValidationError::Validation.
583        init_test_db().await;
584        seed_widget(1, "taken").await;
585
586        let data = json!({"slug": "taken"});
587        let result = AsyncValidator::new(&data)
588            .async_rule(
589                "slug",
590                crate::validation::rules_async::unique("widgets", "slug"),
591            )
592            .validate_async()
593            .await;
594        match result {
595            Err(AsyncValidationError::Validation(e)) => {
596                assert!(
597                    e.has("slug"),
598                    "expected 'slug' field error for duplicate, errors: {e:?}"
599                );
600            }
601            other => panic!("expected Validation error for duplicate, got {other:?}"),
602        }
603    }
604}