axum_request_validator/
layer.rs

1//! An axum layer for HTTP request validation.
2
3use axum::{
4    extract::Request,
5    http,
6    middleware::Next,
7    response::{IntoResponse as _, Response},
8};
9
10use crate::Error;
11
12/// A future that returns [`Response`].
13pub type ResponseFuture =
14    core::pin::Pin<Box<dyn core::future::Future<Output = Response> + Send + 'static>>;
15
16/// The type alias for the fn used in the layer.
17pub type Fn<S> = fn(axum::extract::State<S>, Request, Next) -> ResponseFuture;
18
19/// The type alias for the extractors used in the layer.
20pub type Extractors<S> = (axum::extract::State<S>, Request);
21
22/// The layer type.
23pub type Layer<State> = axum::middleware::FromFnLayer<Fn<State>, State, Extractors<State>>;
24
25/// The layer state.
26#[derive(Debug, Clone)]
27pub struct State<Validator, ErrorHandler> {
28    /// The validator to use.
29    pub validator: Validator,
30
31    /// The error handler to use.
32    pub error_handler: ErrorHandler,
33}
34
35/// Create a new HTTP request validating layer.
36///
37/// ## Examples
38///
39/// ```
40/// # #[derive(Clone)]
41/// # struct MyValidator;
42/// #
43/// # impl<Data: bytes::Buf + Send + Sync> http_request_validator::Validator<Data> for MyValidator {
44/// #    type Error = &'static str;
45/// #
46/// #    async fn validate<'a>(
47/// #        &'a self,
48/// #        _parts: &'a axum::http::request::Parts,
49/// #        buffered_body: &'a Data,
50/// #    ) -> Result<(), Self::Error> {
51/// #        unimplemented!();
52/// #    }
53/// # }
54/// #
55/// use axum::{routing::get, Router};
56///
57/// let app = Router::new()
58///     .route("/", get(|| async { "Hello, World!" }))
59///     .route_layer(axum_request_validator::new(MyValidator));
60/// # let _: Router<()> = app;
61/// ```
62pub fn new<Validator>(validator: Validator) -> Layer<State<Validator, PlainDisplayErrorRenderer>>
63where
64    Validator: http_request_validator::Validator<super::Data> + Send + 'static,
65    <Validator as http_request_validator::Validator<super::Data>>::Error:
66        std::fmt::Display + Send + Sync + 'static,
67{
68    with_error_handler(validator, PlainDisplayErrorRenderer)
69}
70
71/// Create a new HTTP request validating layer with custom error handling.
72///
73/// ## Examples
74///
75/// ```
76/// # #[derive(Clone)]
77/// # struct MyValidator;
78/// #
79/// # impl<Data: bytes::Buf + Send + Sync> http_request_validator::Validator<Data> for MyValidator {
80/// #    type Error = &'static str;
81/// #
82/// #    async fn validate<'a>(
83/// #        &'a self,
84/// #        _parts: &'a axum::http::request::Parts,
85/// #        buffered_body: &'a Data,
86/// #    ) -> Result<(), Self::Error> {
87/// #        unimplemented!();
88/// #    }
89/// # }
90/// #
91/// use axum::{routing::get, Router, http::StatusCode};
92/// use axum_request_validator::{Error, ErrorHandler};
93///
94/// #[derive(Debug, Clone)]
95/// struct MyErrorHandler;
96///
97/// impl<V> ErrorHandler<V> for MyErrorHandler
98/// where
99///     V: std::fmt::Display + Send + Sync + 'static,
100/// {
101///     type Response = (StatusCode, String);
102///
103///     async fn handle_error(&self, error: Error<V>) -> Self::Response {
104///         match error {
105///             Error::BodyBuffering(error) => (
106///                 StatusCode::BAD_REQUEST,
107///                 format!("Unable to buffer the request: {error}"),
108///             ),
109///             Error::Validation(error) => (
110///                 StatusCode::FORBIDDEN,
111///                 format!("Invalid request: {error}"),
112///             ),
113///         }
114///     }
115/// }
116///
117/// let app = Router::new()
118///     .route("/", get(|| async { "Hello, World!" }))
119///     .route_layer(axum_request_validator::with_error_handler(MyValidator, MyErrorHandler));
120/// # let _: Router<()> = app;
121/// ```
122pub fn with_error_handler<Validator, ErrorHandler>(
123    validator: Validator,
124    error_handler: ErrorHandler,
125) -> Layer<State<Validator, ErrorHandler>>
126where
127    Validator: http_request_validator::Validator<super::Data, Error: Send> + Send + 'static,
128    ErrorHandler: self::ErrorHandler<Validator::Error> + Send + 'static,
129{
130    axum::middleware::from_fn_with_state(
131        State {
132            validator,
133            error_handler,
134        },
135        |state, req, next| Box::pin(middleware(state, req, next)),
136    )
137}
138
139/// The error handler for the validation errors.
140pub trait ErrorHandler<V> {
141    /// Whatever the handler should respond with.
142    type Response: axum::response::IntoResponse;
143
144    /// Handler the validation error.
145    fn handle_error(
146        &self,
147        error: Error<V>,
148    ) -> impl std::future::Future<Output = Self::Response> + Send + Sync;
149}
150
151/// A an error renderer that will simply.
152#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
153pub struct PlainDisplayErrorRenderer;
154
155impl<V> ErrorHandler<V> for PlainDisplayErrorRenderer
156where
157    V: std::fmt::Display + Send + Sync,
158    for<'a> V: 'a,
159{
160    type Response = (http::StatusCode, String);
161
162    async fn handle_error(&self, error: Error<V>) -> Self::Response {
163        match error {
164            Error::BodyBuffering(error) => (
165                http::StatusCode::BAD_REQUEST,
166                format!("Unable to buffer the request: {error}"),
167            ),
168            Error::Validation(error) => (
169                http::StatusCode::FORBIDDEN,
170                format!("Invalid request: {error}"),
171            ),
172        }
173    }
174}
175
176/// [`axum`] middleware-fn implementation.
177pub fn middleware<Validator, ErrorHandler>(
178    state: axum::extract::State<State<Validator, ErrorHandler>>,
179    req: Request,
180    next: Next,
181) -> impl core::future::Future<Output = Response>
182where
183    Validator: http_request_validator::Validator<super::Data, Error: Send> + Send,
184    ErrorHandler: self::ErrorHandler<Validator::Error> + Send,
185{
186    let axum::extract::State(State {
187        validator,
188        error_handler,
189    }) = state;
190    async move {
191        let req = match super::validate(validator, req).await {
192            Ok(req) => req,
193            Err(error) => return error_handler.handle_error(error).await.into_response(),
194        };
195        next.run(req).await
196    }
197}