actix_web_validation/
validator.rs

1//! Validation for the [validator](https://docs.rs/validator/latest/validator) crate.
2//! Requires the `validator` feature flag
3//!
4//! Validator is a popular validation library for Rust.
5//!
6//! You will need to import the validator crate in your `Cargo.toml`.
7//!
8//! ```toml
9//! [dependencies]
10//! validator = { version = "0.0.0", features = ["derive"] }
11//! actix-web-validation = { version = "0.0.0", features = ["validator"]}
12//! ```
13//!
14//! For usage examples, see the documentation for [`Validated`]
15//!
16
17use crate::validated_definition;
18use ::validator::Validate;
19use actix_web::dev::{ServiceFactory, ServiceRequest};
20use actix_web::http::StatusCode;
21use actix_web::FromRequest;
22use actix_web::{App, HttpRequest, HttpResponse, ResponseError};
23use std::fmt::Display;
24use std::future::Future;
25use std::sync::Arc;
26use std::{fmt::Debug, ops::Deref, pin::Pin, task::Poll};
27use thiserror::Error;
28use validator::{ValidationError, ValidationErrors, ValidationErrorsKind};
29
30/// A validated extactor.
31///
32/// This type will run any validations on the inner extractors.
33///
34/// ```
35/// use actix_web::{post, web::{self, Json}, App};
36/// use serde::Deserialize;
37/// use validator::Validate;
38/// use actix_web_validation::validator::Validated;
39///
40/// #[derive(Debug, Deserialize, Validate)]
41/// struct Info {
42///     #[validate(length(min = 5))]
43///     username: String,
44/// }
45///
46/// #[post("/")]
47/// async fn index(info: Validated<Json<Info>>) -> String {
48///     format!("Welcome {}!", info.username)
49/// }
50/// ```
51pub struct Validated<T>(pub T);
52
53validated_definition!();
54
55/// Future that extracts and validates actix requests using the Actix Web [`FromRequest`] trait
56///
57/// End users of this library should not need to use this directly for most usecases
58pub struct ValidatedFut<T: FromRequest> {
59    req: actix_web::HttpRequest,
60    fut: <T as FromRequest>::Future,
61    error_handler: Option<ValidatorErrHandler>,
62}
63impl<T> Future for ValidatedFut<T>
64where
65    T: FromRequest + Debug + Deref,
66    T::Future: Unpin,
67    T::Target: Validate,
68{
69    type Output = Result<Validated<T>, actix_web::Error>;
70
71    fn poll(
72        self: std::pin::Pin<&mut Self>,
73        cx: &mut std::task::Context<'_>,
74    ) -> std::task::Poll<Self::Output> {
75        let this = self.get_mut();
76
77        let Poll::Ready(res) = Pin::new(&mut this.fut).poll(cx) else {
78            return std::task::Poll::Pending;
79        };
80
81        let res = match res {
82            Ok(data) => {
83                if let Err(e) = data.validate() {
84                    if let Some(error_handler) = &this.error_handler {
85                        Err((*error_handler)(e, &this.req))
86                    } else {
87                        let err: Error = e.into();
88                        Err(err.into())
89                    }
90                } else {
91                    Ok(Validated(data))
92                }
93            }
94            Err(e) => Err(e.into()),
95        };
96
97        Poll::Ready(res)
98    }
99}
100
101impl<T> FromRequest for Validated<T>
102where
103    T: FromRequest + Debug + Deref,
104    T::Future: Unpin,
105    T::Target: Validate,
106{
107    type Error = actix_web::Error;
108
109    type Future = ValidatedFut<T>;
110
111    fn from_request(
112        req: &actix_web::HttpRequest,
113        payload: &mut actix_web::dev::Payload,
114    ) -> Self::Future {
115        let error_handler = req
116            .app_data::<ValidatorErrorHandler>()
117            .map(|h| h.handler.clone());
118
119        let fut = T::from_request(req, payload);
120
121        ValidatedFut {
122            fut,
123            error_handler,
124            req: req.clone(),
125        }
126    }
127}
128
129#[derive(Error, Debug)]
130struct Error {
131    errors: validator::ValidationErrors,
132}
133
134impl From<validator::ValidationErrors> for Error {
135    fn from(value: validator::ValidationErrors) -> Self {
136        Self { errors: value }
137    }
138}
139
140impl Display for Error {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        write!(f, "{}", self.errors)
143    }
144}
145
146impl ResponseError for Error {
147    fn error_response(&self) -> HttpResponse {
148        HttpResponse::build(StatusCode::BAD_REQUEST).body(format!(
149            "Validation errors in fields:\n{}",
150            flatten_errors(&self.errors)
151                .iter()
152                .map(|(_, field, err)| { format!("\t{}: {}", field, err) })
153                .collect::<Vec<_>>()
154                .join("\n")
155        ))
156    }
157}
158
159/// Helper function for error extraction and formatting.
160/// Return Vec of tuples where first element is full field path (separated by dot)
161/// and second is error.
162#[inline]
163fn flatten_errors(errors: &ValidationErrors) -> Vec<(u16, String, &ValidationError)> {
164    _flatten_errors(errors, None, None)
165}
166
167#[inline]
168fn _flatten_errors(
169    errors: &ValidationErrors,
170    path: Option<String>,
171    indent: Option<u16>,
172) -> Vec<(u16, String, &ValidationError)> {
173    errors
174        .errors()
175        .iter()
176        .flat_map(|(field, err)| {
177            let indent = indent.unwrap_or(0);
178            let actual_path = path
179                .as_ref()
180                .map(|path| [path.as_str(), &field].join("."))
181                .unwrap_or_else(|| field.to_string());
182            match err {
183                ValidationErrorsKind::Field(field_errors) => field_errors
184                    .iter()
185                    .map(|error| (indent, actual_path.clone(), error))
186                    .collect::<Vec<_>>(),
187                ValidationErrorsKind::List(list_error) => list_error
188                    .iter()
189                    .flat_map(|(index, errors)| {
190                        let actual_path = format!("{}[{}]", actual_path.as_str(), index);
191                        _flatten_errors(errors, Some(actual_path), Some(indent + 1))
192                    })
193                    .collect::<Vec<_>>(),
194                ValidationErrorsKind::Struct(struct_errors) => {
195                    _flatten_errors(struct_errors, Some(actual_path), Some(indent + 1))
196                }
197            }
198        })
199        .collect::<Vec<_>>()
200}
201
202pub type ValidatorErrHandler =
203    Arc<dyn Fn(validator::ValidationErrors, &HttpRequest) -> actix_web::Error + Send + Sync>;
204
205struct ValidatorErrorHandler {
206    handler: ValidatorErrHandler,
207}
208
209/// Extension trait to provide a convenience method for adding custom error handler
210pub trait ValidatorErrorHandlerExt {
211    /// Add a custom error handler for validator validated requests
212    fn validator_error_handler(self, handler: ValidatorErrHandler) -> Self;
213}
214
215impl<T> ValidatorErrorHandlerExt for App<T>
216where
217    T: ServiceFactory<ServiceRequest, Config = (), Error = actix_web::Error, InitError = ()>,
218{
219    fn validator_error_handler(self, handler: ValidatorErrHandler) -> Self {
220        self.app_data(ValidatorErrorHandler { handler })
221    }
222}
223
224impl ValidatorErrorHandlerExt for &mut actix_web::web::ServiceConfig {
225    fn validator_error_handler(self, handler: ValidatorErrHandler) -> Self {
226        self.app_data(ValidatorErrorHandler { handler })
227    }
228}
229
230#[cfg(test)]
231mod test {
232    use super::*;
233    use actix_web::web::Bytes;
234    use actix_web::{http::header::ContentType, post, test, web::Json, App, Responder};
235    use serde::{Deserialize, Serialize};
236    use validator::Validate;
237
238    #[derive(Debug, Deserialize, Serialize, Validate)]
239    struct ExamplePayload {
240        #[validate(length(min = 5))]
241        name: String,
242    }
243
244    #[post("/")]
245    async fn endpoint(v: Validated<Json<ExamplePayload>>) -> impl Responder {
246        assert!(v.name.len() > 4);
247        HttpResponse::Ok().body(())
248    }
249
250    #[actix_web::test]
251    async fn should_validate_simple() {
252        let app = test::init_service(App::new().service(endpoint)).await;
253
254        // Valid request
255        let req = test::TestRequest::post()
256            .uri("/")
257            .insert_header(ContentType::plaintext())
258            .set_json(ExamplePayload {
259                name: "123456".to_string(),
260            })
261            .to_request();
262        let resp = test::call_service(&app, req).await;
263        assert_eq!(resp.status().as_u16(), 200);
264
265        // Invalid request
266        let req = test::TestRequest::post()
267            .uri("/")
268            .insert_header(ContentType::plaintext())
269            .set_json(ExamplePayload {
270                name: "1234".to_string(),
271            })
272            .to_request();
273        let resp = test::call_service(&app, req).await;
274        assert_eq!(resp.status().as_u16(), 400);
275    }
276
277    // TODO: This test is unstable because the error or appears to be non-dermimistic
278    #[ignore]
279    #[actix_web::test]
280    async fn should_respond_with_errors_correctly() {
281        let app = test::init_service(App::new().service(endpoint)).await;
282
283        // Invalid request
284        let req = test::TestRequest::post()
285            .uri("/")
286            .insert_header(ContentType::plaintext())
287            .set_json(ExamplePayload {
288                name: "1234".to_string(),
289            })
290            .to_request();
291        let result = test::call_and_read_body(&app, req).await;
292        assert_eq!(
293            result,
294            Bytes::from_static(b"Validation errors in fields:\n\tname: Validation error: length [{\"min\": Number(5), \"value\": String(\"1234\")}]")
295        );
296    }
297
298    #[derive(Debug, Serialize, Error)]
299    struct CustomErrorResponse {
300        custom_message: String,
301        errors: Vec<String>,
302    }
303
304    impl Display for CustomErrorResponse {
305        fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306            unimplemented!()
307        }
308    }
309
310    impl ResponseError for CustomErrorResponse {
311        fn status_code(&self) -> actix_web::http::StatusCode {
312            actix_web::http::StatusCode::BAD_REQUEST
313        }
314
315        fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
316            HttpResponse::build(self.status_code()).body(serde_json::to_string(self).unwrap())
317        }
318    }
319
320    fn error_handler(errors: ::validator::ValidationErrors, _: &HttpRequest) -> actix_web::Error {
321        CustomErrorResponse {
322            custom_message: "My custom message".to_string(),
323            errors: errors
324                .errors()
325                .iter()
326                .map(|(err, _)| err.to_string())
327                .collect(),
328        }
329        .into()
330    }
331
332    #[actix_web::test]
333    async fn should_use_allow_custom_error_responses() {
334        let app = test::init_service(
335            App::new()
336                .service(endpoint)
337                .validator_error_handler(Arc::new(error_handler)),
338        )
339        .await;
340
341        let req = test::TestRequest::post()
342            .uri("/")
343            .insert_header(ContentType::plaintext())
344            .set_json(ExamplePayload {
345                name: "1234".to_string(),
346            })
347            .to_request();
348        let result = test::call_and_read_body(&app, req).await;
349        assert_eq!(
350            result,
351            Bytes::from_static(b"{\"custom_message\":\"My custom message\",\"errors\":[\"name\"]}")
352        );
353    }
354
355    #[test]
356    async fn debug_for_validated_should_work() {
357        let v = Validated(ExamplePayload {
358            name: "abcde".to_string(),
359        });
360
361        assert_eq!(
362            "Validated(ExamplePayload { name: \"abcde\" })",
363            format!("{v:?}")
364        );
365    }
366}