use aho_corasick::AhoCorasick;
use regex::bytes::Regex;
use crate::pattern::{PatternDef, PatternKind};
use crate::MatchError;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MatchResult {
pub pattern_id: usize,
pub start: usize,
pub end: usize,
}
pub struct MatchEngine {
ac_exact: Option<AhoCorasick>,
ac_exact_map: Vec<usize>,
ac_ci: Option<AhoCorasick>,
ac_ci_map: Vec<usize>,
regexes: Vec<(usize, Regex)>,
}
impl MatchEngine {
pub fn compile(patterns: Vec<PatternDef>) -> Result<Self, MatchError> {
let mut exact_literals: Vec<(String, usize)> = Vec::new();
let mut regexes: Vec<(usize, Regex)> = Vec::new();
for pat in &patterns {
match &pat.kind {
PatternKind::Literal(s) => {
if pat.case_insensitive {
let pattern_str = format!("(?i){}", regex::escape(s));
let compiled = Regex::new(&pattern_str).map_err(|e| {
MatchError::InvalidRegex {
id: pat.id,
source: regex::Error::Syntax(e.to_string()),
}
})?;
regexes.push((pat.id, compiled));
} else {
exact_literals.push((s.clone(), pat.id));
}
}
PatternKind::Regex(r) => {
let pattern_str = if pat.case_insensitive {
format!("(?i){r}")
} else {
r.clone()
};
let compiled = Regex::new(&pattern_str).map_err(|e| {
MatchError::InvalidRegex {
id: pat.id,
source: regex::Error::Syntax(e.to_string()),
}
})?;
regexes.push((pat.id, compiled));
}
}
}
let (ac_exact, ac_exact_map) = Self::build_ac(exact_literals, false)?;
Ok(Self {
ac_exact,
ac_exact_map,
ac_ci: None,
ac_ci_map: Vec::new(),
regexes,
})
}
fn build_ac(literals: Vec<(String, usize)>, ci: bool) -> Result<(Option<AhoCorasick>, Vec<usize>), MatchError> {
if literals.is_empty() {
return Ok((None, Vec::new()));
}
let strs: Vec<&str> = literals.iter().map(|(s, _)| s.as_str()).collect();
let id_map: Vec<usize> = literals.iter().map(|(_, id)| *id).collect();
let mut builder = aho_corasick::AhoCorasick::builder();
builder.ascii_case_insensitive(ci);
let ac = builder.build(&strs).map_err(|e| MatchError::AhoCorasick(e.to_string()))?;
Ok((Some(ac), id_map))
}
pub fn scan(&self, input: &[u8]) -> Vec<MatchResult> {
let mut results = Vec::with_capacity(
self.ac_exact_map.len() + self.ac_ci_map.len() + self.regexes.len()
);
if let Some(ac) = &self.ac_exact {
for mat in ac.find_overlapping_iter(input) {
let ac_idx = mat.pattern().as_usize();
if let Some(&user_id) = self.ac_exact_map.get(ac_idx) {
results.push(MatchResult {
pattern_id: user_id,
start: mat.start(),
end: mat.end(),
});
}
}
}
if let Some(ac) = &self.ac_ci {
for mat in ac.find_overlapping_iter(input) {
let ac_idx = mat.pattern().as_usize();
if let Some(&user_id) = self.ac_ci_map.get(ac_idx) {
results.push(MatchResult {
pattern_id: user_id,
start: mat.start(),
end: mat.end(),
});
}
}
}
for (user_id, regex) in &self.regexes {
for mat in regex.find_iter(input) {
results.push(MatchResult {
pattern_id: *user_id,
start: mat.start(),
end: mat.end(),
});
}
}
results
}
pub fn is_match(&self, input: &[u8]) -> bool {
if let Some(ac) = &self.ac_exact {
if ac.is_match(input) {
return true;
}
}
if let Some(ac) = &self.ac_ci {
if ac.is_match(input) {
return true;
}
}
self.regexes.iter().any(|(_, r)| r.is_match(input))
}
}
#[cfg(test)]
mod tests {
use crate::{PatternSet, Scanner};
#[test]
fn literal_match() {
let ps = PatternSet::builder()
.add_literal("password", 0)
.add_literal("secret", 1)
.build()
.unwrap();
let matches = ps.scan(b"my password is secret");
assert!(matches.iter().any(|m| m.pattern_id == 0));
assert!(matches.iter().any(|m| m.pattern_id == 1));
}
#[test]
fn literal_no_match() {
let ps = PatternSet::builder()
.add_literal("foobar", 0)
.build()
.unwrap();
let matches = ps.scan(b"nothing here");
assert!(matches.is_empty());
}
#[test]
fn regex_match() {
let ps = PatternSet::builder()
.add_regex(r"[0-9]{3}-[0-9]{4}", 0)
.build()
.unwrap();
let matches = ps.scan(b"call 555-1234 now");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_id, 0);
assert_eq!(matches[0].start, 5);
assert_eq!(matches[0].end, 13);
}
#[test]
fn mixed_literal_and_regex() {
let ps = PatternSet::builder()
.add_literal("token", 0)
.add_regex(r"[A-Za-z0-9]{20,}", 1)
.build()
.unwrap();
let matches = ps.scan(b"token=abcdefghij1234567890XYZ");
assert!(matches.iter().any(|m| m.pattern_id == 0)); assert!(matches.iter().any(|m| m.pattern_id == 1)); }
#[test]
fn case_insensitive_literal() {
let ps = PatternSet::builder()
.add_literal_ci("SECRET", 0)
.build()
.unwrap();
let matches = ps.scan(b"my secret key");
assert_eq!(matches.len(), 1);
}
#[test]
fn case_insensitive_regex() {
let ps = PatternSet::builder()
.add_regex_ci("password", 0)
.build()
.unwrap();
let matches = ps.scan(b"my PASSWORD is here");
assert_eq!(matches.len(), 1);
}
#[test]
fn is_match_short_circuits() {
let ps = PatternSet::builder()
.add_literal("needle", 0)
.build()
.unwrap();
assert!(ps.is_match(b"find the needle"));
assert!(!ps.is_match(b"nothing here"));
}
#[test]
fn overlapping_literals() {
let ps = PatternSet::builder()
.add_literal("ab", 0)
.add_literal("bc", 1)
.build()
.unwrap();
let matches = ps.scan(b"abc");
assert!(matches.iter().any(|m| m.pattern_id == 0));
assert!(matches.iter().any(|m| m.pattern_id == 1));
}
#[test]
fn multiple_regex_matches() {
let ps = PatternSet::builder()
.add_regex(r"\d+", 0)
.build()
.unwrap();
let matches = ps.scan(b"a1b22c333");
assert_eq!(matches.len(), 3);
}
#[test]
fn empty_input() {
let ps = PatternSet::builder()
.add_literal("x", 0)
.build()
.unwrap();
assert!(ps.scan(b"").is_empty());
assert!(!ps.is_match(b""));
}
#[test]
fn scan_str_convenience() {
let ps = PatternSet::builder()
.add_literal("hello", 0)
.build()
.unwrap();
let matches = ps.scan_str("say hello world");
assert_eq!(matches.len(), 1);
}
#[test]
fn invalid_regex_errors() {
let result = PatternSet::builder()
.add_regex("[invalid(", 0)
.build();
assert!(result.is_err());
}
#[test]
fn large_pattern_set() {
let mut builder = PatternSet::builder();
for i in 0..100 {
builder = builder.add_literal(&format!("pattern_{i}"), i);
}
let ps = builder.build().unwrap();
assert_eq!(ps.pattern_count(), 100);
let matches = ps.scan(b"contains pattern_42 and pattern_99");
assert!(matches.iter().any(|m| m.pattern_id == 42));
assert!(matches.iter().any(|m| m.pattern_id == 99));
}
#[test]
fn match_result_offsets_correct() {
let ps = PatternSet::builder()
.add_literal("xyz", 0)
.build()
.unwrap();
let matches = ps.scan(b"--xyz--");
assert_eq!(matches[0].start, 2);
assert_eq!(matches[0].end, 5);
}
#[test]
fn binary_input() {
let ps = PatternSet::builder()
.add_literal("ELF", 0)
.build()
.unwrap();
let input = b"\x7fELF\x02\x01\x01\x00";
let matches = ps.scan(input);
assert_eq!(matches.len(), 1);
}
#[test]
fn unicode_input() {
let ps = PatternSet::builder()
.add_literal("пароль", 0) .build()
.unwrap();
let matches = ps.scan("мой пароль тут".as_bytes());
assert_eq!(matches.len(), 1);
}
}