1#[cfg(test)]
9pub mod test;
10
11use crate::{HasValidate, ValidationRejection};
12use axum::extract::{FromRef, FromRequest, FromRequestParts, Request};
13use axum::http::request::Parts;
14use std::fmt::Display;
15use std::ops::{Deref, DerefMut};
16use validator::{Validate, ValidateArgs, ValidationErrors};
17
18#[derive(Debug, Clone, Copy, Default)]
30pub struct Valid<E>(pub E);
31
32impl<E> Deref for Valid<E> {
33 type Target = E;
34
35 fn deref(&self) -> &Self::Target {
36 &self.0
37 }
38}
39
40impl<E> DerefMut for Valid<E> {
41 fn deref_mut(&mut self) -> &mut Self::Target {
42 &mut self.0
43 }
44}
45
46impl<T: Display> Display for Valid<T> {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 self.0.fmt(f)
49 }
50}
51
52impl<E> Valid<E> {
53 pub fn into_inner(self) -> E {
55 self.0
56 }
57}
58
59#[cfg(feature = "aide")]
60impl<T> aide::OperationInput for Valid<T>
61where
62 T: aide::OperationInput,
63{
64 fn operation_input(
65 ctx: &mut aide::generate::GenContext,
66 operation: &mut aide::openapi::Operation,
67 ) {
68 T::operation_input(ctx, operation);
69 }
70}
71
72#[derive(Debug, Clone, Copy, Default)]
86pub struct ValidEx<E>(pub E);
87
88impl<E> Deref for ValidEx<E> {
89 type Target = E;
90
91 fn deref(&self) -> &Self::Target {
92 &self.0
93 }
94}
95
96impl<E> DerefMut for ValidEx<E> {
97 fn deref_mut(&mut self) -> &mut Self::Target {
98 &mut self.0
99 }
100}
101
102impl<T: Display> Display for ValidEx<T> {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 self.0.fmt(f)
105 }
106}
107
108impl<E> ValidEx<E> {
109 pub fn into_inner(self) -> E {
114 self.0
115 }
116}
117
118#[cfg(feature = "aide")]
119impl<T> aide::OperationInput for ValidEx<T>
120where
121 T: aide::OperationInput,
122{
123 fn operation_input(
124 ctx: &mut aide::generate::GenContext,
125 operation: &mut aide::openapi::Operation,
126 ) {
127 T::operation_input(ctx, operation);
128 }
129}
130
131pub type ValidRejection<E> = ValidationRejection<ValidationErrors, E>;
134
135impl<E> From<ValidationErrors> for ValidRejection<E> {
136 fn from(value: ValidationErrors) -> Self {
137 Self::Valid(value)
138 }
139}
140
141pub trait HasValidateArgs<'v> {
146 type ValidateArgs: ValidateArgs<'v>;
148 fn get_validate_args(&self) -> &Self::ValidateArgs;
150}
151
152impl<State, Extractor> FromRequest<State> for Valid<Extractor>
153where
154 State: Send + Sync,
155 Extractor: HasValidate + FromRequest<State>,
156 Extractor::Validate: Validate,
157{
158 type Rejection = ValidRejection<<Extractor as FromRequest<State>>::Rejection>;
159
160 async fn from_request(req: Request, state: &State) -> Result<Self, Self::Rejection> {
161 let inner = Extractor::from_request(req, state)
162 .await
163 .map_err(ValidRejection::Inner)?;
164 inner.get_validate().validate()?;
165 Ok(Valid(inner))
166 }
167}
168
169impl<State, Extractor> FromRequestParts<State> for Valid<Extractor>
170where
171 State: Send + Sync,
172 Extractor: HasValidate + FromRequestParts<State>,
173 Extractor::Validate: Validate,
174{
175 type Rejection = ValidRejection<<Extractor as FromRequestParts<State>>::Rejection>;
176
177 async fn from_request_parts(parts: &mut Parts, state: &State) -> Result<Self, Self::Rejection> {
178 let inner = Extractor::from_request_parts(parts, state)
179 .await
180 .map_err(ValidRejection::Inner)?;
181 inner.get_validate().validate()?;
182 Ok(Valid(inner))
183 }
184}
185
186impl<State, Extractor, Args> FromRequest<State> for ValidEx<Extractor>
187where
188 State: Send + Sync,
189 Args: Send + Sync + FromRef<State>,
190 Extractor: for<'v> HasValidateArgs<'v> + FromRequest<State>,
191 for<'v> <Extractor as HasValidateArgs<'v>>::ValidateArgs: ValidateArgs<'v, Args = &'v Args>,
192{
193 type Rejection = ValidRejection<<Extractor as FromRequest<State>>::Rejection>;
194
195 async fn from_request(req: Request, state: &State) -> Result<Self, Self::Rejection> {
196 let arguments: Args = FromRef::from_ref(state);
197 let inner = Extractor::from_request(req, state)
198 .await
199 .map_err(ValidRejection::Inner)?;
200
201 inner.get_validate_args().validate_with_args(&arguments)?;
202 Ok(ValidEx(inner))
203 }
204}
205
206impl<State, Extractor, Args> FromRequestParts<State> for ValidEx<Extractor>
207where
208 State: Send + Sync,
209 Args: Send + Sync + FromRef<State>,
210 Extractor: for<'v> HasValidateArgs<'v> + FromRequestParts<State>,
211 for<'v> <Extractor as HasValidateArgs<'v>>::ValidateArgs: ValidateArgs<'v, Args = &'v Args>,
212{
213 type Rejection = ValidRejection<<Extractor as FromRequestParts<State>>::Rejection>;
214
215 async fn from_request_parts(parts: &mut Parts, state: &State) -> Result<Self, Self::Rejection> {
216 let arguments: Args = FromRef::from_ref(state);
217 let inner = Extractor::from_request_parts(parts, state)
218 .await
219 .map_err(ValidRejection::Inner)?;
220 inner.get_validate_args().validate_with_args(&arguments)?;
221 Ok(ValidEx(inner))
222 }
223}
224
225#[cfg(test)]
227pub mod tests {
228 use super::*;
229 use std::error::Error;
230 use std::fmt::Formatter;
231 use std::io;
232 use validator::ValidationError;
233 const TEST: &str = "test";
234
235 #[test]
236 fn valid_deref_deref_mut_into_inner() {
237 let mut inner = String::from(TEST);
238 let mut v = Valid(inner.clone());
239 assert_eq!(&inner, v.deref());
240 inner.push_str(TEST);
241 v.deref_mut().push_str(TEST);
242 assert_eq!(&inner, v.deref());
243 println!("{}", v);
244 assert_eq!(inner, v.into_inner());
245 }
246
247 #[test]
248 fn valid_ex_deref_deref_mut_into_inner_arguments() {
249 let mut inner = String::from(TEST);
250 let mut v = ValidEx(inner.clone());
251 assert_eq!(&inner, v.deref());
252 inner.push_str(TEST);
253 v.deref_mut().push_str(TEST);
254 assert_eq!(&inner, v.deref());
255 assert_eq!(inner, v.into_inner());
256
257 fn validate(v: i32, args: &DataVA) -> Result<(), ValidationError> {
258 assert!(v < args.a);
259 Ok(())
260 }
261
262 #[derive(Debug, Validate)]
263 #[validate(context = DataVA)]
264 struct Data {
265 #[validate(custom(function = "validate", use_context))]
266 v: i32,
267 }
268
269 impl Display for Data {
270 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
271 write!(f, "{:?}", self)
272 }
273 }
274
275 struct DataVA {
276 a: i32,
277 }
278
279 let v = 12;
280 let data = Data { v };
281 let args = DataVA { a: v + 1 };
282 let ve = ValidEx(data);
283 ve.validate_with_args(&args).expect("invalid");
284 println!("{}", ve);
285 assert_eq!(ve.v, v);
286 }
287
288 #[test]
289 fn display_error() {
290 let mut ve = ValidationErrors::new();
292 ve.add(TEST, ValidationError::new(TEST));
293 let vr = ValidRejection::<String>::Valid(ve.clone());
294 assert_eq!(vr.to_string(), ve.to_string());
295
296 let inner = String::from(TEST);
298 let vr = ValidRejection::<String>::Inner(inner.clone());
299 assert_eq!(inner.to_string(), vr.to_string());
300
301 let mut ve = ValidationErrors::new();
303 ve.add(TEST, ValidationError::new(TEST));
304 let vr = ValidRejection::<io::Error>::Valid(ve.clone());
305 assert!(
306 matches!(vr.source(), Some(source) if source.downcast_ref::<ValidationErrors>().is_some())
307 );
308
309 let vr = ValidRejection::<io::Error>::Inner(io::Error::other(TEST));
311 assert!(
312 matches!(vr.source(), Some(source) if source.downcast_ref::<io::Error>().is_some())
313 );
314 }
315}