use std::convert::Infallible;
use std::ops::{Deref, DerefMut};
use axum::extract::rejection::{FormRejection, JsonRejection, PathRejection, QueryRejection};
use axum::extract::{FromRequest, Multipart, Path, Query};
use axum::{async_trait, Form, Json, RequestExt};
use bytes::Bytes;
use http::header::AsHeaderName;
use http::{header, HeaderMap};
use serde::de::DeserializeOwned;
use serde_json::{Map, Value};
pub use validate::Validate;
use crate::support::filesystem;
mod validate;
#[derive(Debug)]
pub struct Request<T>(T);
#[async_trait]
impl<S> FromRequest<S> for Request<axum::extract::Request>
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(
req: axum::extract::Request,
_state: &S,
) -> Result<Self, Self::Rejection> {
Ok(Self(req))
}
}
impl From<axum::extract::Request> for Request<axum::extract::Request> {
fn from(request: axum::extract::Request) -> Self {
Self(request)
}
}
impl From<Request<axum::extract::Request>> for axum::extract::Request {
fn from(extractor: Request<axum::extract::Request>) -> Self {
extractor.0
}
}
impl<T> Deref for Request<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Request<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Request<axum::extract::Request> {
pub fn header<K>(&self, key: K) -> Option<&str>
where
K: AsHeaderName,
{
self.headers()
.get(key)
.and_then(|value| value.to_str().ok())
}
pub fn is_json(&self) -> bool {
self.content_type() == mime::APPLICATION_JSON
}
pub fn is_form(&self) -> bool {
self.content_type() == mime::APPLICATION_WWW_FORM_URLENCODED
}
pub fn is_multipart(&self) -> bool {
self.content_type() == mime::MULTIPART_FORM_DATA
}
pub fn content_type(&self) -> mime::Mime {
if has_content_type(self.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) {
return mime::APPLICATION_WWW_FORM_URLENCODED;
}
if has_content_type(self.headers(), &mime::APPLICATION_JSON) {
return mime::APPLICATION_JSON;
}
if has_content_type(self.headers(), &mime::MULTIPART_FORM_DATA) {
return mime::MULTIPART_FORM_DATA;
}
mime::TEXT_PLAIN
}
pub async fn path<'a, T: DeserializeOwned + Send + 'a + 'static>(
&'a mut self,
) -> eyre::Result<T, PathRejection> {
let Path(path) = self.extract_parts::<Path<T>>().await?;
Ok(path)
}
pub async fn query<'a, T: DeserializeOwned + 'a + 'static>(
&'a mut self,
) -> eyre::Result<T, QueryRejection> {
let Query(query) = self.extract_parts::<Query<T>>().await?;
Ok(query)
}
pub fn request(self) -> axum::extract::Request {
self.into()
}
pub async fn form<T: DeserializeOwned>(self) -> eyre::Result<T, FormRejection> {
let Form(payload) = Form::<T>::from_request(self.request(), &()).await?;
Ok(payload)
}
pub async fn json<T: DeserializeOwned>(self) -> eyre::Result<T, JsonRejection> {
let Json(payload) = Json::<T>::from_request(self.request(), &()).await?;
Ok(payload)
}
pub async fn multipart<T: DeserializeOwned>(self) -> eyre::Result<T> {
let mut multipart = Multipart::from_request(self.request(), &()).await?;
let mut data = Map::new();
while let Some(field) = multipart.next_field().await? {
if let Some(name) = field.name() {
if let Some(file_name) = field.file_name() {
data.insert(name.to_owned(), file_name.into());
} else {
data.insert(name.to_owned(), field.text().await?.into());
}
}
}
let payload = serde_json::from_value::<T>(data.into())?;
Ok(payload)
}
pub async fn file(self, key: &str) -> eyre::Result<Option<Bytes>> {
let mut multipart = Multipart::from_request(self.request(), &()).await?;
while let Ok(Some(field)) = multipart.next_field().await {
if let Some(name) = field.name() {
if name != key {
continue;
}
let file_name = if let Some(file_name) = field.file_name() {
file_name.to_owned()
} else {
continue;
};
let path = filesystem::store(&file_name, field).await.ok();
if let Some(path) = path {
return Ok(Some(filesystem::get(&path).await?));
}
}
}
Ok(None)
}
pub async fn content<T: DeserializeOwned>(self) -> eyre::Result<T> {
if self.is_json() {
return Ok(self.json::<T>().await?);
}
if self.is_form() {
return Ok(self.form::<T>().await?);
}
self.multipart::<T>().await
}
pub async fn all(self) -> eyre::Result<Map<String, Value>> {
self.content::<Map<String, Value>>().await
}
}
fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool {
let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
content_type
} else {
return false;
};
let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return false;
};
content_type.starts_with(expected_content_type.as_ref())
}