Skip to main content

mll_axum_utils/validator/
validate.rs

1use crate::res::Res;
2use axum::extract::rejection::{BytesRejection, RawFormRejection};
3use axum::{
4    async_trait,
5    body::HttpBody,
6    extract::{FromRequest, RawForm},
7    headers::{ContentType, HeaderMapExt},
8    http::{HeaderMap, Request},
9    BoxError, RequestExt,
10};
11use bytes::Bytes;
12use serde::de::DeserializeOwned;
13use validator::Validate;
14
15const ERROR_CODE: u16 = 422;
16
17/// 提取 Json 类型数据 并验证数据
18#[must_use]
19#[derive(Debug, Clone, Default)]
20pub struct VJson<T: Validate>(pub T);
21
22#[async_trait]
23impl<T, S, B> FromRequest<S, B> for VJson<T>
24where
25    T: DeserializeOwned + Validate,
26    B: HttpBody + Send + 'static,
27    B::Data: Send,
28    B::Error: Into<BoxError>,
29    S: Send + Sync,
30{
31    type Rejection = Res<()>;
32
33    async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
34        if !json_content_type(req.headers()) {
35            return Err(Res::msg(ERROR_CODE, "请求头必须为: application/json"));
36        }
37
38        let data = des_json(Bytes::from_request(req, state).await)?;
39        Ok(VJson(data))
40    }
41}
42
43/// 提取 Form 类型数据 并验证数据
44#[must_use]
45#[derive(Debug, Clone, Default)]
46pub struct VForm<T: Validate>(pub T);
47
48#[async_trait]
49impl<T, S, B> FromRequest<S, B> for VForm<T>
50where
51    T: DeserializeOwned + Validate,
52    B: HttpBody + Send + 'static,
53    B::Data: Send,
54    B::Error: Into<BoxError>,
55    S: Send + Sync,
56{
57    type Rejection = Res<()>;
58
59    async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
60        let data = des_form(req.extract::<RawForm, _>().await)?;
61        Ok(VForm(data))
62    }
63}
64
65/// 提取 Json 或者 Form 类型数据 并验证数据
66#[must_use]
67#[derive(Debug, Clone, Default)]
68pub struct VJsonOrForm<T: Validate>(pub T);
69
70#[async_trait]
71impl<T, S, B> FromRequest<S, B> for VJsonOrForm<T>
72where
73    T: DeserializeOwned + Validate,
74    B: HttpBody + Send + 'static,
75    B::Data: Send,
76    B::Error: Into<BoxError>,
77    S: Send + Sync,
78{
79    type Rejection = Res<()>;
80
81    async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
82        let data = if json_content_type(req.headers()) {
83            des_json(Bytes::from_request(req, state).await)?
84        } else {
85            des_form(req.extract::<RawForm, _>().await)?
86        };
87
88        Ok(VJsonOrForm(data))
89    }
90}
91
92/// 提取 Query 类型数据 并验证数据
93#[must_use]
94#[derive(Debug, Clone, Default)]
95pub struct VQuery<T: Validate>(pub T);
96
97#[async_trait]
98impl<T, S, B> FromRequest<S, B> for VQuery<T>
99where
100    T: DeserializeOwned + Validate,
101    B: HttpBody + Send + 'static,
102    B::Data: Send,
103    B::Error: Into<BoxError>,
104    S: Send + Sync,
105{
106    type Rejection = Res<()>;
107
108    async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
109        let data = serde_urlencoded::from_str::<T>(req.uri().query().unwrap_or_default())
110            .map_err(|err| Res::msg(ERROR_CODE, err))?;
111
112        validate(&data)?;
113        Ok(VQuery(data))
114    }
115}
116
117/// 判断 json 请求头
118pub fn json_content_type(headers: &HeaderMap) -> bool {
119    headers
120        .typed_get::<ContentType>()
121        .map(|t| t.to_string() == "application/json")
122        .unwrap_or(false)
123}
124
125/// 数据验证
126pub fn validate(data: impl Validate) -> Result<(), Res<()>> {
127    if let Err(err) = data.validate() {
128        let mut msg = Vec::new();
129        for (k, v) in err.field_errors() {
130            for item in v {
131                msg.push(format!(
132                    "{k:}: validate failed tips: {}",
133                    item.message.as_ref().unwrap_or(&item.code)
134                ));
135            }
136        }
137        return Err(Res::msgs(ERROR_CODE, msg));
138    }
139    Ok(())
140}
141
142/// 返序列化 json
143fn des_json<T>(data: Result<Bytes, BytesRejection>) -> Result<T, Res<()>>
144where
145    T: Validate + DeserializeOwned,
146{
147    let bytes = data.map_err(|_| Res::msg(ERROR_CODE, "获取数据流失败"))?;
148    let data = serde_json::from_slice::<T>(&bytes).map_err(|e| {
149        Res::msg(
150            ERROR_CODE,
151            e.to_string().split(" at line").next().unwrap_or_default(),
152        )
153    })?;
154
155    validate(&data)?;
156    Ok(data)
157}
158
159/// 返序列化 form
160fn des_form<T>(data: Result<RawForm, RawFormRejection>) -> Result<T, Res<()>>
161where
162    T: Validate + DeserializeOwned,
163{
164    let data = match data {
165        Ok(RawForm(bytes)) => serde_urlencoded::from_bytes::<T>(&bytes)
166            .map_err(|err| Res::msg(ERROR_CODE, err.to_string()))?,
167        Err(_) => return Err(Res::msg(ERROR_CODE, "无法获取到表单数据")),
168    };
169
170    validate(&data)?;
171    Ok(data)
172}