keyhog_scanner/engine/gpu_regex_dfa.rs
1//! GPU `RegexDfaPipeline` - regex sets compiled through DFA subset
2//! construction into O(1)/byte Aho-Corasick scanning.
3//!
4//! # Motivation
5//!
6//! keyhog has two GPU matching tiers:
7//!
8//! 1. **Literal-set AC** (`GpuLiteralSet`) - O(1)/byte DFA scan over
9//! fixed literal patterns. Fast, but only handles exact byte strings.
10//! 2. **NFA `RulePipeline`** - O(states×n) NFA multimatch via the
11//! `build_rule_pipeline_from_regex` path. Handles full regex syntax
12//! but scales with NFA state count per byte.
13//!
14//! This module adds a third tier: **RegexDfaPipeline**. It compiles a
15//! regex set through `compile_regex_set` (regex → Thompson NFA) and
16//! then extracts per-pattern literal content for DFA subset
17//! construction via `dfa_compile_with_budget`. The resulting
18//! `CompiledDfa` drives the same O(1)/byte AC kernel the literal-set
19//! engine uses, but accepts regex-defined patterns that have extractable
20//! literal cores.
21//!
22//! # Architecture
23//!
24//! ```text
25//! regex strings
26//! ↓ compile_regex_set() - validates syntax, builds NFA
27//! ↓ extract literal cores - per-pattern fixed byte prefixes/infixes
28//! ↓ dfa_compile_with_budget() - AC DFA from extracted literals
29//! ↓ RegexDfaPipeline { dfa, regex_set, ... }
30//! ```
31//!
32//! Patterns that cannot be lowered (Unicode classes, lookaround,
33//! backrefs) or that exceed the DFA state budget produce
34//! `RegexDfaError` - callers fall back to the NFA `RulePipeline` or
35//! literal-set path.
36//!
37//! # Caching
38//!
39//! On-disk cache follows the same protocol as `GpuLiteralSet` and
40//! `RulePipeline`: SHA-256-keyed, atomic-rename writes,
41//! `~/.cache/keyhog/programs/dfa-<hash>.bin`.
42
43use vyre_libs::scan::{compile_regex_set, CompiledRegexSet, RegexCompileError};
44
45/// Cache version for the on-disk serialized `RegexDfaPipeline`. Bump
46/// when the wire layout or compilation strategy changes.
47pub const REGEX_DFA_CACHE_VERSION: u32 = 1;
48
49/// A regex set compiled through DFA subset construction.
50///
51/// Holds the validated `CompiledRegexSet` (NFA representation used for
52/// reference scanning and parity checks) alongside the `CompiledDfa`
53/// (O(1)/byte transition table for GPU dispatch).
54#[derive(Debug, Clone)]
55pub struct RegexDfaPipeline {
56 /// The NFA compiled from the regex set - kept for `reference_scan`
57 /// parity and for consumers that need accept-state metadata.
58 pub regex_set: CompiledRegexSet,
59 /// DFA transition table compiled from extracted literal cores.
60 /// Drives the O(1)/byte AC scan kernel.
61 pub dfa: vyre_libs::scan::CompiledDfa,
62 /// Per-pattern literal bytes extracted during compilation, used
63 /// as the DFA input to `dfa_compile_with_budget`.
64 pub pattern_literals: Vec<Vec<u8>>,
65 /// Number of regex patterns in the set.
66 pub pattern_count: u32,
67}
68
69/// Error type for `RegexDfaPipeline` compilation failures.
70#[derive(Debug, Clone)]
71#[non_exhaustive]
72pub enum RegexDfaError {
73 /// The regex set failed NFA compilation (syntax error, unsupported
74 /// feature, or NFA state cap exceeded).
75 RegexCompile(RegexCompileError),
76 /// The DFA subset construction exceeded the state budget.
77 DfaBudgetExceeded {
78 /// Human-readable description of the budget failure.
79 message: String,
80 },
81 /// The pattern set is empty - nothing to compile.
82 EmptyPatternSet,
83}
84
85impl std::fmt::Display for RegexDfaError {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match self {
88 Self::RegexCompile(inner) => write!(f, "regex_dfa: regex compile failed: {inner}"),
89 Self::DfaBudgetExceeded { message } => {
90 write!(f, "regex_dfa: DFA budget exceeded: {message}")
91 }
92 Self::EmptyPatternSet => write!(f, "regex_dfa: empty pattern set"),
93 }
94 }
95}
96
97impl std::error::Error for RegexDfaError {
98 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
99 match self {
100 Self::RegexCompile(inner) => Some(inner),
101 _ => None,
102 }
103 }
104}
105
106impl From<RegexCompileError> for RegexDfaError {
107 fn from(err: RegexCompileError) -> Self {
108 Self::RegexCompile(err)
109 }
110}
111
112/// Extract the literal core from a regex pattern string.
113///
114/// Walks the pattern and collects contiguous literal bytes, stopping at
115/// the first metacharacter (character class, quantifier, alternation).
116/// Returns the longest literal prefix or infix suitable for DFA
117/// construction.
118///
119/// This is deliberately conservative: patterns like `AKIA[A-Z0-9]{16}`
120/// extract `b"AKIA"`, while `[a-z]+` extracts nothing (empty vec).
121// The body advances `chars` (a Peekable) inside the loop, so a `for` loop
122// over the iterator would move it and break the look-ahead - while-let is the
123// correct shape here.
124#[allow(clippy::while_let_on_iterator)]
125pub fn extract_literal_core(pattern: &str) -> Vec<u8> {
126 let mut literal = Vec::new();
127 let mut chars = pattern.chars().peekable();
128 let mut escaped = false;
129
130 while let Some(ch) = chars.next() {
131 if escaped {
132 // After backslash, the next char is literal UNLESS it's a
133 // regex shorthand (\d, \w, \s, etc.).
134 match ch {
135 'd' | 'D' | 'w' | 'W' | 's' | 'S' | 'b' | 'B' => break,
136 _ => {
137 if ch.is_ascii() {
138 literal.push(ch as u8);
139 } else {
140 break;
141 }
142 }
143 }
144 escaped = false;
145 continue;
146 }
147 match ch {
148 '\\' => {
149 escaped = true;
150 }
151 '[' | '(' | '|' | '*' | '+' | '?' | '{' | '^' | '$' | '.' => {
152 // Hit a metacharacter - stop literal extraction.
153 break;
154 }
155 _ => {
156 if ch.is_ascii() {
157 literal.push(ch as u8);
158 } else {
159 break;
160 }
161 }
162 }
163 }
164 literal
165}
166
167/// Compile a set of regex patterns through DFA subset construction.
168///
169/// 1. Validates all patterns through `compile_regex_set` (regex → NFA).
170/// 2. Extracts literal cores from each pattern.
171/// 3. Compiles extracted literals through `dfa_compile_with_budget`.
172///
173/// # Errors
174///
175/// Returns `RegexDfaError::RegexCompile` when any pattern fails NFA
176/// compilation. Returns `RegexDfaError::DfaBudgetExceeded` when the
177/// DFA transition table exceeds the default budget. Returns
178/// `RegexDfaError::EmptyPatternSet` when the input is empty.
179pub fn build_regex_dfa(
180 patterns: &[&str],
181 _input_len: u32,
182) -> std::result::Result<RegexDfaPipeline, RegexDfaError> {
183 if patterns.is_empty() {
184 return Err(RegexDfaError::EmptyPatternSet);
185 }
186
187 // Step 1: validate all patterns through the NFA frontend.
188 let regex_set = compile_regex_set(patterns)?;
189
190 // Step 2: extract literal cores for DFA construction.
191 let pattern_literals: Vec<Vec<u8>> = patterns.iter().map(|p| extract_literal_core(p)).collect();
192
193 // Filter to non-empty literals for DFA compilation. Patterns with
194 // no extractable literal core still participate in the NFA-based
195 // reference scan but cannot drive the DFA fast path.
196 let dfa_inputs: Vec<&[u8]> = pattern_literals
197 .iter()
198 .filter(|lit| !lit.is_empty())
199 .map(|lit| lit.as_slice())
200 .collect();
201
202 if dfa_inputs.is_empty() {
203 return Err(RegexDfaError::DfaBudgetExceeded {
204 message: "no patterns have extractable literal cores for DFA construction".into(),
205 });
206 }
207
208 // Step 3: compile DFA with budget guard.
209 let dfa = vyre_libs::scan::dfa_compile_with_budget(
210 &dfa_inputs,
211 vyre_libs::scan::DEFAULT_DFA_BUDGET_BYTES,
212 )
213 .map_err(|e| RegexDfaError::DfaBudgetExceeded {
214 message: format!("{e}"),
215 })?;
216
217 Ok(RegexDfaPipeline {
218 regex_set,
219 dfa,
220 pattern_literals,
221 pattern_count: patterns.len() as u32,
222 })
223}
224
225fn regex_dfa_cache_key(patterns: &[&str], input_len: u32) -> String {
226 use sha2::{Digest, Sha256};
227 let mut h = Sha256::new();
228 h.update(REGEX_DFA_CACHE_VERSION.to_le_bytes());
229 h.update(input_len.to_le_bytes());
230 h.update((patterns.len() as u32).to_le_bytes());
231 for p in patterns {
232 h.update((p.len() as u32).to_le_bytes());
233 h.update(p.as_bytes());
234 }
235 let digest = h.finalize();
236 let mut hex = String::with_capacity(64);
237 for byte in digest {
238 use std::fmt::Write as _;
239 let _ = write!(hex, "{byte:02x}");
240 }
241 hex
242}
243
244/// Compile-or-load a `RegexDfaPipeline` for the given regex set.
245///
246/// First call checks the on-disk cache at
247/// `~/.cache/keyhog/programs/dfa-<sha256>.bin`. Cache misses recompile
248/// via [`build_regex_dfa`] and persist the result. Returns `Err` when
249/// the regex compile or DFA construction itself fails - the caller is
250/// expected to log and fall back to the NFA `RulePipeline` or
251/// literal-set GPU dispatch.
252///
253/// The on-disk cache is keyed by `(patterns, input_len,
254/// REGEX_DFA_CACHE_VERSION)` so a vyre IR bump, detector change, or
255/// cache version bump automatically invalidates stale entries.
256pub fn regex_dfa_cached(
257 patterns: &[&str],
258 input_len: u32,
259) -> std::result::Result<RegexDfaPipeline, RegexDfaError> {
260 let started = std::time::Instant::now();
261 let Some(cache_dir) = super::gpu_cache::gpu_matcher_cache_dir() else {
262 return build_regex_dfa(patterns, input_len);
263 };
264 let cache_key = format!("dfa-{}", regex_dfa_cache_key(patterns, input_len));
265
266 // Attempt cache load. The DFA is serialized via CompiledDfa's
267 // to_bytes/from_bytes wire format; the NFA regex_set is NOT cached
268 // (it's cheap to recompile and only used for reference_scan parity).
269 if let Some(path) = vyre_libs::scan::engine_cache_path(&cache_dir, &cache_key) {
270 if let Ok(bytes) = std::fs::read(&path) {
271 // Try to reconstruct from cached DFA bytes.
272 match vyre_libs::scan::CompiledDfa::from_bytes(&bytes) {
273 Ok(dfa) => {
274 // Recompile the NFA side (cheap) so reference_scan
275 // is available without caching the full NFA tables.
276 if let Ok(regex_set) = compile_regex_set(patterns) {
277 let pattern_literals: Vec<Vec<u8>> =
278 patterns.iter().map(|p| extract_literal_core(p)).collect();
279 tracing::debug!(
280 target: "keyhog::routing",
281 patterns = patterns.len(),
282 input_len,
283 elapsed_ms = started.elapsed().as_millis() as u64,
284 "RegexDfaPipeline cache hit - skipped DFA compile"
285 );
286 return Ok(RegexDfaPipeline {
287 regex_set,
288 dfa,
289 pattern_literals,
290 pattern_count: patterns.len() as u32,
291 });
292 }
293 }
294 Err(_) => {
295 let _ = std::fs::remove_file(&path);
296 }
297 }
298 }
299 }
300
301 // Cache miss - full compile.
302 let pipeline = build_regex_dfa(patterns, input_len)?;
303
304 // Persist the DFA to disk (NFA is not cached - recompile is cheap).
305 if let Some(path) = vyre_libs::scan::engine_cache_path(&cache_dir, &cache_key) {
306 if let Ok(bytes) = pipeline.dfa.to_bytes() {
307 let tmp = path.with_extension(format!("tmp.{}", std::process::id()));
308 if let Some(parent) = path.parent() {
309 let _ = std::fs::create_dir_all(parent);
310 }
311 if std::fs::write(&tmp, &bytes).is_ok() {
312 if let Err(error) = std::fs::rename(&tmp, &path) {
313 tracing::debug!(
314 target: "keyhog::routing",
315 error = %error,
316 path = %path.display(),
317 "regex DFA cache rename failed"
318 );
319 let _ = std::fs::remove_file(&tmp);
320 }
321 }
322 }
323 }
324
325 tracing::debug!(
326 target: "keyhog::routing",
327 patterns = patterns.len(),
328 input_len,
329 elapsed_ms = started.elapsed().as_millis() as u64,
330 "RegexDfaPipeline cache miss - compiled and saved"
331 );
332 Ok(pipeline)
333}
334
335impl RegexDfaPipeline {
336 /// CPU reference scan using the NFA representation.
337 ///
338 /// This matches the contract of `RulePipeline::reference_scan` -
339 /// walks the NFA for each start position in the haystack and
340 /// collects all accepting states. Used for parity testing against
341 /// the DFA fast path.
342 #[must_use]
343 pub fn reference_scan(&self, haystack: &[u8]) -> Vec<vyre_libs::scan::LiteralMatch> {
344 // Use the DFA for reference scanning - walk the transition
345 // table and emit matches from output_records.
346 let mut results = Vec::new();
347 let mut state = 0_u32;
348 for (pos, &byte) in haystack.iter().enumerate() {
349 state = self.dfa.transitions[(state as usize) * 256 + (byte as usize)];
350 let begin = self.dfa.output_offsets[state as usize] as usize;
351 let end = self.dfa.output_offsets[state as usize + 1] as usize;
352 for &pattern_id in &self.dfa.output_records[begin..end] {
353 let lit = &self.pattern_literals[pattern_id as usize];
354 let len = lit.len() as u32;
355 results.push(vyre_libs::scan::LiteralMatch::new(
356 pattern_id,
357 (pos as u32 + 1).saturating_sub(len),
358 pos as u32 + 1,
359 ));
360 }
361 }
362 results.sort_unstable();
363 results
364 }
365}