Skip to main content

keyhog_scanner/engine/
gpu_literal_phase1.rs

1use super::*;
2
3impl CompiledScanner {
4    pub fn scan_coalesced_gpu_phase1(&self, chunks: &[keyhog_core::Chunk]) -> GpuPhase1Output {
5        // The literal_set program embeds `append_match_subgroup`
6        // (subgroup_ballot + subgroup_shuffle), and vyre's canonical
7        // pre-emit lowering rejects that subgroup form regardless of
8        // the downstream emitter ("variable `_vyre_match_leader` is
9        // referenced before binding"). This was previously gated to
10        // CUDA only, but the rejection happens BEFORE driver-specific
11        // emission, so WGPU hosts (Apple Silicon, Intel Mac, Windows)
12        // hit the same rejection on the literal_set path and silently
13        // dropped to CPU.
14        //
15        // Until the vyre pre-emit lowering accepts the subgroup form
16        // (tracked separately), the AC kernel path is the working
17        // GPU code path for both CUDA and WGPU. KEYHOG_GPU_KERNEL=
18        // literal-set forces the broken path for diagnostic /
19        // bisection use; the default is now AC for every GPU backend.
20        // Cache the env-var lookup. `scan_coalesced_gpu_phase1` is called
21        // per batched chunk group; reading env::var on the hot path costs
22        // ~200 ns per call which adds up to milliseconds across 1k+
23        // chunks. The diagnostic override is process-static so caching
24        // once is exact.
25        static FORCE_LITERAL_SET: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
26        let force_literal_set = *FORCE_LITERAL_SET.get_or_init(|| {
27            matches!(
28                std::env::var("KEYHOG_GPU_KERNEL").ok().as_deref(),
29                Some("literal-set") | Some("literal_set")
30            )
31        });
32        if !force_literal_set {
33            return self.scan_coalesced_gpu_ac_phase1(chunks);
34        }
35
36        // Auto-degrade to the next-best backend when the GPU stack is not
37        // ready: no compiled matcher (no adapter at probe time), the cached
38        // device went away, or the persistent backend is missing.
39        let Some(matcher) = self.gpu_matcher() else {
40            return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
41        };
42        if self.gpu_backend.is_none() {
43            return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
44        }
45
46        let (entries, mut buffer) = super::gpu_coalesce::coalesce_chunks(chunks);
47
48        // 4-byte align the coalesced buffer so every shard slice can be
49        // passed to vyre's u32-typed haystack input WITHOUT a per-shard
50        // `pack_haystack_u32` call. The pack helper is a 2x memcopy
51        // (Vec<u32> intermediate + Vec<u8> output) that produces bytes
52        // byte-identical to the input on 4-aligned slices (see
53        // `vyre_foundation::byte_pack::pack_haystack_u32`). On a 1 GiB
54        // scan with 2 MiB shards that's 512 shards x 2x = ~4 GiB of
55        // throwaway allocations - load-bearing on the 25s gap GPU
56        // currently loses to SIMD at scale. Padding the source buffer
57        // once and slicing each shard collapses that to zero alloc per
58        // shard. Padding bytes are NUL, which no detector literal can
59        // match (extract_literal_prefixes drops NUL), so the trailing
60        // zero-extension is recall-safe.
61        while !buffer.len().is_multiple_of(4) {
62            buffer.push(0);
63        }
64
65        #[cfg(target_os = "linux")]
66        // SAFETY: `buffer` is a live `Vec<u8>` whose `as_ptr()` and
67        // `len()` describe a valid memory range owned by this scope.
68        // `madvise` is advisory - the kernel may ignore it on
69        // non-page-aligned ranges; we treat the call as best-effort
70        // and don't check the return value.
71        unsafe {
72            // Senior Audit §Phase 7.4: Prevent GPU buffers from leaking into core dumps.
73            libc::madvise(
74                buffer.as_ptr() as *mut libc::c_void,
75                buffer.len(),
76                libc::MADV_DONTDUMP,
77            );
78        }
79
80        // Adaptive match cap that scales with the actual buffer size
81        // rather than chunk count. Real-world ceiling: roughly one
82        // literal hit per 64 input bytes is already implausibly dense
83        // for production source code (the densest fixture in the
84        // performance regression suite is ~1 hit per 1 KiB). The
85        // chunk-count formula systematically under-sized batches that
86        // had a few large files, leading to spurious truncation and
87        // the full-CPU re-scan that wastes the GPU dispatch we just
88        // paid for.
89        //
90        // Keeps the kimi-wave2 `cap+1` sentinel-slot trick: ask the
91        // GPU for one more than the cap, and only treat `> cap` as
92        // truncation. A batch that lands EXACTLY at the cap is by
93        // definition complete (would have written into the sentinel
94        // slot otherwise).
95        const MIN_CAP: u32 = 100_000;
96        const MAX_CAP: u32 = 16_000_000;
97        let buffer_cap = (buffer.len() / 64) as u64;
98        let cap: u32 = buffer_cap.clamp(MIN_CAP as u64, MAX_CAP as u64) as u32;
99
100        // wgpu caps each compute dispatch at 65535 workgroups per
101        // dimension (WebGPU spec). Vyre's GpuLiteralSet uses
102        // workgroup_size_x = 32, so a single dispatch can handle at
103        // most 65535 × 32 = 2,097,120 input bytes. For coalesced
104        // batches larger than this (always true with the tier-aware
105        // 2 MiB activation threshold + the orchestrator's adaptive
106        // `batch_bytes_budget` - 256 MiB default, up to 1 GiB on
107        // 24-GiB-VRAM cards), shard the buffer into 2-MiB-or-less
108        // pieces, dispatch each, and merge the matches with a
109        // `start` offset added to put them back into the global
110        // buffer's coordinate space.
111        //
112        // Shard size: 65535 (max workgroups per dim) × 32 (vyre's
113        // workgroup_size_x) = 2,097,120 bytes. Exactly 2 MiB =
114        // 2,097,152 bytes overflows by one workgroup. Use the
115        // exact-aligned value to maximise per-shard throughput
116        // without tripping the wgpu dispatch validator.
117        //
118        // Extra dispatches add ~100 µs each on a high-tier GPU; for
119        // a 256 MiB batch that's ~12 ms of overhead vs SIMD's ~70 s
120        // (a 5800× win). On a 1 GiB batch (5090-class adapter) the
121        // shard count rises 4× but the GPU-vs-SIMD ratio widens
122        // because per-shard dispatch is amortized over more bytes.
123        // Dynamic per-vyre-workgroup: each shard covers
124        // (max_workgroups_per_dim × workgroup_size_x) bytes.
125        // wgpu caps workgroups per dimension at 65 535; vyre's
126        // literal-set program reports its `workgroup_size_x` via
127        // `matcher.program.workgroup_size[0]`. Was hard-coded at
128        // 65_535 × 32 when vyre's literal-set used
129        // workgroup_size_x = 32; now scales automatically when
130        // the vyre side is tuned (e.g. to 128 to cut shard count
131        // by 4×).
132        let workgroup_x = matcher.program.workgroup_size[0] as usize;
133        let gpu_dispatch_max_bytes: usize = 65_535 * workgroup_x;
134        let started = std::time::Instant::now();
135
136        // Slice the coalesced buffer into wgpu-dispatch-sized shards.
137        // The shard boundary itself is wgpu's `dispatch_workgroups`
138        // limit (65 535 workgroups per dimension × 32-byte workgroup
139        // size). The previous flow dispatched these one-by-one with
140        // `matcher.scan` - each call records its own encoder,
141        // submits, and `device.poll(Wait)`s. On a 1 GiB batch with
142        // 512 shards that adds up to ~50 ms × 512 = 25 s of pure
143        // host-side dispatch overhead, *not* GPU compute.
144        //
145        // `WgpuBackend::dispatch_borrowed_batch` records *all* shard
146        // dispatches into one command encoder, single submit, single
147        // poll. For 512 shards the wait collapses from ~25 s to
148        // a single GPU drain - close to the actual compute time.
149        let mut shard_ranges: Vec<(usize, usize)> = Vec::new();
150        let mut shard_start = 0usize;
151        while shard_start < buffer.len() {
152            let shard_end = (shard_start + gpu_dispatch_max_bytes).min(buffer.len());
153            shard_ranges.push((shard_start, shard_end));
154            shard_start = shard_end;
155        }
156        let shard_count = shard_ranges.len();
157
158        // Constants across all shards: pattern offsets/lengths/bytes
159        // and pattern_count. Pre-packed ONCE per process via the
160        // CompiledScanner-level OnceLock and borrowed every dispatch.
161        // Before this cache, `pack_u32_slice` ran four times per scan
162        // producing identical bytes; a process scanning 10 k files
163        // burned 40 k throwaway Vec<u8> allocations on data that never
164        // changes after compile.
165        let const_packs = self
166            .gpu_const_packs
167            .get_or_init(|| super::gpu_cache::GpuConstPacks {
168                pattern_offsets: vyre_libs::scan::dispatch_io::pack_u32_slice(
169                    &matcher.pattern_offsets,
170                ),
171                pattern_lengths: vyre_libs::scan::dispatch_io::pack_u32_slice(
172                    &matcher.pattern_lengths,
173                ),
174                pattern_bytes: vyre_libs::scan::dispatch_io::pack_u32_slice(&matcher.pattern_bytes),
175                pattern_count: vyre_libs::scan::dispatch_io::pack_u32_slice(&[matcher
176                    .pattern_lengths
177                    .len()
178                    as u32]),
179            });
180
181        // Per-shard tiny bytes (shard_len scalar + the two atomic
182        // counters + dispatch config). The haystack input is the
183        // 4-byte-aligned source buffer sliced in place - no Vec<u8>
184        // packing allocation per shard (see the buffer padding above
185        // for the rationale).
186        struct ShardOwned {
187            haystack_len: Vec<u8>,
188            atomic_count: Vec<u8>,
189            atomic_overflow: Vec<u8>,
190            config: vyre::DispatchConfig,
191            cap: u32,
192        }
193        let mut shard_owned: Vec<ShardOwned> = Vec::with_capacity(shard_count);
194        for (start, end) in &shard_ranges {
195            let shard_len = (*end - *start) as u32;
196            let shard_cap_u64 = ((*end - *start) / 64) as u64;
197            let shard_cap = shard_cap_u64.clamp(MIN_CAP as u64, MAX_CAP as u64) as u32;
198            shard_owned.push(ShardOwned {
199                haystack_len: vyre_libs::scan::dispatch_io::pack_u32_slice(&[shard_len]),
200                atomic_count: vec![0u8; 4],
201                atomic_overflow: vec![0u8; 4],
202                config: vyre_libs::scan::dispatch_io::byte_scan_dispatch_config(
203                    shard_len,
204                    matcher.program.workgroup_size[0],
205                ),
206                cap: shard_cap,
207            });
208        }
209
210        // Build borrowed input arrays per shard. Order must match
211        // `GpuLiteralSet::scan` because the buffer-decl order is the
212        // contract between host inputs and GPU kernel binding. The
213        // haystack slot is now a direct slice into the padded source
214        // buffer - no per-shard packing allocation.
215        let shard_input_arrays: Vec<[&[u8]; 8]> = shard_owned
216            .iter()
217            .zip(shard_ranges.iter())
218            .map(|(s, (start, end))| {
219                [
220                    &buffer[*start..*end],
221                    const_packs.pattern_offsets.as_slice(),
222                    const_packs.pattern_lengths.as_slice(),
223                    const_packs.pattern_bytes.as_slice(),
224                    s.haystack_len.as_slice(),
225                    const_packs.pattern_count.as_slice(),
226                    s.atomic_count.as_slice(),
227                    s.atomic_overflow.as_slice(),
228                ]
229            })
230            .collect();
231
232        // vyre's wgpu readback ring is sized at DEFAULT_RING_SLOTS
233        // (lifted to 2048 in vendor/vyre - see
234        // `runtime/readback_ring.rs` for the rationale). Each
235        // GpuLiteralSet dispatch produces 2 readback buffers, so
236        // a batch of N shards burns 2N slots from the 2048-slot
237        // ring. The other constraint is host-side memory: each
238        // shard's haystack is borrowed (no copy), but its
239        // per-dispatch config + atomic counters still allocate
240        // ~24 bytes per shard. The real cost is the input-arrays
241        // Vec<[&[u8]; 8]> at ~64 bytes per entry.
242        //
243        // Adaptive batch cap: a bigger batch flattens the
244        // command-encoder cost across more shards and shortens
245        // the wall-clock for a multi-GiB scan, but climbs
246        // the ring-slot occupancy. 64 was the original safe
247        // value for small hosts; 256 still leaves the 2048-slot
248        // ring deeply under-subscribed and matches the workload
249        // a 24 GiB-VRAM card actually wants.
250        //
251        //   total RAM   shards/batch   1-GiB-scan sequential batches
252        //   < 16 GiB        64           ≥ 8
253        //   16-32 GiB      128             4
254        //   ≥ 32 GiB       256             2
255        //
256        // The 96-GiB-RAM RTX-5090 workstation case drops from
257        // 8 sequential batched dispatches to 2, cutting GPU
258        // pipeline-drain stalls roughly 4x on a 1-GiB batch.
259        let max_shards_per_gpu_batch: usize = {
260            let total_ram_mb = crate::hw_probe::probe_hardware()
261                .total_memory_mb
262                .unwrap_or(0);
263            if total_ram_mb >= 32 * 1024 {
264                256
265            } else if total_ram_mb >= 16 * 1024 {
266                128
267            } else {
268                64
269            }
270        };
271        let mut matches: Vec<vyre_libs::scan::LiteralMatch> = Vec::new();
272        for sub_start in (0..shard_count).step_by(max_shards_per_gpu_batch) {
273            let sub_end = (sub_start + max_shards_per_gpu_batch).min(shard_count);
274            let sub_inputs: Vec<&[&[u8]]> = (sub_start..sub_end)
275                .map(|i| &shard_input_arrays[i][..])
276                .collect();
277            let sub_configs: Vec<vyre::DispatchConfig> = (sub_start..sub_end)
278                .map(|i| shard_owned[i].config.clone())
279                .collect();
280
281            let batch_results =
282                match self.dispatch_gpu_shards(&matcher.program, &sub_inputs, &sub_configs) {
283                    Ok(r) => r,
284                    Err(e) => {
285                        tracing::error!(
286                            shards = sub_end - sub_start,
287                            "GPU batched dispatch failed, falling back to CPU: {e}"
288                        );
289                        return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
290                    }
291                };
292
293            for (offset_in_sub, result) in batch_results.into_iter().enumerate() {
294                let i = sub_start + offset_in_sub;
295                let outputs = match result {
296                    Ok(o) => o,
297                    Err(e) => {
298                        tracing::error!(
299                            shard_index = i,
300                            "GPU shard within batch failed, falling back to CPU: {e}"
301                        );
302                        return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
303                    }
304                };
305                if outputs.len() < 2 {
306                    tracing::error!(
307                        shard_index = i,
308                        outputs = outputs.len(),
309                        "GPU shard output buffer count too small; falling back to CPU"
310                    );
311                    return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
312                }
313                let count_bytes = &outputs[0];
314                let matches_bytes = &outputs[1];
315                if count_bytes.len() < 4 {
316                    tracing::error!(
317                        shard_index = i,
318                        "GPU shard count buffer truncated; falling back to CPU"
319                    );
320                    return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
321                }
322                let count = u32::from_le_bytes([
323                    count_bytes[0],
324                    count_bytes[1],
325                    count_bytes[2],
326                    count_bytes[3],
327                ]);
328                let shard_cap = shard_owned[i].cap;
329                if count > shard_cap {
330                    tracing::warn!(
331                        cap = shard_cap,
332                        count,
333                        shard_index = i,
334                        "GPU shard exceeded its cap: truncation possible; falling back to CPU"
335                    );
336                    return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
337                }
338                let shard_matches = vyre_libs::scan::dispatch_io::unpack_match_triples(
339                    matches_bytes,
340                    count.min(shard_cap),
341                );
342                let offset = shard_ranges[i].0 as u32;
343                for m in &shard_matches {
344                    matches.push(vyre_libs::scan::LiteralMatch::new(
345                        m.pattern_id,
346                        m.start.saturating_add(offset),
347                        m.end.saturating_add(offset),
348                    ));
349                }
350            }
351        }
352        let elapsed_ms = started.elapsed().as_millis();
353        tracing::debug!(
354            target: "keyhog::routing",
355            chunks = chunks.len(),
356            buffer_bytes = buffer.len(),
357            matches = matches.len(),
358            shards = shard_count,
359            cap,
360            elapsed_ms,
361            "vyre GPU batched scan completed"
362        );
363        // Per-pid dedup + chunk attribution lives in `gpu_postprocess`,
364        // shared with the AC kernel phase-1 path. The downstream
365        // `scan_prepared_with_pattern_hits` consumer requires matches
366        // anchored to chunk-local `(pid, local_start, local_end)`
367        // triples sorted by start so the regex confirmation step runs
368        // anchored at each hit rather than re-sweeping each chunk.
369        super::gpu_postprocess::fold_overlapping_same_pid_inplace(&mut matches);
370        let total_patterns = self.ac_map.len() + self.fallback.len();
371        let per_chunk_hits = super::gpu_postprocess::attribute_matches_to_chunks(
372            &matches,
373            &entries,
374            total_patterns,
375            chunks.len(),
376        );
377
378        GpuPhase1Output::Hits(per_chunk_hits)
379    }
380}
381
382// Phase 2 (CPU post-process that runs after this file's GPU
383// literal-set dispatch produces per-chunk hits) lives in
384// `gpu_phase2.rs`. The orphan doc-comment that previously trailed
385// here described that function and was stranded when the body moved.