actix_web_validation/
custom.rs

1//! Validation with using an external library for validation.
2//! Requires the `custom` feature flag
3//!
4//! For usage examples, see the documentation for [`Validated`]
5//!
6
7use crate::validated_definition;
8use actix_web::dev::{ServiceFactory, ServiceRequest};
9use actix_web::http::StatusCode;
10use actix_web::FromRequest;
11use actix_web::{App, HttpRequest, HttpResponse, ResponseError};
12use std::fmt::Display;
13use std::future::Future;
14use std::sync::Arc;
15use std::{fmt::Debug, ops::Deref, pin::Pin, task::Poll};
16use thiserror::Error;
17
18/// A trait that can be implemented to provide validation logic.
19pub trait Validate {
20    fn validate(&self) -> Result<(), Vec<ValidationError>>;
21}
22
23/// A validation error
24#[derive(Debug)]
25pub struct ValidationError {
26    message: String,
27}
28
29impl Display for ValidationError {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "{}", self.message)
32    }
33}
34
35/// A validated extactor.
36///
37/// This type will run any validations on the inner extractors.
38///
39/// ```
40/// use actix_web::{post, web::{self, Json}, App};
41/// use serde::Deserialize;
42/// use actix_web_validation::custom::{Validated, Validate, ValidationError};
43///
44/// #[derive(Debug, Deserialize)]
45/// struct Info {
46///     username: String,
47/// }
48///
49/// impl Validate for Info {
50///     fn validate(&self) -> Result<(), Vec<ValidationError>> {
51///         // Do validation logic here...
52///         Ok(())
53///     }
54/// }
55///
56/// #[post("/")]
57/// async fn index(info: Validated<Json<Info>>) -> String {
58///     format!("Welcome {}!", info.username)
59/// }
60/// ```
61pub struct Validated<T>(pub T);
62
63validated_definition!();
64
65pub struct ValidatedFut<T: FromRequest> {
66    req: actix_web::HttpRequest,
67    fut: <T as FromRequest>::Future,
68    error_handler: Option<ValidationErrHandler>,
69}
70impl<T> Future for ValidatedFut<T>
71where
72    T: FromRequest + Debug + Deref,
73    T::Future: Unpin,
74    T::Target: Validate,
75{
76    type Output = Result<Validated<T>, actix_web::Error>;
77
78    fn poll(
79        self: std::pin::Pin<&mut Self>,
80        cx: &mut std::task::Context<'_>,
81    ) -> std::task::Poll<Self::Output> {
82        let this = self.get_mut();
83        let Poll::Ready(res) = Pin::new(&mut this.fut).poll(cx) else {
84            return std::task::Poll::Pending;
85        };
86
87        let res = match res {
88            Ok(data) => {
89                if let Err(e) = data.validate() {
90                    if let Some(error_handler) = &this.error_handler {
91                        Err((*error_handler)(e, &this.req))
92                    } else {
93                        let err: Error = e.into();
94                        Err(err.into())
95                    }
96                } else {
97                    Ok(Validated(data))
98                }
99            }
100            Err(e) => Err(e.into()),
101        };
102
103        Poll::Ready(res)
104    }
105}
106
107impl<T> FromRequest for Validated<T>
108where
109    T: FromRequest + Debug + Deref,
110    T::Future: Unpin,
111    T::Target: Validate,
112{
113    type Error = actix_web::Error;
114
115    type Future = ValidatedFut<T>;
116
117    fn from_request(
118        req: &actix_web::HttpRequest,
119        payload: &mut actix_web::dev::Payload,
120    ) -> Self::Future {
121        let error_handler = req
122            .app_data::<ValidationErrorHandler>()
123            .map(|h| h.handler.clone());
124
125        let fut = T::from_request(req, payload);
126
127        ValidatedFut {
128            fut,
129            error_handler,
130            req: req.clone(),
131        }
132    }
133}
134
135#[derive(Error, Debug)]
136struct Error {
137    errors: Vec<ValidationError>,
138}
139
140impl From<Vec<ValidationError>> for Error {
141    fn from(value: Vec<ValidationError>) -> Self {
142        Self { errors: value }
143    }
144}
145
146impl Display for Error {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        write!(
149            f,
150            "{}",
151            self.errors
152                .iter()
153                .map(|e| e.message.as_ref())
154                .collect::<Vec<_>>()
155                .join("\n")
156        )
157    }
158}
159
160impl ResponseError for Error {
161    fn error_response(&self) -> HttpResponse {
162        HttpResponse::build(StatusCode::BAD_REQUEST).body(format!(
163            "Validation errors in fields:\n{}",
164            &self
165                .errors
166                .iter()
167                .map(|err| { format!("\t{}", err) })
168                .collect::<Vec<_>>()
169                .join("\n")
170        ))
171    }
172}
173
174pub type ValidationErrHandler =
175    Arc<dyn Fn(Vec<ValidationError>, &HttpRequest) -> actix_web::Error + Send + Sync>;
176
177struct ValidationErrorHandler {
178    handler: ValidationErrHandler,
179}
180
181/// Extension trait to provide a convenience method for adding custom error handler
182pub trait ValidationErrorHandlerExt {
183    /// Add a custom error handler for garde validated requests
184    fn validation_error_handler(self, handler: ValidationErrHandler) -> Self;
185}
186
187impl<T> ValidationErrorHandlerExt for App<T>
188where
189    T: ServiceFactory<ServiceRequest, Config = (), Error = actix_web::Error, InitError = ()>,
190{
191    fn validation_error_handler(self, handler: ValidationErrHandler) -> Self {
192        self.app_data(ValidationErrorHandler { handler })
193    }
194}
195
196impl ValidationErrorHandlerExt for &mut actix_web::web::ServiceConfig {
197    fn validation_error_handler(self, handler: ValidationErrHandler) -> Self {
198        self.app_data(ValidationErrorHandler { handler })
199    }
200}
201
202#[cfg(test)]
203mod test {
204    use super::*;
205    use actix_web::web::Bytes;
206    use actix_web::{http::header::ContentType, post, test, web::Json, App, Responder};
207    use serde::{Deserialize, Serialize};
208
209    #[derive(Debug, Deserialize, Serialize)]
210    struct ExamplePayload {
211        name: String,
212    }
213
214    impl Validate for ExamplePayload {
215        fn validate(&self) -> Result<(), Vec<ValidationError>> {
216            if self.name.len() > 4 {
217                Ok(())
218            } else {
219                Err(vec![ValidationError {
220                    message: "name not long enough".to_string(),
221                }])
222            }
223        }
224    }
225
226    #[post("/")]
227    async fn endpoint(v: Validated<Json<ExamplePayload>>) -> impl Responder {
228        assert!(v.name.len() > 4);
229        HttpResponse::Ok().body(())
230    }
231
232    #[actix_web::test]
233    async fn should_validate_simple() {
234        let app = test::init_service(App::new().service(endpoint)).await;
235
236        // Valid request
237        let req = test::TestRequest::post()
238            .uri("/")
239            .insert_header(ContentType::plaintext())
240            .set_json(ExamplePayload {
241                name: "123456".to_string(),
242            })
243            .to_request();
244        let resp = test::call_service(&app, req).await;
245        assert_eq!(resp.status().as_u16(), 200);
246
247        // Invalid request
248        let req = test::TestRequest::post()
249            .uri("/")
250            .insert_header(ContentType::plaintext())
251            .set_json(ExamplePayload {
252                name: "1234".to_string(),
253            })
254            .to_request();
255        let resp = test::call_service(&app, req).await;
256        assert_eq!(resp.status().as_u16(), 400);
257    }
258
259    #[actix_web::test]
260    async fn should_respond_with_errors_correctly() {
261        let app = test::init_service(App::new().service(endpoint)).await;
262
263        // Invalid request
264        let req = test::TestRequest::post()
265            .uri("/")
266            .insert_header(ContentType::plaintext())
267            .set_json(ExamplePayload {
268                name: "1234".to_string(),
269            })
270            .to_request();
271        let result = test::call_and_read_body(&app, req).await;
272        assert_eq!(
273            result,
274            Bytes::from_static(b"Validation errors in fields:\n\tname not long enough")
275        );
276    }
277
278    #[derive(Debug, Serialize, Error)]
279    struct CustomErrorResponse {
280        custom_message: String,
281        errors: Vec<String>,
282    }
283
284    impl Display for CustomErrorResponse {
285        fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286            unimplemented!()
287        }
288    }
289
290    impl ResponseError for CustomErrorResponse {
291        fn status_code(&self) -> actix_web::http::StatusCode {
292            actix_web::http::StatusCode::BAD_REQUEST
293        }
294
295        fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
296            HttpResponse::build(self.status_code()).body(serde_json::to_string(self).unwrap())
297        }
298    }
299
300    fn error_handler(errors: Vec<ValidationError>, _: &HttpRequest) -> actix_web::Error {
301        CustomErrorResponse {
302            custom_message: "My custom message".to_string(),
303            errors: errors.iter().map(|err| err.message.clone()).collect(),
304        }
305        .into()
306    }
307
308    #[actix_web::test]
309    async fn should_use_allow_custom_error_responses() {
310        let app = test::init_service(
311            App::new()
312                .service(endpoint)
313                .validation_error_handler(Arc::new(error_handler)),
314        )
315        .await;
316
317        let req = test::TestRequest::post()
318            .uri("/")
319            .insert_header(ContentType::plaintext())
320            .set_json(ExamplePayload {
321                name: "1234".to_string(),
322            })
323            .to_request();
324        let result = test::call_and_read_body(&app, req).await;
325        assert_eq!(
326            result,
327            Bytes::from_static(
328                b"{\"custom_message\":\"My custom message\",\"errors\":[\"name not long enough\"]}"
329            )
330        );
331    }
332
333    #[test]
334    async fn debug_for_validated_should_work() {
335        let v = Validated(ExamplePayload {
336            name: "abcde".to_string(),
337        });
338
339        assert_eq!(
340            "Validated(ExamplePayload { name: \"abcde\" })",
341            format!("{v:?}")
342        );
343    }
344}