use crate::error::{Result, TidewayError};
use axum::{Json, extract::Request};
use serde::Deserialize;
use std::future::Future;
use validator::Validate;
pub struct ValidatedJson<T>(pub T);
impl<T, S> axum::extract::FromRequest<S> for ValidatedJson<T>
where
T: for<'de> Deserialize<'de> + Validate + Send,
S: Send + Sync,
{
type Rejection = TidewayError;
async fn from_request(req: Request, state: &S) -> std::result::Result<Self, Self::Rejection> {
let json: Json<T> = Json::from_request(req, state)
.await
.map_err(|e| TidewayError::bad_request(format!("Invalid JSON: {}", e)))?;
json.0.validate().map_err(|errors| {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(field, errors)| {
errors.iter().map(move |error| {
let msg = error
.message
.as_ref()
.map(|m| m.as_ref())
.unwrap_or_else(|| error.code.as_ref());
format!("{}: {}", field, msg)
})
})
.collect();
TidewayError::bad_request(format!("Validation failed: {}", error_messages.join(", ")))
})?;
Ok(ValidatedJson(json.0))
}
}
pub fn validate_json<T: Validate>(json: Json<T>) -> Result<ValidatedJson<T>> {
json.0.validate().map_err(|errors| {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(field, errors)| {
errors.iter().map(move |error| {
let msg = error
.message
.as_ref()
.map(|m| m.as_ref())
.unwrap_or_else(|| error.code.as_ref());
format!("{}: {}", field, msg)
})
})
.collect();
TidewayError::bad_request(format!("Validation failed: {}", error_messages.join(", ")))
})?;
Ok(ValidatedJson(json.0))
}
pub struct ValidatedQuery<T>(pub T);
impl<T, S> axum::extract::FromRequestParts<S> for ValidatedQuery<T>
where
T: for<'de> Deserialize<'de> + Validate + Send,
S: Send + Sync,
{
type Rejection = TidewayError;
fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> impl Future<Output = std::result::Result<Self, Self::Rejection>> + Send {
Box::pin(async move {
let query_string = parts.uri.query().unwrap_or("");
let query: T = serde_urlencoded::from_str(query_string).map_err(|e| {
TidewayError::bad_request(format!("Invalid query parameters: {}", e))
})?;
query.validate().map_err(|errors| {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(field, errors)| {
errors.iter().map(move |error| {
let msg = error
.message
.as_ref()
.map(|m| m.as_ref())
.unwrap_or_else(|| error.code.as_ref());
format!("{}: {}", field, msg)
})
})
.collect();
TidewayError::bad_request(format!(
"Validation failed: {}",
error_messages.join(", ")
))
})?;
Ok(ValidatedQuery(query))
})
}
}
pub struct ValidatedForm<T>(pub T);
pub fn validate_form<T: Validate>(form: axum::extract::Form<T>) -> Result<ValidatedForm<T>> {
form.0.validate().map_err(|errors| {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(field, errors)| {
errors.iter().map(move |error| {
let msg = error
.message
.as_ref()
.map(|m| m.as_ref())
.unwrap_or_else(|| error.code.as_ref());
format!("{}: {}", field, msg)
})
})
.collect();
TidewayError::bad_request(format!("Validation failed: {}", error_messages.join(", ")))
})?;
Ok(ValidatedForm(form.0))
}
#[cfg(test)]
mod tests {
use serde::Deserialize;
use validator::Validate;
#[derive(Deserialize, Validate)]
struct TestRequest {
#[validate(email)]
email: String,
#[validate(range(min = 18, max = 100))]
age: u32,
}
#[tokio::test]
async fn test_validated_json_success() {
let valid_request = TestRequest {
email: "test@example.com".to_string(),
age: 25,
};
assert!(valid_request.validate().is_ok());
}
#[test]
fn test_validation_failure() {
let invalid_request = TestRequest {
email: "not-an-email".to_string(),
age: 15, };
let result = invalid_request.validate();
assert!(result.is_err());
}
}