use regex_syntax::hir::{Class, Hir, HirKind, Look, Repetition};
use crate::scan::nfa::NfaPlan;
const LANES: usize = vyre_primitives::nfa::subgroup_nfa::LANES_PER_SUBGROUP;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum RegexCompileError {
Parse {
pattern_index: usize,
message: String,
},
Unsupported {
pattern_index: usize,
feature: &'static str,
},
TooManyStates {
states: usize,
cap: usize,
},
PatternCountOverflow {
count: usize,
},
MatchLengthOverflow {
pattern_index: usize,
len: usize,
},
TableWordCountOverflow {
table: &'static str,
},
StorageReserveFailed {
field: &'static str,
requested: usize,
message: String,
},
}
impl std::fmt::Display for RegexCompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Parse {
pattern_index,
message,
} => write!(
f,
"regex_compile: pattern {pattern_index} parse error: {message}. \
Fix: review the regex syntax."
),
Self::Unsupported {
pattern_index,
feature,
} => write!(
f,
"regex_compile: pattern {pattern_index} uses unsupported feature `{feature}`. \
Fix: rewrite the detector into supported GPU-NFA syntax or split it into GPU-compatible rules."
),
Self::TooManyStates { states, cap } => write!(
f,
"regex_compile: NFA needs {states} states; per-pipeline cap is {cap}. \
Fix: split the pattern set across multiple pipelines."
),
Self::PatternCountOverflow { count } => write!(
f,
"regex_compile: pattern count {count} exceeds u32 capacity. Fix: shard the pattern set before GPU regex compilation."
),
Self::MatchLengthOverflow {
pattern_index,
len,
} => write!(
f,
"regex_compile: pattern {pattern_index} match length {len} exceeds u32 capacity. Fix: bound or shard the regex before GPU compilation."
),
Self::TableWordCountOverflow { table } => write!(
f,
"regex_compile: {table} table word count overflows host usize. Fix: shard the regex pattern set before table construction."
),
Self::StorageReserveFailed {
field,
requested,
message,
} => write!(
f,
"regex_compile: could not reserve {requested} {field} slot(s): {message}. Fix: shard the regex pattern set before GPU compilation."
),
}
}
}
impl std::error::Error for RegexCompileError {}
#[derive(Debug, Clone)]
pub struct CompiledRegexSet {
pub plan: NfaPlan,
pub transition_table: Vec<u32>,
pub epsilon_table: Vec<u32>,
}
const STATE_CAP: usize = LANES * 32;
pub fn compile_regex_set(patterns: &[&str]) -> Result<CompiledRegexSet, RegexCompileError> {
let mut builder = NfaBuilder::new();
let _pattern_count =
u32::try_from(patterns.len()).map_err(|_| RegexCompileError::PatternCountOverflow {
count: patterns.len(),
})?;
let mut accept_states = Vec::new();
reserve_vec(&mut accept_states, patterns.len(), "accept state")?;
let mut accept_state_ids = Vec::new();
reserve_vec(&mut accept_state_ids, patterns.len(), "accept state id")?;
let mut accept_start_anchored = Vec::new();
reserve_vec(
&mut accept_start_anchored,
patterns.len(),
"accept start-anchor flag",
)?;
let mut accept_end_anchored = Vec::new();
reserve_vec(
&mut accept_end_anchored,
patterns.len(),
"accept end-anchor flag",
)?;
let entry = builder.fresh_state()?;
for (pid, pat) in patterns.iter().enumerate() {
let hir = match regex_syntax::ParserBuilder::new()
.unicode(false)
.utf8(false)
.build()
.parse(pat)
{
Ok(h) => h,
Err(byte_mode_err) => regex_syntax::ParserBuilder::new()
.unicode(true)
.utf8(false)
.build()
.parse(pat)
.map_err(|_unicode_err| RegexCompileError::Parse {
pattern_index: pid,
message: format!("{byte_mode_err}"),
})?,
};
let (frag, anchors) = build_pattern_hir(&mut builder, &hir, pid)?;
builder.add_epsilon(entry, frag.start);
let pid_u32 = u32::try_from(pid).map_err(|_| RegexCompileError::PatternCountOverflow {
count: patterns.len(),
})?;
let match_len_u32 =
u32::try_from(frag.match_len).map_err(|_| RegexCompileError::MatchLengthOverflow {
pattern_index: pid,
len: frag.match_len,
})?;
accept_states.push((pid_u32, match_len_u32));
accept_state_ids.push(frag.end);
accept_start_anchored.push(anchors.start);
accept_end_anchored.push(anchors.end);
}
if builder.state_count() > STATE_CAP {
return Err(RegexCompileError::TooManyStates {
states: builder.state_count(),
cap: STATE_CAP,
});
}
let plan = NfaPlan {
num_states: u32::try_from(builder.state_count()).map_err(|_| {
RegexCompileError::TooManyStates {
states: builder.state_count(),
cap: STATE_CAP,
}
})?,
input_len: 0,
accept_states,
accept_state_ids,
accept_start_anchored,
accept_end_anchored,
};
let (transition_table, epsilon_table) = builder.emit_lane_major_tables()?;
Ok(CompiledRegexSet {
plan,
transition_table,
epsilon_table,
})
}
pub fn build_rule_pipeline_from_regex(
patterns: &[&str],
input_buf: &str,
hit_buf: &str,
input_len: u32,
) -> Result<crate::scan::RulePipeline, RegexCompileError> {
let compiled = compile_regex_set(patterns)?;
let has_epsilon = compiled.epsilon_table.iter().any(|word| *word != 0);
let program = crate::scan::nfa::nfa_scan_with_plan(
&compiled.plan,
has_epsilon,
input_buf,
hit_buf,
input_len,
)
.map_err(|_| RegexCompileError::TooManyStates {
states: compiled.plan.num_states as usize,
cap: STATE_CAP,
})?;
Ok(crate::scan::RulePipeline {
program,
transition_table: compiled.transition_table,
epsilon_table: compiled.epsilon_table,
plan: compiled.plan.for_input_len(input_len),
})
}
#[derive(Debug)]
struct NfaBuilder {
state_count: usize,
transitions: Vec<ByteTransition>,
epsilons: Vec<(u32, u32)>,
}
#[derive(Debug, Clone)]
struct ByteTransition {
src: u32,
set: ByteSet,
dst: u32,
}
#[derive(Debug, Clone)]
struct ByteSet {
bits: [u64; 4], }
impl ByteSet {
fn new() -> Self {
Self { bits: [0; 4] }
}
fn insert(&mut self, b: u8) {
self.bits[(b / 64) as usize] |= 1u64 << (b % 64);
}
fn from_byte(b: u8) -> Self {
let mut s = Self::new();
s.insert(b);
s
}
fn from_range(lo: u8, hi: u8) -> Self {
let mut s = Self::new();
for b in lo..=hi {
s.insert(b);
}
s
}
fn for_each_set_byte(&self, mut f: impl FnMut(u8)) {
for (word_idx, &word) in self.bits.iter().enumerate() {
let mut bits = word;
while bits != 0 {
let bit = bits.trailing_zeros() as usize;
f((word_idx * 64 + bit) as u8);
bits &= bits - 1;
}
}
}
}
#[derive(Debug, Clone, Copy)]
struct Fragment {
start: u32,
end: u32,
match_len: usize,
}
#[derive(Debug, Clone, Copy, Default)]
struct PatternAnchors {
start: bool,
end: bool,
}
impl NfaBuilder {
fn new() -> Self {
Self {
state_count: 0,
transitions: Vec::new(),
epsilons: Vec::new(),
}
}
fn state_count(&self) -> usize {
self.state_count
}
fn fresh_state(&mut self) -> Result<u32, RegexCompileError> {
if self.state_count >= STATE_CAP {
return Err(RegexCompileError::TooManyStates {
states: self.state_count.saturating_add(1),
cap: STATE_CAP,
});
}
let state =
u32::try_from(self.state_count).map_err(|_| RegexCompileError::TooManyStates {
states: self.state_count,
cap: STATE_CAP,
})?;
self.state_count =
self.state_count
.checked_add(1)
.ok_or(RegexCompileError::TooManyStates {
states: usize::MAX,
cap: STATE_CAP,
})?;
Ok(state)
}
fn add_byte_transition(&mut self, src: u32, set: ByteSet, dst: u32) {
self.transitions.push(ByteTransition { src, set, dst });
}
fn add_epsilon(&mut self, src: u32, dst: u32) {
self.epsilons.push((src, dst));
}
fn emit_lane_major_tables(&self) -> Result<(Vec<u32>, Vec<u32>), RegexCompileError> {
let n = self.state_count();
let mut transitions = zeroed_u32_table(
table_word_count(n, 256, "transition")?,
"transition table word",
)?;
let mut epsilons =
zeroed_u32_table(table_word_count(n, 1, "epsilon")?, "epsilon table word")?;
for edge in &self.transitions {
let src = edge.src as usize;
let dst_lane = (edge.dst / 32) as usize;
let dst_bit = 1u32 << (edge.dst % 32);
edge.set.for_each_set_byte(|byte| {
let idx = src * 256 * LANES + (byte as usize) * LANES + dst_lane;
transitions[idx] |= dst_bit;
});
}
for &(src, dst) in &self.epsilons {
let dst_lane = (dst / 32) as usize;
let dst_bit = 1u32 << (dst % 32);
let idx = src as usize * LANES + dst_lane;
epsilons[idx] |= dst_bit;
}
Ok((transitions, epsilons))
}
}
fn table_word_count(
states: usize,
byte_columns: usize,
table: &'static str,
) -> Result<usize, RegexCompileError> {
states
.checked_mul(byte_columns)
.and_then(|words| words.checked_mul(LANES))
.ok_or(RegexCompileError::TableWordCountOverflow { table })
}
fn zeroed_u32_table(words: usize, field: &'static str) -> Result<Vec<u32>, RegexCompileError> {
let mut table = Vec::new();
reserve_vec(&mut table, words, field)?;
table.resize(words, 0);
Ok(table)
}
fn reserve_vec<T>(
vec: &mut Vec<T>,
requested: usize,
field: &'static str,
) -> Result<(), RegexCompileError> {
vyre_foundation::allocation::try_reserve_vec_to_capacity(vec, requested).map_err(|source| {
RegexCompileError::StorageReserveFailed {
field,
requested,
message: source.to_string(),
}
})
}
fn empty_fragment(b: &mut NfaBuilder) -> Result<Fragment, RegexCompileError> {
let s = b.fresh_state()?;
Ok(Fragment {
start: s,
end: s,
match_len: 0,
})
}
fn build_pattern_hir(
b: &mut NfaBuilder,
hir: &Hir,
pid: usize,
) -> Result<(Fragment, PatternAnchors), RegexCompileError> {
match hir.kind() {
HirKind::Look(Look::Start) => Ok((
empty_fragment(b)?,
PatternAnchors {
start: true,
end: false,
},
)),
HirKind::Look(Look::End) => Ok((
empty_fragment(b)?,
PatternAnchors {
start: false,
end: true,
},
)),
HirKind::Concat(parts) => {
let mut first = 0usize;
let mut last = parts.len();
let mut anchors = PatternAnchors::default();
if first < last && is_text_start_look(&parts[first]) {
anchors.start = true;
first += 1;
}
if first < last && is_text_end_look(&parts[last - 1]) {
anchors.end = true;
last -= 1;
}
Ok((build_hir_slice(b, &parts[first..last], pid)?, anchors))
}
_ => Ok((build_hir(b, hir, pid)?, PatternAnchors::default())),
}
}
fn is_text_start_look(hir: &Hir) -> bool {
matches!(hir.kind(), HirKind::Look(Look::Start))
}
fn is_text_end_look(hir: &Hir) -> bool {
matches!(hir.kind(), HirKind::Look(Look::End))
}
fn build_hir_slice(
b: &mut NfaBuilder,
parts: &[Hir],
pid: usize,
) -> Result<Fragment, RegexCompileError> {
let Some(first_part) = parts.first() else {
return empty_fragment(b);
};
let mut acc = build_hir(b, first_part, pid)?;
for child in &parts[1..] {
let next = build_hir(b, child, pid)?;
b.add_epsilon(acc.end, next.start);
acc = Fragment {
start: acc.start,
end: next.end,
match_len: acc.match_len + next.match_len,
};
}
Ok(acc)
}
fn build_hir(b: &mut NfaBuilder, hir: &Hir, pid: usize) -> Result<Fragment, RegexCompileError> {
match hir.kind() {
HirKind::Empty => empty_fragment(b),
HirKind::Literal(lit) => {
let start = b.fresh_state()?;
let mut prev = start;
for &byte in lit.0.iter() {
let next = b.fresh_state()?;
b.add_byte_transition(prev, ByteSet::from_byte(byte), next);
prev = next;
}
Ok(Fragment {
start,
end: prev,
match_len: lit.0.len(),
})
}
HirKind::Class(cls) => build_class(b, cls, pid),
HirKind::Repetition(rep) => build_repetition(b, rep, pid),
HirKind::Concat(parts) => build_hir_slice(b, parts, pid),
HirKind::Alternation(alts) => {
let fork = b.fresh_state()?;
let join = b.fresh_state()?;
let mut max_len = 0usize;
for child in alts {
let frag = build_hir(b, child, pid)?;
b.add_epsilon(fork, frag.start);
b.add_epsilon(frag.end, join);
if frag.match_len > max_len {
max_len = frag.match_len;
}
}
Ok(Fragment {
start: fork,
end: join,
match_len: max_len,
})
}
HirKind::Look(_) => Err(RegexCompileError::Unsupported {
pattern_index: pid,
feature: "non-edge lookaround assertion",
}),
HirKind::Capture(c) => {
build_hir(b, &c.sub, pid)
}
}
}
fn build_repetition(
b: &mut NfaBuilder,
rep: &Repetition,
pid: usize,
) -> Result<Fragment, RegexCompileError> {
let min = rep.min;
let max = rep.max;
if let Some(m) = max {
if m as usize > STATE_CAP {
return Err(RegexCompileError::TooManyStates {
states: m as usize,
cap: STATE_CAP,
});
}
}
if min as usize > STATE_CAP {
return Err(RegexCompileError::TooManyStates {
states: min as usize,
cap: STATE_CAP,
});
}
let start = b.fresh_state()?;
let mut tail = start;
let mut total_len = 0usize;
for _ in 0..min {
let frag = build_hir(b, &rep.sub, pid)?;
b.add_epsilon(tail, frag.start);
tail = frag.end;
total_len += frag.match_len;
}
match max {
None => {
let join = b.fresh_state()?;
let frag = build_hir(b, &rep.sub, pid)?;
b.add_epsilon(tail, frag.start);
b.add_epsilon(frag.end, frag.start); b.add_epsilon(frag.end, join);
b.add_epsilon(tail, join); tail = join;
}
Some(m) => {
for _ in min..m {
let frag = build_hir(b, &rep.sub, pid)?;
let join = b.fresh_state()?;
b.add_epsilon(tail, frag.start);
b.add_epsilon(frag.end, join);
b.add_epsilon(tail, join); tail = join;
}
}
}
Ok(Fragment {
start,
end: tail,
match_len: total_len,
})
}
fn build_class(b: &mut NfaBuilder, cls: &Class, pid: usize) -> Result<Fragment, RegexCompileError> {
if let Some(set) = try_class_as_ascii_byte_set(cls) {
let start = b.fresh_state()?;
let end = b.fresh_state()?;
b.add_byte_transition(start, set, end);
return Ok(Fragment {
start,
end,
match_len: 1,
});
}
let sequences = class_to_utf8_sequences(cls, pid)?;
if sequences.is_empty() {
return Err(RegexCompileError::Unsupported {
pattern_index: pid,
feature: "empty character class after Unicode expansion",
});
}
let start = b.fresh_state()?;
let end = b.fresh_state()?;
let mut max_len = 1usize;
for seq in &sequences {
if seq.is_empty() {
continue;
}
let arm_start = b.fresh_state()?;
b.add_epsilon(start, arm_start);
let mut prev = arm_start;
for &byte in seq {
let next = b.fresh_state()?;
b.add_byte_transition(prev, ByteSet::from_byte(byte), next);
prev = next;
}
b.add_epsilon(prev, end);
if seq.len() > max_len {
max_len = seq.len();
}
}
Ok(Fragment {
start,
end,
match_len: max_len,
})
}
fn try_class_as_ascii_byte_set(cls: &Class) -> Option<ByteSet> {
let mut out = ByteSet::new();
match cls {
Class::Bytes(byte_class) => {
for r in byte_class.iter() {
let merged = ByteSet::from_range(r.start(), r.end());
for w in 0..4 {
out.bits[w] |= merged.bits[w];
}
}
Some(out)
}
Class::Unicode(uni) => {
for r in uni.iter() {
if (r.end() as u32) > 0x7F {
return None;
}
let merged = ByteSet::from_range(r.start() as u8, r.end() as u8);
for w in 0..4 {
out.bits[w] |= merged.bits[w];
}
}
Some(out)
}
}
}
const MAX_CLASS_EXPANSION_CODEPOINTS: usize = 256;
fn class_to_utf8_sequences(cls: &Class, pid: usize) -> Result<Vec<Vec<u8>>, RegexCompileError> {
let mut sequences: Vec<Vec<u8>> = Vec::new();
let mut budget = MAX_CLASS_EXPANSION_CODEPOINTS;
match cls {
Class::Bytes(byte_class) => {
for r in byte_class.iter() {
for byte in r.start()..=r.end() {
if budget == 0 {
return Err(RegexCompileError::Unsupported {
pattern_index: pid,
feature: "byte character class exceeded expansion cap",
});
}
sequences.push(vec![byte]);
budget -= 1;
}
}
}
Class::Unicode(uni) => {
for r in uni.iter() {
let lo = r.start() as u32;
let hi = r.end() as u32;
for cp in lo..=hi {
if budget == 0 {
return Err(RegexCompileError::Unsupported {
pattern_index: pid,
feature: "unicode character class exceeded expansion cap",
});
}
if let Some(c) = char::from_u32(cp) {
let mut buf = [0u8; 4];
let encoded = c.encode_utf8(&mut buf);
sequences.push(encoded.as_bytes().to_vec());
budget -= 1;
}
}
}
}
}
Ok(sequences)
}
#[cfg(test)]
mod tests {
use super::*;
fn states_of(s: &str) -> u32 {
compile_regex_set(&[s]).unwrap().plan.num_states
}
#[test]
fn literal_compiles() {
let r = compile_regex_set(&["abc"]).unwrap();
assert_eq!(r.plan.num_states, 5);
assert_eq!(r.plan.accept_states.len(), 1);
}
#[test]
fn alternation_compiles() {
let r = compile_regex_set(&["a|b"]).unwrap();
assert!(r.plan.num_states > 0);
assert_eq!(r.plan.accept_states.len(), 1);
}
#[test]
fn class_compiles() {
let r = compile_regex_set(&["[a-z]"]).unwrap();
assert!(r.plan.num_states > 0);
}
#[test]
fn text_anchors_compile_to_accept_flags() {
let r = compile_regex_set(&["^foo$"]).unwrap();
assert_eq!(r.plan.accept_start_anchored, vec![true]);
assert_eq!(r.plan.accept_end_anchored, vec![true]);
}
#[test]
fn bounded_repetition_above_old_cap_compiles_under_state_cap() {
let r = compile_regex_set(&["a{0,128}"]).unwrap();
assert!(r.plan.num_states > 64);
assert!(r.plan.num_states <= STATE_CAP as u32);
}
#[test]
fn regex_compile_preserves_accept_metadata_through_checked_paths() {
let r = compile_regex_set(&["a", "bc", "^de$"]).unwrap();
assert_eq!(r.plan.accept_states, vec![(0, 1), (1, 2), (2, 2)]);
assert_eq!(r.plan.accept_state_ids.len(), 3);
assert_eq!(r.plan.accept_start_anchored, vec![false, false, true]);
assert_eq!(r.plan.accept_end_anchored, vec![false, false, true]);
assert_eq!(
r.transition_table.len(),
r.plan.num_states as usize * 256 * LANES
);
assert_eq!(r.epsilon_table.len(), r.plan.num_states as usize * LANES);
}
#[test]
fn regex_compile_uses_checked_abi_and_table_allocation_paths() {
let production = include_str!("regex_compile.rs")
.split("#[cfg(test)]")
.next()
.expect("Fix: regex_compile.rs must contain production section");
assert!(
production.contains("u32::try_from(pid)")
&& production.contains("u32::try_from(frag.match_len)")
&& production.contains("u32::try_from(builder.state_count())")
&& production.contains("u32::try_from(self.state_count)")
&& production.contains("checked_add(1)")
&& production.contains("try_reserve_vec_to_capacity")
&& !production.contains("pid as u32")
&& !production.contains("frag.match_len as u32")
&& !production.contains("builder.state_count() as u32")
&& !production.contains("self.state_count as u32")
&& !production.contains("vec![0u32;")
&& !production.contains("Vec::with_capacity(patterns.len())"),
"Fix: regex compilation must not truncate ids/counts or allocate NFA tables with infallible zero-vector construction."
);
}
#[test]
fn regex_pipeline_uses_compiled_plan_instead_of_literal_source_plan() {
let compiled = compile_regex_set(&["a|bc"]).unwrap();
let pipeline = build_rule_pipeline_from_regex(&["a|bc"], "input", "hits", 64).unwrap();
assert_eq!(pipeline.plan.num_states, compiled.plan.num_states);
assert_eq!(
pipeline.plan.accept_state_ids,
compiled.plan.accept_state_ids
);
assert_eq!(
pipeline.epsilon_table.iter().any(|word| *word != 0),
compiled.epsilon_table.iter().any(|word| *word != 0)
);
assert_ne!(
pipeline.plan.num_states,
crate::scan::nfa::compile(&["a|bc"]).num_states,
"regex pipeline must not rebuild the scan program from literal regex source bytes"
);
}
#[test]
fn states_count_grows_with_concat() {
let one = states_of("a");
let two = states_of("ab");
let three = states_of("abc");
assert!(two > one);
assert!(three > two);
}
#[test]
fn state_cap_enforced() {
let huge: String = (0..(STATE_CAP + 4)).map(|_| 'a').collect();
let err = compile_regex_set(&[&huge]).unwrap_err();
assert!(matches!(err, RegexCompileError::TooManyStates { .. }));
}
#[test]
fn unsupported_regex_diagnostic_does_not_route_to_cpu_backend() {
let err = compile_regex_set(&[r"\bsecret\b"]).unwrap_err();
let message = err.to_string().to_ascii_lowercase();
assert!(
!message.contains("cpu"),
"unsupported GPU-NFA regex diagnostics must not recommend host-side routing: {message}"
);
assert!(
message.contains("gpu"),
"unsupported GPU-NFA regex diagnostics must name the GPU-compatible rewrite contract: {message}"
);
}
#[test]
fn unicode_class_outside_ascii_compiles_via_utf8_expansion() {
let pat = "[hнһh]f_[a-zA-Z0-9]{4}";
let result = compile_regex_set(&[pat]);
let compiled = match result {
Ok(c) => c,
Err(e) => {
panic!("unicode-extended character class must compile via UTF-8 expansion; got {e}")
}
};
assert!(
compiled.plan.num_states > 4,
"expanded NFA must have non-trivial state count"
);
assert_eq!(compiled.plan.accept_states.len(), 1);
}
#[test]
fn ascii_only_class_keeps_single_byte_transition_path() {
let r = compile_regex_set(&["[ab]"]).unwrap();
assert_eq!(
r.plan.num_states, 3,
"[ab] must stay on the single-transition fast path (entry + 2 class states); got {} states",
r.plan.num_states
);
}
#[test]
fn unicode_class_above_expansion_cap_errors_cleanly() {
let pat = "[\u{0100}-\u{0200}]";
let err = compile_regex_set(&[pat]).unwrap_err();
match err {
RegexCompileError::Unsupported { feature, .. } => {
assert!(
feature.contains("expansion cap"),
"over-cap expansion must name the cap in its diagnostic: {feature}"
);
}
other => panic!("expected Unsupported expansion-cap error, got {other:?}"),
}
}
}