actix_json_validator/
lib.rs

1use std::{collections::HashMap, ops::Deref, sync::Arc};
2
3use actix_web::{
4    dev::Payload, http::StatusCode, web::JsonBody, Error, FromRequest, HttpRequest, HttpResponse,
5    HttpResponseBuilder, ResponseError,
6};
7use futures_util::{future::LocalBoxFuture, FutureExt};
8use serde::de::DeserializeOwned;
9use serde_json::{json, Value};
10use serde_valid::{validation::Errors as ValidationError, Validate};
11
12#[derive(Debug, thiserror::Error)]
13pub enum AppError {
14    #[error("{{\"non_field_errors\": [\"Validation failed\"]}}")]
15    ValidationError(HashMap<String, Value>),
16}
17
18impl ResponseError for AppError {
19    fn status_code(&self) -> actix_web::http::StatusCode {
20        match self {
21            AppError::ValidationError(_) => StatusCode::BAD_REQUEST,
22        }
23    }
24
25    fn error_response(&self) -> HttpResponse {
26        let response_body = match self {
27            AppError::ValidationError(errors) => {
28                serde_json::json!(errors)
29            }
30        };
31
32        HttpResponseBuilder::new(self.status_code()).json(response_body)
33    }
34}
35
36fn format_errors(errors: ValidationError) -> HashMap<String, Value> {
37    let mut result = HashMap::new();
38    process_errors(&mut result, None, errors);
39    result
40}
41
42fn process_errors(
43    result: &mut HashMap<String, Value>,
44    key: Option<String>,
45    errors: ValidationError,
46) {
47    match errors {
48        ValidationError::Array(array_errors) => {
49            if !array_errors.errors.is_empty() {
50                let error_messages: Vec<String> = array_errors
51                    .errors
52                    .iter()
53                    .map(ToString::to_string)
54                    .collect();
55                result.insert(
56                    key.clone()
57                        .unwrap_or_else(|| "non_field_errors".to_string()),
58                    json!(error_messages),
59                );
60            }
61
62            // Recursively process nested errors
63            if !array_errors.items.is_empty() {
64                let mut nested_map: HashMap<String, Value> = HashMap::new();
65                for (prop, error) in array_errors.items {
66                    process_errors(&mut nested_map, Some(prop.to_string()), error);
67                }
68                for (prop, value) in nested_map {
69                    result.insert(prop, value);
70                }
71            }
72        }
73
74        ValidationError::Object(object_errors) => {
75            // 1) Collect any direct (top-level) errors on this object
76            if !object_errors.errors.is_empty() {
77                let msgs: Vec<String> = object_errors
78                    .errors
79                    .iter()
80                    .map(ToString::to_string)
81                    .collect();
82
83                result.insert(
84                    // If there's a parent key, use it; otherwise use "non_field_errors"
85                    key.clone().unwrap_or_else(|| "non_field_errors".into()),
86                    json!(msgs),
87                );
88            }
89
90            // 2) For each property, recurse and gather its errors in a local map
91            let mut child_map = serde_json::Map::new();
92            for (prop, err) in object_errors.properties {
93                let mut child_result = HashMap::new();
94                process_errors(&mut child_result, None, err);
95                // child_result is HashMap<String, Value>; we typically expect
96                // it to have either "non_field_errors" or property keys.
97
98                // Merge child_result into a single Value
99                // If it has only one key that is "non_field_errors", we flatten:
100                //    "prop": [ ...error array... ]
101                // else store the entire map:
102                //    "prop": { ... }
103
104                if child_result.len() == 1 && child_result.contains_key("non_field_errors") {
105                    child_map.insert(prop, child_result.remove("non_field_errors").unwrap());
106                } else {
107                    child_map.insert(prop, json!(child_result));
108                }
109            }
110
111            // 3) Now we have a map of child properties. If there's a parent key,
112            //    nest them under that parent key. Otherwise, store them top-level.
113            if !child_map.is_empty() {
114                if let Some(parent) = key {
115                    // If the parent key already exists in result and is an object,
116                    // we can merge. If it's an array, or doesn't exist yet, handle accordingly.
117                    match result.get_mut(&parent) {
118                        Some(val) if val.is_object() => {
119                            // Merge child_map into the existing object
120                            if let Some(obj) = val.as_object_mut() {
121                                for (child_prop, child_val) in child_map {
122                                    obj.insert(child_prop, child_val);
123                                }
124                            }
125                        }
126                        _ => {
127                            // Overwrite or create new
128                            result.insert(parent, json!(child_map));
129                        }
130                    }
131                } else {
132                    // We are top-level
133                    for (child_prop, child_val) in child_map {
134                        result.insert(child_prop, child_val);
135                    }
136                }
137            }
138        }
139
140        ValidationError::NewType(vec_errors) => {
141            if !vec_errors.is_empty() {
142                let error_messages: Vec<String> =
143                    vec_errors.iter().map(ToString::to_string).collect();
144                result.insert(
145                    key.unwrap_or_else(|| "non_field_errors".to_string()),
146                    json!(error_messages),
147                );
148            }
149        }
150    }
151}
152
153#[derive(Debug)]
154pub struct AppJson<T>(pub T);
155
156impl<T> AppJson<T> {
157    /// Deconstruct to an inner value
158    pub fn into_inner(self) -> T {
159        self.0
160    }
161}
162
163impl<T> AsRef<T> for AppJson<T> {
164    fn as_ref(&self) -> &T {
165        &self.0
166    }
167}
168
169impl<T> Deref for AppJson<T> {
170    type Target = T;
171
172    fn deref(&self) -> &T {
173        &self.0
174    }
175}
176
177impl<T> FromRequest for AppJson<T>
178where
179    T: DeserializeOwned + Validate + 'static,
180{
181    type Error = AppError;
182    type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
183
184    #[inline]
185    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
186        let (limit, ctype) = req
187            .app_data::<JsonConfig>()
188            .map(|c| (c.limit, c.content_type.clone()))
189            .unwrap_or((32768, None));
190
191        JsonBody::<T>::new(req, payload, ctype.as_deref(), false)
192            .limit(limit)
193            .map(|res| match res {
194                Ok(data) => data
195                    .validate()
196                    .map_err(|err: serde_valid::validation::Errors| {
197                        println!("{:?}", err);
198                        Self::Error::ValidationError(format_errors(err))
199                    })
200                    .map(|_| AppJson(data)),
201                Err(e) => Err(Self::Error::ValidationError({
202                    let mut formatted_errors = HashMap::new();
203                    formatted_errors.insert("error".to_string(), json!(vec![e.to_string()]));
204                    formatted_errors
205                })),
206            })
207            .boxed_local()
208    }
209}
210
211type ErrHandler = Arc<dyn Fn(Error, &HttpRequest) -> actix_web::Error + Send + Sync>;
212
213#[derive(Clone)]
214pub struct JsonConfig {
215    limit: usize,
216    ehandler: Option<ErrHandler>,
217    content_type: Option<Arc<dyn Fn(mime::Mime) -> bool + Send + Sync>>,
218}
219
220impl JsonConfig {
221    /// Change max size of payload. By default max size is 32Kb
222    pub fn limit(mut self, limit: usize) -> Self {
223        self.limit = limit;
224        self
225    }
226
227    /// Set custom error handler
228    pub fn error_handler<F>(mut self, f: F) -> Self
229    where
230        F: Fn(Error, &HttpRequest) -> actix_web::Error + Send + Sync + 'static,
231    {
232        self.ehandler = Some(Arc::new(f));
233        self
234    }
235
236    /// Set predicate for allowed content types
237    pub fn content_type<F>(mut self, predicate: F) -> Self
238    where
239        F: Fn(mime::Mime) -> bool + Send + Sync + 'static,
240    {
241        self.content_type = Some(Arc::new(predicate));
242        self
243    }
244}
245
246impl Default for JsonConfig {
247    fn default() -> Self {
248        JsonConfig {
249            limit: 32768,
250            ehandler: None,
251            content_type: None,
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use actix_web::body::MessageBody;
260    use actix_web::http::StatusCode;
261    use actix_web::web::Bytes;
262    use actix_web::{test, ResponseError};
263    use serde::Deserialize;
264    use serde_json::json;
265    use serde_valid::{validation::Error as SVError, Validate};
266
267    #[actix_web::test]
268    async fn test_field_level_error() {
269        #[derive(Debug, Deserialize, Validate)]
270        struct Test {
271            #[validate(min_length = 3)]
272            name: String,
273        }
274        let (req, mut payload) = test::TestRequest::post()
275            .set_payload(json!({"name": "tt"}).to_string())
276            .to_http_parts();
277
278        let res = AppJson::<Test>::from_request(&req, &mut payload)
279            .await
280            .unwrap_err();
281
282        assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
283        let body = res.error_response().into_body().try_into_bytes().unwrap();
284        assert_eq!(
285            body,
286            Bytes::from_static(b"{\"name\":[\"The length of the value must be `>= 3`.\"]}")
287        );
288    }
289
290    #[actix_web::test]
291    async fn test_nested_field_level_error() {
292        #[derive(Debug, Deserialize, Validate)]
293        struct Test {
294            #[validate]
295            inner: Inner,
296        }
297
298        #[derive(Debug, Deserialize, Validate)]
299        struct Inner {
300            #[validate(min_length = 3)]
301            name: String,
302        }
303
304        let (req, mut payload) = test::TestRequest::post()
305            .set_payload(json!({"inner": {"name": "tt"}}).to_string())
306            .to_http_parts();
307
308        let res = AppJson::<Test>::from_request(&req, &mut payload)
309            .await
310            .unwrap_err();
311
312        assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
313        let body = res.error_response().into_body().try_into_bytes().unwrap();
314        assert_eq!(
315            body,
316            Bytes::from_static(
317                b"{\"inner\":{\"name\":[\"The length of the value must be `>= 3`.\"]}}"
318            )
319        );
320    }
321
322    #[actix_web::test]
323    async fn test_top_level_error() {
324        /// This struct itself is "invalid" if `is_valid` is false
325        /// We'll simulate a custom validator using `#[validate(schema(function = "..."))]`
326        #[derive(Debug, Deserialize, Validate)]
327        #[validate(custom = top_level_check)]
328        struct TestStruct {
329            pub data: String,
330            pub is_valid: bool,
331        }
332
333        fn top_level_check(value: &TestStruct) -> Result<(), SVError> {
334            if !value.is_valid || !value.data.is_empty() {
335                return Err(SVError::Custom("Overall data is invalid!".to_string()));
336            }
337            Ok(())
338        }
339
340        // Provide invalid input so top-level fails
341        let payload_data = json!({"data": "some stuff", "is_valid": false}).to_string();
342        let (req, mut payload) = test::TestRequest::post()
343            .set_payload(payload_data)
344            .to_http_parts();
345
346        let res = AppJson::<TestStruct>::from_request(&req, &mut payload)
347            .await
348            .unwrap_err();
349
350        // We expect a top-level error => "non_field_errors"
351        assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
352        let body = res.error_response().into_body().try_into_bytes().unwrap();
353        let expected_json = json!({
354            "non_field_errors": ["Overall data is invalid!"]
355        });
356        let expected_string = expected_json.to_string(); // keep this string in a variable
357        let expected_bytes = Bytes::from(expected_string); // create Bytes from that string
358
359        assert_eq!(body, expected_bytes);
360    }
361
362    /// 2) Test array-level validation error
363    #[actix_web::test]
364    async fn test_array_error() {
365        /// Suppose each item in `items` must be >= 3 chars
366        #[derive(Debug, Deserialize, Validate)]
367        struct ArrayStruct {
368            #[validate(min_items = 2)] // at least 2 items
369            items: Vec<String>,
370        }
371
372        // Provide invalid data: only 1 item, length < 3
373        let payload_data = json!({"items": ["ab"]}).to_string();
374        let (req, mut payload) = test::TestRequest::post()
375            .set_payload(payload_data)
376            .to_http_parts();
377
378        let res = AppJson::<ArrayStruct>::from_request(&req, &mut payload)
379            .await
380            .unwrap_err();
381
382        assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
383        let body = res.error_response().into_body().try_into_bytes().unwrap();
384
385        let expected = json!({
386            "items": ["The length of the items must be `>= 2`."]
387        });
388        let expected_string = expected.to_string();
389        let expected_bytes = Bytes::from(expected_string);
390        assert_eq!(body, expected_bytes);
391    }
392
393    /// 3) Test multiple nested properties failing
394    #[actix_web::test]
395    async fn test_multiple_nested_errors() {
396        #[derive(Debug, Deserialize, Validate)]
397        struct Parent {
398            #[validate]
399            inner1: Inner,
400            #[validate]
401            inner2: Inner,
402        }
403
404        #[derive(Debug, Deserialize, Validate)]
405        struct Inner {
406            #[validate(min_length = 3)]
407            name: String,
408            #[validate(minimum = 10)]
409            age: u8,
410        }
411
412        let payload_data = json!({
413            "inner1": {"name": "ab", "age": 9},
414            "inner2": {"name": "cd", "age": 5}
415        })
416        .to_string();
417        let (req, mut payload) = test::TestRequest::post()
418            .set_payload(payload_data)
419            .to_http_parts();
420
421        let res = AppJson::<Parent>::from_request(&req, &mut payload)
422            .await
423            .unwrap_err();
424
425        assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
426        let body = res.error_response().into_body().try_into_bytes().unwrap();
427
428        let expected = json!({
429            "inner1": {
430                "name": ["The length of the value must be `>= 3`."],
431                "age": ["The number must be `>= 10`."]
432            },
433            "inner2": {
434                "name": ["The length of the value must be `>= 3`."],
435                "age": ["The number must be `>= 10`."]
436            }
437        });
438
439        let expected_string = expected.to_string();
440        let expected_bytes = Bytes::from(expected_string);
441        assert_eq!(body, expected_bytes);
442    }
443
444    #[actix_web::test]
445    async fn test_newtype_validation_error() {
446        #[derive(Debug, Deserialize, Validate)]
447        struct NewTypeWrapper(#[validate(minimum = 10)] i32);
448
449        let payload_data = json!(5).to_string(); // invalid: must be >= 10
450        let (req, mut payload) = test::TestRequest::post()
451            .set_payload(payload_data)
452            .to_http_parts();
453
454        let res = AppJson::<NewTypeWrapper>::from_request(&req, &mut payload)
455            .await
456            .unwrap_err();
457
458        assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
459        let body = res.error_response().into_body().try_into_bytes().unwrap();
460        let expected = json!({
461            "non_field_errors": ["The number must be `>= 10`."]
462        });
463
464        let expected_string = expected.to_string();
465        let expected_bytes = Bytes::from(expected_string);
466        assert_eq!(body, expected_bytes);
467    }
468}