actix_web_validation/
garde.rs1use crate::validated_definition;
18use ::garde::Validate;
19use actix_web::dev::{ServiceFactory, ServiceRequest};
20use actix_web::http::StatusCode;
21use actix_web::FromRequest;
22use actix_web::{App, HttpRequest, HttpResponse, ResponseError};
23use std::fmt::Display;
24use std::future::Future;
25use std::sync::Arc;
26use std::{fmt::Debug, ops::Deref, pin::Pin, task::Poll};
27use thiserror::Error;
28
29pub struct Validated<T>(pub T);
51
52validated_definition!();
53
54pub struct ValidatedFut<T: FromRequest> {
58 req: actix_web::HttpRequest,
59 fut: <T as FromRequest>::Future,
60 error_handler: Option<GardeErrHandler>,
61}
62
63impl<T> Future for ValidatedFut<T>
64where
65 T: FromRequest + Debug + Deref,
66 T::Future: Unpin,
67 T::Target: Validate,
68 <T::Target as garde::Validate>::Context: Default,
69{
70 type Output = Result<Validated<T>, actix_web::Error>;
71
72 fn poll(
73 self: std::pin::Pin<&mut Self>,
74 cx: &mut std::task::Context<'_>,
75 ) -> std::task::Poll<Self::Output> {
76 let this = self.get_mut();
77 let Poll::Ready(res) = Pin::new(&mut this.fut).poll(cx) else {
78 return std::task::Poll::Pending;
79 };
80
81 let res = match res {
82 Ok(data) => {
83 if let Err(e) = data.validate() {
84 if let Some(error_handler) = &this.error_handler {
85 Err((*error_handler)(e, &this.req))
86 } else {
87 let err: Error = e.into();
88 Err(err.into())
89 }
90 } else {
91 Ok(Validated(data))
92 }
93 }
94 Err(e) => Err(e.into()),
95 };
96
97 Poll::Ready(res)
98 }
99}
100
101impl<T> FromRequest for Validated<T>
102where
103 T: FromRequest + Debug + Deref,
104 T::Future: Unpin,
105 T::Target: Validate,
106 <T::Target as garde::Validate>::Context: Default,
107{
108 type Error = actix_web::Error;
109
110 type Future = ValidatedFut<T>;
111
112 fn from_request(
113 req: &actix_web::HttpRequest,
114 payload: &mut actix_web::dev::Payload,
115 ) -> Self::Future {
116 let error_handler = req
117 .app_data::<GardeErrorHandler>()
118 .map(|h| h.handler.clone());
119
120 let fut = T::from_request(req, payload);
121
122 ValidatedFut {
123 fut,
124 error_handler,
125 req: req.clone(),
126 }
127 }
128}
129
130#[derive(Error, Debug)]
131struct Error {
132 report: garde::Report,
133}
134
135impl From<garde::Report> for Error {
136 fn from(value: garde::Report) -> Self {
137 Self { report: value }
138 }
139}
140
141impl Display for Error {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 write!(f, "{}", self.report)
144 }
145}
146
147impl ResponseError for Error {
148 fn error_response(&self) -> HttpResponse {
149 let message = self
150 .report
151 .iter()
152 .map(|(path, error)| format!("{path}: {}", error.message()))
153 .collect::<Vec<_>>()
154 .join("\n");
155
156 HttpResponse::build(StatusCode::BAD_REQUEST)
157 .body(format!("Validation errors in fields:\n{}", message))
158 }
159}
160
161pub type GardeErrHandler =
162 Arc<dyn Fn(garde::Report, &HttpRequest) -> actix_web::Error + Send + Sync>;
163
164struct GardeErrorHandler {
165 handler: GardeErrHandler,
166}
167
168pub trait GardeErrorHandlerExt {
170 fn garde_error_handler(self, handler: GardeErrHandler) -> Self;
172}
173
174impl<T> GardeErrorHandlerExt for App<T>
175where
176 T: ServiceFactory<ServiceRequest, Config = (), Error = actix_web::Error, InitError = ()>,
177{
178 fn garde_error_handler(self, handler: GardeErrHandler) -> Self {
179 self.app_data(GardeErrorHandler { handler })
180 }
181}
182
183impl GardeErrorHandlerExt for &mut actix_web::web::ServiceConfig {
184 fn garde_error_handler(self, handler: GardeErrHandler) -> Self {
185 self.app_data(GardeErrorHandler { handler })
186 }
187}
188
189#[cfg(test)]
190mod test {
191 use super::*;
192 use actix_web::web::Bytes;
193 use actix_web::{http::header::ContentType, post, test, web::Json, App, Responder};
194 use garde::Validate;
195 use serde::{Deserialize, Serialize};
196
197 #[derive(Debug, Deserialize, Serialize, Validate)]
198 struct ExamplePayload {
199 #[garde(length(min = 5))]
200 name: String,
201 }
202
203 #[post("/")]
204 async fn endpoint(v: Validated<Json<ExamplePayload>>) -> impl Responder {
205 assert!(v.name.len() > 4);
206 HttpResponse::Ok().body(())
207 }
208
209 #[actix_web::test]
210 async fn should_validate_simple() {
211 let app = test::init_service(App::new().service(endpoint)).await;
212
213 let req = test::TestRequest::post()
215 .uri("/")
216 .insert_header(ContentType::plaintext())
217 .set_json(ExamplePayload {
218 name: "123456".to_string(),
219 })
220 .to_request();
221 let resp = test::call_service(&app, req).await;
222 assert_eq!(resp.status().as_u16(), 200);
223
224 let req = test::TestRequest::post()
226 .uri("/")
227 .insert_header(ContentType::plaintext())
228 .set_json(ExamplePayload {
229 name: "1234".to_string(),
230 })
231 .to_request();
232 let resp = test::call_service(&app, req).await;
233 assert_eq!(resp.status().as_u16(), 400);
234 }
235
236 #[actix_web::test]
237 async fn should_respond_with_errors_correctly() {
238 let app = test::init_service(App::new().service(endpoint)).await;
239
240 let req = test::TestRequest::post()
242 .uri("/")
243 .insert_header(ContentType::plaintext())
244 .set_json(ExamplePayload {
245 name: "1234".to_string(),
246 })
247 .to_request();
248 let result = test::call_and_read_body(&app, req).await;
249 assert_eq!(
250 result,
251 Bytes::from_static(b"Validation errors in fields:\nname: length is lower than 5")
252 );
253 }
254
255 #[derive(Debug, Serialize, Error)]
256 struct CustomErrorResponse {
257 custom_message: String,
258 errors: Vec<String>,
259 }
260
261 impl Display for CustomErrorResponse {
262 fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263 unimplemented!()
264 }
265 }
266
267 impl ResponseError for CustomErrorResponse {
268 fn status_code(&self) -> actix_web::http::StatusCode {
269 actix_web::http::StatusCode::BAD_REQUEST
270 }
271
272 fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
273 HttpResponse::build(self.status_code()).body(serde_json::to_string(self).unwrap())
274 }
275 }
276
277 fn error_handler(errors: ::garde::Report, _: &HttpRequest) -> actix_web::Error {
278 CustomErrorResponse {
279 custom_message: "My custom message".to_string(),
280 errors: errors.iter().map(|(_, err)| err.to_string()).collect(),
281 }
282 .into()
283 }
284
285 #[actix_web::test]
286 async fn should_use_allow_custom_error_responses() {
287 let app = test::init_service(
288 App::new()
289 .service(endpoint)
290 .garde_error_handler(Arc::new(error_handler)),
291 )
292 .await;
293
294 let req = test::TestRequest::post()
295 .uri("/")
296 .insert_header(ContentType::plaintext())
297 .set_json(ExamplePayload {
298 name: "1234".to_string(),
299 })
300 .to_request();
301 let result = test::call_and_read_body(&app, req).await;
302 assert_eq!(
303 result,
304 Bytes::from_static(b"{\"custom_message\":\"My custom message\",\"errors\":[\"length is lower than 5\"]}")
305 );
306 }
307
308 #[test]
309 async fn debug_for_validated_should_work() {
310 let v = Validated(ExamplePayload {
311 name: "abcde".to_string(),
312 });
313
314 assert_eq!(
315 "Validated(ExamplePayload { name: \"abcde\" })",
316 format!("{v:?}")
317 );
318 }
319}