use std::{collections::HashMap, ops::Deref, sync::Arc};
use actix_web::{
dev::Payload, http::StatusCode, web::JsonBody, Error, FromRequest, HttpRequest, HttpResponse,
HttpResponseBuilder, ResponseError,
};
use futures_util::{future::LocalBoxFuture, FutureExt};
use serde::de::DeserializeOwned;
use serde_json::{json, Value};
use serde_valid::{validation::Errors as ValidationError, Validate};
#[derive(Debug, thiserror::Error)]
pub enum AppError {
#[error("{{\"non_field_errors\": [\"Validation failed\"]}}")]
ValidationError(HashMap<String, Value>),
}
impl ResponseError for AppError {
fn status_code(&self) -> actix_web::http::StatusCode {
match self {
AppError::ValidationError(_) => StatusCode::BAD_REQUEST,
}
}
fn error_response(&self) -> HttpResponse {
let response_body = match self {
AppError::ValidationError(errors) => {
serde_json::json!(errors)
}
};
HttpResponseBuilder::new(self.status_code()).json(response_body)
}
}
fn format_errors(errors: ValidationError) -> HashMap<String, Value> {
let mut result = HashMap::new();
process_errors(&mut result, None, errors);
result
}
fn process_errors(
result: &mut HashMap<String, Value>,
key: Option<String>,
errors: ValidationError,
) {
match errors {
ValidationError::Array(array_errors) => {
if !array_errors.errors.is_empty() {
let error_messages: Vec<String> = array_errors
.errors
.iter()
.map(ToString::to_string)
.collect();
result.insert(
key.clone()
.unwrap_or_else(|| "non_field_errors".to_string()),
json!(error_messages),
);
}
if !array_errors.items.is_empty() {
let mut nested_map: HashMap<String, Value> = HashMap::new();
for (prop, error) in array_errors.items {
process_errors(&mut nested_map, Some(prop.to_string()), error);
}
for (prop, value) in nested_map {
result.insert(prop, value);
}
}
}
ValidationError::Object(object_errors) => {
if !object_errors.errors.is_empty() {
let msgs: Vec<String> = object_errors
.errors
.iter()
.map(ToString::to_string)
.collect();
result.insert(
key.clone().unwrap_or_else(|| "non_field_errors".into()),
json!(msgs),
);
}
let mut child_map = serde_json::Map::new();
for (prop, err) in object_errors.properties {
let mut child_result = HashMap::new();
process_errors(&mut child_result, None, err);
if child_result.len() == 1 && child_result.contains_key("non_field_errors") {
child_map.insert(prop, child_result.remove("non_field_errors").unwrap());
} else {
child_map.insert(prop, json!(child_result));
}
}
if !child_map.is_empty() {
if let Some(parent) = key {
match result.get_mut(&parent) {
Some(val) if val.is_object() => {
if let Some(obj) = val.as_object_mut() {
for (child_prop, child_val) in child_map {
obj.insert(child_prop, child_val);
}
}
}
_ => {
result.insert(parent, json!(child_map));
}
}
} else {
for (child_prop, child_val) in child_map {
result.insert(child_prop, child_val);
}
}
}
}
ValidationError::NewType(vec_errors) => {
if !vec_errors.is_empty() {
let error_messages: Vec<String> =
vec_errors.iter().map(ToString::to_string).collect();
result.insert(
key.unwrap_or_else(|| "non_field_errors".to_string()),
json!(error_messages),
);
}
}
}
}
#[derive(Debug)]
pub struct AppJson<T>(pub T);
impl<T> AppJson<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> AsRef<T> for AppJson<T> {
fn as_ref(&self) -> &T {
&self.0
}
}
impl<T> Deref for AppJson<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> FromRequest for AppJson<T>
where
T: DeserializeOwned + Validate + 'static,
{
type Error = AppError;
type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
let (limit, ctype) = req
.app_data::<JsonConfig>()
.map(|c| (c.limit, c.content_type.clone()))
.unwrap_or((32768, None));
JsonBody::<T>::new(req, payload, ctype.as_deref(), false)
.limit(limit)
.map(|res| match res {
Ok(data) => data
.validate()
.map_err(|err: serde_valid::validation::Errors| {
println!("{:?}", err);
Self::Error::ValidationError(format_errors(err))
})
.map(|_| AppJson(data)),
Err(e) => Err(Self::Error::ValidationError({
let mut formatted_errors = HashMap::new();
formatted_errors.insert("error".to_string(), json!(vec![e.to_string()]));
formatted_errors
})),
})
.boxed_local()
}
}
type ErrHandler = Arc<dyn Fn(Error, &HttpRequest) -> actix_web::Error + Send + Sync>;
#[derive(Clone)]
pub struct JsonConfig {
limit: usize,
ehandler: Option<ErrHandler>,
content_type: Option<Arc<dyn Fn(mime::Mime) -> bool + Send + Sync>>,
}
impl JsonConfig {
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn error_handler<F>(mut self, f: F) -> Self
where
F: Fn(Error, &HttpRequest) -> actix_web::Error + Send + Sync + 'static,
{
self.ehandler = Some(Arc::new(f));
self
}
pub fn content_type<F>(mut self, predicate: F) -> Self
where
F: Fn(mime::Mime) -> bool + Send + Sync + 'static,
{
self.content_type = Some(Arc::new(predicate));
self
}
}
impl Default for JsonConfig {
fn default() -> Self {
JsonConfig {
limit: 32768,
ehandler: None,
content_type: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use actix_web::body::MessageBody;
use actix_web::http::StatusCode;
use actix_web::web::Bytes;
use actix_web::{test, ResponseError};
use serde::Deserialize;
use serde_json::json;
use serde_valid::{validation::Error as SVError, Validate};
#[actix_web::test]
async fn test_field_level_error() {
#[derive(Debug, Deserialize, Validate)]
struct Test {
#[validate(min_length = 3)]
name: String,
}
let (req, mut payload) = test::TestRequest::post()
.set_payload(json!({"name": "tt"}).to_string())
.to_http_parts();
let res = AppJson::<Test>::from_request(&req, &mut payload)
.await
.unwrap_err();
assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
let body = res.error_response().into_body().try_into_bytes().unwrap();
assert_eq!(
body,
Bytes::from_static(b"{\"name\":[\"The length of the value must be `>= 3`.\"]}")
);
}
#[actix_web::test]
async fn test_nested_field_level_error() {
#[derive(Debug, Deserialize, Validate)]
struct Test {
#[validate]
inner: Inner,
}
#[derive(Debug, Deserialize, Validate)]
struct Inner {
#[validate(min_length = 3)]
name: String,
}
let (req, mut payload) = test::TestRequest::post()
.set_payload(json!({"inner": {"name": "tt"}}).to_string())
.to_http_parts();
let res = AppJson::<Test>::from_request(&req, &mut payload)
.await
.unwrap_err();
assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
let body = res.error_response().into_body().try_into_bytes().unwrap();
assert_eq!(
body,
Bytes::from_static(
b"{\"inner\":{\"name\":[\"The length of the value must be `>= 3`.\"]}}"
)
);
}
#[actix_web::test]
async fn test_top_level_error() {
#[derive(Debug, Deserialize, Validate)]
#[validate(custom = top_level_check)]
struct TestStruct {
pub data: String,
pub is_valid: bool,
}
fn top_level_check(value: &TestStruct) -> Result<(), SVError> {
if !value.is_valid || !value.data.is_empty() {
return Err(SVError::Custom("Overall data is invalid!".to_string()));
}
Ok(())
}
let payload_data = json!({"data": "some stuff", "is_valid": false}).to_string();
let (req, mut payload) = test::TestRequest::post()
.set_payload(payload_data)
.to_http_parts();
let res = AppJson::<TestStruct>::from_request(&req, &mut payload)
.await
.unwrap_err();
assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
let body = res.error_response().into_body().try_into_bytes().unwrap();
let expected_json = json!({
"non_field_errors": ["Overall data is invalid!"]
});
let expected_string = expected_json.to_string(); let expected_bytes = Bytes::from(expected_string);
assert_eq!(body, expected_bytes);
}
#[actix_web::test]
async fn test_array_error() {
#[derive(Debug, Deserialize, Validate)]
struct ArrayStruct {
#[validate(min_items = 2)] items: Vec<String>,
}
let payload_data = json!({"items": ["ab"]}).to_string();
let (req, mut payload) = test::TestRequest::post()
.set_payload(payload_data)
.to_http_parts();
let res = AppJson::<ArrayStruct>::from_request(&req, &mut payload)
.await
.unwrap_err();
assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
let body = res.error_response().into_body().try_into_bytes().unwrap();
let expected = json!({
"items": ["The length of the items must be `>= 2`."]
});
let expected_string = expected.to_string();
let expected_bytes = Bytes::from(expected_string);
assert_eq!(body, expected_bytes);
}
#[actix_web::test]
async fn test_multiple_nested_errors() {
#[derive(Debug, Deserialize, Validate)]
struct Parent {
#[validate]
inner1: Inner,
#[validate]
inner2: Inner,
}
#[derive(Debug, Deserialize, Validate)]
struct Inner {
#[validate(min_length = 3)]
name: String,
#[validate(minimum = 10)]
age: u8,
}
let payload_data = json!({
"inner1": {"name": "ab", "age": 9},
"inner2": {"name": "cd", "age": 5}
})
.to_string();
let (req, mut payload) = test::TestRequest::post()
.set_payload(payload_data)
.to_http_parts();
let res = AppJson::<Parent>::from_request(&req, &mut payload)
.await
.unwrap_err();
assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
let body = res.error_response().into_body().try_into_bytes().unwrap();
let expected = json!({
"inner1": {
"name": ["The length of the value must be `>= 3`."],
"age": ["The number must be `>= 10`."]
},
"inner2": {
"name": ["The length of the value must be `>= 3`."],
"age": ["The number must be `>= 10`."]
}
});
let expected_string = expected.to_string();
let expected_bytes = Bytes::from(expected_string);
assert_eq!(body, expected_bytes);
}
#[actix_web::test]
async fn test_newtype_validation_error() {
#[derive(Debug, Deserialize, Validate)]
struct NewTypeWrapper(#[validate(minimum = 10)] i32);
let payload_data = json!(5).to_string(); let (req, mut payload) = test::TestRequest::post()
.set_payload(payload_data)
.to_http_parts();
let res = AppJson::<NewTypeWrapper>::from_request(&req, &mut payload)
.await
.unwrap_err();
assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
let body = res.error_response().into_body().try_into_bytes().unwrap();
let expected = json!({
"non_field_errors": ["The number must be `>= 10`."]
});
let expected_string = expected.to_string();
let expected_bytes = Bytes::from(expected_string);
assert_eq!(body, expected_bytes);
}
}