1use axum::{extract::Request, response::Response};
31use future::{AsyncResponseFuture, ResponseFuture};
32use futures_util::StreamExt;
33pub use layer::{AsyncFilterExLayer, FilterExLayer};
34pub use predicate::{AsyncPredicate, Predicate};
35use std::task::{Context, Poll};
36use tower::Service;
37
38mod future;
39mod layer;
40mod predicate;
41
42#[derive(Debug)]
46pub struct FilterEx<T, U> {
47 inner: T,
48 predicate: U,
49}
50
51impl<T: Clone, U: Clone> Clone for FilterEx<T, U> {
52 fn clone(&self) -> Self {
53 Self {
54 inner: self.inner.clone(),
55 predicate: self.predicate.clone(),
56 }
57 }
58}
59
60impl<T, U: Clone> FilterEx<T, U> {
61 pub fn new(inner: T, predicate: U) -> Self {
63 Self { inner, predicate }
64 }
65
66 pub fn layer(predicate: U) -> FilterExLayer<U> {
70 FilterExLayer::new(predicate)
71 }
72
73 pub fn check<R>(&mut self, request: R) -> Result<U::Request, U::Response>
75 where
76 U: Predicate<R>,
77 {
78 self.predicate.check(request)
79 }
80
81 pub fn get_ref(&self) -> &T {
83 &self.inner
84 }
85
86 pub fn get_mut(&mut self) -> &mut T {
88 &mut self.inner
89 }
90
91 pub fn into_inner(self) -> T {
93 self.inner
94 }
95}
96
97impl<T, U> Service<Request> for FilterEx<T, U>
98where
99 T: Service<U::Request, Response = Response>,
100 U: Predicate<Request, Response = Response>,
101{
102 type Response = T::Response;
103 type Error = T::Error;
104 type Future = ResponseFuture<T::Future>;
105
106 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107 self.inner.poll_ready(cx)
108 }
109
110 fn call(&mut self, req: Request) -> Self::Future {
111 match self.predicate.check(req) {
112 Ok(req) => ResponseFuture::Future {
113 future: self.inner.call(req),
114 },
115 Err(response) => ResponseFuture::Error {
116 response: Some(response),
117 },
118 }
119 }
120}
121
122#[derive(Debug)]
126pub struct AsyncFilterEx<T, U> {
127 inner: T,
128 predicate: U,
129}
130
131impl<T: Clone, U: Clone> Clone for AsyncFilterEx<T, U> {
132 fn clone(&self) -> Self {
133 Self {
134 inner: self.inner.clone(),
135 predicate: self.predicate.clone(),
136 }
137 }
138}
139
140impl<T, U> AsyncFilterEx<T, U> {
141 pub fn new(inner: T, predicate: U) -> Self {
143 Self { inner, predicate }
144 }
145
146 pub fn layer(predicate: U) -> AsyncFilterExLayer<U> {
150 AsyncFilterExLayer::new(predicate)
151 }
152
153 pub async fn check<R>(&mut self, request: R) -> Result<U::Request, U::Response>
155 where
156 U: AsyncPredicate<R>,
157 {
158 self.predicate.check(request).await
159 }
160
161 pub fn get_ref(&self) -> &T {
163 &self.inner
164 }
165
166 pub fn get_mut(&mut self) -> &mut T {
168 &mut self.inner
169 }
170
171 pub fn into_inner(self) -> T {
173 self.inner
174 }
175}
176
177impl<T, U> Service<Request> for AsyncFilterEx<T, U>
178where
179 T: Service<U::Request, Response = Response> + Clone,
180 U: AsyncPredicate<Request, Response = Response>,
181{
182 type Response = T::Response;
183 type Error = T::Error;
184 type Future = AsyncResponseFuture<U, T, Request>;
185
186 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
187 self.inner.poll_ready(cx)
188 }
189
190 fn call(&mut self, req: Request) -> Self::Future {
191 use std::mem;
192
193 let inner = self.inner.clone();
194 let inner = mem::replace(&mut self.inner, inner);
199
200 let check = self.predicate.check(req);
202
203 AsyncResponseFuture::new(check, inner)
204 }
205}
206
207pub async fn drain_body(request: Request) {
208 let mut data_stream = request.into_body().into_data_stream();
209 while let Some(_) = data_stream.next().await {}
210}