by_loco/controller/extractor/
validate.rs

1use crate::Error;
2use axum::extract::{Form, FromRequest, Json, Request};
3use serde::de::DeserializeOwned;
4use validator::Validate;
5
6#[derive(Debug, Clone, Copy, Default)]
7pub struct JsonValidateWithMessage<T>(pub T);
8
9impl<T, S> FromRequest<S> for JsonValidateWithMessage<T>
10where
11    T: DeserializeOwned + Validate,
12    S: Send + Sync,
13{
14    type Rejection = Error;
15
16    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
17        let Json(value) = Json::<T>::from_request(req, state).await?;
18        value.validate()?;
19        Ok(Self(value))
20    }
21}
22
23#[derive(Debug, Clone, Copy, Default)]
24pub struct FormValidateWithMessage<T>(pub T);
25
26impl<T, S> FromRequest<S> for FormValidateWithMessage<T>
27where
28    T: DeserializeOwned + Validate,
29    S: Send + Sync,
30{
31    type Rejection = Error;
32
33    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
34        let Form(value) = Form::<T>::from_request(req, state).await?;
35        value.validate()?;
36        Ok(Self(value))
37    }
38}
39
40#[derive(Debug, Clone, Copy, Default)]
41pub struct JsonValidate<T>(pub T);
42
43impl<T, S> FromRequest<S> for JsonValidate<T>
44where
45    T: DeserializeOwned + Validate,
46    S: Send + Sync,
47{
48    type Rejection = Error;
49
50    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
51        let Json(value) = Json::<T>::from_request(req, state).await?;
52        value.validate().map_err(|err| {
53            tracing::debug!(err = ?err, "request validation error occurred");
54            Error::BadRequest(String::new())
55        })?;
56        Ok(Self(value))
57    }
58}
59
60#[derive(Debug, Clone, Copy, Default)]
61pub struct FormValidate<T>(pub T);
62
63impl<T, S> FromRequest<S> for FormValidate<T>
64where
65    T: DeserializeOwned + Validate,
66    S: Send + Sync,
67{
68    type Rejection = Error;
69
70    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
71        let Form(value) = Form::<T>::from_request(req, state).await?;
72        value.validate().map_err(|err| {
73            tracing::debug!(err = ?err, "request validation error occurred");
74            Error::BadRequest(String::new())
75        })?;
76        Ok(Self(value))
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use axum::{
84        body::{to_bytes, Body},
85        http::{self, Request as HttpRequest, StatusCode},
86        response::IntoResponse,
87    };
88    use serde::{Deserialize, Serialize};
89    use serde_json::{json, Value};
90    use validator::Validate;
91
92    // Define a test struct that implements Validate
93    #[derive(Debug, Serialize, Deserialize, Validate)]
94    struct TestUser {
95        #[validate(length(min = 3, message = "username must be at least 3 characters"))]
96        username: String,
97        #[validate(email(message = "email must be valid"))]
98        email: String,
99    }
100
101    // Helper function to create a mock JSON request
102    fn create_json_request(json: &str) -> HttpRequest<Body> {
103        HttpRequest::builder()
104            .method(http::Method::POST)
105            .uri("/test")
106            .header(http::header::CONTENT_TYPE, "application/json")
107            .body(Body::from(json.to_string()))
108            .unwrap()
109    }
110
111    // Helper function to create a mock Form request
112    fn create_form_request(form_data: &str) -> HttpRequest<Body> {
113        HttpRequest::builder()
114            .method(http::Method::POST)
115            .uri("/test")
116            .header(
117                http::header::CONTENT_TYPE,
118                "application/x-www-form-urlencoded",
119            )
120            .body(Body::from(form_data.to_string()))
121            .unwrap()
122    }
123
124    // Helper function to check the status code and get JSON response
125    async fn assert_response_status_and_body(
126        err: Error,
127        expected_status: StatusCode,
128        expected_json: Value,
129    ) {
130        let response = err.into_response();
131        assert_eq!(response.status(), expected_status);
132
133        let body = to_bytes(response.into_body(), 1024 * 1024)
134            .await
135            .expect("Failed to read response body");
136
137        let body_str = String::from_utf8(body.to_vec()).expect("Response body is not valid UTF-8");
138
139        let actual_json =
140            serde_json::from_str::<Value>(&body_str).expect("Response body is not valid JSON");
141
142        assert_eq!(actual_json, expected_json);
143    }
144
145    #[tokio::test]
146    async fn test_json_validate_with_message_valid() {
147        let valid_json = r#"{"username": "valid_user", "email": "test@example.com"}"#;
148        let request = create_json_request(valid_json);
149
150        let result = JsonValidateWithMessage::<TestUser>::from_request(request, &()).await;
151        assert!(result.is_ok());
152
153        let user = result.unwrap().0;
154        assert_eq!(user.username, "valid_user");
155        assert_eq!(user.email, "test@example.com");
156    }
157
158    #[tokio::test]
159    async fn test_json_validate_with_message_invalid() {
160        let invalid_json = r#"{"username": "ab", "email": "invalid-email"}"#;
161        let request = create_json_request(invalid_json);
162
163        let result = JsonValidateWithMessage::<TestUser>::from_request(request, &()).await;
164        assert!(result.is_err());
165
166        let expected = json!({
167            "errors": {
168                "username": [
169                    {
170                        "code": "length",
171                        "message": "username must be at least 3 characters",
172                        "params": {
173                            "min": 3,
174                            "value": "ab"
175                        }
176                    }
177                ],
178                "email": [
179                    {
180                        "code": "email",
181                        "message": "email must be valid",
182                        "params": {
183                            "value": "invalid-email"
184                        }
185                    }
186                ]
187            }
188        });
189
190        assert_response_status_and_body(result.unwrap_err(), StatusCode::BAD_REQUEST, expected)
191            .await;
192    }
193
194    #[tokio::test]
195    async fn test_form_validate_with_message_valid() {
196        let valid_form = "username=valid_user&email=test@example.com";
197        let request = create_form_request(valid_form);
198
199        let result = FormValidateWithMessage::<TestUser>::from_request(request, &()).await;
200        assert!(result.is_ok());
201
202        let user = result.unwrap().0;
203        assert_eq!(user.username, "valid_user");
204        assert_eq!(user.email, "test@example.com");
205    }
206
207    #[tokio::test]
208    async fn test_form_validate_with_message_invalid() {
209        let invalid_form = "username=ab&email=invalid-email";
210        let request = create_form_request(invalid_form);
211
212        let result = FormValidateWithMessage::<TestUser>::from_request(request, &()).await;
213        assert!(result.is_err());
214
215        let expected = json!({
216            "errors": {
217                "username": [
218                    {
219                        "code": "length",
220                        "message": "username must be at least 3 characters",
221                        "params": {
222                            "min": 3,
223                            "value": "ab"
224                        }
225                    }
226                ],
227                "email": [
228                    {
229                        "code": "email",
230                        "message": "email must be valid",
231                        "params": {
232                            "value": "invalid-email"
233                        }
234                    }
235                ]
236            }
237        });
238
239        assert_response_status_and_body(result.unwrap_err(), StatusCode::BAD_REQUEST, expected)
240            .await;
241    }
242
243    #[tokio::test]
244    async fn test_json_validate_valid() {
245        let valid_json = r#"{"username": "valid_user", "email": "test@example.com"}"#;
246        let request = create_json_request(valid_json);
247
248        let result = JsonValidate::<TestUser>::from_request(request, &()).await;
249        assert!(result.is_ok());
250
251        let user = result.unwrap().0;
252        assert_eq!(user.username, "valid_user");
253        assert_eq!(user.email, "test@example.com");
254    }
255
256    #[tokio::test]
257    async fn test_json_validate_invalid() {
258        let invalid_json = r#"{"username": "ab", "email": "invalid-email"}"#;
259        let request = create_json_request(invalid_json);
260
261        let result = JsonValidate::<TestUser>::from_request(request, &()).await;
262        assert!(result.is_err());
263
264        let err = result.unwrap_err();
265        if let Error::BadRequest(msg) = &err {
266            assert_eq!(msg, &String::new());
267        } else {
268            panic!("Expected BadRequest error");
269        }
270
271        let expected = json!({
272            "error": "Bad Request"
273        });
274
275        assert_response_status_and_body(err, StatusCode::BAD_REQUEST, expected).await;
276    }
277
278    #[tokio::test]
279    async fn test_form_validate_valid() {
280        let valid_form = "username=valid_user&email=test@example.com";
281        let request = create_form_request(valid_form);
282
283        let result = FormValidate::<TestUser>::from_request(request, &()).await;
284        assert!(result.is_ok());
285
286        let user = result.unwrap().0;
287        assert_eq!(user.username, "valid_user");
288        assert_eq!(user.email, "test@example.com");
289    }
290
291    #[tokio::test]
292    async fn test_form_validate_invalid() {
293        let invalid_form = "username=ab&email=invalid-email";
294        let request = create_form_request(invalid_form);
295
296        let result = FormValidate::<TestUser>::from_request(request, &()).await;
297        assert!(result.is_err());
298
299        let err = result.unwrap_err();
300        if let Error::BadRequest(msg) = &err {
301            assert_eq!(msg, &String::new());
302        } else {
303            panic!("Expected BadRequest error");
304        }
305
306        let expected = json!({
307            "error": "Bad Request"
308        });
309
310        assert_response_status_and_body(err, StatusCode::BAD_REQUEST, expected).await;
311    }
312
313    #[tokio::test]
314    async fn test_malformed_json() {
315        let invalid_json = r#"{"username": "valid_user", "email": "test@example.com"#; // Missing closing brace
316        let request = create_json_request(invalid_json);
317
318        let result = JsonValidate::<TestUser>::from_request(request, &()).await;
319        assert!(result.is_err());
320
321        let expected = json!({
322            "error": "Bad Request"
323        });
324
325        assert_response_status_and_body(result.unwrap_err(), StatusCode::BAD_REQUEST, expected)
326            .await;
327    }
328
329    #[tokio::test]
330    async fn test_malformed_form() {
331        let invalid_form = "username=valid_user&email%invalid_format";
332        let request = create_form_request(invalid_form);
333
334        let result = FormValidate::<TestUser>::from_request(request, &()).await;
335        assert!(result.is_err());
336
337        let expected = json!({
338            "error": "internal_server_error",
339            "description": "Internal Server Error"
340        });
341
342        assert_response_status_and_body(
343            result.unwrap_err(),
344            StatusCode::INTERNAL_SERVER_ERROR,
345            expected,
346        )
347        .await;
348    }
349}