Skip to main content

axum_api_kit/
validated.rs

1use axum::{
2    extract::{rejection::JsonRejection, FromRequest, Request},
3    http::StatusCode,
4    Json,
5};
6use serde::de::DeserializeOwned;
7use validator::Validate;
8
9use crate::ApiError;
10
11/// An Axum extractor that deserializes a JSON request body and validates it with
12/// [`validator`](https://docs.rs/validator).
13///
14/// On success it yields `ValidatedJson(value)`. On failure it short-circuits the handler
15/// with a `(StatusCode, Json<ApiError>)` response:
16///
17/// - malformed JSON -> `400 Bad Request`, code `INVALID_JSON`
18/// - well-formed JSON of the wrong shape -> `422 Unprocessable Entity`, code `INVALID_BODY`
19/// - missing or incorrect `Content-Type` -> `415 Unsupported Media Type`, code
20///   `UNSUPPORTED_MEDIA_TYPE`
21/// - validation failure -> `422 Unprocessable Entity`, code `VALIDATION_ERROR` with
22///   field-level `details` (see [`ApiError`]'s `From<validator::ValidationErrors>` impl)
23///
24/// Requires the `validator` feature.
25///
26/// # Example
27///
28/// ```rust,no_run
29/// use axum_api_kit::ValidatedJson;
30/// use serde::Deserialize;
31/// use validator::Validate;
32///
33/// #[derive(Deserialize, Validate)]
34/// struct CreateUser {
35///     #[validate(length(min = 1, max = 100))]
36///     name: String,
37///     #[validate(email)]
38///     email: String,
39/// }
40///
41/// // The body is deserialized and validated before the handler body runs.
42/// async fn create_user(ValidatedJson(user): ValidatedJson<CreateUser>) {
43///     let _ = (user.name, user.email);
44/// }
45/// ```
46#[derive(Debug, Clone)]
47pub struct ValidatedJson<T>(pub T);
48
49impl<T, S> FromRequest<S> for ValidatedJson<T>
50where
51    T: DeserializeOwned + Validate,
52    S: Send + Sync,
53{
54    type Rejection = (StatusCode, Json<ApiError>);
55
56    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
57        let Json(value) = Json::<T>::from_request(req, state)
58            .await
59            .map_err(json_rejection_to_api_error)?;
60
61        value.validate().map_err(|errors| {
62            (
63                StatusCode::UNPROCESSABLE_ENTITY,
64                Json(ApiError::from(errors)),
65            )
66        })?;
67
68        Ok(ValidatedJson(value))
69    }
70}
71
72/// Map an Axum [`JsonRejection`] onto an [`ApiError`] with a stable machine-readable code.
73///
74/// The HTTP status is taken from the rejection itself so it stays in sync with Axum.
75fn json_rejection_to_api_error(rejection: JsonRejection) -> (StatusCode, Json<ApiError>) {
76    let code = match &rejection {
77        JsonRejection::JsonSyntaxError(_) => "INVALID_JSON",
78        JsonRejection::JsonDataError(_) => "INVALID_BODY",
79        JsonRejection::MissingJsonContentType(_) => "UNSUPPORTED_MEDIA_TYPE",
80        _ => "BAD_REQUEST",
81    };
82    (
83        rejection.status(),
84        Json(ApiError::new(code, rejection.body_text())),
85    )
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use axum::{body::Body, http::header::CONTENT_TYPE, http::Request};
92    use serde::Deserialize;
93
94    #[derive(Debug, Deserialize, Validate)]
95    struct Input {
96        #[validate(length(min = 2))]
97        name: String,
98    }
99
100    async fn extract(body: &str, with_content_type: bool) -> Result<Input, (StatusCode, ApiError)> {
101        let mut builder = Request::builder().method("POST").uri("/");
102        if with_content_type {
103            builder = builder.header(CONTENT_TYPE, "application/json");
104        }
105        let req = builder.body(Body::from(body.to_owned())).unwrap();
106        ValidatedJson::<Input>::from_request(req, &())
107            .await
108            .map(|ValidatedJson(v)| v)
109            .map_err(|(status, Json(err))| (status, err))
110    }
111
112    #[tokio::test]
113    async fn valid_body_extracts() {
114        let input = extract(r#"{"name":"abcd"}"#, true).await.unwrap();
115        assert_eq!(input.name, "abcd");
116    }
117
118    #[tokio::test]
119    async fn malformed_json_is_invalid_json() {
120        let (status, err) = extract("{not json", true).await.unwrap_err();
121        assert_eq!(status, StatusCode::BAD_REQUEST);
122        assert_eq!(err.code, "INVALID_JSON");
123    }
124
125    #[tokio::test]
126    async fn wrong_shape_is_invalid_body() {
127        let (status, err) = extract(r#"{"name":123}"#, true).await.unwrap_err();
128        assert_eq!(status, StatusCode::UNPROCESSABLE_ENTITY);
129        assert_eq!(err.code, "INVALID_BODY");
130    }
131
132    #[tokio::test]
133    async fn missing_content_type_is_unsupported_media_type() {
134        let (status, err) = extract(r#"{"name":"abcd"}"#, false).await.unwrap_err();
135        assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE);
136        assert_eq!(err.code, "UNSUPPORTED_MEDIA_TYPE");
137    }
138
139    #[tokio::test]
140    async fn validation_failure_is_validation_error_with_fields() {
141        let (status, err) = extract(r#"{"name":"a"}"#, true).await.unwrap_err();
142        assert_eq!(status, StatusCode::UNPROCESSABLE_ENTITY);
143        assert_eq!(err.code, "VALIDATION_ERROR");
144        let v = serde_json::to_value(&err).unwrap();
145        assert!(v["details"]["fields"]["name"].is_array());
146    }
147}