1use axum::async_trait;
4use axum::body::Bytes;
5use axum::extract::{FromRequest, Request};
6use axum::http::header;
7use garde::Validate;
8use serde::de::DeserializeOwned;
9
10use crate::Error;
11
12pub struct ValidatedForm<T>(pub T);
15
16#[async_trait]
17impl<T, S> FromRequest<S> for ValidatedForm<T>
18where
19 T: DeserializeOwned + Validate + Send + 'static,
20 T::Context: Default,
21 S: Send + Sync,
22{
23 type Rejection = Error;
24
25 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
26 let content_type = req
27 .headers()
28 .get(header::CONTENT_TYPE)
29 .and_then(|v| v.to_str().ok())
30 .unwrap_or("")
31 .to_string();
32
33 let bytes = Bytes::from_request(req, state)
34 .await
35 .map_err(|e| Error::bad_request(e.to_string()))?;
36
37 let value: T = if content_type.starts_with("application/json") {
38 serde_json::from_slice(&bytes).map_err(|e| Error::bad_request(e.to_string()))?
39 } else if content_type.starts_with("application/x-www-form-urlencoded") {
40 serde_urlencoded::from_bytes(&bytes).map_err(|e| Error::bad_request(e.to_string()))?
41 } else if bytes.is_empty() {
42 return Err(Error::bad_request("empty request body"));
43 } else {
44 serde_json::from_slice(&bytes).map_err(|e| Error::bad_request(e.to_string()))?
46 };
47
48 value.validate_with(&Default::default())?;
49 Ok(ValidatedForm(value))
50 }
51}