#[cfg(test)]
pub mod test;
use crate::{HasValidate, ValidationRejection};
use axum::async_trait;
use axum::extract::{FromRef, FromRequest, FromRequestParts, Request};
use axum::http::request::Parts;
use std::fmt::Display;
use std::ops::{Deref, DerefMut};
use validator::{Validate, ValidateArgs, ValidationErrors};
#[derive(Debug, Clone, Copy, Default)]
pub struct Valid<E>(pub E);
impl<E> Deref for Valid<E> {
type Target = E;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<E> DerefMut for Valid<E> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T: Display> Display for Valid<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<E> Valid<E> {
pub fn into_inner(self) -> E {
self.0
}
}
#[cfg(feature = "aide")]
impl<T> aide::OperationInput for Valid<T>
where
T: aide::OperationInput,
{
fn operation_input(ctx: &mut aide::gen::GenContext, operation: &mut aide::openapi::Operation) {
T::operation_input(ctx, operation);
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ValidEx<E, A>(pub E, pub A);
impl<E, A> Deref for ValidEx<E, A> {
type Target = E;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<E, A> DerefMut for ValidEx<E, A> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T: Display, A> Display for ValidEx<T, A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<E, A> ValidEx<E, A> {
pub fn into_inner(self) -> E {
self.0
}
pub fn arguments<'a>(&'a self) -> <<A as Arguments<'a>>::T as ValidateArgs<'a>>::Args
where
A: Arguments<'a>,
{
self.1.get()
}
}
#[cfg(feature = "aide")]
impl<T, A> aide::OperationInput for ValidEx<T, A>
where
T: aide::OperationInput,
{
fn operation_input(ctx: &mut aide::gen::GenContext, operation: &mut aide::openapi::Operation) {
T::operation_input(ctx, operation);
}
}
pub trait Arguments<'a> {
type T: ValidateArgs<'a>;
fn get(&'a self) -> <<Self as Arguments<'a>>::T as ValidateArgs<'a>>::Args;
}
pub type ValidRejection<E> = ValidationRejection<ValidationErrors, E>;
impl<E> From<ValidationErrors> for ValidRejection<E> {
fn from(value: ValidationErrors) -> Self {
Self::Valid(value)
}
}
pub trait HasValidateArgs<'v> {
type ValidateArgs: ValidateArgs<'v>;
fn get_validate_args(&self) -> &Self::ValidateArgs;
}
#[async_trait]
impl<State, Extractor> FromRequest<State> for Valid<Extractor>
where
State: Send + Sync,
Extractor: HasValidate + FromRequest<State>,
Extractor::Validate: Validate,
{
type Rejection = ValidRejection<<Extractor as FromRequest<State>>::Rejection>;
async fn from_request(req: Request, state: &State) -> Result<Self, Self::Rejection> {
let inner = Extractor::from_request(req, state)
.await
.map_err(ValidRejection::Inner)?;
inner.get_validate().validate()?;
Ok(Valid(inner))
}
}
#[async_trait]
impl<State, Extractor> FromRequestParts<State> for Valid<Extractor>
where
State: Send + Sync,
Extractor: HasValidate + FromRequestParts<State>,
Extractor::Validate: Validate,
{
type Rejection = ValidRejection<<Extractor as FromRequestParts<State>>::Rejection>;
async fn from_request_parts(parts: &mut Parts, state: &State) -> Result<Self, Self::Rejection> {
let inner = Extractor::from_request_parts(parts, state)
.await
.map_err(ValidRejection::Inner)?;
inner.get_validate().validate()?;
Ok(Valid(inner))
}
}
#[async_trait]
impl<State, Extractor, Args> FromRequest<State> for ValidEx<Extractor, Args>
where
State: Send + Sync,
Args: Send
+ Sync
+ FromRef<State>
+ for<'a> Arguments<'a, T = <Extractor as HasValidateArgs<'a>>::ValidateArgs>,
Extractor: for<'v> HasValidateArgs<'v> + FromRequest<State>,
{
type Rejection = ValidRejection<<Extractor as FromRequest<State>>::Rejection>;
async fn from_request(req: Request, state: &State) -> Result<Self, Self::Rejection> {
let arguments: Args = FromRef::from_ref(state);
let inner = Extractor::from_request(req, state)
.await
.map_err(ValidRejection::Inner)?;
inner.get_validate_args().validate_args(arguments.get())?;
Ok(ValidEx(inner, arguments))
}
}
#[async_trait]
impl<State, Extractor, Args> FromRequestParts<State> for ValidEx<Extractor, Args>
where
State: Send + Sync,
Args: Send
+ Sync
+ FromRef<State>
+ for<'a> Arguments<'a, T = <Extractor as HasValidateArgs<'a>>::ValidateArgs>,
Extractor: for<'v> HasValidateArgs<'v> + FromRequestParts<State>,
{
type Rejection = ValidRejection<<Extractor as FromRequestParts<State>>::Rejection>;
async fn from_request_parts(parts: &mut Parts, state: &State) -> Result<Self, Self::Rejection> {
let arguments: Args = FromRef::from_ref(state);
let inner = Extractor::from_request_parts(parts, state)
.await
.map_err(ValidRejection::Inner)?;
inner.get_validate_args().validate_args(arguments.get())?;
Ok(ValidEx(inner, arguments))
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use std::error::Error;
use std::fmt::Formatter;
use std::io;
use validator::ValidationError;
const TEST: &str = "test";
#[test]
fn valid_deref_deref_mut_into_inner() {
let mut inner = String::from(TEST);
let mut v = Valid(inner.clone());
assert_eq!(&inner, v.deref());
inner.push_str(TEST);
v.deref_mut().push_str(TEST);
assert_eq!(&inner, v.deref());
println!("{}", v);
assert_eq!(inner, v.into_inner());
}
#[test]
fn valid_ex_deref_deref_mut_into_inner_arguments() {
let mut inner = String::from(TEST);
let mut v = ValidEx(inner.clone(), ());
assert_eq!(&inner, v.deref());
inner.push_str(TEST);
v.deref_mut().push_str(TEST);
assert_eq!(&inner, v.deref());
assert_eq!(inner, v.into_inner());
fn validate(_v: i32, _args: i32) -> Result<(), ValidationError> {
Ok(())
}
#[derive(Debug, Validate)]
struct Data {
#[validate(custom(function = "validate", arg = "i32"))]
v: i32,
}
impl Display for Data {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
struct DataVA {
a: i32,
}
impl<'a> Arguments<'a> for DataVA {
type T = Data;
fn get(&'a self) -> <<Self as Arguments<'a>>::T as ValidateArgs<'a>>::Args {
self.a
}
}
let data = Data { v: 12 };
let args = DataVA { a: 123 };
let ve = ValidEx(data, args);
println!("{}", ve);
assert_eq!(ve.v, 12);
let a = ve.arguments();
assert_eq!(a, 123);
}
#[test]
fn display_error() {
let mut ve = ValidationErrors::new();
ve.add(TEST, ValidationError::new(TEST));
let vr = ValidRejection::<String>::Valid(ve.clone());
assert_eq!(vr.to_string(), ve.to_string());
let inner = String::from(TEST);
let vr = ValidRejection::<String>::Inner(inner.clone());
assert_eq!(inner.to_string(), vr.to_string());
let mut ve = ValidationErrors::new();
ve.add(TEST, ValidationError::new(TEST));
let vr = ValidRejection::<io::Error>::Valid(ve.clone());
assert!(
matches!(vr.source(), Some(source) if source.downcast_ref::<ValidationErrors>().is_some())
);
let vr = ValidRejection::<io::Error>::Inner(io::Error::new(io::ErrorKind::Other, TEST));
assert!(
matches!(vr.source(), Some(source) if source.downcast_ref::<io::Error>().is_some())
);
}
}