1use crate::validated_definition;
18use ::validator::Validate;
19use actix_web::dev::{ServiceFactory, ServiceRequest};
20use actix_web::http::StatusCode;
21use actix_web::FromRequest;
22use actix_web::{App, HttpRequest, HttpResponse, ResponseError};
23use std::fmt::Display;
24use std::future::Future;
25use std::sync::Arc;
26use std::{fmt::Debug, ops::Deref, pin::Pin, task::Poll};
27use thiserror::Error;
28use validator::{ValidationError, ValidationErrors, ValidationErrorsKind};
29
30pub struct Validated<T>(pub T);
52
53validated_definition!();
54
55pub struct ValidatedFut<T: FromRequest> {
59 req: actix_web::HttpRequest,
60 fut: <T as FromRequest>::Future,
61 error_handler: Option<ValidatorErrHandler>,
62}
63impl<T> Future for ValidatedFut<T>
64where
65 T: FromRequest + Debug + Deref,
66 T::Future: Unpin,
67 T::Target: Validate,
68{
69 type Output = Result<Validated<T>, actix_web::Error>;
70
71 fn poll(
72 self: std::pin::Pin<&mut Self>,
73 cx: &mut std::task::Context<'_>,
74 ) -> std::task::Poll<Self::Output> {
75 let this = self.get_mut();
76
77 let Poll::Ready(res) = Pin::new(&mut this.fut).poll(cx) else {
78 return std::task::Poll::Pending;
79 };
80
81 let res = match res {
82 Ok(data) => {
83 if let Err(e) = data.validate() {
84 if let Some(error_handler) = &this.error_handler {
85 Err((*error_handler)(e, &this.req))
86 } else {
87 let err: Error = e.into();
88 Err(err.into())
89 }
90 } else {
91 Ok(Validated(data))
92 }
93 }
94 Err(e) => Err(e.into()),
95 };
96
97 Poll::Ready(res)
98 }
99}
100
101impl<T> FromRequest for Validated<T>
102where
103 T: FromRequest + Debug + Deref,
104 T::Future: Unpin,
105 T::Target: Validate,
106{
107 type Error = actix_web::Error;
108
109 type Future = ValidatedFut<T>;
110
111 fn from_request(
112 req: &actix_web::HttpRequest,
113 payload: &mut actix_web::dev::Payload,
114 ) -> Self::Future {
115 let error_handler = req
116 .app_data::<ValidatorErrorHandler>()
117 .map(|h| h.handler.clone());
118
119 let fut = T::from_request(req, payload);
120
121 ValidatedFut {
122 fut,
123 error_handler,
124 req: req.clone(),
125 }
126 }
127}
128
129#[derive(Error, Debug)]
130struct Error {
131 errors: validator::ValidationErrors,
132}
133
134impl From<validator::ValidationErrors> for Error {
135 fn from(value: validator::ValidationErrors) -> Self {
136 Self { errors: value }
137 }
138}
139
140impl Display for Error {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 write!(f, "{}", self.errors)
143 }
144}
145
146impl ResponseError for Error {
147 fn error_response(&self) -> HttpResponse {
148 HttpResponse::build(StatusCode::BAD_REQUEST).body(format!(
149 "Validation errors in fields:\n{}",
150 flatten_errors(&self.errors)
151 .iter()
152 .map(|(_, field, err)| { format!("\t{}: {}", field, err) })
153 .collect::<Vec<_>>()
154 .join("\n")
155 ))
156 }
157}
158
159#[inline]
163fn flatten_errors(errors: &ValidationErrors) -> Vec<(u16, String, &ValidationError)> {
164 _flatten_errors(errors, None, None)
165}
166
167#[inline]
168fn _flatten_errors(
169 errors: &ValidationErrors,
170 path: Option<String>,
171 indent: Option<u16>,
172) -> Vec<(u16, String, &ValidationError)> {
173 errors
174 .errors()
175 .iter()
176 .flat_map(|(field, err)| {
177 let indent = indent.unwrap_or(0);
178 let actual_path = path
179 .as_ref()
180 .map(|path| [path.as_str(), &field].join("."))
181 .unwrap_or_else(|| field.to_string());
182 match err {
183 ValidationErrorsKind::Field(field_errors) => field_errors
184 .iter()
185 .map(|error| (indent, actual_path.clone(), error))
186 .collect::<Vec<_>>(),
187 ValidationErrorsKind::List(list_error) => list_error
188 .iter()
189 .flat_map(|(index, errors)| {
190 let actual_path = format!("{}[{}]", actual_path.as_str(), index);
191 _flatten_errors(errors, Some(actual_path), Some(indent + 1))
192 })
193 .collect::<Vec<_>>(),
194 ValidationErrorsKind::Struct(struct_errors) => {
195 _flatten_errors(struct_errors, Some(actual_path), Some(indent + 1))
196 }
197 }
198 })
199 .collect::<Vec<_>>()
200}
201
202pub type ValidatorErrHandler =
203 Arc<dyn Fn(validator::ValidationErrors, &HttpRequest) -> actix_web::Error + Send + Sync>;
204
205struct ValidatorErrorHandler {
206 handler: ValidatorErrHandler,
207}
208
209pub trait ValidatorErrorHandlerExt {
211 fn validator_error_handler(self, handler: ValidatorErrHandler) -> Self;
213}
214
215impl<T> ValidatorErrorHandlerExt for App<T>
216where
217 T: ServiceFactory<ServiceRequest, Config = (), Error = actix_web::Error, InitError = ()>,
218{
219 fn validator_error_handler(self, handler: ValidatorErrHandler) -> Self {
220 self.app_data(ValidatorErrorHandler { handler })
221 }
222}
223
224impl ValidatorErrorHandlerExt for &mut actix_web::web::ServiceConfig {
225 fn validator_error_handler(self, handler: ValidatorErrHandler) -> Self {
226 self.app_data(ValidatorErrorHandler { handler })
227 }
228}
229
230#[cfg(test)]
231mod test {
232 use super::*;
233 use actix_web::web::Bytes;
234 use actix_web::{http::header::ContentType, post, test, web::Json, App, Responder};
235 use serde::{Deserialize, Serialize};
236 use validator::Validate;
237
238 #[derive(Debug, Deserialize, Serialize, Validate)]
239 struct ExamplePayload {
240 #[validate(length(min = 5))]
241 name: String,
242 }
243
244 #[post("/")]
245 async fn endpoint(v: Validated<Json<ExamplePayload>>) -> impl Responder {
246 assert!(v.name.len() > 4);
247 HttpResponse::Ok().body(())
248 }
249
250 #[actix_web::test]
251 async fn should_validate_simple() {
252 let app = test::init_service(App::new().service(endpoint)).await;
253
254 let req = test::TestRequest::post()
256 .uri("/")
257 .insert_header(ContentType::plaintext())
258 .set_json(ExamplePayload {
259 name: "123456".to_string(),
260 })
261 .to_request();
262 let resp = test::call_service(&app, req).await;
263 assert_eq!(resp.status().as_u16(), 200);
264
265 let req = test::TestRequest::post()
267 .uri("/")
268 .insert_header(ContentType::plaintext())
269 .set_json(ExamplePayload {
270 name: "1234".to_string(),
271 })
272 .to_request();
273 let resp = test::call_service(&app, req).await;
274 assert_eq!(resp.status().as_u16(), 400);
275 }
276
277 #[ignore]
279 #[actix_web::test]
280 async fn should_respond_with_errors_correctly() {
281 let app = test::init_service(App::new().service(endpoint)).await;
282
283 let req = test::TestRequest::post()
285 .uri("/")
286 .insert_header(ContentType::plaintext())
287 .set_json(ExamplePayload {
288 name: "1234".to_string(),
289 })
290 .to_request();
291 let result = test::call_and_read_body(&app, req).await;
292 assert_eq!(
293 result,
294 Bytes::from_static(b"Validation errors in fields:\n\tname: Validation error: length [{\"min\": Number(5), \"value\": String(\"1234\")}]")
295 );
296 }
297
298 #[derive(Debug, Serialize, Error)]
299 struct CustomErrorResponse {
300 custom_message: String,
301 errors: Vec<String>,
302 }
303
304 impl Display for CustomErrorResponse {
305 fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 unimplemented!()
307 }
308 }
309
310 impl ResponseError for CustomErrorResponse {
311 fn status_code(&self) -> actix_web::http::StatusCode {
312 actix_web::http::StatusCode::BAD_REQUEST
313 }
314
315 fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
316 HttpResponse::build(self.status_code()).body(serde_json::to_string(self).unwrap())
317 }
318 }
319
320 fn error_handler(errors: ::validator::ValidationErrors, _: &HttpRequest) -> actix_web::Error {
321 CustomErrorResponse {
322 custom_message: "My custom message".to_string(),
323 errors: errors
324 .errors()
325 .iter()
326 .map(|(err, _)| err.to_string())
327 .collect(),
328 }
329 .into()
330 }
331
332 #[actix_web::test]
333 async fn should_use_allow_custom_error_responses() {
334 let app = test::init_service(
335 App::new()
336 .service(endpoint)
337 .validator_error_handler(Arc::new(error_handler)),
338 )
339 .await;
340
341 let req = test::TestRequest::post()
342 .uri("/")
343 .insert_header(ContentType::plaintext())
344 .set_json(ExamplePayload {
345 name: "1234".to_string(),
346 })
347 .to_request();
348 let result = test::call_and_read_body(&app, req).await;
349 assert_eq!(
350 result,
351 Bytes::from_static(b"{\"custom_message\":\"My custom message\",\"errors\":[\"name\"]}")
352 );
353 }
354
355 #[test]
356 async fn debug_for_validated_should_work() {
357 let v = Validated(ExamplePayload {
358 name: "abcde".to_string(),
359 });
360
361 assert_eq!(
362 "Validated(ExamplePayload { name: \"abcde\" })",
363 format!("{v:?}")
364 );
365 }
366}