actix_web_validation/
garde.rs

1//! Validation for the [garde](https://docs.rs/garde/latest/garde) crate.
2//! Requires the `garde` feature flag
3//!
4//! Garde is a popular validation library for Rust.
5//!
6//! You will need to import the garde crate in your `Cargo.toml`.
7//!
8//! ```toml
9//! [dependencies]
10//! garde = { version = "0.0.0", features = ["derive"] }
11//! actix-web-validation = { version = "0.0.0", features = ["garde"]}
12//! ```
13//!
14//! For usage examples, see the documentation for [`Validated`]
15//!
16
17use crate::validated_definition;
18use ::garde::Validate;
19use actix_web::dev::{ServiceFactory, ServiceRequest};
20use actix_web::http::StatusCode;
21use actix_web::FromRequest;
22use actix_web::{App, HttpRequest, HttpResponse, ResponseError};
23use std::fmt::Display;
24use std::future::Future;
25use std::sync::Arc;
26use std::{fmt::Debug, ops::Deref, pin::Pin, task::Poll};
27use thiserror::Error;
28
29/// A validated extactor.
30///
31/// This type will run any validations on the inner extractors.
32///
33/// ```
34/// use actix_web::{post, web::{self, Json}, App};
35/// use serde::Deserialize;
36/// use garde::Validate;
37/// use actix_web_validation::garde::Validated;
38///
39/// #[derive(Debug, Deserialize, Validate)]
40/// struct Info {
41///     #[garde(length(min = 3))]
42///     username: String,
43/// }
44///
45/// #[post("/")]
46/// async fn index(info: Validated<Json<Info>>) -> String {
47///     format!("Welcome {}!", info.username)
48/// }
49/// ```
50pub struct Validated<T>(pub T);
51
52validated_definition!();
53
54/// Future that extracts and validates actix requests using the Actix Web [`FromRequest`] trait
55///
56/// End users of this library should not need to use this directly for most usecases
57pub struct ValidatedFut<T: FromRequest> {
58    req: actix_web::HttpRequest,
59    fut: <T as FromRequest>::Future,
60    error_handler: Option<GardeErrHandler>,
61}
62
63impl<T> Future for ValidatedFut<T>
64where
65    T: FromRequest + Debug + Deref,
66    T::Future: Unpin,
67    T::Target: Validate,
68    <T::Target as garde::Validate>::Context: Default,
69{
70    type Output = Result<Validated<T>, actix_web::Error>;
71
72    fn poll(
73        self: std::pin::Pin<&mut Self>,
74        cx: &mut std::task::Context<'_>,
75    ) -> std::task::Poll<Self::Output> {
76        let this = self.get_mut();
77        let Poll::Ready(res) = Pin::new(&mut this.fut).poll(cx) else {
78            return std::task::Poll::Pending;
79        };
80
81        let res = match res {
82            Ok(data) => {
83                if let Err(e) = data.validate() {
84                    if let Some(error_handler) = &this.error_handler {
85                        Err((*error_handler)(e, &this.req))
86                    } else {
87                        let err: Error = e.into();
88                        Err(err.into())
89                    }
90                } else {
91                    Ok(Validated(data))
92                }
93            }
94            Err(e) => Err(e.into()),
95        };
96
97        Poll::Ready(res)
98    }
99}
100
101impl<T> FromRequest for Validated<T>
102where
103    T: FromRequest + Debug + Deref,
104    T::Future: Unpin,
105    T::Target: Validate,
106    <T::Target as garde::Validate>::Context: Default,
107{
108    type Error = actix_web::Error;
109
110    type Future = ValidatedFut<T>;
111
112    fn from_request(
113        req: &actix_web::HttpRequest,
114        payload: &mut actix_web::dev::Payload,
115    ) -> Self::Future {
116        let error_handler = req
117            .app_data::<GardeErrorHandler>()
118            .map(|h| h.handler.clone());
119
120        let fut = T::from_request(req, payload);
121
122        ValidatedFut {
123            fut,
124            error_handler,
125            req: req.clone(),
126        }
127    }
128}
129
130#[derive(Error, Debug)]
131struct Error {
132    report: garde::Report,
133}
134
135impl From<garde::Report> for Error {
136    fn from(value: garde::Report) -> Self {
137        Self { report: value }
138    }
139}
140
141impl Display for Error {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        write!(f, "{}", self.report)
144    }
145}
146
147impl ResponseError for Error {
148    fn error_response(&self) -> HttpResponse {
149        let message = self
150            .report
151            .iter()
152            .map(|(path, error)| format!("{path}: {}", error.message()))
153            .collect::<Vec<_>>()
154            .join("\n");
155
156        HttpResponse::build(StatusCode::BAD_REQUEST)
157            .body(format!("Validation errors in fields:\n{}", message))
158    }
159}
160
161pub type GardeErrHandler =
162    Arc<dyn Fn(garde::Report, &HttpRequest) -> actix_web::Error + Send + Sync>;
163
164struct GardeErrorHandler {
165    handler: GardeErrHandler,
166}
167
168/// Extension trait to provide a convenience method for adding custom error handler
169pub trait GardeErrorHandlerExt {
170    /// Add a custom error handler for garde validated requests
171    fn garde_error_handler(self, handler: GardeErrHandler) -> Self;
172}
173
174impl<T> GardeErrorHandlerExt for App<T>
175where
176    T: ServiceFactory<ServiceRequest, Config = (), Error = actix_web::Error, InitError = ()>,
177{
178    fn garde_error_handler(self, handler: GardeErrHandler) -> Self {
179        self.app_data(GardeErrorHandler { handler })
180    }
181}
182
183impl GardeErrorHandlerExt for &mut actix_web::web::ServiceConfig {
184    fn garde_error_handler(self, handler: GardeErrHandler) -> Self {
185        self.app_data(GardeErrorHandler { handler })
186    }
187}
188
189#[cfg(test)]
190mod test {
191    use super::*;
192    use actix_web::web::Bytes;
193    use actix_web::{http::header::ContentType, post, test, web::Json, App, Responder};
194    use garde::Validate;
195    use serde::{Deserialize, Serialize};
196
197    #[derive(Debug, Deserialize, Serialize, Validate)]
198    struct ExamplePayload {
199        #[garde(length(min = 5))]
200        name: String,
201    }
202
203    #[post("/")]
204    async fn endpoint(v: Validated<Json<ExamplePayload>>) -> impl Responder {
205        assert!(v.name.len() > 4);
206        HttpResponse::Ok().body(())
207    }
208
209    #[actix_web::test]
210    async fn should_validate_simple() {
211        let app = test::init_service(App::new().service(endpoint)).await;
212
213        // Valid request
214        let req = test::TestRequest::post()
215            .uri("/")
216            .insert_header(ContentType::plaintext())
217            .set_json(ExamplePayload {
218                name: "123456".to_string(),
219            })
220            .to_request();
221        let resp = test::call_service(&app, req).await;
222        assert_eq!(resp.status().as_u16(), 200);
223
224        // Invalid request
225        let req = test::TestRequest::post()
226            .uri("/")
227            .insert_header(ContentType::plaintext())
228            .set_json(ExamplePayload {
229                name: "1234".to_string(),
230            })
231            .to_request();
232        let resp = test::call_service(&app, req).await;
233        assert_eq!(resp.status().as_u16(), 400);
234    }
235
236    #[actix_web::test]
237    async fn should_respond_with_errors_correctly() {
238        let app = test::init_service(App::new().service(endpoint)).await;
239
240        // Invalid request
241        let req = test::TestRequest::post()
242            .uri("/")
243            .insert_header(ContentType::plaintext())
244            .set_json(ExamplePayload {
245                name: "1234".to_string(),
246            })
247            .to_request();
248        let result = test::call_and_read_body(&app, req).await;
249        assert_eq!(
250            result,
251            Bytes::from_static(b"Validation errors in fields:\nname: length is lower than 5")
252        );
253    }
254
255    #[derive(Debug, Serialize, Error)]
256    struct CustomErrorResponse {
257        custom_message: String,
258        errors: Vec<String>,
259    }
260
261    impl Display for CustomErrorResponse {
262        fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263            unimplemented!()
264        }
265    }
266
267    impl ResponseError for CustomErrorResponse {
268        fn status_code(&self) -> actix_web::http::StatusCode {
269            actix_web::http::StatusCode::BAD_REQUEST
270        }
271
272        fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
273            HttpResponse::build(self.status_code()).body(serde_json::to_string(self).unwrap())
274        }
275    }
276
277    fn error_handler(errors: ::garde::Report, _: &HttpRequest) -> actix_web::Error {
278        CustomErrorResponse {
279            custom_message: "My custom message".to_string(),
280            errors: errors.iter().map(|(_, err)| err.to_string()).collect(),
281        }
282        .into()
283    }
284
285    #[actix_web::test]
286    async fn should_use_allow_custom_error_responses() {
287        let app = test::init_service(
288            App::new()
289                .service(endpoint)
290                .garde_error_handler(Arc::new(error_handler)),
291        )
292        .await;
293
294        let req = test::TestRequest::post()
295            .uri("/")
296            .insert_header(ContentType::plaintext())
297            .set_json(ExamplePayload {
298                name: "1234".to_string(),
299            })
300            .to_request();
301        let result = test::call_and_read_body(&app, req).await;
302        assert_eq!(
303            result,
304            Bytes::from_static(b"{\"custom_message\":\"My custom message\",\"errors\":[\"length is lower than 5\"]}")
305        );
306    }
307
308    #[test]
309    async fn debug_for_validated_should_work() {
310        let v = Validated(ExamplePayload {
311            name: "abcde".to_string(),
312        });
313
314        assert_eq!(
315            "Validated(ExamplePayload { name: \"abcde\" })",
316            format!("{v:?}")
317        );
318    }
319}