Skip to main content

keyhog_scanner/engine/
gpu_lazy.rs

1use super::*;
2use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
3
4fn append_match_bound_slot(
5    hits_buffer: &str,
6    count_buffer: &str,
7    tag: impl Into<Expr>,
8    start: impl Into<Expr>,
9    end: impl Into<Expr>,
10) -> Node {
11    let slot_name = "_keyhog_match_slot";
12    let max_hits = Expr::div(Expr::buf_len(hits_buffer), Expr::u32(3));
13
14    Node::block(vec![
15        Node::let_bind(
16            slot_name,
17            Expr::atomic_add(count_buffer, Expr::u32(0), Expr::u32(1)),
18        ),
19        Node::if_then(
20            Expr::lt(Expr::var(slot_name), max_hits),
21            vec![
22                Node::store(
23                    hits_buffer,
24                    Expr::mul(Expr::var(slot_name), Expr::u32(3)),
25                    tag.into(),
26                ),
27                Node::store(
28                    hits_buffer,
29                    Expr::add(Expr::mul(Expr::var(slot_name), Expr::u32(3)), Expr::u32(1)),
30                    start.into(),
31                ),
32                Node::store(
33                    hits_buffer,
34                    Expr::add(Expr::mul(Expr::var(slot_name), Expr::u32(3)), Expr::u32(2)),
35                    end.into(),
36                ),
37            ],
38        ),
39    ])
40}
41
42fn build_ac_bounded_ranges_program_bound_atomic(
43    dfa: &vyre_libs::scan::dfa::CompiledDfa,
44    pattern_count: u32,
45    max_matches: u32,
46) -> Option<Program> {
47    let output_records_len = u32::try_from(dfa.output_records.len()).ok()?;
48    let max_pattern_len = dfa.max_pattern_len.max(1);
49
50    let haystack = "haystack";
51    let transitions = "transitions";
52    let output_offsets = "output_offsets";
53    let output_records = "output_records";
54    let pattern_lengths = "pattern_lengths";
55    let haystack_len = "haystack_len";
56    let match_count = "match_count";
57    let matches = "matches";
58
59    let i = Expr::var("i");
60    let end = Expr::add(i.clone(), Expr::u32(1));
61    let scan_start = Expr::select(
62        Expr::lt(i.clone(), Expr::u32(max_pattern_len - 1)),
63        Expr::u32(0),
64        Expr::sub(end.clone(), Expr::u32(max_pattern_len)),
65    );
66    let (load_step_byte, step_byte) =
67        vyre_libs::scan::builders::load_packed_byte(haystack, Expr::var("step"));
68
69    let walk_body = vec![
70        Node::let_bind("i", Expr::InvocationId { axis: 0 }),
71        Node::if_then(
72            Expr::lt(i.clone(), Expr::load(haystack_len, Expr::u32(0))),
73            vec![
74                Node::let_bind("state", Expr::u32(0)),
75                Node::let_bind("scan_start", scan_start),
76                Node::let_bind("scan_end", end),
77                Node::loop_for(
78                    "step",
79                    Expr::var("scan_start"),
80                    Expr::var("scan_end"),
81                    vec![
82                        load_step_byte,
83                        Node::assign(
84                            "state",
85                            Expr::load(
86                                transitions,
87                                Expr::add(Expr::mul(Expr::var("state"), Expr::u32(256)), step_byte),
88                            ),
89                        ),
90                    ],
91                ),
92                Node::let_bind("out_begin", Expr::load(output_offsets, Expr::var("state"))),
93                Node::let_bind(
94                    "out_end",
95                    Expr::load(output_offsets, Expr::add(Expr::var("state"), Expr::u32(1))),
96                ),
97                Node::loop_for(
98                    "out_idx",
99                    Expr::var("out_begin"),
100                    Expr::var("out_end"),
101                    vec![
102                        Node::let_bind(
103                            "pattern_id",
104                            Expr::load(output_records, Expr::var("out_idx")),
105                        ),
106                        Node::let_bind(
107                            "pat_len",
108                            Expr::load(pattern_lengths, Expr::var("pattern_id")),
109                        ),
110                        Node::let_bind(
111                            "match_start",
112                            Expr::select(
113                                Expr::lt(Expr::var("scan_end"), Expr::var("pat_len")),
114                                Expr::u32(0),
115                                Expr::sub(Expr::var("scan_end"), Expr::var("pat_len")),
116                            ),
117                        ),
118                        append_match_bound_slot(
119                            matches,
120                            match_count,
121                            Expr::var("pattern_id"),
122                            Expr::var("match_start"),
123                            Expr::var("scan_end"),
124                        ),
125                    ],
126                ),
127            ],
128        ),
129    ];
130
131    Some(Program::wrapped(
132        vec![
133            BufferDecl::storage(haystack, 0, BufferAccess::ReadOnly, DataType::U32),
134            BufferDecl::storage(transitions, 1, BufferAccess::ReadOnly, DataType::U32)
135                .with_count(dfa.state_count.saturating_mul(256)),
136            BufferDecl::storage(output_offsets, 2, BufferAccess::ReadOnly, DataType::U32)
137                .with_count(dfa.state_count.saturating_add(1)),
138            BufferDecl::storage(output_records, 3, BufferAccess::ReadOnly, DataType::U32)
139                .with_count(output_records_len),
140            BufferDecl::storage(pattern_lengths, 4, BufferAccess::ReadOnly, DataType::U32)
141                .with_count(pattern_count),
142            BufferDecl::storage(haystack_len, 5, BufferAccess::ReadOnly, DataType::U32)
143                .with_count(1),
144            BufferDecl::read_write(match_count, 6, DataType::U32).with_count(1),
145            BufferDecl::output(matches, 7, DataType::U32).with_count(max_matches.saturating_mul(3)),
146        ],
147        [128, 1, 1],
148        vec![vyre_libs::region::wrap_anonymous(
149            "keyhog::matching::classic_ac_bounded_ranges",
150            walk_body,
151        )],
152    ))
153}
154
155/// Tracks whether the subgroup-coalesced match-append form
156/// (subgroup_ballot + subgroup_shuffle -> `_vyre_match_leader`) is
157/// enabled for the AC GPU dispatch Program. Held forced-off because
158/// vyre's substrate-neutral pre-emit lowering rejects that form on
159/// every backend (CUDA and wgpu both: "_vyre_match_leader referenced
160/// before binding", Innovation I.17). This is a named, greppable
161/// dead-path marker rather than a silent inline `false`: once the
162/// vyre IR gap is closed, flipping this to `true` re-enables the
163/// ~32x atomic-contention reduction on the shared match-count buffer
164/// across every backend in one place. The interim contention win
165/// (per-workgroup local reduction -> one atomic add per group) lives
166/// in the kernel builder, not here.
167const AC_GPU_SUBGROUP_COALESCE: bool = false;
168
169impl CompiledScanner {
170    /// Lazily compile the GPU literal-set on first call. Returns `None`
171    /// when no compatible adapter was detected at probe time.
172    ///
173    /// Persists the compiled matcher to `~/.cache/keyhog/programs/<hash>.bin`.
174    /// On a cache hit the matcher is loaded from disk and the GPU
175    /// recompile is skipped entirely - biggest cold-start win on
176    /// `keyhog scan` / `scan-system` runs that re-launch repeatedly.
177    /// Cache misses (no file, version-mismatch, corrupt blob) silently
178    /// recompile and re-cache.
179    pub fn gpu_matcher(&self) -> Option<&vyre_libs::scan::GpuLiteralSet> {
180        self.gpu_matcher
181            .get_or_init(|| {
182                let Some(literals) = &self.gpu_literals else {
183                    return None;
184                };
185                let literal_refs: Vec<&[u8]> = literals.iter().map(|v| v.as_slice()).collect();
186                let cache_dir = super::gpu_cache::gpu_matcher_cache_dir()?;
187                let cache_key = format!(
188                    "lit-{}",
189                    super::gpu_cache::gpu_matcher_cache_key(&literal_refs)
190                );
191                let started = std::time::Instant::now();
192                // One-line lego-block cache wiring courtesy of
193                // `vyre_libs::scan::cached_load_or_compile`. The
194                // helper handles atomic-rename, stale-blob deletion,
195                // and silent fall-through on cache-side I/O errors -
196                // every behaviour the previous hand-rolled
197                // load/save pair tried to match. We log compile cost
198                // here so the operator can still see warm-vs-cold
199                // start latency in `--verbose` output.
200                let matcher =
201                    vyre_libs::scan::cached_load_or_compile(&cache_dir, &cache_key, || {
202                        vyre_libs::scan::GpuLiteralSet::compile(&literal_refs)
203                    });
204                tracing::debug!(
205                    target: "keyhog::routing",
206                    patterns = literal_refs.len(),
207                    elapsed_ms = started.elapsed().as_millis() as u64,
208                    "GpuLiteralSet ready (warm cache or compiled)"
209                );
210                Some(matcher)
211            })
212            .as_ref()
213    }
214
215    /// Lazily build the Aho-Corasick bounded-ranges dispatch Program
216    /// from the GpuLiteralSet's CompiledDfa. The two engines share the
217    /// same DFA - only the dispatch Program (and therefore the
218    /// per-byte algorithm) differs:
219    ///
220    /// * `gpu_matcher().program` - `build_literal_set_program`:
221    ///   walks every pattern × every literal byte per haystack
222    ///   position. `O(N × L) per byte`. Works for any pattern set
223    ///   that fits the DFA budget.
224    /// * `ac_gpu_program()` - `classic_ac_bounded_ranges_program`:
225    ///   walks the AC transition table forward `L_max` bytes per
226    ///   position, emits every pattern in the accepting state's
227    ///   flat output_links. `O(L_max) per byte` regardless of N.
228    ///
229    /// Selected at scan time via `KEYHOG_GPU_KERNEL=ac`. Returns
230    /// `None` when no GPU matcher is available; callers fall through
231    /// to the literal-set path or non-GPU backend.
232    ///
233    /// Cap of `super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH` triples per shard
234    /// dispatch matches the existing literal-set output-buffer cap.
235    /// Truncation (count > cap on readback) is handled by the same
236    /// fall-back-to-CPU branch the literal-set path uses.
237    pub fn ac_gpu_program(&self) -> Option<&vyre::Program> {
238        self.ac_gpu_program
239            .get_or_init(|| {
240                let matcher = self.gpu_matcher()?;
241                let pattern_count = matcher.pattern_lengths.len() as u32;
242                // Pick the match-append strategy. The subgroup form
243                // (subgroup_ballot + subgroup_shuffle producing
244                // _vyre_match_leader) was originally gated to wgpu
245                // only because vyre-driver-cuda rejects it during
246                // canonical pre-emit lowering. Runtime testing on
247                // Apple Silicon M4 Pro with vyre v0.4.2 confirmed
248                // the SAME "_vyre_match_leader referenced before
249                // binding" rejection on the wgpu path: the lowering
250                // gap is in vyre's substrate-neutral pre-emit step,
251                // not the driver-specific emitter. Until the IR gap
252                // is closed, use_subgroup_coalesce stays false on
253                // every backend. We lose the ~32x atomic-contention
254                // reduction the subgroup form would have provided
255                // (Innovation I.17), but recall and correctness are
256                // preserved; the plain append_match path produces
257                // bit-identical match output, just with more atomic
258                // pressure on the shared count buffer.
259                let backend_id = self.gpu_backend.as_ref().map(|b| b.id()).unwrap_or("none");
260                let use_subgroup_coalesce = AC_GPU_SUBGROUP_COALESCE;
261                let program = if use_subgroup_coalesce {
262                    vyre_libs::scan::classic_ac::build_ac_bounded_ranges_program_ext(
263                        &matcher.dfa,
264                        pattern_count,
265                        super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH,
266                        true,
267                    )
268                } else {
269                    build_ac_bounded_ranges_program_bound_atomic(
270                        &matcher.dfa,
271                        pattern_count,
272                        super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH,
273                    )?
274                };
275                tracing::debug!(
276                    target: "keyhog::routing",
277                    pattern_count,
278                    state_count = matcher.dfa.state_count,
279                    max_pattern_len = matcher.dfa.max_pattern_len,
280                    backend = backend_id,
281                    use_subgroup_coalesce,
282                    "AC GPU dispatch Program built"
283                );
284                Some(program)
285            })
286            .as_ref()
287    }
288
289    /// Lazily compile the regex-NFA `RulePipeline` on first call.
290    /// Returns `None` once the OnceLock has fired when the regex
291    /// compile failed - typically because the combined NFA exceeds
292    /// vyre's per-subgroup state cap (`LANES * 32`) or because one
293    /// of the detector regexes uses a feature the byte-NFA frontend
294    /// can't represent (Unicode classes, lookaround, backrefs).
295    /// Callers should fall back to the literal-set GPU dispatch on
296    /// `None`.
297    ///
298    /// Pipeline is sized for [`super::rule_pipeline::megascan_input_len()`] bytes; batches
299    /// larger than that must take a different path. The orchestrator
300    /// caps batches at the same value (256 MiB default, up to 1 GiB
301    /// on 24+ GiB-VRAM cards) so this matches normal scan flow.
302    pub fn rule_pipeline(&self) -> Option<&vyre_libs::scan::RulePipeline> {
303        self.rule_pipeline
304            .get_or_init(|| {
305                let pattern_strs: Vec<&str> = self
306                    .ac_map
307                    .iter()
308                    .map(|p| p.regex.as_str())
309                    .chain(self.fallback.iter().map(|(p, _)| p.regex.as_str()))
310                    .collect();
311                if pattern_strs.is_empty() {
312                    return None;
313                }
314                let started = std::time::Instant::now();
315                let input_cap = super::rule_pipeline::megascan_input_len();
316                match super::rule_pipeline::rule_pipeline_cached(&pattern_strs, input_cap as u32) {
317                    Ok(pipe) => {
318                        tracing::info!(
319                            target: "keyhog::routing",
320                            patterns = pattern_strs.len(),
321                            input_len = input_cap,
322                            elapsed_ms = started.elapsed().as_millis() as u64,
323                            "MegaScan RulePipeline compiled"
324                        );
325                        Some(pipe)
326                    }
327                    Err(error) => {
328                        // Demoted from `warn` to `debug` - the
329                        // fallback to literal-set GPU dispatch is the
330                        // designed degradation when vyre's byte-NFA
331                        // frontend can't represent every pattern (e.g.
332                        // lookaround in pattern 990 of the bundled
333                        // detector corpus). The user can't fix it, and
334                        // hitting this WARN once per `--backend mega-
335                        // scan` invocation creates noise without
336                        // signal. kimi-dogfood-3 #138.
337                        tracing::debug!(
338                            patterns = pattern_strs.len(),
339                            error = %format!("{error:?}"),
340                            "MegaScan RulePipeline compile failed - falling back to literal-set GPU dispatch. \
341                             Common causes: regex set exceeds vyre's per-subgroup state cap, or one or more \
342                             patterns use Unicode classes / lookaround / backrefs that the byte-NFA frontend \
343                             can't represent."
344                        );
345                        None
346                    }
347                }
348            })
349            .as_ref()
350    }
351
352    /// Lazily build fused GPU decode→scan programs (base64 + hex).
353    ///
354    /// Returns `None` when no GPU matcher is available (no literals, no
355    /// adapter). The fused programs share the same DFA transition tables
356    /// as the literal-set engine but prepend an on-GPU decode stage,
357    /// eliminating the CPU→GPU round-trip for encoded content.
358    pub fn fused_decode_programs(
359        &self,
360    ) -> Option<&super::gpu_decode_scan::FusedDecodeScanPrograms> {
361        self.fused_decode_programs
362            .get_or_init(|| {
363                let matcher = self.gpu_matcher()?;
364                let state_count = matcher.dfa.state_count;
365                let input_len = super::rule_pipeline::megascan_input_len() as u32;
366                let programs = super::gpu_decode_scan::build_fused_programs(state_count, input_len);
367                if programs.any_available() {
368                    tracing::info!(
369                        target: "keyhog::gpu",
370                        base64 = programs.base64_program.is_some(),
371                        hex = programs.hex_program.is_some(),
372                        state_count,
373                        input_len,
374                        "fused decode+scan programs built"
375                    );
376                    Some(programs)
377                } else {
378                    tracing::debug!(
379                        target: "keyhog::gpu",
380                        "fused decode+scan programs not available - CPU decode path will be used"
381                    );
382                    None
383                }
384            })
385            .as_ref()
386    }
387}