Skip to main content

keyhog_scanner/engine/
gpu_regex_dfa.rs

1//! GPU `RegexDfaPipeline` - regex sets compiled through DFA subset
2//! construction into O(1)/byte Aho-Corasick scanning.
3//!
4//! # Motivation
5//!
6//! keyhog has two GPU matching tiers:
7//!
8//! 1. **Literal-set AC** (`GpuLiteralSet`) - O(1)/byte DFA scan over
9//!    fixed literal patterns. Fast, but only handles exact byte strings.
10//! 2. **NFA `RulePipeline`** - O(states×n) NFA multimatch via the
11//!    `build_rule_pipeline_from_regex` path. Handles full regex syntax
12//!    but scales with NFA state count per byte.
13//!
14//! This module adds a third tier: **RegexDfaPipeline**. It compiles a
15//! regex set through `compile_regex_set` (regex → Thompson NFA) and
16//! then extracts per-pattern literal content for DFA subset
17//! construction via `dfa_compile_with_budget`. The resulting
18//! `CompiledDfa` drives the same O(1)/byte AC kernel the literal-set
19//! engine uses, but accepts regex-defined patterns that have extractable
20//! literal cores.
21//!
22//! # Architecture
23//!
24//! ```text
25//! regex strings
26//!   ↓  compile_regex_set()   - validates syntax, builds NFA
27//!   ↓  extract literal cores - per-pattern fixed byte prefixes/infixes
28//!   ↓  dfa_compile_with_budget() - AC DFA from extracted literals
29//!   ↓  RegexDfaPipeline { dfa, regex_set, ... }
30//! ```
31//!
32//! Patterns that cannot be lowered (Unicode classes, lookaround,
33//! backrefs) or that exceed the DFA state budget produce
34//! `RegexDfaError` - callers fall back to the NFA `RulePipeline` or
35//! literal-set path.
36//!
37//! # Caching
38//!
39//! On-disk cache follows the same protocol as `GpuLiteralSet` and
40//! `RulePipeline`: SHA-256-keyed, atomic-rename writes,
41//! `~/.cache/keyhog/programs/dfa-<hash>.bin`.
42
43use vyre_libs::scan::{compile_regex_set, CompiledRegexSet, RegexCompileError};
44
45/// Cache version for the on-disk serialized `RegexDfaPipeline`. Bump
46/// when the wire layout or compilation strategy changes.
47pub const REGEX_DFA_CACHE_VERSION: u32 = 1;
48
49/// A regex set compiled through DFA subset construction.
50///
51/// Holds the validated `CompiledRegexSet` (NFA representation used for
52/// reference scanning and parity checks) alongside the `CompiledDfa`
53/// (O(1)/byte transition table for GPU dispatch).
54#[derive(Debug, Clone)]
55pub struct RegexDfaPipeline {
56    /// The NFA compiled from the regex set - kept for `reference_scan`
57    /// parity and for consumers that need accept-state metadata.
58    pub regex_set: CompiledRegexSet,
59    /// DFA transition table compiled from extracted literal cores.
60    /// Drives the O(1)/byte AC scan kernel.
61    pub dfa: vyre_libs::scan::CompiledDfa,
62    /// Per-pattern literal bytes extracted during compilation, used
63    /// as the DFA input to `dfa_compile_with_budget`.
64    pub pattern_literals: Vec<Vec<u8>>,
65    /// Number of regex patterns in the set.
66    pub pattern_count: u32,
67}
68
69/// Error type for `RegexDfaPipeline` compilation failures.
70#[derive(Debug, Clone)]
71#[non_exhaustive]
72pub enum RegexDfaError {
73    /// The regex set failed NFA compilation (syntax error, unsupported
74    /// feature, or NFA state cap exceeded).
75    RegexCompile(RegexCompileError),
76    /// The DFA subset construction exceeded the state budget.
77    DfaBudgetExceeded {
78        /// Human-readable description of the budget failure.
79        message: String,
80    },
81    /// The pattern set is empty - nothing to compile.
82    EmptyPatternSet,
83}
84
85impl std::fmt::Display for RegexDfaError {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        match self {
88            Self::RegexCompile(inner) => write!(f, "regex_dfa: regex compile failed: {inner}"),
89            Self::DfaBudgetExceeded { message } => {
90                write!(f, "regex_dfa: DFA budget exceeded: {message}")
91            }
92            Self::EmptyPatternSet => write!(f, "regex_dfa: empty pattern set"),
93        }
94    }
95}
96
97impl std::error::Error for RegexDfaError {
98    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
99        match self {
100            Self::RegexCompile(inner) => Some(inner),
101            _ => None,
102        }
103    }
104}
105
106impl From<RegexCompileError> for RegexDfaError {
107    fn from(err: RegexCompileError) -> Self {
108        Self::RegexCompile(err)
109    }
110}
111
112/// Extract the literal core from a regex pattern string.
113///
114/// Walks the pattern and collects contiguous literal bytes, stopping at
115/// the first metacharacter (character class, quantifier, alternation).
116/// Returns the longest literal prefix or infix suitable for DFA
117/// construction.
118///
119/// This is deliberately conservative: patterns like `AKIA[A-Z0-9]{16}`
120/// extract `b"AKIA"`, while `[a-z]+` extracts nothing (empty vec).
121// The body advances `chars` (a Peekable) inside the loop, so a `for` loop
122// over the iterator would move it and break the look-ahead - while-let is the
123// correct shape here.
124#[allow(clippy::while_let_on_iterator)]
125pub fn extract_literal_core(pattern: &str) -> Vec<u8> {
126    let mut literal = Vec::new();
127    let mut chars = pattern.chars().peekable();
128    let mut escaped = false;
129
130    while let Some(ch) = chars.next() {
131        if escaped {
132            // After backslash, the next char is literal UNLESS it's a
133            // regex shorthand (\d, \w, \s, etc.).
134            match ch {
135                'd' | 'D' | 'w' | 'W' | 's' | 'S' | 'b' | 'B' => break,
136                _ => {
137                    if ch.is_ascii() {
138                        literal.push(ch as u8);
139                    } else {
140                        break;
141                    }
142                }
143            }
144            escaped = false;
145            continue;
146        }
147        match ch {
148            '\\' => {
149                escaped = true;
150            }
151            '[' | '(' | '|' | '*' | '+' | '?' | '{' | '^' | '$' | '.' => {
152                // Hit a metacharacter - stop literal extraction.
153                break;
154            }
155            _ => {
156                if ch.is_ascii() {
157                    literal.push(ch as u8);
158                } else {
159                    break;
160                }
161            }
162        }
163    }
164    literal
165}
166
167/// Compile a set of regex patterns through DFA subset construction.
168///
169/// 1. Validates all patterns through `compile_regex_set` (regex → NFA).
170/// 2. Extracts literal cores from each pattern.
171/// 3. Compiles extracted literals through `dfa_compile_with_budget`.
172///
173/// # Errors
174///
175/// Returns `RegexDfaError::RegexCompile` when any pattern fails NFA
176/// compilation. Returns `RegexDfaError::DfaBudgetExceeded` when the
177/// DFA transition table exceeds the default budget. Returns
178/// `RegexDfaError::EmptyPatternSet` when the input is empty.
179pub fn build_regex_dfa(
180    patterns: &[&str],
181    _input_len: u32,
182) -> std::result::Result<RegexDfaPipeline, RegexDfaError> {
183    if patterns.is_empty() {
184        return Err(RegexDfaError::EmptyPatternSet);
185    }
186
187    // Step 1: validate all patterns through the NFA frontend.
188    let regex_set = compile_regex_set(patterns)?;
189
190    // Step 2: extract literal cores for DFA construction.
191    let pattern_literals: Vec<Vec<u8>> = patterns.iter().map(|p| extract_literal_core(p)).collect();
192
193    // Filter to non-empty literals for DFA compilation. Patterns with
194    // no extractable literal core still participate in the NFA-based
195    // reference scan but cannot drive the DFA fast path.
196    let dfa_inputs: Vec<&[u8]> = pattern_literals
197        .iter()
198        .filter(|lit| !lit.is_empty())
199        .map(|lit| lit.as_slice())
200        .collect();
201
202    if dfa_inputs.is_empty() {
203        return Err(RegexDfaError::DfaBudgetExceeded {
204            message: "no patterns have extractable literal cores for DFA construction".into(),
205        });
206    }
207
208    // Step 3: compile DFA with budget guard.
209    let dfa = vyre_libs::scan::dfa_compile_with_budget(
210        &dfa_inputs,
211        vyre_libs::scan::DEFAULT_DFA_BUDGET_BYTES,
212    )
213    .map_err(|e| RegexDfaError::DfaBudgetExceeded {
214        message: format!("{e}"),
215    })?;
216
217    Ok(RegexDfaPipeline {
218        regex_set,
219        dfa,
220        pattern_literals,
221        pattern_count: patterns.len() as u32,
222    })
223}
224
225fn regex_dfa_cache_key(patterns: &[&str], input_len: u32) -> String {
226    use sha2::{Digest, Sha256};
227    let mut h = Sha256::new();
228    h.update(REGEX_DFA_CACHE_VERSION.to_le_bytes());
229    h.update(input_len.to_le_bytes());
230    h.update((patterns.len() as u32).to_le_bytes());
231    for p in patterns {
232        h.update((p.len() as u32).to_le_bytes());
233        h.update(p.as_bytes());
234    }
235    let digest = h.finalize();
236    let mut hex = String::with_capacity(64);
237    for byte in digest {
238        use std::fmt::Write as _;
239        let _ = write!(hex, "{byte:02x}");
240    }
241    hex
242}
243
244/// Compile-or-load a `RegexDfaPipeline` for the given regex set.
245///
246/// First call checks the on-disk cache at
247/// `~/.cache/keyhog/programs/dfa-<sha256>.bin`. Cache misses recompile
248/// via [`build_regex_dfa`] and persist the result. Returns `Err` when
249/// the regex compile or DFA construction itself fails - the caller is
250/// expected to log and fall back to the NFA `RulePipeline` or
251/// literal-set GPU dispatch.
252///
253/// The on-disk cache is keyed by `(patterns, input_len,
254/// REGEX_DFA_CACHE_VERSION)` so a vyre IR bump, detector change, or
255/// cache version bump automatically invalidates stale entries.
256pub fn regex_dfa_cached(
257    patterns: &[&str],
258    input_len: u32,
259) -> std::result::Result<RegexDfaPipeline, RegexDfaError> {
260    let started = std::time::Instant::now();
261    let Some(cache_dir) = super::gpu_cache::gpu_matcher_cache_dir() else {
262        return build_regex_dfa(patterns, input_len);
263    };
264    let cache_key = format!("dfa-{}", regex_dfa_cache_key(patterns, input_len));
265
266    // Attempt cache load. The DFA is serialized via CompiledDfa's
267    // to_bytes/from_bytes wire format; the NFA regex_set is NOT cached
268    // (it's cheap to recompile and only used for reference_scan parity).
269    if let Some(path) = vyre_libs::scan::engine_cache_path(&cache_dir, &cache_key) {
270        if let Ok(bytes) = std::fs::read(&path) {
271            // Try to reconstruct from cached DFA bytes.
272            match vyre_libs::scan::CompiledDfa::from_bytes(&bytes) {
273                Ok(dfa) => {
274                    // Recompile the NFA side (cheap) so reference_scan
275                    // is available without caching the full NFA tables.
276                    if let Ok(regex_set) = compile_regex_set(patterns) {
277                        let pattern_literals: Vec<Vec<u8>> =
278                            patterns.iter().map(|p| extract_literal_core(p)).collect();
279                        tracing::debug!(
280                            target: "keyhog::routing",
281                            patterns = patterns.len(),
282                            input_len,
283                            elapsed_ms = started.elapsed().as_millis() as u64,
284                            "RegexDfaPipeline cache hit - skipped DFA compile"
285                        );
286                        return Ok(RegexDfaPipeline {
287                            regex_set,
288                            dfa,
289                            pattern_literals,
290                            pattern_count: patterns.len() as u32,
291                        });
292                    }
293                }
294                Err(_) => {
295                    let _ = std::fs::remove_file(&path);
296                }
297            }
298        }
299    }
300
301    // Cache miss - full compile.
302    let pipeline = build_regex_dfa(patterns, input_len)?;
303
304    // Persist the DFA to disk (NFA is not cached - recompile is cheap).
305    if let Some(path) = vyre_libs::scan::engine_cache_path(&cache_dir, &cache_key) {
306        if let Ok(bytes) = pipeline.dfa.to_bytes() {
307            let tmp = path.with_extension(format!("tmp.{}", std::process::id()));
308            if let Some(parent) = path.parent() {
309                let _ = std::fs::create_dir_all(parent);
310            }
311            if std::fs::write(&tmp, &bytes).is_ok() {
312                if let Err(error) = std::fs::rename(&tmp, &path) {
313                    tracing::debug!(
314                        target: "keyhog::routing",
315                        error = %error,
316                        path = %path.display(),
317                        "regex DFA cache rename failed"
318                    );
319                    let _ = std::fs::remove_file(&tmp);
320                }
321            }
322        }
323    }
324
325    tracing::debug!(
326        target: "keyhog::routing",
327        patterns = patterns.len(),
328        input_len,
329        elapsed_ms = started.elapsed().as_millis() as u64,
330        "RegexDfaPipeline cache miss - compiled and saved"
331    );
332    Ok(pipeline)
333}
334
335impl RegexDfaPipeline {
336    /// CPU reference scan using the NFA representation.
337    ///
338    /// This matches the contract of `RulePipeline::reference_scan` -
339    /// walks the NFA for each start position in the haystack and
340    /// collects all accepting states. Used for parity testing against
341    /// the DFA fast path.
342    #[must_use]
343    pub fn reference_scan(&self, haystack: &[u8]) -> Vec<vyre_libs::scan::LiteralMatch> {
344        // Use the DFA for reference scanning - walk the transition
345        // table and emit matches from output_records.
346        let mut results = Vec::new();
347        let mut state = 0_u32;
348        for (pos, &byte) in haystack.iter().enumerate() {
349            state = self.dfa.transitions[(state as usize) * 256 + (byte as usize)];
350            let begin = self.dfa.output_offsets[state as usize] as usize;
351            let end = self.dfa.output_offsets[state as usize + 1] as usize;
352            for &pattern_id in &self.dfa.output_records[begin..end] {
353                let lit = &self.pattern_literals[pattern_id as usize];
354                let len = lit.len() as u32;
355                results.push(vyre_libs::scan::LiteralMatch::new(
356                    pattern_id,
357                    (pos as u32 + 1).saturating_sub(len),
358                    pos as u32 + 1,
359                ));
360            }
361        }
362        results.sort_unstable();
363        results
364    }
365}