Skip to main content

keyhog_scanner/engine/
gpu_ac_phase1.rs

1use super::*;
2
3static GPU_AC_DEGENERATE_DISABLED: std::sync::atomic::AtomicBool =
4    std::sync::atomic::AtomicBool::new(false);
5
6impl CompiledScanner {
7    pub fn scan_coalesced_gpu_ac_phase1(&self, chunks: &[keyhog_core::Chunk]) -> GpuPhase1Output {
8        let Some(matcher) = self.gpu_matcher() else {
9            return self.gpu_degrade_done_with_reason(
10                chunks,
11                crate::hw_probe::ScanBackend::Gpu,
12                Some("GPU literal matcher unavailable for AC dispatch"),
13            );
14        };
15        let Some(program) = self.ac_gpu_program() else {
16            return self.gpu_degrade_done_with_reason(
17                chunks,
18                crate::hw_probe::ScanBackend::Gpu,
19                Some("GPU AC dispatch program unavailable"),
20            );
21        };
22        if self.gpu_backend.is_none() {
23            return self.gpu_degrade_done_with_reason(
24                chunks,
25                crate::hw_probe::ScanBackend::Gpu,
26                Some("GPU backend handle unavailable for AC dispatch"),
27            );
28        }
29        if GPU_AC_DEGENERATE_DISABLED.load(std::sync::atomic::Ordering::Relaxed) {
30            return self.gpu_degrade_done_with_reason(
31                chunks,
32                crate::hw_probe::ScanBackend::Gpu,
33                Some("GPU AC previously emitted degenerate match triples (end <= start); skipping known-corrupt Vyre dispatch"),
34            );
35        }
36
37        let (entries, mut buffer) = super::gpu_coalesce::coalesce_chunks(chunks);
38
39        // ASCII-lowercase the coalesced haystack so the AC literal automaton
40        // matches case-INSENSITIVELY, exactly like the SIMD Hyperscan path
41        // (compiled CASELESS for every pattern). Without this the GPU drops
42        // matches on uppercase occurrences of lowercase literal prefixes
43        // (PERF-07 gpu_parity: `csb_` literal vs `CSB_...` in soc21_enum.h ->
44        // SIMD 4, GPU 0). The literal set is lowercased to the same fold in
45        // `build_gpu_literals`. This buffer is the phase-1 PREFILTER only -
46        // phase 2 re-confirms each hit on the ORIGINAL chunk bytes with the
47        // caseless regex - and ASCII fold is 1-byte-to-1-byte (only A-Z), so
48        // the match offsets attributed back to chunks are unchanged and the
49        // reported credential keeps its original case.
50        buffer.make_ascii_lowercase();
51
52        // Same buffer 4-alignment trick as `scan_coalesced_gpu`: lets
53        // every shard pass `&buffer[start..end]` straight to vyre's
54        // u32-typed haystack input instead of running pack_haystack_u32
55        // (a 2x memcopy producing byte-identical output for aligned
56        // slices). Eliminates ~2x buffer.len() of transient allocations
57        // per scan. NUL padding is recall-safe (literals can't contain
58        // NUL).
59        while !buffer.len().is_multiple_of(4) {
60            buffer.push(0);
61        }
62
63        #[cfg(target_os = "linux")]
64        // SAFETY: same contract as scan_coalesced_gpu - `buffer` is a
65        // live owned Vec describing a valid range; madvise is advisory.
66        unsafe {
67            libc::madvise(
68                buffer.as_ptr() as *mut libc::c_void,
69                buffer.len(),
70                libc::MADV_DONTDUMP,
71            );
72        }
73
74        let workgroup_x = program.workgroup_size[0] as usize;
75        // WGSL workgroups-per-dim ceiling is 65 535. At workgroup_x = 64
76        // that's a ~4 MiB shard. The shard cap is here so we never feed
77        // the dispatch a workgroup count > 65 535 (validation error).
78        const GPU_DISPATCH_MAX_WORKGROUPS_AC: usize = 65_535;
79        let gpu_dispatch_max_bytes: usize = GPU_DISPATCH_MAX_WORKGROUPS_AC * workgroup_x;
80        let started = std::time::Instant::now();
81
82        let mut shard_ranges: Vec<(usize, usize)> = Vec::new();
83        let mut shard_start = 0usize;
84        while shard_start < buffer.len() {
85            let shard_end = (shard_start + gpu_dispatch_max_bytes).min(buffer.len());
86            shard_ranges.push((shard_start, shard_end));
87            shard_start = shard_end;
88        }
89        let shard_count = shard_ranges.len();
90
91        // Constants packed ONCE per process via the scanner-level
92        // OnceLock. Same rationale as `scan_coalesced_gpu`: AC kernel
93        // re-ran four `pack_u32_slice` calls on identical bytes every
94        // dispatch.
95        // The AC program's binding layout:
96        //   0: haystack (per shard, slice into padded buffer)
97        //   1: transitions
98        //   2: output_offsets
99        //   3: output_records
100        //   4: pattern_lengths
101        //   5: haystack_len (per shard, packed)
102        //   6: match_count (per shard, atomic counter)
103        //   7: matches (output, backend-allocated from BufferDecl)
104        let ac_packs = self
105            .gpu_ac_const_packs
106            .get_or_init(|| super::gpu_cache::AcConstPacks {
107                transitions: vyre_libs::scan::dispatch_io::pack_u32_slice(&matcher.dfa.transitions),
108                output_offsets: vyre_libs::scan::dispatch_io::pack_u32_slice(
109                    &matcher.dfa.output_offsets,
110                ),
111                output_records: vyre_libs::scan::dispatch_io::pack_u32_slice(
112                    &matcher.dfa.output_records,
113                ),
114                pattern_lengths: vyre_libs::scan::dispatch_io::pack_u32_slice(
115                    &matcher.pattern_lengths,
116                ),
117            });
118
119        struct ShardOwnedAc {
120            haystack_len: Vec<u8>,
121            atomic_count: Vec<u8>,
122            config: vyre::DispatchConfig,
123        }
124        let mut shard_owned: Vec<ShardOwnedAc> = Vec::with_capacity(shard_count);
125        for &(s_start, s_end) in &shard_ranges {
126            let shard_len = (s_end - s_start) as u32;
127            shard_owned.push(ShardOwnedAc {
128                haystack_len: vyre_libs::scan::dispatch_io::pack_u32_slice(&[shard_len]),
129                atomic_count: vec![0u8; 4],
130                config: vyre_libs::scan::dispatch_io::byte_scan_dispatch_config(
131                    shard_len,
132                    program.workgroup_size[0],
133                ),
134            });
135        }
136
137        let shard_input_arrays: Vec<[&[u8]; 7]> = shard_owned
138            .iter()
139            .zip(shard_ranges.iter())
140            .map(|(s, &(start, end))| {
141                [
142                    &buffer[start..end],
143                    ac_packs.transitions.as_slice(),
144                    ac_packs.output_offsets.as_slice(),
145                    ac_packs.output_records.as_slice(),
146                    ac_packs.pattern_lengths.as_slice(),
147                    s.haystack_len.as_slice(),
148                    s.atomic_count.as_slice(),
149                ]
150            })
151            .collect();
152
153        // Sub-batched dispatch: dynamically scaled MAX_SHARDS_PER_GPU_BATCH
154        // budget based on system RAM keeps transient host-side memory
155        // bounded while maximizing dispatch concurrency for high-tier GPUs
156        // and leaving vyre's 2048-slot readback ring deeply under-subscribed.
157        let max_shards_per_gpu_batch: usize = {
158            let total_ram_mb = crate::hw_probe::probe_hardware()
159                .total_memory_mb
160                .unwrap_or(0);
161            if total_ram_mb >= 32 * 1024 {
162                256
163            } else if total_ram_mb >= 16 * 1024 {
164                128
165            } else {
166                64
167            }
168        };
169        let mut matches: Vec<vyre_libs::scan::LiteralMatch> = Vec::new();
170        for sub_start in (0..shard_count).step_by(max_shards_per_gpu_batch) {
171            let sub_end = (sub_start + max_shards_per_gpu_batch).min(shard_count);
172            let sub_inputs: Vec<&[&[u8]]> = (sub_start..sub_end)
173                .map(|i| &shard_input_arrays[i][..])
174                .collect();
175            let sub_configs: Vec<vyre::DispatchConfig> = (sub_start..sub_end)
176                .map(|i| shard_owned[i].config.clone())
177                .collect();
178
179            let batch_results = match self.dispatch_gpu_shards(program, &sub_inputs, &sub_configs) {
180                Ok(r) => r,
181                Err(e) => {
182                    tracing::error!(
183                        shards = sub_end - sub_start,
184                        "AC GPU batched dispatch failed, falling back to CPU: {e}"
185                    );
186                    let reason = format!("AC GPU batched dispatch failed: {e}");
187                    return self.gpu_degrade_done_with_reason(
188                        chunks,
189                        crate::hw_probe::ScanBackend::Gpu,
190                        Some(&reason),
191                    );
192                }
193            };
194
195            for (offset_in_sub, result) in batch_results.into_iter().enumerate() {
196                let i = sub_start + offset_in_sub;
197                let outputs = match result {
198                    Ok(o) => o,
199                    Err(e) => {
200                        tracing::error!(
201                            shard_index = i,
202                            "AC GPU shard within batch failed, falling back to CPU: {e}"
203                        );
204                        let reason = format!("AC GPU shard {i} dispatch failed: {e}");
205                        return self.gpu_degrade_done_with_reason(
206                            chunks,
207                            crate::hw_probe::ScanBackend::Gpu,
208                            Some(&reason),
209                        );
210                    }
211                };
212                if outputs.len() < 2 {
213                    tracing::error!(
214                        shard_index = i,
215                        outputs = outputs.len(),
216                        "AC GPU shard output buffer count too small; falling back to CPU"
217                    );
218                    let reason = format!(
219                        "AC GPU shard {i} returned {} output buffer(s), expected at least 2",
220                        outputs.len()
221                    );
222                    return self.gpu_degrade_done_with_reason(
223                        chunks,
224                        crate::hw_probe::ScanBackend::Gpu,
225                        Some(&reason),
226                    );
227                }
228                let count_bytes = &outputs[0];
229                let matches_bytes = &outputs[1];
230                if count_bytes.len() < 4 {
231                    tracing::error!(
232                        shard_index = i,
233                        "AC GPU shard count buffer truncated; falling back to CPU"
234                    );
235                    let reason = format!(
236                        "AC GPU shard {i} returned truncated count buffer ({} byte(s), expected 4)",
237                        count_bytes.len()
238                    );
239                    return self.gpu_degrade_done_with_reason(
240                        chunks,
241                        crate::hw_probe::ScanBackend::Gpu,
242                        Some(&reason),
243                    );
244                }
245                let count = u32::from_le_bytes([
246                    count_bytes[0],
247                    count_bytes[1],
248                    count_bytes[2],
249                    count_bytes[3],
250                ]);
251                if count > super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH {
252                    tracing::warn!(
253                        cap = super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH,
254                        count,
255                        shard_index = i,
256                        "AC GPU shard exceeded dense-prefix cap; rerouting batch through SIMD coalesced scan"
257                    );
258                    if self.has_simd_prefilter() {
259                        if std::env::var_os("KH_PERF").is_some() {
260                            eprintln!(
261                                "KH_PERF gpu_ac_cap_reroute: chunks={} shard={} shard_matches={} cap={} shard_bytes={}",
262                                chunks.len(),
263                                i,
264                                count,
265                                super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH,
266                                shard_ranges[i].1 - shard_ranges[i].0
267                            );
268                        }
269                        return GpuPhase1Output::Done(self.scan_coalesced_non_gpu(chunks));
270                    }
271                    let reason = format!(
272                        "AC GPU shard {i} reported {count} matches, exceeding dense-prefix cap {} and no SIMD fallback is available",
273                        super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH
274                    );
275                    return self.gpu_degrade_done_with_reason(
276                        chunks,
277                        crate::hw_probe::ScanBackend::Gpu,
278                        Some(&reason),
279                    );
280                }
281                let shard_matches = vyre_libs::scan::dispatch_io::unpack_match_triples(
282                    matches_bytes,
283                    count.min(super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH),
284                );
285                let offset = shard_ranges[i].0 as u32;
286                for m in &shard_matches {
287                    matches.push(vyre_libs::scan::LiteralMatch::new(
288                        m.pattern_id,
289                        m.start.saturating_add(offset),
290                        m.end.saturating_add(offset),
291                    ));
292                }
293            }
294        }
295        let elapsed_ms = started.elapsed().as_millis();
296        tracing::debug!(
297            target: "keyhog::routing",
298            chunks = chunks.len(),
299            buffer_bytes = buffer.len(),
300            matches = matches.len(),
301            shards = shard_count,
302            elapsed_ms,
303            "AC GPU batched scan completed"
304        );
305
306        // PERF-07c correctness guard: a sound AC kernel emits `end = i + 1`
307        // and `start = end - pat_len` with `pat_len >= 1`, so EVERY real match
308        // has `end > start`. A triple with `end <= start` (observed: a flood of
309        // degenerate `(pid=0, start=0, end=0)`) is impossible from correct
310        // output. The vyre CUDA PTX emit path currently produces such triples;
311        // folded to `(0,0)` they mis-attribute every PID to chunk 0 of a
312        // coalesced batch, silently dropping real hits in chunks > 0 - a
313        // fail-OPEN recall gap that only manifests on multi-file batches
314        // (single-file scans put the target in chunk 0 and mask it). Until the
315        // emitter is fixed (tracked as the vyre GPU upgrade), detect the
316        // corruption and degrade THIS batch to the SIMD/CPU literal path, which
317        // is correct and - measured on the kernel - actually faster than the
318        // GPU AC path here. The GPU MoE scorer still runs in phase 2. This is
319        // self-validating: a backend that emits sound triples (zero degenerate)
320        // never degrades, so the guard auto-clears once vyre's CUDA emit is
321        // fixed, with no keyhog change required.
322        if matches.iter().any(|m| m.end <= m.start) {
323            GPU_AC_DEGENERATE_DISABLED.store(true, std::sync::atomic::Ordering::Relaxed);
324            tracing::warn!(
325                target: "keyhog::routing",
326                raw_matches = matches.len(),
327                chunks = chunks.len(),
328                "GPU AC emitted degenerate match triples (end <= start); vyre CUDA \
329                 emit bug PERF-07c. Degrading this batch to the SIMD/CPU literal \
330                 path to preserve recall parity."
331            );
332            return self.gpu_degrade_done_with_reason(
333                chunks,
334                crate::hw_probe::ScanBackend::Gpu,
335                Some("GPU AC emitted degenerate match triples (end <= start); vyre CUDA emit bug PERF-07c"),
336            );
337        }
338        if self.has_simd_prefilter()
339            && super::gpu_postprocess::gpu_phase2_hits_are_dense(
340                matches.len(),
341                buffer.len(),
342                chunks.len(),
343            )
344        {
345            tracing::warn!(
346                target: "keyhog::routing",
347                raw_matches = matches.len(),
348                buffer_bytes = buffer.len(),
349                chunks = chunks.len(),
350                "GPU AC prefix output is too dense for phase 2; rerouting this batch through SIMD coalesced scan",
351            );
352            if std::env::var_os("KH_PERF").is_some() {
353                eprintln!(
354                    "KH_PERF gpu_ac_dense_phase2_reroute: chunks={} buffer_bytes={} raw_matches={} bytes_per_hit={:.1}",
355                    chunks.len(),
356                    buffer.len(),
357                    matches.len(),
358                    buffer.len() as f64 / matches.len().max(1) as f64
359                );
360            }
361            return GpuPhase1Output::Done(self.scan_coalesced_non_gpu(chunks));
362        }
363        super::gpu_postprocess::fold_overlapping_same_pid_inplace(&mut matches);
364        let total_patterns = self.ac_map.len() + self.fallback.len();
365        let per_chunk_hits = super::gpu_postprocess::attribute_matches_to_chunks(
366            &matches,
367            &entries,
368            total_patterns,
369            chunks.len(),
370        );
371
372        // Hand the hits back to the orchestrator so it can run phase 2
373        // on a separate thread (pipelined). Combined-wrapper callers
374        // (`scan_coalesced_gpu_ac`) call phase 2 inline immediately
375        // after this returns, preserving the original synchronous
376        // behaviour.
377        GpuPhase1Output::Hits(per_chunk_hits)
378    }
379}