lmrc_http_common/
extractors.rs

1//! Request extractors with automatic validation
2//!
3//! This module provides Axum extractors that automatically validate request data
4//! using the `validator` crate.
5//!
6//! ## Example
7//!
8//! ```rust
9//! use axum::{Router, routing::post};
10//! use lmrc_http_common::extractors::ValidatedJson;
11//! use serde::Deserialize;
12//! use validator::Validate;
13//!
14//! #[derive(Debug, Deserialize, Validate)]
15//! struct CreateUser {
16//!     #[validate(length(min = 3, max = 50))]
17//!     username: String,
18//!     #[validate(email)]
19//!     email: String,
20//! }
21//!
22//! async fn create_user(
23//!     ValidatedJson(payload): ValidatedJson<CreateUser>
24//! ) -> &'static str {
25//!     // payload is automatically validated!
26//!     "User created"
27//! }
28//!
29//! # async fn example() {
30//! let app: Router = Router::new().route("/users", post(create_user));
31//! # }
32//! ```
33
34use crate::error::HttpError;
35use async_trait::async_trait;
36use axum::{
37    extract::{FromRequest, Request},
38    Json,
39};
40use serde::de::DeserializeOwned;
41
42#[cfg(feature = "validation")]
43use validator::Validate;
44
45/// JSON extractor with automatic validation
46///
47/// This extractor deserializes JSON and automatically validates it using
48/// the `validator` crate. If validation fails, it returns a 422 Unprocessable
49/// Entity response with validation error details.
50///
51/// ## Example
52///
53/// ```rust
54/// use lmrc_http_common::extractors::ValidatedJson;
55/// use serde::Deserialize;
56/// use validator::Validate;
57///
58/// #[derive(Deserialize, Validate)]
59/// struct SignupRequest {
60///     #[validate(length(min = 3))]
61///     username: String,
62///     #[validate(email)]
63///     email: String,
64/// }
65///
66/// async fn signup(
67///     ValidatedJson(req): ValidatedJson<SignupRequest>
68/// ) -> &'static str {
69///     "Success"
70/// }
71/// ```
72#[derive(Debug, Clone, Copy, Default)]
73pub struct ValidatedJson<T>(pub T);
74
75#[cfg(feature = "validation")]
76#[async_trait]
77impl<T, S> FromRequest<S> for ValidatedJson<T>
78where
79    T: DeserializeOwned + Validate,
80    S: Send + Sync,
81{
82    type Rejection = HttpError;
83
84    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
85        let Json(value) = Json::<T>::from_request(req, state)
86            .await
87            .map_err(|e| HttpError::BadRequest(format!("Invalid JSON: {}", e)))?;
88
89        value
90            .validate()
91            .map_err(|e| HttpError::ValidationError(format_validation_errors(&e)))?;
92
93        Ok(ValidatedJson(value))
94    }
95}
96
97#[cfg(not(feature = "validation"))]
98#[async_trait]
99impl<T, S> FromRequest<S> for ValidatedJson<T>
100where
101    T: DeserializeOwned,
102    S: Send + Sync,
103{
104    type Rejection = HttpError;
105
106    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
107        let Json(value) = Json::<T>::from_request(req, state)
108            .await
109            .map_err(|e| HttpError::BadRequest(format!("Invalid JSON: {}", e)))?;
110
111        Ok(ValidatedJson(value))
112    }
113}
114
115/// Query parameters extractor with automatic validation
116///
117/// This extractor deserializes query parameters and automatically validates them.
118///
119/// ## Example
120///
121/// ```rust
122/// use lmrc_http_common::extractors::ValidatedQuery;
123/// use serde::Deserialize;
124/// use validator::Validate;
125///
126/// #[derive(Deserialize, Validate)]
127/// struct Pagination {
128///     #[validate(range(min = 1, max = 100))]
129///     page: u32,
130///     #[validate(range(min = 1, max = 100))]
131///     per_page: u32,
132/// }
133///
134/// async fn list_items(
135///     ValidatedQuery(params): ValidatedQuery<Pagination>
136/// ) -> &'static str {
137///     "Items list"
138/// }
139/// ```
140#[derive(Debug, Clone, Copy, Default)]
141pub struct ValidatedQuery<T>(pub T);
142
143#[cfg(feature = "validation")]
144#[async_trait]
145impl<T, S> FromRequest<S> for ValidatedQuery<T>
146where
147    T: DeserializeOwned + Validate,
148    S: Send + Sync,
149{
150    type Rejection = HttpError;
151
152    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
153        let axum::extract::Query(value) = axum::extract::Query::<T>::from_request(req, state)
154            .await
155            .map_err(|e| HttpError::BadRequest(format!("Invalid query parameters: {}", e)))?;
156
157        value
158            .validate()
159            .map_err(|e| HttpError::ValidationError(format_validation_errors(&e)))?;
160
161        Ok(ValidatedQuery(value))
162    }
163}
164
165#[cfg(not(feature = "validation"))]
166#[async_trait]
167impl<T, S> FromRequest<S> for ValidatedQuery<T>
168where
169    T: DeserializeOwned,
170    S: Send + Sync,
171{
172    type Rejection = HttpError;
173
174    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
175        let axum::extract::Query(value) = axum::extract::Query::<T>::from_request(req, state)
176            .await
177            .map_err(|e| HttpError::BadRequest(format!("Invalid query parameters: {}", e)))?;
178
179        Ok(ValidatedQuery(value))
180    }
181}
182
183/// Format validation errors into a user-friendly string
184#[cfg(feature = "validation")]
185fn format_validation_errors(errors: &validator::ValidationErrors) -> String {
186    use std::fmt::Write;
187
188    let mut message = String::new();
189    let mut first = true;
190
191    for (field, field_errors) in errors.field_errors() {
192        for error in field_errors {
193            if !first {
194                write!(&mut message, "; ").unwrap();
195            }
196            first = false;
197
198            write!(&mut message, "{}: ", field).unwrap();
199
200            if let Some(msg) = &error.message {
201                write!(&mut message, "{}", msg).unwrap();
202            } else {
203                write!(&mut message, "validation failed ({})", error.code).unwrap();
204            }
205        }
206    }
207
208    if message.is_empty() {
209        "Validation failed".to_string()
210    } else {
211        message
212    }
213}
214
215#[cfg(all(test, feature = "validation"))]
216mod tests {
217    use super::*;
218    use axum::{
219        body::Body,
220        http::{Request, StatusCode},
221        routing::post,
222        Router,
223    };
224    use serde::{Deserialize, Serialize};
225    use tower::ServiceExt;
226    use validator::Validate;
227
228    #[derive(Debug, Deserialize, Serialize, Validate)]
229    struct TestPayload {
230        #[validate(length(min = 3, max = 10))]
231        name: String,
232        #[validate(range(min = 18, max = 100))]
233        age: u32,
234    }
235
236    async fn test_handler(ValidatedJson(payload): ValidatedJson<TestPayload>) -> StatusCode {
237        assert_eq!(payload.name.len(), 5);
238        StatusCode::OK
239    }
240
241    #[tokio::test]
242    async fn test_validated_json_success() {
243        let app = Router::new().route("/", post(test_handler));
244
245        let payload = TestPayload {
246            name: "Alice".to_string(),
247            age: 25,
248        };
249
250        let response = app
251            .oneshot(
252                Request::builder()
253                    .method("POST")
254                    .uri("/")
255                    .header("content-type", "application/json")
256                    .body(Body::from(serde_json::to_string(&payload).unwrap()))
257                    .unwrap(),
258            )
259            .await
260            .unwrap();
261
262        assert_eq!(response.status(), StatusCode::OK);
263    }
264
265    #[tokio::test]
266    async fn test_validated_json_validation_error() {
267        let app = Router::new().route("/", post(test_handler));
268
269        let payload = TestPayload {
270            name: "AB".to_string(), // Too short
271            age: 25,
272        };
273
274        let response = app
275            .oneshot(
276                Request::builder()
277                    .method("POST")
278                    .uri("/")
279                    .header("content-type", "application/json")
280                    .body(Body::from(serde_json::to_string(&payload).unwrap()))
281                    .unwrap(),
282            )
283            .await
284            .unwrap();
285
286        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
287    }
288
289    #[tokio::test]
290    async fn test_validated_json_invalid_json() {
291        let app = Router::new().route("/", post(test_handler));
292
293        let response = app
294            .oneshot(
295                Request::builder()
296                    .method("POST")
297                    .uri("/")
298                    .header("content-type", "application/json")
299                    .body(Body::from("invalid json"))
300                    .unwrap(),
301            )
302            .await
303            .unwrap();
304
305        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
306    }
307}