axum_help/filter/
mod.rs

1//! Conditionally dispatch requests to the inner service based on the result of
2//! a predicate.
3//!
4//! Unlike [filter](https://docs.rs/tower/latest/tower/filter/index.html) mod in
5//! tower, this let you return a custom [`response`](http::response::Response) to user when the request is rejected.
6//!
7//! # Example
8//!```
9//! # use axum::routing::{get, Router};
10//! # use axum::response::IntoResponse;
11//! # use axum::body::Body;
12//! # use axum::headers::{authorization::Basic, Authorization, HeaderMapExt};
13//! # use axum_help::filter::FilterExLayer;
14//! # use http::{Request, StatusCode};
15//! #
16//! # fn main() {
17//!     Router::new()
18//!         .route("/get", get(|| async { "get works" }))
19//!         .layer(FilterExLayer::new(|request: Request<Body>| {
20//!             if let Some(_auth) = request.headers().typed_get::<Authorization<Basic>>() {
21//!                 // TODO: do something
22//!                 Ok(request)
23//!            } else {
24//!                Err(StatusCode::UNAUTHORIZED.into_response())
25//!            }
26//!         }));
27//! # }
28//!```
29//!
30use axum::{extract::Request, response::Response};
31use future::{AsyncResponseFuture, ResponseFuture};
32use futures_util::StreamExt;
33pub use layer::{AsyncFilterExLayer, FilterExLayer};
34pub use predicate::{AsyncPredicate, Predicate};
35use std::task::{Context, Poll};
36use tower::Service;
37
38mod future;
39mod layer;
40mod predicate;
41
42/// Conditionally dispatch requests to the inner service based on a [predicate].
43///
44/// [predicate]: Predicate
45#[derive(Debug)]
46pub struct FilterEx<T, U> {
47    inner: T,
48    predicate: U,
49}
50
51impl<T: Clone, U: Clone> Clone for FilterEx<T, U> {
52    fn clone(&self) -> Self {
53        Self {
54            inner: self.inner.clone(),
55            predicate: self.predicate.clone(),
56        }
57    }
58}
59
60impl<T, U: Clone> FilterEx<T, U> {
61    /// Returns a new [FilterEx] service wrapping `inner`
62    pub fn new(inner: T, predicate: U) -> Self {
63        Self { inner, predicate }
64    }
65
66    /// Returns a new [Layer](tower::Layer) that wraps services with a [FilterEx] service
67    /// with the given [Predicate]
68    ///
69    pub fn layer(predicate: U) -> FilterExLayer<U> {
70        FilterExLayer::new(predicate)
71    }
72
73    /// Check a `Request` value against thie filter's predicate
74    pub fn check<R>(&mut self, request: R) -> Result<U::Request, U::Response>
75    where
76        U: Predicate<R>,
77    {
78        self.predicate.check(request)
79    }
80
81    /// Get a reference to the inner service
82    pub fn get_ref(&self) -> &T {
83        &self.inner
84    }
85
86    /// Get a mutable reference to the inner service
87    pub fn get_mut(&mut self) -> &mut T {
88        &mut self.inner
89    }
90
91    /// Consume `self`, returning the inner service
92    pub fn into_inner(self) -> T {
93        self.inner
94    }
95}
96
97impl<T, U> Service<Request> for FilterEx<T, U>
98where
99    T: Service<U::Request, Response = Response>,
100    U: Predicate<Request, Response = Response>,
101{
102    type Response = T::Response;
103    type Error = T::Error;
104    type Future = ResponseFuture<T::Future>;
105
106    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107        self.inner.poll_ready(cx)
108    }
109
110    fn call(&mut self, req: Request) -> Self::Future {
111        match self.predicate.check(req) {
112            Ok(req) => ResponseFuture::Future {
113                future: self.inner.call(req),
114            },
115            Err(response) => ResponseFuture::Error {
116                response: Some(response),
117            },
118        }
119    }
120}
121
122/// Conditionally dispatch requests to the inner service based on an
123/// asynchronous [predicate](AsyncPredicate)
124///
125#[derive(Debug)]
126pub struct AsyncFilterEx<T, U> {
127    inner: T,
128    predicate: U,
129}
130
131impl<T: Clone, U: Clone> Clone for AsyncFilterEx<T, U> {
132    fn clone(&self) -> Self {
133        Self {
134            inner: self.inner.clone(),
135            predicate: self.predicate.clone(),
136        }
137    }
138}
139
140impl<T, U> AsyncFilterEx<T, U> {
141    /// Returns a new [AsyncFilterEx] service wrapping `inner`.
142    pub fn new(inner: T, predicate: U) -> Self {
143        Self { inner, predicate }
144    }
145
146    /// Returns a new [Layer](tower::Layer) that wraps services with a [AsyncFilterEx] service
147    /// with the given [AsyncPredicate]
148    ///
149    pub fn layer(predicate: U) -> AsyncFilterExLayer<U> {
150        AsyncFilterExLayer::new(predicate)
151    }
152
153    /// Check a `Request` value against thie filter's predicate
154    pub async fn check<R>(&mut self, request: R) -> Result<U::Request, U::Response>
155    where
156        U: AsyncPredicate<R>,
157    {
158        self.predicate.check(request).await
159    }
160
161    /// Get a reference to the inner service
162    pub fn get_ref(&self) -> &T {
163        &self.inner
164    }
165
166    /// Get a mutable reference to the inner service
167    pub fn get_mut(&mut self) -> &mut T {
168        &mut self.inner
169    }
170
171    /// Consume `self`, returning the inner service
172    pub fn into_inner(self) -> T {
173        self.inner
174    }
175}
176
177impl<T, U> Service<Request> for AsyncFilterEx<T, U>
178where
179    T: Service<U::Request, Response = Response> + Clone,
180    U: AsyncPredicate<Request, Response = Response>,
181{
182    type Response = T::Response;
183    type Error = T::Error;
184    type Future = AsyncResponseFuture<U, T, Request>;
185
186    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
187        self.inner.poll_ready(cx)
188    }
189
190    fn call(&mut self, req: Request) -> Self::Future {
191        use std::mem;
192
193        let inner = self.inner.clone();
194        // In case the inner service has state that's driven to readiness and
195        // not tracked by clones (such as `Buffer`), pass the version we have
196        // already called `poll_ready` on into the future, and leave its clone
197        // behind.
198        let inner = mem::replace(&mut self.inner, inner);
199
200        // Check the request
201        let check = self.predicate.check(req);
202
203        AsyncResponseFuture::new(check, inner)
204    }
205}
206
207pub async fn drain_body(request: Request) {
208    let mut data_stream = request.into_body().into_data_stream();
209    while let Some(_) = data_stream.next().await {}
210}