keyhog_scanner/engine/gpu_lazy.rs
1use super::*;
2use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
3
4fn append_match_bound_slot(
5 hits_buffer: &str,
6 count_buffer: &str,
7 tag: impl Into<Expr>,
8 start: impl Into<Expr>,
9 end: impl Into<Expr>,
10) -> Node {
11 let slot_name = "_keyhog_match_slot";
12 let max_hits = Expr::div(Expr::buf_len(hits_buffer), Expr::u32(3));
13
14 Node::block(vec![
15 Node::let_bind(
16 slot_name,
17 Expr::atomic_add(count_buffer, Expr::u32(0), Expr::u32(1)),
18 ),
19 Node::if_then(
20 Expr::lt(Expr::var(slot_name), max_hits),
21 vec![
22 Node::store(
23 hits_buffer,
24 Expr::mul(Expr::var(slot_name), Expr::u32(3)),
25 tag.into(),
26 ),
27 Node::store(
28 hits_buffer,
29 Expr::add(Expr::mul(Expr::var(slot_name), Expr::u32(3)), Expr::u32(1)),
30 start.into(),
31 ),
32 Node::store(
33 hits_buffer,
34 Expr::add(Expr::mul(Expr::var(slot_name), Expr::u32(3)), Expr::u32(2)),
35 end.into(),
36 ),
37 ],
38 ),
39 ])
40}
41
42fn build_ac_bounded_ranges_program_bound_atomic(
43 dfa: &vyre_libs::scan::dfa::CompiledDfa,
44 pattern_count: u32,
45 max_matches: u32,
46) -> Option<Program> {
47 let output_records_len = u32::try_from(dfa.output_records.len()).ok()?;
48 let max_pattern_len = dfa.max_pattern_len.max(1);
49
50 let haystack = "haystack";
51 let transitions = "transitions";
52 let output_offsets = "output_offsets";
53 let output_records = "output_records";
54 let pattern_lengths = "pattern_lengths";
55 let haystack_len = "haystack_len";
56 let match_count = "match_count";
57 let matches = "matches";
58
59 let i = Expr::var("i");
60 let end = Expr::add(i.clone(), Expr::u32(1));
61 let scan_start = Expr::select(
62 Expr::lt(i.clone(), Expr::u32(max_pattern_len - 1)),
63 Expr::u32(0),
64 Expr::sub(end.clone(), Expr::u32(max_pattern_len)),
65 );
66 let (load_step_byte, step_byte) =
67 vyre_libs::scan::builders::load_packed_byte(haystack, Expr::var("step"));
68
69 let walk_body = vec![
70 Node::let_bind("i", Expr::InvocationId { axis: 0 }),
71 Node::if_then(
72 Expr::lt(i.clone(), Expr::load(haystack_len, Expr::u32(0))),
73 vec![
74 Node::let_bind("state", Expr::u32(0)),
75 Node::let_bind("scan_start", scan_start),
76 Node::let_bind("scan_end", end),
77 Node::loop_for(
78 "step",
79 Expr::var("scan_start"),
80 Expr::var("scan_end"),
81 vec![
82 load_step_byte,
83 Node::assign(
84 "state",
85 Expr::load(
86 transitions,
87 Expr::add(Expr::mul(Expr::var("state"), Expr::u32(256)), step_byte),
88 ),
89 ),
90 ],
91 ),
92 Node::let_bind("out_begin", Expr::load(output_offsets, Expr::var("state"))),
93 Node::let_bind(
94 "out_end",
95 Expr::load(output_offsets, Expr::add(Expr::var("state"), Expr::u32(1))),
96 ),
97 Node::loop_for(
98 "out_idx",
99 Expr::var("out_begin"),
100 Expr::var("out_end"),
101 vec![
102 Node::let_bind(
103 "pattern_id",
104 Expr::load(output_records, Expr::var("out_idx")),
105 ),
106 Node::let_bind(
107 "pat_len",
108 Expr::load(pattern_lengths, Expr::var("pattern_id")),
109 ),
110 Node::let_bind(
111 "match_start",
112 Expr::select(
113 Expr::lt(Expr::var("scan_end"), Expr::var("pat_len")),
114 Expr::u32(0),
115 Expr::sub(Expr::var("scan_end"), Expr::var("pat_len")),
116 ),
117 ),
118 append_match_bound_slot(
119 matches,
120 match_count,
121 Expr::var("pattern_id"),
122 Expr::var("match_start"),
123 Expr::var("scan_end"),
124 ),
125 ],
126 ),
127 ],
128 ),
129 ];
130
131 Some(Program::wrapped(
132 vec![
133 BufferDecl::storage(haystack, 0, BufferAccess::ReadOnly, DataType::U32),
134 BufferDecl::storage(transitions, 1, BufferAccess::ReadOnly, DataType::U32)
135 .with_count(dfa.state_count.saturating_mul(256)),
136 BufferDecl::storage(output_offsets, 2, BufferAccess::ReadOnly, DataType::U32)
137 .with_count(dfa.state_count.saturating_add(1)),
138 BufferDecl::storage(output_records, 3, BufferAccess::ReadOnly, DataType::U32)
139 .with_count(output_records_len),
140 BufferDecl::storage(pattern_lengths, 4, BufferAccess::ReadOnly, DataType::U32)
141 .with_count(pattern_count),
142 BufferDecl::storage(haystack_len, 5, BufferAccess::ReadOnly, DataType::U32)
143 .with_count(1),
144 BufferDecl::read_write(match_count, 6, DataType::U32).with_count(1),
145 BufferDecl::output(matches, 7, DataType::U32).with_count(max_matches.saturating_mul(3)),
146 ],
147 [128, 1, 1],
148 vec![vyre_libs::region::wrap_anonymous(
149 "keyhog::matching::classic_ac_bounded_ranges",
150 walk_body,
151 )],
152 ))
153}
154
155/// Tracks whether the subgroup-coalesced match-append form
156/// (subgroup_ballot + subgroup_shuffle -> `_vyre_match_leader`) is
157/// enabled for the AC GPU dispatch Program. Held forced-off because
158/// vyre's substrate-neutral pre-emit lowering rejects that form on
159/// every backend (CUDA and wgpu both: "_vyre_match_leader referenced
160/// before binding", Innovation I.17). This is a named, greppable
161/// dead-path marker rather than a silent inline `false`: once the
162/// vyre IR gap is closed, flipping this to `true` re-enables the
163/// ~32x atomic-contention reduction on the shared match-count buffer
164/// across every backend in one place. The interim contention win
165/// (per-workgroup local reduction -> one atomic add per group) lives
166/// in the kernel builder, not here.
167const AC_GPU_SUBGROUP_COALESCE: bool = false;
168
169impl CompiledScanner {
170 /// Lazily compile the GPU literal-set on first call. Returns `None`
171 /// when no compatible adapter was detected at probe time.
172 ///
173 /// Persists the compiled matcher to `~/.cache/keyhog/programs/<hash>.bin`.
174 /// On a cache hit the matcher is loaded from disk and the GPU
175 /// recompile is skipped entirely - biggest cold-start win on
176 /// `keyhog scan` / `scan-system` runs that re-launch repeatedly.
177 /// Cache misses (no file, version-mismatch, corrupt blob) silently
178 /// recompile and re-cache.
179 pub fn gpu_matcher(&self) -> Option<&vyre_libs::scan::GpuLiteralSet> {
180 self.gpu_matcher
181 .get_or_init(|| {
182 let Some(literals) = &self.gpu_literals else {
183 return None;
184 };
185 let literal_refs: Vec<&[u8]> = literals.iter().map(|v| v.as_slice()).collect();
186 let cache_dir = super::gpu_cache::gpu_matcher_cache_dir()?;
187 let cache_key = format!(
188 "lit-{}",
189 super::gpu_cache::gpu_matcher_cache_key(&literal_refs)
190 );
191 let started = std::time::Instant::now();
192 // One-line lego-block cache wiring courtesy of
193 // `vyre_libs::scan::cached_load_or_compile`. The
194 // helper handles atomic-rename, stale-blob deletion,
195 // and silent fall-through on cache-side I/O errors -
196 // every behaviour the previous hand-rolled
197 // load/save pair tried to match. We log compile cost
198 // here so the operator can still see warm-vs-cold
199 // start latency in `--verbose` output.
200 let matcher =
201 vyre_libs::scan::cached_load_or_compile(&cache_dir, &cache_key, || {
202 vyre_libs::scan::GpuLiteralSet::compile(&literal_refs)
203 });
204 tracing::debug!(
205 target: "keyhog::routing",
206 patterns = literal_refs.len(),
207 elapsed_ms = started.elapsed().as_millis() as u64,
208 "GpuLiteralSet ready (warm cache or compiled)"
209 );
210 Some(matcher)
211 })
212 .as_ref()
213 }
214
215 /// Lazily build the Aho-Corasick bounded-ranges dispatch Program
216 /// from the GpuLiteralSet's CompiledDfa. The two engines share the
217 /// same DFA - only the dispatch Program (and therefore the
218 /// per-byte algorithm) differs:
219 ///
220 /// * `gpu_matcher().program` - `build_literal_set_program`:
221 /// walks every pattern × every literal byte per haystack
222 /// position. `O(N × L) per byte`. Works for any pattern set
223 /// that fits the DFA budget.
224 /// * `ac_gpu_program()` - `classic_ac_bounded_ranges_program`:
225 /// walks the AC transition table forward `L_max` bytes per
226 /// position, emits every pattern in the accepting state's
227 /// flat output_links. `O(L_max) per byte` regardless of N.
228 ///
229 /// Selected at scan time via `KEYHOG_GPU_KERNEL=ac`. Returns
230 /// `None` when no GPU matcher is available; callers fall through
231 /// to the literal-set path or non-GPU backend.
232 ///
233 /// Cap of `super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH` triples per shard
234 /// dispatch matches the existing literal-set output-buffer cap.
235 /// Truncation (count > cap on readback) is handled by the same
236 /// fall-back-to-CPU branch the literal-set path uses.
237 pub fn ac_gpu_program(&self) -> Option<&vyre::Program> {
238 self.ac_gpu_program
239 .get_or_init(|| {
240 let matcher = self.gpu_matcher()?;
241 let pattern_count = matcher.pattern_lengths.len() as u32;
242 // Pick the match-append strategy. The subgroup form
243 // (subgroup_ballot + subgroup_shuffle producing
244 // _vyre_match_leader) was originally gated to wgpu
245 // only because vyre-driver-cuda rejects it during
246 // canonical pre-emit lowering. Runtime testing on
247 // Apple Silicon M4 Pro with vyre v0.4.2 confirmed
248 // the SAME "_vyre_match_leader referenced before
249 // binding" rejection on the wgpu path: the lowering
250 // gap is in vyre's substrate-neutral pre-emit step,
251 // not the driver-specific emitter. Until the IR gap
252 // is closed, use_subgroup_coalesce stays false on
253 // every backend. We lose the ~32x atomic-contention
254 // reduction the subgroup form would have provided
255 // (Innovation I.17), but recall and correctness are
256 // preserved; the plain append_match path produces
257 // bit-identical match output, just with more atomic
258 // pressure on the shared count buffer.
259 let backend_id = self.gpu_backend.as_ref().map(|b| b.id()).unwrap_or("none");
260 let use_subgroup_coalesce = AC_GPU_SUBGROUP_COALESCE;
261 let program = if use_subgroup_coalesce {
262 vyre_libs::scan::classic_ac::build_ac_bounded_ranges_program_ext(
263 &matcher.dfa,
264 pattern_count,
265 super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH,
266 true,
267 )
268 } else {
269 build_ac_bounded_ranges_program_bound_atomic(
270 &matcher.dfa,
271 pattern_count,
272 super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH,
273 )?
274 };
275 tracing::debug!(
276 target: "keyhog::routing",
277 pattern_count,
278 state_count = matcher.dfa.state_count,
279 max_pattern_len = matcher.dfa.max_pattern_len,
280 backend = backend_id,
281 use_subgroup_coalesce,
282 "AC GPU dispatch Program built"
283 );
284 Some(program)
285 })
286 .as_ref()
287 }
288
289 /// Lazily compile the regex-NFA `RulePipeline` on first call.
290 /// Returns `None` once the OnceLock has fired when the regex
291 /// compile failed - typically because the combined NFA exceeds
292 /// vyre's per-subgroup state cap (`LANES * 32`) or because one
293 /// of the detector regexes uses a feature the byte-NFA frontend
294 /// can't represent (Unicode classes, lookaround, backrefs).
295 /// Callers should fall back to the literal-set GPU dispatch on
296 /// `None`.
297 ///
298 /// Pipeline is sized for [`super::rule_pipeline::megascan_input_len()`] bytes; batches
299 /// larger than that must take a different path. The orchestrator
300 /// caps batches at the same value (256 MiB default, up to 1 GiB
301 /// on 24+ GiB-VRAM cards) so this matches normal scan flow.
302 pub fn rule_pipeline(&self) -> Option<&vyre_libs::scan::RulePipeline> {
303 self.rule_pipeline
304 .get_or_init(|| {
305 let pattern_strs: Vec<&str> = self
306 .ac_map
307 .iter()
308 .map(|p| p.regex.as_str())
309 .chain(self.fallback.iter().map(|(p, _)| p.regex.as_str()))
310 .collect();
311 if pattern_strs.is_empty() {
312 return None;
313 }
314 let started = std::time::Instant::now();
315 let input_cap = super::rule_pipeline::megascan_input_len();
316 match super::rule_pipeline::rule_pipeline_cached(&pattern_strs, input_cap as u32) {
317 Ok(pipe) => {
318 tracing::info!(
319 target: "keyhog::routing",
320 patterns = pattern_strs.len(),
321 input_len = input_cap,
322 elapsed_ms = started.elapsed().as_millis() as u64,
323 "MegaScan RulePipeline compiled"
324 );
325 Some(pipe)
326 }
327 Err(error) => {
328 // Demoted from `warn` to `debug` - the
329 // fallback to literal-set GPU dispatch is the
330 // designed degradation when vyre's byte-NFA
331 // frontend can't represent every pattern (e.g.
332 // lookaround in pattern 990 of the bundled
333 // detector corpus). The user can't fix it, and
334 // hitting this WARN once per `--backend mega-
335 // scan` invocation creates noise without
336 // signal. kimi-dogfood-3 #138.
337 tracing::debug!(
338 patterns = pattern_strs.len(),
339 error = %format!("{error:?}"),
340 "MegaScan RulePipeline compile failed - falling back to literal-set GPU dispatch. \
341 Common causes: regex set exceeds vyre's per-subgroup state cap, or one or more \
342 patterns use Unicode classes / lookaround / backrefs that the byte-NFA frontend \
343 can't represent."
344 );
345 None
346 }
347 }
348 })
349 .as_ref()
350 }
351
352 /// Lazily build fused GPU decode→scan programs (base64 + hex).
353 ///
354 /// Returns `None` when no GPU matcher is available (no literals, no
355 /// adapter). The fused programs share the same DFA transition tables
356 /// as the literal-set engine but prepend an on-GPU decode stage,
357 /// eliminating the CPU→GPU round-trip for encoded content.
358 pub fn fused_decode_programs(
359 &self,
360 ) -> Option<&super::gpu_decode_scan::FusedDecodeScanPrograms> {
361 self.fused_decode_programs
362 .get_or_init(|| {
363 let matcher = self.gpu_matcher()?;
364 let state_count = matcher.dfa.state_count;
365 let input_len = super::rule_pipeline::megascan_input_len() as u32;
366 let programs = super::gpu_decode_scan::build_fused_programs(state_count, input_len);
367 if programs.any_available() {
368 tracing::info!(
369 target: "keyhog::gpu",
370 base64 = programs.base64_program.is_some(),
371 hex = programs.hex_program.is_some(),
372 state_count,
373 input_len,
374 "fused decode+scan programs built"
375 );
376 Some(programs)
377 } else {
378 tracing::debug!(
379 target: "keyhog::gpu",
380 "fused decode+scan programs not available - CPU decode path will be used"
381 );
382 None
383 }
384 })
385 .as_ref()
386 }
387}