use axum::{
extract::Request,
http,
middleware::Next,
response::{IntoResponse as _, Response},
};
use crate::Error;
pub type ResponseFuture =
core::pin::Pin<Box<dyn core::future::Future<Output = Response> + Send + 'static>>;
pub type Fn<S> = fn(axum::extract::State<S>, Request, Next) -> ResponseFuture;
pub type Extractors<S> = (axum::extract::State<S>, Request);
pub type Layer<State> = axum::middleware::FromFnLayer<Fn<State>, State, Extractors<State>>;
#[derive(Debug, Clone)]
pub struct State<Validator, ErrorHandler> {
pub validator: Validator,
pub error_handler: ErrorHandler,
}
pub fn new<Validator>(validator: Validator) -> Layer<State<Validator, PlainDisplayErrorRenderer>>
where
Validator: http_request_validator::Validator<super::Data> + Send + 'static,
<Validator as http_request_validator::Validator<super::Data>>::Error:
std::fmt::Display + Send + Sync + 'static,
{
with_error_handler(validator, PlainDisplayErrorRenderer)
}
pub fn with_error_handler<Validator, ErrorHandler>(
validator: Validator,
error_handler: ErrorHandler,
) -> Layer<State<Validator, ErrorHandler>>
where
Validator: http_request_validator::Validator<super::Data, Error: Send> + Send + 'static,
ErrorHandler: self::ErrorHandler<Validator::Error> + Send + 'static,
{
axum::middleware::from_fn_with_state(
State {
validator,
error_handler,
},
|state, req, next| Box::pin(middleware(state, req, next)),
)
}
pub trait ErrorHandler<V> {
type Response: axum::response::IntoResponse;
fn handle_error(
&self,
error: Error<V>,
) -> impl std::future::Future<Output = Self::Response> + Send + Sync;
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct PlainDisplayErrorRenderer;
impl<V> ErrorHandler<V> for PlainDisplayErrorRenderer
where
V: std::fmt::Display + Send + Sync,
for<'a> V: 'a,
{
type Response = (http::StatusCode, String);
async fn handle_error(&self, error: Error<V>) -> Self::Response {
match error {
Error::BodyBuffering(error) => (
http::StatusCode::BAD_REQUEST,
format!("Unable to buffer the request: {error}"),
),
Error::Validation(error) => (
http::StatusCode::FORBIDDEN,
format!("Invalid request: {error}"),
),
}
}
}
pub fn middleware<Validator, ErrorHandler>(
state: axum::extract::State<State<Validator, ErrorHandler>>,
req: Request,
next: Next,
) -> impl core::future::Future<Output = Response>
where
Validator: http_request_validator::Validator<super::Data, Error: Send> + Send,
ErrorHandler: self::ErrorHandler<Validator::Error> + Send,
{
let axum::extract::State(State {
validator,
error_handler,
}) = state;
async move {
let req = match super::validate(validator, req).await {
Ok(req) => req,
Err(error) => return error_handler.handle_error(error).await.into_response(),
};
next.run(req).await
}
}