use std::cell::RefCell;
use std::sync::Arc;
use rustc_hash::FxHashMap;
use smallvec::{SmallVec, smallvec};
use crate::automaton::{
BYTE_CEILING, FieldMatcher,
arena::{ARENA_VALUE_TERMINATOR, SmallTable, StateArena, StateId},
};
use super::parser::{
Branch as RegexpBranch, QuantifiedAtom, REGEXP_QUANTIFIER_MAX, Root as RegexpRoot, RuneRange,
};
const UTF8_1BYTE_MAX: u32 = 0x7F;
const UTF8_2BYTE_MAX: u32 = 0x7FF;
const UTF8_3BYTE_MAX: u32 = 0xFFFF;
const SURROGATE_START: u32 = 0xD800;
const SURROGATE_END: u32 = 0xDFFF;
fn rune_to_utf8(r: char) -> Vec<u8> {
let mut buf = [0u8; 4];
let s = r.encode_utf8(&mut buf);
s.as_bytes().to_vec()
}
#[must_use]
pub fn regexp_has_plus_star(root: &RegexpRoot) -> bool {
for branch in root {
for qa in branch {
if qa.is_plus() || qa.is_star() {
return true;
}
if let Some(ref subtree) = qa.subtree
&& regexp_has_plus_star(subtree)
{
return true;
}
}
}
false
}
#[must_use]
pub fn make_regexp_nfa_arena(root: RegexpRoot) -> (StateArena, StateId, Arc<FieldMatcher>) {
let next_field = Arc::new(FieldMatcher::new());
let (mut arena, start) = if root.is_empty() {
let mut arena = StateArena::with_capacity(4);
let match_state = arena.alloc();
arena[match_state]
.field_transitions
.push(next_field.clone());
let vt_state = arena.alloc_with_table(SmallTable::with_mappings(
StateId::NONE,
&[ARENA_VALUE_TERMINATOR],
&[match_state],
));
let closing_quote =
arena.alloc_with_table(SmallTable::with_mappings(StateId::NONE, b"\"", &[vt_state]));
let start = arena.alloc_with_table(SmallTable::with_mappings(
StateId::NONE,
b"\"",
&[closing_quote],
));
(arena, start)
} else {
let mut arena = StateArena::with_capacity(16);
let match_state = arena.alloc();
arena[match_state]
.field_transitions
.push(next_field.clone());
let vt_state = arena.alloc_with_table(SmallTable::with_mappings(
StateId::NONE,
&[ARENA_VALUE_TERMINATOR],
&[match_state],
));
let next_step =
arena.alloc_with_table(SmallTable::with_mappings(StateId::NONE, b"\"", &[vt_state]));
let branch_start = make_arena_nfa_from_branches(&root, &mut arena, next_step);
let start = arena.alloc_with_table(SmallTable::with_mappings(
StateId::NONE,
b"\"",
&[branch_start],
));
(arena, start)
};
arena.precompute_epsilon_closures();
(arena, start, next_field)
}
fn make_arena_nfa_from_branches(
root: &RegexpRoot,
arena: &mut StateArena,
next_step: StateId,
) -> StateId {
if root.is_empty() {
return next_step;
}
if root.len() == 1 {
return make_one_arena_branch_fa(&root[0], arena, next_step);
}
let mut branch_starts = Vec::with_capacity(root.len());
for branch in root {
if branch.is_empty() {
branch_starts.push(next_step);
} else {
let branch_start = make_one_arena_branch_fa(branch, arena, next_step);
branch_starts.push(branch_start);
}
}
let start = arena.alloc();
arena[start].table.epsilons = SmallVec::from_vec(branch_starts);
start
}
fn make_one_arena_branch_fa(
branch: &RegexpBranch,
arena: &mut StateArena,
next_step: StateId,
) -> StateId {
let mut current_next = next_step;
for qa in branch.iter().rev() {
let original_next = current_next;
if qa.is_plus() || qa.is_star() {
current_next = create_arena_plus_star_loop(qa, arena, original_next, qa.is_star());
} else if qa.is_qm() {
let atom_state = make_arena_atom_fa(qa, arena, current_next);
arena[atom_state].table.epsilons.push(original_next);
current_next = atom_state;
} else if qa.is_singleton() {
current_next = make_arena_atom_fa(qa, arena, current_next);
} else {
let n = usize::try_from(qa.quant_min).expect("quant_min must be non-negative");
let m = usize::try_from(qa.quant_max).expect("quant_max must be non-negative");
debug_assert!(
qa.quant_max <= REGEXP_QUANTIFIER_MAX,
"parser must bound quantifier repetition counts"
);
for _ in n..m {
let atom_state = make_arena_atom_fa(qa, arena, current_next);
arena[atom_state].table.epsilons.push(current_next);
current_next = atom_state;
}
for _ in 0..n {
current_next = make_arena_atom_fa(qa, arena, current_next);
}
}
}
current_next
}
fn create_arena_plus_star_loop(
qa: &QuantifiedAtom,
arena: &mut StateArena,
exit_state: StateId,
is_star: bool,
) -> StateId {
let loopback = arena.alloc();
let start = make_arena_atom_fa(qa, arena, loopback);
arena[loopback].table.epsilons = smallvec![exit_state, start];
if is_star {
arena[start].table.epsilons.push(exit_state);
}
let accel = qa.ascii_negated_bytes.as_ref().map(|bytes| {
let mut accel = crate::automaton::AccelInfo {
exit_bytes: [0; 3],
len: u8::try_from(bytes.len()).expect("ascii_negated_bytes is bounded to 3"),
};
for (i, &b) in bytes.iter().enumerate() {
accel.exit_bytes[i] = b;
}
accel
});
if let Some(accel) = accel {
arena[start].table.accel = Some(accel.clone());
arena[loopback].table.accel = Some(accel);
}
start
}
fn make_arena_atom_fa(qa: &QuantifiedAtom, arena: &mut StateArena, next: StateId) -> StateId {
if qa.is_dot {
make_arena_dot_fa(arena, next)
} else if let Some(ref subtree) = qa.subtree {
make_arena_nfa_from_branches(subtree, arena, next)
} else if let Some(ref cache_key) = qa.cache_key {
if cache_key == "wb_W" {
return make_nonword_char_fa(arena, next);
}
make_cached_rune_range_fa(cache_key, &qa.runes, arena, next)
} else {
make_arena_rune_range_fa(&qa.runes, arena, next)
}
}
#[allow(clippy::type_complexity)]
#[allow(clippy::similar_names)]
fn make_utf8_char_fa(
arena: &mut StateArena,
dest: StateId,
ascii_filter: Option<&dyn Fn(&mut [StateId; BYTE_CEILING])>,
) -> StateId {
let s_last = arena.alloc_with_table({
let mut table = SmallTable::new();
let mut unpacked = [StateId::NONE; BYTE_CEILING];
unpacked[0x80..0xC0].fill(dest);
table.pack(&unpacked);
table
});
let s_last_inter = arena.alloc_with_table({
let mut table = SmallTable::new();
let mut unpacked = [StateId::NONE; BYTE_CEILING];
unpacked[0x80..0xC0].fill(s_last);
table.pack(&unpacked);
table
});
let s_first_inter = arena.alloc_with_table({
let mut table = SmallTable::new();
let mut unpacked = [StateId::NONE; BYTE_CEILING];
unpacked[0x80..0xC0].fill(s_last_inter);
table.pack(&unpacked);
table
});
let target_e0 = arena.alloc_with_table({
let mut table = SmallTable::new();
let mut unpacked = [StateId::NONE; BYTE_CEILING];
unpacked[0xA0..0xC0].fill(s_last);
table.pack(&unpacked);
table
});
let target_ed = arena.alloc_with_table({
let mut table = SmallTable::new();
let mut unpacked = [StateId::NONE; BYTE_CEILING];
unpacked[0x80..0xA0].fill(s_last);
table.pack(&unpacked);
table
});
let target_f0 = arena.alloc_with_table({
let mut table = SmallTable::new();
let mut unpacked = [StateId::NONE; BYTE_CEILING];
unpacked[0x90..0xC0].fill(s_last_inter);
table.pack(&unpacked);
table
});
let target_f4 = arena.alloc_with_table({
let mut table = SmallTable::new();
let mut unpacked = [StateId::NONE; BYTE_CEILING];
unpacked[0x80..0x90].fill(s_last_inter);
table.pack(&unpacked);
table
});
arena.alloc_with_table({
let mut unpacked = [StateId::NONE; BYTE_CEILING];
unpacked[..0x80].fill(dest);
if let Some(filter) = ascii_filter {
filter(&mut unpacked);
}
unpacked[0xC2..0xE0].fill(s_last);
unpacked[0xE0] = target_e0;
unpacked[0xE1..0xED].fill(s_last_inter);
unpacked[0xED] = target_ed;
unpacked[0xEE..0xF0].fill(s_last_inter);
unpacked[0xF0] = target_f0;
unpacked[0xF1..0xF4].fill(s_first_inter);
unpacked[0xF4] = target_f4;
let mut table = SmallTable::new();
table.pack(&unpacked);
table
})
}
fn make_arena_dot_fa(arena: &mut StateArena, dest: StateId) -> StateId {
make_utf8_char_fa(arena, dest, None)
}
pub fn make_nonword_char_fa(arena: &mut StateArena, dest: StateId) -> StateId {
make_utf8_char_fa(
arena,
dest,
Some(&|unpacked| {
for b in b'a'..=b'z' {
unpacked[b as usize] = StateId::NONE;
}
for b in b'A'..=b'Z' {
unpacked[b as usize] = StateId::NONE;
}
for b in b'0'..=b'9' {
unpacked[b as usize] = StateId::NONE;
}
unpacked[b'_' as usize] = StateId::NONE;
}),
)
}
struct CachedShell {
tables: Vec<SmallTable>,
root: u32,
}
thread_local! {
static FA_SHELL_CACHE: RefCell<FxHashMap<String, CachedShell>> = RefCell::new(FxHashMap::default());
}
fn build_shell(rr: &RuneRange) -> CachedShell {
let mut temp_arena = StateArena::with_capacity(16);
let placeholder = temp_arena.alloc();
let root_id = make_arena_rune_range_fa(rr, &mut temp_arena, placeholder);
let mut tables = Vec::with_capacity(temp_arena.len());
for i in 0..temp_arena.len() {
let id = StateId::from_index(i);
tables.push(temp_arena[id].table.clone());
}
CachedShell {
tables,
root: u32::try_from(root_id.index())
.expect("shell arenas have far fewer than u32::MAX states"),
}
}
fn instantiate_shell(shell: &CachedShell, arena: &mut StateArena, next: StateId) -> StateId {
let mut id_map: Vec<StateId> = Vec::with_capacity(shell.tables.len());
id_map.push(next);
for _ in 1..shell.tables.len() {
id_map.push(arena.alloc());
}
for (local_idx, src_table) in shell.tables.iter().enumerate() {
if local_idx == 0 {
continue;
}
let real_id = id_map[local_idx];
let mut table = src_table.clone();
for step in &mut table.steps {
if !step.is_none() {
*step = id_map[step.index()];
}
}
for eps in &mut table.epsilons {
debug_assert!(
!eps.is_none(),
"cached shell epsilons must not contain NONE entries"
);
*eps = id_map[eps.index()];
}
arena[real_id].table = table;
}
id_map[shell.root as usize]
}
fn make_cached_rune_range_fa(
cache_key: &str,
rr: &RuneRange,
arena: &mut StateArena,
next: StateId,
) -> StateId {
FA_SHELL_CACHE.with(|cache| {
let mut cache = cache.borrow_mut();
if let Some(shell) = cache.get(cache_key) {
return instantiate_shell(shell, arena, next);
}
let shell = build_shell(rr);
let result = instantiate_shell(&shell, arena, next);
cache.insert(cache_key.to_string(), shell);
result
})
}
pub fn clear_fa_shell_cache() {
FA_SHELL_CACHE.with(|cache| {
cache.borrow_mut().clear();
});
}
#[cfg(test)]
pub fn fa_shell_cache_size() -> usize {
FA_SHELL_CACHE.with(|cache| cache.borrow().len())
}
struct ArenaRuneTreeEntry {
next: Option<StateId>,
child: Option<ArenaRuneTreeNode>,
}
type ArenaRuneTreeNode = Vec<Option<ArenaRuneTreeEntry>>;
fn new_arena_rune_tree_node() -> ArenaRuneTreeNode {
(0..BYTE_CEILING).map(|_| None).collect()
}
fn arena_nfa_from_rune_tree(arena: &mut StateArena, root: &ArenaRuneTreeNode) -> StateId {
arena_table_from_rune_tree_node(arena, root)
}
fn arena_table_from_rune_tree_node(arena: &mut StateArena, node: &ArenaRuneTreeNode) -> StateId {
let mut unpacked: [StateId; BYTE_CEILING] = [StateId::NONE; BYTE_CEILING];
for (b, entry_opt) in node.iter().enumerate() {
if let Some(entry) = entry_opt {
if let Some(next) = entry.next {
unpacked[b] = next;
} else if let Some(ref child) = entry.child {
let child_state = arena_table_from_rune_tree_node(arena, child);
unpacked[b] = child_state;
}
}
}
let mut table = SmallTable::new();
table.pack(&unpacked);
arena.alloc_with_table(table)
}
fn make_arena_rune_range_fa(rr: &RuneRange, arena: &mut StateArena, next: StateId) -> StateId {
let mut root = new_arena_rune_tree_node();
for pair in rr {
add_arena_rune_pair_tree_entry(&mut root, pair.lo, pair.hi, next);
}
arena_nfa_from_rune_tree(arena, &root)
}
fn add_arena_rune_pair_tree_entry(root: &mut ArenaRuneTreeNode, lo: char, hi: char, dest: StateId) {
let lo_u32 = lo as u32;
let hi_u32 = hi as u32;
let boundaries = [UTF8_1BYTE_MAX, UTF8_2BYTE_MAX, UTF8_3BYTE_MAX, u32::MAX];
let mut current = lo_u32;
for &boundary in &boundaries {
if current > hi_u32 {
break;
}
if boundary < current {
continue;
}
let segment_end = hi_u32.min(boundary);
if intersects_surrogate(current, segment_end) {
if before_surrogate(current) {
let pre_end = (SURROGATE_START - 1).min(segment_end);
if let (Some(start), Some(end)) = (char::from_u32(current), char::from_u32(pre_end))
{
add_arena_utf8_range_to_tree(root, start, end, dest);
}
}
if after_surrogate(segment_end) {
let post_start = (SURROGATE_END + 1).max(current);
if let (Some(start), Some(end)) =
(char::from_u32(post_start), char::from_u32(segment_end))
{
add_arena_utf8_range_to_tree(root, start, end, dest);
}
}
} else if let (Some(start), Some(end)) =
(char::from_u32(current), char::from_u32(segment_end))
{
add_arena_utf8_range_to_tree(root, start, end, dest);
}
current = segment_end + 1;
}
}
fn add_arena_utf8_range_to_tree(root: &mut ArenaRuneTreeNode, lo: char, hi: char, dest: StateId) {
let lo_bytes = rune_to_utf8(lo);
let hi_bytes = rune_to_utf8(hi);
debug_assert_eq!(lo_bytes.len(), hi_bytes.len());
add_arena_byte_range_recursive(root, &lo_bytes, &hi_bytes, 0, dest);
}
fn add_arena_byte_range_recursive(
node: &mut ArenaRuneTreeNode,
lo_bytes: &[u8],
hi_bytes: &[u8],
idx: usize,
dest: StateId,
) {
if idx >= lo_bytes.len() {
return;
}
let lo_byte = lo_bytes[idx];
let hi_byte = hi_bytes[idx];
let is_last = idx == lo_bytes.len() - 1;
if lo_byte == hi_byte {
ensure_arena_tree_entry(node, lo_byte);
let entry = node[lo_byte as usize].as_mut().unwrap();
if is_last {
entry.next = Some(dest);
} else {
if entry.child.is_none() {
entry.child = Some(new_arena_rune_tree_node());
}
add_arena_byte_range_recursive(
entry.child.as_mut().unwrap(),
lo_bytes,
hi_bytes,
idx + 1,
dest,
);
}
} else {
add_arena_lo_range_to_tree(node, lo_bytes, idx, dest);
add_arena_middle_range_to_tree(
node,
lo_byte + 1,
hi_byte - 1,
remaining_byte_depth(lo_bytes.len(), idx, 1),
dest,
);
add_arena_hi_range_to_tree(node, hi_bytes, idx, dest);
}
}
const fn remaining_byte_depth(byte_len: usize, idx: usize, offset: usize) -> usize {
byte_len - (idx + offset)
}
const fn intersects_surrogate(current: u32, segment_end: u32) -> bool {
current <= SURROGATE_END && segment_end >= SURROGATE_START
}
const fn before_surrogate(current: u32) -> bool {
current < SURROGATE_START
}
const fn after_surrogate(segment_end: u32) -> bool {
segment_end > SURROGATE_END
}
fn add_arena_lo_range_to_tree(
node: &mut ArenaRuneTreeNode,
lo_bytes: &[u8],
idx: usize,
dest: StateId,
) {
let lo_byte = lo_bytes[idx];
let is_last = idx == lo_bytes.len() - 1;
ensure_arena_tree_entry(node, lo_byte);
let entry = node[lo_byte as usize].as_mut().unwrap();
if is_last {
entry.next = Some(dest);
} else {
if entry.child.is_none() {
entry.child = Some(new_arena_rune_tree_node());
}
let child = entry.child.as_mut().unwrap();
let next_byte = lo_bytes[idx + 1];
add_arena_lo_range_to_tree(child, lo_bytes, idx + 1, dest);
add_arena_middle_range_to_tree(
child,
next_byte.wrapping_add(1),
0xBF,
remaining_byte_depth(lo_bytes.len(), idx, 2),
dest,
);
}
}
fn add_arena_hi_range_to_tree(
node: &mut ArenaRuneTreeNode,
hi_bytes: &[u8],
idx: usize,
dest: StateId,
) {
let hi_byte = hi_bytes[idx];
let is_last = idx == hi_bytes.len() - 1;
ensure_arena_tree_entry(node, hi_byte);
let entry = node[hi_byte as usize].as_mut().unwrap();
if is_last {
entry.next = Some(dest);
} else {
if entry.child.is_none() {
entry.child = Some(new_arena_rune_tree_node());
}
let child = entry.child.as_mut().unwrap();
let next_byte = hi_bytes[idx + 1];
add_arena_middle_range_to_tree(
child,
0x80,
next_byte.wrapping_sub(1),
remaining_byte_depth(hi_bytes.len(), idx, 2),
dest,
);
add_arena_hi_range_to_tree(child, hi_bytes, idx + 1, dest);
}
}
fn add_arena_middle_range_to_tree(
node: &mut ArenaRuneTreeNode,
lo: u8,
hi: u8,
depth: usize,
dest: StateId,
) {
debug_assert!(
depth <= 3,
"rune-range middle depth bounded by UTF-8 byte length, got {depth}"
);
if depth == 0 {
for byte in lo..=hi {
ensure_arena_tree_entry(node, byte);
node[byte as usize].as_mut().unwrap().next = Some(dest);
}
} else {
for byte in lo..=hi {
ensure_arena_tree_entry(node, byte);
let entry = node[byte as usize].as_mut().unwrap();
if entry.child.is_none() {
entry.child = Some(new_arena_rune_tree_node());
}
add_arena_middle_range_to_tree(
entry.child.as_mut().unwrap(),
0x80,
0xBF,
depth - 1,
dest,
);
}
}
}
fn ensure_arena_tree_entry(node: &mut ArenaRuneTreeNode, byte: u8) {
let idx = byte as usize;
if node[idx].is_none() {
node[idx] = Some(ArenaRuneTreeEntry {
next: None,
child: None,
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remaining_byte_depth() {
assert_eq!(remaining_byte_depth(4, 0, 1), 3);
assert_eq!(remaining_byte_depth(4, 1, 1), 2);
assert_eq!(remaining_byte_depth(4, 2, 1), 1);
assert_eq!(remaining_byte_depth(4, 3, 1), 0);
assert_eq!(remaining_byte_depth(4, 0, 2), 2);
assert_eq!(remaining_byte_depth(4, 1, 2), 1);
assert_eq!(remaining_byte_depth(4, 2, 2), 0);
assert_eq!(remaining_byte_depth(2, 0, 1), 1);
assert_eq!(remaining_byte_depth(3, 1, 2), 0);
}
#[test]
fn test_intersects_surrogate_boundary() {
assert!(!intersects_surrogate(0, 0x100));
assert!(!intersects_surrogate(0, SURROGATE_START - 1));
assert!(intersects_surrogate(0, SURROGATE_START));
assert!(intersects_surrogate(SURROGATE_START, SURROGATE_END));
assert!(!intersects_surrogate(SURROGATE_END + 1, 0xFFFF));
assert!(intersects_surrogate(0, 0xFFFF));
}
fn verify_shell_cache_populate_then_clear(pattern: &str) {
clear_fa_shell_cache();
assert_eq!(fa_shell_cache_size(), 0);
let root = crate::regexp::parse_regexp(pattern).unwrap();
let _ = make_regexp_nfa_arena(root);
let populated = fa_shell_cache_size();
assert!(
populated > 0,
"building a cache-keyed regex must populate the cache"
);
clear_fa_shell_cache();
assert_eq!(
fa_shell_cache_size(),
0,
"clear_fa_shell_cache must drop all entries"
);
}
#[test]
#[cfg_attr(miri, ignore)]
fn test_clear_fa_shell_cache_drops_entries() {
verify_shell_cache_populate_then_clear("~i");
}
#[cfg(miri)]
#[test]
fn test_clear_fa_shell_cache_drops_entries_miri_friendly() {
verify_shell_cache_populate_then_clear("~p{Cc}");
}
#[test]
fn test_before_after_surrogate() {
assert!(before_surrogate(SURROGATE_START - 1));
assert!(!before_surrogate(SURROGATE_START));
assert!(!before_surrogate(SURROGATE_END));
assert!(after_surrogate(SURROGATE_END + 1));
assert!(!after_surrogate(SURROGATE_END));
assert!(!after_surrogate(SURROGATE_START));
}
}