garde_actix_web/web/
query.rs1use crate::validate_for_request;
2use actix_web::dev::Payload;
3use actix_web::error::QueryPayloadError;
4use actix_web::{Error, FromRequest, HttpRequest};
5use derive_more::{AsRef, Deref, DerefMut, Display, From};
6use futures::future::{Ready, err, ok};
7use garde::Validate;
8use serde::de::DeserializeOwned;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deref, DerefMut, AsRef, Display, From)]
13pub struct Query<T>(pub T);
14
15impl<T> Query<T> {
16 pub fn into_inner(self) -> T {
17 self.0
18 }
19}
20
21impl<T: DeserializeOwned> Query<T> {
22 pub fn from_query(query_str: &str) -> Result<Self, QueryPayloadError> {
23 serde_urlencoded::from_str::<T>(query_str)
24 .map(Self)
25 .map_err(QueryPayloadError::Deserialize)
26 }
27}
28
29impl<T> FromRequest for Query<T>
30where
31 T: DeserializeOwned + Validate + 'static,
32 T::Context: Default,
33{
34 type Error = Error;
35 type Future = Ready<Result<Self, Error>>;
36
37 #[inline]
38 fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
39 let req_copy = req.clone();
40 let error_handler = req.app_data::<QueryConfig>().and_then(|c| c.err_handler.clone());
41
42 serde_urlencoded::from_str::<T>(req.query_string())
43 .map_err(|e| {
44 let e = QueryPayloadError::Deserialize(e);
45 crate::error::Error::QueryPayloadError(e)
46 })
47 .and_then(|data: T| {
48 let req = req_copy;
49 validate_for_request(data, &req)
50 })
51 .map(|val| ok(Query(val)))
52 .unwrap_or_else(move |e| {
53 log::debug!(
54 "Failed during Query extractor deserialization. \
55 Request path: {:?}",
56 req.path()
57 );
58
59 let e = if let Some(error_handler) = error_handler {
60 (error_handler)(e, req)
61 } else {
62 e.into()
63 };
64
65 err(e)
66 })
67 }
68}
69
70#[derive(Clone, Default)]
73pub struct QueryConfig {
74 #[allow(clippy::type_complexity)]
75 pub(crate) err_handler: Option<Arc<dyn Fn(crate::error::Error, &HttpRequest) -> Error + Send + Sync>>,
76}
77
78impl QueryConfig {
79 pub fn error_handler<F>(mut self, f: F) -> Self
80 where
81 F: Fn(crate::error::Error, &HttpRequest) -> Error + Send + Sync + 'static,
82 {
83 self.err_handler = Some(Arc::new(f));
84 self
85 }
86}
87
88#[cfg(test)]
89mod test {
90 use crate::web::{Query, QueryConfig};
91 use actix_http::StatusCode;
92 use actix_web::error::InternalError;
93 use actix_web::test::{TestRequest, call_service, init_service};
94 use actix_web::web::{post, resource};
95 use actix_web::{App, HttpResponse};
96 use garde::Validate;
97 use serde::{Deserialize, Serialize};
98
99 #[derive(Debug, PartialEq, Validate, Serialize, Deserialize)]
100 struct QueryData {
101 #[garde(range(min = 18, max = 28))]
102 age: u8,
103 }
104
105 #[derive(Debug, PartialEq, Validate, Serialize, Deserialize)]
106 #[garde(context(NumberContext))]
107 struct QueryDataWithContext {
108 #[garde(custom(is_big_enough))]
109 age: u8,
110 }
111
112 #[derive(Default, Debug)]
113 struct NumberContext {
114 min: u8,
115 }
116
117 fn is_big_enough(value: &u8, context: &NumberContext) -> garde::Result {
118 if value < &context.min {
119 return Err(garde::Error::new("Number is too low"));
120 }
121 Ok(())
122 }
123
124 async fn test_handler(_: Query<QueryData>) -> HttpResponse {
125 HttpResponse::Ok().finish()
126 }
127
128 async fn test_handler_with_context(_: Query<QueryDataWithContext>) -> HttpResponse {
129 HttpResponse::Ok().finish()
130 }
131
132 #[tokio::test]
133 async fn test_simple_query_validation() {
134 let app = init_service(App::new().service(resource("/").route(post().to(test_handler)))).await;
135
136 let req = TestRequest::post().uri("/?age=24").to_request();
137 let resp = call_service(&app, req).await;
138 assert_eq!(resp.status(), StatusCode::OK);
139
140 let req = TestRequest::post().uri("/?age=30").to_request();
141 let resp = call_service(&app, req).await;
142 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
143 }
144
145 #[tokio::test]
146 async fn test_query_validation_custom_config() {
147 let app = init_service(
148 App::new()
149 .app_data(
150 QueryConfig::default()
151 .error_handler(|err, _req| InternalError::from_response(err, HttpResponse::Conflict().finish()).into()),
152 )
153 .service(resource("/").route(post().to(test_handler))),
154 )
155 .await;
156
157 let req = TestRequest::post().uri("/?age=24").to_request();
158 let resp = call_service(&app, req).await;
159 assert_eq!(resp.status(), StatusCode::OK);
160
161 let req = TestRequest::post().uri("/?age=30").to_request();
162 let resp = call_service(&app, req).await;
163 assert_eq!(resp.status(), StatusCode::CONFLICT);
164 }
165
166 #[tokio::test]
167 async fn test_query_validation_with_context() {
168 let number_context = NumberContext { min: 25 };
169 let app = init_service(
170 App::new()
171 .app_data(number_context)
172 .service(resource("/").route(post().to(test_handler_with_context))),
173 )
174 .await;
175
176 let req = TestRequest::post().uri("/?age=24").to_request();
177 let resp = call_service(&app, req).await;
178 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
179
180 let req = TestRequest::post().uri("/?age=30").to_request();
181 let resp = call_service(&app, req).await;
182 assert_eq!(resp.status(), StatusCode::OK);
183 }
184
185 #[tokio::test]
186 async fn test_query_validation_with_missing_context() {
187 let app = init_service(App::new().service(resource("/").route(post().to(test_handler_with_context)))).await;
188
189 let req = TestRequest::post().uri("/?age=24").to_request();
190 let resp = call_service(&app, req).await;
191 assert_eq!(resp.status(), StatusCode::OK);
192
193 let req = TestRequest::post().uri("/?age=30").to_request();
194 let resp = call_service(&app, req).await;
195 assert_eq!(resp.status(), StatusCode::OK);
196 }
197}