use crate::nfa::{Nfa, NfaInstruction, StateId};
use std::collections::VecDeque;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct CaptureBitSet(pub u64);
impl CaptureBitSet {
pub const fn empty() -> Self {
Self(0)
}
pub fn all(count: usize) -> Self {
if count >= 64 {
Self(u64::MAX)
} else if count == 0 {
Self(0)
} else {
Self((1u64 << count) - 1)
}
}
#[inline]
pub fn set(&mut self, idx: u32) {
if idx < 64 {
self.0 |= 1u64 << idx;
}
}
#[inline]
pub fn clear(&mut self, idx: u32) {
if idx < 64 {
self.0 &= !(1u64 << idx);
}
}
#[inline]
pub fn contains(&self, idx: u32) -> bool {
if idx < 64 {
(self.0 & (1u64 << idx)) != 0
} else {
false
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0 == 0
}
#[inline]
pub fn count(&self) -> u32 {
self.0.count_ones()
}
#[inline]
pub fn union(&self, other: &Self) -> Self {
Self(self.0 | other.0)
}
#[inline]
pub fn intersect(&self, other: &Self) -> Self {
Self(self.0 & other.0)
}
pub fn iter(&self) -> impl Iterator<Item = u32> + '_ {
(0..64u32).filter(move |&i| self.contains(i))
}
}
#[derive(Debug, Clone, Default)]
pub struct StateLiveness {
pub live_reads: CaptureBitSet,
pub writes: CaptureBitSet,
pub copy_mask: CaptureBitSet,
}
#[derive(Debug, Clone)]
pub struct NfaLiveness {
pub states: Vec<StateLiveness>,
pub capture_count: u32,
pub lookaround_count: u32,
}
impl NfaLiveness {
pub fn copy_mask(&self, state: StateId) -> CaptureBitSet {
self.states
.get(state as usize)
.map(|s| s.copy_mask)
.unwrap_or_default()
}
pub fn needs_copy(&self, state: StateId) -> bool {
!self.copy_mask(state).is_empty()
}
}
pub fn analyze_liveness(nfa: &Nfa) -> NfaLiveness {
let state_count = nfa.states.len();
let capture_count = nfa.capture_count;
let mut states: Vec<StateLiveness> = vec![StateLiveness::default(); state_count];
let mut lookaround_count = 0u32;
for (id, state) in nfa.states.iter().enumerate() {
if let Some(ref instr) = state.instruction {
match instr {
NfaInstruction::CaptureStart(idx) | NfaInstruction::CaptureEnd(idx) => {
states[id].writes.set(*idx);
}
NfaInstruction::Backref(idx) => {
states[id].live_reads.set(*idx);
}
NfaInstruction::PositiveLookahead(_)
| NfaInstruction::NegativeLookahead(_)
| NfaInstruction::PositiveLookbehind(_)
| NfaInstruction::NegativeLookbehind(_) => {
lookaround_count += 1;
}
_ => {}
}
}
}
let mut worklist: VecDeque<StateId> = VecDeque::new();
let mut in_worklist = vec![false; state_count];
let mut predecessors: Vec<Vec<StateId>> = vec![Vec::new(); state_count];
for (id, state) in nfa.states.iter().enumerate() {
let id = id as StateId;
for &target in &state.epsilon {
predecessors[target as usize].push(id);
}
for (_, target) in &state.transitions {
predecessors[*target as usize].push(id);
}
}
for (id, state_liveness) in states.iter().enumerate() {
if !state_liveness.live_reads.is_empty() {
worklist.push_back(id as StateId);
in_worklist[id] = true;
}
}
let all_captures = CaptureBitSet::all(capture_count as usize + 1);
for &match_id in &nfa.matches {
states[match_id as usize].live_reads = all_captures;
if !in_worklist[match_id as usize] {
worklist.push_back(match_id);
in_worklist[match_id as usize] = true;
}
}
while let Some(state_id) = worklist.pop_front() {
in_worklist[state_id as usize] = false;
let current_reads = states[state_id as usize].live_reads;
for &pred_id in &predecessors[state_id as usize] {
let pred_state = &mut states[pred_id as usize];
let old_reads = pred_state.live_reads;
let propagated = current_reads.union(&pred_state.live_reads);
if propagated != old_reads {
pred_state.live_reads = propagated;
if !in_worklist[pred_id as usize] {
worklist.push_back(pred_id);
in_worklist[pred_id as usize] = true;
}
}
}
}
let mut writes_before: Vec<CaptureBitSet> = vec![CaptureBitSet::empty(); state_count];
worklist.clear();
in_worklist.fill(false);
worklist.push_back(nfa.start);
in_worklist[nfa.start as usize] = true;
while let Some(state_id) = worklist.pop_front() {
in_worklist[state_id as usize] = false;
let state = &nfa.states[state_id as usize];
let current_writes = writes_before[state_id as usize].union(&states[state_id as usize].writes);
let mut propagate = |target: StateId| {
let old = writes_before[target as usize];
let new = old.union(¤t_writes);
if new != old {
writes_before[target as usize] = new;
if !in_worklist[target as usize] {
worklist.push_back(target);
in_worklist[target as usize] = true;
}
}
};
for &target in &state.epsilon {
propagate(target);
}
for (_, target) in &state.transitions {
propagate(*target);
}
}
for (id, state_liveness) in states.iter_mut().enumerate() {
state_liveness.copy_mask = state_liveness.live_reads;
let _ = &writes_before[id]; }
NfaLiveness {
states,
capture_count,
lookaround_count,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hir::translate;
use crate::nfa::compile;
use crate::parser::parse;
fn analyze_pattern(pattern: &str) -> NfaLiveness {
let ast = parse(pattern).unwrap();
let hir = translate(&ast).unwrap();
let nfa = compile(&hir).unwrap();
analyze_liveness(&nfa)
}
#[test]
fn test_bitset_operations() {
let mut bs = CaptureBitSet::empty();
assert!(bs.is_empty());
bs.set(0);
bs.set(3);
bs.set(7);
assert!(bs.contains(0));
assert!(bs.contains(3));
assert!(bs.contains(7));
assert!(!bs.contains(1));
assert_eq!(bs.count(), 3);
let bs2 = CaptureBitSet::all(4);
assert!(bs2.contains(0));
assert!(bs2.contains(3));
assert!(!bs2.contains(4));
let union = bs.union(&bs2);
assert!(union.contains(0));
assert!(union.contains(3));
assert!(union.contains(7));
let intersect = bs.intersect(&bs2);
assert!(intersect.contains(0));
assert!(intersect.contains(3));
assert!(!intersect.contains(7));
}
#[test]
fn test_simple_capture() {
let liveness = analyze_pattern(r"(a)");
assert_eq!(liveness.capture_count, 1);
let _has_capture_1_live = liveness.states.iter().any(|s| s.live_reads.contains(1));
let has_capture_1_write = liveness.states.iter().any(|s| s.writes.contains(1));
assert!(has_capture_1_write, "Should have capture 1 write");
}
#[test]
fn test_alternation_no_captures() {
let liveness = analyze_pattern(r"a|b");
assert_eq!(liveness.capture_count, 0);
for state in &liveness.states {
assert!(
!state.copy_mask.contains(1),
"No explicit captures means group 1+ should not be in copy_mask"
);
}
}
#[test]
fn test_alternation_with_captures() {
let liveness = analyze_pattern(r"(a)|(b)");
assert_eq!(liveness.capture_count, 2);
}
#[test]
fn test_backref_makes_capture_live() {
let liveness = analyze_pattern(r"(a)\1");
let has_live_capture_1 = liveness.states.iter().any(|s| s.live_reads.contains(1));
assert!(has_live_capture_1, "Backref should make capture 1 live");
}
#[test]
fn test_nested_captures() {
let liveness = analyze_pattern(r"((a)(b))");
assert_eq!(liveness.capture_count, 3);
}
#[test]
fn test_lookaround_count() {
let liveness = analyze_pattern(r"a(?=b)");
assert_eq!(liveness.lookaround_count, 1);
let liveness = analyze_pattern(r"(?=a)b(?!c)");
assert_eq!(liveness.lookaround_count, 2);
}
}