use aho_corasick::AhoCorasick;
use regex::{Regex, RegexSet};
use crate::matcher::{CompiledMatcher, GroupMode};
pub const AHO_CORASICK_THRESHOLD: usize = 8;
pub(crate) const CI_GROUP_THRESHOLD: usize = 2;
pub const REGEX_SET_THRESHOLD: usize = 3;
pub fn optimize_any_of(matchers: Vec<CompiledMatcher>) -> CompiledMatcher {
match matchers.len() {
0 => return CompiledMatcher::AnyOf(Vec::new()),
1 => {
return matchers
.into_iter()
.next()
.expect("len == 1 was just checked");
}
_ => {}
}
if matchers.len() < REGEX_SET_THRESHOLD {
return wrap_ci_group_or_anyof(matchers);
}
let mut contains_ci: Vec<String> = Vec::new();
let mut contains_cs: Vec<String> = Vec::new();
let mut regexes: Vec<Regex> = Vec::new();
let mut others: Vec<CompiledMatcher> = Vec::new();
for m in matchers {
match m {
CompiledMatcher::Contains {
value,
case_insensitive: true,
} => contains_ci.push(value),
CompiledMatcher::Contains {
value,
case_insensitive: false,
} => contains_cs.push(value),
CompiledMatcher::Regex(re) => regexes.push(re),
other => others.push(other),
}
}
let mut result: Vec<CompiledMatcher> = Vec::with_capacity(others.len() + 3);
consume_contains(&mut result, contains_ci, true);
consume_contains(&mut result, contains_cs, false);
consume_regexes(&mut result, regexes);
result.extend(others);
match result.len() {
0 => CompiledMatcher::AnyOf(Vec::new()),
1 => result
.into_iter()
.next()
.expect("len == 1 was just checked"),
_ => wrap_ci_group_or_anyof(result),
}
}
fn wrap_ci_group_or_anyof(children: Vec<CompiledMatcher>) -> CompiledMatcher {
if children.len() >= CI_GROUP_THRESHOLD && children.iter().all(is_pre_lowerable) {
CompiledMatcher::CaseInsensitiveGroup {
children,
mode: GroupMode::Any,
}
} else {
CompiledMatcher::AnyOf(children)
}
}
pub(crate) fn is_pre_lowerable(m: &CompiledMatcher) -> bool {
match m {
CompiledMatcher::Contains {
case_insensitive: true,
..
}
| CompiledMatcher::StartsWith {
case_insensitive: true,
..
}
| CompiledMatcher::EndsWith {
case_insensitive: true,
..
}
| CompiledMatcher::Exact {
case_insensitive: true,
..
}
| CompiledMatcher::AhoCorasickSet {
case_insensitive: true,
..
} => true,
CompiledMatcher::Regex(re) => regex_is_case_insensitive(re.as_str()),
CompiledMatcher::RegexSetMatch { set, .. } => {
set.patterns().iter().all(|p| regex_is_case_insensitive(p))
}
CompiledMatcher::Not(inner) => is_pre_lowerable(inner),
CompiledMatcher::AnyOf(children) | CompiledMatcher::AllOf(children) => {
children.iter().all(is_pre_lowerable)
}
CompiledMatcher::CaseInsensitiveGroup { children, .. } => {
children.iter().all(is_pre_lowerable)
}
_ => false,
}
}
fn regex_is_case_insensitive(pattern: &str) -> bool {
let bytes = pattern.as_bytes();
if bytes.len() < 4 || bytes[0] != b'(' || bytes[1] != b'?' {
return false;
}
let mut i = 2;
while i < bytes.len() {
match bytes[i] {
b'i' => return true,
b')' | b':' | b'-' => return false,
b'a'..=b'z' | b'A'..=b'Z' => {}
_ => return false,
}
i += 1;
}
false
}
fn consume_regexes(result: &mut Vec<CompiledMatcher>, regexes: Vec<Regex>) {
if regexes.is_empty() {
return;
}
if regexes.len() >= REGEX_SET_THRESHOLD {
let patterns: Vec<&str> = regexes.iter().map(Regex::as_str).collect();
if let Ok(set) = RegexSet::new(&patterns) {
result.push(CompiledMatcher::RegexSetMatch {
set,
mode: GroupMode::Any,
});
return;
}
}
for re in regexes {
result.push(CompiledMatcher::Regex(re));
}
}
fn consume_contains(result: &mut Vec<CompiledMatcher>, needles: Vec<String>, ci: bool) {
if needles.is_empty() {
return;
}
if needles.len() >= AHO_CORASICK_THRESHOLD
&& let Ok(automaton) = AhoCorasick::new(&needles)
{
result.push(CompiledMatcher::AhoCorasickSet {
automaton,
case_insensitive: ci,
needles,
});
return;
}
for value in needles {
result.push(CompiledMatcher::Contains {
value,
case_insensitive: ci,
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::event::{EventValue, JsonEvent};
use serde_json::json;
fn ci_contains(s: &str) -> CompiledMatcher {
CompiledMatcher::Contains {
value: s.to_lowercase(),
case_insensitive: true,
}
}
fn cs_contains(s: &str) -> CompiledMatcher {
CompiledMatcher::Contains {
value: s.to_string(),
case_insensitive: false,
}
}
#[test]
fn empty_input_returns_empty_anyof() {
let m = optimize_any_of(Vec::new());
assert!(matches!(m, CompiledMatcher::AnyOf(ref v) if v.is_empty()));
}
#[test]
fn singleton_unwraps() {
let m = optimize_any_of(vec![ci_contains("foo")]);
assert!(matches!(
m,
CompiledMatcher::Contains {
case_insensitive: true,
..
}
));
}
#[test]
fn below_ac_threshold_wraps_ci_group_when_all_pre_lowerable() {
let needles: Vec<_> = (0..AHO_CORASICK_THRESHOLD - 1)
.map(|i| ci_contains(&format!("p{i}")))
.collect();
let m = optimize_any_of(needles);
match m {
CompiledMatcher::CaseInsensitiveGroup {
children,
mode: GroupMode::Any,
} => assert_eq!(children.len(), AHO_CORASICK_THRESHOLD - 1),
other => panic!("expected CaseInsensitiveGroup, got {other:?}"),
}
}
#[test]
fn below_ac_threshold_keeps_anyof_when_mixed_case() {
let needles: Vec<CompiledMatcher> = vec![ci_contains("foo"), cs_contains("BAR")];
let m = optimize_any_of(needles);
assert!(matches!(m, CompiledMatcher::AnyOf(ref v) if v.len() == 2));
}
#[test]
fn at_threshold_builds_aho_corasick() {
let needles: Vec<_> = (0..AHO_CORASICK_THRESHOLD)
.map(|i| ci_contains(&format!("p{i}")))
.collect();
let m = optimize_any_of(needles);
assert!(matches!(
m,
CompiledMatcher::AhoCorasickSet {
case_insensitive: true,
..
}
));
}
#[test]
fn separate_buckets_for_ci_and_cs() {
let mut needles = Vec::new();
for i in 0..AHO_CORASICK_THRESHOLD {
needles.push(ci_contains(&format!("ci{i}")));
}
for i in 0..AHO_CORASICK_THRESHOLD {
needles.push(cs_contains(&format!("CS{i}")));
}
let m = optimize_any_of(needles);
let children = match m {
CompiledMatcher::AnyOf(v) => v,
_ => panic!("expected AnyOf wrapping two AC sets"),
};
assert_eq!(children.len(), 2);
assert!(children.iter().any(|c| matches!(
c,
CompiledMatcher::AhoCorasickSet {
case_insensitive: true,
..
}
)));
assert!(children.iter().any(|c| matches!(
c,
CompiledMatcher::AhoCorasickSet {
case_insensitive: false,
..
}
)));
}
#[test]
fn mixed_pre_lowerable_children_are_wrapped_in_ci_group() {
let mut needles: Vec<_> = (0..AHO_CORASICK_THRESHOLD)
.map(|i| ci_contains(&format!("p{i}")))
.collect();
needles.push(CompiledMatcher::StartsWith {
value: "cmd".into(),
case_insensitive: true,
});
needles.push(CompiledMatcher::EndsWith {
value: ".exe".into(),
case_insensitive: true,
});
let m = optimize_any_of(needles);
let children = match m {
CompiledMatcher::CaseInsensitiveGroup {
children,
mode: GroupMode::Any,
} => children,
other => panic!("expected CaseInsensitiveGroup, got {other:?}"),
};
assert_eq!(children.len(), 3);
assert!(matches!(
children[0],
CompiledMatcher::AhoCorasickSet { .. }
));
assert!(matches!(children[1], CompiledMatcher::StartsWith { .. }));
assert!(matches!(children[2], CompiledMatcher::EndsWith { .. }));
}
#[test]
fn ci_group_skipped_when_a_child_is_case_sensitive() {
let mut needles: Vec<_> = (0..AHO_CORASICK_THRESHOLD)
.map(|i| ci_contains(&format!("p{i}")))
.collect();
needles.push(cs_contains("EXACT"));
let m = optimize_any_of(needles);
assert!(matches!(m, CompiledMatcher::AnyOf(ref v) if v.len() == 2));
}
#[test]
fn is_pre_lowerable_classifies_correctly() {
use regex::Regex;
assert!(is_pre_lowerable(&ci_contains("foo")));
assert!(is_pre_lowerable(&CompiledMatcher::StartsWith {
value: "x".into(),
case_insensitive: true,
}));
assert!(is_pre_lowerable(&CompiledMatcher::EndsWith {
value: "x".into(),
case_insensitive: true,
}));
assert!(is_pre_lowerable(&CompiledMatcher::Exact {
value: "x".into(),
case_insensitive: true,
}));
assert!(is_pre_lowerable(&CompiledMatcher::Regex(
Regex::new(r"(?i)foo.*bar").unwrap()
)));
assert!(is_pre_lowerable(&CompiledMatcher::Regex(
Regex::new(r"(?ims)foo").unwrap()
)));
assert!(!is_pre_lowerable(&cs_contains("foo")));
assert!(!is_pre_lowerable(&CompiledMatcher::Exact {
value: "X".into(),
case_insensitive: false,
}));
assert!(!is_pre_lowerable(&CompiledMatcher::Regex(
Regex::new(r"^foo").unwrap()
)));
assert!(!is_pre_lowerable(&CompiledMatcher::NumericEq(42.0)));
assert!(!is_pre_lowerable(&CompiledMatcher::Cidr(
"10.0.0.0/8".parse().unwrap()
)));
}
#[test]
fn regex_is_case_insensitive_recognizer() {
assert!(regex_is_case_insensitive("(?i)foo"));
assert!(regex_is_case_insensitive("(?im)foo"));
assert!(regex_is_case_insensitive("(?si)foo"));
assert!(!regex_is_case_insensitive("foo"));
assert!(!regex_is_case_insensitive("(?m)foo"));
assert!(!regex_is_case_insensitive(""));
assert!(!regex_is_case_insensitive("(?"));
assert!(!regex_is_case_insensitive("(?-i)foo"));
assert!(!regex_is_case_insensitive("(?:foo)"));
}
#[test]
fn ci_group_matches_same_haystacks_as_anyof() {
let event_json = json!({});
let event = JsonEvent::borrow(&event_json);
let make_children = || -> Vec<CompiledMatcher> {
vec![
ci_contains("powershell"),
CompiledMatcher::StartsWith {
value: "cmd".to_lowercase(),
case_insensitive: true,
},
CompiledMatcher::EndsWith {
value: ".exe".to_lowercase(),
case_insensitive: true,
},
CompiledMatcher::Exact {
value: "whoami".to_lowercase(),
case_insensitive: true,
},
]
};
let optimized = optimize_any_of(make_children());
let unoptimized = CompiledMatcher::AnyOf(make_children());
assert!(matches!(
optimized,
CompiledMatcher::CaseInsensitiveGroup {
mode: GroupMode::Any,
..
}
));
for s in [
"PowerShell.exe -enc XYZ",
"CMD.exe /c whoami",
"C:/Windows/System32/notepad.exe",
"WHOAMI",
"no match",
"",
] {
let v = EventValue::Str(s.into());
assert_eq!(
optimized.matches(&v, &event),
unoptimized.matches(&v, &event),
"CI group disagrees with AnyOf on {s:?}"
);
}
}
#[test]
fn ac_matches_same_haystack_as_anyof() {
let needles_str = [
"whoami",
"mimikatz",
"powershell",
"invoke",
"iex",
"rundll32",
"regsvr32",
"certutil",
];
assert!(needles_str.len() >= AHO_CORASICK_THRESHOLD);
let optimized = optimize_any_of(needles_str.iter().map(|s| ci_contains(s)).collect());
let unoptimized =
CompiledMatcher::AnyOf(needles_str.iter().map(|s| ci_contains(s)).collect());
let event_json = json!({});
let event = JsonEvent::borrow(&event_json);
let test_strings = [
"cmd.exe /c whoami",
"Invoke-Mimikatz with PowerShell",
"no patterns here",
"RUNDLL32.EXE foo.dll",
"WHOAMI in caps",
"",
];
for s in test_strings {
let v = EventValue::Str(s.into());
assert_eq!(
optimized.matches(&v, &event),
unoptimized.matches(&v, &event),
"mismatch on haystack {s:?}"
);
}
}
fn ci_regex(pattern: &str) -> CompiledMatcher {
CompiledMatcher::Regex(regex::Regex::new(&format!("(?i){pattern}")).unwrap())
}
#[test]
fn at_threshold_builds_regex_set() {
let regexes: Vec<_> = (0..REGEX_SET_THRESHOLD)
.map(|i| ci_regex(&format!("foo{i}")))
.collect();
let m = optimize_any_of(regexes);
assert!(matches!(
m,
CompiledMatcher::RegexSetMatch {
mode: GroupMode::Any,
..
}
));
}
#[test]
fn below_regex_set_threshold_keeps_individual_regexes() {
let regexes: Vec<_> = (0..REGEX_SET_THRESHOLD - 1)
.map(|i| ci_regex(&format!("foo{i}")))
.collect();
let m = optimize_any_of(regexes);
match m {
CompiledMatcher::CaseInsensitiveGroup { children, .. } => {
assert_eq!(children.len(), REGEX_SET_THRESHOLD - 1);
assert!(
children
.iter()
.all(|c| matches!(c, CompiledMatcher::Regex(_)))
);
}
other => panic!("expected CaseInsensitiveGroup with individual regexes, got {other:?}"),
}
}
#[test]
fn regex_set_matches_same_haystacks_as_individual_regexes() {
let patterns = [r"^cmd\.exe", r"powershell", r"\.ps1$", r"mimikatz"];
let make_children =
|| -> Vec<CompiledMatcher> { patterns.iter().map(|p| ci_regex(p)).collect() };
let optimized = optimize_any_of(make_children());
let unoptimized = CompiledMatcher::AnyOf(make_children());
let event_json = json!({});
let event = JsonEvent::borrow(&event_json);
for s in [
"cmd.exe /c whoami",
"POWERSHELL.EXE -enc",
"C:/scripts/run.PS1",
"Invoke-MIMIKATZ",
"notepad.exe",
"",
] {
let v = EventValue::Str(s.into());
assert_eq!(
optimized.matches(&v, &event),
unoptimized.matches(&v, &event),
"RegexSet disagrees with AnyOf(Regex) on {s:?}"
);
}
}
#[test]
fn mixed_contains_and_regex_partitions_correctly() {
let mut input = Vec::new();
for i in 0..AHO_CORASICK_THRESHOLD {
input.push(ci_contains(&format!("c{i}")));
}
for i in 0..REGEX_SET_THRESHOLD {
input.push(ci_regex(&format!("r{i}")));
}
input.push(CompiledMatcher::Exact {
value: "EXACT".into(),
case_insensitive: false,
});
let m = optimize_any_of(input);
let children = match m {
CompiledMatcher::AnyOf(v) => v,
other => panic!("expected AnyOf, got {other:?}"),
};
assert_eq!(children.len(), 3);
assert!(
children
.iter()
.any(|c| matches!(c, CompiledMatcher::AhoCorasickSet { .. }))
);
assert!(
children
.iter()
.any(|c| matches!(c, CompiledMatcher::RegexSetMatch { .. }))
);
assert!(children.iter().any(|c| matches!(
c,
CompiledMatcher::Exact {
case_insensitive: false,
..
}
)));
}
#[test]
fn build_regex_keeps_flags_in_pattern_string() {
use super::super::helpers::build_regex;
let re = build_regex("foo", true, false, false).unwrap();
let s = re.as_str();
assert!(
s.starts_with("(?") && s.contains('i'),
"build_regex must inline case-insensitive flag into pattern string, got {s:?}"
);
let set = regex::RegexSet::new([s]).unwrap();
assert!(set.is_match("FOO"), "RegexSet lost the (?i) flag");
}
#[test]
fn regex_set_pre_lowerable_when_all_patterns_are_ci() {
let set = regex::RegexSet::new(["(?i)foo", "(?i)bar"]).unwrap();
let m = CompiledMatcher::RegexSetMatch {
set,
mode: GroupMode::Any,
};
assert!(is_pre_lowerable(&m));
}
#[test]
fn regex_set_not_pre_lowerable_when_any_pattern_is_cs() {
let set = regex::RegexSet::new(["(?i)foo", "bar"]).unwrap();
let m = CompiledMatcher::RegexSetMatch {
set,
mode: GroupMode::Any,
};
assert!(!is_pre_lowerable(&m));
}
}