keyhog_scanner/engine/
gpu_ac_phase1.rs1use super::*;
2
3impl CompiledScanner {
4 pub fn scan_coalesced_gpu_ac_phase1(&self, chunks: &[keyhog_core::Chunk]) -> GpuPhase1Output {
5 let Some(matcher) = self.gpu_matcher() else {
6 return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
7 };
8 let Some(program) = self.ac_gpu_program() else {
9 return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
10 };
11 if self.gpu_backend.is_none() {
12 return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
13 }
14
15 let (entries, mut buffer) = super::gpu_coalesce::coalesce_chunks(chunks);
16
17 while !buffer.len().is_multiple_of(4) {
25 buffer.push(0);
26 }
27
28 #[cfg(target_os = "linux")]
29 unsafe {
32 libc::madvise(
33 buffer.as_ptr() as *mut libc::c_void,
34 buffer.len(),
35 libc::MADV_DONTDUMP,
36 );
37 }
38
39 let workgroup_x = program.workgroup_size[0] as usize;
40 const GPU_DISPATCH_MAX_WORKGROUPS_AC: usize = 65_535;
44 let gpu_dispatch_max_bytes: usize = GPU_DISPATCH_MAX_WORKGROUPS_AC * workgroup_x;
45 let started = std::time::Instant::now();
46
47 let mut shard_ranges: Vec<(usize, usize)> = Vec::new();
48 let mut shard_start = 0usize;
49 while shard_start < buffer.len() {
50 let shard_end = (shard_start + gpu_dispatch_max_bytes).min(buffer.len());
51 shard_ranges.push((shard_start, shard_end));
52 shard_start = shard_end;
53 }
54 let shard_count = shard_ranges.len();
55
56 let ac_packs = self
70 .gpu_ac_const_packs
71 .get_or_init(|| super::gpu_cache::AcConstPacks {
72 transitions: vyre_libs::scan::dispatch_io::pack_u32_slice(&matcher.dfa.transitions),
73 output_offsets: vyre_libs::scan::dispatch_io::pack_u32_slice(
74 &matcher.dfa.output_offsets,
75 ),
76 output_records: vyre_libs::scan::dispatch_io::pack_u32_slice(
77 &matcher.dfa.output_records,
78 ),
79 pattern_lengths: vyre_libs::scan::dispatch_io::pack_u32_slice(
80 &matcher.pattern_lengths,
81 ),
82 });
83
84 struct ShardOwnedAc {
85 haystack_len: Vec<u8>,
86 atomic_count: Vec<u8>,
87 config: vyre::DispatchConfig,
88 }
89 let mut shard_owned: Vec<ShardOwnedAc> = Vec::with_capacity(shard_count);
90 for &(s_start, s_end) in &shard_ranges {
91 let shard_len = (s_end - s_start) as u32;
92 shard_owned.push(ShardOwnedAc {
93 haystack_len: vyre_libs::scan::dispatch_io::pack_u32_slice(&[shard_len]),
94 atomic_count: vec![0u8; 4],
95 config: vyre_libs::scan::dispatch_io::byte_scan_dispatch_config(
96 shard_len,
97 program.workgroup_size[0],
98 ),
99 });
100 }
101
102 let shard_input_arrays: Vec<[&[u8]; 7]> = shard_owned
103 .iter()
104 .zip(shard_ranges.iter())
105 .map(|(s, &(start, end))| {
106 [
107 &buffer[start..end],
108 ac_packs.transitions.as_slice(),
109 ac_packs.output_offsets.as_slice(),
110 ac_packs.output_records.as_slice(),
111 ac_packs.pattern_lengths.as_slice(),
112 s.haystack_len.as_slice(),
113 s.atomic_count.as_slice(),
114 ]
115 })
116 .collect();
117
118 let max_shards_per_gpu_batch: usize = {
123 let total_ram_mb = crate::hw_probe::probe_hardware()
124 .total_memory_mb
125 .unwrap_or(0);
126 if total_ram_mb >= 32 * 1024 {
127 256
128 } else if total_ram_mb >= 16 * 1024 {
129 128
130 } else {
131 64
132 }
133 };
134 let mut matches: Vec<vyre_libs::scan::LiteralMatch> = Vec::new();
135 for sub_start in (0..shard_count).step_by(max_shards_per_gpu_batch) {
136 let sub_end = (sub_start + max_shards_per_gpu_batch).min(shard_count);
137 let sub_inputs: Vec<&[&[u8]]> = (sub_start..sub_end)
138 .map(|i| &shard_input_arrays[i][..])
139 .collect();
140 let sub_configs: Vec<vyre::DispatchConfig> = (sub_start..sub_end)
141 .map(|i| shard_owned[i].config.clone())
142 .collect();
143
144 let batch_results = match self.dispatch_gpu_shards(program, &sub_inputs, &sub_configs) {
145 Ok(r) => r,
146 Err(e) => {
147 tracing::error!(
148 shards = sub_end - sub_start,
149 "AC GPU batched dispatch failed, falling back to CPU: {e}"
150 );
151 return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
152 }
153 };
154
155 for (offset_in_sub, result) in batch_results.into_iter().enumerate() {
156 let i = sub_start + offset_in_sub;
157 let outputs = match result {
158 Ok(o) => o,
159 Err(e) => {
160 tracing::error!(
161 shard_index = i,
162 "AC GPU shard within batch failed, falling back to CPU: {e}"
163 );
164 return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
165 }
166 };
167 if outputs.len() < 2 {
168 tracing::error!(
169 shard_index = i,
170 outputs = outputs.len(),
171 "AC GPU shard output buffer count too small; falling back to CPU"
172 );
173 return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
174 }
175 let count_bytes = &outputs[0];
176 let matches_bytes = &outputs[1];
177 if count_bytes.len() < 4 {
178 tracing::error!(
179 shard_index = i,
180 "AC GPU shard count buffer truncated; falling back to CPU"
181 );
182 return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
183 }
184 let count = u32::from_le_bytes([
185 count_bytes[0],
186 count_bytes[1],
187 count_bytes[2],
188 count_bytes[3],
189 ]);
190 if count > super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH {
191 tracing::warn!(
192 cap = super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH,
193 count,
194 shard_index = i,
195 "AC GPU shard exceeded program cap: truncation possible; falling back to CPU"
196 );
197 return self.gpu_degrade_done(chunks, crate::hw_probe::ScanBackend::Gpu);
198 }
199 let shard_matches = vyre_libs::scan::dispatch_io::unpack_match_triples(
200 matches_bytes,
201 count.min(super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH),
202 );
203 let offset = shard_ranges[i].0 as u32;
204 for m in &shard_matches {
205 matches.push(vyre_libs::scan::LiteralMatch::new(
206 m.pattern_id,
207 m.start.saturating_add(offset),
208 m.end.saturating_add(offset),
209 ));
210 }
211 }
212 }
213 let elapsed_ms = started.elapsed().as_millis();
214 tracing::debug!(
215 target: "keyhog::routing",
216 chunks = chunks.len(),
217 buffer_bytes = buffer.len(),
218 matches = matches.len(),
219 shards = shard_count,
220 elapsed_ms,
221 "AC GPU batched scan completed"
222 );
223
224 super::gpu_postprocess::fold_overlapping_same_pid_inplace(&mut matches);
225 let total_patterns = self.ac_map.len() + self.fallback.len();
226 let per_chunk_hits = super::gpu_postprocess::attribute_matches_to_chunks(
227 &matches,
228 &entries,
229 total_patterns,
230 chunks.len(),
231 );
232
233 GpuPhase1Output::Hits(per_chunk_hits)
239 }
240}