1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
#![doc = include_str!("../README.md")]
#![deny(unsafe_code, missing_docs, clippy::unwrap_used)]

#[cfg(feature = "extra")]
pub mod extra;
#[cfg(feature = "form")]
pub mod form;
#[cfg(feature = "json")]
pub mod json;
#[cfg(feature = "msgpack")]
pub mod msgpack;
pub mod path;
#[cfg(feature = "query")]
pub mod query;
#[cfg(test)]
pub mod test;
#[cfg(feature = "typed_header")]
pub mod typed_header;
#[cfg(feature = "typed_multipart")]
pub mod typed_multipart;
#[cfg(feature = "yaml")]
pub mod yaml;

use axum::async_trait;
use axum::extract::{FromRef, FromRequest, FromRequestParts};
use axum::http::request::Parts;
use axum::http::{Request, StatusCode};
use axum::response::{IntoResponse, Response};
use std::error::Error;
use std::fmt::{Debug, Display, Formatter};
use std::ops::{Deref, DerefMut};
use validator::{Validate, ValidateArgs, ValidationErrors};

/// Http status code returned when there are validation errors.
#[cfg(feature = "422")]
pub const VALIDATION_ERROR_STATUS: StatusCode = StatusCode::UNPROCESSABLE_ENTITY;
/// Http status code returned when there are validation errors.
#[cfg(not(feature = "422"))]
pub const VALIDATION_ERROR_STATUS: StatusCode = StatusCode::BAD_REQUEST;

/// # `Valid` data extractor
///
/// This extractor can be used in combination with axum's extractors like
/// Json, Form, Query, Path, etc to validate their inner data automatically.
/// It can also work with custom extractors that implement the `HasValidate` trait.
///
/// See the docs for each integration module to find examples of using
/// `Valid` with that extractor.
///
/// For examples with custom extractors, check out the `tests/custom.rs` file.
///
#[derive(Debug, Clone, Copy, Default)]
pub struct Valid<E>(pub E);

impl<E> Deref for Valid<E> {
    type Target = E;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<E> DerefMut for Valid<E> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

impl<T: Display, A> Display for ValidEx<T, A> {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        self.0.fmt(f)
    }
}

impl<E> Valid<E> {
    /// Consume the `Valid` extractor and returns the inner type.
    pub fn into_inner(self) -> E {
        self.0
    }
}

/// # `ValidEx` data extractor
///
/// `ValidEx` can be incorporated with extractors from various modules, similar to `Valid`.
/// Two differences exist between `ValidEx` and `Valid`:
///
/// - The inner data type in `ValidEx` implements `ValidateArgs` instead of `Validate`.
/// - `ValidEx` includes a second field that represents arguments used during validation of the first field.
///
/// The implementation of `ValidateArgs` is often automatically handled by validator's derive macros
/// (for more details, please refer to the validator's documentation).
///
/// Although current module documentation predominantly showcases `Valid` examples, the usage of `ValidEx` is analogous.
///
#[derive(Debug, Clone, Copy, Default)]
pub struct ValidEx<E, A>(pub E, pub A);

impl<E, A> Deref for ValidEx<E, A> {
    type Target = E;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<E, A> DerefMut for ValidEx<E, A> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

impl<T: Display> Display for Valid<T> {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        self.0.fmt(f)
    }
}

impl<E, A> ValidEx<E, A> {
    /// Consumes the `ValidEx` and returns the validated data within.
    ///
    /// This returns the `E` type which represents the data that has been
    /// successfully validated.
    pub fn into_inner(self) -> E {
        self.0
    }

    /// Returns a reference to the validation arguments.
    ///
    /// This provides access to the `A` type which contains the arguments used
    /// to validate the data. These arguments were passed to the validation
    /// function.
    pub fn arguments<'a>(&'a self) -> <<A as Arguments<'a>>::T as ValidateArgs<'a>>::Args
    where
        A: Arguments<'a>,
    {
        self.1.get()
    }
}

fn response_builder(ve: ValidationErrors) -> Response {
    #[cfg(feature = "into_json")]
    {
        (VALIDATION_ERROR_STATUS, axum::Json(ve)).into_response()
    }
    #[cfg(not(feature = "into_json"))]
    {
        (VALIDATION_ERROR_STATUS, ve.to_string()).into_response()
    }
}

/// `Arguments` provides the validation arguments for the data type `T`.
///
/// This trait has an associated type `T` which represents the data type to
/// validate. `T` must implement the `ValidateArgs` trait which defines the
/// validation logic.
///
/// It's important to mention that types implementing `Arguments` should be a part of the router's state
/// (either through implementing `FromRef<StateType>` or by directly becoming the state)
/// to enable automatic arguments retrieval during validation.
///
pub trait Arguments<'a> {
    /// The data type to validate using this arguments
    type T: ValidateArgs<'a>;
    /// This method gets the arguments required by `ValidateArgs::validate_args`
    fn get(&'a self) -> <<Self as Arguments<'a>>::T as ValidateArgs<'a>>::Args;
}

/// `ValidRejection` is returned when the `Valid` extractor fails.
///
/// This enumeration captures two types of errors that can occur when using `Valid`: errors related to the validation
/// logic itself (encapsulated in `Valid`), and errors that may arise within the inner extractor (represented by `Inner`).
///
#[derive(Debug)]
pub enum ValidRejection<E> {
    /// `Valid` variant captures errors related to the validation logic. It contains `ValidationErrors`
    /// which is a collection of validation failures for each field.
    Valid(ValidationErrors),
    /// `Inner` variant represents potential errors that might occur within the inner extractor.
    Inner(E),
}

impl<E: Display> Display for ValidRejection<E> {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            ValidRejection::Valid(errors) => write!(f, "{errors}"),
            ValidRejection::Inner(error) => write!(f, "{error}"),
        }
    }
}

impl<E: Error + 'static> Error for ValidRejection<E> {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            ValidRejection::Valid(ve) => Some(ve),
            ValidRejection::Inner(e) => Some(e),
        }
    }
}

impl<E> From<ValidationErrors> for ValidRejection<E> {
    fn from(value: ValidationErrors) -> Self {
        Self::Valid(value)
    }
}

impl<E: IntoResponse> IntoResponse for ValidRejection<E> {
    fn into_response(self) -> Response {
        match self {
            ValidRejection::Valid(ve) => response_builder(ve),
            ValidRejection::Inner(e) => e.into_response(),
        }
    }
}

/// Trait for types that can supply a reference that can be validated.
///
/// Extractor types `T` that implement this trait can be used with `Valid`.
///
pub trait HasValidate {
    /// Inner type that can be validated for correctness
    type Validate: Validate;
    /// Get the inner value
    fn get_validate(&self) -> &Self::Validate;
}

/// Trait for types that can supply a reference that can be validated using arguments.
///
/// Extractor types `T` that implement this trait can be used with `ValidEx`.
///
pub trait HasValidateArgs<'v> {
    /// Inner type that can be validated using arguments
    type ValidateArgs: ValidateArgs<'v>;
    /// Get the inner value
    fn get_validate_args(&self) -> &Self::ValidateArgs;
}

#[async_trait]
impl<State, Body, Extractor> FromRequest<State, Body> for Valid<Extractor>
where
    State: Send + Sync,
    Body: Send + Sync + 'static,
    Extractor: HasValidate + FromRequest<State, Body>,
    Extractor::Validate: Validate,
{
    type Rejection = ValidRejection<<Extractor as FromRequest<State, Body>>::Rejection>;

    async fn from_request(req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> {
        let inner = Extractor::from_request(req, state)
            .await
            .map_err(ValidRejection::Inner)?;
        inner.get_validate().validate()?;
        Ok(Valid(inner))
    }
}

#[async_trait]
impl<State, Extractor> FromRequestParts<State> for Valid<Extractor>
where
    State: Send + Sync,
    Extractor: HasValidate + FromRequestParts<State>,
    Extractor::Validate: Validate,
{
    type Rejection = ValidRejection<<Extractor as FromRequestParts<State>>::Rejection>;

    async fn from_request_parts(parts: &mut Parts, state: &State) -> Result<Self, Self::Rejection> {
        let inner = Extractor::from_request_parts(parts, state)
            .await
            .map_err(ValidRejection::Inner)?;
        inner.get_validate().validate()?;
        Ok(Valid(inner))
    }
}

#[async_trait]
impl<State, Body, Extractor, Args> FromRequest<State, Body> for ValidEx<Extractor, Args>
where
    State: Send + Sync,
    Body: Send + Sync + 'static,
    Args: Send
        + Sync
        + FromRef<State>
        + for<'a> Arguments<'a, T = <Extractor as HasValidateArgs<'a>>::ValidateArgs>,
    Extractor: for<'v> HasValidateArgs<'v> + FromRequest<State, Body>,
    for<'v> <Extractor as HasValidateArgs<'v>>::ValidateArgs: ValidateArgs<'v>,
{
    type Rejection = ValidRejection<<Extractor as FromRequest<State, Body>>::Rejection>;

    async fn from_request(req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> {
        let arguments: Args = FromRef::from_ref(state);
        let inner = Extractor::from_request(req, state)
            .await
            .map_err(ValidRejection::Inner)?;

        inner.get_validate_args().validate_args(arguments.get())?;
        Ok(ValidEx(inner, arguments))
    }
}

#[async_trait]
impl<State, Extractor, Args> FromRequestParts<State> for ValidEx<Extractor, Args>
where
    State: Send + Sync,
    Args: Send
        + Sync
        + FromRef<State>
        + for<'a> Arguments<'a, T = <Extractor as HasValidateArgs<'a>>::ValidateArgs>,
    Extractor: for<'v> HasValidateArgs<'v> + FromRequestParts<State>,
    for<'v> <Extractor as HasValidateArgs<'v>>::ValidateArgs: ValidateArgs<'v>,
{
    type Rejection = ValidRejection<<Extractor as FromRequestParts<State>>::Rejection>;

    async fn from_request_parts(parts: &mut Parts, state: &State) -> Result<Self, Self::Rejection> {
        let arguments: Args = FromRef::from_ref(state);
        let inner = Extractor::from_request_parts(parts, state)
            .await
            .map_err(ValidRejection::Inner)?;
        inner.get_validate_args().validate_args(arguments.get())?;
        Ok(ValidEx(inner, arguments))
    }
}

#[cfg(test)]
pub mod tests {
    use crate::{Arguments, Valid, ValidEx, ValidRejection};
    use reqwest::{RequestBuilder, StatusCode};
    use serde::Serialize;
    use std::error::Error;
    use std::io;
    use std::ops::{Deref, DerefMut};
    use validator::{Validate, ValidateArgs, ValidationError, ValidationErrors};

    /// # Valid test parameter
    pub trait ValidTestParameter: Serialize + 'static {
        /// Create a valid parameter
        fn valid() -> &'static Self;
        /// Create an error serializable array
        fn error() -> &'static [(&'static str, &'static str)];
        /// Create a invalid parameter
        fn invalid() -> &'static Self;
    }

    /// # Valid Tests
    ///
    /// This trait defines three test cases to check
    /// if an extractor combined with the Valid type works properly.
    ///
    /// 1. For a valid request, the server should return `200 OK`.
    /// 2. For an invalid request according to the extractor, the server should return the error HTTP status code defined by the extractor itself.
    /// 3. For an invalid request according to Valid, the server should return VALIDATION_ERROR_STATUS as the error code.
    ///
    pub trait ValidTest {
        /// The HTTP status code returned when inner extractor failed.
        const ERROR_STATUS_CODE: StatusCode;
        /// The HTTP status code returned when the outer extractor fails.
        /// Use crate::VALIDATION_ERROR_STATUS by default.
        const INVALID_STATUS_CODE: StatusCode = crate::VALIDATION_ERROR_STATUS;
        /// If the response body can be serialized into JSON format
        const JSON_SERIALIZABLE: bool = true;
        /// Build a valid request, the server should return `200 OK`.
        fn set_valid_request(builder: RequestBuilder) -> RequestBuilder;
        /// Build an invalid request according to the extractor, the server should return `Self::ERROR_STATUS_CODE`
        fn set_error_request(builder: RequestBuilder) -> RequestBuilder;
        /// Build an invalid request according to Valid, the server should return VALIDATION_ERROR_STATUS
        fn set_invalid_request(builder: RequestBuilder) -> RequestBuilder;
    }

    #[cfg(feature = "extra")]
    pub trait Rejection {
        const STATUS_CODE: StatusCode;
    }

    const TEST: &str = "test";

    #[test]
    fn valid_deref_deref_mut_into_inner() {
        let mut inner = String::from(TEST);
        let mut v = Valid(inner.clone());
        assert_eq!(&inner, v.deref());
        inner.push_str(TEST);
        v.deref_mut().push_str(TEST);
        assert_eq!(&inner, v.deref());
        assert_eq!(inner, v.into_inner());
    }

    #[test]
    fn valid_ex_deref_deref_mut_into_inner_arguments() {
        let mut inner = String::from(TEST);
        let mut v = ValidEx(inner.clone(), ());
        assert_eq!(&inner, v.deref());
        inner.push_str(TEST);
        v.deref_mut().push_str(TEST);
        assert_eq!(&inner, v.deref());
        assert_eq!(inner, v.into_inner());

        fn validate(_v: i32, _args: i32) -> Result<(), ValidationError> {
            Ok(())
        }

        #[derive(Validate)]
        struct Data {
            #[validate(custom(function = "validate", arg = "i32"))]
            v: i32,
        }

        struct DataVA {
            a: i32,
        }

        impl<'a> Arguments<'a> for DataVA {
            type T = Data;

            fn get(&'a self) -> <<Self as Arguments<'a>>::T as ValidateArgs<'a>>::Args {
                self.a
            }
        }

        let data = Data { v: 12 };
        let args = DataVA { a: 123 };
        let ve = ValidEx(data, args);
        assert_eq!(ve.v, 12);
        let a = ve.arguments();
        assert_eq!(a, 123);
    }

    #[test]
    fn display_error() {
        // ValidRejection::Valid Display
        let mut ve = ValidationErrors::new();
        ve.add(TEST, ValidationError::new(TEST));
        let vr = ValidRejection::<String>::Valid(ve.clone());
        assert_eq!(vr.to_string(), ve.to_string());

        // ValidRejection::Inner Display
        let inner = String::from(TEST);
        let vr = ValidRejection::<String>::Inner(inner.clone());
        assert_eq!(inner.to_string(), vr.to_string());

        // ValidRejection::Valid Error
        let mut ve = ValidationErrors::new();
        ve.add(TEST, ValidationError::new(TEST));
        let vr = ValidRejection::<io::Error>::Valid(ve.clone());
        assert!(
            matches!(vr.source(), Some(source) if source.downcast_ref::<ValidationErrors>().is_some())
        );

        // ValidRejection::Valid Error
        let vr = ValidRejection::<io::Error>::Inner(io::Error::new(io::ErrorKind::Other, TEST));
        assert!(
            matches!(vr.source(), Some(source) if source.downcast_ref::<io::Error>().is_some())
        );
    }
}