1use super::*;
2use crate::hw_probe::ScanBackend;
3use keyhog_core::Chunk;
4
5use std::sync::Arc;
6
7pub(crate) struct PreparedChunk {
8 pub(crate) chunk: Arc<Chunk>,
9 pub(crate) preprocessed: ScannerPreprocessedText,
10}
11
12#[cfg(feature = "simd")]
19pub(crate) fn build_simd_scanner(
20 ac_map: &[CompiledPattern],
21 _fallback: &[(CompiledPattern, Vec<String>)],
22) -> Option<(crate::simd::backend::HsScanner, Vec<Vec<usize>>)> {
23 use std::collections::HashMap;
24
25 let mut regex_to_hs_id: HashMap<String, usize> = HashMap::new();
26 let mut hs_patterns: Vec<(usize, usize, String, bool)> = Vec::new();
27 let mut index_map: Vec<Vec<usize>> = Vec::new();
28
29 for (idx, entry) in ac_map.iter().enumerate() {
30 let regex_str = entry.regex.as_str();
31 let hs_id = *regex_to_hs_id
32 .entry(regex_str.to_string())
33 .or_insert_with(|| {
34 let id = hs_patterns.len();
35 hs_patterns.push((
36 entry.detector_index,
37 id,
38 regex_str.to_string(),
39 entry.group.is_some(),
40 ));
41 index_map.push(Vec::new());
42 id
43 });
44 index_map[hs_id].push(idx);
45 }
46
47 let pattern_refs: Vec<(usize, usize, &str, bool)> = hs_patterns
48 .iter()
49 .map(|(a, b, c, d)| (*a, *b, c.as_str(), *d))
50 .collect();
51
52 tracing::info!(
53 unique = hs_patterns.len(),
54 raw = ac_map.len(),
55 "compiling deduplicated AC regexes into Hyperscan"
56 );
57
58 match crate::simd::backend::HsScanner::compile(&pattern_refs) {
59 Ok((scanner, unsupported)) => {
60 tracing::info!(
61 compiled = scanner.pattern_count(),
62 unsupported = unsupported.len(),
63 "HS ready"
64 );
65 Some((scanner, index_map))
66 }
67 Err(error) => {
68 tracing::warn!("HS compilation failed: {error}");
69 None
70 }
71 }
72}
73
74impl CompiledScanner {
75 pub(crate) fn scan_chunks_with_backend_internal(
76 &self,
77 chunks: &[Chunk],
78 backend: ScanBackend,
79 ) -> Vec<Vec<RawMatch>> {
80 if backend != ScanBackend::Gpu || chunks.is_empty() || self.gpu_pattern_set.is_none() {
81 return chunks
82 .iter()
83 .map(|chunk| self.scan_with_backend(chunk, backend))
84 .collect();
85 }
86
87 let prepared: Vec<_> = chunks
88 .iter()
89 .map(|chunk| self.prepare_chunk(chunk))
90 .collect();
91
92 let total_patterns = self.ac_map.len() + self.fallback.len();
93 let mut triggered = vec![vec![0u64; total_patterns.div_ceil(64)]; prepared.len()];
94 if !self.populate_gpu_batch_triggers(&prepared, &mut triggered) {
95 let fallback_backend = self.degraded_backend_after_gpu_failure();
96 tracing::debug!(
97 fallback = fallback_backend.label(),
98 "gpu batch scan unavailable, degrading to non-gpu backend"
99 );
100 return chunks
101 .iter()
102 .map(|chunk| self.scan_with_backend(chunk, fallback_backend))
103 .collect();
104 }
105
106 prepared
107 .into_iter()
108 .zip(triggered)
109 .map(|(prepared, chunk_triggered)| {
110 self.scan_prepared_with_triggered(prepared, backend, chunk_triggered, None)
111 })
112 .collect()
113 }
114
115 pub(crate) fn prepare_chunk(&self, chunk: &Chunk) -> PreparedChunk {
116 let mut owned_normalized = None;
117 let owned_unicode;
118 let chunk = if chunk.data.is_ascii() {
119 chunk
120 } else {
121 normalize_scannable_chunk(chunk, &mut owned_normalized)
122 };
123
124 let chunk = if self.config.unicode_normalization {
125 let unicode_normalized = unicode_hardening::normalize_homoglyphs(&chunk.data);
126 if unicode_normalized != chunk.data {
127 owned_unicode = Some(keyhog_core::Chunk {
128 data: unicode_normalized,
129 metadata: chunk.metadata.clone(),
130 });
131 owned_unicode.as_ref().unwrap_or(chunk)
132 } else {
133 chunk
134 }
135 } else {
136 chunk
137 };
138
139 let preprocessed = if let Some(pp) =
140 crate::structured::preprocess(&chunk.data, chunk.metadata.path.as_deref())
141 {
142 pp
143 } else {
144 #[cfg(feature = "multiline")]
145 if crate::multiline::has_concatenation_indicators(&chunk.data) {
146 crate::multiline::preprocess_multiline(&chunk.data, &self.config.multiline)
147 } else {
148 ScannerPreprocessedText::passthrough(&chunk.data)
149 }
150 #[cfg(not(feature = "multiline"))]
151 ScannerPreprocessedText::passthrough(&chunk.data)
152 };
153
154 PreparedChunk {
155 chunk: Arc::new(chunk.clone()),
156 preprocessed,
157 }
158 }
159
160 pub(crate) fn scan_prepared_with_triggered(
161 &self,
162 prepared: PreparedChunk,
163 backend: ScanBackend,
164 triggered_patterns: Vec<u64>,
165 deadline: Option<std::time::Instant>,
166 ) -> Vec<RawMatch> {
167 let line_offsets = compute_line_offsets(&prepared.preprocessed.text);
168 let code_lines: Vec<&str> = prepared.chunk.data.lines().collect();
169 let documentation_lines = context::documentation_line_flags(&code_lines);
170 let mut scan_state = ScanState::default();
171
172 #[cfg(feature = "simdsieve")]
173 self.scan_hot_patterns_fast(
174 &prepared.preprocessed.text,
175 &line_offsets,
176 &prepared.chunk,
177 &mut scan_state,
178 );
179
180 let expanded_patterns = if backend == ScanBackend::Gpu {
181 triggered_patterns } else {
183 self.expand_triggered_patterns(&triggered_patterns)
184 };
185
186 let total_patterns = self.ac_map.len() + self.fallback.len();
187 let confirmed_patterns: Vec<usize> = if backend == ScanBackend::Gpu {
188 (0..total_patterns)
189 .filter(|&i| (expanded_patterns[i / 64] & (1 << (i % 64))) != 0)
190 .collect()
191 } else {
192 (0..self.ac_map.len())
193 .filter(|&i| (expanded_patterns[i / 64] & (1 << (i % 64))) != 0)
194 .collect()
195 };
196
197 self.extract_confirmed_patterns(
198 &confirmed_patterns,
199 &prepared.preprocessed,
200 &line_offsets,
201 &code_lines,
202 &documentation_lines,
203 &prepared.chunk,
204 &mut scan_state,
205 deadline,
206 );
207
208 self.scan_generic_assignments(&code_lines, &prepared.chunk, &mut scan_state);
212
213 #[cfg(feature = "entropy")]
214 self.scan_entropy_fallback(
215 &prepared.preprocessed,
216 &line_offsets,
217 &prepared.chunk,
218 &mut scan_state,
219 );
220
221 #[cfg(feature = "ml")]
222 self.apply_ml_batch_scores(&mut scan_state);
223
224 tracing::debug!(
225 backend = backend.label(),
226 path = prepared
227 .chunk
228 .metadata
229 .path
230 .as_deref()
231 .unwrap_or("<memory>"),
232 matches = scan_state.matches.len(),
233 "completed scan with selected backend"
234 );
235
236 scan_state.into_matches()
237 }
238
239 pub(crate) fn collect_triggered_patterns_for_backend(
240 &self,
241 text: &str,
242 backend: ScanBackend,
243 ) -> Vec<u64> {
244 match backend {
245 ScanBackend::Gpu => self.collect_triggered_patterns_gpu(text),
246 ScanBackend::SimdCpu => self.collect_triggered_patterns_simd(text),
247 ScanBackend::CpuFallback => self.collect_triggered_patterns_cpu(text),
248 }
249 }
250
251 fn collect_triggered_patterns_gpu(&self, text: &str) -> Vec<u64> {
252 if let Some(matcher) = self.gpu_matcher() {
253 match matcher.scan_blocking(text.as_bytes()) {
254 Ok(matches) => return self.triggered_patterns_from_gpu_matches(&matches),
255 Err(error) => {
256 tracing::debug!("gpu scan failed, degrading to CPU path: {error}");
257 }
258 }
259 }
260 self.collect_triggered_patterns_simd(text)
261 }
262
263 fn collect_triggered_patterns_simd(&self, text: &str) -> Vec<u64> {
264 #[cfg(feature = "simd")]
265 if let Some(scanner) = &self.simd_prefilter {
266 let mut triggered_patterns = vec![0u64; self.ac_map.len().div_ceil(64)];
267 for (hs_id, _start, _end) in scanner.scan(text.as_bytes()) {
268 let Some((_detector_index, ac_index, _has_group)) = scanner.pattern_info(hs_id)
269 else {
270 continue;
271 };
272 self.mark_triggered_pattern(&mut triggered_patterns, ac_index);
273 }
274 return triggered_patterns;
275 }
276
277 self.collect_triggered_patterns_cpu(text)
278 }
279
280 fn collect_triggered_patterns_cpu(&self, text: &str) -> Vec<u64> {
281 let mut triggered_patterns = vec![0u64; self.ac_map.len().div_ceil(64)];
282 if let Some(ac) = &self.ac {
283 for ac_match in ac.scan(text.as_bytes()).unwrap_or_default() {
284 self.mark_triggered_pattern(&mut triggered_patterns, ac_match.pattern_id as usize);
285 }
286 }
287 triggered_patterns
288 }
289
290 fn triggered_patterns_from_gpu_matches(&self, matches: &[warpstate::Match]) -> Vec<u64> {
291 let total_patterns = self.ac_map.len() + self.fallback.len();
292 let mut triggered_patterns = vec![0u64; total_patterns.div_ceil(64)];
293 for matched in matches {
294 let pattern_index = matched.pattern_id as usize;
295 if pattern_index >= total_patterns {
296 continue;
297 }
298 triggered_patterns[pattern_index / 64] |= 1u64 << (pattern_index % 64);
299 }
300 triggered_patterns
301 }
302
303 fn mark_triggered_pattern(&self, triggered_patterns: &mut [u64], pattern_index: usize) {
304 if pattern_index / 64 >= triggered_patterns.len() {
305 return;
306 }
307 triggered_patterns[pattern_index / 64] |= 1u64 << (pattern_index % 64);
308 if pattern_index < self.prefix_propagation.len() {
309 for &propagated_index in &self.prefix_propagation[pattern_index] {
310 if propagated_index / 64 < triggered_patterns.len() {
311 triggered_patterns[propagated_index / 64] |= 1u64 << (propagated_index % 64);
312 }
313 }
314 }
315 }
316
317 pub fn gpu_matcher(&self) -> Option<&warpstate::AutoMatcher> {
318 self.gpu_matcher
319 .get_or_init(|| {
320 let patterns = self.gpu_pattern_set.as_ref()?.clone();
321 let config = warpstate::AutoMatcherConfig::new()
322 .gpu_threshold(0)
323 .gpu_max_input_size(usize::MAX / 2)
324 .auto_tune_threshold(false)
325 .max_matches(self.config.max_matches_per_chunk.min(u32::MAX as usize) as u32);
326 let handle = std::thread::spawn(move || {
329 pollster::block_on(warpstate::AutoMatcher::with_config(&patterns, config))
330 });
331 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
332 loop {
333 if handle.is_finished() {
334 break;
335 }
336 if std::time::Instant::now() > deadline {
337 tracing::warn!("GPU matcher init timed out (5s)");
338 return None;
339 }
340 std::thread::sleep(std::time::Duration::from_millis(50));
341 }
342 match handle.join().ok()? {
343 Ok(matcher) => {
344 if let Err(e) = matcher.scan_blocking(b"x") {
346 tracing::debug!("GPU warm-up scan failed: {e}");
347 } else {
348 tracing::debug!("GPU warm-up scan completed");
349 }
350 Some(matcher)
351 }
352 Err(error) => {
353 tracing::warn!("failed to initialize warpstate GPU matcher: {error}");
354 None
355 }
356 }
357 })
358 .as_ref()
359 }
360
361 fn degraded_backend_after_gpu_failure(&self) -> ScanBackend {
362 let caps = crate::hw_probe::probe_hardware();
363 if caps.has_avx512 || caps.has_avx2 || caps.has_neon {
364 ScanBackend::SimdCpu
365 } else {
366 ScanBackend::CpuFallback
367 }
368 }
369
370 fn populate_gpu_batch_triggers(
371 &self,
372 prepared: &[PreparedChunk],
373 triggered: &mut [Vec<u64>],
374 ) -> bool {
375 let Some(matcher) = self.gpu_matcher() else {
376 return false;
377 };
378
379 const MAX_BATCH_BYTES: usize = 64 * 1024 * 1024;
380 const MAX_BATCH_ITEMS: usize = 2048;
381
382 let mut start = 0usize;
383 while start < prepared.len() {
384 let mut end = start;
385 let mut batch_bytes = 0usize;
386 while end < prepared.len() && end - start < MAX_BATCH_ITEMS {
387 let len = prepared[end].preprocessed.text.len();
388 if end > start && batch_bytes.saturating_add(len) > MAX_BATCH_BYTES {
389 break;
390 }
391 batch_bytes = batch_bytes.saturating_add(len);
392 end += 1;
393 }
394
395 let (entries, buffer) = coalesce_preprocessed_batch(&prepared[start..end]);
396 let matches = match matcher.scan_blocking(&buffer) {
397 Ok(matches) => matches,
398 Err(error) => {
399 tracing::warn!("batched GPU scan failed: {error}");
400 return false;
401 }
402 };
403
404 map_batch_matches(self, &entries, matches, &mut triggered[start..end]);
405 start = end;
406 }
407
408 true
409 }
410}
411
412fn coalesce_preprocessed_batch(
413 prepared: &[PreparedChunk],
414) -> (Vec<(usize, usize, usize)>, Vec<u8>) {
415 let total_bytes = prepared
416 .iter()
417 .map(|chunk| chunk.preprocessed.text.len())
418 .sum();
419 let mut entries = Vec::with_capacity(prepared.len());
420 let mut buffer = Vec::with_capacity(total_bytes);
421
422 for (index, chunk) in prepared.iter().enumerate() {
423 let start = buffer.len();
424 buffer.extend_from_slice(chunk.preprocessed.text.as_bytes());
425 entries.push((index, start, chunk.preprocessed.text.len()));
426 }
427
428 (entries, buffer)
429}
430
431fn map_batch_matches(
432 scanner: &CompiledScanner,
433 entries: &[(usize, usize, usize)],
434 matches: Vec<warpstate::Match>,
435 triggered: &mut [Vec<u64>],
436) {
437 let mut cursor = 0usize;
438 for matched in matches {
439 let global_start = matched.start as usize;
440 let global_end = matched.end as usize;
441
442 while cursor < entries.len() {
443 let (_, offset, len) = entries[cursor];
444 if global_start < offset + len {
445 break;
446 }
447 cursor += 1;
448 }
449 if cursor >= entries.len() {
450 break;
451 }
452
453 let (chunk_index, offset, len) = entries[cursor];
454 if global_start < offset || global_end > offset + len {
455 continue;
456 }
457 scanner.mark_triggered_pattern(&mut triggered[chunk_index], matched.pattern_id as usize);
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::{PreparedChunk, coalesce_preprocessed_batch, map_batch_matches};
464 use crate::engine::CompiledScanner;
465 use crate::types::ScannerPreprocessedText;
466 use keyhog_core::{Chunk, ChunkMetadata, DetectorSpec, PatternSpec, Severity};
467 use std::sync::Arc;
468
469 fn chunk() -> Chunk {
470 Chunk {
471 data: String::new(),
472 metadata: ChunkMetadata::default(),
473 }
474 }
475
476 #[test]
477 fn coalescing_preserves_offsets() {
478 let prepared = vec![
479 PreparedChunk {
480 chunk: Arc::new(chunk()),
481 preprocessed: ScannerPreprocessedText::passthrough("abc"),
482 },
483 PreparedChunk {
484 chunk: Arc::new(chunk()),
485 preprocessed: ScannerPreprocessedText::passthrough("defg"),
486 },
487 ];
488
489 let (entries, buffer) = coalesce_preprocessed_batch(&prepared);
490 assert_eq!(entries, vec![(0, 0, 3), (1, 3, 4)]);
491 assert_eq!(buffer, b"abcdefg");
492 }
493
494 #[test]
495 fn cross_boundary_matches_are_dropped() {
496 let scanner = CompiledScanner::compile(vec![DetectorSpec {
497 id: "demo-token".into(),
498 name: "Demo Token".into(),
499 service: "demo".into(),
500 severity: Severity::High,
501 patterns: vec![PatternSpec {
502 regex: "abc".into(),
503 description: None,
504 group: None,
505 }],
506 companions: vec![],
507 verify: None,
508 keywords: vec!["abc".into()],
509 ..Default::default()
510 }])
511 .unwrap();
512 let entries = vec![(0usize, 0usize, 3usize), (1usize, 3usize, 3usize)];
513 let matches = vec![
514 warpstate::Match::from_parts(0, 1, 2),
515 warpstate::Match::from_parts(0, 2, 4),
516 ];
517 let mut triggered = vec![vec![0u64; 1], vec![0u64; 1]];
518
519 map_batch_matches(&scanner, &entries, matches, &mut triggered);
520
521 assert_eq!(triggered[0][0], 1);
522 assert_eq!(triggered[1][0], 0);
523 }
524}