Skip to main content

pipi/
validation.rs

1//! This module provides utility functions for handling validation errors for
2//! structs. It useful if you want to validate model before insert to Database.
3//!
4//! # Example:
5//!
6//! In the following example you can see how you can validate a user model
7//! ```rust,ignore
8//! use pipi::prelude::*;
9//! pub use myapp::_entities::users::ActiveModel;
10//!
11//! // Validation structure
12//! #[derive(Debug, Validate, Deserialize)]
13//! pub struct Validator {
14//!     #[validate(length(min = 2, message = "Name must be at least 2 characters long."))]
15//!     pub name: String,
16//! }
17//!
18//! impl Validatable for ActiveModel {
19//!   fn validator(&self) -> Box<dyn Validate> {
20//!     Box::new(Validator {
21//!         name: self.name.as_ref().to_owned(),
22//!     })
23//!   }
24//! }
25//!
26//! /// Override `before_save` function and run validation to make sure that we insert valid data.
27//! #[async_trait::async_trait]
28//! impl ActiveModelBehavior for ActiveModel {
29//!     async fn before_save<C>(self, _db: &C, insert: bool) -> Result<Self, DbErr>
30//!     where
31//!         C: ConnectionTrait,
32//!     {
33//!         {
34//!             self.validate()?;
35//!             Ok(self)
36//!         }
37//!     }
38//! }
39//! ```
40
41#[cfg(feature = "with-db")]
42use sea_orm::DbErr;
43use serde::{Deserialize, Serialize};
44use std::collections::{BTreeMap, HashMap};
45use validator::ValidationErrors;
46
47// this is a line-serialization type. it is used as an intermediate format
48// to hold validation error data when we transform from
49// validation::ValidationErrors to DbErr and encode all information in json.
50#[derive(Debug, Deserialize, Serialize)]
51#[allow(clippy::module_name_repetitions)]
52pub struct ModelValidationMessage {
53    pub code: String,
54    pub message: Option<String>,
55}
56
57/// <DbErr conversion hack>
58///
59/// Convert `ModelValidationErrors` (pretty) into a `DbErr` (ugly) for database
60/// handling.
61///
62/// Because `DbErr` is used in model hooks and we implement the hooks
63/// in the trait, we MUST use `DbErr`, so we need to "hide" a _representation_
64/// of the error in `DbErr::Custom`, so that it can be unpacked later down the
65/// stream, in the central error response handler.
66#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
67pub struct ValidationError {
68    pub code: String,
69    pub message: Option<String>,
70    #[serde(skip_serializing_if = "HashMap::is_empty")]
71    pub params: HashMap<String, serde_json::Value>,
72}
73
74#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone, PartialEq, Eq)]
75#[error("Model validation failed")]
76pub struct ModelValidationErrors {
77    pub errors: BTreeMap<String, Vec<ValidationError>>,
78}
79
80impl From<ValidationErrors> for ModelValidationErrors {
81    fn from(value: ValidationErrors) -> Self {
82        let mut map: BTreeMap<String, Vec<ValidationError>> = BTreeMap::new();
83        for (field, errs) in &value.field_errors() {
84            let mut list: Vec<ValidationError> = Vec::with_capacity(errs.len());
85            for err in *errs {
86                let mut params: HashMap<String, serde_json::Value> = HashMap::new();
87                for (k, v) in &err.params {
88                    params.insert(k.to_string(), v.clone());
89                }
90                list.push(ValidationError {
91                    code: err.code.to_string(),
92                    message: err.message.as_ref().map(std::string::ToString::to_string),
93                    params,
94                });
95            }
96            map.insert((*field).to_string(), list);
97        }
98        Self { errors: map }
99    }
100}
101
102#[cfg(feature = "with-db")]
103impl From<ModelValidationErrors> for DbErr {
104    fn from(errors: ModelValidationErrors) -> Self {
105        into_db_error(&errors)
106    }
107}
108
109#[cfg(feature = "with-db")]
110#[must_use]
111pub fn into_db_error(errors: &ModelValidationErrors) -> sea_orm::DbErr {
112    let compact: BTreeMap<String, Vec<ModelValidationMessage>> = errors
113        .errors
114        .iter()
115        .map(|(field, list)| {
116            let flat: Vec<ModelValidationMessage> = list
117                .iter()
118                .map(|e| ModelValidationMessage {
119                    code: e.code.clone(),
120                    message: e.message.clone(),
121                })
122                .collect();
123            (field.clone(), flat)
124        })
125        .collect();
126
127    match serde_json::to_string(&compact) {
128        Ok(s) => sea_orm::DbErr::Custom(s),
129        Err(err) => sea_orm::DbErr::Custom(format!(
130            "[before_save] could not parse validation errors. err: {err}"
131        )),
132    }
133}
134
135/// Implement `Validatable` for `ActiveModel` when you want it to have a
136/// `validate()` function.
137pub trait ValidatorTrait {
138    /// Perform validation and return a normalized error type
139    ///
140    /// # Errors
141    ///
142    /// Returns `ModelValidationErrors` when validation fails.
143    fn validate(&self) -> Result<(), ModelValidationErrors>;
144}
145
146/// Adapter: allow using the `validator` crate seamlessly
147impl<T: validator::Validate> ValidatorTrait for T {
148    fn validate(&self) -> Result<(), ModelValidationErrors> {
149        validator::Validate::validate(self).map_err(ModelValidationErrors::from)
150    }
151}
152
153/// Implement `Validatable` for `ActiveModel` when you want it to have a
154/// `validate()` function.
155pub trait Validatable {
156    /// Perform validation
157    ///
158    /// # Errors
159    ///
160    /// This function will return an error if there are validation errors
161    fn validate(&self) -> Result<(), ModelValidationErrors> {
162        let v = self.validator();
163        validator::Validate::validate(&*v).map_err(ModelValidationErrors::from)
164    }
165    fn validator(&self) -> Box<dyn validator::Validate>;
166}
167
168#[cfg(test)]
169mod tests {
170
171    use insta::assert_debug_snapshot;
172    use rstest::rstest;
173    use serde::Deserialize;
174    use validator::Validate;
175
176    use super::*;
177
178    #[derive(Debug, Deserialize, Validate)]
179    pub struct TestValidator {
180        #[validate(length(min = 4, message = "Invalid min characters long."))]
181        pub name: String,
182    }
183
184    #[cfg(feature = "with-db")]
185    #[rstest]
186    #[case("foo")]
187    #[case("foo-bar")]
188    fn can_validate_into_db_error(#[case] name: &str) {
189        let data = TestValidator {
190            name: name.to_string(),
191        };
192
193        assert_debug_snapshot!(
194            format!("struct-[{name}]"),
195            validator::Validate::validate(&data)
196                .map_err(|e| into_db_error(&ModelValidationErrors::from(e)))
197        );
198    }
199
200    // Custom validator example without the `validator` crate
201    #[derive(Debug, Deserialize)]
202    pub struct CustomValidator {
203        pub name: String,
204    }
205
206    impl ValidatorTrait for CustomValidator {
207        fn validate(&self) -> Result<(), ModelValidationErrors> {
208            if self.name.len() < 4 {
209                let mut errors: BTreeMap<String, Vec<ValidationError>> = BTreeMap::new();
210                errors.insert(
211                    "name".to_string(),
212                    vec![ValidationError {
213                        code: "length".to_string(),
214                        message: Some("Invalid min characters long.".to_string()),
215                        params: HashMap::new(),
216                    }],
217                );
218                return Err(ModelValidationErrors { errors });
219            }
220            Ok(())
221        }
222    }
223
224    #[rstest]
225    #[case("ab")]
226    #[case("abcd")]
227    fn custom_validator_works(#[case] name: &str) {
228        let v = CustomValidator {
229            name: name.to_string(),
230        };
231        let res = v.validate();
232        if name.len() < 4 {
233            assert!(res.is_err());
234        } else {
235            assert!(res.is_ok());
236        }
237    }
238}