micro_web/router/
filter.rs

1//! Request filtering module that provides composable request filters.
2//!
3//! This module implements a filter system that allows you to:
4//! - Filter requests based on HTTP methods
5//! - Filter requests based on headers
6//! - Combine multiple filters using AND/OR logic
7//! - Create custom filters using closures
8//!
9//! ## Thread Safety
10//!
11//! All filters must implement the `Filter` trait, which requires `Send + Sync`.
12//! This ensures that filters can be safely shared and used across threads,
13//! which is essential for concurrent request handling in a web server environment.
14//!
15//! # Examples
16//!
17//! ```
18//! use micro_web::router::filter::{all_filter, any_filter, get_method, header};
19//!
20//! // Create a filter that matches GET requests
21//! let get_filter = get_method();
22//!
23//! // Create a filter that checks for specific header
24//! let auth_filter = header("Authorization", "Bearer token");
25//!
26//! // Combine filters with AND logic
27//! let mut combined = all_filter();
28//! combined.and(get_filter).and(auth_filter);
29//! ```
30
31use std::any::type_name_of_val;
32use std::fmt::{Debug, Formatter};
33use crate::RequestContext;
34use http::{HeaderName, HeaderValue, Method};
35
36/// Core trait for request filtering.
37///
38/// Implementors of this trait can be used to filter HTTP requests
39/// based on custom logic. Filters can be composed using [`AllFilter`]
40/// and [`AnyFilter`].
41///
42///
43/// The `Filter` trait requires `Send + Sync`, ensuring that filters
44/// can be safely used in a multithreaded environment.
45pub trait Filter: Send + Sync + Debug {
46    /// Check if the request matches this filter's criteria.
47    ///
48    /// Returns `true` if the request should be allowed, `false` otherwise.
49    fn matches(&self, req: &RequestContext) -> bool;
50}
51
52/// A filter that wraps a closure.
53struct FnFilter<F: Fn(&RequestContext) -> bool>(F);
54
55impl<F: Fn(&RequestContext) -> bool> Debug for FnFilter<F> {
56    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("FnFilter")
58            .field("fn", &type_name_of_val(&self.0))
59            .finish()
60    }
61}
62
63impl<F: Fn(&RequestContext) -> bool + Send + Sync> Filter for FnFilter<F> {
64    fn matches(&self, req: &RequestContext) -> bool {
65        self.0(req)
66    }
67}
68
69/// Creates a new filter from a closure.
70///
71/// This allows creating custom filters using simple closures.
72///
73/// # Example
74/// ```
75/// use micro_web::router::filter::filter_fn;
76///
77/// let custom_filter = filter_fn(|req| {
78///     req.uri().path().starts_with("/api")
79/// });
80/// ```
81pub fn filter_fn<F>(f: F) -> impl Filter
82where
83    F: Fn(&RequestContext) -> bool + Send + Sync,
84{
85    FnFilter(f)
86}
87
88/// Creates a filter that always returns true.
89#[inline(always)]
90pub const fn true_filter() -> TrueFilter {
91    TrueFilter
92}
93
94/// Creates a filter that always returns false.
95#[inline(always)]
96pub const fn false_filter() -> FalseFilter {
97    FalseFilter
98}
99
100/// A filter that always returns true.
101#[derive(Debug)]
102pub struct TrueFilter;
103impl Filter for TrueFilter {
104    #[inline(always)]
105    fn matches(&self, _req: &RequestContext) -> bool {
106        true
107    }
108}
109
110/// A filter that always returns false.
111#[derive(Debug)]
112pub struct FalseFilter;
113impl Filter for FalseFilter {
114    #[inline(always)]
115    fn matches(&self, _req: &RequestContext) -> bool {
116        false
117    }
118}
119
120/// Creates a new OR-composed filter chain.
121pub fn any_filter() -> AnyFilter {
122    AnyFilter::new()
123}
124
125/// Compose filters with OR logic.
126///
127/// If any inner filter succeeds, the whole filter succeeds.
128/// An empty filter chain returns true by default.
129#[derive(Debug)]
130pub struct AnyFilter {
131    filters: Vec<Box<dyn Filter>>,
132}
133
134impl AnyFilter {
135    fn new() -> Self {
136        Self { filters: vec![] }
137    }
138
139    /// Add a new filter to the OR chain.
140    pub fn or<F: Filter + 'static>(&mut self, filter: F) -> &mut Self {
141        self.filters.push(Box::new(filter));
142        self
143    }
144}
145
146impl Filter for AnyFilter {
147    fn matches(&self, req: &RequestContext) -> bool {
148        if self.filters.is_empty() {
149            return true;
150        }
151
152        for filter in &self.filters {
153            if filter.matches(req) {
154                return true;
155            }
156        }
157
158        false
159    }
160}
161
162/// Creates a new AND-composed filter chain.
163pub fn all_filter() -> AllFilter {
164    AllFilter::new()
165}
166
167/// Compose filters with AND logic.
168///
169/// All inner filters must succeed for the whole filter to succeed.
170/// An empty filter chain returns true by default.
171#[derive(Debug)]
172pub struct AllFilter {
173    filters: Vec<Box<dyn Filter>>,
174}
175
176impl AllFilter {
177    fn new() -> Self {
178        Self { filters: vec![] }
179    }
180
181    /// Add a new filter to the AND chain.
182    pub fn and<F: Filter + 'static>(&mut self, filter: F) -> &mut Self {
183        self.filters.push(Box::new(filter));
184        self
185    }
186}
187
188impl Filter for AllFilter {
189    fn matches(&self, req: &RequestContext) -> bool {
190        if self.filters.is_empty() {
191            return true;
192        }
193
194        for filter in &self.filters {
195            if !filter.matches(req) {
196                return false;
197            }
198        }
199
200        true
201    }
202}
203
204/// A filter that matches HTTP methods.
205#[derive(Debug)]
206pub struct MethodFilter(Method);
207
208impl Filter for MethodFilter {
209    fn matches(&self, req: &RequestContext) -> bool {
210        self.0.eq(req.method())
211    }
212}
213
214macro_rules! method_filter {
215    ($method:ident, $upper_case_method:ident) => {
216        #[doc = concat!("Creates a filter that matches HTTP ", stringify!($upper_case_method), " requests.")]
217        #[inline]
218        pub fn $method() -> MethodFilter {
219            MethodFilter(Method::$upper_case_method)
220        }
221    };
222}
223
224method_filter!(get_method, GET);
225method_filter!(post_method, POST);
226method_filter!(put_method, PUT);
227method_filter!(delete_method, DELETE);
228method_filter!(head_method, HEAD);
229method_filter!(options_method, OPTIONS);
230method_filter!(connect_method, CONNECT);
231method_filter!(patch_method, PATCH);
232method_filter!(trace_method, TRACE);
233
234/// Creates a filter that matches a specific header name and value.
235#[inline]
236pub fn header<K, V>(header_name: K, header_value: V) -> HeaderFilter
237where
238    HeaderName: TryFrom<K>,
239    <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
240    HeaderValue: TryFrom<V>,
241    <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
242{
243    // TODO: need to process the unwrap
244    let name = <HeaderName as TryFrom<K>>::try_from(header_name).map_err(Into::into).unwrap();
245    let value = <HeaderValue as TryFrom<V>>::try_from(header_value).map_err(Into::into).unwrap();
246    HeaderFilter(name, value)
247}
248
249/// A filter that matches HTTP headers.
250#[derive(Debug)]
251pub struct HeaderFilter(HeaderName, HeaderValue);
252
253impl Filter for HeaderFilter {
254    fn matches(&self, req: &RequestContext) -> bool {
255        let value_option = req.headers().get(&self.0);
256        value_option.map(|value| self.1.eq(value)).unwrap_or(false)
257    }
258}