use super::ValidatorError;
use super::validators::{DatabaseValidatorError, UniqueTogetherValidator, UniqueValidator};
use reinhardt_db::backends::DatabaseConnection;
use reinhardt_db::orm::Model;
use serde::Serialize;
use std::marker::PhantomData;
use std::sync::Arc;
pub trait ModelLevelValidator<M>: Send + Sync + std::fmt::Debug {
fn validate(&self, instance: &M) -> Result<(), ValidatorError>;
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ValidatorConfig<M: Model> {
unique_validators: Vec<UniqueValidator<M>>,
unique_together_validators: Vec<UniqueTogetherValidator<M>>,
sync_model_validators: Vec<Arc<dyn ModelLevelValidator<M>>>,
_phantom: PhantomData<M>,
}
impl<M: Model> ValidatorConfig<M> {
pub fn new() -> Self {
Self {
unique_validators: Vec::new(),
unique_together_validators: Vec::new(),
sync_model_validators: Vec::new(),
_phantom: PhantomData,
}
}
pub fn add_unique_validator(&mut self, validator: UniqueValidator<M>) {
self.unique_validators.push(validator);
}
pub fn add_unique_together_validator(&mut self, validator: UniqueTogetherValidator<M>) {
self.unique_together_validators.push(validator);
}
pub fn add_sync_model_validator(&mut self, validator: Arc<dyn ModelLevelValidator<M>>) {
self.sync_model_validators.push(validator);
}
pub fn unique_validators(&self) -> &[UniqueValidator<M>] {
&self.unique_validators
}
pub fn unique_together_validators(&self) -> &[UniqueTogetherValidator<M>] {
&self.unique_together_validators
}
pub fn sync_model_validators(&self) -> &[Arc<dyn ModelLevelValidator<M>>] {
&self.sync_model_validators
}
pub fn has_validators(&self) -> bool {
!self.unique_validators.is_empty()
|| !self.unique_together_validators.is_empty()
|| !self.sync_model_validators.is_empty()
}
pub fn validate(&self, instance: &M) -> Result<(), ValidatorError> {
for validator in &self.sync_model_validators {
validator.validate(instance)?;
}
Ok(())
}
pub async fn validate_async(
&self,
connection: &DatabaseConnection,
instance: &M,
instance_pk: Option<&M::PrimaryKey>,
) -> Result<(), DatabaseValidatorError>
where
M: Serialize,
M::PrimaryKey: std::fmt::Display,
{
let value =
serde_json::to_value(instance).map_err(|e| DatabaseValidatorError::DatabaseError {
message: format!("Failed to serialize model: {}", e),
query: None,
})?;
let obj = value
.as_object()
.ok_or_else(|| DatabaseValidatorError::DatabaseError {
message: "Model must serialize to an object".to_string(),
query: None,
})?;
for validator in &self.unique_validators {
let field_value = obj
.get(validator.field_name())
.and_then(|v| v.as_str())
.ok_or_else(|| DatabaseValidatorError::FieldNotFound {
field: validator.field_name().to_string(),
})?;
validator
.validate(connection, field_value, instance_pk)
.await?;
}
for validator in &self.unique_together_validators {
let mut values = std::collections::HashMap::new();
for field in validator.field_names() {
let value = obj.get(field).and_then(|v| v.as_str()).ok_or_else(|| {
DatabaseValidatorError::FieldNotFound {
field: field.clone(),
}
})?;
values.insert(field.clone(), value.to_string());
}
validator.validate(connection, &values, instance_pk).await?;
}
Ok(())
}
}
impl<M: Model> Default for ValidatorConfig<M> {
fn default() -> Self {
Self::new()
}
}
impl<M: Model> std::panic::UnwindSafe for ValidatorConfig<M> {}
impl<M: Model> std::panic::RefUnwindSafe for ValidatorConfig<M> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::serializers::validators::{UniqueTogetherValidator, UniqueValidator};
use reinhardt_db::orm::FieldSelector;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct TestUser {
id: Option<i64>,
username: String,
email: String,
}
#[derive(Debug, Clone)]
struct TestUserFields;
impl FieldSelector for TestUserFields {
fn with_alias(self, _alias: &str) -> Self {
self
}
}
impl Model for TestUser {
type PrimaryKey = i64;
type Fields = TestUserFields;
fn table_name() -> &'static str {
"test_users"
}
fn new_fields() -> Self::Fields {
TestUserFields
}
fn primary_key(&self) -> Option<Self::PrimaryKey> {
self.id
}
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
self.id = Some(value);
}
}
#[test]
fn test_validator_config_new() {
let config = ValidatorConfig::<TestUser>::new();
assert_eq!(config.unique_validators().len(), 0);
assert_eq!(config.unique_together_validators().len(), 0);
assert!(!config.has_validators());
}
#[test]
fn test_add_unique_validator() {
let mut config = ValidatorConfig::<TestUser>::new();
config.add_unique_validator(UniqueValidator::new("username"));
assert_eq!(config.unique_validators().len(), 1);
assert!(config.has_validators());
}
#[test]
fn test_add_unique_together_validator() {
let mut config = ValidatorConfig::<TestUser>::new();
config
.add_unique_together_validator(UniqueTogetherValidator::new(vec!["username", "email"]));
assert_eq!(config.unique_together_validators().len(), 1);
assert!(config.has_validators());
}
#[test]
fn test_multiple_validators() {
let mut config = ValidatorConfig::<TestUser>::new();
config.add_unique_validator(UniqueValidator::new("username"));
config.add_unique_validator(UniqueValidator::new("email"));
config
.add_unique_together_validator(UniqueTogetherValidator::new(vec!["username", "email"]));
assert_eq!(config.unique_validators().len(), 2);
assert_eq!(config.unique_together_validators().len(), 1);
assert!(config.has_validators());
}
}