pub trait Normalizable {
fn normalize(&mut self) {
}
}
#[cfg(feature = "axum")]
use axum::{
extract::{rejection::JsonRejection, FromRequest, Request},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
#[cfg(feature = "axum")]
use serde::de::DeserializeOwned;
#[cfg(feature = "axum")]
use serde_json::json;
#[cfg(feature = "axum")]
use validator::Validate;
#[cfg(feature = "axum")]
pub struct ValidatedJson<T>(pub T);
#[cfg(feature = "axum")]
impl<S, T> FromRequest<S> for ValidatedJson<T>
where
T: DeserializeOwned + Validate + Normalizable + Send,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let Json(mut data) =
Json::<T>::from_request(req, state)
.await
.map_err(|err: JsonRejection| {
let error_message = match err {
JsonRejection::JsonDataError(e) => {
format!("Invalid JSON data: {}", e)
}
JsonRejection::JsonSyntaxError(e) => {
format!("JSON syntax error: {}", e)
}
JsonRejection::MissingJsonContentType(_) => {
"Missing Content-Type: application/json header".to_string()
}
_ => format!("Failed to parse JSON: {}", err),
};
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": error_message,
"type": "invalid_request_error",
"code": "json_parse_error"
}
})),
)
.into_response()
})?;
data.normalize();
data.validate().map_err(|validation_errors| {
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": validation_errors.to_string(),
"type": "invalid_request_error",
"code": 400
}
})),
)
.into_response()
})?;
Ok(ValidatedJson(data))
}
}
#[cfg(feature = "axum")]
impl<T> std::ops::Deref for ValidatedJson<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(feature = "axum")]
impl<T> std::ops::DerefMut for ValidatedJson<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[cfg(all(test, feature = "axum"))]
mod tests {
use serde::{Deserialize, Serialize};
use validator::Validate;
use super::*;
#[derive(Debug, Deserialize, Serialize, Validate)]
struct TestRequest {
#[validate(range(min = 0.0, max = 1.0))]
value: f32,
#[validate(length(min = 1))]
name: String,
}
impl Normalizable for TestRequest {
}
#[tokio::test]
async fn test_validated_json_valid() {
let request = TestRequest {
value: 0.5,
name: "test".to_string(),
};
assert!(request.validate().is_ok());
}
#[tokio::test]
async fn test_validated_json_invalid_range() {
let request = TestRequest {
value: 1.5, name: "test".to_string(),
};
assert!(request.validate().is_err());
}
#[tokio::test]
async fn test_validated_json_invalid_length() {
let request = TestRequest {
value: 0.5,
name: "".to_string(), };
assert!(request.validate().is_err());
}
}