Skip to main content

keyhog_scanner/engine/
backend.rs

1use super::*;
2use crate::hw_probe::ScanBackend;
3use keyhog_core::Chunk;
4
5use std::sync::Arc;
6
7pub(crate) struct PreparedChunk {
8    pub(crate) chunk: Arc<Chunk>,
9    pub(crate) preprocessed: ScannerPreprocessedText,
10}
11
12/// Build a Hyperscan database from ALL detector patterns.
13///
14/// Unlike the old approach that compiled AC literal prefixes (1500+ escaped
15/// strings), this compiles the ACTUAL full regexes — one per unique
16/// (detector_index, pattern) pair. This is what Titus does: compile every
17/// regex into one Hyperscan database, scan once.
18#[cfg(feature = "simd")]
19pub(crate) fn build_simd_scanner(
20    ac_map: &[CompiledPattern],
21    _fallback: &[(CompiledPattern, Vec<String>)],
22) -> Option<(crate::simd::backend::HsScanner, Vec<Vec<usize>>)> {
23    use std::collections::HashMap;
24
25    let mut regex_to_hs_id: HashMap<String, usize> = HashMap::new();
26    let mut hs_patterns: Vec<(usize, usize, String, bool)> = Vec::new();
27    let mut index_map: Vec<Vec<usize>> = Vec::new();
28
29    for (idx, entry) in ac_map.iter().enumerate() {
30        let regex_str = entry.regex.as_str();
31        let hs_id = *regex_to_hs_id
32            .entry(regex_str.to_string())
33            .or_insert_with(|| {
34                let id = hs_patterns.len();
35                hs_patterns.push((
36                    entry.detector_index,
37                    id,
38                    regex_str.to_string(),
39                    entry.group.is_some(),
40                ));
41                index_map.push(Vec::new());
42                id
43            });
44        index_map[hs_id].push(idx);
45    }
46
47    let pattern_refs: Vec<(usize, usize, &str, bool)> = hs_patterns
48        .iter()
49        .map(|(a, b, c, d)| (*a, *b, c.as_str(), *d))
50        .collect();
51
52    tracing::info!(
53        unique = hs_patterns.len(),
54        raw = ac_map.len(),
55        "compiling deduplicated AC regexes into Hyperscan"
56    );
57
58    match crate::simd::backend::HsScanner::compile(&pattern_refs) {
59        Ok((scanner, unsupported)) => {
60            tracing::info!(
61                compiled = scanner.pattern_count(),
62                unsupported = unsupported.len(),
63                "HS ready"
64            );
65            Some((scanner, index_map))
66        }
67        Err(error) => {
68            tracing::warn!("HS compilation failed: {error}");
69            None
70        }
71    }
72}
73
74impl CompiledScanner {
75    pub(crate) fn scan_chunks_with_backend_internal(
76        &self,
77        chunks: &[Chunk],
78        backend: ScanBackend,
79    ) -> Vec<Vec<RawMatch>> {
80        if backend != ScanBackend::Gpu || chunks.is_empty() || self.gpu_pattern_set.is_none() {
81            return chunks
82                .iter()
83                .map(|chunk| self.scan_with_backend(chunk, backend))
84                .collect();
85        }
86
87        let prepared: Vec<_> = chunks
88            .iter()
89            .map(|chunk| self.prepare_chunk(chunk))
90            .collect();
91
92        let total_patterns = self.ac_map.len() + self.fallback.len();
93        let mut triggered = vec![vec![0u64; total_patterns.div_ceil(64)]; prepared.len()];
94        if !self.populate_gpu_batch_triggers(&prepared, &mut triggered) {
95            let fallback_backend = self.degraded_backend_after_gpu_failure();
96            tracing::debug!(
97                fallback = fallback_backend.label(),
98                "gpu batch scan unavailable, degrading to non-gpu backend"
99            );
100            return chunks
101                .iter()
102                .map(|chunk| self.scan_with_backend(chunk, fallback_backend))
103                .collect();
104        }
105
106        prepared
107            .into_iter()
108            .zip(triggered)
109            .map(|(prepared, chunk_triggered)| {
110                self.scan_prepared_with_triggered(prepared, backend, chunk_triggered, None)
111            })
112            .collect()
113    }
114
115    pub(crate) fn prepare_chunk(&self, chunk: &Chunk) -> PreparedChunk {
116        let mut owned_normalized = None;
117        let owned_unicode;
118        let chunk = if chunk.data.is_ascii() {
119            chunk
120        } else {
121            normalize_scannable_chunk(chunk, &mut owned_normalized)
122        };
123
124        let chunk = if self.config.unicode_normalization {
125            let unicode_normalized = unicode_hardening::normalize_homoglyphs(&chunk.data);
126            if unicode_normalized != chunk.data {
127                owned_unicode = Some(keyhog_core::Chunk {
128                    data: unicode_normalized,
129                    metadata: chunk.metadata.clone(),
130                });
131                owned_unicode.as_ref().unwrap_or(chunk)
132            } else {
133                chunk
134            }
135        } else {
136            chunk
137        };
138
139        let preprocessed = if let Some(pp) =
140            crate::structured::preprocess(&chunk.data, chunk.metadata.path.as_deref())
141        {
142            pp
143        } else {
144            #[cfg(feature = "multiline")]
145            if crate::multiline::has_concatenation_indicators(&chunk.data) {
146                crate::multiline::preprocess_multiline(&chunk.data, &self.config.multiline)
147            } else {
148                ScannerPreprocessedText::passthrough(&chunk.data)
149            }
150            #[cfg(not(feature = "multiline"))]
151            ScannerPreprocessedText::passthrough(&chunk.data)
152        };
153
154        PreparedChunk {
155            chunk: Arc::new(chunk.clone()),
156            preprocessed,
157        }
158    }
159
160    pub(crate) fn scan_prepared_with_triggered(
161        &self,
162        prepared: PreparedChunk,
163        backend: ScanBackend,
164        triggered_patterns: Vec<u64>,
165        deadline: Option<std::time::Instant>,
166    ) -> Vec<RawMatch> {
167        let line_offsets = compute_line_offsets(&prepared.preprocessed.text);
168        let code_lines: Vec<&str> = prepared.chunk.data.lines().collect();
169        let documentation_lines = context::documentation_line_flags(&code_lines);
170        let mut scan_state = ScanState::default();
171
172        #[cfg(feature = "simdsieve")]
173        self.scan_hot_patterns_fast(
174            &prepared.preprocessed.text,
175            &line_offsets,
176            &prepared.chunk,
177            &mut scan_state,
178        );
179
180        let expanded_patterns = if backend == ScanBackend::Gpu {
181            triggered_patterns // GPU runs full regexes; no AC prefix expansion needed.
182        } else {
183            self.expand_triggered_patterns(&triggered_patterns)
184        };
185
186        let total_patterns = self.ac_map.len() + self.fallback.len();
187        let confirmed_patterns: Vec<usize> = if backend == ScanBackend::Gpu {
188            (0..total_patterns)
189                .filter(|&i| (expanded_patterns[i / 64] & (1 << (i % 64))) != 0)
190                .collect()
191        } else {
192            (0..self.ac_map.len())
193                .filter(|&i| (expanded_patterns[i / 64] & (1 << (i % 64))) != 0)
194                .collect()
195        };
196
197        self.extract_confirmed_patterns(
198            &confirmed_patterns,
199            &prepared.preprocessed,
200            &line_offsets,
201            &code_lines,
202            &documentation_lines,
203            &prepared.chunk,
204            &mut scan_state,
205            deadline,
206        );
207
208        // Generic key=value scanner: catches secrets assigned to variables
209        // with secret-related names. Only fires when no named detector already
210        // found a match on the same line AND the value has high entropy.
211        self.scan_generic_assignments(&code_lines, &prepared.chunk, &mut scan_state);
212
213        #[cfg(feature = "entropy")]
214        self.scan_entropy_fallback(
215            &prepared.preprocessed,
216            &line_offsets,
217            &prepared.chunk,
218            &mut scan_state,
219        );
220
221        #[cfg(feature = "ml")]
222        self.apply_ml_batch_scores(&mut scan_state);
223
224        tracing::debug!(
225            backend = backend.label(),
226            path = prepared
227                .chunk
228                .metadata
229                .path
230                .as_deref()
231                .unwrap_or("<memory>"),
232            matches = scan_state.matches.len(),
233            "completed scan with selected backend"
234        );
235
236        scan_state.into_matches()
237    }
238
239    pub(crate) fn collect_triggered_patterns_for_backend(
240        &self,
241        text: &str,
242        backend: ScanBackend,
243    ) -> Vec<u64> {
244        match backend {
245            ScanBackend::Gpu => self.collect_triggered_patterns_gpu(text),
246            ScanBackend::SimdCpu => self.collect_triggered_patterns_simd(text),
247            ScanBackend::CpuFallback => self.collect_triggered_patterns_cpu(text),
248        }
249    }
250
251    fn collect_triggered_patterns_gpu(&self, text: &str) -> Vec<u64> {
252        if let Some(matcher) = self.gpu_matcher() {
253            match matcher.scan_blocking(text.as_bytes()) {
254                Ok(matches) => return self.triggered_patterns_from_gpu_matches(&matches),
255                Err(error) => {
256                    tracing::debug!("gpu scan failed, degrading to CPU path: {error}");
257                }
258            }
259        }
260        self.collect_triggered_patterns_simd(text)
261    }
262
263    fn collect_triggered_patterns_simd(&self, text: &str) -> Vec<u64> {
264        #[cfg(feature = "simd")]
265        if let Some(scanner) = &self.simd_prefilter {
266            let mut triggered_patterns = vec![0u64; self.ac_map.len().div_ceil(64)];
267            for (hs_id, _start, _end) in scanner.scan(text.as_bytes()) {
268                let Some((_detector_index, ac_index, _has_group)) = scanner.pattern_info(hs_id)
269                else {
270                    continue;
271                };
272                self.mark_triggered_pattern(&mut triggered_patterns, ac_index);
273            }
274            return triggered_patterns;
275        }
276
277        self.collect_triggered_patterns_cpu(text)
278    }
279
280    fn collect_triggered_patterns_cpu(&self, text: &str) -> Vec<u64> {
281        let mut triggered_patterns = vec![0u64; self.ac_map.len().div_ceil(64)];
282        if let Some(ac) = &self.ac {
283            for ac_match in ac.scan(text.as_bytes()).unwrap_or_default() {
284                self.mark_triggered_pattern(&mut triggered_patterns, ac_match.pattern_id as usize);
285            }
286        }
287        triggered_patterns
288    }
289
290    fn triggered_patterns_from_gpu_matches(&self, matches: &[warpstate::Match]) -> Vec<u64> {
291        let total_patterns = self.ac_map.len() + self.fallback.len();
292        let mut triggered_patterns = vec![0u64; total_patterns.div_ceil(64)];
293        for matched in matches {
294            let pattern_index = matched.pattern_id as usize;
295            if pattern_index >= total_patterns {
296                continue;
297            }
298            triggered_patterns[pattern_index / 64] |= 1u64 << (pattern_index % 64);
299        }
300        triggered_patterns
301    }
302
303    fn mark_triggered_pattern(&self, triggered_patterns: &mut [u64], pattern_index: usize) {
304        if pattern_index / 64 >= triggered_patterns.len() {
305            return;
306        }
307        triggered_patterns[pattern_index / 64] |= 1u64 << (pattern_index % 64);
308        if pattern_index < self.prefix_propagation.len() {
309            for &propagated_index in &self.prefix_propagation[pattern_index] {
310                if propagated_index / 64 < triggered_patterns.len() {
311                    triggered_patterns[propagated_index / 64] |= 1u64 << (propagated_index % 64);
312                }
313            }
314        }
315    }
316
317    pub fn gpu_matcher(&self) -> Option<&warpstate::AutoMatcher> {
318        self.gpu_matcher
319            .get_or_init(|| {
320                let patterns = self.gpu_pattern_set.as_ref()?.clone();
321                let config = warpstate::AutoMatcherConfig::new()
322                    .gpu_threshold(0)
323                    .gpu_max_input_size(usize::MAX / 2)
324                    .auto_tune_threshold(false)
325                    .max_matches(self.config.max_matches_per_chunk.min(u32::MAX as usize) as u32);
326                // wgpu async calls need a real runtime — spawn a dedicated thread
327                // to avoid nesting inside the CLI's tokio runtime.
328                let handle = std::thread::spawn(move || {
329                    pollster::block_on(warpstate::AutoMatcher::with_config(&patterns, config))
330                });
331                let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
332                loop {
333                    if handle.is_finished() {
334                        break;
335                    }
336                    if std::time::Instant::now() > deadline {
337                        tracing::warn!("GPU matcher init timed out (5s)");
338                        return None;
339                    }
340                    std::thread::sleep(std::time::Duration::from_millis(50));
341                }
342                match handle.join().ok()? {
343                    Ok(matcher) => {
344                        // Warm-up: dummy 1-byte scan to amortize cold-start latency.
345                        if let Err(e) = matcher.scan_blocking(b"x") {
346                            tracing::debug!("GPU warm-up scan failed: {e}");
347                        } else {
348                            tracing::debug!("GPU warm-up scan completed");
349                        }
350                        Some(matcher)
351                    }
352                    Err(error) => {
353                        tracing::warn!("failed to initialize warpstate GPU matcher: {error}");
354                        None
355                    }
356                }
357            })
358            .as_ref()
359    }
360
361    fn degraded_backend_after_gpu_failure(&self) -> ScanBackend {
362        let caps = crate::hw_probe::probe_hardware();
363        if caps.has_avx512 || caps.has_avx2 || caps.has_neon {
364            ScanBackend::SimdCpu
365        } else {
366            ScanBackend::CpuFallback
367        }
368    }
369
370    fn populate_gpu_batch_triggers(
371        &self,
372        prepared: &[PreparedChunk],
373        triggered: &mut [Vec<u64>],
374    ) -> bool {
375        let Some(matcher) = self.gpu_matcher() else {
376            return false;
377        };
378
379        const MAX_BATCH_BYTES: usize = 64 * 1024 * 1024;
380        const MAX_BATCH_ITEMS: usize = 2048;
381
382        let mut start = 0usize;
383        while start < prepared.len() {
384            let mut end = start;
385            let mut batch_bytes = 0usize;
386            while end < prepared.len() && end - start < MAX_BATCH_ITEMS {
387                let len = prepared[end].preprocessed.text.len();
388                if end > start && batch_bytes.saturating_add(len) > MAX_BATCH_BYTES {
389                    break;
390                }
391                batch_bytes = batch_bytes.saturating_add(len);
392                end += 1;
393            }
394
395            let (entries, buffer) = coalesce_preprocessed_batch(&prepared[start..end]);
396            let matches = match matcher.scan_blocking(&buffer) {
397                Ok(matches) => matches,
398                Err(error) => {
399                    tracing::warn!("batched GPU scan failed: {error}");
400                    return false;
401                }
402            };
403
404            map_batch_matches(self, &entries, matches, &mut triggered[start..end]);
405            start = end;
406        }
407
408        true
409    }
410}
411
412fn coalesce_preprocessed_batch(
413    prepared: &[PreparedChunk],
414) -> (Vec<(usize, usize, usize)>, Vec<u8>) {
415    let total_bytes = prepared
416        .iter()
417        .map(|chunk| chunk.preprocessed.text.len())
418        .sum();
419    let mut entries = Vec::with_capacity(prepared.len());
420    let mut buffer = Vec::with_capacity(total_bytes);
421
422    for (index, chunk) in prepared.iter().enumerate() {
423        let start = buffer.len();
424        buffer.extend_from_slice(chunk.preprocessed.text.as_bytes());
425        entries.push((index, start, chunk.preprocessed.text.len()));
426    }
427
428    (entries, buffer)
429}
430
431fn map_batch_matches(
432    scanner: &CompiledScanner,
433    entries: &[(usize, usize, usize)],
434    matches: Vec<warpstate::Match>,
435    triggered: &mut [Vec<u64>],
436) {
437    let mut cursor = 0usize;
438    for matched in matches {
439        let global_start = matched.start as usize;
440        let global_end = matched.end as usize;
441
442        while cursor < entries.len() {
443            let (_, offset, len) = entries[cursor];
444            if global_start < offset + len {
445                break;
446            }
447            cursor += 1;
448        }
449        if cursor >= entries.len() {
450            break;
451        }
452
453        let (chunk_index, offset, len) = entries[cursor];
454        if global_start < offset || global_end > offset + len {
455            continue;
456        }
457        scanner.mark_triggered_pattern(&mut triggered[chunk_index], matched.pattern_id as usize);
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::{PreparedChunk, coalesce_preprocessed_batch, map_batch_matches};
464    use crate::engine::CompiledScanner;
465    use crate::types::ScannerPreprocessedText;
466    use keyhog_core::{Chunk, ChunkMetadata, DetectorSpec, PatternSpec, Severity};
467    use std::sync::Arc;
468
469    fn chunk() -> Chunk {
470        Chunk {
471            data: String::new(),
472            metadata: ChunkMetadata::default(),
473        }
474    }
475
476    #[test]
477    fn coalescing_preserves_offsets() {
478        let prepared = vec![
479            PreparedChunk {
480                chunk: Arc::new(chunk()),
481                preprocessed: ScannerPreprocessedText::passthrough("abc"),
482            },
483            PreparedChunk {
484                chunk: Arc::new(chunk()),
485                preprocessed: ScannerPreprocessedText::passthrough("defg"),
486            },
487        ];
488
489        let (entries, buffer) = coalesce_preprocessed_batch(&prepared);
490        assert_eq!(entries, vec![(0, 0, 3), (1, 3, 4)]);
491        assert_eq!(buffer, b"abcdefg");
492    }
493
494    #[test]
495    fn cross_boundary_matches_are_dropped() {
496        let scanner = CompiledScanner::compile(vec![DetectorSpec {
497            id: "demo-token".into(),
498            name: "Demo Token".into(),
499            service: "demo".into(),
500            severity: Severity::High,
501            patterns: vec![PatternSpec {
502                regex: "abc".into(),
503                description: None,
504                group: None,
505            }],
506            companions: vec![],
507            verify: None,
508            keywords: vec!["abc".into()],
509            ..Default::default()
510        }])
511        .unwrap();
512        let entries = vec![(0usize, 0usize, 3usize), (1usize, 3usize, 3usize)];
513        let matches = vec![
514            warpstate::Match::from_parts(0, 1, 2),
515            warpstate::Match::from_parts(0, 2, 4),
516        ];
517        let mut triggered = vec![vec![0u64; 1], vec![0u64; 1]];
518
519        map_batch_matches(&scanner, &entries, matches, &mut triggered);
520
521        assert_eq!(triggered[0][0], 1);
522        assert_eq!(triggered[1][0], 0);
523    }
524}