axum_request_validator/layer.rs
1//! An axum layer for HTTP request validation.
2
3use axum::{
4 extract::Request,
5 http,
6 middleware::Next,
7 response::{IntoResponse as _, Response},
8};
9
10use crate::Error;
11
12/// A future that returns [`Response`].
13pub type ResponseFuture =
14 core::pin::Pin<Box<dyn core::future::Future<Output = Response> + Send + 'static>>;
15
16/// The type alias for the fn used in the layer.
17pub type Fn<S> = fn(axum::extract::State<S>, Request, Next) -> ResponseFuture;
18
19/// The type alias for the extractors used in the layer.
20pub type Extractors<S> = (axum::extract::State<S>, Request);
21
22/// The layer type.
23pub type Layer<State> = axum::middleware::FromFnLayer<Fn<State>, State, Extractors<State>>;
24
25/// The layer state.
26#[derive(Debug, Clone)]
27pub struct State<Validator, ErrorHandler> {
28 /// The validator to use.
29 pub validator: Validator,
30
31 /// The error handler to use.
32 pub error_handler: ErrorHandler,
33}
34
35/// Create a new HTTP request validating layer.
36///
37/// ## Examples
38///
39/// ```
40/// # #[derive(Clone)]
41/// # struct MyValidator;
42/// #
43/// # impl<Data: bytes::Buf + Send + Sync> http_request_validator::Validator<Data> for MyValidator {
44/// # type Error = &'static str;
45/// #
46/// # async fn validate<'a>(
47/// # &'a self,
48/// # _parts: &'a axum::http::request::Parts,
49/// # buffered_body: &'a Data,
50/// # ) -> Result<(), Self::Error> {
51/// # unimplemented!();
52/// # }
53/// # }
54/// #
55/// use axum::{routing::get, Router};
56///
57/// let app = Router::new()
58/// .route("/", get(|| async { "Hello, World!" }))
59/// .route_layer(axum_request_validator::new(MyValidator));
60/// # let _: Router<()> = app;
61/// ```
62pub fn new<Validator>(validator: Validator) -> Layer<State<Validator, PlainDisplayErrorRenderer>>
63where
64 Validator: http_request_validator::Validator<super::Data> + Send + 'static,
65 <Validator as http_request_validator::Validator<super::Data>>::Error:
66 std::fmt::Display + Send + Sync + 'static,
67{
68 with_error_handler(validator, PlainDisplayErrorRenderer)
69}
70
71/// Create a new HTTP request validating layer with custom error handling.
72///
73/// ## Examples
74///
75/// ```
76/// # #[derive(Clone)]
77/// # struct MyValidator;
78/// #
79/// # impl<Data: bytes::Buf + Send + Sync> http_request_validator::Validator<Data> for MyValidator {
80/// # type Error = &'static str;
81/// #
82/// # async fn validate<'a>(
83/// # &'a self,
84/// # _parts: &'a axum::http::request::Parts,
85/// # buffered_body: &'a Data,
86/// # ) -> Result<(), Self::Error> {
87/// # unimplemented!();
88/// # }
89/// # }
90/// #
91/// use axum::{routing::get, Router, http::StatusCode};
92/// use axum_request_validator::{Error, ErrorHandler};
93///
94/// #[derive(Debug, Clone)]
95/// struct MyErrorHandler;
96///
97/// impl<V> ErrorHandler<V> for MyErrorHandler
98/// where
99/// V: std::fmt::Display + Send + Sync + 'static,
100/// {
101/// type Response = (StatusCode, String);
102///
103/// async fn handle_error(&self, error: Error<V>) -> Self::Response {
104/// match error {
105/// Error::BodyBuffering(error) => (
106/// StatusCode::BAD_REQUEST,
107/// format!("Unable to buffer the request: {error}"),
108/// ),
109/// Error::Validation(error) => (
110/// StatusCode::FORBIDDEN,
111/// format!("Invalid request: {error}"),
112/// ),
113/// }
114/// }
115/// }
116///
117/// let app = Router::new()
118/// .route("/", get(|| async { "Hello, World!" }))
119/// .route_layer(axum_request_validator::with_error_handler(MyValidator, MyErrorHandler));
120/// # let _: Router<()> = app;
121/// ```
122pub fn with_error_handler<Validator, ErrorHandler>(
123 validator: Validator,
124 error_handler: ErrorHandler,
125) -> Layer<State<Validator, ErrorHandler>>
126where
127 Validator: http_request_validator::Validator<super::Data, Error: Send> + Send + 'static,
128 ErrorHandler: self::ErrorHandler<Validator::Error> + Send + 'static,
129{
130 axum::middleware::from_fn_with_state(
131 State {
132 validator,
133 error_handler,
134 },
135 |state, req, next| Box::pin(middleware(state, req, next)),
136 )
137}
138
139/// The error handler for the validation errors.
140pub trait ErrorHandler<V> {
141 /// Whatever the handler should respond with.
142 type Response: axum::response::IntoResponse;
143
144 /// Handler the validation error.
145 fn handle_error(
146 &self,
147 error: Error<V>,
148 ) -> impl std::future::Future<Output = Self::Response> + Send + Sync;
149}
150
151/// A an error renderer that will simply.
152#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
153pub struct PlainDisplayErrorRenderer;
154
155impl<V> ErrorHandler<V> for PlainDisplayErrorRenderer
156where
157 V: std::fmt::Display + Send + Sync,
158 for<'a> V: 'a,
159{
160 type Response = (http::StatusCode, String);
161
162 async fn handle_error(&self, error: Error<V>) -> Self::Response {
163 match error {
164 Error::BodyBuffering(error) => (
165 http::StatusCode::BAD_REQUEST,
166 format!("Unable to buffer the request: {error}"),
167 ),
168 Error::Validation(error) => (
169 http::StatusCode::FORBIDDEN,
170 format!("Invalid request: {error}"),
171 ),
172 }
173 }
174}
175
176/// [`axum`] middleware-fn implementation.
177pub fn middleware<Validator, ErrorHandler>(
178 state: axum::extract::State<State<Validator, ErrorHandler>>,
179 req: Request,
180 next: Next,
181) -> impl core::future::Future<Output = Response>
182where
183 Validator: http_request_validator::Validator<super::Data, Error: Send> + Send,
184 ErrorHandler: self::ErrorHandler<Validator::Error> + Send,
185{
186 let axum::extract::State(State {
187 validator,
188 error_handler,
189 }) = state;
190 async move {
191 let req = match super::validate(validator, req).await {
192 Ok(req) => req,
193 Err(error) => return error_handler.handle_error(error).await.into_response(),
194 };
195 next.run(req).await
196 }
197}