keyhog_scanner/engine/gpu_literal_phase1.rs
1use super::*;
2
3impl CompiledScanner {
4 pub fn scan_coalesced_gpu_phase1(&self, chunks: &[keyhog_core::Chunk]) -> GpuPhase1Output {
5 // The literal_set program embeds `append_match_subgroup`
6 // (subgroup_ballot + subgroup_shuffle), and vyre's canonical
7 // pre-emit lowering rejects that subgroup form regardless of
8 // the downstream emitter ("variable `_vyre_match_leader` is
9 // referenced before binding"). This was previously gated to
10 // CUDA only, but the rejection happens BEFORE driver-specific
11 // emission, so WGPU hosts (Apple Silicon, Intel Mac, Windows)
12 // hit the same rejection on the literal_set path and silently
13 // dropped to CPU.
14 //
15 // Until the vyre pre-emit lowering accepts the subgroup form
16 // (tracked separately), the AC kernel path is the working
17 // GPU code path for both CUDA and WGPU. KEYHOG_GPU_KERNEL=
18 // literal-set forces the broken path for diagnostic /
19 // bisection use; the default is now AC for every GPU backend.
20 // Cache the env-var lookup. `scan_coalesced_gpu_phase1` is called
21 // per batched chunk group; reading env::var on the hot path costs
22 // ~200 ns per call which adds up to milliseconds across 1k+
23 // chunks. The diagnostic override is process-static so caching
24 // once is exact.
25 static FORCE_LITERAL_SET: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
26 let force_literal_set = *FORCE_LITERAL_SET.get_or_init(|| {
27 matches!(
28 std::env::var("KEYHOG_GPU_KERNEL").ok().as_deref(),
29 Some("literal-set") | Some("literal_set")
30 )
31 });
32 if !force_literal_set {
33 return self.scan_coalesced_gpu_ac_phase1(chunks);
34 }
35
36 // Auto-degrade to the next-best backend when the GPU stack is not
37 // ready: no compiled matcher (no adapter at probe time), the cached
38 // device went away, or the persistent backend is missing.
39 let Some(matcher) = self.gpu_matcher() else {
40 return self.gpu_degrade_done_with_reason(
41 chunks,
42 crate::hw_probe::ScanBackend::Gpu,
43 Some("GPU literal-set matcher unavailable"),
44 );
45 };
46 if self.gpu_backend.is_none() {
47 return self.gpu_degrade_done_with_reason(
48 chunks,
49 crate::hw_probe::ScanBackend::Gpu,
50 Some("GPU backend handle unavailable for literal-set dispatch"),
51 );
52 }
53
54 let (entries, mut buffer) = super::gpu_coalesce::coalesce_chunks(chunks);
55
56 // ASCII-lowercase the coalesced haystack so the literal-set automaton
57 // matches case-INSENSITIVELY, matching the SIMD Hyperscan path (CASELESS
58 // for every pattern) and the lowercased literal set from
59 // `build_gpu_literals`. Prefilter-only buffer (phase 2 re-confirms on
60 // original bytes); ASCII fold is position-preserving so offsets are
61 // unchanged. See PERF-07 / gpu_ac_phase1 for the full rationale.
62 buffer.make_ascii_lowercase();
63
64 // 4-byte align the coalesced buffer so every shard slice can be
65 // passed to vyre's u32-typed haystack input WITHOUT a per-shard
66 // `pack_haystack_u32` call. The pack helper is a 2x memcopy
67 // (Vec<u32> intermediate + Vec<u8> output) that produces bytes
68 // byte-identical to the input on 4-aligned slices (see
69 // `vyre_foundation::byte_pack::pack_haystack_u32`). On a 1 GiB
70 // scan with 2 MiB shards that's 512 shards x 2x = ~4 GiB of
71 // throwaway allocations - load-bearing on the 25s gap GPU
72 // currently loses to SIMD at scale. Padding the source buffer
73 // once and slicing each shard collapses that to zero alloc per
74 // shard. Padding bytes are NUL, which no detector literal can
75 // match (extract_literal_prefixes drops NUL), so the trailing
76 // zero-extension is recall-safe.
77 while !buffer.len().is_multiple_of(4) {
78 buffer.push(0);
79 }
80
81 #[cfg(target_os = "linux")]
82 // SAFETY: `buffer` is a live `Vec<u8>` whose `as_ptr()` and
83 // `len()` describe a valid memory range owned by this scope.
84 // `madvise` is advisory - the kernel may ignore it on
85 // non-page-aligned ranges; we treat the call as best-effort
86 // and don't check the return value.
87 unsafe {
88 // Senior Audit §Phase 7.4: Prevent GPU buffers from leaking into core dumps.
89 libc::madvise(
90 buffer.as_ptr() as *mut libc::c_void,
91 buffer.len(),
92 libc::MADV_DONTDUMP,
93 );
94 }
95
96 // Adaptive match cap that scales with the actual buffer size
97 // rather than chunk count. Real-world ceiling: roughly one
98 // literal hit per 64 input bytes is already implausibly dense
99 // for production source code (the densest fixture in the
100 // performance regression suite is ~1 hit per 1 KiB). The
101 // chunk-count formula systematically under-sized batches that
102 // had a few large files, leading to spurious truncation and
103 // the full-CPU re-scan that wastes the GPU dispatch we just
104 // paid for.
105 //
106 // Keeps the kimi-wave2 `cap+1` sentinel-slot trick: ask the
107 // GPU for one more than the cap, and only treat `> cap` as
108 // truncation. A batch that lands EXACTLY at the cap is by
109 // definition complete (would have written into the sentinel
110 // slot otherwise).
111 const MIN_CAP: u32 = 100_000;
112 const MAX_CAP: u32 = 16_000_000;
113 let buffer_cap = (buffer.len() / 64) as u64;
114 let cap: u32 = buffer_cap.clamp(MIN_CAP as u64, MAX_CAP as u64) as u32;
115
116 // wgpu caps each compute dispatch at 65535 workgroups per
117 // dimension (WebGPU spec). Vyre's GpuLiteralSet uses
118 // workgroup_size_x = 32, so a single dispatch can handle at
119 // most 65535 × 32 = 2,097,120 input bytes. For coalesced
120 // batches larger than this (always true with the tier-aware
121 // 2 MiB activation threshold + the orchestrator's adaptive
122 // `batch_bytes_budget` - 256 MiB default, up to 1 GiB on
123 // 24-GiB-VRAM cards), shard the buffer into 2-MiB-or-less
124 // pieces, dispatch each, and merge the matches with a
125 // `start` offset added to put them back into the global
126 // buffer's coordinate space.
127 //
128 // Shard size: 65535 (max workgroups per dim) × 32 (vyre's
129 // workgroup_size_x) = 2,097,120 bytes. Exactly 2 MiB =
130 // 2,097,152 bytes overflows by one workgroup. Use the
131 // exact-aligned value to maximise per-shard throughput
132 // without tripping the wgpu dispatch validator.
133 //
134 // Extra dispatches add ~100 µs each on a high-tier GPU; for
135 // a 256 MiB batch that's ~12 ms of overhead vs SIMD's ~70 s
136 // (a 5800× win). On a 1 GiB batch (5090-class adapter) the
137 // shard count rises 4× but the GPU-vs-SIMD ratio widens
138 // because per-shard dispatch is amortized over more bytes.
139 // Dynamic per-vyre-workgroup: each shard covers
140 // (max_workgroups_per_dim × workgroup_size_x) bytes.
141 // wgpu caps workgroups per dimension at 65 535; vyre's
142 // literal-set program reports its `workgroup_size_x` via
143 // `matcher.program.workgroup_size[0]`. Was hard-coded at
144 // 65_535 × 32 when vyre's literal-set used
145 // workgroup_size_x = 32; now scales automatically when
146 // the vyre side is tuned (e.g. to 128 to cut shard count
147 // by 4×).
148 let workgroup_x = matcher.program.workgroup_size[0] as usize;
149 let gpu_dispatch_max_bytes: usize = 65_535 * workgroup_x;
150 let started = std::time::Instant::now();
151
152 // Slice the coalesced buffer into wgpu-dispatch-sized shards.
153 // The shard boundary itself is wgpu's `dispatch_workgroups`
154 // limit (65 535 workgroups per dimension × 32-byte workgroup
155 // size). The previous flow dispatched these one-by-one with
156 // `matcher.scan` - each call records its own encoder,
157 // submits, and `device.poll(Wait)`s. On a 1 GiB batch with
158 // 512 shards that adds up to ~50 ms × 512 = 25 s of pure
159 // host-side dispatch overhead, *not* GPU compute.
160 //
161 // `WgpuBackend::dispatch_borrowed_batch` records *all* shard
162 // dispatches into one command encoder, single submit, single
163 // poll. For 512 shards the wait collapses from ~25 s to
164 // a single GPU drain - close to the actual compute time.
165 let mut shard_ranges: Vec<(usize, usize)> = Vec::new();
166 let mut shard_start = 0usize;
167 while shard_start < buffer.len() {
168 let shard_end = (shard_start + gpu_dispatch_max_bytes).min(buffer.len());
169 shard_ranges.push((shard_start, shard_end));
170 shard_start = shard_end;
171 }
172 let shard_count = shard_ranges.len();
173
174 // Constants across all shards: pattern offsets/lengths/bytes
175 // and pattern_count. Pre-packed ONCE per process via the
176 // CompiledScanner-level OnceLock and borrowed every dispatch.
177 // Before this cache, `pack_u32_slice` ran four times per scan
178 // producing identical bytes; a process scanning 10 k files
179 // burned 40 k throwaway Vec<u8> allocations on data that never
180 // changes after compile.
181 let const_packs = self
182 .gpu_const_packs
183 .get_or_init(|| super::gpu_cache::GpuConstPacks {
184 pattern_offsets: vyre_libs::scan::dispatch_io::pack_u32_slice(
185 &matcher.pattern_offsets,
186 ),
187 pattern_lengths: vyre_libs::scan::dispatch_io::pack_u32_slice(
188 &matcher.pattern_lengths,
189 ),
190 pattern_bytes: vyre_libs::scan::dispatch_io::pack_u32_slice(&matcher.pattern_bytes),
191 pattern_count: vyre_libs::scan::dispatch_io::pack_u32_slice(&[matcher
192 .pattern_lengths
193 .len()
194 as u32]),
195 });
196
197 // Per-shard tiny bytes (shard_len scalar + the two atomic
198 // counters + dispatch config). The haystack input is the
199 // 4-byte-aligned source buffer sliced in place - no Vec<u8>
200 // packing allocation per shard (see the buffer padding above
201 // for the rationale).
202 struct ShardOwned {
203 haystack_len: Vec<u8>,
204 atomic_count: Vec<u8>,
205 atomic_overflow: Vec<u8>,
206 config: vyre::DispatchConfig,
207 cap: u32,
208 }
209 let mut shard_owned: Vec<ShardOwned> = Vec::with_capacity(shard_count);
210 for (start, end) in &shard_ranges {
211 let shard_len = (*end - *start) as u32;
212 let shard_cap_u64 = ((*end - *start) / 64) as u64;
213 let shard_cap = shard_cap_u64.clamp(MIN_CAP as u64, MAX_CAP as u64) as u32;
214 shard_owned.push(ShardOwned {
215 haystack_len: vyre_libs::scan::dispatch_io::pack_u32_slice(&[shard_len]),
216 atomic_count: vec![0u8; 4],
217 atomic_overflow: vec![0u8; 4],
218 config: vyre_libs::scan::dispatch_io::byte_scan_dispatch_config(
219 shard_len,
220 matcher.program.workgroup_size[0],
221 ),
222 cap: shard_cap,
223 });
224 }
225
226 // Build borrowed input arrays per shard. Order must match
227 // `GpuLiteralSet::scan` because the buffer-decl order is the
228 // contract between host inputs and GPU kernel binding. The
229 // haystack slot is now a direct slice into the padded source
230 // buffer - no per-shard packing allocation.
231 let shard_input_arrays: Vec<[&[u8]; 8]> = shard_owned
232 .iter()
233 .zip(shard_ranges.iter())
234 .map(|(s, (start, end))| {
235 [
236 &buffer[*start..*end],
237 const_packs.pattern_offsets.as_slice(),
238 const_packs.pattern_lengths.as_slice(),
239 const_packs.pattern_bytes.as_slice(),
240 s.haystack_len.as_slice(),
241 const_packs.pattern_count.as_slice(),
242 s.atomic_count.as_slice(),
243 s.atomic_overflow.as_slice(),
244 ]
245 })
246 .collect();
247
248 // vyre's wgpu readback ring is sized at DEFAULT_RING_SLOTS
249 // (lifted to 2048 in vendor/vyre - see
250 // `runtime/readback_ring.rs` for the rationale). Each
251 // GpuLiteralSet dispatch produces 2 readback buffers, so
252 // a batch of N shards burns 2N slots from the 2048-slot
253 // ring. The other constraint is host-side memory: each
254 // shard's haystack is borrowed (no copy), but its
255 // per-dispatch config + atomic counters still allocate
256 // ~24 bytes per shard. The real cost is the input-arrays
257 // Vec<[&[u8]; 8]> at ~64 bytes per entry.
258 //
259 // Adaptive batch cap: a bigger batch flattens the
260 // command-encoder cost across more shards and shortens
261 // the wall-clock for a multi-GiB scan, but climbs
262 // the ring-slot occupancy. 64 was the original safe
263 // value for small hosts; 256 still leaves the 2048-slot
264 // ring deeply under-subscribed and matches the workload
265 // a 24 GiB-VRAM card actually wants.
266 //
267 // total RAM shards/batch 1-GiB-scan sequential batches
268 // < 16 GiB 64 ≥ 8
269 // 16-32 GiB 128 4
270 // ≥ 32 GiB 256 2
271 //
272 // The 96-GiB-RAM RTX-5090 workstation case drops from
273 // 8 sequential batched dispatches to 2, cutting GPU
274 // pipeline-drain stalls roughly 4x on a 1-GiB batch.
275 let max_shards_per_gpu_batch: usize = {
276 let total_ram_mb = crate::hw_probe::probe_hardware()
277 .total_memory_mb
278 .unwrap_or(0);
279 if total_ram_mb >= 32 * 1024 {
280 256
281 } else if total_ram_mb >= 16 * 1024 {
282 128
283 } else {
284 64
285 }
286 };
287 let mut matches: Vec<vyre_libs::scan::LiteralMatch> = Vec::new();
288 for sub_start in (0..shard_count).step_by(max_shards_per_gpu_batch) {
289 let sub_end = (sub_start + max_shards_per_gpu_batch).min(shard_count);
290 let sub_inputs: Vec<&[&[u8]]> = (sub_start..sub_end)
291 .map(|i| &shard_input_arrays[i][..])
292 .collect();
293 let sub_configs: Vec<vyre::DispatchConfig> = (sub_start..sub_end)
294 .map(|i| shard_owned[i].config.clone())
295 .collect();
296
297 let batch_results =
298 match self.dispatch_gpu_shards(&matcher.program, &sub_inputs, &sub_configs) {
299 Ok(r) => r,
300 Err(e) => {
301 tracing::error!(
302 shards = sub_end - sub_start,
303 "GPU batched dispatch failed, falling back to CPU: {e}"
304 );
305 let reason = format!("GPU literal-set batched dispatch failed: {e}");
306 return self.gpu_degrade_done_with_reason(
307 chunks,
308 crate::hw_probe::ScanBackend::Gpu,
309 Some(&reason),
310 );
311 }
312 };
313
314 for (offset_in_sub, result) in batch_results.into_iter().enumerate() {
315 let i = sub_start + offset_in_sub;
316 let outputs = match result {
317 Ok(o) => o,
318 Err(e) => {
319 tracing::error!(
320 shard_index = i,
321 "GPU shard within batch failed, falling back to CPU: {e}"
322 );
323 let reason = format!("GPU literal-set shard {i} dispatch failed: {e}");
324 return self.gpu_degrade_done_with_reason(
325 chunks,
326 crate::hw_probe::ScanBackend::Gpu,
327 Some(&reason),
328 );
329 }
330 };
331 if outputs.len() < 2 {
332 tracing::error!(
333 shard_index = i,
334 outputs = outputs.len(),
335 "GPU shard output buffer count too small; falling back to CPU"
336 );
337 let reason = format!(
338 "GPU literal-set shard {i} returned {} output buffer(s), expected at least 2",
339 outputs.len()
340 );
341 return self.gpu_degrade_done_with_reason(
342 chunks,
343 crate::hw_probe::ScanBackend::Gpu,
344 Some(&reason),
345 );
346 }
347 let count_bytes = &outputs[0];
348 let matches_bytes = &outputs[1];
349 if count_bytes.len() < 4 {
350 tracing::error!(
351 shard_index = i,
352 "GPU shard count buffer truncated; falling back to CPU"
353 );
354 let reason = format!(
355 "GPU literal-set shard {i} returned truncated count buffer ({} byte(s), expected 4)",
356 count_bytes.len()
357 );
358 return self.gpu_degrade_done_with_reason(
359 chunks,
360 crate::hw_probe::ScanBackend::Gpu,
361 Some(&reason),
362 );
363 }
364 let count = u32::from_le_bytes([
365 count_bytes[0],
366 count_bytes[1],
367 count_bytes[2],
368 count_bytes[3],
369 ]);
370 let shard_cap = shard_owned[i].cap;
371 if count > shard_cap {
372 tracing::warn!(
373 cap = shard_cap,
374 count,
375 shard_index = i,
376 "GPU shard exceeded its cap: truncation possible; falling back to CPU"
377 );
378 let reason = format!(
379 "GPU literal-set shard {i} reported {count} matches, exceeding cap {shard_cap}"
380 );
381 return self.gpu_degrade_done_with_reason(
382 chunks,
383 crate::hw_probe::ScanBackend::Gpu,
384 Some(&reason),
385 );
386 }
387 let shard_matches = vyre_libs::scan::dispatch_io::unpack_match_triples(
388 matches_bytes,
389 count.min(shard_cap),
390 );
391 let offset = shard_ranges[i].0 as u32;
392 for m in &shard_matches {
393 matches.push(vyre_libs::scan::LiteralMatch::new(
394 m.pattern_id,
395 m.start.saturating_add(offset),
396 m.end.saturating_add(offset),
397 ));
398 }
399 }
400 }
401 let elapsed_ms = started.elapsed().as_millis();
402 tracing::debug!(
403 target: "keyhog::routing",
404 chunks = chunks.len(),
405 buffer_bytes = buffer.len(),
406 matches = matches.len(),
407 shards = shard_count,
408 cap,
409 elapsed_ms,
410 "vyre GPU batched scan completed"
411 );
412 if self.has_simd_prefilter()
413 && super::gpu_postprocess::gpu_phase2_hits_are_dense(
414 matches.len(),
415 buffer.len(),
416 chunks.len(),
417 )
418 {
419 tracing::warn!(
420 target: "keyhog::routing",
421 raw_matches = matches.len(),
422 buffer_bytes = buffer.len(),
423 chunks = chunks.len(),
424 "GPU literal prefix output is too dense for phase 2; rerouting this batch through SIMD coalesced scan",
425 );
426 if std::env::var_os("KH_PERF").is_some() {
427 eprintln!(
428 "KH_PERF gpu_literal_dense_phase2_reroute: chunks={} buffer_bytes={} raw_matches={} bytes_per_hit={:.1}",
429 chunks.len(),
430 buffer.len(),
431 matches.len(),
432 buffer.len() as f64 / matches.len().max(1) as f64
433 );
434 }
435 return GpuPhase1Output::Done(self.scan_coalesced_non_gpu(chunks));
436 }
437 // Per-pid dedup + chunk attribution lives in `gpu_postprocess`,
438 // shared with the AC kernel phase-1 path. The downstream
439 // `scan_prepared_with_pattern_hits` consumer requires matches
440 // anchored to chunk-local `(pid, local_start, local_end)`
441 // triples sorted by start so the regex confirmation step runs
442 // anchored at each hit rather than re-sweeping each chunk.
443 super::gpu_postprocess::fold_overlapping_same_pid_inplace(&mut matches);
444 let total_patterns = self.ac_map.len() + self.fallback.len();
445 let per_chunk_hits = super::gpu_postprocess::attribute_matches_to_chunks(
446 &matches,
447 &entries,
448 total_patterns,
449 chunks.len(),
450 );
451
452 GpuPhase1Output::Hits(per_chunk_hits)
453 }
454}
455
456// Phase 2 (CPU post-process that runs after this file's GPU
457// literal-set dispatch produces per-chunk hits) lives in
458// `gpu_phase2.rs`. The orphan doc-comment that previously trailed
459// here described that function and was stranded when the body moved.