use regex::Regex;
use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, Ordering};
pub struct AllowlistMatcher {
exact: HashSet<String>,
globs: Vec<String>,
regexes: Vec<(String, Regex)>,
case_sensitive: bool,
seen: AtomicU64,
}
impl AllowlistMatcher {
#[must_use]
pub fn new(patterns: Vec<String>) -> (Self, Vec<String>) {
Self::build(patterns, false)
}
#[must_use]
pub fn new_case_sensitive(patterns: Vec<String>) -> (Self, Vec<String>) {
Self::build(patterns, true)
}
fn build(patterns: Vec<String>, case_sensitive: bool) -> (Self, Vec<String>) {
let mut exact = HashSet::new();
let mut globs = Vec::new();
let mut regexes = Vec::new();
let mut warnings = Vec::new();
for pat in patterns {
if let Some(re_src) = pat.strip_prefix("regex:") {
match Regex::new(re_src) {
Ok(compiled) => regexes.push((pat, compiled)),
Err(e) => warnings.push(format!(
"allowlist pattern '{pat}' failed to compile: {e} — pattern skipped"
)),
}
continue;
}
for ch in ['^', '$', '+', '(', ')'] {
if ch == '$' && !pat.replace("${", "").contains('$') {
continue;
}
if pat.contains(ch) {
warnings.push(format!(
"allowlist pattern '{pat}' contains regex character '{ch}'; \
it is matched literally — use the 'regex:' prefix for regex syntax"
));
break;
}
}
let stored = if case_sensitive {
pat
} else {
pat.to_lowercase()
};
if stored.contains('*') {
globs.push(stored);
} else {
exact.insert(stored);
}
}
(
Self {
exact,
globs,
regexes,
case_sensitive,
seen: AtomicU64::new(0),
},
warnings,
)
}
pub fn is_allowed(&self, value: &str) -> bool {
self.match_pattern(value).is_some()
}
pub fn match_pattern<'a>(&'a self, value: &str) -> Option<&'a str> {
let normalized: std::borrow::Cow<str> = if self.case_sensitive {
std::borrow::Cow::Borrowed(value)
} else {
std::borrow::Cow::Owned(value.to_lowercase())
};
if let Some(s) = self.exact.get(normalized.as_ref()) {
self.seen.fetch_add(1, Ordering::Relaxed);
return Some(s.as_str());
}
for pat in &self.globs {
if glob_matches(pat, &normalized) {
self.seen.fetch_add(1, Ordering::Relaxed);
return Some(pat.as_str());
}
}
for (pat_str, re) in &self.regexes {
if re.is_match(value) {
self.seen.fetch_add(1, Ordering::Relaxed);
return Some(pat_str.as_str());
}
}
None
}
pub fn seen_count(&self) -> u64 {
self.seen.load(Ordering::Relaxed)
}
pub fn pattern_count(&self) -> usize {
self.exact.len() + self.globs.len() + self.regexes.len()
}
pub fn is_empty(&self) -> bool {
self.exact.is_empty() && self.globs.is_empty() && self.regexes.is_empty()
}
}
pub(crate) fn glob_matches(pattern: &str, value: &str) -> bool {
let parts: Vec<&str> = pattern.split('*').collect();
let n = parts.len();
if !value.starts_with(parts[0]) {
return false;
}
if !value.ends_with(parts[n - 1]) {
return false;
}
if n == 2 {
return value.len() >= parts[0].len() + parts[n - 1].len();
}
let mut pos = parts[0].len();
let end = value.len().saturating_sub(parts[n - 1].len());
for part in &parts[1..n - 1] {
if part.is_empty() {
continue;
}
match value[pos..end].find(part) {
Some(found) => pos += found + part.len(),
None => return false,
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
fn matcher(pats: &[&str]) -> AllowlistMatcher {
let (m, _) = AllowlistMatcher::new(pats.iter().map(|s| (*s).to_string()).collect());
m
}
fn matcher_cs(pats: &[&str]) -> AllowlistMatcher {
let (m, _) =
AllowlistMatcher::new_case_sensitive(pats.iter().map(|s| (*s).to_string()).collect());
m
}
#[test]
fn exact_match() {
let m = matcher(&["localhost", "127.0.0.1"]);
assert!(m.is_allowed("localhost"));
assert!(m.is_allowed("127.0.0.1"));
assert!(m.is_allowed("Localhost")); assert!(m.is_allowed("LOCALHOST")); assert!(!m.is_allowed("localhost2")); }
#[test]
fn exact_match_case_sensitive() {
let m = matcher_cs(&["localhost", "127.0.0.1"]);
assert!(m.is_allowed("localhost"));
assert!(!m.is_allowed("Localhost")); assert!(!m.is_allowed("LOCALHOST"));
}
#[test]
fn glob_suffix() {
let m = matcher(&["*.internal"]);
assert!(m.is_allowed("db.internal"));
assert!(m.is_allowed("staging.db.internal"));
assert!(!m.is_allowed("db.internal.evil"));
assert!(!m.is_allowed("internal"));
}
#[test]
fn glob_prefix() {
let m = matcher(&["192.168.1.*"]);
assert!(m.is_allowed("192.168.1.1"));
assert!(m.is_allowed("192.168.1.255"));
assert!(!m.is_allowed("192.168.2.1"));
assert!(m.is_allowed("192.168.1."));
}
#[test]
fn glob_middle() {
let m = matcher(&["user-*@corp.com"]);
assert!(m.is_allowed("user-alice@corp.com"));
assert!(m.is_allowed("user-bob@corp.com"));
assert!(!m.is_allowed("admin@corp.com"));
assert!(!m.is_allowed("user-alice@other.com"));
}
#[test]
fn glob_star_only() {
let m = matcher(&["*"]);
assert!(m.is_allowed("anything"));
assert!(m.is_allowed(""));
}
#[test]
fn seen_counter() {
let m = matcher(&["ok"]);
assert_eq!(m.seen_count(), 0);
m.is_allowed("ok");
m.is_allowed("ok");
m.is_allowed("not-ok");
assert_eq!(m.seen_count(), 2);
}
#[test]
fn regex_char_warning() {
let (_, warnings) = AllowlistMatcher::new(vec!["^bad$".into()]);
assert!(!warnings.is_empty());
}
#[test]
fn empty_allowlist_is_empty() {
let m = matcher(&[]);
assert!(m.is_empty());
assert!(!m.is_allowed("anything"));
}
#[test]
fn match_pattern_returns_exact_pattern() {
let m = matcher(&["localhost"]);
assert_eq!(m.match_pattern("localhost"), Some("localhost"));
assert_eq!(m.match_pattern("other"), None);
}
#[test]
fn match_pattern_returns_glob_pattern() {
let m = matcher(&["*.internal"]);
assert_eq!(m.match_pattern("db.internal"), Some("*.internal"));
assert_eq!(m.match_pattern("github.com"), None);
}
#[test]
fn match_pattern_returns_first_matching_pattern() {
let m = matcher(&["*.internal", "db.*"]);
assert_eq!(m.match_pattern("db.internal"), Some("*.internal"));
}
#[test]
fn match_pattern_increments_seen_counter() {
let m = matcher(&["ok"]);
assert_eq!(m.seen_count(), 0);
m.match_pattern("ok");
assert_eq!(m.seen_count(), 1);
m.match_pattern("not-ok");
assert_eq!(m.seen_count(), 1);
}
#[test]
fn is_allowed_delegates_to_match_pattern() {
let m = matcher(&["*.internal"]);
assert!(m.is_allowed("db.internal"));
assert!(!m.is_allowed("github.com"));
assert_eq!(m.seen_count(), 1);
}
#[test]
fn glob_multiple_wildcards() {
let m = matcher(&["a*b*c"]);
assert!(m.is_allowed("abc"));
assert!(m.is_allowed("aXbYc"));
assert!(m.is_allowed("aXXXbYYYc"));
assert!(!m.is_allowed("abX"));
assert!(!m.is_allowed("Xbc"));
}
#[test]
fn glob_adjacent_wildcards_treated_as_one() {
let m = matcher(&["a**b"]);
assert!(m.is_allowed("ab"));
assert!(m.is_allowed("aXb"));
assert!(!m.is_allowed("ba"));
}
#[test]
fn glob_empty_value_only_matches_star() {
let m = matcher(&["*"]);
assert!(m.is_allowed(""));
let m2 = matcher(&["a*"]);
assert!(!m2.is_allowed(""));
}
#[test]
fn glob_prefix_suffix_overlap_rejected() {
let m = matcher(&["a*b"]);
assert!(!m.is_allowed("a"));
assert!(!m.is_allowed("b"));
assert!(m.is_allowed("ab"));
assert!(m.is_allowed("aXb"));
}
#[test]
fn large_exact_list_all_match() {
let words: Vec<String> = (0..500).map(|i| format!("word{i}")).collect();
let (m, _) = AllowlistMatcher::new(words.clone());
for w in &words {
assert!(m.is_allowed(w), "should allow {w}");
}
assert!(!m.is_allowed("word500"));
assert!(!m.is_allowed("notaword"));
}
#[test]
fn exact_and_glob_coexist() {
let m = matcher(&["localhost", "127.0.0.1", "*.internal"]);
assert!(m.is_allowed("localhost"));
assert!(m.is_allowed("127.0.0.1"));
assert!(m.is_allowed("db.internal"));
assert!(!m.is_allowed("github.com"));
}
#[test]
fn regex_basic_match() {
let m = matcher(&["regex:^192\\.168\\.[0-9]+\\.[0-9]+$"]);
assert!(m.is_allowed("192.168.1.1"));
assert!(m.is_allowed("192.168.100.255"));
assert!(!m.is_allowed("192.168.1.")); assert!(!m.is_allowed("10.0.0.1"));
}
#[test]
fn regex_substring_match_without_anchors() {
let m = matcher(&["regex:internal"]);
assert!(m.is_allowed("db.internal.corp"));
assert!(m.is_allowed("internal"));
assert!(!m.is_allowed("external"));
}
#[test]
fn regex_anchored_full_match() {
let m = matcher(&["regex:^token-[A-Z]{3}-[0-9]{4}$"]);
assert!(m.is_allowed("token-ABC-1234"));
assert!(!m.is_allowed("token-AB-1234")); assert!(!m.is_allowed("xtoken-ABC-1234")); }
#[test]
fn regex_case_sensitive_by_default() {
let m = matcher(&["regex:^localhost$"]);
assert!(m.is_allowed("localhost"));
assert!(!m.is_allowed("LOCALHOST"));
assert!(!m.is_allowed("Localhost"));
}
#[test]
fn regex_case_insensitive_via_flag() {
let m = matcher(&["regex:(?i)^localhost$"]);
assert!(m.is_allowed("localhost"));
assert!(m.is_allowed("LOCALHOST"));
assert!(m.is_allowed("LocalHost"));
}
#[test]
fn regex_invalid_pattern_produces_warning_and_is_skipped() {
let (m, warnings) = AllowlistMatcher::new(vec!["regex:[invalid".into()]);
assert!(!warnings.is_empty(), "invalid regex must produce a warning");
assert!(warnings[0].contains("failed to compile"));
assert!(!m.is_allowed("anything"));
assert_eq!(m.pattern_count(), 0);
}
#[test]
fn regex_match_pattern_returns_full_prefixed_string() {
let m = matcher(&["regex:^10\\.0\\."]);
assert_eq!(m.match_pattern("10.0.1.5"), Some("regex:^10\\.0\\."),);
assert_eq!(m.match_pattern("192.168.1.1"), None);
}
#[test]
fn regex_seen_counter_increments() {
let m = matcher(&["regex:^test"]);
assert_eq!(m.seen_count(), 0);
m.is_allowed("test-value");
m.is_allowed("test-value");
m.is_allowed("other");
assert_eq!(m.seen_count(), 2);
}
#[test]
fn regex_coexists_with_exact_and_glob() {
let m = matcher(&[
"localhost",
"*.internal",
"regex:^10\\.[0-9]+\\.[0-9]+\\.[0-9]+$",
]);
assert!(m.is_allowed("localhost"));
assert!(m.is_allowed("db.internal"));
assert!(m.is_allowed("10.0.0.1"));
assert!(m.is_allowed("10.255.255.255"));
assert!(!m.is_allowed("192.168.1.1"));
assert!(!m.is_allowed("github.com"));
assert_eq!(m.pattern_count(), 3);
}
#[test]
fn regex_not_subject_to_case_insensitive_lowercasing() {
let m = matcher(&["regex:^[A-Z]{3}$"]); assert!(m.is_allowed("ABC"));
assert!(!m.is_allowed("abc")); }
#[test]
fn metacharacter_warning_updated_to_suggest_regex_prefix() {
let (_, warnings) = AllowlistMatcher::new(vec!["^bad$".into()]);
assert!(!warnings.is_empty());
assert!(
warnings[0].contains("regex:"),
"warning should suggest regex: prefix, got: {}",
warnings[0],
);
}
}