1use std::{collections::HashMap, ops::Deref, sync::Arc};
2
3use actix_web::{
4 dev::Payload, http::StatusCode, web::JsonBody, Error, FromRequest, HttpRequest, HttpResponse,
5 HttpResponseBuilder, ResponseError,
6};
7use futures_util::{future::LocalBoxFuture, FutureExt};
8use serde::de::DeserializeOwned;
9use serde_json::{json, Value};
10use serde_valid::{validation::Errors as ValidationError, Validate};
11
12#[derive(Debug, thiserror::Error)]
13pub enum AppError {
14 #[error("{{\"non_field_errors\": [\"Validation failed\"]}}")]
15 ValidationError(HashMap<String, Value>),
16}
17
18impl ResponseError for AppError {
19 fn status_code(&self) -> actix_web::http::StatusCode {
20 match self {
21 AppError::ValidationError(_) => StatusCode::BAD_REQUEST,
22 }
23 }
24
25 fn error_response(&self) -> HttpResponse {
26 let response_body = match self {
27 AppError::ValidationError(errors) => {
28 serde_json::json!(errors)
29 }
30 };
31
32 HttpResponseBuilder::new(self.status_code()).json(response_body)
33 }
34}
35
36fn format_errors(errors: ValidationError) -> HashMap<String, Value> {
37 let mut result = HashMap::new();
38 process_errors(&mut result, None, errors);
39 result
40}
41
42fn process_errors(
43 result: &mut HashMap<String, Value>,
44 key: Option<String>,
45 errors: ValidationError,
46) {
47 match errors {
48 ValidationError::Array(array_errors) => {
49 if !array_errors.errors.is_empty() {
50 let error_messages: Vec<String> = array_errors
51 .errors
52 .iter()
53 .map(ToString::to_string)
54 .collect();
55 result.insert(
56 key.clone()
57 .unwrap_or_else(|| "non_field_errors".to_string()),
58 json!(error_messages),
59 );
60 }
61
62 if !array_errors.items.is_empty() {
64 let mut nested_map: HashMap<String, Value> = HashMap::new();
65 for (prop, error) in array_errors.items {
66 process_errors(&mut nested_map, Some(prop.to_string()), error);
67 }
68 for (prop, value) in nested_map {
69 result.insert(prop, value);
70 }
71 }
72 }
73
74 ValidationError::Object(object_errors) => {
75 if !object_errors.errors.is_empty() {
77 let msgs: Vec<String> = object_errors
78 .errors
79 .iter()
80 .map(ToString::to_string)
81 .collect();
82
83 result.insert(
84 key.clone().unwrap_or_else(|| "non_field_errors".into()),
86 json!(msgs),
87 );
88 }
89
90 let mut child_map = serde_json::Map::new();
92 for (prop, err) in object_errors.properties {
93 let mut child_result = HashMap::new();
94 process_errors(&mut child_result, None, err);
95 if child_result.len() == 1 && child_result.contains_key("non_field_errors") {
105 child_map.insert(prop, child_result.remove("non_field_errors").unwrap());
106 } else {
107 child_map.insert(prop, json!(child_result));
108 }
109 }
110
111 if !child_map.is_empty() {
114 if let Some(parent) = key {
115 match result.get_mut(&parent) {
118 Some(val) if val.is_object() => {
119 if let Some(obj) = val.as_object_mut() {
121 for (child_prop, child_val) in child_map {
122 obj.insert(child_prop, child_val);
123 }
124 }
125 }
126 _ => {
127 result.insert(parent, json!(child_map));
129 }
130 }
131 } else {
132 for (child_prop, child_val) in child_map {
134 result.insert(child_prop, child_val);
135 }
136 }
137 }
138 }
139
140 ValidationError::NewType(vec_errors) => {
141 if !vec_errors.is_empty() {
142 let error_messages: Vec<String> =
143 vec_errors.iter().map(ToString::to_string).collect();
144 result.insert(
145 key.unwrap_or_else(|| "non_field_errors".to_string()),
146 json!(error_messages),
147 );
148 }
149 }
150 }
151}
152
153#[derive(Debug)]
154pub struct AppJson<T>(pub T);
155
156impl<T> AppJson<T> {
157 pub fn into_inner(self) -> T {
159 self.0
160 }
161}
162
163impl<T> AsRef<T> for AppJson<T> {
164 fn as_ref(&self) -> &T {
165 &self.0
166 }
167}
168
169impl<T> Deref for AppJson<T> {
170 type Target = T;
171
172 fn deref(&self) -> &T {
173 &self.0
174 }
175}
176
177impl<T> FromRequest for AppJson<T>
178where
179 T: DeserializeOwned + Validate + 'static,
180{
181 type Error = AppError;
182 type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
183
184 #[inline]
185 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
186 let (limit, ctype) = req
187 .app_data::<JsonConfig>()
188 .map(|c| (c.limit, c.content_type.clone()))
189 .unwrap_or((32768, None));
190
191 JsonBody::<T>::new(req, payload, ctype.as_deref(), false)
192 .limit(limit)
193 .map(|res| match res {
194 Ok(data) => data
195 .validate()
196 .map_err(|err: serde_valid::validation::Errors| {
197 println!("{:?}", err);
198 Self::Error::ValidationError(format_errors(err))
199 })
200 .map(|_| AppJson(data)),
201 Err(e) => Err(Self::Error::ValidationError({
202 let mut formatted_errors = HashMap::new();
203 formatted_errors.insert("error".to_string(), json!(vec![e.to_string()]));
204 formatted_errors
205 })),
206 })
207 .boxed_local()
208 }
209}
210
211type ErrHandler = Arc<dyn Fn(Error, &HttpRequest) -> actix_web::Error + Send + Sync>;
212
213#[derive(Clone)]
214pub struct JsonConfig {
215 limit: usize,
216 ehandler: Option<ErrHandler>,
217 content_type: Option<Arc<dyn Fn(mime::Mime) -> bool + Send + Sync>>,
218}
219
220impl JsonConfig {
221 pub fn limit(mut self, limit: usize) -> Self {
223 self.limit = limit;
224 self
225 }
226
227 pub fn error_handler<F>(mut self, f: F) -> Self
229 where
230 F: Fn(Error, &HttpRequest) -> actix_web::Error + Send + Sync + 'static,
231 {
232 self.ehandler = Some(Arc::new(f));
233 self
234 }
235
236 pub fn content_type<F>(mut self, predicate: F) -> Self
238 where
239 F: Fn(mime::Mime) -> bool + Send + Sync + 'static,
240 {
241 self.content_type = Some(Arc::new(predicate));
242 self
243 }
244}
245
246impl Default for JsonConfig {
247 fn default() -> Self {
248 JsonConfig {
249 limit: 32768,
250 ehandler: None,
251 content_type: None,
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use actix_web::body::MessageBody;
260 use actix_web::http::StatusCode;
261 use actix_web::web::Bytes;
262 use actix_web::{test, ResponseError};
263 use serde::Deserialize;
264 use serde_json::json;
265 use serde_valid::{validation::Error as SVError, Validate};
266
267 #[actix_web::test]
268 async fn test_field_level_error() {
269 #[derive(Debug, Deserialize, Validate)]
270 struct Test {
271 #[validate(min_length = 3)]
272 name: String,
273 }
274 let (req, mut payload) = test::TestRequest::post()
275 .set_payload(json!({"name": "tt"}).to_string())
276 .to_http_parts();
277
278 let res = AppJson::<Test>::from_request(&req, &mut payload)
279 .await
280 .unwrap_err();
281
282 assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
283 let body = res.error_response().into_body().try_into_bytes().unwrap();
284 assert_eq!(
285 body,
286 Bytes::from_static(b"{\"name\":[\"The length of the value must be `>= 3`.\"]}")
287 );
288 }
289
290 #[actix_web::test]
291 async fn test_nested_field_level_error() {
292 #[derive(Debug, Deserialize, Validate)]
293 struct Test {
294 #[validate]
295 inner: Inner,
296 }
297
298 #[derive(Debug, Deserialize, Validate)]
299 struct Inner {
300 #[validate(min_length = 3)]
301 name: String,
302 }
303
304 let (req, mut payload) = test::TestRequest::post()
305 .set_payload(json!({"inner": {"name": "tt"}}).to_string())
306 .to_http_parts();
307
308 let res = AppJson::<Test>::from_request(&req, &mut payload)
309 .await
310 .unwrap_err();
311
312 assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
313 let body = res.error_response().into_body().try_into_bytes().unwrap();
314 assert_eq!(
315 body,
316 Bytes::from_static(
317 b"{\"inner\":{\"name\":[\"The length of the value must be `>= 3`.\"]}}"
318 )
319 );
320 }
321
322 #[actix_web::test]
323 async fn test_top_level_error() {
324 #[derive(Debug, Deserialize, Validate)]
327 #[validate(custom = top_level_check)]
328 struct TestStruct {
329 pub data: String,
330 pub is_valid: bool,
331 }
332
333 fn top_level_check(value: &TestStruct) -> Result<(), SVError> {
334 if !value.is_valid || !value.data.is_empty() {
335 return Err(SVError::Custom("Overall data is invalid!".to_string()));
336 }
337 Ok(())
338 }
339
340 let payload_data = json!({"data": "some stuff", "is_valid": false}).to_string();
342 let (req, mut payload) = test::TestRequest::post()
343 .set_payload(payload_data)
344 .to_http_parts();
345
346 let res = AppJson::<TestStruct>::from_request(&req, &mut payload)
347 .await
348 .unwrap_err();
349
350 assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
352 let body = res.error_response().into_body().try_into_bytes().unwrap();
353 let expected_json = json!({
354 "non_field_errors": ["Overall data is invalid!"]
355 });
356 let expected_string = expected_json.to_string(); let expected_bytes = Bytes::from(expected_string); assert_eq!(body, expected_bytes);
360 }
361
362 #[actix_web::test]
364 async fn test_array_error() {
365 #[derive(Debug, Deserialize, Validate)]
367 struct ArrayStruct {
368 #[validate(min_items = 2)] items: Vec<String>,
370 }
371
372 let payload_data = json!({"items": ["ab"]}).to_string();
374 let (req, mut payload) = test::TestRequest::post()
375 .set_payload(payload_data)
376 .to_http_parts();
377
378 let res = AppJson::<ArrayStruct>::from_request(&req, &mut payload)
379 .await
380 .unwrap_err();
381
382 assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
383 let body = res.error_response().into_body().try_into_bytes().unwrap();
384
385 let expected = json!({
386 "items": ["The length of the items must be `>= 2`."]
387 });
388 let expected_string = expected.to_string();
389 let expected_bytes = Bytes::from(expected_string);
390 assert_eq!(body, expected_bytes);
391 }
392
393 #[actix_web::test]
395 async fn test_multiple_nested_errors() {
396 #[derive(Debug, Deserialize, Validate)]
397 struct Parent {
398 #[validate]
399 inner1: Inner,
400 #[validate]
401 inner2: Inner,
402 }
403
404 #[derive(Debug, Deserialize, Validate)]
405 struct Inner {
406 #[validate(min_length = 3)]
407 name: String,
408 #[validate(minimum = 10)]
409 age: u8,
410 }
411
412 let payload_data = json!({
413 "inner1": {"name": "ab", "age": 9},
414 "inner2": {"name": "cd", "age": 5}
415 })
416 .to_string();
417 let (req, mut payload) = test::TestRequest::post()
418 .set_payload(payload_data)
419 .to_http_parts();
420
421 let res = AppJson::<Parent>::from_request(&req, &mut payload)
422 .await
423 .unwrap_err();
424
425 assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
426 let body = res.error_response().into_body().try_into_bytes().unwrap();
427
428 let expected = json!({
429 "inner1": {
430 "name": ["The length of the value must be `>= 3`."],
431 "age": ["The number must be `>= 10`."]
432 },
433 "inner2": {
434 "name": ["The length of the value must be `>= 3`."],
435 "age": ["The number must be `>= 10`."]
436 }
437 });
438
439 let expected_string = expected.to_string();
440 let expected_bytes = Bytes::from(expected_string);
441 assert_eq!(body, expected_bytes);
442 }
443
444 #[actix_web::test]
445 async fn test_newtype_validation_error() {
446 #[derive(Debug, Deserialize, Validate)]
447 struct NewTypeWrapper(#[validate(minimum = 10)] i32);
448
449 let payload_data = json!(5).to_string(); let (req, mut payload) = test::TestRequest::post()
451 .set_payload(payload_data)
452 .to_http_parts();
453
454 let res = AppJson::<NewTypeWrapper>::from_request(&req, &mut payload)
455 .await
456 .unwrap_err();
457
458 assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
459 let body = res.error_response().into_body().try_into_bytes().unwrap();
460 let expected = json!({
461 "non_field_errors": ["The number must be `>= 10`."]
462 });
463
464 let expected_string = expected.to_string();
465 let expected_bytes = Bytes::from(expected_string);
466 assert_eq!(body, expected_bytes);
467 }
468}