axum_help/filter/
mod.rs

1//! Conditionally dispatch requests to the inner service based on the result of
2//! a predicate.
3//!
4//! Unlike [filter](https://docs.rs/tower/latest/tower/filter/index.html) mod in
5//! tower, this let you return a custom [`response`](http::response::Response) to user when the request is rejected.
6//!
7//! # Example
8//!```
9//! # use axum::routing::{get, Router};
10//! # use axum::response::IntoResponse;
11//! # use axum::body::Body;
12//! # use axum::headers::{authorization::Basic, Authorization, HeaderMapExt};
13//! # use axum_help::filter::FilterExLayer;
14//! # use http::{Request, StatusCode};
15//! #
16//! # fn main() {
17//!     Router::new()
18//!         .route("/get", get(|| async { "get works" }))
19//!         .layer(FilterExLayer::new(|request: Request<Body>| {
20//!             if let Some(_auth) = request.headers().typed_get::<Authorization<Basic>>() {
21//!                 // TODO: do something
22//!                 Ok(request)
23//!            } else {
24//!                Err(StatusCode::UNAUTHORIZED.into_response())
25//!            }
26//!         }));
27//! # }
28//!```
29//!
30use axum::{extract::Request, response::Response};
31use future::{AsyncResponseFuture, ResponseFuture};
32use futures_util::StreamExt;
33pub use layer::{AsyncFilterExLayer, FilterExLayer};
34pub use predicate::{AsyncPredicate, Predicate};
35use std::{
36    marker::PhantomData,
37    task::{Context, Poll},
38};
39use tower::Service;
40
41mod future;
42mod layer;
43mod predicate;
44
45/// Conditionally dispatch requests to the inner service based on a [predicate].
46///
47/// [predicate]: Predicate
48#[derive(Debug)]
49pub struct FilterEx<T, U> {
50    inner: T,
51    predicate: U,
52}
53
54impl<T: Clone, U: Clone> Clone for FilterEx<T, U> {
55    fn clone(&self) -> Self {
56        Self {
57            inner: self.inner.clone(),
58            predicate: self.predicate.clone(),
59        }
60    }
61}
62
63impl<T, U: Clone> FilterEx<T, U> {
64    /// Returns a new [FilterEx] service wrapping `inner`
65    pub fn new(inner: T, predicate: U) -> Self {
66        Self { inner, predicate }
67    }
68
69    /// Returns a new [Layer](tower::Layer) that wraps services with a [FilterEx] service
70    /// with the given [Predicate]
71    ///
72    pub fn layer(predicate: U) -> FilterExLayer<U> {
73        FilterExLayer::new(predicate)
74    }
75
76    /// Check a `Request` value against thie filter's predicate
77    pub fn check<R>(&mut self, request: R) -> Result<U::Request, U::Response>
78    where
79        U: Predicate<R>,
80    {
81        self.predicate.check(request)
82    }
83
84    /// Get a reference to the inner service
85    pub fn get_ref(&self) -> &T {
86        &self.inner
87    }
88
89    /// Get a mutable reference to the inner service
90    pub fn get_mut(&mut self) -> &mut T {
91        &mut self.inner
92    }
93
94    /// Consume `self`, returning the inner service
95    pub fn into_inner(self) -> T {
96        self.inner
97    }
98}
99
100impl<T, U> Service<Request> for FilterEx<T, U>
101where
102    T: Service<U::Request, Response = Response>,
103    U: Predicate<Request, Response = Response>,
104{
105    type Response = T::Response;
106    type Error = T::Error;
107    type Future = ResponseFuture<T::Future>;
108
109    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110        self.inner.poll_ready(cx)
111    }
112
113    fn call(&mut self, req: Request) -> Self::Future {
114        match self.predicate.check(req) {
115            Ok(req) => ResponseFuture::Future {
116                future: self.inner.call(req),
117            },
118            Err(response) => ResponseFuture::Error {
119                response: Some(response),
120            },
121        }
122    }
123}
124
125/// Conditionally dispatch requests to the inner service based on an
126/// asynchronous [predicate](AsyncPredicate)
127///
128#[derive(Debug)]
129pub struct AsyncFilterEx<T, U, R>
130where
131    U: AsyncPredicate<R>,
132{
133    inner: T,
134    predicate: U,
135    _r: PhantomData<R>,
136}
137
138impl<T: Clone, U: Clone, R> Clone for AsyncFilterEx<T, U, R>
139where
140    U: AsyncPredicate<R>,
141{
142    fn clone(&self) -> Self {
143        Self {
144            inner: self.inner.clone(),
145            predicate: self.predicate.clone(),
146            _r: PhantomData,
147        }
148    }
149}
150
151impl<T, U, R> AsyncFilterEx<T, U, R>
152where
153    U: AsyncPredicate<R>,
154{
155    /// Returns a new [AsyncFilterEx] service wrapping `inner`.
156    pub fn new(inner: T, predicate: U) -> Self {
157        Self {
158            inner,
159            predicate,
160            _r: PhantomData,
161        }
162    }
163
164    /// Returns a new [Layer](tower::Layer) that wraps services with a [AsyncFilterEx] service
165    /// with the given [AsyncPredicate]
166    ///
167    pub fn layer(predicate: U) -> AsyncFilterExLayer<U, R> {
168        AsyncFilterExLayer::new(predicate)
169    }
170
171    /// Check a `Request` value against thie filter's predicate
172    pub async fn check(&mut self, request: R) -> Result<U::Request, U::Response>
173    where
174        U: AsyncPredicate<R>,
175    {
176        self.predicate.check(request).await
177    }
178
179    /// Get a reference to the inner service
180    pub fn get_ref(&self) -> &T {
181        &self.inner
182    }
183
184    /// Get a mutable reference to the inner service
185    pub fn get_mut(&mut self) -> &mut T {
186        &mut self.inner
187    }
188
189    /// Consume `self`, returning the inner service
190    pub fn into_inner(self) -> T {
191        self.inner
192    }
193}
194
195impl<T, U> Service<Request> for AsyncFilterEx<T, U, Request>
196where
197    T: Service<U::Request, Response = Response> + Clone,
198    U: AsyncPredicate<Request, Response = Response>,
199{
200    type Response = T::Response;
201    type Error = T::Error;
202    type Future = AsyncResponseFuture<U, T, Request>;
203
204    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
205        self.inner.poll_ready(cx)
206    }
207
208    fn call(&mut self, req: Request) -> Self::Future {
209        use std::mem;
210
211        let inner = self.inner.clone();
212        // In case the inner service has state that's driven to readiness and
213        // not tracked by clones (such as `Buffer`), pass the version we have
214        // already called `poll_ready` on into the future, and leave its clone
215        // behind.
216        let inner = mem::replace(&mut self.inner, inner);
217
218        // Check the request
219        let check = self.predicate.check(req);
220
221        AsyncResponseFuture::new(check, inner)
222    }
223}
224
225pub async fn drain_body(request: Request) {
226    let mut data_stream = request.into_body().into_data_stream();
227    while let Some(_) = data_stream.next().await {}
228}