use crate::{Form, Query};
use axum_core::{
extract::{FromRequest, FromRequestParts, Request},
response::{IntoResponse, Response},
};
use facet_core::Facet;
use http::{StatusCode, header, request::Parts};
use http_body_util::BodyExt;
use std::fmt;
#[derive(Debug)]
pub struct FormRejection {
kind: FormRejectionKind,
}
#[derive(Debug)]
enum FormRejectionKind {
BodyError(axum_core::Error),
DeserializeError(crate::UrlEncodedError),
InvalidUtf8,
InvalidContentType,
}
impl FormRejection {
pub const fn status(&self) -> StatusCode {
match &self.kind {
FormRejectionKind::BodyError(_) => StatusCode::BAD_REQUEST,
FormRejectionKind::DeserializeError(_) => StatusCode::UNPROCESSABLE_ENTITY,
FormRejectionKind::InvalidUtf8 => StatusCode::BAD_REQUEST,
FormRejectionKind::InvalidContentType => StatusCode::UNSUPPORTED_MEDIA_TYPE,
}
}
}
impl fmt::Display for FormRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
FormRejectionKind::BodyError(err) => {
write!(f, "Failed to read request body: {err}")
}
FormRejectionKind::DeserializeError(err) => {
write!(f, "Failed to deserialize form data: {err}")
}
FormRejectionKind::InvalidUtf8 => {
write!(f, "Request body is not valid UTF-8")
}
FormRejectionKind::InvalidContentType => {
write!(
f,
"Invalid `Content-Type` header: expected `application/x-www-form-urlencoded`"
)
}
}
}
}
impl std::error::Error for FormRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.kind {
FormRejectionKind::BodyError(err) => Some(err),
FormRejectionKind::DeserializeError(err) => Some(err),
FormRejectionKind::InvalidUtf8 => None,
FormRejectionKind::InvalidContentType => None,
}
}
}
impl IntoResponse for FormRejection {
fn into_response(self) -> Response {
let body = self.to_string();
let status = self.status();
(status, body).into_response()
}
}
impl From<axum_core::Error> for FormRejection {
fn from(err: axum_core::Error) -> Self {
FormRejection {
kind: FormRejectionKind::BodyError(err),
}
}
}
impl From<crate::UrlEncodedError> for FormRejection {
fn from(err: crate::UrlEncodedError) -> Self {
FormRejection {
kind: FormRejectionKind::DeserializeError(err),
}
}
}
#[derive(Debug)]
pub struct QueryRejection {
kind: QueryRejectionKind,
}
#[derive(Debug)]
enum QueryRejectionKind {
DeserializeError(crate::UrlEncodedError),
}
impl QueryRejection {
pub const fn status(&self) -> StatusCode {
match &self.kind {
QueryRejectionKind::DeserializeError(_) => StatusCode::BAD_REQUEST,
}
}
}
impl fmt::Display for QueryRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
QueryRejectionKind::DeserializeError(err) => {
write!(f, "Failed to deserialize query parameters: {err}")
}
}
}
}
impl std::error::Error for QueryRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.kind {
QueryRejectionKind::DeserializeError(err) => Some(err),
}
}
}
impl IntoResponse for QueryRejection {
fn into_response(self) -> Response {
let body = self.to_string();
let status = self.status();
(status, body).into_response()
}
}
impl From<crate::UrlEncodedError> for QueryRejection {
fn from(err: crate::UrlEncodedError) -> Self {
QueryRejection {
kind: QueryRejectionKind::DeserializeError(err),
}
}
}
fn is_form_content_type(req: &Request) -> bool {
let Some(content_type) = req.headers().get(header::CONTENT_TYPE) else {
return false;
};
let Ok(content_type) = content_type.to_str() else {
return false;
};
content_type.starts_with("application/x-www-form-urlencoded")
}
impl<T, S> FromRequest<S> for Form<T>
where
T: Facet<'static>,
S: Send + Sync,
{
type Rejection = FormRejection;
async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
if !is_form_content_type(&req) {
return Err(FormRejection {
kind: FormRejectionKind::InvalidContentType,
});
}
let bytes = req
.into_body()
.collect()
.await
.map_err(axum_core::Error::new)?
.to_bytes();
let body_str = std::str::from_utf8(&bytes).map_err(|_| FormRejection {
kind: FormRejectionKind::InvalidUtf8,
})?;
let value: T = crate::from_str_owned(body_str)?;
Ok(Form(value))
}
}
impl<T, S> FromRequestParts<S> for Query<T>
where
T: Facet<'static>,
S: Send + Sync,
{
type Rejection = QueryRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = parts.uri.query().unwrap_or_default();
let value: T = crate::from_str_owned(query)?;
Ok(Query(value))
}
}