use ipnet::IpNet;
use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::str::FromStr;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct HeaderMatch {
pub name: String,
#[serde(default)]
pub value: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct RequestMatcher {
#[serde(default)]
pub source_ips: Vec<String>,
#[serde(default)]
pub headers: Vec<HeaderMatch>,
#[serde(default)]
pub min_body_size_bytes: Option<usize>,
#[serde(default)]
pub max_body_size_bytes: Option<usize>,
#[serde(default)]
pub chunked_only: Option<bool>,
}
impl RequestMatcher {
pub fn is_empty(&self) -> bool {
self.source_ips.is_empty()
&& self.headers.is_empty()
&& self.min_body_size_bytes.is_none()
&& self.max_body_size_bytes.is_none()
&& self.chunked_only.is_none()
}
pub fn matches<'a, I>(
&self,
client_ip: Option<&str>,
headers: I,
body_size: Option<usize>,
is_chunked: bool,
) -> bool
where
I: IntoIterator<Item = (&'a str, &'a str)> + Clone,
{
if self.is_empty() {
return true;
}
if !self.source_ips.is_empty() {
let ok = client_ip
.and_then(|s| IpAddr::from_str(s).ok())
.map(|ip| self.source_ips.iter().any(|cidr| ip_in_cidr(ip, cidr)))
.unwrap_or(false);
if !ok {
return false;
}
}
for hm in &self.headers {
let needle = hm.name.to_ascii_lowercase();
let mut found = false;
for (k, v) in headers.clone() {
if k.eq_ignore_ascii_case(&needle) {
match &hm.value {
None => {
found = true;
break;
}
Some(expected) if v == expected => {
found = true;
break;
}
_ => continue,
}
}
}
if !found {
return false;
}
}
if let Some(min) = self.min_body_size_bytes {
if body_size.unwrap_or(0) < min {
return false;
}
}
if let Some(max) = self.max_body_size_bytes {
if body_size.unwrap_or(0) > max {
return false;
}
}
if let Some(want) = self.chunked_only {
if want != is_chunked {
return false;
}
}
true
}
}
fn ip_in_cidr(ip: IpAddr, cidr: &str) -> bool {
if let Ok(net) = IpNet::from_str(cidr) {
return net.contains(&ip);
}
if let Ok(single) = IpAddr::from_str(cidr) {
return single == ip;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
fn h(pairs: &[(&str, &str)]) -> Vec<(String, String)> {
pairs.iter().map(|(k, v)| (k.to_string(), v.to_string())).collect()
}
fn iter(v: &[(String, String)]) -> impl IntoIterator<Item = (&str, &str)> + Clone {
v.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect::<Vec<_>>()
}
#[test]
fn empty_matcher_matches_everything() {
let m = RequestMatcher::default();
assert!(m.is_empty());
let hs = h(&[]);
assert!(m.matches(None, iter(&hs), None, false));
assert!(m.matches(Some("8.8.8.8"), iter(&hs), Some(1024), true));
}
#[test]
fn cidr_v4_match() {
let m = RequestMatcher {
source_ips: vec!["10.0.0.0/8".into()],
..Default::default()
};
let hs = h(&[]);
assert!(m.matches(Some("10.5.6.7"), iter(&hs), None, false));
assert!(!m.matches(Some("11.0.0.1"), iter(&hs), None, false));
assert!(!m.matches(None, iter(&hs), None, false));
}
#[test]
fn bare_ip_treated_as_host() {
let m = RequestMatcher {
source_ips: vec!["127.0.0.1".into()],
..Default::default()
};
let hs = h(&[]);
assert!(m.matches(Some("127.0.0.1"), iter(&hs), None, false));
assert!(!m.matches(Some("127.0.0.2"), iter(&hs), None, false));
}
#[test]
fn cidr_v6_match() {
let m = RequestMatcher {
source_ips: vec!["2001:db8::/32".into()],
..Default::default()
};
let hs = h(&[]);
assert!(m.matches(Some("2001:db8::1"), iter(&hs), None, false));
assert!(!m.matches(Some("2001:db9::1"), iter(&hs), None, false));
}
#[test]
fn header_presence_only() {
let m = RequestMatcher {
headers: vec![HeaderMatch {
name: "x-test".into(),
value: None,
}],
..Default::default()
};
let with = h(&[("x-test", "anything")]);
let without = h(&[("x-other", "v")]);
assert!(m.matches(None, iter(&with), None, false));
assert!(!m.matches(None, iter(&without), None, false));
}
#[test]
fn header_exact_value_case_insensitive_name() {
let m = RequestMatcher {
headers: vec![HeaderMatch {
name: "X-Test".into(),
value: Some("abc".into()),
}],
..Default::default()
};
let good = h(&[("x-test", "abc")]);
let bad = h(&[("x-test", "xyz")]);
assert!(m.matches(None, iter(&good), None, false));
assert!(!m.matches(None, iter(&bad), None, false));
}
#[test]
fn body_size_threshold() {
let m = RequestMatcher {
min_body_size_bytes: Some(1024),
..Default::default()
};
let hs = h(&[]);
assert!(m.matches(None, iter(&hs), Some(2048), false));
assert!(!m.matches(None, iter(&hs), Some(512), false));
assert!(!m.matches(None, iter(&hs), None, false));
let m2 = RequestMatcher {
max_body_size_bytes: Some(1024),
..Default::default()
};
assert!(m2.matches(None, iter(&hs), Some(512), false));
assert!(!m2.matches(None, iter(&hs), Some(2048), false));
}
#[test]
fn chunked_only() {
let m_chunked = RequestMatcher {
chunked_only: Some(true),
..Default::default()
};
let m_unchunked = RequestMatcher {
chunked_only: Some(false),
..Default::default()
};
let hs = h(&[]);
assert!(m_chunked.matches(None, iter(&hs), None, true));
assert!(!m_chunked.matches(None, iter(&hs), None, false));
assert!(!m_unchunked.matches(None, iter(&hs), None, true));
assert!(m_unchunked.matches(None, iter(&hs), None, false));
}
#[test]
fn and_semantics_across_fields() {
let m = RequestMatcher {
source_ips: vec!["10.0.0.0/8".into()],
headers: vec![HeaderMatch {
name: "x-test".into(),
value: Some("yes".into()),
}],
min_body_size_bytes: Some(100),
chunked_only: Some(true),
..Default::default()
};
let hs = h(&[("x-test", "yes")]);
assert!(m.matches(Some("10.1.1.1"), iter(&hs), Some(200), true));
assert!(!m.matches(Some("8.8.8.8"), iter(&hs), Some(200), true));
let bad_hs = h(&[("x-test", "no")]);
assert!(!m.matches(Some("10.1.1.1"), iter(&bad_hs), Some(200), true));
assert!(!m.matches(Some("10.1.1.1"), iter(&hs), Some(50), true));
assert!(!m.matches(Some("10.1.1.1"), iter(&hs), Some(200), false));
}
#[test]
fn invalid_cidr_does_not_panic() {
let m = RequestMatcher {
source_ips: vec!["not-an-ip".into()],
..Default::default()
};
let hs = h(&[]);
assert!(!m.matches(Some("1.2.3.4"), iter(&hs), None, false));
}
}