use std::any::type_name_of_val;
use std::fmt::{Debug, Formatter};
use crate::RequestContext;
use http::{HeaderName, HeaderValue, Method};
pub trait Filter: Send + Sync + Debug {
fn matches(&self, req: &RequestContext) -> bool;
}
struct FnFilter<F: Fn(&RequestContext) -> bool>(F);
impl<F: Fn(&RequestContext) -> bool> Debug for FnFilter<F> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FnFilter")
.field("fn", &type_name_of_val(&self.0))
.finish()
}
}
impl<F: Fn(&RequestContext) -> bool + Send + Sync> Filter for FnFilter<F> {
fn matches(&self, req: &RequestContext) -> bool {
self.0(req)
}
}
pub fn filter_fn<F>(f: F) -> impl Filter
where
F: Fn(&RequestContext) -> bool + Send + Sync,
{
FnFilter(f)
}
#[inline(always)]
pub const fn true_filter() -> TrueFilter {
TrueFilter
}
#[inline(always)]
pub const fn false_filter() -> FalseFilter {
FalseFilter
}
#[derive(Debug)]
pub struct TrueFilter;
impl Filter for TrueFilter {
#[inline(always)]
fn matches(&self, _req: &RequestContext) -> bool {
true
}
}
#[derive(Debug)]
pub struct FalseFilter;
impl Filter for FalseFilter {
#[inline(always)]
fn matches(&self, _req: &RequestContext) -> bool {
false
}
}
pub fn any_filter() -> AnyFilter {
AnyFilter::new()
}
#[derive(Debug)]
pub struct AnyFilter {
filters: Vec<Box<dyn Filter>>,
}
impl AnyFilter {
fn new() -> Self {
Self { filters: vec![] }
}
pub fn or<F: Filter + 'static>(&mut self, filter: F) -> &mut Self {
self.filters.push(Box::new(filter));
self
}
}
impl Filter for AnyFilter {
fn matches(&self, req: &RequestContext) -> bool {
if self.filters.is_empty() {
return true;
}
for filter in &self.filters {
if filter.matches(req) {
return true;
}
}
false
}
}
pub fn all_filter() -> AllFilter {
AllFilter::new()
}
#[derive(Debug)]
pub struct AllFilter {
filters: Vec<Box<dyn Filter>>,
}
impl AllFilter {
fn new() -> Self {
Self { filters: vec![] }
}
pub fn and<F: Filter + 'static>(&mut self, filter: F) -> &mut Self {
self.filters.push(Box::new(filter));
self
}
}
impl Filter for AllFilter {
fn matches(&self, req: &RequestContext) -> bool {
if self.filters.is_empty() {
return true;
}
for filter in &self.filters {
if !filter.matches(req) {
return false;
}
}
true
}
}
#[derive(Debug)]
pub struct MethodFilter(Method);
impl Filter for MethodFilter {
fn matches(&self, req: &RequestContext) -> bool {
self.0.eq(req.method())
}
}
macro_rules! method_filter {
($method:ident, $upper_case_method:ident) => {
#[doc = concat!("Creates a filter that matches HTTP ", stringify!($upper_case_method), " requests.")]
#[inline]
pub fn $method() -> MethodFilter {
MethodFilter(Method::$upper_case_method)
}
};
}
method_filter!(get_method, GET);
method_filter!(post_method, POST);
method_filter!(put_method, PUT);
method_filter!(delete_method, DELETE);
method_filter!(head_method, HEAD);
method_filter!(options_method, OPTIONS);
method_filter!(connect_method, CONNECT);
method_filter!(patch_method, PATCH);
method_filter!(trace_method, TRACE);
#[inline]
pub fn header<K, V>(header_name: K, header_value: V) -> HeaderFilter
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
let name = <HeaderName as TryFrom<K>>::try_from(header_name).map_err(Into::into).unwrap();
let value = <HeaderValue as TryFrom<V>>::try_from(header_value).map_err(Into::into).unwrap();
HeaderFilter(name, value)
}
#[derive(Debug)]
pub struct HeaderFilter(HeaderName, HeaderValue);
impl Filter for HeaderFilter {
fn matches(&self, req: &RequestContext) -> bool {
let value_option = req.headers().get(&self.0);
value_option.map(|value| self.1.eq(value)).unwrap_or(false)
}
}