use std::cell::UnsafeCell;
#[derive(Default, Clone, Copy)]
pub(super) struct RuleState {
pub(super) generation: u16,
pub(super) remaining_and: u16,
pub(super) vetoed: bool,
pub(super) satisfied_mask: u64,
}
pub(super) struct SimpleMatchState {
pub(super) rule_states: Vec<RuleState>,
pub(super) matrix: Vec<Vec<i32>>,
pub(super) matrix_status: Vec<Vec<u8>>,
pub(super) touched_indices: Vec<usize>,
generation: u16,
}
#[thread_local]
pub(super) static SIMPLE_MATCH_STATE: UnsafeCell<SimpleMatchState> =
UnsafeCell::new(SimpleMatchState::new());
pub(super) struct ScanState<'a> {
pub(super) rule_states: &'a mut [RuleState],
pub(super) touched_indices: &'a mut Vec<usize>,
pub(super) matrix: &'a mut [Vec<i32>],
pub(super) matrix_status: &'a mut [Vec<u8>],
pub(super) generation: u16,
}
#[derive(Clone, Copy)]
pub(super) struct ScanContext {
pub(super) text_index: usize,
pub(super) process_type_mask: u64,
pub(super) num_variants: usize,
pub(super) exit_early: bool,
pub(super) non_ascii_density: f32,
}
impl ScanState<'_> {
pub(super) fn touched_indices(&self) -> &[usize] {
self.touched_indices
}
#[inline(always)]
pub(super) fn has_match(&self) -> bool {
self.touched_indices()
.iter()
.any(|&rule_idx| self.rule_is_satisfied(rule_idx))
}
#[inline(always)]
pub(super) fn rule_is_satisfied(&self, rule_idx: usize) -> bool {
unsafe { core::hint::assert_unchecked(rule_idx < self.rule_states.len()) };
let rs = &self.rule_states[rule_idx];
rs.generation == self.generation && rs.remaining_and == 0 && !rs.vetoed
}
}
#[cfg(test)]
impl ScanState<'_> {
pub(super) fn generation(&self) -> u16 {
self.generation
}
pub(super) fn init_rule(
&mut self,
rule: &super::rule::Rule,
and_count: usize,
rule_idx: usize,
ctx: ScanContext,
) {
let generation = self.generation;
let rs = &mut self.rule_states[rule_idx];
rs.generation = generation;
rs.remaining_and = and_count as u16;
rs.vetoed = false;
rs.satisfied_mask = 0;
self.touched_indices.push(rule_idx);
let use_matrix = and_count > super::pattern::BITMASK_CAPACITY
|| rule.segment_counts.len() > super::pattern::BITMASK_CAPACITY
|| rule.segment_counts[..and_count].iter().any(|&v| v != 1)
|| rule.segment_counts[and_count..].iter().any(|&v| v != 0);
if use_matrix {
init_matrix(
&mut self.matrix[rule_idx],
&mut self.matrix_status[rule_idx],
&rule.segment_counts,
ctx.num_variants,
);
}
}
}
impl SimpleMatchState {
pub(super) const fn new() -> Self {
Self {
rule_states: Vec::new(),
matrix: Vec::new(),
matrix_status: Vec::new(),
touched_indices: Vec::new(),
generation: 0,
}
}
pub(super) fn prepare(&mut self, size: usize) {
if self.generation == u16::MAX {
for rs in self.rule_states.iter_mut() {
rs.generation = 0;
}
self.generation = 1;
} else {
self.generation += 1;
}
if self.rule_states.len() < size {
self.rule_states.resize(size, RuleState::default());
self.matrix.resize(size, Vec::new());
self.matrix_status.resize(size, Vec::new());
}
self.touched_indices.clear();
}
#[inline(always)]
pub(super) fn as_scan_state(&mut self) -> ScanState<'_> {
ScanState {
rule_states: &mut self.rule_states,
touched_indices: &mut self.touched_indices,
matrix: &mut self.matrix,
matrix_status: &mut self.matrix_status,
generation: self.generation,
}
}
}
#[cold]
#[inline(never)]
pub(super) fn init_matrix(
flat_matrix: &mut Vec<i32>,
flat_status: &mut Vec<u8>,
segment_counts: &[i32],
num_variants: usize,
) {
let num_splits = segment_counts.len();
flat_matrix.clear();
flat_matrix.resize(num_splits * num_variants, 0i32);
flat_status.clear();
flat_status.resize(num_splits, 0u8);
for (split_idx, &count) in segment_counts.iter().enumerate() {
let row_start = split_idx * num_variants;
flat_matrix[row_start..row_start + num_variants].fill(count);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ctx(num_variants: usize, exit_early: bool) -> ScanContext {
ScanContext {
text_index: 0,
process_type_mask: u64::MAX,
num_variants,
exit_early,
non_ascii_density: 0.0,
}
}
#[test]
fn test_prepare() {
let mut state = SimpleMatchState::new();
assert_eq!(state.generation, 0);
state.prepare(10);
assert!(state.rule_states.len() >= 10);
assert!(state.matrix.len() >= 10);
assert_eq!(state.generation, 1);
assert!(state.touched_indices.is_empty());
state.prepare(10);
assert_eq!(state.generation, 2);
state.prepare(10);
assert_eq!(state.generation, 3);
}
#[test]
fn test_prepare_generation_wraparound() {
let mut state = SimpleMatchState::new();
state.prepare(3);
let current = state.generation;
state.rule_states[0].generation = current;
state.rule_states[1].generation = current;
state.rule_states[2].generation = current;
state.generation = u16::MAX - 1;
state.prepare(3);
assert_eq!(state.generation, u16::MAX);
state.prepare(3);
assert_eq!(state.generation, 1);
for rs in &state.rule_states {
assert_eq!(rs.generation, 0);
}
}
#[test]
fn test_rule_satisfaction() {
let mut state = SimpleMatchState::new();
state.prepare(1);
let current = state.generation;
state.rule_states[0].generation = current;
state.rule_states[0].remaining_and = 0;
state.rule_states[0].vetoed = false;
assert!(state.as_scan_state().rule_is_satisfied(0));
state.rule_states[0].vetoed = true;
assert!(!state.as_scan_state().rule_is_satisfied(0));
}
#[test]
fn test_init_rule_matrix() {
let mut state = SimpleMatchState::new();
state.prepare(1);
let rule = super::super::rule::Rule {
segment_counts: vec![2, 1, 0],
word_id: 1,
word: "a&a&b~c".to_owned(),
};
let ctx = make_ctx(2, false);
let mut ss = state.as_scan_state();
ss.init_rule(&rule, 2, 0, ctx);
assert_eq!(ss.rule_states[0].generation, ss.generation());
assert_eq!(ss.rule_states[0].remaining_and, 2);
assert!(!ss.rule_states[0].vetoed);
assert_eq!(ss.rule_states[0].satisfied_mask, 0);
assert_eq!(ss.touched_indices(), &[0]);
assert_eq!(ss.matrix[0].len(), 6);
assert_eq!(&ss.matrix[0][..], &[2, 2, 1, 1, 0, 0]);
assert_eq!(ss.matrix_status[0].len(), 3);
assert!(ss.matrix_status[0].iter().all(|&s| s == 0));
}
}