use crate::error::ValidationError;
use crate::RuleContext;
use async_trait::async_trait;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Clone, Default)]
pub struct ValidationContext {
resources: std::collections::HashMap<&'static str, Arc<dyn std::any::Any + Send + Sync>>,
}
impl ValidationContext {
pub fn new() -> Self {
Self {
resources: std::collections::HashMap::new(),
}
}
pub fn with_resource<T: Send + Sync + 'static>(
mut self,
key: &'static str,
resource: Arc<T>,
) -> Self {
self.resources.insert(key, resource);
self
}
pub fn get_resource<T: Send + Sync + 'static>(&self, key: &'static str) -> Option<Arc<T>> {
self.resources
.get(key)
.and_then(|any| any.clone().downcast::<T>().ok())
}
}
#[async_trait]
pub trait AsyncValidate {
async fn validate_async(&self, ctx: &ValidationContext) -> Result<(), ValidationError>;
}
pub type AsyncRuleFn<T> = Arc<
dyn Fn(
&T,
&RuleContext,
&ValidationContext,
) -> Pin<Box<dyn Future<Output = ValidationError> + Send>>
+ Send
+ Sync,
>;
pub struct AsyncRule<T: ?Sized> {
func: AsyncRuleFn<T>,
}
impl<T: ?Sized> AsyncRule<T> {
pub fn new<F, Fut>(func: F) -> Self
where
F: Fn(&T, &RuleContext, &ValidationContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ValidationError> + Send + 'static,
{
Self {
func: Arc::new(move |value, ctx, vctx| Box::pin(func(value, ctx, vctx))),
}
}
pub async fn apply(
&self,
value: &T,
ctx: &RuleContext,
vctx: &ValidationContext,
) -> ValidationError {
(self.func)(value, ctx, vctx).await
}
}
impl<T: ?Sized> Clone for AsyncRule<T> {
fn clone(&self) -> Self {
Self {
func: Arc::clone(&self.func),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_new() {
let ctx = ValidationContext::new();
assert!(ctx.resources.is_empty());
}
#[test]
fn test_context_with_resource() {
let value = Arc::new(42i32);
let ctx = ValidationContext::new().with_resource("test", value.clone());
let retrieved = ctx.get_resource::<i32>("test");
assert!(retrieved.is_some());
assert_eq!(*retrieved.unwrap(), 42);
}
#[test]
fn test_context_get_resource_wrong_type() {
let value = Arc::new(42i32);
let ctx = ValidationContext::new().with_resource("test", value);
let retrieved = ctx.get_resource::<String>("test");
assert!(retrieved.is_none());
}
#[test]
fn test_context_get_resource_missing_key() {
let ctx = ValidationContext::new();
let retrieved = ctx.get_resource::<i32>("missing");
assert!(retrieved.is_none());
}
struct TestUser {
email: String,
}
#[async_trait]
impl AsyncValidate for TestUser {
async fn validate_async(&self, _ctx: &ValidationContext) -> Result<(), ValidationError> {
if self.email.is_empty() {
return Err(ValidationError::single(
crate::Path::from("email"),
"required",
"Email is required",
));
}
Ok(())
}
}
#[tokio::test]
async fn test_async_validate_success() {
let user = TestUser {
email: "test@example.com".to_string(),
};
let ctx = ValidationContext::new();
assert!(user.validate_async(&ctx).await.is_ok());
}
#[tokio::test]
async fn test_async_validate_failure() {
let user = TestUser {
email: String::new(),
};
let ctx = ValidationContext::new();
let result = user.validate_async(&ctx).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_async_rule_success() {
let rule = AsyncRule::new(
|value: &str, _ctx: &RuleContext, _vctx: &ValidationContext| {
let len = value.len();
async move {
if len >= 3 {
ValidationError::default()
} else {
ValidationError::single(
crate::Path::root(),
"too_short",
"Value must be at least 3 characters",
)
}
}
},
);
let ctx = RuleContext::root("test");
let vctx = ValidationContext::new();
let result = rule.apply("hello", &ctx, &vctx).await;
assert!(result.is_empty());
}
#[tokio::test]
async fn test_async_rule_failure() {
let rule = AsyncRule::new(
|value: &str, _ctx: &RuleContext, _vctx: &ValidationContext| {
let len = value.len();
async move {
if len >= 3 {
ValidationError::default()
} else {
ValidationError::single(
crate::Path::root(),
"too_short",
"Value must be at least 3 characters",
)
}
}
},
);
let ctx = RuleContext::root("test");
let vctx = ValidationContext::new();
let result = rule.apply("ab", &ctx, &vctx).await;
assert!(!result.is_empty());
assert_eq!(result.violations.len(), 1);
assert_eq!(result.violations[0].code, "too_short");
}
#[tokio::test]
async fn test_async_rule_with_context_resource() {
#[derive(Clone)]
struct MockDatabase {
taken_emails: Vec<String>,
}
impl MockDatabase {
fn new() -> Self {
Self {
taken_emails: vec!["taken@example.com".to_string()],
}
}
fn exists(&self, email: &str) -> bool {
self.taken_emails.contains(&email.to_string())
}
}
let rule = AsyncRule::new(|email: &str, ctx: &RuleContext, vctx: &ValidationContext| {
let db = vctx
.get_resource::<MockDatabase>("db")
.expect("Database not in context");
let email = email.to_string();
let path = ctx.full_path();
async move {
if db.exists(&email) {
ValidationError::single(path, "email_taken", "Email is already registered")
} else {
ValidationError::default()
}
}
});
let db = Arc::new(MockDatabase::new());
let vctx = ValidationContext::new().with_resource("db", db);
let ctx = RuleContext::root("email");
let result = rule.apply("taken@example.com", &ctx, &vctx).await;
assert!(!result.is_empty());
assert_eq!(result.violations[0].code, "email_taken");
let result = rule.apply("available@example.com", &ctx, &vctx).await;
assert!(result.is_empty());
}
}