use std::cell::UnsafeCell;
#[derive(Default, Clone, Copy)]
pub(super) struct WordState {
pub(super) generation: u16,
pub(super) remaining_and: u16,
pub(super) vetoed: bool,
}
pub(super) struct SimpleMatchState {
pub(super) word_states: Vec<WordState>,
pub(super) satisfied_masks: Vec<u64>,
pub(super) matrix: Vec<Vec<i32>>,
pub(super) matrix_status: Vec<Vec<u8>>,
pub(super) touched_indices: Vec<usize>,
generation: u16,
}
#[thread_local]
pub(super) static SIMPLE_MATCH_STATE: UnsafeCell<SimpleMatchState> =
UnsafeCell::new(SimpleMatchState::new());
pub(super) struct ScanState<'a> {
pub(super) word_states: &'a mut [WordState],
pub(super) satisfied_masks: &'a mut [u64],
pub(super) touched_indices: &'a mut Vec<usize>,
pub(super) matrix: &'a mut [Vec<i32>],
pub(super) matrix_status: &'a mut [Vec<u8>],
pub(super) generation: u16,
}
#[derive(Clone, Copy)]
pub(super) struct ScanContext {
pub(super) text_index: usize,
pub(super) process_type_mask: u64,
pub(super) num_variants: usize,
pub(super) exit_early: bool,
pub(super) non_ascii_density: f32,
}
impl ScanState<'_> {
pub(super) fn touched_indices(&self) -> &[usize] {
self.touched_indices
}
#[inline(always)]
pub(super) fn rule_is_satisfied(&self, rule_idx: usize) -> bool {
let ws = unsafe {
core::hint::assert_unchecked(rule_idx < self.word_states.len());
self.word_states.get_unchecked(rule_idx)
};
ws.generation == self.generation && ws.remaining_and == 0 && !ws.vetoed
}
}
#[cfg(test)]
impl ScanState<'_> {
pub(super) fn generation(&self) -> u16 {
self.generation
}
pub(super) fn init_rule(
&mut self,
rule: &super::rule::Rule,
and_count: usize,
rule_idx: usize,
ctx: ScanContext,
) {
let generation = self.generation;
let ws = unsafe { self.word_states.get_unchecked_mut(rule_idx) };
ws.generation = generation;
ws.remaining_and = and_count as u16;
ws.vetoed = false;
self.satisfied_masks[rule_idx] = 0;
self.touched_indices.push(rule_idx);
let use_matrix = and_count > super::pattern::BITMASK_CAPACITY
|| rule.segment_counts.len() > super::pattern::BITMASK_CAPACITY
|| rule.segment_counts[..and_count].iter().any(|&v| v != 1)
|| rule.segment_counts[and_count..].iter().any(|&v| v != 0);
if use_matrix {
init_matrix(
unsafe { self.matrix.get_unchecked_mut(rule_idx) },
unsafe { self.matrix_status.get_unchecked_mut(rule_idx) },
&rule.segment_counts,
ctx.num_variants,
);
}
}
}
impl SimpleMatchState {
pub(super) const fn new() -> Self {
Self {
word_states: Vec::new(),
satisfied_masks: Vec::new(),
matrix: Vec::new(),
matrix_status: Vec::new(),
touched_indices: Vec::new(),
generation: 0,
}
}
pub(super) fn prepare(&mut self, size: usize) {
if self.generation == u16::MAX {
for ws in self.word_states.iter_mut() {
ws.generation = 0;
}
self.generation = 1;
} else {
self.generation += 1;
}
if self.word_states.len() < size {
self.word_states.resize(size, WordState::default());
self.satisfied_masks.resize(size, 0);
self.matrix.resize(size, Vec::new());
self.matrix_status.resize(size, Vec::new());
}
self.touched_indices.clear();
}
#[inline(always)]
pub(super) fn as_scan_state(&mut self) -> ScanState<'_> {
ScanState {
word_states: &mut self.word_states,
satisfied_masks: &mut self.satisfied_masks,
touched_indices: &mut self.touched_indices,
matrix: &mut self.matrix,
matrix_status: &mut self.matrix_status,
generation: self.generation,
}
}
}
#[cold]
#[inline(never)]
pub(super) fn init_matrix(
flat_matrix: &mut Vec<i32>,
flat_status: &mut Vec<u8>,
segment_counts: &[i32],
num_variants: usize,
) {
let num_splits = segment_counts.len();
flat_matrix.clear();
flat_matrix.resize(num_splits * num_variants, 0i32);
flat_status.clear();
flat_status.resize(num_splits, 0u8);
for (split_idx, &count) in segment_counts.iter().enumerate() {
let row_start = split_idx * num_variants;
flat_matrix[row_start..row_start + num_variants].fill(count);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ctx(num_variants: usize, exit_early: bool) -> ScanContext {
ScanContext {
text_index: 0,
process_type_mask: u64::MAX,
num_variants,
exit_early,
non_ascii_density: 0.0,
}
}
#[test]
fn test_prepare() {
let mut state = SimpleMatchState::new();
assert_eq!(state.generation, 0);
state.prepare(10);
assert!(state.word_states.len() >= 10);
assert!(state.matrix.len() >= 10);
assert_eq!(state.generation, 1);
assert!(state.touched_indices.is_empty());
state.prepare(10);
assert_eq!(state.generation, 2);
state.prepare(10);
assert_eq!(state.generation, 3);
}
#[test]
fn test_prepare_generation_wraparound() {
let mut state = SimpleMatchState::new();
state.prepare(3);
let current = state.generation;
state.word_states[0].generation = current;
state.word_states[1].generation = current;
state.word_states[2].generation = current;
state.generation = u16::MAX - 1;
state.prepare(3);
assert_eq!(state.generation, u16::MAX);
state.prepare(3);
assert_eq!(state.generation, 1);
for ws in &state.word_states {
assert_eq!(ws.generation, 0);
}
}
#[test]
fn test_rule_satisfaction() {
let mut state = SimpleMatchState::new();
state.prepare(1);
let current = state.generation;
state.word_states[0].generation = current;
state.word_states[0].remaining_and = 0;
state.word_states[0].vetoed = false;
assert!(state.as_scan_state().rule_is_satisfied(0));
state.word_states[0].vetoed = true;
assert!(!state.as_scan_state().rule_is_satisfied(0));
}
#[test]
fn test_init_rule_matrix() {
let mut state = SimpleMatchState::new();
state.prepare(1);
let rule = super::super::rule::Rule {
segment_counts: vec![2, 1, 0],
word_id: 1,
word: "a&a&b~c".to_owned(),
};
let ctx = make_ctx(2, false);
let mut ss = state.as_scan_state();
ss.init_rule(&rule, 2, 0, ctx);
assert_eq!(ss.word_states[0].generation, ss.generation());
assert_eq!(ss.word_states[0].remaining_and, 2);
assert!(!ss.word_states[0].vetoed);
assert_eq!(ss.satisfied_masks[0], 0);
assert_eq!(ss.touched_indices(), &[0]);
assert_eq!(ss.matrix[0].len(), 6);
assert_eq!(&ss.matrix[0][..], &[2, 2, 1, 1, 0, 0]);
assert_eq!(ss.matrix_status[0].len(), 3);
assert!(ss.matrix_status[0].iter().all(|&s| s == 0));
}
}