Skip to main content

gatel_core/router/
matcher.rs

1/// Path matching strategies and advanced request matchers.
2use std::net::SocketAddr;
3
4use http::Request;
5
6use crate::Body;
7use crate::glob::glob_matches;
8
9// ---------------------------------------------------------------------------
10// Path matching (existing functionality)
11// ---------------------------------------------------------------------------
12
13/// Match a request path against a route pattern.
14///
15/// Supported patterns:
16/// - `"/*"` — matches everything
17/// - `"/api/*"` — prefix match (path starts with `/api/`)
18/// - `"/exact"` — exact match
19/// - `"*.php"` — suffix/extension match
20pub fn path_matches(pattern: &str, path: &str) -> bool {
21    if pattern == "/*" || pattern == "*" {
22        return true;
23    }
24    // Extension glob: "*.php" matches any path ending with ".php"
25    if pattern.starts_with('*') && !pattern.starts_with("**") {
26        let suffix = &pattern[1..];
27        return path.ends_with(suffix);
28    }
29    if let Some(prefix) = pattern.strip_suffix("/*") {
30        // Prefix match: "/api/*" matches "/api/", "/api/foo", "/api/foo/bar"
31        return path == prefix || path.starts_with(&format!("{prefix}/"));
32    }
33    if let Some(prefix) = pattern.strip_suffix('*') {
34        // Glob-like: "/static*" matches "/static", "/staticfiles", etc.
35        return path.starts_with(prefix);
36    }
37    // Exact match
38    path == pattern
39}
40
41/// Sort patterns by specificity (most specific first).
42/// Longer non-wildcard prefixes are more specific.
43pub fn pattern_specificity(pattern: &str) -> usize {
44    if pattern == "/*" || pattern == "*" {
45        return 0;
46    }
47    if pattern.starts_with('*') {
48        // Extension matchers like "*.php" are somewhat specific.
49        return pattern.len();
50    }
51    if let Some(prefix) = pattern.strip_suffix("/*") {
52        return prefix.len() + 1;
53    }
54    if pattern.ends_with('*') {
55        return pattern.len();
56    }
57    // Exact match is most specific
58    pattern.len() + 1000
59}
60
61// ---------------------------------------------------------------------------
62// Advanced request matchers
63// ---------------------------------------------------------------------------
64
65/// A composable request matcher that can test various aspects of an incoming
66/// HTTP request.  Matchers can be combined with `And`, `Or`, and `Not`.
67#[derive(Debug, Clone, serde::Serialize)]
68pub enum RequestMatcher {
69    /// Match the request path using glob-style patterns.
70    Path(String),
71    /// Match the HTTP method (e.g. `["GET", "POST"]`).
72    Method(Vec<String>),
73    /// Match a header value with glob pattern (e.g. `name="X-Custom"`, `pattern="foo*"`).
74    Header { name: String, pattern: String },
75    /// Match a header value with regex-like glob pattern.
76    HeaderRegex { name: String, regex: String },
77    /// Match a query parameter.  If `value` is `None`, just check presence.
78    Query { key: String, value: Option<String> },
79    /// Match client IP against CIDR ranges (e.g. `["192.168.0.0/16", "10.0.0.0/8"]`).
80    RemoteIp(Vec<String>),
81    /// Match the protocol/scheme (e.g. `"https"`, `"http"`).
82    Protocol(String),
83    /// Simple expression matcher: `"{method} == GET && {path} ~ /api/*"`.
84    Expression(String),
85    /// Logical NOT.
86    Not(Box<RequestMatcher>),
87    /// Logical AND: all must match.
88    And(Vec<RequestMatcher>),
89    /// Logical OR: at least one must match.
90    Or(Vec<RequestMatcher>),
91    /// Match the Accept-Language header against a list of language tags.
92    /// Uses case-insensitive prefix matching (e.g. "en" matches "en-US").
93    Language(Vec<String>),
94}
95
96impl RequestMatcher {
97    /// Test whether an incoming request matches this matcher.
98    pub fn matches(&self, req: &Request<Body>, client_addr: SocketAddr) -> bool {
99        match self {
100            RequestMatcher::Path(pattern) => {
101                let path = req.uri().path();
102                path_matches(pattern, path)
103            }
104
105            RequestMatcher::Method(methods) => {
106                let req_method = req.method().as_str().to_uppercase();
107                methods.iter().any(|m| m.to_uppercase() == req_method)
108            }
109
110            RequestMatcher::Header { name, pattern } => {
111                if let Ok(header_name) = name.parse::<http::header::HeaderName>() {
112                    req.headers()
113                        .get(&header_name)
114                        .and_then(|v| v.to_str().ok())
115                        .map(|v| glob_matches(pattern, v))
116                        .unwrap_or(false)
117                } else {
118                    false
119                }
120            }
121
122            RequestMatcher::HeaderRegex { name, regex } => {
123                if let Ok(header_name) = name.parse::<http::header::HeaderName>() {
124                    req.headers()
125                        .get(&header_name)
126                        .and_then(|v| v.to_str().ok())
127                        .map(|v| glob_matches(regex, v))
128                        .unwrap_or(false)
129                } else {
130                    false
131                }
132            }
133
134            RequestMatcher::Query { key, value } => {
135                let query_str = req.uri().query().unwrap_or("");
136                match_query_param(query_str, key, value.as_deref())
137            }
138
139            RequestMatcher::RemoteIp(cidrs) => {
140                let client_ip = client_addr.ip();
141                cidrs.iter().any(|cidr| match_cidr(cidr, &client_ip))
142            }
143
144            RequestMatcher::Protocol(proto) => {
145                let scheme = req.uri().scheme_str().unwrap_or("http");
146                scheme.eq_ignore_ascii_case(proto)
147            }
148
149            RequestMatcher::Expression(expr) => eval_expression(expr, req, client_addr),
150
151            RequestMatcher::Not(inner) => !inner.matches(req, client_addr),
152
153            RequestMatcher::And(matchers) => matchers.iter().all(|m| m.matches(req, client_addr)),
154
155            RequestMatcher::Or(matchers) => matchers.iter().any(|m| m.matches(req, client_addr)),
156
157            RequestMatcher::Language(langs) => {
158                // Parse Accept-Language header value and check for prefix matches.
159                // E.g. "en-US,en;q=0.9,fr;q=0.8" → ["en-US", "en", "fr"]
160                let header_value = req
161                    .headers()
162                    .get(http::header::ACCEPT_LANGUAGE)
163                    .and_then(|v| v.to_str().ok())
164                    .unwrap_or("");
165                // Extract the language tags (strip quality values).
166                let accepted: Vec<&str> = header_value
167                    .split(',')
168                    .map(|part| part.split(';').next().unwrap_or("").trim())
169                    .filter(|s| !s.is_empty())
170                    .collect();
171                langs.iter().any(|configured| {
172                    accepted.iter().any(|accepted_lang| {
173                        // Prefix match: "en" matches "en-US" or "en"
174                        let c = configured.to_lowercase();
175                        let a = accepted_lang.to_lowercase();
176                        a == c || a.starts_with(&format!("{c}-"))
177                    })
178                })
179            }
180        }
181    }
182}
183
184// ---------------------------------------------------------------------------
185// Glob matching
186// ---------------------------------------------------------------------------
187
188// ---------------------------------------------------------------------------
189// Query parameter matching
190// ---------------------------------------------------------------------------
191
192/// Check if a query string contains a parameter with the given key
193/// (and optionally value).
194fn match_query_param(query: &str, key: &str, value: Option<&str>) -> bool {
195    for pair in query.split('&') {
196        if pair.is_empty() {
197            continue;
198        }
199        let (k, v) = if let Some(eq_pos) = pair.find('=') {
200            (&pair[..eq_pos], Some(&pair[eq_pos + 1..]))
201        } else {
202            (pair, None)
203        };
204        if k == key {
205            match value {
206                None => return true, // just check presence
207                Some(expected) => {
208                    if v == Some(expected) {
209                        return true;
210                    }
211                }
212            }
213        }
214    }
215    false
216}
217
218// ---------------------------------------------------------------------------
219// CIDR matching
220// ---------------------------------------------------------------------------
221
222/// Public re-export of CIDR matching used by the router condition evaluator.
223pub fn match_cidr_pub(cidr: &str, ip: &std::net::IpAddr) -> bool {
224    match_cidr(cidr, ip)
225}
226
227/// Match an IP address against a CIDR range.
228///
229/// Supports:
230/// - Exact IP: `"192.168.1.1"`
231/// - CIDR notation: `"192.168.0.0/16"`, `"10.0.0.0/8"`
232/// - IPv6: `"::1"`, `"fd00::/8"`
233fn match_cidr(cidr: &str, ip: &std::net::IpAddr) -> bool {
234    if let Some(slash_pos) = cidr.find('/') {
235        let network_str = &cidr[..slash_pos];
236        let prefix_str = &cidr[slash_pos + 1..];
237
238        let network: std::net::IpAddr = match network_str.parse() {
239            Ok(addr) => addr,
240            Err(_) => return false,
241        };
242        let prefix_len: u32 = match prefix_str.parse() {
243            Ok(p) => p,
244            Err(_) => return false,
245        };
246
247        match (network, ip) {
248            (std::net::IpAddr::V4(net), std::net::IpAddr::V4(addr)) => {
249                if prefix_len > 32 {
250                    return false;
251                }
252                if prefix_len == 0 {
253                    return true;
254                }
255                let mask = u32::MAX << (32 - prefix_len);
256                (u32::from(*addr) & mask) == (u32::from(net) & mask)
257            }
258            (std::net::IpAddr::V6(net), std::net::IpAddr::V6(addr)) => {
259                if prefix_len > 128 {
260                    return false;
261                }
262                if prefix_len == 0 {
263                    return true;
264                }
265                let net_bits = u128::from(net);
266                let addr_bits = u128::from(*addr);
267                let mask = u128::MAX << (128 - prefix_len);
268                (addr_bits & mask) == (net_bits & mask)
269            }
270            _ => false, // v4/v6 mismatch
271        }
272    } else {
273        // Exact IP match.
274        match cidr.parse::<std::net::IpAddr>() {
275            Ok(expected) => *ip == expected,
276            Err(_) => false,
277        }
278    }
279}
280
281// ---------------------------------------------------------------------------
282// Simple expression evaluation
283// ---------------------------------------------------------------------------
284
285/// Evaluate a simple expression against a request.
286///
287/// Supported tokens:
288/// - `{method}` — HTTP method
289/// - `{path}` — request path
290/// - `{host}` — Host header
291/// - `{remote_ip}` — client IP
292///
293/// Operators:
294/// - `==` — exact equality
295/// - `!=` — inequality
296/// - `~` — glob match
297///
298/// Combinators:
299/// - `&&` — logical AND
300/// - `||` — logical OR
301fn eval_expression(expr: &str, req: &Request<Body>, client_addr: SocketAddr) -> bool {
302    // Split by "||" first (lowest precedence), then "&&".
303    let or_parts: Vec<&str> = expr.split("||").collect();
304    for or_part in &or_parts {
305        let and_parts: Vec<&str> = or_part.split("&&").collect();
306        let all_match = and_parts
307            .iter()
308            .all(|part| eval_single_condition(part.trim(), req, client_addr));
309        if all_match {
310            return true;
311        }
312    }
313    false
314}
315
316fn eval_single_condition(cond: &str, req: &Request<Body>, client_addr: SocketAddr) -> bool {
317    // Try to parse: "{var} op value"
318    let (var, op, value) = if let Some(pos) = cond.find("!=") {
319        let var = cond[..pos].trim();
320        let value = cond[pos + 2..].trim();
321        (var, "!=", value)
322    } else if let Some(pos) = cond.find("==") {
323        let var = cond[..pos].trim();
324        let value = cond[pos + 2..].trim();
325        (var, "==", value)
326    } else if let Some(pos) = cond.find('~') {
327        let var = cond[..pos].trim();
328        let value = cond[pos + 1..].trim();
329        (var, "~", value)
330    } else {
331        // Cannot parse; treat as false.
332        return false;
333    };
334
335    let resolved = resolve_variable(var, req, client_addr);
336
337    match op {
338        "==" => resolved == value,
339        "!=" => resolved != value,
340        "~" => glob_matches(value, &resolved),
341        _ => false,
342    }
343}
344
345fn resolve_variable(var: &str, req: &Request<Body>, client_addr: SocketAddr) -> String {
346    match var.trim_matches(|c| c == '{' || c == '}') {
347        "method" => req.method().to_string(),
348        "path" => req.uri().path().to_string(),
349        "host" => req
350            .headers()
351            .get(http::header::HOST)
352            .and_then(|v| v.to_str().ok())
353            .unwrap_or("")
354            .to_string(),
355        "remote_ip" => client_addr.ip().to_string(),
356        "scheme" | "protocol" => req.uri().scheme_str().unwrap_or("http").to_string(),
357        "query" => req.uri().query().unwrap_or("").to_string(),
358        _ => String::new(),
359    }
360}
361
362// ---------------------------------------------------------------------------
363// Tests
364// ---------------------------------------------------------------------------
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_wildcard() {
372        assert!(path_matches("/*", "/anything"));
373        assert!(path_matches("/*", "/"));
374        assert!(path_matches("*", "/foo"));
375    }
376
377    #[test]
378    fn test_prefix() {
379        assert!(path_matches("/api/*", "/api/users"));
380        assert!(path_matches("/api/*", "/api/"));
381        assert!(path_matches("/api/*", "/api"));
382        assert!(!path_matches("/api/*", "/apifoo"));
383        assert!(!path_matches("/api/*", "/other"));
384    }
385
386    #[test]
387    fn test_exact() {
388        assert!(path_matches("/health", "/health"));
389        assert!(!path_matches("/health", "/health/check"));
390    }
391
392    #[test]
393    fn test_extension_match() {
394        assert!(path_matches("*.php", "/index.php"));
395        assert!(path_matches("*.php", "/app/page.php"));
396        assert!(!path_matches("*.php", "/index.html"));
397    }
398
399    #[test]
400    fn test_specificity_ordering() {
401        assert!(pattern_specificity("/api/*") > pattern_specificity("/*"));
402        assert!(pattern_specificity("/api/v1/*") > pattern_specificity("/api/*"));
403        assert!(pattern_specificity("/exact") > pattern_specificity("/api/v1/*"));
404    }
405
406    #[test]
407    fn test_glob_matches_star() {
408        assert!(glob_matches("foo*", "foobar"));
409        assert!(glob_matches("foo*", "foo"));
410        assert!(!glob_matches("foo*", "baz"));
411        assert!(glob_matches("*bar", "foobar"));
412        assert!(!glob_matches("foo*", "foo/bar"));
413    }
414
415    #[test]
416    fn test_glob_matches_double_star() {
417        assert!(glob_matches("**", "anything/at/all"));
418        assert!(glob_matches("/api/**", "/api/v1/users"));
419        assert!(glob_matches("foo/**/bar", "foo/a/b/c/bar"));
420    }
421
422    #[test]
423    fn test_glob_matches_question() {
424        assert!(glob_matches("fo?", "foo"));
425        assert!(glob_matches("fo?", "fob"));
426        assert!(!glob_matches("fo?", "fooo"));
427    }
428
429    #[test]
430    fn test_query_param() {
431        assert!(match_query_param("a=1&b=2", "a", Some("1")));
432        assert!(match_query_param("a=1&b=2", "b", None));
433        assert!(!match_query_param("a=1&b=2", "c", None));
434        assert!(!match_query_param("a=1", "a", Some("2")));
435    }
436
437    #[test]
438    fn test_cidr_match_v4() {
439        let ip: std::net::IpAddr = "192.168.1.100".parse().unwrap();
440        assert!(match_cidr("192.168.0.0/16", &ip));
441        assert!(match_cidr("192.168.1.0/24", &ip));
442        assert!(!match_cidr("10.0.0.0/8", &ip));
443        assert!(match_cidr("192.168.1.100", &ip));
444    }
445
446    #[test]
447    fn test_cidr_match_v6() {
448        let ip: std::net::IpAddr = "::1".parse().unwrap();
449        assert!(match_cidr("::1", &ip));
450        assert!(match_cidr("::0/0", &ip));
451    }
452
453    #[test]
454    fn test_request_matcher_method() {
455        let req = http::Request::builder()
456            .method("GET")
457            .uri("/test")
458            .body(crate::empty_body())
459            .unwrap();
460        let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
461
462        let matcher = RequestMatcher::Method(vec!["GET".into(), "POST".into()]);
463        assert!(matcher.matches(&req, addr));
464
465        let matcher = RequestMatcher::Method(vec!["POST".into()]);
466        assert!(!matcher.matches(&req, addr));
467    }
468
469    #[test]
470    fn test_request_matcher_query() {
471        let req = http::Request::builder()
472            .uri("/test?foo=bar&baz=1")
473            .body(crate::empty_body())
474            .unwrap();
475        let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
476
477        let matcher = RequestMatcher::Query {
478            key: "foo".into(),
479            value: Some("bar".into()),
480        };
481        assert!(matcher.matches(&req, addr));
482
483        let matcher = RequestMatcher::Query {
484            key: "baz".into(),
485            value: None,
486        };
487        assert!(matcher.matches(&req, addr));
488
489        let matcher = RequestMatcher::Query {
490            key: "missing".into(),
491            value: None,
492        };
493        assert!(!matcher.matches(&req, addr));
494    }
495
496    #[test]
497    fn test_request_matcher_not() {
498        let req = http::Request::builder()
499            .method("GET")
500            .uri("/test")
501            .body(crate::empty_body())
502            .unwrap();
503        let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
504
505        let matcher = RequestMatcher::Not(Box::new(RequestMatcher::Method(vec!["POST".into()])));
506        assert!(matcher.matches(&req, addr));
507    }
508
509    #[test]
510    fn test_request_matcher_and_or() {
511        let req = http::Request::builder()
512            .method("GET")
513            .uri("/api/test?debug=1")
514            .body(crate::empty_body())
515            .unwrap();
516        let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
517
518        let matcher = RequestMatcher::And(vec![
519            RequestMatcher::Method(vec!["GET".into()]),
520            RequestMatcher::Path("/api/*".into()),
521        ]);
522        assert!(matcher.matches(&req, addr));
523
524        let matcher = RequestMatcher::Or(vec![
525            RequestMatcher::Method(vec!["POST".into()]),
526            RequestMatcher::Path("/api/*".into()),
527        ]);
528        assert!(matcher.matches(&req, addr));
529    }
530
531    #[test]
532    fn test_request_matcher_remote_ip() {
533        let req = http::Request::builder()
534            .uri("/test")
535            .body(crate::empty_body())
536            .unwrap();
537        let addr: SocketAddr = "192.168.1.50:1234".parse().unwrap();
538
539        let matcher = RequestMatcher::RemoteIp(vec!["192.168.0.0/16".into()]);
540        assert!(matcher.matches(&req, addr));
541
542        let matcher = RequestMatcher::RemoteIp(vec!["10.0.0.0/8".into()]);
543        assert!(!matcher.matches(&req, addr));
544    }
545}