axum_api_kit/
validated.rs1use axum::{
2 extract::{rejection::JsonRejection, FromRequest, Request},
3 http::StatusCode,
4 Json,
5};
6use serde::de::DeserializeOwned;
7use validator::Validate;
8
9use crate::ApiError;
10
11#[derive(Debug, Clone)]
47pub struct ValidatedJson<T>(pub T);
48
49impl<T, S> FromRequest<S> for ValidatedJson<T>
50where
51 T: DeserializeOwned + Validate,
52 S: Send + Sync,
53{
54 type Rejection = (StatusCode, Json<ApiError>);
55
56 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
57 let Json(value) = Json::<T>::from_request(req, state)
58 .await
59 .map_err(json_rejection_to_api_error)?;
60
61 value.validate().map_err(|errors| {
62 (
63 StatusCode::UNPROCESSABLE_ENTITY,
64 Json(ApiError::from(errors)),
65 )
66 })?;
67
68 Ok(ValidatedJson(value))
69 }
70}
71
72fn json_rejection_to_api_error(rejection: JsonRejection) -> (StatusCode, Json<ApiError>) {
76 let code = match &rejection {
77 JsonRejection::JsonSyntaxError(_) => "INVALID_JSON",
78 JsonRejection::JsonDataError(_) => "INVALID_BODY",
79 JsonRejection::MissingJsonContentType(_) => "UNSUPPORTED_MEDIA_TYPE",
80 _ => "BAD_REQUEST",
81 };
82 (
83 rejection.status(),
84 Json(ApiError::new(code, rejection.body_text())),
85 )
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use axum::{body::Body, http::header::CONTENT_TYPE, http::Request};
92 use serde::Deserialize;
93
94 #[derive(Debug, Deserialize, Validate)]
95 struct Input {
96 #[validate(length(min = 2))]
97 name: String,
98 }
99
100 async fn extract(body: &str, with_content_type: bool) -> Result<Input, (StatusCode, ApiError)> {
101 let mut builder = Request::builder().method("POST").uri("/");
102 if with_content_type {
103 builder = builder.header(CONTENT_TYPE, "application/json");
104 }
105 let req = builder.body(Body::from(body.to_owned())).unwrap();
106 ValidatedJson::<Input>::from_request(req, &())
107 .await
108 .map(|ValidatedJson(v)| v)
109 .map_err(|(status, Json(err))| (status, err))
110 }
111
112 #[tokio::test]
113 async fn valid_body_extracts() {
114 let input = extract(r#"{"name":"abcd"}"#, true).await.unwrap();
115 assert_eq!(input.name, "abcd");
116 }
117
118 #[tokio::test]
119 async fn malformed_json_is_invalid_json() {
120 let (status, err) = extract("{not json", true).await.unwrap_err();
121 assert_eq!(status, StatusCode::BAD_REQUEST);
122 assert_eq!(err.code, "INVALID_JSON");
123 }
124
125 #[tokio::test]
126 async fn wrong_shape_is_invalid_body() {
127 let (status, err) = extract(r#"{"name":123}"#, true).await.unwrap_err();
128 assert_eq!(status, StatusCode::UNPROCESSABLE_ENTITY);
129 assert_eq!(err.code, "INVALID_BODY");
130 }
131
132 #[tokio::test]
133 async fn missing_content_type_is_unsupported_media_type() {
134 let (status, err) = extract(r#"{"name":"abcd"}"#, false).await.unwrap_err();
135 assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE);
136 assert_eq!(err.code, "UNSUPPORTED_MEDIA_TYPE");
137 }
138
139 #[tokio::test]
140 async fn validation_failure_is_validation_error_with_fields() {
141 let (status, err) = extract(r#"{"name":"a"}"#, true).await.unwrap_err();
142 assert_eq!(status, StatusCode::UNPROCESSABLE_ENTITY);
143 assert_eq!(err.code, "VALIDATION_ERROR");
144 let v = serde_json::to_value(&err).unwrap();
145 assert!(v["details"]["fields"]["name"].is_array());
146 }
147}