Skip to main content

keyhog_scanner/engine/
gpu_ac_phase1.rs

1use super::*;
2
3impl CompiledScanner {
4    pub fn scan_coalesced_gpu_ac_phase1(&self, chunks: &[keyhog_core::Chunk]) -> GpuPhase1Output {
5        let Some(matcher) = self.gpu_matcher() else {
6            return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
7        };
8        let Some(program) = self.ac_gpu_program() else {
9            return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
10        };
11        if self.gpu_backend.is_none() {
12            return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
13        }
14
15        let (entries, mut buffer) = super::gpu_coalesce::coalesce_chunks(chunks);
16
17        // Same buffer 4-alignment trick as `scan_coalesced_gpu`: lets
18        // every shard pass `&buffer[start..end]` straight to vyre's
19        // u32-typed haystack input instead of running pack_haystack_u32
20        // (a 2x memcopy producing byte-identical output for aligned
21        // slices). Eliminates ~2x buffer.len() of transient allocations
22        // per scan. NUL padding is recall-safe (literals can't contain
23        // NUL).
24        while !buffer.len().is_multiple_of(4) {
25            buffer.push(0);
26        }
27
28        #[cfg(target_os = "linux")]
29        // SAFETY: same contract as scan_coalesced_gpu - `buffer` is a
30        // live owned Vec describing a valid range; madvise is advisory.
31        unsafe {
32            libc::madvise(
33                buffer.as_ptr() as *mut libc::c_void,
34                buffer.len(),
35                libc::MADV_DONTDUMP,
36            );
37        }
38
39        let workgroup_x = program.workgroup_size[0] as usize;
40        // WGSL workgroups-per-dim ceiling is 65 535. At workgroup_x = 64
41        // that's a ~4 MiB shard. The shard cap is here so we never feed
42        // the dispatch a workgroup count > 65 535 (validation error).
43        const GPU_DISPATCH_MAX_WORKGROUPS_AC: usize = 65_535;
44        let gpu_dispatch_max_bytes: usize = GPU_DISPATCH_MAX_WORKGROUPS_AC * workgroup_x;
45        let started = std::time::Instant::now();
46
47        let mut shard_ranges: Vec<(usize, usize)> = Vec::new();
48        let mut shard_start = 0usize;
49        while shard_start < buffer.len() {
50            let shard_end = (shard_start + gpu_dispatch_max_bytes).min(buffer.len());
51            shard_ranges.push((shard_start, shard_end));
52            shard_start = shard_end;
53        }
54        let shard_count = shard_ranges.len();
55
56        // Constants packed ONCE per process via the scanner-level
57        // OnceLock. Same rationale as `scan_coalesced_gpu`: AC kernel
58        // re-ran four `pack_u32_slice` calls on identical bytes every
59        // dispatch.
60        // The AC program's binding layout:
61        //   0: haystack (per shard, slice into padded buffer)
62        //   1: transitions
63        //   2: output_offsets
64        //   3: output_records
65        //   4: pattern_lengths
66        //   5: haystack_len (per shard, packed)
67        //   6: match_count (per shard, atomic counter)
68        //   7: matches (output, backend-allocated from BufferDecl)
69        let ac_packs = self
70            .gpu_ac_const_packs
71            .get_or_init(|| super::gpu_cache::AcConstPacks {
72                transitions: vyre_libs::scan::dispatch_io::pack_u32_slice(&matcher.dfa.transitions),
73                output_offsets: vyre_libs::scan::dispatch_io::pack_u32_slice(
74                    &matcher.dfa.output_offsets,
75                ),
76                output_records: vyre_libs::scan::dispatch_io::pack_u32_slice(
77                    &matcher.dfa.output_records,
78                ),
79                pattern_lengths: vyre_libs::scan::dispatch_io::pack_u32_slice(
80                    &matcher.pattern_lengths,
81                ),
82            });
83
84        struct ShardOwnedAc {
85            haystack_len: Vec<u8>,
86            atomic_count: Vec<u8>,
87            config: vyre::DispatchConfig,
88        }
89        let mut shard_owned: Vec<ShardOwnedAc> = Vec::with_capacity(shard_count);
90        for &(s_start, s_end) in &shard_ranges {
91            let shard_len = (s_end - s_start) as u32;
92            shard_owned.push(ShardOwnedAc {
93                haystack_len: vyre_libs::scan::dispatch_io::pack_u32_slice(&[shard_len]),
94                atomic_count: vec![0u8; 4],
95                config: vyre_libs::scan::dispatch_io::byte_scan_dispatch_config(
96                    shard_len,
97                    program.workgroup_size[0],
98                ),
99            });
100        }
101
102        let shard_input_arrays: Vec<[&[u8]; 7]> = shard_owned
103            .iter()
104            .zip(shard_ranges.iter())
105            .map(|(s, &(start, end))| {
106                [
107                    &buffer[start..end],
108                    ac_packs.transitions.as_slice(),
109                    ac_packs.output_offsets.as_slice(),
110                    ac_packs.output_records.as_slice(),
111                    ac_packs.pattern_lengths.as_slice(),
112                    s.haystack_len.as_slice(),
113                    s.atomic_count.as_slice(),
114                ]
115            })
116            .collect();
117
118        // Sub-batched dispatch: dynamically scaled MAX_SHARDS_PER_GPU_BATCH
119        // budget based on system RAM keeps transient host-side memory
120        // bounded while maximizing dispatch concurrency for high-tier GPUs
121        // and leaving vyre's 2048-slot readback ring deeply under-subscribed.
122        let max_shards_per_gpu_batch: usize = {
123            let total_ram_mb = crate::hw_probe::probe_hardware()
124                .total_memory_mb
125                .unwrap_or(0);
126            if total_ram_mb >= 32 * 1024 {
127                256
128            } else if total_ram_mb >= 16 * 1024 {
129                128
130            } else {
131                64
132            }
133        };
134        let mut matches: Vec<vyre_libs::scan::LiteralMatch> = Vec::new();
135        for sub_start in (0..shard_count).step_by(max_shards_per_gpu_batch) {
136            let sub_end = (sub_start + max_shards_per_gpu_batch).min(shard_count);
137            let sub_inputs: Vec<&[&[u8]]> = (sub_start..sub_end)
138                .map(|i| &shard_input_arrays[i][..])
139                .collect();
140            let sub_configs: Vec<vyre::DispatchConfig> = (sub_start..sub_end)
141                .map(|i| shard_owned[i].config.clone())
142                .collect();
143
144            let batch_results = match self.dispatch_gpu_shards(program, &sub_inputs, &sub_configs) {
145                Ok(r) => r,
146                Err(e) => {
147                    tracing::error!(
148                        shards = sub_end - sub_start,
149                        "AC GPU batched dispatch failed, falling back to CPU: {e}"
150                    );
151                    return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
152                }
153            };
154
155            for (offset_in_sub, result) in batch_results.into_iter().enumerate() {
156                let i = sub_start + offset_in_sub;
157                let outputs = match result {
158                    Ok(o) => o,
159                    Err(e) => {
160                        tracing::error!(
161                            shard_index = i,
162                            "AC GPU shard within batch failed, falling back to CPU: {e}"
163                        );
164                        return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
165                    }
166                };
167                if outputs.len() < 2 {
168                    tracing::error!(
169                        shard_index = i,
170                        outputs = outputs.len(),
171                        "AC GPU shard output buffer count too small; falling back to CPU"
172                    );
173                    return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
174                }
175                let count_bytes = &outputs[0];
176                let matches_bytes = &outputs[1];
177                if count_bytes.len() < 4 {
178                    tracing::error!(
179                        shard_index = i,
180                        "AC GPU shard count buffer truncated; falling back to CPU"
181                    );
182                    return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
183                }
184                let count = u32::from_le_bytes([
185                    count_bytes[0],
186                    count_bytes[1],
187                    count_bytes[2],
188                    count_bytes[3],
189                ]);
190                if count > super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH {
191                    tracing::warn!(
192                        cap = super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH,
193                        count,
194                        shard_index = i,
195                        "AC GPU shard exceeded program cap: truncation possible; falling back to CPU"
196                    );
197                    return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
198                }
199                let shard_matches = vyre_libs::scan::dispatch_io::unpack_match_triples(
200                    matches_bytes,
201                    count.min(super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH),
202                );
203                let offset = shard_ranges[i].0 as u32;
204                for m in &shard_matches {
205                    matches.push(vyre_libs::scan::LiteralMatch::new(
206                        m.pattern_id,
207                        m.start.saturating_add(offset),
208                        m.end.saturating_add(offset),
209                    ));
210                }
211            }
212        }
213        let elapsed_ms = started.elapsed().as_millis();
214        tracing::debug!(
215            target: "keyhog::routing",
216            chunks = chunks.len(),
217            buffer_bytes = buffer.len(),
218            matches = matches.len(),
219            shards = shard_count,
220            elapsed_ms,
221            "AC GPU batched scan completed"
222        );
223
224        super::gpu_postprocess::fold_overlapping_same_pid_inplace(&mut matches);
225        let total_patterns = self.ac_map.len() + self.fallback.len();
226        let per_chunk_hits = super::gpu_postprocess::attribute_matches_to_chunks(
227            &matches,
228            &entries,
229            total_patterns,
230            chunks.len(),
231        );
232
233        // Hand the hits back to the orchestrator so it can run phase 2
234        // on a separate thread (pipelined). Combined-wrapper callers
235        // (`scan_coalesced_gpu_ac`) call phase 2 inline immediately
236        // after this returns, preserving the original synchronous
237        // behaviour.
238        GpuPhase1Output::Hits(per_chunk_hits)
239    }
240}