use vyre_libs::scan::{compile_regex_set, CompiledRegexSet, RegexCompileError};
pub const REGEX_DFA_CACHE_VERSION: u32 = 1;
#[derive(Debug, Clone)]
pub struct RegexDfaPipeline {
pub regex_set: CompiledRegexSet,
pub dfa: vyre_libs::scan::CompiledDfa,
pub pattern_literals: Vec<Vec<u8>>,
pub pattern_count: u32,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum RegexDfaError {
RegexCompile(RegexCompileError),
DfaBudgetExceeded {
message: String,
},
EmptyPatternSet,
}
impl std::fmt::Display for RegexDfaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::RegexCompile(inner) => write!(f, "regex_dfa: regex compile failed: {inner}"),
Self::DfaBudgetExceeded { message } => {
write!(f, "regex_dfa: DFA budget exceeded: {message}")
}
Self::EmptyPatternSet => write!(f, "regex_dfa: empty pattern set"),
}
}
}
impl std::error::Error for RegexDfaError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::RegexCompile(inner) => Some(inner),
_ => None,
}
}
}
impl From<RegexCompileError> for RegexDfaError {
fn from(err: RegexCompileError) -> Self {
Self::RegexCompile(err)
}
}
#[allow(clippy::while_let_on_iterator)]
pub fn extract_literal_core(pattern: &str) -> Vec<u8> {
let mut literal = Vec::new();
let mut chars = pattern.chars().peekable();
let mut escaped = false;
while let Some(ch) = chars.next() {
if escaped {
match ch {
'd' | 'D' | 'w' | 'W' | 's' | 'S' | 'b' | 'B' => break,
_ => {
if ch.is_ascii() {
literal.push(ch as u8);
} else {
break;
}
}
}
escaped = false;
continue;
}
match ch {
'\\' => {
escaped = true;
}
'[' | '(' | '|' | '*' | '+' | '?' | '{' | '^' | '$' | '.' => {
break;
}
_ => {
if ch.is_ascii() {
literal.push(ch as u8);
} else {
break;
}
}
}
}
literal
}
pub fn build_regex_dfa(
patterns: &[&str],
_input_len: u32,
) -> std::result::Result<RegexDfaPipeline, RegexDfaError> {
if patterns.is_empty() {
return Err(RegexDfaError::EmptyPatternSet);
}
let regex_set = compile_regex_set(patterns)?;
let pattern_literals: Vec<Vec<u8>> = patterns.iter().map(|p| extract_literal_core(p)).collect();
let dfa_inputs: Vec<&[u8]> = pattern_literals
.iter()
.filter(|lit| !lit.is_empty())
.map(|lit| lit.as_slice())
.collect();
if dfa_inputs.is_empty() {
return Err(RegexDfaError::DfaBudgetExceeded {
message: "no patterns have extractable literal cores for DFA construction".into(),
});
}
let dfa = vyre_libs::scan::dfa_compile_with_budget(
&dfa_inputs,
vyre_libs::scan::DEFAULT_DFA_BUDGET_BYTES,
)
.map_err(|e| RegexDfaError::DfaBudgetExceeded {
message: format!("{e}"),
})?;
Ok(RegexDfaPipeline {
regex_set,
dfa,
pattern_literals,
pattern_count: patterns.len() as u32,
})
}
fn regex_dfa_cache_key(patterns: &[&str], input_len: u32) -> String {
use sha2::{Digest, Sha256};
let mut h = Sha256::new();
h.update(REGEX_DFA_CACHE_VERSION.to_le_bytes());
h.update(input_len.to_le_bytes());
h.update((patterns.len() as u32).to_le_bytes());
for p in patterns {
h.update((p.len() as u32).to_le_bytes());
h.update(p.as_bytes());
}
let digest = h.finalize();
let mut hex = String::with_capacity(64);
for byte in digest {
use std::fmt::Write as _;
let _ = write!(hex, "{byte:02x}");
}
hex
}
pub fn regex_dfa_cached(
patterns: &[&str],
input_len: u32,
) -> std::result::Result<RegexDfaPipeline, RegexDfaError> {
let started = std::time::Instant::now();
let Some(cache_dir) = super::gpu_cache::gpu_matcher_cache_dir() else {
return build_regex_dfa(patterns, input_len);
};
let cache_key = format!("dfa-{}", regex_dfa_cache_key(patterns, input_len));
if let Some(path) = vyre_libs::scan::engine_cache_path(&cache_dir, &cache_key) {
if let Ok(bytes) = std::fs::read(&path) {
match vyre_libs::scan::CompiledDfa::from_bytes(&bytes) {
Ok(dfa) => {
if let Ok(regex_set) = compile_regex_set(patterns) {
let pattern_literals: Vec<Vec<u8>> =
patterns.iter().map(|p| extract_literal_core(p)).collect();
tracing::debug!(
target: "keyhog::routing",
patterns = patterns.len(),
input_len,
elapsed_ms = started.elapsed().as_millis() as u64,
"RegexDfaPipeline cache hit - skipped DFA compile"
);
return Ok(RegexDfaPipeline {
regex_set,
dfa,
pattern_literals,
pattern_count: patterns.len() as u32,
});
}
}
Err(_) => {
let _ = std::fs::remove_file(&path);
}
}
}
}
let pipeline = build_regex_dfa(patterns, input_len)?;
if let Some(path) = vyre_libs::scan::engine_cache_path(&cache_dir, &cache_key) {
if let Ok(bytes) = pipeline.dfa.to_bytes() {
let tmp = path.with_extension(format!("tmp.{}", std::process::id()));
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if std::fs::write(&tmp, &bytes).is_ok() {
if let Err(error) = std::fs::rename(&tmp, &path) {
tracing::debug!(
target: "keyhog::routing",
error = %error,
path = %path.display(),
"regex DFA cache rename failed"
);
let _ = std::fs::remove_file(&tmp);
}
}
}
}
tracing::debug!(
target: "keyhog::routing",
patterns = patterns.len(),
input_len,
elapsed_ms = started.elapsed().as_millis() as u64,
"RegexDfaPipeline cache miss - compiled and saved"
);
Ok(pipeline)
}
impl RegexDfaPipeline {
#[must_use]
pub fn reference_scan(&self, haystack: &[u8]) -> Vec<vyre_libs::scan::LiteralMatch> {
let mut results = Vec::new();
let mut state = 0_u32;
for (pos, &byte) in haystack.iter().enumerate() {
state = self.dfa.transitions[(state as usize) * 256 + (byte as usize)];
let begin = self.dfa.output_offsets[state as usize] as usize;
let end = self.dfa.output_offsets[state as usize + 1] as usize;
for &pattern_id in &self.dfa.output_records[begin..end] {
let lit = &self.pattern_literals[pattern_id as usize];
let len = lit.len() as u32;
results.push(vyre_libs::scan::LiteralMatch::new(
pattern_id,
(pos as u32 + 1).saturating_sub(len),
pos as u32 + 1,
));
}
}
results.sort_unstable();
results
}
}