use super::SerializerError;
use reinhardt_db::backends::DatabaseConnection;
use reinhardt_db::orm::{Filter, FilterOperator, FilterValue, Model};
use std::marker::PhantomData;
use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq)]
pub enum DatabaseValidatorError {
#[error("Unique constraint violated: {field} = '{value}' already exists in table {table}")]
UniqueConstraintViolation {
field: String,
value: String,
table: String,
message: Option<String>,
},
#[error(
"Unique together constraint violated: fields ({fields:?}) with values ({values:?}) already exist in table {table}"
)]
UniqueTogetherViolation {
fields: Vec<String>,
values: Vec<String>,
table: String,
message: Option<String>,
},
#[error("Database error during validation: {message}")]
DatabaseError {
message: String,
query: Option<String>,
},
#[error("Required field '{field}' not found in validation data")]
FieldNotFound {
field: String,
},
}
impl From<DatabaseValidatorError> for SerializerError {
fn from(err: DatabaseValidatorError) -> Self {
SerializerError::Other {
message: err.to_string(),
}
}
}
impl From<DatabaseValidatorError> for reinhardt_core::exception::Error {
fn from(err: DatabaseValidatorError) -> Self {
match err {
DatabaseValidatorError::UniqueConstraintViolation {
field,
value,
table,
message,
} => {
let msg = message.unwrap_or_else(|| {
format!(
"Field '{}' with value '{}' already exists in {}",
field, value, table
)
});
reinhardt_core::exception::Error::Conflict(msg)
}
DatabaseValidatorError::UniqueTogetherViolation {
fields,
values,
table,
message,
} => {
let msg = message.unwrap_or_else(|| {
format!(
"Combination of fields {:?} with values {:?} already exists in {}",
fields, values, table
)
});
reinhardt_core::exception::Error::Conflict(msg)
}
DatabaseValidatorError::FieldNotFound { field } => {
reinhardt_core::exception::Error::Validation(format!(
"Required field '{}' not found",
field
))
}
DatabaseValidatorError::DatabaseError { message, .. } => {
reinhardt_core::exception::Error::Database(message)
}
}
}
}
#[derive(Debug, Clone)]
pub struct UniqueValidator<M: Model> {
field_name: String,
message: Option<String>,
_phantom: PhantomData<M>,
}
impl<M: Model> UniqueValidator<M> {
pub fn new(field_name: impl Into<String>) -> Self {
Self {
field_name: field_name.into(),
message: None,
_phantom: PhantomData,
}
}
pub fn with_message(mut self, message: impl Into<String>) -> Self {
self.message = Some(message.into());
self
}
pub fn field_name(&self) -> &str {
&self.field_name
}
pub async fn validate(
&self,
_connection: &DatabaseConnection,
value: &str,
instance_pk: Option<&M::PrimaryKey>,
) -> Result<(), DatabaseValidatorError>
where
M::PrimaryKey: std::fmt::Display,
{
let table_name = M::table_name();
let mut qs = M::objects().all();
qs = qs.filter(Filter::new(
self.field_name.clone(),
FilterOperator::Eq,
FilterValue::String(value.to_string()),
));
if let Some(pk) = instance_pk {
qs = qs.filter(Filter::new(
M::primary_key_field().to_string(),
FilterOperator::Ne,
FilterValue::String(pk.to_string()),
));
}
let count = qs
.count()
.await
.map_err(|e| DatabaseValidatorError::DatabaseError {
message: e.to_string(),
query: None,
})?;
if count > 0 {
Err(DatabaseValidatorError::UniqueConstraintViolation {
field: self.field_name.clone(),
value: value.to_string(),
table: table_name.to_string(),
message: self.message.clone(),
})
} else {
Ok(())
}
}
}
#[derive(Debug, Clone)]
pub struct UniqueTogetherValidator<M: Model> {
field_names: Vec<String>,
message: Option<String>,
_phantom: PhantomData<M>,
}
impl<M: Model> UniqueTogetherValidator<M> {
pub fn new(field_names: Vec<impl Into<String>>) -> Self {
Self {
field_names: field_names.into_iter().map(|f| f.into()).collect(),
message: None,
_phantom: PhantomData,
}
}
pub fn with_message(mut self, message: impl Into<String>) -> Self {
self.message = Some(message.into());
self
}
pub fn field_names(&self) -> &[String] {
&self.field_names
}
pub async fn validate(
&self,
_connection: &DatabaseConnection,
values: &std::collections::HashMap<String, String>,
instance_pk: Option<&M::PrimaryKey>,
) -> Result<(), DatabaseValidatorError>
where
M::PrimaryKey: std::fmt::Display,
{
let table_name = M::table_name();
let mut qs = M::objects().all();
let mut field_values = Vec::new();
for field_name in &self.field_names {
let value =
values
.get(field_name)
.ok_or_else(|| DatabaseValidatorError::FieldNotFound {
field: field_name.clone(),
})?;
field_values.push(value.clone());
qs = qs.filter(Filter::new(
field_name.clone(),
FilterOperator::Eq,
FilterValue::String(value.clone()),
));
}
if let Some(pk) = instance_pk {
qs = qs.filter(Filter::new(
M::primary_key_field().to_string(),
FilterOperator::Ne,
FilterValue::String(pk.to_string()),
));
}
let count = qs
.count()
.await
.map_err(|e| DatabaseValidatorError::DatabaseError {
message: e.to_string(),
query: None,
})?;
if count > 0 {
Err(DatabaseValidatorError::UniqueTogetherViolation {
fields: self.field_names.clone(),
values: field_values,
table: table_name.to_string(),
message: self.message.clone(),
})
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use reinhardt_db::orm::FieldSelector;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct TestUser {
id: Option<i64>,
username: String,
email: String,
}
#[derive(Debug, Clone)]
struct TestUserFields;
impl FieldSelector for TestUserFields {
fn with_alias(self, _alias: &str) -> Self {
self
}
}
impl Model for TestUser {
type PrimaryKey = i64;
type Fields = TestUserFields;
fn table_name() -> &'static str {
"test_users"
}
fn new_fields() -> Self::Fields {
TestUserFields
}
fn primary_key(&self) -> Option<Self::PrimaryKey> {
self.id
}
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
self.id = Some(value);
}
}
#[test]
fn test_unique_validator_new() {
let validator = UniqueValidator::<TestUser>::new("username");
assert_eq!(validator.field_name(), "username");
}
#[test]
fn test_unique_validator_with_message() {
let validator =
UniqueValidator::<TestUser>::new("username").with_message("Custom error message");
assert_eq!(validator.field_name(), "username");
assert!(validator.message.is_some());
}
#[test]
fn test_unique_together_validator_new() {
let validator = UniqueTogetherValidator::<TestUser>::new(vec!["username", "email"]);
assert_eq!(validator.field_names().len(), 2);
assert_eq!(validator.field_names()[0], "username");
assert_eq!(validator.field_names()[1], "email");
}
#[test]
fn test_unique_together_validator_with_message() {
let validator = UniqueTogetherValidator::<TestUser>::new(vec!["username", "email"])
.with_message("Custom combination message");
assert_eq!(validator.field_names().len(), 2);
assert!(validator.message.is_some());
}
}