use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use http::{Request, Response};
use tower::Service;
use crate::http::{Body, BoxError, HttpService};
type Predicate = Arc<dyn Fn(&Request<Body>) -> bool + Send + Sync>;
fn method_predicate(method: http::Method) -> impl Fn(&Request<Body>) -> bool + Send + Sync {
move |req: &Request<Body>| *req.method() == method
}
fn methods_predicate(methods: Vec<http::Method>) -> impl Fn(&Request<Body>) -> bool + Send + Sync {
move |req: &Request<Body>| methods.contains(req.method())
}
type LayerFn = Box<dyn Fn(HttpService) -> HttpService + Send + Sync>;
struct Rule {
predicate: Predicate,
layer_fn: LayerFn,
}
pub struct Conditional {
rules: Vec<Rule>,
}
impl Conditional {
pub fn new() -> Self {
Self { rules: Vec::new() }
}
pub fn when<L>(
mut self,
predicate: impl Fn(&Request<Body>) -> bool + Send + Sync + 'static,
layer: L,
) -> Self
where
L: tower::Layer<HttpService> + Send + Sync + 'static,
L::Service:
Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Send + 'static,
<L::Service as Service<Request<Body>>>::Future: Send,
{
self.rules.push(Rule {
predicate: Arc::new(predicate),
layer_fn: Box::new(move |inner| tower::util::BoxService::new(layer.layer(inner))),
});
self
}
pub fn when_path<L>(self, path: impl Into<String>, layer: L) -> Self
where
L: tower::Layer<HttpService> + Send + Sync + 'static,
L::Service:
Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Send + 'static,
<L::Service as Service<Request<Body>>>::Future: Send,
{
let path = path.into();
self.when(move |req| req.uri().path() == path, layer)
}
pub fn when_path_glob<L>(self, pattern: &str, layer: L) -> Result<Self, globset::Error>
where
L: tower::Layer<HttpService> + Send + Sync + 'static,
L::Service:
Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Send + 'static,
<L::Service as Service<Request<Body>>>::Future: Send,
{
let matcher = globset::GlobBuilder::new(pattern)
.literal_separator(true)
.build()?
.compile_matcher();
Ok(self.when(move |req| matcher.is_match(req.uri().path()), layer))
}
pub fn when_method<L>(self, method: http::Method, layer: L) -> Self
where
L: tower::Layer<HttpService> + Send + Sync + 'static,
L::Service:
Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Send + 'static,
<L::Service as Service<Request<Body>>>::Future: Send,
{
self.when(method_predicate(method), layer)
}
pub fn when_methods<L>(self, methods: impl Into<Vec<http::Method>>, layer: L) -> Self
where
L: tower::Layer<HttpService> + Send + Sync + 'static,
L::Service:
Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Send + 'static,
<L::Service as Service<Request<Body>>>::Future: Send,
{
self.when(methods_predicate(methods.into()), layer)
}
}
pub trait ConditionalLayer: tower::Layer<HttpService> + Send + Sync + 'static
where
<Self as tower::Layer<HttpService>>::Service:
Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Send + 'static,
<<Self as tower::Layer<HttpService>>::Service as Service<Request<Body>>>::Future: Send,
{
fn when(
self,
predicate: impl Fn(&Request<Body>) -> bool + Send + Sync + 'static,
) -> Conditional;
fn when_path(self, path: impl Into<String>) -> Conditional;
fn when_path_glob(self, pattern: &str) -> Result<Conditional, globset::Error>;
fn when_method(self, method: http::Method) -> Conditional;
fn when_methods(self, methods: impl Into<Vec<http::Method>>) -> Conditional;
}
impl<L> ConditionalLayer for L
where
L: tower::Layer<HttpService> + Send + Sync + 'static,
L::Service:
Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Send + 'static,
<L::Service as Service<Request<Body>>>::Future: Send,
{
fn when(
self,
predicate: impl Fn(&Request<Body>) -> bool + Send + Sync + 'static,
) -> Conditional {
Conditional::new().when(predicate, self)
}
fn when_path(self, path: impl Into<String>) -> Conditional {
Conditional::new().when_path(path, self)
}
fn when_path_glob(self, pattern: &str) -> Result<Conditional, globset::Error> {
Conditional::new().when_path_glob(pattern, self)
}
fn when_method(self, method: http::Method) -> Conditional {
Conditional::new().when_method(method, self)
}
fn when_methods(self, methods: impl Into<Vec<http::Method>>) -> Conditional {
Conditional::new().when_methods(methods, self)
}
}
impl Default for Conditional {
fn default() -> Self {
Self::new()
}
}
impl tower::Layer<HttpService> for Conditional {
type Service = ConditionalService;
fn layer(&self, inner: HttpService) -> ConditionalService {
let shared = Arc::new(Mutex::new(inner));
let rules: Vec<(Predicate, HttpService)> = self
.rules
.iter()
.map(|rule| {
let accessor = SharedInnerService {
inner: shared.clone(),
};
let layered = (rule.layer_fn)(tower::util::BoxService::new(accessor));
(rule.predicate.clone(), layered)
})
.collect();
ConditionalService {
rules,
shared_inner: shared,
}
}
}
struct SharedInnerService {
inner: Arc<Mutex<HttpService>>,
}
impl Service<Request<Body>> for SharedInnerService {
type Response = Response<Body>;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.lock().unwrap().poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
self.inner.lock().unwrap().call(req)
}
}
pub struct ConditionalService {
rules: Vec<(Predicate, HttpService)>,
shared_inner: Arc<Mutex<HttpService>>,
}
impl Service<Request<Body>> for ConditionalService {
type Response = Response<Body>;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.shared_inner.lock().unwrap().poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
for (predicate, service) in &mut self.rules {
if (predicate)(&req) {
return service.call(req);
}
}
self.shared_inner.lock().unwrap().call(req)
}
}