Skip to main content

keyhog_scanner/engine/
gpu_megascan.rs

1use super::*;
2
3impl CompiledScanner {
4    pub fn scan_coalesced_megascan(
5        &self,
6        chunks: &[keyhog_core::Chunk],
7    ) -> Vec<Vec<keyhog_core::RawMatch>> {
8        use crate::hw_probe::ScanBackend;
9
10        let Some(pipeline) = self.rule_pipeline() else {
11            super::gpu_forced::deny_silent_megascan_degrade(
12                "regex pipeline compile rejected the detector set",
13            );
14            tracing::debug!(
15                "MegaScan: regex pipeline unavailable, dispatching via literal-set GPU"
16            );
17            return self.scan_coalesced_gpu(chunks);
18        };
19        let Some(backend) = self.gpu_backend.as_ref() else {
20            super::gpu_forced::deny_silent_megascan_degrade(
21                "no GPU backend acquired at compile time",
22            );
23            return self.scan_coalesced_gpu(chunks);
24        };
25
26        let (entries, buffer) = super::gpu_coalesce::coalesce_chunks(chunks);
27
28        // Pipeline was pre-built for at most `megascan_input_len()` bytes;
29        // bigger batches can't dispatch. Auto-degrade rather than
30        // truncate (truncation = silent false negatives).
31        let input_cap = super::rule_pipeline::megascan_input_len();
32        if buffer.len() > input_cap {
33            super::gpu_forced::deny_silent_megascan_degrade(
34                "coalesced batch exceeds RulePipeline input_len cap",
35            );
36            tracing::debug!(
37                buffer_bytes = buffer.len(),
38                input_len = input_cap,
39                "MegaScan: batch exceeds RulePipeline input_len cap, falling back to literal-set GPU"
40            );
41            return self.scan_coalesced_gpu(chunks);
42        }
43
44        #[cfg(target_os = "linux")]
45        // SAFETY: same contract as scan_coalesced_gpu - `buffer` is a
46        // live owned Vec describing a valid range; madvise is advisory.
47        unsafe {
48            libc::madvise(
49                buffer.as_ptr() as *mut libc::c_void,
50                buffer.len(),
51                libc::MADV_DONTDUMP,
52            );
53        }
54
55        // Same buffer-scaled cap as the literal-set path.
56        const MIN_CAP: u32 = 100_000;
57        const MAX_CAP: u32 = 16_000_000;
58        let buffer_cap = (buffer.len() / 64) as u64;
59        let cap: u32 = buffer_cap.clamp(MIN_CAP as u64, MAX_CAP as u64) as u32;
60        let max_matches = cap.saturating_add(1);
61
62        let started = std::time::Instant::now();
63        let raw_matches = match pipeline.scan(&**backend, &buffer, max_matches) {
64            Ok(matches) => matches,
65            Err(error) => {
66                tracing::error!(
67                    %error,
68                    "MegaScan dispatch failed: falling back to literal-set GPU"
69                );
70                super::gpu_forced::deny_silent_megascan_degrade(
71                    "MegaScan dispatch returned an error at runtime",
72                );
73                return self.scan_coalesced_gpu(chunks);
74            }
75        };
76        let elapsed_ms = started.elapsed().as_millis();
77        tracing::debug!(
78            target: "keyhog::routing",
79            chunks = chunks.len(),
80            buffer_bytes = buffer.len(),
81            matches = raw_matches.len(),
82            cap,
83            elapsed_ms,
84            "MegaScan RulePipeline scan completed"
85        );
86
87        if raw_matches.len() > cap as usize {
88            tracing::warn!(
89                cap,
90                "MegaScan exceeded cap: truncation possible; dispatching via literal-set GPU"
91            );
92            super::gpu_forced::deny_silent_megascan_degrade(
93                "match count exceeded MegaScan dispatch cap (truncation risk)",
94            );
95            return self.scan_coalesced_gpu(chunks);
96        }
97
98        let mut matches: Vec<vyre_libs::scan::LiteralMatch> = raw_matches
99            .iter()
100            .map(|m| vyre_libs::scan::LiteralMatch::new(m.pattern_id, m.start, m.end))
101            .collect();
102        // In-place dedup: sort by (pattern_id, start, end) and fold overlapping spans.
103        matches.sort_unstable_by(|a, b| {
104            a.pattern_id
105                .cmp(&b.pattern_id)
106                .then(a.start.cmp(&b.start))
107                .then(a.end.cmp(&b.end))
108        });
109        {
110            let mut write = 0;
111            for read in 1..matches.len() {
112                if matches[read].pattern_id == matches[write].pattern_id
113                    && matches[read].start <= matches[write].end
114                {
115                    if matches[read].end > matches[write].end {
116                        matches[write] = vyre_libs::scan::LiteralMatch::new(
117                            matches[write].pattern_id,
118                            matches[write].start,
119                            matches[read].end,
120                        );
121                    }
122                } else {
123                    write += 1;
124                    matches[write] = matches[read];
125                }
126            }
127            if !matches.is_empty() {
128                matches.truncate(write + 1);
129            }
130        }
131        matches.sort_unstable_by_key(|m| m.start);
132
133        let total_patterns = self.ac_map.len() + self.fallback.len();
134        let mut per_chunk_triggers: Vec<Vec<u64>> = chunks
135            .iter()
136            .map(|_| vec![0u64; total_patterns.div_ceil(64)])
137            .collect();
138        let mut cursor = 0usize;
139        for matched in &matches {
140            let global_start = matched.start as usize;
141            let global_end = matched.end as usize;
142            while cursor < entries.len() {
143                let (_, offset, len) = entries[cursor];
144                if global_start < offset + len {
145                    break;
146                }
147                cursor += 1;
148            }
149            if cursor >= entries.len() {
150                break;
151            }
152            let (chunk_index, offset, len) = entries[cursor];
153            if global_start < offset || global_end > offset + len {
154                continue;
155            }
156            let pattern_index = matched.pattern_id as usize;
157            if pattern_index < total_patterns {
158                per_chunk_triggers[chunk_index][pattern_index / 64] |= 1u64 << (pattern_index % 64);
159            }
160        }
161
162        use rayon::prelude::*;
163        let mut results: Vec<Vec<keyhog_core::RawMatch>> = chunks
164            .par_iter()
165            .zip(per_chunk_triggers.into_par_iter())
166            .map(|(chunk, triggered)| {
167                let prepared = self.prepare_chunk(chunk);
168                let mut matches = self.scan_prepared_with_triggered(
169                    prepared,
170                    ScanBackend::MegaScan,
171                    triggered,
172                    None,
173                );
174                self.post_process_matches(chunk, &mut matches, None);
175                matches
176            })
177            .collect();
178
179        // Same boundary reassembly as the literal-set path.
180        super::boundary::scan_chunk_boundaries(self, chunks, &mut results);
181        results
182    }
183}