use std::{borrow::Cow, collections::HashMap};
use super::{
SimpleResult,
pattern::PatternKind,
state::{ScanContext, ScanState, init_matrix},
};
use crate::process::ProcessType;
pub type SimpleTable<'a> = HashMap<ProcessType, HashMap<u32, &'a str>>;
pub type SimpleTableSerde<'a> = HashMap<ProcessType, HashMap<u32, Cow<'a, str>>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub(super) enum SatisfactionMethod {
Immediate = 0,
Bitmask = 1,
Matrix = 2,
}
impl SatisfactionMethod {
#[inline(always)]
pub(super) fn use_matrix(self) -> bool {
matches!(self, Self::Matrix)
}
}
#[derive(Debug, Clone)]
pub(super) struct Rule {
pub(super) segment_counts: Vec<i32>,
pub(super) word_id: u32,
pub(super) word: String,
}
#[derive(Debug, Clone, Copy)]
pub(super) struct RuleInfo {
pub(super) and_count: u8,
pub(super) method: SatisfactionMethod,
pub(super) has_not: bool,
}
#[derive(Clone)]
pub(super) struct RuleSet {
rules: Vec<Rule>,
rule_info: Vec<RuleInfo>,
}
impl RuleSet {
pub(super) fn new(rules: Vec<Rule>, rule_info: Vec<RuleInfo>) -> Self {
debug_assert_eq!(rules.len(), rule_info.len());
Self { rules, rule_info }
}
pub(super) fn rule_info(&self) -> &[RuleInfo] {
&self.rule_info
}
#[inline(always)]
pub(super) fn info(&self, rule_idx: usize) -> RuleInfo {
unsafe {
core::hint::assert_unchecked(rule_idx < self.rule_info.len());
*self.rule_info.get_unchecked(rule_idx)
}
}
pub(super) fn heap_bytes(&self) -> usize {
let inner: usize = self
.rules
.iter()
.map(|r| r.segment_counts.capacity() * size_of::<i32>() + r.word.capacity())
.sum();
self.rules.capacity() * size_of::<Rule>()
+ self.rule_info.capacity() * size_of::<RuleInfo>()
+ inner
}
#[inline(always)]
pub(super) fn len(&self) -> usize {
self.rules.len()
}
#[inline(always)]
pub(super) fn has_match(&self, ss: &ScanState<'_>) -> bool {
ss.touched_indices()
.iter()
.any(|&rule_idx| ss.rule_is_satisfied(rule_idx))
}
pub(super) fn collect_matches<'a>(
&'a self,
ss: &ScanState<'_>,
results: &mut Vec<SimpleResult<'a>>,
) {
for &rule_idx in ss.touched_indices() {
if ss.rule_is_satisfied(rule_idx) {
self.push_result(rule_idx, results);
}
}
}
pub(super) fn for_each_satisfied<'a>(
&'a self,
ss: &ScanState<'_>,
mut on_match: impl FnMut(SimpleResult<'a>) -> bool,
) -> bool {
for &rule_idx in ss.touched_indices() {
if ss.rule_is_satisfied(rule_idx) && on_match(self.result_at(rule_idx)) {
return true;
}
}
false
}
#[inline(always)]
pub(super) fn result_at<'a>(&'a self, rule_idx: usize) -> SimpleResult<'a> {
let rule = unsafe {
core::hint::assert_unchecked(rule_idx < self.rules.len());
self.rules.get_unchecked(rule_idx)
};
SimpleResult {
word_id: rule.word_id,
word: Cow::Borrowed(&rule.word),
}
}
#[inline(always)]
pub(super) fn eval_hit(
&self,
rule_idx: usize,
kind: PatternKind,
offset: usize,
ctx: ScanContext,
ss: &mut ScanState<'_>,
) -> bool {
let generation = ss.generation;
let info = self.info(rule_idx);
unsafe {
core::hint::assert_unchecked(rule_idx < ss.word_states.len());
core::hint::assert_unchecked(rule_idx < self.rules.len());
}
let ws = unsafe { ss.word_states.get_unchecked_mut(rule_idx) };
if matches!(kind, PatternKind::Not) {
if ws.generation == generation {
if ws.vetoed {
return false;
}
} else {
ws.generation = generation;
ws.remaining_and = info.and_count as u16;
ws.vetoed = false;
unsafe { *ss.satisfied_masks.get_unchecked_mut(rule_idx) = 0 };
ss.touched_indices.push(rule_idx);
if info.method.use_matrix() {
let rule = unsafe { self.rules.get_unchecked(rule_idx) };
init_matrix(
unsafe { ss.matrix.get_unchecked_mut(rule_idx) },
unsafe { ss.matrix_status.get_unchecked_mut(rule_idx) },
&rule.segment_counts,
ctx.num_variants,
);
}
}
if info.method.use_matrix() {
let flat_matrix = unsafe { ss.matrix.get_unchecked_mut(rule_idx) };
let flat_status = unsafe { ss.matrix_status.get_unchecked_mut(rule_idx) };
let counter = &mut flat_matrix[offset * ctx.num_variants + ctx.text_index];
*counter += 1;
if flat_status[offset] == 0 && *counter > 0 {
flat_status[offset] = 1;
ws.vetoed = true;
}
} else {
ws.vetoed = true;
}
return false;
}
if ws.generation == generation {
if info.has_not && ws.vetoed {
return false;
}
if ws.remaining_and == 0 {
return !info.has_not && ctx.exit_early;
}
} else {
ws.generation = generation;
ws.remaining_and = info.and_count as u16;
ws.vetoed = false;
unsafe { *ss.satisfied_masks.get_unchecked_mut(rule_idx) = 0 };
ss.touched_indices.push(rule_idx);
if info.method.use_matrix() {
let rule = unsafe { self.rules.get_unchecked(rule_idx) };
init_matrix(
unsafe { ss.matrix.get_unchecked_mut(rule_idx) },
unsafe { ss.matrix_status.get_unchecked_mut(rule_idx) },
&rule.segment_counts,
ctx.num_variants,
);
}
}
let is_satisfied = match info.method {
SatisfactionMethod::Matrix => {
let flat_matrix = unsafe { ss.matrix.get_unchecked_mut(rule_idx) };
let flat_status = unsafe { ss.matrix_status.get_unchecked_mut(rule_idx) };
let counter = &mut flat_matrix[offset * ctx.num_variants + ctx.text_index];
*counter -= 1;
if flat_status[offset] == 0 && *counter <= 0 {
flat_status[offset] = 1;
ws.remaining_and -= 1;
}
ws.remaining_and == 0
}
SatisfactionMethod::Immediate => {
ws.remaining_and = 0;
true
}
SatisfactionMethod::Bitmask => {
let bit = 1u64 << offset;
let mask = unsafe { ss.satisfied_masks.get_unchecked_mut(rule_idx) };
if *mask & bit == 0 {
*mask |= bit;
ws.remaining_and -= 1;
}
ws.remaining_and == 0
}
};
ctx.exit_early && is_satisfied && !info.has_not && !ws.vetoed
}
#[inline(always)]
fn push_result<'a>(&'a self, rule_idx: usize, results: &mut Vec<SimpleResult<'a>>) {
let rule = unsafe {
core::hint::assert_unchecked(rule_idx < self.rules.len());
self.rules.get_unchecked(rule_idx)
};
results.push(SimpleResult {
word_id: rule.word_id,
word: Cow::Borrowed(&rule.word),
});
}
}
#[cfg(test)]
mod tests {
use super::{super::state::SimpleMatchState, *};
fn make_ctx(exit_early: bool) -> ScanContext {
ScanContext {
text_index: 0,
process_type_mask: u64::MAX,
num_variants: 1,
exit_early,
non_ascii_density: 0.0,
}
}
fn make_simple_ruleset(word_id: u32, word: &str) -> RuleSet {
RuleSet::new(
vec![Rule {
segment_counts: vec![1],
word_id,
word: word.to_owned(),
}],
vec![RuleInfo {
and_count: 1,
method: SatisfactionMethod::Immediate,
has_not: false,
}],
)
}
#[test]
fn test_eval_hit_simple_kind() {
let rules = make_simple_ruleset(1, "hello");
let mut state = SimpleMatchState::new();
state.prepare(1);
let mut ss = state.as_scan_state();
let result = rules.eval_hit(0, PatternKind::And, 0, make_ctx(true), &mut ss);
assert!(result, "Simple AND with exit_early should return true");
assert!(ss.rule_is_satisfied(0));
let result2 = rules.eval_hit(0, PatternKind::And, 0, make_ctx(true), &mut ss);
assert!(
result2,
"already-satisfied Simple should still return exit_early"
);
}
#[test]
fn test_eval_hit_and_bitmask() {
let rules = RuleSet::new(
vec![Rule {
segment_counts: vec![1, 1, 1],
word_id: 1,
word: "a&b&c".to_owned(),
}],
vec![RuleInfo {
and_count: 3,
method: SatisfactionMethod::Bitmask,
has_not: false,
}],
);
let mut state = SimpleMatchState::new();
state.prepare(1);
let mut ss = state.as_scan_state();
let ctx = make_ctx(true);
assert!(!rules.eval_hit(0, PatternKind::And, 0, ctx, &mut ss));
assert!(!ss.rule_is_satisfied(0));
assert!(!rules.eval_hit(0, PatternKind::And, 1, ctx, &mut ss));
assert!(!ss.rule_is_satisfied(0));
assert!(rules.eval_hit(0, PatternKind::And, 2, ctx, &mut ss));
assert!(ss.rule_is_satisfied(0));
}
#[test]
fn test_eval_hit_not_veto() {
let rules = RuleSet::new(
vec![Rule {
segment_counts: vec![1, 0],
word_id: 1,
word: "a~b".to_owned(),
}],
vec![RuleInfo {
and_count: 1,
method: SatisfactionMethod::Immediate,
has_not: true,
}],
);
let mut state = SimpleMatchState::new();
state.prepare(1);
let mut ss = state.as_scan_state();
let ctx = make_ctx(false);
rules.eval_hit(0, PatternKind::And, 0, ctx, &mut ss);
assert!(ss.rule_is_satisfied(0));
rules.eval_hit(0, PatternKind::Not, 1, ctx, &mut ss);
assert!(!ss.rule_is_satisfied(0), "NOT should veto the rule");
}
#[test]
fn test_eval_hit_matrix_counters() {
let rules = RuleSet::new(
vec![Rule {
segment_counts: vec![2, 1],
word_id: 1,
word: "a&a&b".to_owned(),
}],
vec![RuleInfo {
and_count: 2,
method: SatisfactionMethod::Matrix,
has_not: false,
}],
);
let mut state = SimpleMatchState::new();
state.prepare(1);
let mut ss = state.as_scan_state();
let ctx = make_ctx(true);
assert!(!rules.eval_hit(0, PatternKind::And, 0, ctx, &mut ss));
assert!(!ss.rule_is_satisfied(0));
assert!(!rules.eval_hit(0, PatternKind::And, 1, ctx, &mut ss));
assert!(!ss.rule_is_satisfied(0));
assert!(rules.eval_hit(0, PatternKind::And, 0, ctx, &mut ss));
assert!(ss.rule_is_satisfied(0));
}
}