use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, AhoCorasickKind};
use super::{
CallbackEvents, ScanCallbackResult, ScanData, ScanError, ScanEvent, StringIdentifier,
StringMatch,
};
use crate::atoms::pick_atom_in_literal;
use crate::compiler::variable::Variable;
use crate::compiler::CompilerProfile;
use crate::matcher::{AcMatchStatus, Matcher};
use crate::memory::Region;
#[derive(Debug)]
pub(crate) struct AcScan {
aho: AhoCorasick,
aho_index_to_literal_info: Vec<Vec<LiteralInfo>>,
non_handled_var_indexes: Vec<usize>,
}
#[derive(Debug)]
struct LiteralInfo {
variable_index: usize,
literal_index: usize,
slice_offset: (usize, usize),
}
impl AcScan {
pub(crate) fn new(variables: &[Variable], profile: CompilerProfile) -> Self {
let mut lits = Vec::new();
let mut known_lits = HashMap::new();
let mut aho_index_to_literal_info = Vec::new();
let mut non_handled_var_indexes = Vec::new();
for (variable_index, var) in variables.iter().enumerate() {
if var.matcher.literals.is_empty() {
non_handled_var_indexes.push(variable_index);
} else {
let mut known_literals_of_var = HashSet::new();
for (literal_index, lit) in var.matcher.literals.iter().enumerate() {
let (start, end) = pick_atom_in_literal(lit);
let mut atom = lit[start..(lit.len() - end)].to_vec();
let literal_info = LiteralInfo {
variable_index,
literal_index,
slice_offset: (start, end),
};
if !known_literals_of_var.insert((atom.clone(), start)) {
continue;
}
atom.make_ascii_lowercase();
match known_lits.entry(atom.clone()) {
Entry::Vacant(v) => {
let _r = v.insert(lits.len());
aho_index_to_literal_info.push(vec![literal_info]);
lits.push(atom);
}
Entry::Occupied(o) => {
let index = o.get();
aho_index_to_literal_info[*index].push(literal_info);
}
}
}
}
}
let mut builder = AhoCorasickBuilder::new();
let builder = builder.ascii_case_insensitive(true);
let builder = builder.kind(Some(match profile {
CompilerProfile::Speed => AhoCorasickKind::DFA,
CompilerProfile::Memory => AhoCorasickKind::ContiguousNFA,
}));
let aho = builder.build(&lits).unwrap();
Self {
aho,
aho_index_to_literal_info,
non_handled_var_indexes,
}
}
pub(super) fn scan_region<'scanner>(
&self,
region: &Region,
scanner: &'scanner super::Inner,
scan_data: &mut ScanData<'scanner, '_>,
matches: &mut [Vec<StringMatch>],
) -> Result<(), ScanError> {
#[cfg(feature = "profiling")]
if let Some(stats) = scan_data.statistics.as_mut() {
stats.nb_memory_chunks += 1;
stats.memory_scanned_size += region.mem.len();
}
for mat in self.aho.find_overlapping_iter(region.mem) {
if scan_data.check_timeout() {
return Err(ScanError::Timeout);
}
self.handle_possible_match(region, scanner, &mat, scan_data, matches)?;
}
if !self.non_handled_var_indexes.is_empty() {
#[cfg(feature = "profiling")]
let start = std::time::Instant::now();
for variable_index in &self.non_handled_var_indexes {
let var = &scanner.variables[*variable_index].matcher;
scan_single_variable(region, var, scan_data, &mut matches[*variable_index]);
}
#[cfg(feature = "profiling")]
if let Some(stats) = scan_data.statistics.as_mut() {
stats.raw_regexes_eval_duration += start.elapsed();
}
}
Ok(())
}
fn handle_possible_match<'scanner>(
&self,
region: &Region,
scanner: &'scanner super::Inner,
mat: &aho_corasick::Match,
scan_data: &mut ScanData<'scanner, '_>,
matches: &mut [Vec<StringMatch>],
) -> Result<(), ScanError> {
for literal_info in &self.aho_index_to_literal_info[mat.pattern()] {
let LiteralInfo {
variable_index,
literal_index,
slice_offset: (start_offset, end_offset),
} = *literal_info;
let var = &scanner.variables[variable_index].matcher;
#[cfg(feature = "profiling")]
if let Some(stats) = scan_data.statistics.as_mut() {
stats.nb_ac_matches += 1;
}
#[cfg(feature = "profiling")]
let start_instant = std::time::Instant::now();
let Some(start) = mat.start().checked_sub(start_offset) else {
continue;
};
let end = match mat.end().checked_add(end_offset) {
Some(v) if v <= region.mem.len() => v,
_ => continue,
};
let m = start..end;
let Some(match_type) = var.confirm_ac_literal(region.mem, &m, literal_index) else {
continue;
};
let var_matches = &mut matches[variable_index];
let start_position = match var_matches.last() {
Some(mat) if mat.base == region.start => mat.offset + 1,
_ => 0,
};
let res = var.process_ac_match(region.mem, m, start_position, match_type);
#[cfg(feature = "profiling")]
{
if let Some(stats) = scan_data.statistics.as_mut() {
stats.ac_confirm_duration += start_instant.elapsed();
}
}
match res {
AcMatchStatus::None => (),
AcMatchStatus::Multiple(v) if v.is_empty() => (),
AcMatchStatus::Multiple(found_matches) => {
var_matches.extend(found_matches.into_iter().map(|m| {
StringMatch::new(region, m, scan_data.params.match_max_length, 0)
}));
}
AcMatchStatus::Single(m) => {
let xor_key = var.get_xor_key(literal_index);
var_matches.push(StringMatch::new(
region,
m,
scan_data.params.match_max_length,
xor_key,
));
}
}
if var_matches.len() > (scan_data.params.string_max_nb_matches as usize) {
var_matches.truncate(scan_data.params.string_max_nb_matches as usize);
if (scan_data.params.callback_events & CallbackEvents::STRING_REACHED_MATCH_LIMIT).0
!= 0
&& scan_data.string_reached_match_limit.insert(variable_index)
{
if let Some(cb) = &mut scan_data.callback {
if let Some(string_identifier) =
build_string_identifier(scanner, variable_index)
{
match (cb)(ScanEvent::StringReachedMatchLimit(string_identifier)) {
ScanCallbackResult::Continue => (),
ScanCallbackResult::Abort => return Err(ScanError::CallbackAbort),
}
}
}
}
}
}
Ok(())
}
}
fn scan_single_variable(
region: &Region,
matcher: &Matcher,
scan_data: &mut ScanData,
string_matches: &mut Vec<StringMatch>,
) {
let mut offset = 0;
while offset < region.mem.len() {
let mat = matcher.find_next_match_at(region.mem, offset);
match mat {
None => break,
Some(mat) => {
offset = mat.start + 1;
string_matches.push(StringMatch::new(
region,
mat,
scan_data.params.match_max_length,
0,
));
#[allow(clippy::cast_possible_truncation)]
if (string_matches.len() as u32) >= scan_data.params.string_max_nb_matches {
break;
}
}
}
}
}
fn build_string_identifier(
scanner: &super::Inner,
variable_index: usize,
) -> Option<StringIdentifier<'_>> {
let mut index = 0;
for rule in scanner.global_rules.iter().chain(scanner.rules.iter()) {
if index + rule.nb_variables > variable_index {
return Some(StringIdentifier {
rule_namespace: scanner.namespaces[rule.namespace_index].as_ref(),
rule_name: &rule.name,
string_name: &scanner.variables[variable_index].name,
string_index: variable_index - index,
});
}
index += rule.nb_variables;
}
debug_assert!(false);
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::test_type_traits_non_clonable;
#[test]
fn test_types_traits() {
test_type_traits_non_clonable(AcScan::new(&[], CompilerProfile::Speed));
test_type_traits_non_clonable(LiteralInfo {
variable_index: 0,
literal_index: 0,
slice_offset: (0, 0),
});
}
}