Skip to main content

keyhog_scanner/
compiler.rs

1//! Logic for compiling detector specifications into an efficient scanning engine.
2
3use crate::error::{Result, ScanError};
4use crate::types::*;
5use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
6use keyhog_core::{CompanionSpec, DetectorSpec, PatternSpec};
7use regex::Regex;
8use warpstate::PatternSet;
9
10pub struct CompileState {
11    pub ac_literals: Vec<String>,
12    pub ac_map: Vec<CompiledPattern>,
13    pub fallback: Vec<(CompiledPattern, Vec<String>)>,
14    pub companions: Vec<Vec<CompiledCompanion>>,
15    pub quality_warnings: Vec<String>,
16}
17
18pub fn build_compile_state(detectors: &[DetectorSpec]) -> Result<CompileState> {
19    use rayon::prelude::*;
20
21    // Phase 1: Pre-compile all regexes in parallel (the expensive part).
22    let compiled_results: Vec<Result<(Vec<CompiledPattern>, Vec<CompiledCompanion>)>> = detectors
23        .par_iter()
24        .enumerate()
25        .map(|(detector_index, detector)| {
26            let companions = compile_detector_companions(detector)?;
27            let mut patterns = Vec::new();
28            for (pattern_index, pattern) in detector.patterns.iter().enumerate() {
29                patterns.push(compile_pattern(
30                    detector_index,
31                    pattern_index,
32                    pattern,
33                    &detector.id,
34                )?);
35            }
36            Ok((patterns, companions))
37        })
38        .collect();
39
40    // Phase 2: Assemble results sequentially (fast, no regex compilation).
41    let mut ac_literals = Vec::new();
42    let mut ac_map = Vec::new();
43    let mut fallback = Vec::new();
44    let mut companions = Vec::with_capacity(detectors.len());
45    let mut quality_warnings = Vec::new();
46
47    for (detector_index, (result, detector)) in compiled_results
48        .into_iter()
49        .zip(detectors.iter())
50        .enumerate()
51    {
52        let (compiled_patterns, detector_companions) = result?;
53        companions.push(detector_companions);
54
55        for (pattern_index, (compiled, pattern)) in compiled_patterns
56            .into_iter()
57            .zip(detector.patterns.iter())
58            .enumerate()
59        {
60            let prefixes = extract_literal_prefixes(&pattern.regex);
61
62            // Homoglyph expansion for high-confidence patterns
63            for prefix in &prefixes {
64                if prefix.len() >= 3 {
65                    let expanded_prefix = crate::homoglyph::expand_homoglyphs(prefix);
66                    if expanded_prefix != *prefix
67                        && let Ok(re) = Regex::new(&format!("^{}", expanded_prefix))
68                    {
69                        let expanded_pattern = CompiledPattern {
70                            detector_index,
71                            regex: re,
72                            group: pattern.group,
73                        };
74                        fallback.push((expanded_pattern, detector.keywords.clone()));
75                    }
76                }
77            }
78
79            if !prefixes.is_empty() {
80                for prefix in prefixes {
81                    ac_literals.push(prefix);
82                    ac_map.push(compiled.clone());
83                }
84            } else {
85                if detector.keywords.is_empty() {
86                    quality_warnings.push(format!(
87                        "Detector {} pattern {pattern_index} has no literal prefix and no keywords.",
88                        detector.id
89                    ));
90                }
91                fallback.push((compiled, detector.keywords.clone()));
92            }
93        }
94    }
95
96    Ok(CompileState {
97        ac_literals,
98        ac_map,
99        fallback,
100        companions,
101        quality_warnings,
102    })
103}
104
105pub fn build_ac_pattern_set(literals: &[String]) -> Result<Option<PatternSet>> {
106    if literals.is_empty() {
107        return Ok(None);
108    }
109    let mut builder = PatternSet::builder();
110    for lit in literals {
111        builder = builder.literal(lit);
112    }
113    Ok(Some(builder.build()?))
114}
115
116/// Build a complete PatternSet containing ALL patterns (AC regexes + fallback regexes)
117/// for GPU matching. Falls back to None if compilation fails (e.g., overly complex regexes).
118/// Build a GPU PatternSet from AC LITERAL prefixes (not regexes).
119///
120/// The GPU shader runs an AC automaton on wgpu compute cores — pattern count
121/// is irrelevant because all patterns are evaluated in parallel. Uses
122/// `.literal()` which builds an AC trie (no DFA state explosion).
123///
124/// The regex-based `.regex()` builder uses regex-automata DFA internally
125/// which explodes at >100 patterns. We NEVER use that for GPU.
126pub fn build_gpu_pattern_set(ac_literals: &[String]) -> Option<PatternSet> {
127    if ac_literals.is_empty() {
128        return None;
129    }
130    let mut builder = PatternSet::builder();
131    for lit in ac_literals {
132        if !lit.is_empty() {
133            builder = builder.literal(lit);
134        }
135    }
136    match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| builder.build())) {
137        Ok(Ok(ps)) => {
138            tracing::info!(
139                patterns = ac_literals.len(),
140                "GPU PatternSet compiled (AC literals)"
141            );
142            Some(ps)
143        }
144        Ok(Err(e)) => {
145            tracing::warn!("GPU PatternSet error: {e}");
146            None
147        }
148        Err(_) => {
149            tracing::warn!("GPU PatternSet panicked");
150            None
151        }
152    }
153}
154
155pub fn build_detector_to_patterns(
156    ac_map: &[CompiledPattern],
157    detector_count: usize,
158) -> Vec<Vec<usize>> {
159    let mut map = vec![Vec::new(); detector_count];
160    for (pat_idx, entry) in ac_map.iter().enumerate() {
161        map[entry.detector_index].push(pat_idx);
162    }
163    map
164}
165
166pub fn build_same_prefix_patterns(literals: &[String]) -> Vec<Vec<usize>> {
167    let mut groups: std::collections::HashMap<&str, Vec<usize>> = std::collections::HashMap::new();
168    for (i, lit) in literals.iter().enumerate() {
169        groups.entry(lit.as_str()).or_default().push(i);
170    }
171    let mut map = vec![Vec::new(); literals.len()];
172    for indices in groups.values() {
173        if indices.len() > 1 {
174            for &i in indices {
175                map[i] = indices.iter().copied().filter(|&j| j != i).collect();
176            }
177        }
178    }
179    map
180}
181
182pub fn build_prefix_propagation(literals: &[String]) -> Vec<Vec<usize>> {
183    let mut map = vec![Vec::new(); literals.len()];
184    // Sort indices by literal length (shortest first) for efficient prefix matching.
185    let mut sorted: Vec<(usize, &str)> = literals
186        .iter()
187        .enumerate()
188        .map(|(i, s)| (i, s.as_str()))
189        .collect();
190    sorted.sort_by_key(|(_, s)| s.len());
191    // For each longer string, check if any shorter string is its prefix.
192    for a in 0..sorted.len() {
193        for b in (a + 1)..sorted.len() {
194            let (j, short) = sorted[a];
195            let (i, long) = sorted[b];
196            if short != long && long.starts_with(short) {
197                map[j].push(i);
198            }
199        }
200    }
201    map
202}
203
204pub fn build_fallback_keyword_ac(
205    fallback: &[(CompiledPattern, Vec<String>)],
206) -> (Option<AhoCorasick>, Vec<Vec<usize>>) {
207    let mut all_keywords = Vec::new();
208    let mut keyword_to_patterns = Vec::new();
209    let mut keyword_map: std::collections::HashMap<String, usize> =
210        std::collections::HashMap::new();
211
212    for (pattern_idx, (_, keywords)) in fallback.iter().enumerate() {
213        for kw in keywords {
214            if kw.len() < 4 {
215                continue;
216            }
217            let idx = *keyword_map.entry(kw.clone()).or_insert_with(|| {
218                all_keywords.push(kw.clone());
219                keyword_to_patterns.push(Vec::new());
220                all_keywords.len() - 1
221            });
222            keyword_to_patterns[idx].push(pattern_idx);
223        }
224    }
225
226    if all_keywords.is_empty() {
227        return (None, Vec::new());
228    }
229
230    let ac = AhoCorasickBuilder::new()
231        .ascii_case_insensitive(true)
232        .build(all_keywords)
233        .ok();
234
235    (ac, keyword_to_patterns)
236}
237
238pub fn log_quality_warnings(warnings: &[String]) {
239    for warning in warnings {
240        tracing::warn!(target: "keyhog::scanner::quality", "{}", warning);
241    }
242}
243
244pub fn compile_detector_companions(detector: &DetectorSpec) -> Result<Vec<CompiledCompanion>> {
245    detector
246        .companions
247        .iter()
248        .map(|companion| compile_companion(companion, &detector.id))
249        .collect()
250}
251
252#[allow(clippy::too_many_arguments, dead_code)]
253pub fn compile_detector_pattern(
254    detector_index: usize,
255    detector: &DetectorSpec,
256    pattern_index: usize,
257    pattern: &PatternSpec,
258    ac_literals: &mut Vec<String>,
259    ac_map: &mut Vec<CompiledPattern>,
260    fallback: &mut Vec<(CompiledPattern, Vec<String>)>,
261    quality_warnings: &mut Vec<String>,
262) -> Result<()> {
263    let detector_id = &detector.id;
264    let compiled = compile_pattern(detector_index, pattern_index, pattern, detector_id)?;
265
266    // Prefix extraction for Aho-Corasick prefiltering
267    let prefixes = extract_literal_prefixes(&pattern.regex);
268
269    // Proactive Homoglyph Expansion:
270    // For high-confidence patterns (with literal prefixes), add an expanded
271    // version that handles common Unicode lookalike characters.
272    for prefix in &prefixes {
273        if prefix.len() >= 3 {
274            let expanded_prefix = crate::homoglyph::expand_homoglyphs(prefix);
275            if expanded_prefix != *prefix
276                && let Ok(re) = Regex::new(&format!("^{}", expanded_prefix))
277            {
278                let expanded_pattern = CompiledPattern {
279                    detector_index,
280                    regex: re,
281                    group: pattern.group,
282                };
283                // Always put homoglyph variants in fallback (they are regexes)
284                fallback.push((expanded_pattern, detector.keywords.clone()));
285            }
286        }
287    }
288
289    if !prefixes.is_empty() {
290        tracing::debug!(
291            detector_id,
292            ?prefixes,
293            mode = "AC",
294            "compiled detector pattern"
295        );
296        for prefix in prefixes {
297            ac_literals.push(prefix);
298            ac_map.push(compiled.clone());
299        }
300    } else {
301        // No literal prefix. With Hyperscan, these will be compiled directly
302        // into the HS database alongside the AC-prefix patterns. Without
303        // Hyperscan, they go to the keyword-gated regex fallback.
304        if detector.keywords.is_empty() {
305            quality_warnings.push(format!(
306                "Detector {detector_id} pattern {pattern_index} has no literal prefix and no keywords."
307            ));
308        }
309        fallback.push((compiled, detector.keywords.clone()));
310    }
311    Ok(())
312}
313
314pub fn compile_pattern(
315    detector_index: usize,
316    pattern_index: usize,
317    spec: &PatternSpec,
318    detector_id: &str,
319) -> Result<CompiledPattern> {
320    let regex = regex::RegexBuilder::new(&spec.regex)
321        .size_limit(REGEX_SIZE_LIMIT_BYTES)
322        .dfa_size_limit(REGEX_SIZE_LIMIT_BYTES)
323        .crlf(true)
324        .build()
325        .map_err(|e| ScanError::RegexCompile {
326            detector_id: detector_id.to_string(),
327            index: pattern_index,
328            source: e,
329        })?;
330    Ok(CompiledPattern {
331        detector_index,
332        regex,
333        group: spec.group,
334    })
335}
336
337pub fn compile_companion(spec: &CompanionSpec, detector_id: &str) -> Result<CompiledCompanion> {
338    let regex = regex::RegexBuilder::new(&spec.regex)
339        .size_limit(REGEX_SIZE_LIMIT_BYTES)
340        .dfa_size_limit(REGEX_SIZE_LIMIT_BYTES)
341        .crlf(true)
342        .build()
343        .map_err(|e| ScanError::RegexCompile {
344            detector_id: detector_id.to_string(),
345            index: FIRST_CAPTURE_GROUP_INDEX,
346            source: e,
347        })?;
348    let capture_group = (regex.captures_len() > 1).then_some(FIRST_CAPTURE_GROUP_INDEX);
349    Ok(CompiledCompanion {
350        name: spec.name.clone(),
351        regex,
352        capture_group,
353        within_lines: spec.within_lines,
354        required: spec.required,
355    })
356}
357
358/// Extract literal prefixes from a regex pattern for Aho-Corasick.
359/// Handles simple literals and top-level groups like (AKIA|ASIA).
360pub fn extract_literal_prefixes(pattern: &str) -> Vec<String> {
361    // Strip leading inline flags like (?i), (?m), (?s), (?x), (?im), etc.
362    // These set regex modes but don't consume input.
363    let pattern = strip_leading_inline_flags(pattern);
364
365    if pattern.starts_with('(') && pattern.contains('|') {
366        // Handle (A|B|C)
367        let mut depth = 0;
368        let mut end_idx = None;
369        for (i, ch) in pattern.char_indices() {
370            match ch {
371                '(' => depth += 1,
372                ')' => {
373                    depth -= 1;
374                    if depth == 0 {
375                        end_idx = Some(i);
376                        break;
377                    }
378                }
379                _ => {}
380            }
381        }
382
383        if let Some(end) = end_idx {
384            let mut inner = &pattern[1..end];
385            // Strip non-capturing group prefix (?:, (?i:, (?im:, etc.)
386            if inner.starts_with("?:") {
387                inner = &inner[2..];
388            } else if inner.starts_with("?i:")
389                || inner.starts_with("?m:")
390                || inner.starts_with("?s:")
391            {
392                inner = &inner[3..];
393            } else if inner.starts_with("?im:")
394                || inner.starts_with("?is:")
395                || inner.starts_with("?ms:")
396            {
397                inner = &inner[4..];
398            }
399            // Split by |, but only at depth 0
400            let mut parts = Vec::new();
401            let mut start = 0;
402            let mut d = 0;
403            for (i, ch) in inner.char_indices() {
404                match ch {
405                    '(' => d += 1,
406                    ')' => d -= 1,
407                    '|' if d == 0 => {
408                        parts.push(&inner[start..i]);
409                        start = i + 1;
410                    }
411                    _ => {}
412                }
413            }
414            parts.push(&inner[start..]);
415
416            let mut results = Vec::new();
417            for part in parts {
418                if let Some(p) = extract_literal_prefix(part) {
419                    results.push(p);
420                }
421            }
422            if !results.is_empty() {
423                return results;
424            }
425        }
426    }
427
428    // Default: try to extract a single prefix from the start
429    extract_literal_prefix(pattern).into_iter().collect()
430}
431
432/// Strip leading inline flags like `(?i)`, `(?m)`, `(?ims)` from a regex.
433/// These set modes for the rest of the pattern but don't produce a group.
434fn strip_leading_inline_flags(pattern: &str) -> &str {
435    if !pattern.starts_with("(?") {
436        return pattern;
437    }
438    // (?i), (?m), (?s), (?x), (?im), (?ims), (?imsx) etc. — flags only, no ':'
439    let bytes = pattern.as_bytes();
440    if bytes.len() < 4 || bytes[0] != b'(' || bytes[1] != b'?' {
441        return pattern;
442    }
443    let mut i = 2;
444    while i < bytes.len() && matches!(bytes[i], b'i' | b'm' | b's' | b'x' | b'u' | b'U') {
445        i += 1;
446    }
447    if i < bytes.len() && bytes[i] == b')' {
448        // (?flags) — strip the entire inline flag group
449        &pattern[i + 1..]
450    } else {
451        pattern
452    }
453}
454
455pub fn extract_literal_prefix(pattern: &str) -> Option<String> {
456    let mut prefix = String::new();
457    let mut chars = pattern.chars().peekable();
458    while let Some(ch) = chars.next() {
459        match ch {
460            '\\' => {
461                let Some(next) = chars.next() else {
462                    break;
463                };
464                if is_escaped_literal(next) {
465                    prefix.push(next);
466                } else {
467                    break;
468                }
469            }
470            '[' | '.' | '*' | '+' | '?' | '{' | '|' | '^' | '$' => break,
471            '(' => {
472                // Mid-pattern alternation: try to extend the prefix with
473                // the group's alternatives. This turns "secret_(key|token)"
474                // into prefix "secret_key" (the longest common prefix after
475                // expanding alternatives). If the group has no pipe, continue
476                // extracting the literal inside it.
477                let group_start = chars.clone().collect::<String>();
478                if let Some(alternatives) = extract_group_alternatives(&group_start) {
479                    // Find the longest common prefix of all alternatives
480                    if let Some(first) = alternatives.first() {
481                        let common: String = first.chars()
482                            .enumerate()
483                            .take_while(|(i, c)| {
484                                alternatives.iter().all(|alt| {
485                                    alt.chars().nth(*i) == Some(*c)
486                                })
487                            })
488                            .map(|(_, c)| c)
489                            .collect();
490                        if !common.is_empty() {
491                            prefix.push_str(&common);
492                        }
493                    }
494                }
495                break;
496            }
497            _ => {
498                prefix.push(ch);
499            }
500        }
501    }
502    if prefix.len() >= MIN_LITERAL_PREFIX_CHARS {
503        Some(prefix)
504    } else {
505        None
506    }
507}
508
509/// Extract literal alternatives from a group at the start of a string.
510/// Input: "key|token)rest..." → Some(["key", "token"])
511/// Returns None if the group contains regex metacharacters.
512fn extract_group_alternatives(s: &str) -> Option<Vec<String>> {
513    // Strip optional non-capturing prefix
514    let inner = s.strip_prefix("?:")
515        .or_else(|| s.strip_prefix("?i:"))
516        .or_else(|| s.strip_prefix("?im:"))
517        .unwrap_or(s);
518
519    let mut depth = 0i32;
520    let mut end = None;
521    for (i, ch) in inner.char_indices() {
522        match ch {
523            '(' => depth += 1,
524            ')' => {
525                if depth == 0 {
526                    end = Some(i);
527                    break;
528                }
529                depth -= 1;
530            }
531            _ => {}
532        }
533    }
534    let end = end?;
535    let group_content = &inner[..end];
536
537    // Split by | at depth 0
538    let mut parts = Vec::new();
539    let mut start = 0;
540    let mut d = 0i32;
541    for (i, ch) in group_content.char_indices() {
542        match ch {
543            '(' => d += 1,
544            ')' => d -= 1,
545            '|' if d == 0 => {
546                parts.push(&group_content[start..i]);
547                start = i + 1;
548            }
549            _ => {}
550        }
551    }
552    parts.push(&group_content[start..]);
553
554    // Extract literal prefix from each alternative
555    let literals: Vec<String> = parts.iter()
556        .filter_map(|part| {
557            let mut lit = String::new();
558            for ch in part.chars() {
559                match ch {
560                    'a'..='z' | 'A'..='Z' | '0'..='9' | '_' | '-' | '.' | ':' | '=' | ' ' => {
561                        lit.push(ch);
562                    }
563                    '\\' => break, // escaped char — stop
564                    _ => break, // metachar — stop
565                }
566            }
567            if lit.is_empty() { None } else { Some(lit) }
568        })
569        .collect();
570
571    if literals.len() == parts.len() && !literals.is_empty() {
572        Some(literals)
573    } else {
574        None
575    }
576}
577
578pub fn is_escaped_literal(ch: char) -> bool {
579    matches!(
580        ch,
581        '[' | ']' | '(' | ')' | '.' | '*' | '+' | '?' | '{' | '}' | '\\' | '|' | '^' | '$'
582    )
583}