mll_axum_utils/validator/
validate.rs1use crate::res::Res;
2use axum::extract::rejection::{BytesRejection, RawFormRejection};
3use axum::{
4 async_trait,
5 body::HttpBody,
6 extract::{FromRequest, RawForm},
7 headers::{ContentType, HeaderMapExt},
8 http::{HeaderMap, Request},
9 BoxError, RequestExt,
10};
11use bytes::Bytes;
12use serde::de::DeserializeOwned;
13use validator::Validate;
14
15const ERROR_CODE: u16 = 422;
16
17#[must_use]
19#[derive(Debug, Clone, Default)]
20pub struct VJson<T: Validate>(pub T);
21
22#[async_trait]
23impl<T, S, B> FromRequest<S, B> for VJson<T>
24where
25 T: DeserializeOwned + Validate,
26 B: HttpBody + Send + 'static,
27 B::Data: Send,
28 B::Error: Into<BoxError>,
29 S: Send + Sync,
30{
31 type Rejection = Res<()>;
32
33 async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
34 if !json_content_type(req.headers()) {
35 return Err(Res::msg(ERROR_CODE, "请求头必须为: application/json"));
36 }
37
38 let data = des_json(Bytes::from_request(req, state).await)?;
39 Ok(VJson(data))
40 }
41}
42
43#[must_use]
45#[derive(Debug, Clone, Default)]
46pub struct VForm<T: Validate>(pub T);
47
48#[async_trait]
49impl<T, S, B> FromRequest<S, B> for VForm<T>
50where
51 T: DeserializeOwned + Validate,
52 B: HttpBody + Send + 'static,
53 B::Data: Send,
54 B::Error: Into<BoxError>,
55 S: Send + Sync,
56{
57 type Rejection = Res<()>;
58
59 async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
60 let data = des_form(req.extract::<RawForm, _>().await)?;
61 Ok(VForm(data))
62 }
63}
64
65#[must_use]
67#[derive(Debug, Clone, Default)]
68pub struct VJsonOrForm<T: Validate>(pub T);
69
70#[async_trait]
71impl<T, S, B> FromRequest<S, B> for VJsonOrForm<T>
72where
73 T: DeserializeOwned + Validate,
74 B: HttpBody + Send + 'static,
75 B::Data: Send,
76 B::Error: Into<BoxError>,
77 S: Send + Sync,
78{
79 type Rejection = Res<()>;
80
81 async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
82 let data = if json_content_type(req.headers()) {
83 des_json(Bytes::from_request(req, state).await)?
84 } else {
85 des_form(req.extract::<RawForm, _>().await)?
86 };
87
88 Ok(VJsonOrForm(data))
89 }
90}
91
92#[must_use]
94#[derive(Debug, Clone, Default)]
95pub struct VQuery<T: Validate>(pub T);
96
97#[async_trait]
98impl<T, S, B> FromRequest<S, B> for VQuery<T>
99where
100 T: DeserializeOwned + Validate,
101 B: HttpBody + Send + 'static,
102 B::Data: Send,
103 B::Error: Into<BoxError>,
104 S: Send + Sync,
105{
106 type Rejection = Res<()>;
107
108 async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
109 let data = serde_urlencoded::from_str::<T>(req.uri().query().unwrap_or_default())
110 .map_err(|err| Res::msg(ERROR_CODE, err))?;
111
112 validate(&data)?;
113 Ok(VQuery(data))
114 }
115}
116
117pub fn json_content_type(headers: &HeaderMap) -> bool {
119 headers
120 .typed_get::<ContentType>()
121 .map(|t| t.to_string() == "application/json")
122 .unwrap_or(false)
123}
124
125pub fn validate(data: impl Validate) -> Result<(), Res<()>> {
127 if let Err(err) = data.validate() {
128 let mut msg = Vec::new();
129 for (k, v) in err.field_errors() {
130 for item in v {
131 msg.push(format!(
132 "{k:}: validate failed tips: {}",
133 item.message.as_ref().unwrap_or(&item.code)
134 ));
135 }
136 }
137 return Err(Res::msgs(ERROR_CODE, msg));
138 }
139 Ok(())
140}
141
142fn des_json<T>(data: Result<Bytes, BytesRejection>) -> Result<T, Res<()>>
144where
145 T: Validate + DeserializeOwned,
146{
147 let bytes = data.map_err(|_| Res::msg(ERROR_CODE, "获取数据流失败"))?;
148 let data = serde_json::from_slice::<T>(&bytes).map_err(|e| {
149 Res::msg(
150 ERROR_CODE,
151 e.to_string().split(" at line").next().unwrap_or_default(),
152 )
153 })?;
154
155 validate(&data)?;
156 Ok(data)
157}
158
159fn des_form<T>(data: Result<RawForm, RawFormRejection>) -> Result<T, Res<()>>
161where
162 T: Validate + DeserializeOwned,
163{
164 let data = match data {
165 Ok(RawForm(bytes)) => serde_urlencoded::from_bytes::<T>(&bytes)
166 .map_err(|err| Res::msg(ERROR_CODE, err.to_string()))?,
167 Err(_) => return Err(Res::msg(ERROR_CODE, "无法获取到表单数据")),
168 };
169
170 validate(&data)?;
171 Ok(data)
172}