1use std::net::SocketAddr;
3
4use http::Request;
5
6use crate::Body;
7use crate::glob::glob_matches;
8
9pub fn path_matches(pattern: &str, path: &str) -> bool {
21 if pattern == "/*" || pattern == "*" {
22 return true;
23 }
24 if pattern.starts_with('*') && !pattern.starts_with("**") {
26 let suffix = &pattern[1..];
27 return path.ends_with(suffix);
28 }
29 if let Some(prefix) = pattern.strip_suffix("/*") {
30 return path == prefix || path.starts_with(&format!("{prefix}/"));
32 }
33 if let Some(prefix) = pattern.strip_suffix('*') {
34 return path.starts_with(prefix);
36 }
37 path == pattern
39}
40
41pub fn pattern_specificity(pattern: &str) -> usize {
44 if pattern == "/*" || pattern == "*" {
45 return 0;
46 }
47 if pattern.starts_with('*') {
48 return pattern.len();
50 }
51 if let Some(prefix) = pattern.strip_suffix("/*") {
52 return prefix.len() + 1;
53 }
54 if pattern.ends_with('*') {
55 return pattern.len();
56 }
57 pattern.len() + 1000
59}
60
61#[derive(Debug, Clone, serde::Serialize)]
68pub enum RequestMatcher {
69 Path(String),
71 Method(Vec<String>),
73 Header { name: String, pattern: String },
75 HeaderRegex { name: String, regex: String },
77 Query { key: String, value: Option<String> },
79 RemoteIp(Vec<String>),
81 Protocol(String),
83 Expression(String),
85 Not(Box<RequestMatcher>),
87 And(Vec<RequestMatcher>),
89 Or(Vec<RequestMatcher>),
91 Language(Vec<String>),
94}
95
96impl RequestMatcher {
97 pub fn matches(&self, req: &Request<Body>, client_addr: SocketAddr) -> bool {
99 match self {
100 RequestMatcher::Path(pattern) => {
101 let path = req.uri().path();
102 path_matches(pattern, path)
103 }
104
105 RequestMatcher::Method(methods) => {
106 let req_method = req.method().as_str().to_uppercase();
107 methods.iter().any(|m| m.to_uppercase() == req_method)
108 }
109
110 RequestMatcher::Header { name, pattern } => {
111 if let Ok(header_name) = name.parse::<http::header::HeaderName>() {
112 req.headers()
113 .get(&header_name)
114 .and_then(|v| v.to_str().ok())
115 .map(|v| glob_matches(pattern, v))
116 .unwrap_or(false)
117 } else {
118 false
119 }
120 }
121
122 RequestMatcher::HeaderRegex { name, regex } => {
123 if let Ok(header_name) = name.parse::<http::header::HeaderName>() {
124 req.headers()
125 .get(&header_name)
126 .and_then(|v| v.to_str().ok())
127 .map(|v| glob_matches(regex, v))
128 .unwrap_or(false)
129 } else {
130 false
131 }
132 }
133
134 RequestMatcher::Query { key, value } => {
135 let query_str = req.uri().query().unwrap_or("");
136 match_query_param(query_str, key, value.as_deref())
137 }
138
139 RequestMatcher::RemoteIp(cidrs) => {
140 let client_ip = client_addr.ip();
141 cidrs.iter().any(|cidr| match_cidr(cidr, &client_ip))
142 }
143
144 RequestMatcher::Protocol(proto) => {
145 let scheme = req.uri().scheme_str().unwrap_or("http");
146 scheme.eq_ignore_ascii_case(proto)
147 }
148
149 RequestMatcher::Expression(expr) => eval_expression(expr, req, client_addr),
150
151 RequestMatcher::Not(inner) => !inner.matches(req, client_addr),
152
153 RequestMatcher::And(matchers) => matchers.iter().all(|m| m.matches(req, client_addr)),
154
155 RequestMatcher::Or(matchers) => matchers.iter().any(|m| m.matches(req, client_addr)),
156
157 RequestMatcher::Language(langs) => {
158 let header_value = req
161 .headers()
162 .get(http::header::ACCEPT_LANGUAGE)
163 .and_then(|v| v.to_str().ok())
164 .unwrap_or("");
165 let accepted: Vec<&str> = header_value
167 .split(',')
168 .map(|part| part.split(';').next().unwrap_or("").trim())
169 .filter(|s| !s.is_empty())
170 .collect();
171 langs.iter().any(|configured| {
172 accepted.iter().any(|accepted_lang| {
173 let c = configured.to_lowercase();
175 let a = accepted_lang.to_lowercase();
176 a == c || a.starts_with(&format!("{c}-"))
177 })
178 })
179 }
180 }
181 }
182}
183
184fn match_query_param(query: &str, key: &str, value: Option<&str>) -> bool {
195 for pair in query.split('&') {
196 if pair.is_empty() {
197 continue;
198 }
199 let (k, v) = if let Some(eq_pos) = pair.find('=') {
200 (&pair[..eq_pos], Some(&pair[eq_pos + 1..]))
201 } else {
202 (pair, None)
203 };
204 if k == key {
205 match value {
206 None => return true, Some(expected) => {
208 if v == Some(expected) {
209 return true;
210 }
211 }
212 }
213 }
214 }
215 false
216}
217
218pub fn match_cidr_pub(cidr: &str, ip: &std::net::IpAddr) -> bool {
224 match_cidr(cidr, ip)
225}
226
227fn match_cidr(cidr: &str, ip: &std::net::IpAddr) -> bool {
234 if let Some(slash_pos) = cidr.find('/') {
235 let network_str = &cidr[..slash_pos];
236 let prefix_str = &cidr[slash_pos + 1..];
237
238 let network: std::net::IpAddr = match network_str.parse() {
239 Ok(addr) => addr,
240 Err(_) => return false,
241 };
242 let prefix_len: u32 = match prefix_str.parse() {
243 Ok(p) => p,
244 Err(_) => return false,
245 };
246
247 match (network, ip) {
248 (std::net::IpAddr::V4(net), std::net::IpAddr::V4(addr)) => {
249 if prefix_len > 32 {
250 return false;
251 }
252 if prefix_len == 0 {
253 return true;
254 }
255 let mask = u32::MAX << (32 - prefix_len);
256 (u32::from(*addr) & mask) == (u32::from(net) & mask)
257 }
258 (std::net::IpAddr::V6(net), std::net::IpAddr::V6(addr)) => {
259 if prefix_len > 128 {
260 return false;
261 }
262 if prefix_len == 0 {
263 return true;
264 }
265 let net_bits = u128::from(net);
266 let addr_bits = u128::from(*addr);
267 let mask = u128::MAX << (128 - prefix_len);
268 (addr_bits & mask) == (net_bits & mask)
269 }
270 _ => false, }
272 } else {
273 match cidr.parse::<std::net::IpAddr>() {
275 Ok(expected) => *ip == expected,
276 Err(_) => false,
277 }
278 }
279}
280
281fn eval_expression(expr: &str, req: &Request<Body>, client_addr: SocketAddr) -> bool {
302 let or_parts: Vec<&str> = expr.split("||").collect();
304 for or_part in &or_parts {
305 let and_parts: Vec<&str> = or_part.split("&&").collect();
306 let all_match = and_parts
307 .iter()
308 .all(|part| eval_single_condition(part.trim(), req, client_addr));
309 if all_match {
310 return true;
311 }
312 }
313 false
314}
315
316fn eval_single_condition(cond: &str, req: &Request<Body>, client_addr: SocketAddr) -> bool {
317 let (var, op, value) = if let Some(pos) = cond.find("!=") {
319 let var = cond[..pos].trim();
320 let value = cond[pos + 2..].trim();
321 (var, "!=", value)
322 } else if let Some(pos) = cond.find("==") {
323 let var = cond[..pos].trim();
324 let value = cond[pos + 2..].trim();
325 (var, "==", value)
326 } else if let Some(pos) = cond.find('~') {
327 let var = cond[..pos].trim();
328 let value = cond[pos + 1..].trim();
329 (var, "~", value)
330 } else {
331 return false;
333 };
334
335 let resolved = resolve_variable(var, req, client_addr);
336
337 match op {
338 "==" => resolved == value,
339 "!=" => resolved != value,
340 "~" => glob_matches(value, &resolved),
341 _ => false,
342 }
343}
344
345fn resolve_variable(var: &str, req: &Request<Body>, client_addr: SocketAddr) -> String {
346 match var.trim_matches(|c| c == '{' || c == '}') {
347 "method" => req.method().to_string(),
348 "path" => req.uri().path().to_string(),
349 "host" => req
350 .headers()
351 .get(http::header::HOST)
352 .and_then(|v| v.to_str().ok())
353 .unwrap_or("")
354 .to_string(),
355 "remote_ip" => client_addr.ip().to_string(),
356 "scheme" | "protocol" => req.uri().scheme_str().unwrap_or("http").to_string(),
357 "query" => req.uri().query().unwrap_or("").to_string(),
358 _ => String::new(),
359 }
360}
361
362#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_wildcard() {
372 assert!(path_matches("/*", "/anything"));
373 assert!(path_matches("/*", "/"));
374 assert!(path_matches("*", "/foo"));
375 }
376
377 #[test]
378 fn test_prefix() {
379 assert!(path_matches("/api/*", "/api/users"));
380 assert!(path_matches("/api/*", "/api/"));
381 assert!(path_matches("/api/*", "/api"));
382 assert!(!path_matches("/api/*", "/apifoo"));
383 assert!(!path_matches("/api/*", "/other"));
384 }
385
386 #[test]
387 fn test_exact() {
388 assert!(path_matches("/health", "/health"));
389 assert!(!path_matches("/health", "/health/check"));
390 }
391
392 #[test]
393 fn test_extension_match() {
394 assert!(path_matches("*.php", "/index.php"));
395 assert!(path_matches("*.php", "/app/page.php"));
396 assert!(!path_matches("*.php", "/index.html"));
397 }
398
399 #[test]
400 fn test_specificity_ordering() {
401 assert!(pattern_specificity("/api/*") > pattern_specificity("/*"));
402 assert!(pattern_specificity("/api/v1/*") > pattern_specificity("/api/*"));
403 assert!(pattern_specificity("/exact") > pattern_specificity("/api/v1/*"));
404 }
405
406 #[test]
407 fn test_glob_matches_star() {
408 assert!(glob_matches("foo*", "foobar"));
409 assert!(glob_matches("foo*", "foo"));
410 assert!(!glob_matches("foo*", "baz"));
411 assert!(glob_matches("*bar", "foobar"));
412 assert!(!glob_matches("foo*", "foo/bar"));
413 }
414
415 #[test]
416 fn test_glob_matches_double_star() {
417 assert!(glob_matches("**", "anything/at/all"));
418 assert!(glob_matches("/api/**", "/api/v1/users"));
419 assert!(glob_matches("foo/**/bar", "foo/a/b/c/bar"));
420 }
421
422 #[test]
423 fn test_glob_matches_question() {
424 assert!(glob_matches("fo?", "foo"));
425 assert!(glob_matches("fo?", "fob"));
426 assert!(!glob_matches("fo?", "fooo"));
427 }
428
429 #[test]
430 fn test_query_param() {
431 assert!(match_query_param("a=1&b=2", "a", Some("1")));
432 assert!(match_query_param("a=1&b=2", "b", None));
433 assert!(!match_query_param("a=1&b=2", "c", None));
434 assert!(!match_query_param("a=1", "a", Some("2")));
435 }
436
437 #[test]
438 fn test_cidr_match_v4() {
439 let ip: std::net::IpAddr = "192.168.1.100".parse().unwrap();
440 assert!(match_cidr("192.168.0.0/16", &ip));
441 assert!(match_cidr("192.168.1.0/24", &ip));
442 assert!(!match_cidr("10.0.0.0/8", &ip));
443 assert!(match_cidr("192.168.1.100", &ip));
444 }
445
446 #[test]
447 fn test_cidr_match_v6() {
448 let ip: std::net::IpAddr = "::1".parse().unwrap();
449 assert!(match_cidr("::1", &ip));
450 assert!(match_cidr("::0/0", &ip));
451 }
452
453 #[test]
454 fn test_request_matcher_method() {
455 let req = http::Request::builder()
456 .method("GET")
457 .uri("/test")
458 .body(crate::empty_body())
459 .unwrap();
460 let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
461
462 let matcher = RequestMatcher::Method(vec!["GET".into(), "POST".into()]);
463 assert!(matcher.matches(&req, addr));
464
465 let matcher = RequestMatcher::Method(vec!["POST".into()]);
466 assert!(!matcher.matches(&req, addr));
467 }
468
469 #[test]
470 fn test_request_matcher_query() {
471 let req = http::Request::builder()
472 .uri("/test?foo=bar&baz=1")
473 .body(crate::empty_body())
474 .unwrap();
475 let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
476
477 let matcher = RequestMatcher::Query {
478 key: "foo".into(),
479 value: Some("bar".into()),
480 };
481 assert!(matcher.matches(&req, addr));
482
483 let matcher = RequestMatcher::Query {
484 key: "baz".into(),
485 value: None,
486 };
487 assert!(matcher.matches(&req, addr));
488
489 let matcher = RequestMatcher::Query {
490 key: "missing".into(),
491 value: None,
492 };
493 assert!(!matcher.matches(&req, addr));
494 }
495
496 #[test]
497 fn test_request_matcher_not() {
498 let req = http::Request::builder()
499 .method("GET")
500 .uri("/test")
501 .body(crate::empty_body())
502 .unwrap();
503 let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
504
505 let matcher = RequestMatcher::Not(Box::new(RequestMatcher::Method(vec!["POST".into()])));
506 assert!(matcher.matches(&req, addr));
507 }
508
509 #[test]
510 fn test_request_matcher_and_or() {
511 let req = http::Request::builder()
512 .method("GET")
513 .uri("/api/test?debug=1")
514 .body(crate::empty_body())
515 .unwrap();
516 let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
517
518 let matcher = RequestMatcher::And(vec![
519 RequestMatcher::Method(vec!["GET".into()]),
520 RequestMatcher::Path("/api/*".into()),
521 ]);
522 assert!(matcher.matches(&req, addr));
523
524 let matcher = RequestMatcher::Or(vec![
525 RequestMatcher::Method(vec!["POST".into()]),
526 RequestMatcher::Path("/api/*".into()),
527 ]);
528 assert!(matcher.matches(&req, addr));
529 }
530
531 #[test]
532 fn test_request_matcher_remote_ip() {
533 let req = http::Request::builder()
534 .uri("/test")
535 .body(crate::empty_body())
536 .unwrap();
537 let addr: SocketAddr = "192.168.1.50:1234".parse().unwrap();
538
539 let matcher = RequestMatcher::RemoteIp(vec!["192.168.0.0/16".into()]);
540 assert!(matcher.matches(&req, addr));
541
542 let matcher = RequestMatcher::RemoteIp(vec!["10.0.0.0/8".into()]);
543 assert!(!matcher.matches(&req, addr));
544 }
545}