lmrc_http_common/
extractors.rs1use crate::error::HttpError;
35use async_trait::async_trait;
36use axum::{
37 extract::{FromRequest, Request},
38 Json,
39};
40use serde::de::DeserializeOwned;
41
42#[cfg(feature = "validation")]
43use validator::Validate;
44
45#[derive(Debug, Clone, Copy, Default)]
73pub struct ValidatedJson<T>(pub T);
74
75#[cfg(feature = "validation")]
76#[async_trait]
77impl<T, S> FromRequest<S> for ValidatedJson<T>
78where
79 T: DeserializeOwned + Validate,
80 S: Send + Sync,
81{
82 type Rejection = HttpError;
83
84 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
85 let Json(value) = Json::<T>::from_request(req, state)
86 .await
87 .map_err(|e| HttpError::BadRequest(format!("Invalid JSON: {}", e)))?;
88
89 value
90 .validate()
91 .map_err(|e| HttpError::ValidationError(format_validation_errors(&e)))?;
92
93 Ok(ValidatedJson(value))
94 }
95}
96
97#[cfg(not(feature = "validation"))]
98#[async_trait]
99impl<T, S> FromRequest<S> for ValidatedJson<T>
100where
101 T: DeserializeOwned,
102 S: Send + Sync,
103{
104 type Rejection = HttpError;
105
106 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
107 let Json(value) = Json::<T>::from_request(req, state)
108 .await
109 .map_err(|e| HttpError::BadRequest(format!("Invalid JSON: {}", e)))?;
110
111 Ok(ValidatedJson(value))
112 }
113}
114
115#[derive(Debug, Clone, Copy, Default)]
141pub struct ValidatedQuery<T>(pub T);
142
143#[cfg(feature = "validation")]
144#[async_trait]
145impl<T, S> FromRequest<S> for ValidatedQuery<T>
146where
147 T: DeserializeOwned + Validate,
148 S: Send + Sync,
149{
150 type Rejection = HttpError;
151
152 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
153 let axum::extract::Query(value) = axum::extract::Query::<T>::from_request(req, state)
154 .await
155 .map_err(|e| HttpError::BadRequest(format!("Invalid query parameters: {}", e)))?;
156
157 value
158 .validate()
159 .map_err(|e| HttpError::ValidationError(format_validation_errors(&e)))?;
160
161 Ok(ValidatedQuery(value))
162 }
163}
164
165#[cfg(not(feature = "validation"))]
166#[async_trait]
167impl<T, S> FromRequest<S> for ValidatedQuery<T>
168where
169 T: DeserializeOwned,
170 S: Send + Sync,
171{
172 type Rejection = HttpError;
173
174 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
175 let axum::extract::Query(value) = axum::extract::Query::<T>::from_request(req, state)
176 .await
177 .map_err(|e| HttpError::BadRequest(format!("Invalid query parameters: {}", e)))?;
178
179 Ok(ValidatedQuery(value))
180 }
181}
182
183#[cfg(feature = "validation")]
185fn format_validation_errors(errors: &validator::ValidationErrors) -> String {
186 use std::fmt::Write;
187
188 let mut message = String::new();
189 let mut first = true;
190
191 for (field, field_errors) in errors.field_errors() {
192 for error in field_errors {
193 if !first {
194 write!(&mut message, "; ").unwrap();
195 }
196 first = false;
197
198 write!(&mut message, "{}: ", field).unwrap();
199
200 if let Some(msg) = &error.message {
201 write!(&mut message, "{}", msg).unwrap();
202 } else {
203 write!(&mut message, "validation failed ({})", error.code).unwrap();
204 }
205 }
206 }
207
208 if message.is_empty() {
209 "Validation failed".to_string()
210 } else {
211 message
212 }
213}
214
215#[cfg(all(test, feature = "validation"))]
216mod tests {
217 use super::*;
218 use axum::{
219 body::Body,
220 http::{Request, StatusCode},
221 routing::post,
222 Router,
223 };
224 use serde::{Deserialize, Serialize};
225 use tower::ServiceExt;
226 use validator::Validate;
227
228 #[derive(Debug, Deserialize, Serialize, Validate)]
229 struct TestPayload {
230 #[validate(length(min = 3, max = 10))]
231 name: String,
232 #[validate(range(min = 18, max = 100))]
233 age: u32,
234 }
235
236 async fn test_handler(ValidatedJson(payload): ValidatedJson<TestPayload>) -> StatusCode {
237 assert_eq!(payload.name.len(), 5);
238 StatusCode::OK
239 }
240
241 #[tokio::test]
242 async fn test_validated_json_success() {
243 let app = Router::new().route("/", post(test_handler));
244
245 let payload = TestPayload {
246 name: "Alice".to_string(),
247 age: 25,
248 };
249
250 let response = app
251 .oneshot(
252 Request::builder()
253 .method("POST")
254 .uri("/")
255 .header("content-type", "application/json")
256 .body(Body::from(serde_json::to_string(&payload).unwrap()))
257 .unwrap(),
258 )
259 .await
260 .unwrap();
261
262 assert_eq!(response.status(), StatusCode::OK);
263 }
264
265 #[tokio::test]
266 async fn test_validated_json_validation_error() {
267 let app = Router::new().route("/", post(test_handler));
268
269 let payload = TestPayload {
270 name: "AB".to_string(), age: 25,
272 };
273
274 let response = app
275 .oneshot(
276 Request::builder()
277 .method("POST")
278 .uri("/")
279 .header("content-type", "application/json")
280 .body(Body::from(serde_json::to_string(&payload).unwrap()))
281 .unwrap(),
282 )
283 .await
284 .unwrap();
285
286 assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
287 }
288
289 #[tokio::test]
290 async fn test_validated_json_invalid_json() {
291 let app = Router::new().route("/", post(test_handler));
292
293 let response = app
294 .oneshot(
295 Request::builder()
296 .method("POST")
297 .uri("/")
298 .header("content-type", "application/json")
299 .body(Body::from("invalid json"))
300 .unwrap(),
301 )
302 .await
303 .unwrap();
304
305 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
306 }
307}