by_loco/controller/extractor/
validate.rs1use crate::Error;
2use axum::extract::{Form, FromRequest, Json, Request};
3use serde::de::DeserializeOwned;
4use validator::Validate;
5
6#[derive(Debug, Clone, Copy, Default)]
7pub struct JsonValidateWithMessage<T>(pub T);
8
9impl<T, S> FromRequest<S> for JsonValidateWithMessage<T>
10where
11 T: DeserializeOwned + Validate,
12 S: Send + Sync,
13{
14 type Rejection = Error;
15
16 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
17 let Json(value) = Json::<T>::from_request(req, state).await?;
18 value.validate()?;
19 Ok(Self(value))
20 }
21}
22
23#[derive(Debug, Clone, Copy, Default)]
24pub struct FormValidateWithMessage<T>(pub T);
25
26impl<T, S> FromRequest<S> for FormValidateWithMessage<T>
27where
28 T: DeserializeOwned + Validate,
29 S: Send + Sync,
30{
31 type Rejection = Error;
32
33 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
34 let Form(value) = Form::<T>::from_request(req, state).await?;
35 value.validate()?;
36 Ok(Self(value))
37 }
38}
39
40#[derive(Debug, Clone, Copy, Default)]
41pub struct JsonValidate<T>(pub T);
42
43impl<T, S> FromRequest<S> for JsonValidate<T>
44where
45 T: DeserializeOwned + Validate,
46 S: Send + Sync,
47{
48 type Rejection = Error;
49
50 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
51 let Json(value) = Json::<T>::from_request(req, state).await?;
52 value.validate().map_err(|err| {
53 tracing::debug!(err = ?err, "request validation error occurred");
54 Error::BadRequest(String::new())
55 })?;
56 Ok(Self(value))
57 }
58}
59
60#[derive(Debug, Clone, Copy, Default)]
61pub struct FormValidate<T>(pub T);
62
63impl<T, S> FromRequest<S> for FormValidate<T>
64where
65 T: DeserializeOwned + Validate,
66 S: Send + Sync,
67{
68 type Rejection = Error;
69
70 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
71 let Form(value) = Form::<T>::from_request(req, state).await?;
72 value.validate().map_err(|err| {
73 tracing::debug!(err = ?err, "request validation error occurred");
74 Error::BadRequest(String::new())
75 })?;
76 Ok(Self(value))
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use axum::{
84 body::{to_bytes, Body},
85 http::{self, Request as HttpRequest, StatusCode},
86 response::IntoResponse,
87 };
88 use serde::{Deserialize, Serialize};
89 use serde_json::{json, Value};
90 use validator::Validate;
91
92 #[derive(Debug, Serialize, Deserialize, Validate)]
94 struct TestUser {
95 #[validate(length(min = 3, message = "username must be at least 3 characters"))]
96 username: String,
97 #[validate(email(message = "email must be valid"))]
98 email: String,
99 }
100
101 fn create_json_request(json: &str) -> HttpRequest<Body> {
103 HttpRequest::builder()
104 .method(http::Method::POST)
105 .uri("/test")
106 .header(http::header::CONTENT_TYPE, "application/json")
107 .body(Body::from(json.to_string()))
108 .unwrap()
109 }
110
111 fn create_form_request(form_data: &str) -> HttpRequest<Body> {
113 HttpRequest::builder()
114 .method(http::Method::POST)
115 .uri("/test")
116 .header(
117 http::header::CONTENT_TYPE,
118 "application/x-www-form-urlencoded",
119 )
120 .body(Body::from(form_data.to_string()))
121 .unwrap()
122 }
123
124 async fn assert_response_status_and_body(
126 err: Error,
127 expected_status: StatusCode,
128 expected_json: Value,
129 ) {
130 let response = err.into_response();
131 assert_eq!(response.status(), expected_status);
132
133 let body = to_bytes(response.into_body(), 1024 * 1024)
134 .await
135 .expect("Failed to read response body");
136
137 let body_str = String::from_utf8(body.to_vec()).expect("Response body is not valid UTF-8");
138
139 let actual_json =
140 serde_json::from_str::<Value>(&body_str).expect("Response body is not valid JSON");
141
142 assert_eq!(actual_json, expected_json);
143 }
144
145 #[tokio::test]
146 async fn test_json_validate_with_message_valid() {
147 let valid_json = r#"{"username": "valid_user", "email": "test@example.com"}"#;
148 let request = create_json_request(valid_json);
149
150 let result = JsonValidateWithMessage::<TestUser>::from_request(request, &()).await;
151 assert!(result.is_ok());
152
153 let user = result.unwrap().0;
154 assert_eq!(user.username, "valid_user");
155 assert_eq!(user.email, "test@example.com");
156 }
157
158 #[tokio::test]
159 async fn test_json_validate_with_message_invalid() {
160 let invalid_json = r#"{"username": "ab", "email": "invalid-email"}"#;
161 let request = create_json_request(invalid_json);
162
163 let result = JsonValidateWithMessage::<TestUser>::from_request(request, &()).await;
164 assert!(result.is_err());
165
166 let expected = json!({
167 "errors": {
168 "username": [
169 {
170 "code": "length",
171 "message": "username must be at least 3 characters",
172 "params": {
173 "min": 3,
174 "value": "ab"
175 }
176 }
177 ],
178 "email": [
179 {
180 "code": "email",
181 "message": "email must be valid",
182 "params": {
183 "value": "invalid-email"
184 }
185 }
186 ]
187 }
188 });
189
190 assert_response_status_and_body(result.unwrap_err(), StatusCode::BAD_REQUEST, expected)
191 .await;
192 }
193
194 #[tokio::test]
195 async fn test_form_validate_with_message_valid() {
196 let valid_form = "username=valid_user&email=test@example.com";
197 let request = create_form_request(valid_form);
198
199 let result = FormValidateWithMessage::<TestUser>::from_request(request, &()).await;
200 assert!(result.is_ok());
201
202 let user = result.unwrap().0;
203 assert_eq!(user.username, "valid_user");
204 assert_eq!(user.email, "test@example.com");
205 }
206
207 #[tokio::test]
208 async fn test_form_validate_with_message_invalid() {
209 let invalid_form = "username=ab&email=invalid-email";
210 let request = create_form_request(invalid_form);
211
212 let result = FormValidateWithMessage::<TestUser>::from_request(request, &()).await;
213 assert!(result.is_err());
214
215 let expected = json!({
216 "errors": {
217 "username": [
218 {
219 "code": "length",
220 "message": "username must be at least 3 characters",
221 "params": {
222 "min": 3,
223 "value": "ab"
224 }
225 }
226 ],
227 "email": [
228 {
229 "code": "email",
230 "message": "email must be valid",
231 "params": {
232 "value": "invalid-email"
233 }
234 }
235 ]
236 }
237 });
238
239 assert_response_status_and_body(result.unwrap_err(), StatusCode::BAD_REQUEST, expected)
240 .await;
241 }
242
243 #[tokio::test]
244 async fn test_json_validate_valid() {
245 let valid_json = r#"{"username": "valid_user", "email": "test@example.com"}"#;
246 let request = create_json_request(valid_json);
247
248 let result = JsonValidate::<TestUser>::from_request(request, &()).await;
249 assert!(result.is_ok());
250
251 let user = result.unwrap().0;
252 assert_eq!(user.username, "valid_user");
253 assert_eq!(user.email, "test@example.com");
254 }
255
256 #[tokio::test]
257 async fn test_json_validate_invalid() {
258 let invalid_json = r#"{"username": "ab", "email": "invalid-email"}"#;
259 let request = create_json_request(invalid_json);
260
261 let result = JsonValidate::<TestUser>::from_request(request, &()).await;
262 assert!(result.is_err());
263
264 let err = result.unwrap_err();
265 if let Error::BadRequest(msg) = &err {
266 assert_eq!(msg, &String::new());
267 } else {
268 panic!("Expected BadRequest error");
269 }
270
271 let expected = json!({
272 "error": "Bad Request"
273 });
274
275 assert_response_status_and_body(err, StatusCode::BAD_REQUEST, expected).await;
276 }
277
278 #[tokio::test]
279 async fn test_form_validate_valid() {
280 let valid_form = "username=valid_user&email=test@example.com";
281 let request = create_form_request(valid_form);
282
283 let result = FormValidate::<TestUser>::from_request(request, &()).await;
284 assert!(result.is_ok());
285
286 let user = result.unwrap().0;
287 assert_eq!(user.username, "valid_user");
288 assert_eq!(user.email, "test@example.com");
289 }
290
291 #[tokio::test]
292 async fn test_form_validate_invalid() {
293 let invalid_form = "username=ab&email=invalid-email";
294 let request = create_form_request(invalid_form);
295
296 let result = FormValidate::<TestUser>::from_request(request, &()).await;
297 assert!(result.is_err());
298
299 let err = result.unwrap_err();
300 if let Error::BadRequest(msg) = &err {
301 assert_eq!(msg, &String::new());
302 } else {
303 panic!("Expected BadRequest error");
304 }
305
306 let expected = json!({
307 "error": "Bad Request"
308 });
309
310 assert_response_status_and_body(err, StatusCode::BAD_REQUEST, expected).await;
311 }
312
313 #[tokio::test]
314 async fn test_malformed_json() {
315 let invalid_json = r#"{"username": "valid_user", "email": "test@example.com"#; let request = create_json_request(invalid_json);
317
318 let result = JsonValidate::<TestUser>::from_request(request, &()).await;
319 assert!(result.is_err());
320
321 let expected = json!({
322 "error": "Bad Request"
323 });
324
325 assert_response_status_and_body(result.unwrap_err(), StatusCode::BAD_REQUEST, expected)
326 .await;
327 }
328
329 #[tokio::test]
330 async fn test_malformed_form() {
331 let invalid_form = "username=valid_user&email%invalid_format";
332 let request = create_form_request(invalid_form);
333
334 let result = FormValidate::<TestUser>::from_request(request, &()).await;
335 assert!(result.is_err());
336
337 let expected = json!({
338 "error": "internal_server_error",
339 "description": "Internal Server Error"
340 });
341
342 assert_response_status_and_body(
343 result.unwrap_err(),
344 StatusCode::INTERNAL_SERVER_ERROR,
345 expected,
346 )
347 .await;
348 }
349}