Skip to main content

keyhog_scanner/engine/
gpu_literal_phase1.rs

1use super::*;
2
3impl CompiledScanner {
4    pub fn scan_coalesced_gpu_phase1(&self, chunks: &[keyhog_core::Chunk]) -> GpuPhase1Output {
5        // The literal_set program embeds `append_match_subgroup`
6        // (subgroup_ballot + subgroup_shuffle), and vyre's canonical
7        // pre-emit lowering rejects that subgroup form regardless of
8        // the downstream emitter ("variable `_vyre_match_leader` is
9        // referenced before binding"). This was previously gated to
10        // CUDA only, but the rejection happens BEFORE driver-specific
11        // emission, so WGPU hosts (Apple Silicon, Intel Mac, Windows)
12        // hit the same rejection on the literal_set path and silently
13        // dropped to CPU.
14        //
15        // Until the vyre pre-emit lowering accepts the subgroup form
16        // (tracked separately), the AC kernel path is the working
17        // GPU code path for both CUDA and WGPU. KEYHOG_GPU_KERNEL=
18        // literal-set forces the broken path for diagnostic /
19        // bisection use; the default is now AC for every GPU backend.
20        // Cache the env-var lookup. `scan_coalesced_gpu_phase1` is called
21        // per batched chunk group; reading env::var on the hot path costs
22        // ~200 ns per call which adds up to milliseconds across 1k+
23        // chunks. The diagnostic override is process-static so caching
24        // once is exact.
25        static FORCE_LITERAL_SET: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
26        let force_literal_set = *FORCE_LITERAL_SET.get_or_init(|| {
27            matches!(
28                std::env::var("KEYHOG_GPU_KERNEL").ok().as_deref(),
29                Some("literal-set") | Some("literal_set")
30            )
31        });
32        if !force_literal_set {
33            return self.scan_coalesced_gpu_ac_phase1(chunks);
34        }
35
36        // Auto-degrade to the next-best backend when the GPU stack is not
37        // ready: no compiled matcher (no adapter at probe time), the cached
38        // device went away, or the persistent backend is missing.
39        let Some(matcher) = self.gpu_matcher() else {
40            return self.gpu_degrade_done_with_reason(
41                chunks,
42                crate::hw_probe::ScanBackend::Gpu,
43                Some("GPU literal-set matcher unavailable"),
44            );
45        };
46        if self.gpu_backend.is_none() {
47            return self.gpu_degrade_done_with_reason(
48                chunks,
49                crate::hw_probe::ScanBackend::Gpu,
50                Some("GPU backend handle unavailable for literal-set dispatch"),
51            );
52        }
53
54        let (entries, mut buffer) = super::gpu_coalesce::coalesce_chunks(chunks);
55
56        // ASCII-lowercase the coalesced haystack so the literal-set automaton
57        // matches case-INSENSITIVELY, matching the SIMD Hyperscan path (CASELESS
58        // for every pattern) and the lowercased literal set from
59        // `build_gpu_literals`. Prefilter-only buffer (phase 2 re-confirms on
60        // original bytes); ASCII fold is position-preserving so offsets are
61        // unchanged. See PERF-07 / gpu_ac_phase1 for the full rationale.
62        buffer.make_ascii_lowercase();
63
64        // 4-byte align the coalesced buffer so every shard slice can be
65        // passed to vyre's u32-typed haystack input WITHOUT a per-shard
66        // `pack_haystack_u32` call. The pack helper is a 2x memcopy
67        // (Vec<u32> intermediate + Vec<u8> output) that produces bytes
68        // byte-identical to the input on 4-aligned slices (see
69        // `vyre_foundation::byte_pack::pack_haystack_u32`). On a 1 GiB
70        // scan with 2 MiB shards that's 512 shards x 2x = ~4 GiB of
71        // throwaway allocations - load-bearing on the 25s gap GPU
72        // currently loses to SIMD at scale. Padding the source buffer
73        // once and slicing each shard collapses that to zero alloc per
74        // shard. Padding bytes are NUL, which no detector literal can
75        // match (extract_literal_prefixes drops NUL), so the trailing
76        // zero-extension is recall-safe.
77        while !buffer.len().is_multiple_of(4) {
78            buffer.push(0);
79        }
80
81        #[cfg(target_os = "linux")]
82        // SAFETY: `buffer` is a live `Vec<u8>` whose `as_ptr()` and
83        // `len()` describe a valid memory range owned by this scope.
84        // `madvise` is advisory - the kernel may ignore it on
85        // non-page-aligned ranges; we treat the call as best-effort
86        // and don't check the return value.
87        unsafe {
88            // Senior Audit §Phase 7.4: Prevent GPU buffers from leaking into core dumps.
89            libc::madvise(
90                buffer.as_ptr() as *mut libc::c_void,
91                buffer.len(),
92                libc::MADV_DONTDUMP,
93            );
94        }
95
96        // Adaptive match cap that scales with the actual buffer size
97        // rather than chunk count. Real-world ceiling: roughly one
98        // literal hit per 64 input bytes is already implausibly dense
99        // for production source code (the densest fixture in the
100        // performance regression suite is ~1 hit per 1 KiB). The
101        // chunk-count formula systematically under-sized batches that
102        // had a few large files, leading to spurious truncation and
103        // the full-CPU re-scan that wastes the GPU dispatch we just
104        // paid for.
105        //
106        // Keeps the kimi-wave2 `cap+1` sentinel-slot trick: ask the
107        // GPU for one more than the cap, and only treat `> cap` as
108        // truncation. A batch that lands EXACTLY at the cap is by
109        // definition complete (would have written into the sentinel
110        // slot otherwise).
111        const MIN_CAP: u32 = 100_000;
112        const MAX_CAP: u32 = 16_000_000;
113        let buffer_cap = (buffer.len() / 64) as u64;
114        let cap: u32 = buffer_cap.clamp(MIN_CAP as u64, MAX_CAP as u64) as u32;
115
116        // wgpu caps each compute dispatch at 65535 workgroups per
117        // dimension (WebGPU spec). Vyre's GpuLiteralSet uses
118        // workgroup_size_x = 32, so a single dispatch can handle at
119        // most 65535 × 32 = 2,097,120 input bytes. For coalesced
120        // batches larger than this (always true with the tier-aware
121        // 2 MiB activation threshold + the orchestrator's adaptive
122        // `batch_bytes_budget` - 256 MiB default, up to 1 GiB on
123        // 24-GiB-VRAM cards), shard the buffer into 2-MiB-or-less
124        // pieces, dispatch each, and merge the matches with a
125        // `start` offset added to put them back into the global
126        // buffer's coordinate space.
127        //
128        // Shard size: 65535 (max workgroups per dim) × 32 (vyre's
129        // workgroup_size_x) = 2,097,120 bytes. Exactly 2 MiB =
130        // 2,097,152 bytes overflows by one workgroup. Use the
131        // exact-aligned value to maximise per-shard throughput
132        // without tripping the wgpu dispatch validator.
133        //
134        // Extra dispatches add ~100 µs each on a high-tier GPU; for
135        // a 256 MiB batch that's ~12 ms of overhead vs SIMD's ~70 s
136        // (a 5800× win). On a 1 GiB batch (5090-class adapter) the
137        // shard count rises 4× but the GPU-vs-SIMD ratio widens
138        // because per-shard dispatch is amortized over more bytes.
139        // Dynamic per-vyre-workgroup: each shard covers
140        // (max_workgroups_per_dim × workgroup_size_x) bytes.
141        // wgpu caps workgroups per dimension at 65 535; vyre's
142        // literal-set program reports its `workgroup_size_x` via
143        // `matcher.program.workgroup_size[0]`. Was hard-coded at
144        // 65_535 × 32 when vyre's literal-set used
145        // workgroup_size_x = 32; now scales automatically when
146        // the vyre side is tuned (e.g. to 128 to cut shard count
147        // by 4×).
148        let workgroup_x = matcher.program.workgroup_size[0] as usize;
149        let gpu_dispatch_max_bytes: usize = 65_535 * workgroup_x;
150        let started = std::time::Instant::now();
151
152        // Slice the coalesced buffer into wgpu-dispatch-sized shards.
153        // The shard boundary itself is wgpu's `dispatch_workgroups`
154        // limit (65 535 workgroups per dimension × 32-byte workgroup
155        // size). The previous flow dispatched these one-by-one with
156        // `matcher.scan` - each call records its own encoder,
157        // submits, and `device.poll(Wait)`s. On a 1 GiB batch with
158        // 512 shards that adds up to ~50 ms × 512 = 25 s of pure
159        // host-side dispatch overhead, *not* GPU compute.
160        //
161        // `WgpuBackend::dispatch_borrowed_batch` records *all* shard
162        // dispatches into one command encoder, single submit, single
163        // poll. For 512 shards the wait collapses from ~25 s to
164        // a single GPU drain - close to the actual compute time.
165        let mut shard_ranges: Vec<(usize, usize)> = Vec::new();
166        let mut shard_start = 0usize;
167        while shard_start < buffer.len() {
168            let shard_end = (shard_start + gpu_dispatch_max_bytes).min(buffer.len());
169            shard_ranges.push((shard_start, shard_end));
170            shard_start = shard_end;
171        }
172        let shard_count = shard_ranges.len();
173
174        // Constants across all shards: pattern offsets/lengths/bytes
175        // and pattern_count. Pre-packed ONCE per process via the
176        // CompiledScanner-level OnceLock and borrowed every dispatch.
177        // Before this cache, `pack_u32_slice` ran four times per scan
178        // producing identical bytes; a process scanning 10 k files
179        // burned 40 k throwaway Vec<u8> allocations on data that never
180        // changes after compile.
181        let const_packs = self
182            .gpu_const_packs
183            .get_or_init(|| super::gpu_cache::GpuConstPacks {
184                pattern_offsets: vyre_libs::scan::dispatch_io::pack_u32_slice(
185                    &matcher.pattern_offsets,
186                ),
187                pattern_lengths: vyre_libs::scan::dispatch_io::pack_u32_slice(
188                    &matcher.pattern_lengths,
189                ),
190                pattern_bytes: vyre_libs::scan::dispatch_io::pack_u32_slice(&matcher.pattern_bytes),
191                pattern_count: vyre_libs::scan::dispatch_io::pack_u32_slice(&[matcher
192                    .pattern_lengths
193                    .len()
194                    as u32]),
195            });
196
197        // Per-shard tiny bytes (shard_len scalar + the two atomic
198        // counters + dispatch config). The haystack input is the
199        // 4-byte-aligned source buffer sliced in place - no Vec<u8>
200        // packing allocation per shard (see the buffer padding above
201        // for the rationale).
202        struct ShardOwned {
203            haystack_len: Vec<u8>,
204            atomic_count: Vec<u8>,
205            atomic_overflow: Vec<u8>,
206            config: vyre::DispatchConfig,
207            cap: u32,
208        }
209        let mut shard_owned: Vec<ShardOwned> = Vec::with_capacity(shard_count);
210        for (start, end) in &shard_ranges {
211            let shard_len = (*end - *start) as u32;
212            let shard_cap_u64 = ((*end - *start) / 64) as u64;
213            let shard_cap = shard_cap_u64.clamp(MIN_CAP as u64, MAX_CAP as u64) as u32;
214            shard_owned.push(ShardOwned {
215                haystack_len: vyre_libs::scan::dispatch_io::pack_u32_slice(&[shard_len]),
216                atomic_count: vec![0u8; 4],
217                atomic_overflow: vec![0u8; 4],
218                config: vyre_libs::scan::dispatch_io::byte_scan_dispatch_config(
219                    shard_len,
220                    matcher.program.workgroup_size[0],
221                ),
222                cap: shard_cap,
223            });
224        }
225
226        // Build borrowed input arrays per shard. Order must match
227        // `GpuLiteralSet::scan` because the buffer-decl order is the
228        // contract between host inputs and GPU kernel binding. The
229        // haystack slot is now a direct slice into the padded source
230        // buffer - no per-shard packing allocation.
231        let shard_input_arrays: Vec<[&[u8]; 8]> = shard_owned
232            .iter()
233            .zip(shard_ranges.iter())
234            .map(|(s, (start, end))| {
235                [
236                    &buffer[*start..*end],
237                    const_packs.pattern_offsets.as_slice(),
238                    const_packs.pattern_lengths.as_slice(),
239                    const_packs.pattern_bytes.as_slice(),
240                    s.haystack_len.as_slice(),
241                    const_packs.pattern_count.as_slice(),
242                    s.atomic_count.as_slice(),
243                    s.atomic_overflow.as_slice(),
244                ]
245            })
246            .collect();
247
248        // vyre's wgpu readback ring is sized at DEFAULT_RING_SLOTS
249        // (lifted to 2048 in vendor/vyre - see
250        // `runtime/readback_ring.rs` for the rationale). Each
251        // GpuLiteralSet dispatch produces 2 readback buffers, so
252        // a batch of N shards burns 2N slots from the 2048-slot
253        // ring. The other constraint is host-side memory: each
254        // shard's haystack is borrowed (no copy), but its
255        // per-dispatch config + atomic counters still allocate
256        // ~24 bytes per shard. The real cost is the input-arrays
257        // Vec<[&[u8]; 8]> at ~64 bytes per entry.
258        //
259        // Adaptive batch cap: a bigger batch flattens the
260        // command-encoder cost across more shards and shortens
261        // the wall-clock for a multi-GiB scan, but climbs
262        // the ring-slot occupancy. 64 was the original safe
263        // value for small hosts; 256 still leaves the 2048-slot
264        // ring deeply under-subscribed and matches the workload
265        // a 24 GiB-VRAM card actually wants.
266        //
267        //   total RAM   shards/batch   1-GiB-scan sequential batches
268        //   < 16 GiB        64           ≥ 8
269        //   16-32 GiB      128             4
270        //   ≥ 32 GiB       256             2
271        //
272        // The 96-GiB-RAM RTX-5090 workstation case drops from
273        // 8 sequential batched dispatches to 2, cutting GPU
274        // pipeline-drain stalls roughly 4x on a 1-GiB batch.
275        let max_shards_per_gpu_batch: usize = {
276            let total_ram_mb = crate::hw_probe::probe_hardware()
277                .total_memory_mb
278                .unwrap_or(0);
279            if total_ram_mb >= 32 * 1024 {
280                256
281            } else if total_ram_mb >= 16 * 1024 {
282                128
283            } else {
284                64
285            }
286        };
287        let mut matches: Vec<vyre_libs::scan::LiteralMatch> = Vec::new();
288        for sub_start in (0..shard_count).step_by(max_shards_per_gpu_batch) {
289            let sub_end = (sub_start + max_shards_per_gpu_batch).min(shard_count);
290            let sub_inputs: Vec<&[&[u8]]> = (sub_start..sub_end)
291                .map(|i| &shard_input_arrays[i][..])
292                .collect();
293            let sub_configs: Vec<vyre::DispatchConfig> = (sub_start..sub_end)
294                .map(|i| shard_owned[i].config.clone())
295                .collect();
296
297            let batch_results =
298                match self.dispatch_gpu_shards(&matcher.program, &sub_inputs, &sub_configs) {
299                    Ok(r) => r,
300                    Err(e) => {
301                        tracing::error!(
302                            shards = sub_end - sub_start,
303                            "GPU batched dispatch failed, falling back to CPU: {e}"
304                        );
305                        let reason = format!("GPU literal-set batched dispatch failed: {e}");
306                        return self.gpu_degrade_done_with_reason(
307                            chunks,
308                            crate::hw_probe::ScanBackend::Gpu,
309                            Some(&reason),
310                        );
311                    }
312                };
313
314            for (offset_in_sub, result) in batch_results.into_iter().enumerate() {
315                let i = sub_start + offset_in_sub;
316                let outputs = match result {
317                    Ok(o) => o,
318                    Err(e) => {
319                        tracing::error!(
320                            shard_index = i,
321                            "GPU shard within batch failed, falling back to CPU: {e}"
322                        );
323                        let reason = format!("GPU literal-set shard {i} dispatch failed: {e}");
324                        return self.gpu_degrade_done_with_reason(
325                            chunks,
326                            crate::hw_probe::ScanBackend::Gpu,
327                            Some(&reason),
328                        );
329                    }
330                };
331                if outputs.len() < 2 {
332                    tracing::error!(
333                        shard_index = i,
334                        outputs = outputs.len(),
335                        "GPU shard output buffer count too small; falling back to CPU"
336                    );
337                    let reason = format!(
338                        "GPU literal-set shard {i} returned {} output buffer(s), expected at least 2",
339                        outputs.len()
340                    );
341                    return self.gpu_degrade_done_with_reason(
342                        chunks,
343                        crate::hw_probe::ScanBackend::Gpu,
344                        Some(&reason),
345                    );
346                }
347                let count_bytes = &outputs[0];
348                let matches_bytes = &outputs[1];
349                if count_bytes.len() < 4 {
350                    tracing::error!(
351                        shard_index = i,
352                        "GPU shard count buffer truncated; falling back to CPU"
353                    );
354                    let reason = format!(
355                        "GPU literal-set shard {i} returned truncated count buffer ({} byte(s), expected 4)",
356                        count_bytes.len()
357                    );
358                    return self.gpu_degrade_done_with_reason(
359                        chunks,
360                        crate::hw_probe::ScanBackend::Gpu,
361                        Some(&reason),
362                    );
363                }
364                let count = u32::from_le_bytes([
365                    count_bytes[0],
366                    count_bytes[1],
367                    count_bytes[2],
368                    count_bytes[3],
369                ]);
370                let shard_cap = shard_owned[i].cap;
371                if count > shard_cap {
372                    tracing::warn!(
373                        cap = shard_cap,
374                        count,
375                        shard_index = i,
376                        "GPU shard exceeded its cap: truncation possible; falling back to CPU"
377                    );
378                    let reason = format!(
379                        "GPU literal-set shard {i} reported {count} matches, exceeding cap {shard_cap}"
380                    );
381                    return self.gpu_degrade_done_with_reason(
382                        chunks,
383                        crate::hw_probe::ScanBackend::Gpu,
384                        Some(&reason),
385                    );
386                }
387                let shard_matches = vyre_libs::scan::dispatch_io::unpack_match_triples(
388                    matches_bytes,
389                    count.min(shard_cap),
390                );
391                let offset = shard_ranges[i].0 as u32;
392                for m in &shard_matches {
393                    matches.push(vyre_libs::scan::LiteralMatch::new(
394                        m.pattern_id,
395                        m.start.saturating_add(offset),
396                        m.end.saturating_add(offset),
397                    ));
398                }
399            }
400        }
401        let elapsed_ms = started.elapsed().as_millis();
402        tracing::debug!(
403            target: "keyhog::routing",
404            chunks = chunks.len(),
405            buffer_bytes = buffer.len(),
406            matches = matches.len(),
407            shards = shard_count,
408            cap,
409            elapsed_ms,
410            "vyre GPU batched scan completed"
411        );
412        if self.has_simd_prefilter()
413            && super::gpu_postprocess::gpu_phase2_hits_are_dense(
414                matches.len(),
415                buffer.len(),
416                chunks.len(),
417            )
418        {
419            tracing::warn!(
420                target: "keyhog::routing",
421                raw_matches = matches.len(),
422                buffer_bytes = buffer.len(),
423                chunks = chunks.len(),
424                "GPU literal prefix output is too dense for phase 2; rerouting this batch through SIMD coalesced scan",
425            );
426            if std::env::var_os("KH_PERF").is_some() {
427                eprintln!(
428                    "KH_PERF gpu_literal_dense_phase2_reroute: chunks={} buffer_bytes={} raw_matches={} bytes_per_hit={:.1}",
429                    chunks.len(),
430                    buffer.len(),
431                    matches.len(),
432                    buffer.len() as f64 / matches.len().max(1) as f64
433                );
434            }
435            return GpuPhase1Output::Done(self.scan_coalesced_non_gpu(chunks));
436        }
437        // Per-pid dedup + chunk attribution lives in `gpu_postprocess`,
438        // shared with the AC kernel phase-1 path. The downstream
439        // `scan_prepared_with_pattern_hits` consumer requires matches
440        // anchored to chunk-local `(pid, local_start, local_end)`
441        // triples sorted by start so the regex confirmation step runs
442        // anchored at each hit rather than re-sweeping each chunk.
443        super::gpu_postprocess::fold_overlapping_same_pid_inplace(&mut matches);
444        let total_patterns = self.ac_map.len() + self.fallback.len();
445        let per_chunk_hits = super::gpu_postprocess::attribute_matches_to_chunks(
446            &matches,
447            &entries,
448            total_patterns,
449            chunks.len(),
450        );
451
452        GpuPhase1Output::Hits(per_chunk_hits)
453    }
454}
455
456// Phase 2 (CPU post-process that runs after this file's GPU
457// literal-set dispatch produces per-chunk hits) lives in
458// `gpu_phase2.rs`. The orphan doc-comment that previously trailed
459// here described that function and was stranded when the body moved.