use std::future::Future;
use std::marker::PhantomData;
use std::task::{Context, Poll};
use futures::ready;
use tower::{Layer, Service};
use crate::futures::SelectServiceAndCallFut;
pub trait AsyncFilter<T>: Clone + Send + Sync {
type Future: Future<Output = bool> + Send;
fn matches(&self, item: &T) -> Self::Future;
}
pub struct AsyncFilterLayer<F, S, T, R, E>
where
F: AsyncFilter<T>,
S: Service<T, Response = R, Error = E>,
{
filter: F,
service: S,
_marker: PhantomData<(T, R, E)>,
}
impl<F, S, R, E, T> Clone for AsyncFilterLayer<F, S, T, R, E>
where
F: AsyncFilter<T> + Clone,
S: Service<T, Response = R, Error = E> + Clone,
{
fn clone(&self) -> Self {
Self {
filter: self.filter.clone(),
service: self.service.clone(),
_marker: PhantomData,
}
}
}
impl<F: AsyncFilter<T>, S: Service<T>, T: Send + 'static>
AsyncFilterLayer<F, S, T, S::Response, S::Error>
{
pub fn new(filter: F, service: S) -> Self {
Self {
filter,
service,
_marker: PhantomData,
}
}
}
impl<F, S, I, T, R, E> Layer<I> for AsyncFilterLayer<F, S, T, R, E>
where
F: AsyncFilter<T> + Clone,
S: Service<T, Response = R, Error = E> + Clone,
I: Service<T, Response = R, Error = E> + Clone,
T: Send + 'static,
{
type Service = AsyncFilterService<F, S, I, T, R, E>;
fn layer(&self, inner_service: I) -> Self::Service {
let filter = self.filter.clone();
let filtered_service = self.service.clone();
AsyncFilterService {
filter,
service: filtered_service,
inner: inner_service,
_marker: PhantomData,
}
}
}
#[derive(Debug)]
pub struct AsyncFilterService<F, S, I, T, R, E>
where
F: AsyncFilter<T>,
S: Service<T, Response = R, Error = E>,
I: Service<T, Response = R, Error = E>,
{
filter: F,
service: S,
inner: I,
_marker: PhantomData<(T, R, E)>,
}
impl<F, S, I, T, R, E> Clone for AsyncFilterService<F, S, I, T, R, E>
where
F: AsyncFilter<T>,
S: Service<T, Response = R, Error = E> + Clone,
I: Service<T, Response = R, Error = E> + Clone,
{
fn clone(&self) -> Self {
Self {
filter: self.filter.clone(),
service: self.service.clone(),
inner: self.inner.clone(),
_marker: PhantomData,
}
}
}
impl<F, S, I, T, R, E> Service<T> for AsyncFilterService<F, S, I, T, R, E>
where
F: AsyncFilter<T>,
F::Future: Send + 'static,
S: Service<T, Response = R, Error = E> + Clone + Send + 'static,
S::Future: Send + 'static,
I: Service<T, Response = R, Error = E> + Clone + Send + 'static,
I::Future: Send + 'static,
T: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = SelectServiceAndCallFut<F::Future, S, I, T, R, E>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.service.poll_ready(cx))?;
ready!(self.inner.poll_ready(cx))?;
Poll::Ready(Ok(()))
}
fn call(&mut self, req: T) -> Self::Future {
let matches = self.filter.matches(&req);
let clone = self.service.clone();
let service = std::mem::replace(&mut self.service, clone);
let clone = self.inner.clone();
let inner = std::mem::replace(&mut self.inner, clone);
SelectServiceAndCallFut::new(matches, req, service, inner)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::*;
#[tokio::test]
async fn should_allow() {
let service_a = TestService("a");
let service_b = TestService("b");
let filter = TestFilter(true);
let filter_layer = AsyncFilterLayer::new(filter, service_a);
let mut middleware = filter_layer.layer(service_b);
assert_eq!(middleware.call(()).await, Ok("a"));
}
#[tokio::test]
async fn should_fall_through() {
let service_a = TestService("a");
let service_b = TestService("b");
let filter = TestFilter(false);
let filter_layer = AsyncFilterLayer::new(filter, service_a);
let mut middleware = filter_layer.layer(service_b);
assert_eq!(middleware.call(()).await, Ok("b"));
}
}