1use super::*;
2
3static GPU_AC_DEGENERATE_DISABLED: std::sync::atomic::AtomicBool =
4 std::sync::atomic::AtomicBool::new(false);
5
6impl CompiledScanner {
7 pub fn scan_coalesced_gpu_ac_phase1(&self, chunks: &[keyhog_core::Chunk]) -> GpuPhase1Output {
8 let Some(matcher) = self.gpu_matcher() else {
9 return self.gpu_degrade_done_with_reason(
10 chunks,
11 crate::hw_probe::ScanBackend::Gpu,
12 Some("GPU literal matcher unavailable for AC dispatch"),
13 );
14 };
15 let Some(program) = self.ac_gpu_program() else {
16 return self.gpu_degrade_done_with_reason(
17 chunks,
18 crate::hw_probe::ScanBackend::Gpu,
19 Some("GPU AC dispatch program unavailable"),
20 );
21 };
22 if self.gpu_backend.is_none() {
23 return self.gpu_degrade_done_with_reason(
24 chunks,
25 crate::hw_probe::ScanBackend::Gpu,
26 Some("GPU backend handle unavailable for AC dispatch"),
27 );
28 }
29 if GPU_AC_DEGENERATE_DISABLED.load(std::sync::atomic::Ordering::Relaxed) {
30 return self.gpu_degrade_done_with_reason(
31 chunks,
32 crate::hw_probe::ScanBackend::Gpu,
33 Some("GPU AC previously emitted degenerate match triples (end <= start); skipping known-corrupt Vyre dispatch"),
34 );
35 }
36
37 let (entries, mut buffer) = super::gpu_coalesce::coalesce_chunks(chunks);
38
39 buffer.make_ascii_lowercase();
51
52 while !buffer.len().is_multiple_of(4) {
60 buffer.push(0);
61 }
62
63 #[cfg(target_os = "linux")]
64 unsafe {
67 libc::madvise(
68 buffer.as_ptr() as *mut libc::c_void,
69 buffer.len(),
70 libc::MADV_DONTDUMP,
71 );
72 }
73
74 let workgroup_x = program.workgroup_size[0] as usize;
75 const GPU_DISPATCH_MAX_WORKGROUPS_AC: usize = 65_535;
79 let gpu_dispatch_max_bytes: usize = GPU_DISPATCH_MAX_WORKGROUPS_AC * workgroup_x;
80 let started = std::time::Instant::now();
81
82 let mut shard_ranges: Vec<(usize, usize)> = Vec::new();
83 let mut shard_start = 0usize;
84 while shard_start < buffer.len() {
85 let shard_end = (shard_start + gpu_dispatch_max_bytes).min(buffer.len());
86 shard_ranges.push((shard_start, shard_end));
87 shard_start = shard_end;
88 }
89 let shard_count = shard_ranges.len();
90
91 let ac_packs = self
105 .gpu_ac_const_packs
106 .get_or_init(|| super::gpu_cache::AcConstPacks {
107 transitions: vyre_libs::scan::dispatch_io::pack_u32_slice(&matcher.dfa.transitions),
108 output_offsets: vyre_libs::scan::dispatch_io::pack_u32_slice(
109 &matcher.dfa.output_offsets,
110 ),
111 output_records: vyre_libs::scan::dispatch_io::pack_u32_slice(
112 &matcher.dfa.output_records,
113 ),
114 pattern_lengths: vyre_libs::scan::dispatch_io::pack_u32_slice(
115 &matcher.pattern_lengths,
116 ),
117 });
118
119 struct ShardOwnedAc {
120 haystack_len: Vec<u8>,
121 atomic_count: Vec<u8>,
122 config: vyre::DispatchConfig,
123 }
124 let mut shard_owned: Vec<ShardOwnedAc> = Vec::with_capacity(shard_count);
125 for &(s_start, s_end) in &shard_ranges {
126 let shard_len = (s_end - s_start) as u32;
127 shard_owned.push(ShardOwnedAc {
128 haystack_len: vyre_libs::scan::dispatch_io::pack_u32_slice(&[shard_len]),
129 atomic_count: vec![0u8; 4],
130 config: vyre_libs::scan::dispatch_io::byte_scan_dispatch_config(
131 shard_len,
132 program.workgroup_size[0],
133 ),
134 });
135 }
136
137 let shard_input_arrays: Vec<[&[u8]; 7]> = shard_owned
138 .iter()
139 .zip(shard_ranges.iter())
140 .map(|(s, &(start, end))| {
141 [
142 &buffer[start..end],
143 ac_packs.transitions.as_slice(),
144 ac_packs.output_offsets.as_slice(),
145 ac_packs.output_records.as_slice(),
146 ac_packs.pattern_lengths.as_slice(),
147 s.haystack_len.as_slice(),
148 s.atomic_count.as_slice(),
149 ]
150 })
151 .collect();
152
153 let max_shards_per_gpu_batch: usize = {
158 let total_ram_mb = crate::hw_probe::probe_hardware()
159 .total_memory_mb
160 .unwrap_or(0);
161 if total_ram_mb >= 32 * 1024 {
162 256
163 } else if total_ram_mb >= 16 * 1024 {
164 128
165 } else {
166 64
167 }
168 };
169 let mut matches: Vec<vyre_libs::scan::LiteralMatch> = Vec::new();
170 for sub_start in (0..shard_count).step_by(max_shards_per_gpu_batch) {
171 let sub_end = (sub_start + max_shards_per_gpu_batch).min(shard_count);
172 let sub_inputs: Vec<&[&[u8]]> = (sub_start..sub_end)
173 .map(|i| &shard_input_arrays[i][..])
174 .collect();
175 let sub_configs: Vec<vyre::DispatchConfig> = (sub_start..sub_end)
176 .map(|i| shard_owned[i].config.clone())
177 .collect();
178
179 let batch_results = match self.dispatch_gpu_shards(program, &sub_inputs, &sub_configs) {
180 Ok(r) => r,
181 Err(e) => {
182 tracing::error!(
183 shards = sub_end - sub_start,
184 "AC GPU batched dispatch failed, falling back to CPU: {e}"
185 );
186 let reason = format!("AC GPU batched dispatch failed: {e}");
187 return self.gpu_degrade_done_with_reason(
188 chunks,
189 crate::hw_probe::ScanBackend::Gpu,
190 Some(&reason),
191 );
192 }
193 };
194
195 for (offset_in_sub, result) in batch_results.into_iter().enumerate() {
196 let i = sub_start + offset_in_sub;
197 let outputs = match result {
198 Ok(o) => o,
199 Err(e) => {
200 tracing::error!(
201 shard_index = i,
202 "AC GPU shard within batch failed, falling back to CPU: {e}"
203 );
204 let reason = format!("AC GPU shard {i} dispatch failed: {e}");
205 return self.gpu_degrade_done_with_reason(
206 chunks,
207 crate::hw_probe::ScanBackend::Gpu,
208 Some(&reason),
209 );
210 }
211 };
212 if outputs.len() < 2 {
213 tracing::error!(
214 shard_index = i,
215 outputs = outputs.len(),
216 "AC GPU shard output buffer count too small; falling back to CPU"
217 );
218 let reason = format!(
219 "AC GPU shard {i} returned {} output buffer(s), expected at least 2",
220 outputs.len()
221 );
222 return self.gpu_degrade_done_with_reason(
223 chunks,
224 crate::hw_probe::ScanBackend::Gpu,
225 Some(&reason),
226 );
227 }
228 let count_bytes = &outputs[0];
229 let matches_bytes = &outputs[1];
230 if count_bytes.len() < 4 {
231 tracing::error!(
232 shard_index = i,
233 "AC GPU shard count buffer truncated; falling back to CPU"
234 );
235 let reason = format!(
236 "AC GPU shard {i} returned truncated count buffer ({} byte(s), expected 4)",
237 count_bytes.len()
238 );
239 return self.gpu_degrade_done_with_reason(
240 chunks,
241 crate::hw_probe::ScanBackend::Gpu,
242 Some(&reason),
243 );
244 }
245 let count = u32::from_le_bytes([
246 count_bytes[0],
247 count_bytes[1],
248 count_bytes[2],
249 count_bytes[3],
250 ]);
251 if count > super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH {
252 tracing::warn!(
253 cap = super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH,
254 count,
255 shard_index = i,
256 "AC GPU shard exceeded dense-prefix cap; rerouting batch through SIMD coalesced scan"
257 );
258 if self.has_simd_prefilter() {
259 if std::env::var_os("KH_PERF").is_some() {
260 eprintln!(
261 "KH_PERF gpu_ac_cap_reroute: chunks={} shard={} shard_matches={} cap={} shard_bytes={}",
262 chunks.len(),
263 i,
264 count,
265 super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH,
266 shard_ranges[i].1 - shard_ranges[i].0
267 );
268 }
269 return GpuPhase1Output::Done(self.scan_coalesced_non_gpu(chunks));
270 }
271 let reason = format!(
272 "AC GPU shard {i} reported {count} matches, exceeding dense-prefix cap {} and no SIMD fallback is available",
273 super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH
274 );
275 return self.gpu_degrade_done_with_reason(
276 chunks,
277 crate::hw_probe::ScanBackend::Gpu,
278 Some(&reason),
279 );
280 }
281 let shard_matches = vyre_libs::scan::dispatch_io::unpack_match_triples(
282 matches_bytes,
283 count.min(super::rule_pipeline::AC_GPU_MAX_MATCHES_PER_DISPATCH),
284 );
285 let offset = shard_ranges[i].0 as u32;
286 for m in &shard_matches {
287 matches.push(vyre_libs::scan::LiteralMatch::new(
288 m.pattern_id,
289 m.start.saturating_add(offset),
290 m.end.saturating_add(offset),
291 ));
292 }
293 }
294 }
295 let elapsed_ms = started.elapsed().as_millis();
296 tracing::debug!(
297 target: "keyhog::routing",
298 chunks = chunks.len(),
299 buffer_bytes = buffer.len(),
300 matches = matches.len(),
301 shards = shard_count,
302 elapsed_ms,
303 "AC GPU batched scan completed"
304 );
305
306 if matches.iter().any(|m| m.end <= m.start) {
323 GPU_AC_DEGENERATE_DISABLED.store(true, std::sync::atomic::Ordering::Relaxed);
324 tracing::warn!(
325 target: "keyhog::routing",
326 raw_matches = matches.len(),
327 chunks = chunks.len(),
328 "GPU AC emitted degenerate match triples (end <= start); vyre CUDA \
329 emit bug PERF-07c. Degrading this batch to the SIMD/CPU literal \
330 path to preserve recall parity."
331 );
332 return self.gpu_degrade_done_with_reason(
333 chunks,
334 crate::hw_probe::ScanBackend::Gpu,
335 Some("GPU AC emitted degenerate match triples (end <= start); vyre CUDA emit bug PERF-07c"),
336 );
337 }
338 if self.has_simd_prefilter()
339 && super::gpu_postprocess::gpu_phase2_hits_are_dense(
340 matches.len(),
341 buffer.len(),
342 chunks.len(),
343 )
344 {
345 tracing::warn!(
346 target: "keyhog::routing",
347 raw_matches = matches.len(),
348 buffer_bytes = buffer.len(),
349 chunks = chunks.len(),
350 "GPU AC prefix output is too dense for phase 2; rerouting this batch through SIMD coalesced scan",
351 );
352 if std::env::var_os("KH_PERF").is_some() {
353 eprintln!(
354 "KH_PERF gpu_ac_dense_phase2_reroute: chunks={} buffer_bytes={} raw_matches={} bytes_per_hit={:.1}",
355 chunks.len(),
356 buffer.len(),
357 matches.len(),
358 buffer.len() as f64 / matches.len().max(1) as f64
359 );
360 }
361 return GpuPhase1Output::Done(self.scan_coalesced_non_gpu(chunks));
362 }
363 super::gpu_postprocess::fold_overlapping_same_pid_inplace(&mut matches);
364 let total_patterns = self.ac_map.len() + self.fallback.len();
365 let per_chunk_hits = super::gpu_postprocess::attribute_matches_to_chunks(
366 &matches,
367 &entries,
368 total_patterns,
369 chunks.len(),
370 );
371
372 GpuPhase1Output::Hits(per_chunk_hits)
378 }
379}