use crate::error::HttpError;
use async_trait::async_trait;
use axum::{
extract::{FromRequest, Request},
Json,
};
use serde::de::DeserializeOwned;
#[cfg(feature = "validation")]
use validator::Validate;
#[derive(Debug, Clone, Copy, Default)]
pub struct ValidatedJson<T>(pub T);
#[cfg(feature = "validation")]
#[async_trait]
impl<T, S> FromRequest<S> for ValidatedJson<T>
where
T: DeserializeOwned + Validate,
S: Send + Sync,
{
type Rejection = HttpError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let Json(value) = Json::<T>::from_request(req, state)
.await
.map_err(|e| HttpError::BadRequest(format!("Invalid JSON: {}", e)))?;
value
.validate()
.map_err(|e| HttpError::ValidationError(format_validation_errors(&e)))?;
Ok(ValidatedJson(value))
}
}
#[cfg(not(feature = "validation"))]
#[async_trait]
impl<T, S> FromRequest<S> for ValidatedJson<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = HttpError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let Json(value) = Json::<T>::from_request(req, state)
.await
.map_err(|e| HttpError::BadRequest(format!("Invalid JSON: {}", e)))?;
Ok(ValidatedJson(value))
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ValidatedQuery<T>(pub T);
#[cfg(feature = "validation")]
#[async_trait]
impl<T, S> FromRequest<S> for ValidatedQuery<T>
where
T: DeserializeOwned + Validate,
S: Send + Sync,
{
type Rejection = HttpError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let axum::extract::Query(value) = axum::extract::Query::<T>::from_request(req, state)
.await
.map_err(|e| HttpError::BadRequest(format!("Invalid query parameters: {}", e)))?;
value
.validate()
.map_err(|e| HttpError::ValidationError(format_validation_errors(&e)))?;
Ok(ValidatedQuery(value))
}
}
#[cfg(not(feature = "validation"))]
#[async_trait]
impl<T, S> FromRequest<S> for ValidatedQuery<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = HttpError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let axum::extract::Query(value) = axum::extract::Query::<T>::from_request(req, state)
.await
.map_err(|e| HttpError::BadRequest(format!("Invalid query parameters: {}", e)))?;
Ok(ValidatedQuery(value))
}
}
#[cfg(feature = "validation")]
fn format_validation_errors(errors: &validator::ValidationErrors) -> String {
use std::fmt::Write;
let mut message = String::new();
let mut first = true;
for (field, field_errors) in errors.field_errors() {
for error in field_errors {
if !first {
write!(&mut message, "; ").unwrap();
}
first = false;
write!(&mut message, "{}: ", field).unwrap();
if let Some(msg) = &error.message {
write!(&mut message, "{}", msg).unwrap();
} else {
write!(&mut message, "validation failed ({})", error.code).unwrap();
}
}
}
if message.is_empty() {
"Validation failed".to_string()
} else {
message
}
}
#[cfg(all(test, feature = "validation"))]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
routing::post,
Router,
};
use serde::{Deserialize, Serialize};
use tower::ServiceExt;
use validator::Validate;
#[derive(Debug, Deserialize, Serialize, Validate)]
struct TestPayload {
#[validate(length(min = 3, max = 10))]
name: String,
#[validate(range(min = 18, max = 100))]
age: u32,
}
async fn test_handler(ValidatedJson(payload): ValidatedJson<TestPayload>) -> StatusCode {
assert_eq!(payload.name.len(), 5);
StatusCode::OK
}
#[tokio::test]
async fn test_validated_json_success() {
let app = Router::new().route("/", post(test_handler));
let payload = TestPayload {
name: "Alice".to_string(),
age: 25,
};
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&payload).unwrap()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_validated_json_validation_error() {
let app = Router::new().route("/", post(test_handler));
let payload = TestPayload {
name: "AB".to_string(), age: 25,
};
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&payload).unwrap()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[tokio::test]
async fn test_validated_json_invalid_json() {
let app = Router::new().route("/", post(test_handler));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/")
.header("content-type", "application/json")
.body(Body::from("invalid json"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
}