acton_htmx/extractors/
validated.rs1use axum::{
28 extract::{Form, FromRequest, Request},
29 http::StatusCode,
30 response::{IntoResponse, Response},
31};
32use serde::de::DeserializeOwned;
33use std::fmt;
34use validator::Validate;
35
36#[derive(Debug, Clone, Copy, Default)]
69pub struct ValidatedForm<T>(pub T);
70
71impl<T, S> FromRequest<S> for ValidatedForm<T>
72where
73 T: DeserializeOwned + Validate + 'static,
74 S: Send + Sync + 'static,
75{
76 type Rejection = ValidationError;
77
78 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
79 let Form(data) = Form::<T>::from_request(req, state)
81 .await
82 .map_err(|err| {
83 ValidationError::FormRejection(format!("Failed to parse form data: {err}"))
84 })?;
85
86 data.validate()
88 .map_err(ValidationError::Validation)?;
89
90 Ok(Self(data))
91 }
92}
93
94#[derive(Debug)]
99pub enum ValidationError {
100 FormRejection(String),
102 Validation(validator::ValidationErrors),
104}
105
106impl fmt::Display for ValidationError {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 match self {
109 Self::FormRejection(msg) => write!(f, "Form parsing error: {msg}"),
110 Self::Validation(errors) => {
111 write!(f, "Validation failed: ")?;
112 for (field, errors) in errors.field_errors() {
113 write!(f, "{field}: ")?;
114 for error in errors {
115 if let Some(message) = &error.message {
116 write!(f, "{message}, ")?;
117 } else {
118 write!(f, "{}, ", error.code)?;
119 }
120 }
121 }
122 Ok(())
123 }
124 }
125 }
126}
127
128impl std::error::Error for ValidationError {}
129
130impl IntoResponse for ValidationError {
131 fn into_response(self) -> Response {
132 match self {
133 Self::FormRejection(msg) => {
134 (StatusCode::BAD_REQUEST, format!("Invalid form data: {msg}")).into_response()
135 }
136 Self::Validation(errors) => {
137 let error_messages = format_validation_errors(&errors);
138 (
139 StatusCode::UNPROCESSABLE_ENTITY,
140 format!("Validation failed:\n{error_messages}"),
141 )
142 .into_response()
143 }
144 }
145 }
146}
147
148#[must_use]
162pub fn format_validation_errors(errors: &validator::ValidationErrors) -> String {
163 let mut messages = Vec::new();
164
165 for (field, field_errors) in errors.field_errors() {
166 for error in field_errors {
167 let message = error.message.as_ref().map_or_else(
168 || format!("{field}: {}", error.code),
169 ToString::to_string,
170 );
171 messages.push(message);
172 }
173 }
174
175 messages.join("\n")
176}
177
178#[must_use]
210pub fn validation_errors_json(errors: &validator::ValidationErrors) -> serde_json::Value {
211 let mut error_map = serde_json::Map::new();
212
213 for (field, field_errors) in errors.field_errors() {
214 let messages: Vec<String> = field_errors
215 .iter()
216 .map(|error| {
217 error.message.as_ref().map_or_else(
218 || error.code.to_string(),
219 ToString::to_string,
220 )
221 })
222 .collect();
223
224 error_map.insert(field.to_string(), serde_json::json!(messages));
225 }
226
227 serde_json::json!({
228 "errors": error_map
229 })
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use axum::{
236 body::Body,
237 http::{Method, Request, StatusCode},
238 routing::post,
239 Router,
240 };
241 use serde::Deserialize;
242 use tower::ServiceExt;
243 use validator::Validate;
244
245 #[derive(Debug, Deserialize, Validate)]
246 struct TestForm {
247 #[validate(email)]
248 email: String,
249 #[validate(length(min = 8))]
250 password: String,
251 }
252
253 async fn test_handler(ValidatedForm(form): ValidatedForm<TestForm>) -> String {
254 format!("Email: {}, Password length: {}", form.email, form.password.len())
255 }
256
257 #[tokio::test]
258 async fn test_valid_form() {
259 let app = Router::new().route("/", post(test_handler));
260
261 let request = Request::builder()
262 .method(Method::POST)
263 .uri("/")
264 .header("content-type", "application/x-www-form-urlencoded")
265 .body(Body::from("email=test@example.com&password=password123"))
266 .unwrap();
267
268 let response = app.oneshot(request).await.unwrap();
269
270 assert_eq!(response.status(), StatusCode::OK);
271 }
272
273 #[tokio::test]
274 async fn test_invalid_email() {
275 let app = Router::new().route("/", post(test_handler));
276
277 let request = Request::builder()
278 .method(Method::POST)
279 .uri("/")
280 .header("content-type", "application/x-www-form-urlencoded")
281 .body(Body::from("email=invalid-email&password=password123"))
282 .unwrap();
283
284 let response = app.oneshot(request).await.unwrap();
285
286 assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
287 }
288
289 #[tokio::test]
290 async fn test_short_password() {
291 let app = Router::new().route("/", post(test_handler));
292
293 let request = Request::builder()
294 .method(Method::POST)
295 .uri("/")
296 .header("content-type", "application/x-www-form-urlencoded")
297 .body(Body::from("email=test@example.com&password=short"))
298 .unwrap();
299
300 let response = app.oneshot(request).await.unwrap();
301
302 assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
303 }
304
305 #[test]
306 fn test_format_validation_errors() {
307 let mut errors = validator::ValidationErrors::new();
308 errors.add(
309 "email",
310 validator::ValidationError::new("email")
311 .with_message(std::borrow::Cow::Borrowed("Invalid email address")),
312 );
313
314 let formatted = format_validation_errors(&errors);
315 assert!(formatted.contains("Invalid email address"));
316 }
317
318 #[test]
319 fn test_validation_errors_json() {
320 let mut errors = validator::ValidationErrors::new();
321 errors.add(
322 "email",
323 validator::ValidationError::new("email")
324 .with_message(std::borrow::Cow::Borrowed("Invalid email address")),
325 );
326 errors.add(
327 "password",
328 validator::ValidationError::new("length")
329 .with_message(std::borrow::Cow::Borrowed("Password too short")),
330 );
331
332 let json = validation_errors_json(&errors);
333 assert!(json.get("errors").is_some());
334
335 let errors_obj = json.get("errors").unwrap().as_object().unwrap();
336 assert!(errors_obj.contains_key("email"));
337 assert!(errors_obj.contains_key("password"));
338 }
339
340 #[test]
341 fn test_validation_error_display() {
342 let mut errors = validator::ValidationErrors::new();
343 errors.add(
344 "email",
345 validator::ValidationError::new("email")
346 .with_message(std::borrow::Cow::Borrowed("Invalid email")),
347 );
348
349 let error = ValidationError::Validation(errors);
350 let display = format!("{error}");
351 assert!(display.contains("Validation failed"));
352 assert!(display.contains("email"));
353 }
354}