use crate::hir::CodepointClass;
use std::collections::BTreeSet;
use std::sync::Arc;
pub type StateId = u32;
#[derive(Debug, Clone)]
pub struct Nfa {
pub states: Vec<NfaState>,
pub start: StateId,
pub matches: Vec<StateId>,
pub capture_count: u32,
pub has_backrefs: bool,
pub has_lookaround: bool,
pub epsilon_closures: Option<Vec<BTreeSet<StateId>>>,
}
impl Nfa {
pub fn new() -> Self {
Self {
states: Vec::new(),
start: 0,
matches: Vec::new(),
capture_count: 0,
has_backrefs: false,
has_lookaround: false,
epsilon_closures: None,
}
}
pub fn add_state(&mut self, state: NfaState) -> StateId {
let id = self.states.len() as StateId;
self.states.push(state);
id
}
pub fn state_count(&self) -> usize {
self.states.len()
}
pub fn get(&self, id: StateId) -> Option<&NfaState> {
self.states.get(id as usize)
}
pub fn get_mut(&mut self, id: StateId) -> Option<&mut NfaState> {
self.states.get_mut(id as usize)
}
pub fn epsilon_closure(&self, states: &BTreeSet<StateId>) -> BTreeSet<StateId> {
if let Some(ref precomputed) = self.epsilon_closures {
let mut closure = BTreeSet::new();
for &state_id in states {
if let Some(state_closure) = precomputed.get(state_id as usize) {
closure.extend(state_closure.iter().copied());
}
}
return closure;
}
let mut closure = states.clone();
let mut stack: Vec<StateId> = states.iter().copied().collect();
while let Some(state_id) = stack.pop() {
if let Some(state) = self.get(state_id) {
for &next in &state.epsilon {
if closure.insert(next) {
stack.push(next);
}
}
}
}
closure
}
pub fn precompute_epsilon_closures(&mut self) {
let epsilon_count: usize = self.states.iter().map(|s| s.epsilon.len()).sum();
if epsilon_count < 100 {
return;
}
let mut closures = Vec::with_capacity(self.states.len());
for state_id in 0..self.states.len() {
let mut closure = BTreeSet::new();
closure.insert(state_id as StateId);
let mut stack = vec![state_id as StateId];
while let Some(sid) = stack.pop() {
if let Some(state) = self.get(sid) {
for &next in &state.epsilon {
if closure.insert(next) {
stack.push(next);
}
}
}
}
closures.push(closure);
}
self.epsilon_closures = Some(closures);
}
}
impl Default for Nfa {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct NfaState {
pub transitions: Vec<(ByteRange, StateId)>,
pub epsilon: Vec<StateId>,
pub is_match: bool,
pub instruction: Option<NfaInstruction>,
}
impl NfaState {
pub fn new() -> Self {
Self {
transitions: Vec::new(),
epsilon: Vec::new(),
is_match: false,
instruction: None,
}
}
pub fn match_state() -> Self {
Self {
transitions: Vec::new(),
epsilon: Vec::new(),
is_match: true,
instruction: None,
}
}
pub fn add_transition(&mut self, range: ByteRange, target: StateId) {
self.transitions.push((range, target));
}
pub fn add_epsilon(&mut self, target: StateId) {
self.epsilon.push(target);
}
}
impl Default for NfaState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ByteRange {
pub start: u8,
pub end: u8,
}
impl ByteRange {
pub fn new(start: u8, end: u8) -> Self {
Self { start, end }
}
pub fn single(byte: u8) -> Self {
Self {
start: byte,
end: byte,
}
}
pub fn any() -> Self {
Self { start: 0, end: 255 }
}
pub fn contains(&self, byte: u8) -> bool {
byte >= self.start && byte <= self.end
}
pub fn overlaps(&self, other: &ByteRange) -> bool {
self.start <= other.end && other.start <= self.end
}
}
#[derive(Debug, Clone)]
pub struct ByteClass {
pub ranges: Vec<ByteRange>,
bitmap: [u64; 4],
}
impl ByteClass {
pub fn new(ranges: Vec<ByteRange>) -> Self {
let bitmap = Self::compute_bitmap(&ranges);
Self { ranges, bitmap }
}
pub fn from_slice(ranges: &[ByteRange]) -> Self {
Self::new(ranges.to_vec())
}
fn compute_bitmap(ranges: &[ByteRange]) -> [u64; 4] {
let mut bits = [0u64; 4];
for range in ranges {
for byte in range.start..=range.end {
let idx = (byte / 64) as usize;
let bit = byte % 64;
bits[idx] |= 1u64 << bit;
}
}
bits
}
#[inline(always)]
pub fn contains(&self, byte: u8) -> bool {
let idx = (byte / 64) as usize;
let bit = byte % 64;
(self.bitmap[idx] & (1u64 << bit)) != 0
}
#[inline]
pub fn bitmap(&self) -> &[u64; 4] {
&self.bitmap
}
}
#[derive(Debug, Clone)]
pub enum NfaInstruction {
CaptureStart(u32),
CaptureEnd(u32),
Backref(u32),
WordBoundary,
NotWordBoundary,
StartOfText,
EndOfText,
StartOfLine,
EndOfLine,
PositiveLookahead(Arc<Nfa>),
NegativeLookahead(Arc<Nfa>),
PositiveLookbehind(Arc<Nfa>),
NegativeLookbehind(Arc<Nfa>),
NonGreedyExit,
CodepointClass(CodepointClass, StateId),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_byte_range() {
let range = ByteRange::new(b'a', b'z');
assert!(range.contains(b'm'));
assert!(!range.contains(b'A'));
}
#[test]
fn test_epsilon_closure() {
let mut nfa = Nfa::new();
let mut s0 = NfaState::new();
s0.add_epsilon(1);
nfa.add_state(s0);
let mut s1 = NfaState::new();
s1.add_epsilon(2);
nfa.add_state(s1);
nfa.add_state(NfaState::new());
let mut initial = BTreeSet::new();
initial.insert(0);
let closure = nfa.epsilon_closure(&initial);
assert!(closure.contains(&0));
assert!(closure.contains(&1));
assert!(closure.contains(&2));
}
}