use std::cell::Cell;
use std::mem;
use bitvec::array::BitArray;
use super::instr::{Instr, InstrParser, Offset};
use crate::re::bitmapset::BitmapSet;
use crate::re::thompson::instr::SplitId;
use crate::re::{Action, CodeLoc, DEFAULT_SCAN_LIMIT, WideIter};
pub(crate) struct PikeVM<'r> {
code: &'r [u8],
threads: BitmapSet<u32>,
next_threads: BitmapSet<u32>,
scan_limit: u16,
cache: EpsilonClosureState,
}
impl<'r> PikeVM<'r> {
pub fn new(code: &'r [u8]) -> Self {
Self {
code,
threads: BitmapSet::new(),
next_threads: BitmapSet::new(),
cache: EpsilonClosureState::new(),
scan_limit: DEFAULT_SCAN_LIMIT,
}
}
#[allow(dead_code)]
pub fn scan_limit(mut self, limit: u16) -> Self {
self.scan_limit = limit;
self
}
#[inline]
pub(crate) fn try_match<C>(
&mut self,
start: C,
right: &[u8],
left: &[u8],
wide: bool,
mut f: impl FnMut(usize) -> Action,
) where
C: CodeLoc,
{
match (start.backwards(), wide) {
(false, false) => {
self.try_match_impl(start, right.iter(), left.iter().rev(), f)
}
(false, true) => {
let error_fwd = Cell::new(None);
let error_bck = Cell::new(None);
self.try_match_impl(
start,
WideIter::non_zero_first(right.iter(), &error_fwd),
WideIter::zero_first(left.iter().rev(), &error_bck),
|match_len| match error_fwd.get() {
Some(pos) if pos < match_len => Action::Stop,
_ => f(match_len * 2),
},
)
}
(true, false) => {
self.try_match_impl(start, left.iter().rev(), right.iter(), f)
}
(true, true) => {
let error_fwd = Cell::new(None);
let error_bck = Cell::new(None);
self.try_match_impl(
start,
WideIter::zero_first(left.iter().rev(), &error_fwd),
WideIter::non_zero_first(right.iter(), &error_bck),
|match_len| match error_fwd.get() {
Some(pos) if pos < match_len => Action::Stop,
_ => f(match_len * 2),
},
)
}
}
}
fn try_match_impl<'a, C, F, B>(
&mut self,
start: C,
mut fwd_input: F,
mut bck_input: B,
mut f: impl FnMut(usize) -> Action,
) where
C: CodeLoc,
F: Iterator<Item = &'a u8>,
B: Iterator<Item = &'a u8>,
{
let step = 1;
let mut current_pos = 0;
let mut curr_byte = fwd_input.next();
debug_assert!(self.threads.is_empty());
epsilon_closure(
self.code,
start,
0,
curr_byte,
bck_input.next(),
&mut self.cache,
&mut self.threads,
);
while !self.threads.is_empty() {
let next_byte = fwd_input.next();
for (ip, rep_count) in self.threads.iter() {
let (instr, instr_size) = InstrParser::decode_instr(unsafe {
self.code.get_unchecked(*ip..)
});
let is_match = match instr {
Instr::AnyByte => curr_byte.is_some(),
Instr::Byte(byte) => {
matches!(curr_byte, Some(b) if *b == byte)
}
Instr::MaskedByte { byte, mask } => {
matches!(curr_byte, Some(b) if *b & mask == byte)
}
Instr::CaseInsensitiveChar(byte) => {
matches!(curr_byte, Some(b) if b.to_ascii_lowercase() == byte)
}
Instr::ClassBitmap(class) => {
matches!(curr_byte, Some(b) if class.contains(*b))
}
Instr::ClassRanges(class) => {
matches!(curr_byte, Some(b) if class.contains(*b))
}
Instr::Match => match f(current_pos) {
Action::Stop => break,
Action::Continue => false,
},
_ => unreachable!(),
};
if is_match {
epsilon_closure(
self.code,
C::from(*ip + instr_size),
*rep_count,
next_byte,
curr_byte,
&mut self.cache,
&mut self.next_threads,
);
}
}
curr_byte = next_byte;
current_pos += step;
mem::swap(&mut self.threads, &mut self.next_threads);
self.next_threads.clear();
if current_pos >= self.scan_limit as usize {
self.threads.clear();
break;
}
}
}
}
pub struct EpsilonClosureState {
threads: Vec<(usize, u32)>,
executed_splits:
BitArray<[usize; SplitId::MAX.div_ceil(usize::BITS as usize)]>,
dirty: bool,
}
impl EpsilonClosureState {
pub fn new() -> Self {
Self {
threads: Vec::new(),
executed_splits: Default::default(),
dirty: false,
}
}
#[inline(always)]
pub fn executed(&mut self, split_id: SplitId) -> bool {
if self.dirty {
self.executed_splits.fill(false);
self.dirty = false;
}
unsafe {
let executed = *self
.executed_splits
.get_unchecked(std::convert::Into::<usize>::into(split_id));
if !executed {
self.executed_splits.set_unchecked(split_id.into(), true);
}
executed
}
}
}
#[inline(always)]
pub(crate) fn epsilon_closure<C: CodeLoc>(
code: &[u8],
start: C,
rep_count: u32,
curr_byte: Option<&u8>,
prev_byte: Option<&u8>,
state: &mut EpsilonClosureState,
closure: &mut BitmapSet<u32>,
) {
state.threads.push((start.location(), rep_count));
state.dirty = true;
let is_word_char = |c: u8| c == b'_' || c.is_ascii_alphanumeric();
let apply_offset = |ip: usize, offset: Offset| -> usize {
(ip as isize).saturating_add(offset.into()).try_into().unwrap()
};
while let Some((ip, mut rep_count)) = state.threads.pop() {
let (instr, instr_size) =
InstrParser::decode_instr(unsafe { code.get_unchecked(ip..) });
match instr {
Instr::AnyByte
| Instr::Byte(_)
| Instr::MaskedByte { .. }
| Instr::CaseInsensitiveChar(_)
| Instr::ClassBitmap(_)
| Instr::ClassRanges(_)
| Instr::Match => {
closure.insert(ip, rep_count);
}
Instr::SplitA(id, offset) => {
if !state.executed(id) {
state.threads.push((apply_offset(ip, offset), rep_count));
state.threads.push((
apply_offset(ip, instr_size.into()),
rep_count,
));
}
}
Instr::SplitB(id, offset) => {
if !state.executed(id) {
state.threads.push((
apply_offset(ip, instr_size.into()),
rep_count,
));
state.threads.push((apply_offset(ip, offset), rep_count));
}
}
Instr::SplitN(split) => {
if !state.executed(split.id()) {
for offset in split.offsets().rev() {
state
.threads
.push((apply_offset(ip, offset), rep_count));
}
}
}
Instr::RepeatGreedy { offset, min, max } => {
rep_count += 1;
if rep_count >= min {
state
.threads
.push((apply_offset(ip, instr_size.into()), 0));
}
if rep_count < max {
state.threads.push((apply_offset(ip, offset), rep_count));
}
}
Instr::RepeatNonGreedy { offset, min, max } => {
rep_count += 1;
if rep_count < max {
state.threads.push((apply_offset(ip, offset), rep_count));
}
if rep_count >= min {
state
.threads
.push((apply_offset(ip, instr_size.into()), 0));
}
}
Instr::Jump(offset) => {
state.threads.push((apply_offset(ip, offset), rep_count));
}
Instr::Start => {
let is_match = match (start.backwards(), prev_byte, curr_byte)
{
(false, None, _) => true,
(true, _, None) => true,
_ => false,
};
if is_match {
state.threads.push((
apply_offset(ip, instr_size.into()),
rep_count,
));
}
}
Instr::End => {
let is_match = match (start.backwards(), prev_byte, curr_byte)
{
(false, _, None) => true,
(true, None, _) => true,
_ => false,
};
if is_match {
state.threads.push((
apply_offset(ip, instr_size.into()),
rep_count,
));
}
}
Instr::LineStart => {
let is_match = match (start.backwards(), prev_byte, curr_byte)
{
(false, None, _) => true,
(false, Some(b'\n'), None) => true,
(false, Some(b'\r'), None) => true,
(false, Some(b'\n'), Some(curr_byte)) => {
*curr_byte != b'\r'
}
(false, Some(b'\r'), Some(curr_byte)) => {
*curr_byte != b'\n'
}
(true, _, None) => true,
(true, None, Some(b'\n')) => true,
(true, None, Some(b'\r')) => true,
(true, Some(prev_byte), Some(b'\n')) => {
*prev_byte != b'\r'
}
(true, Some(prev_byte), Some(b'\r')) => {
*prev_byte != b'\n'
}
_ => false,
};
if is_match {
state.threads.push((
apply_offset(ip, instr_size.into()),
rep_count,
));
}
}
Instr::LineEnd => {
let is_match = match (start.backwards(), prev_byte, curr_byte)
{
(false, _, None) => true,
(false, None, Some(b'\n')) => true,
(false, None, Some(b'\t')) => true,
(false, Some(prev_byte), Some(b'\n')) => {
*prev_byte != b'\r'
}
(false, Some(prev_byte), Some(b'\r')) => {
*prev_byte != b'\n'
}
(true, None, _) => true,
(true, Some(b'\n'), Some(curr_byte)) => {
*curr_byte != b'\r'
}
(true, Some(b'\r'), Some(curr_byte)) => {
*curr_byte != b'\n'
}
_ => false,
};
if is_match {
state.threads.push((
apply_offset(ip, instr_size.into()),
rep_count,
));
}
}
Instr::WordStart => {
let is_match = match (start.backwards(), prev_byte, curr_byte)
{
(false, Some(p), Some(c)) | (true, Some(c), Some(p)) => {
!is_word_char(*p) && is_word_char(*c)
}
(false, None, Some(c)) | (true, Some(c), None) => {
is_word_char(*c)
}
_ => false,
};
if is_match {
state.threads.push((
apply_offset(ip, instr_size.into()),
rep_count,
));
}
}
Instr::WordEnd => {
let is_match = match (start.backwards(), prev_byte, curr_byte)
{
(false, Some(p), Some(c)) | (true, Some(c), Some(p)) => {
is_word_char(*p) && !is_word_char(*c)
}
(false, Some(p), None) | (true, Some(p), None) => {
is_word_char(*p)
}
_ => false,
};
if is_match {
state.threads.push((
apply_offset(ip, instr_size.into()),
rep_count,
));
}
}
Instr::WordBoundary | Instr::WordBoundaryNeg => {
let mut is_match = match (prev_byte, curr_byte) {
(Some(p), Some(c)) => is_word_char(*p) != is_word_char(*c),
(None, Some(b)) | (Some(b), None) => is_word_char(*b),
_ => false,
};
if matches!(instr, Instr::WordBoundaryNeg) {
is_match = !is_match;
}
if is_match {
state.threads.push((
apply_offset(ip, instr_size.into()),
rep_count,
));
}
}
}
}
}