1use axum::{extract::Request, response::Response};
31use future::{AsyncResponseFuture, ResponseFuture};
32use futures_util::StreamExt;
33pub use layer::{AsyncFilterExLayer, FilterExLayer};
34pub use predicate::{AsyncPredicate, Predicate};
35use std::{
36 marker::PhantomData,
37 task::{Context, Poll},
38};
39use tower::Service;
40
41mod future;
42mod layer;
43mod predicate;
44
45#[derive(Debug)]
49pub struct FilterEx<T, U> {
50 inner: T,
51 predicate: U,
52}
53
54impl<T: Clone, U: Clone> Clone for FilterEx<T, U> {
55 fn clone(&self) -> Self {
56 Self {
57 inner: self.inner.clone(),
58 predicate: self.predicate.clone(),
59 }
60 }
61}
62
63impl<T, U: Clone> FilterEx<T, U> {
64 pub fn new(inner: T, predicate: U) -> Self {
66 Self { inner, predicate }
67 }
68
69 pub fn layer(predicate: U) -> FilterExLayer<U> {
73 FilterExLayer::new(predicate)
74 }
75
76 pub fn check<R>(&mut self, request: R) -> Result<U::Request, U::Response>
78 where
79 U: Predicate<R>,
80 {
81 self.predicate.check(request)
82 }
83
84 pub fn get_ref(&self) -> &T {
86 &self.inner
87 }
88
89 pub fn get_mut(&mut self) -> &mut T {
91 &mut self.inner
92 }
93
94 pub fn into_inner(self) -> T {
96 self.inner
97 }
98}
99
100impl<T, U> Service<Request> for FilterEx<T, U>
101where
102 T: Service<U::Request, Response = Response>,
103 U: Predicate<Request, Response = Response>,
104{
105 type Response = T::Response;
106 type Error = T::Error;
107 type Future = ResponseFuture<T::Future>;
108
109 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110 self.inner.poll_ready(cx)
111 }
112
113 fn call(&mut self, req: Request) -> Self::Future {
114 match self.predicate.check(req) {
115 Ok(req) => ResponseFuture::Future {
116 future: self.inner.call(req),
117 },
118 Err(response) => ResponseFuture::Error {
119 response: Some(response),
120 },
121 }
122 }
123}
124
125#[derive(Debug)]
129pub struct AsyncFilterEx<T, U, R>
130where
131 U: AsyncPredicate<R>,
132{
133 inner: T,
134 predicate: U,
135 _r: PhantomData<R>,
136}
137
138impl<T: Clone, U: Clone, R> Clone for AsyncFilterEx<T, U, R>
139where
140 U: AsyncPredicate<R>,
141{
142 fn clone(&self) -> Self {
143 Self {
144 inner: self.inner.clone(),
145 predicate: self.predicate.clone(),
146 _r: PhantomData,
147 }
148 }
149}
150
151impl<T, U, R> AsyncFilterEx<T, U, R>
152where
153 U: AsyncPredicate<R>,
154{
155 pub fn new(inner: T, predicate: U) -> Self {
157 Self {
158 inner,
159 predicate,
160 _r: PhantomData,
161 }
162 }
163
164 pub fn layer(predicate: U) -> AsyncFilterExLayer<U, R> {
168 AsyncFilterExLayer::new(predicate)
169 }
170
171 pub async fn check(&mut self, request: R) -> Result<U::Request, U::Response>
173 where
174 U: AsyncPredicate<R>,
175 {
176 self.predicate.check(request).await
177 }
178
179 pub fn get_ref(&self) -> &T {
181 &self.inner
182 }
183
184 pub fn get_mut(&mut self) -> &mut T {
186 &mut self.inner
187 }
188
189 pub fn into_inner(self) -> T {
191 self.inner
192 }
193}
194
195impl<T, U> Service<Request> for AsyncFilterEx<T, U, Request>
196where
197 T: Service<U::Request, Response = Response> + Clone,
198 U: AsyncPredicate<Request, Response = Response>,
199{
200 type Response = T::Response;
201 type Error = T::Error;
202 type Future = AsyncResponseFuture<U, T, Request>;
203
204 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
205 self.inner.poll_ready(cx)
206 }
207
208 fn call(&mut self, req: Request) -> Self::Future {
209 use std::mem;
210
211 let inner = self.inner.clone();
212 let inner = mem::replace(&mut self.inner, inner);
217
218 let check = self.predicate.check(req);
220
221 AsyncResponseFuture::new(check, inner)
222 }
223}
224
225pub async fn drain_body(request: Request) {
226 let mut data_stream = request.into_body().into_data_stream();
227 while let Some(_) = data_stream.next().await {}
228}