actix_web_validation/
custom.rs1use crate::validated_definition;
8use actix_web::dev::{ServiceFactory, ServiceRequest};
9use actix_web::http::StatusCode;
10use actix_web::FromRequest;
11use actix_web::{App, HttpRequest, HttpResponse, ResponseError};
12use std::fmt::Display;
13use std::future::Future;
14use std::sync::Arc;
15use std::{fmt::Debug, ops::Deref, pin::Pin, task::Poll};
16use thiserror::Error;
17
18pub trait Validate {
20 fn validate(&self) -> Result<(), Vec<ValidationError>>;
21}
22
23#[derive(Debug)]
25pub struct ValidationError {
26 message: String,
27}
28
29impl Display for ValidationError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "{}", self.message)
32 }
33}
34
35pub struct Validated<T>(pub T);
62
63validated_definition!();
64
65pub struct ValidatedFut<T: FromRequest> {
66 req: actix_web::HttpRequest,
67 fut: <T as FromRequest>::Future,
68 error_handler: Option<ValidationErrHandler>,
69}
70impl<T> Future for ValidatedFut<T>
71where
72 T: FromRequest + Debug + Deref,
73 T::Future: Unpin,
74 T::Target: Validate,
75{
76 type Output = Result<Validated<T>, actix_web::Error>;
77
78 fn poll(
79 self: std::pin::Pin<&mut Self>,
80 cx: &mut std::task::Context<'_>,
81 ) -> std::task::Poll<Self::Output> {
82 let this = self.get_mut();
83 let Poll::Ready(res) = Pin::new(&mut this.fut).poll(cx) else {
84 return std::task::Poll::Pending;
85 };
86
87 let res = match res {
88 Ok(data) => {
89 if let Err(e) = data.validate() {
90 if let Some(error_handler) = &this.error_handler {
91 Err((*error_handler)(e, &this.req))
92 } else {
93 let err: Error = e.into();
94 Err(err.into())
95 }
96 } else {
97 Ok(Validated(data))
98 }
99 }
100 Err(e) => Err(e.into()),
101 };
102
103 Poll::Ready(res)
104 }
105}
106
107impl<T> FromRequest for Validated<T>
108where
109 T: FromRequest + Debug + Deref,
110 T::Future: Unpin,
111 T::Target: Validate,
112{
113 type Error = actix_web::Error;
114
115 type Future = ValidatedFut<T>;
116
117 fn from_request(
118 req: &actix_web::HttpRequest,
119 payload: &mut actix_web::dev::Payload,
120 ) -> Self::Future {
121 let error_handler = req
122 .app_data::<ValidationErrorHandler>()
123 .map(|h| h.handler.clone());
124
125 let fut = T::from_request(req, payload);
126
127 ValidatedFut {
128 fut,
129 error_handler,
130 req: req.clone(),
131 }
132 }
133}
134
135#[derive(Error, Debug)]
136struct Error {
137 errors: Vec<ValidationError>,
138}
139
140impl From<Vec<ValidationError>> for Error {
141 fn from(value: Vec<ValidationError>) -> Self {
142 Self { errors: value }
143 }
144}
145
146impl Display for Error {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 write!(
149 f,
150 "{}",
151 self.errors
152 .iter()
153 .map(|e| e.message.as_ref())
154 .collect::<Vec<_>>()
155 .join("\n")
156 )
157 }
158}
159
160impl ResponseError for Error {
161 fn error_response(&self) -> HttpResponse {
162 HttpResponse::build(StatusCode::BAD_REQUEST).body(format!(
163 "Validation errors in fields:\n{}",
164 &self
165 .errors
166 .iter()
167 .map(|err| { format!("\t{}", err) })
168 .collect::<Vec<_>>()
169 .join("\n")
170 ))
171 }
172}
173
174pub type ValidationErrHandler =
175 Arc<dyn Fn(Vec<ValidationError>, &HttpRequest) -> actix_web::Error + Send + Sync>;
176
177struct ValidationErrorHandler {
178 handler: ValidationErrHandler,
179}
180
181pub trait ValidationErrorHandlerExt {
183 fn validation_error_handler(self, handler: ValidationErrHandler) -> Self;
185}
186
187impl<T> ValidationErrorHandlerExt for App<T>
188where
189 T: ServiceFactory<ServiceRequest, Config = (), Error = actix_web::Error, InitError = ()>,
190{
191 fn validation_error_handler(self, handler: ValidationErrHandler) -> Self {
192 self.app_data(ValidationErrorHandler { handler })
193 }
194}
195
196impl ValidationErrorHandlerExt for &mut actix_web::web::ServiceConfig {
197 fn validation_error_handler(self, handler: ValidationErrHandler) -> Self {
198 self.app_data(ValidationErrorHandler { handler })
199 }
200}
201
202#[cfg(test)]
203mod test {
204 use super::*;
205 use actix_web::web::Bytes;
206 use actix_web::{http::header::ContentType, post, test, web::Json, App, Responder};
207 use serde::{Deserialize, Serialize};
208
209 #[derive(Debug, Deserialize, Serialize)]
210 struct ExamplePayload {
211 name: String,
212 }
213
214 impl Validate for ExamplePayload {
215 fn validate(&self) -> Result<(), Vec<ValidationError>> {
216 if self.name.len() > 4 {
217 Ok(())
218 } else {
219 Err(vec![ValidationError {
220 message: "name not long enough".to_string(),
221 }])
222 }
223 }
224 }
225
226 #[post("/")]
227 async fn endpoint(v: Validated<Json<ExamplePayload>>) -> impl Responder {
228 assert!(v.name.len() > 4);
229 HttpResponse::Ok().body(())
230 }
231
232 #[actix_web::test]
233 async fn should_validate_simple() {
234 let app = test::init_service(App::new().service(endpoint)).await;
235
236 let req = test::TestRequest::post()
238 .uri("/")
239 .insert_header(ContentType::plaintext())
240 .set_json(ExamplePayload {
241 name: "123456".to_string(),
242 })
243 .to_request();
244 let resp = test::call_service(&app, req).await;
245 assert_eq!(resp.status().as_u16(), 200);
246
247 let req = test::TestRequest::post()
249 .uri("/")
250 .insert_header(ContentType::plaintext())
251 .set_json(ExamplePayload {
252 name: "1234".to_string(),
253 })
254 .to_request();
255 let resp = test::call_service(&app, req).await;
256 assert_eq!(resp.status().as_u16(), 400);
257 }
258
259 #[actix_web::test]
260 async fn should_respond_with_errors_correctly() {
261 let app = test::init_service(App::new().service(endpoint)).await;
262
263 let req = test::TestRequest::post()
265 .uri("/")
266 .insert_header(ContentType::plaintext())
267 .set_json(ExamplePayload {
268 name: "1234".to_string(),
269 })
270 .to_request();
271 let result = test::call_and_read_body(&app, req).await;
272 assert_eq!(
273 result,
274 Bytes::from_static(b"Validation errors in fields:\n\tname not long enough")
275 );
276 }
277
278 #[derive(Debug, Serialize, Error)]
279 struct CustomErrorResponse {
280 custom_message: String,
281 errors: Vec<String>,
282 }
283
284 impl Display for CustomErrorResponse {
285 fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 unimplemented!()
287 }
288 }
289
290 impl ResponseError for CustomErrorResponse {
291 fn status_code(&self) -> actix_web::http::StatusCode {
292 actix_web::http::StatusCode::BAD_REQUEST
293 }
294
295 fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
296 HttpResponse::build(self.status_code()).body(serde_json::to_string(self).unwrap())
297 }
298 }
299
300 fn error_handler(errors: Vec<ValidationError>, _: &HttpRequest) -> actix_web::Error {
301 CustomErrorResponse {
302 custom_message: "My custom message".to_string(),
303 errors: errors.iter().map(|err| err.message.clone()).collect(),
304 }
305 .into()
306 }
307
308 #[actix_web::test]
309 async fn should_use_allow_custom_error_responses() {
310 let app = test::init_service(
311 App::new()
312 .service(endpoint)
313 .validation_error_handler(Arc::new(error_handler)),
314 )
315 .await;
316
317 let req = test::TestRequest::post()
318 .uri("/")
319 .insert_header(ContentType::plaintext())
320 .set_json(ExamplePayload {
321 name: "1234".to_string(),
322 })
323 .to_request();
324 let result = test::call_and_read_body(&app, req).await;
325 assert_eq!(
326 result,
327 Bytes::from_static(
328 b"{\"custom_message\":\"My custom message\",\"errors\":[\"name not long enough\"]}"
329 )
330 );
331 }
332
333 #[test]
334 async fn debug_for_validated_should_work() {
335 let v = Validated(ExamplePayload {
336 name: "abcde".to_string(),
337 });
338
339 assert_eq!(
340 "Validated(ExamplePayload { name: \"abcde\" })",
341 format!("{v:?}")
342 );
343 }
344}