use crate::error::{Result, ScanError};
use crate::types::*;
use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
use keyhog_core::{CompanionSpec, DetectorSpec, PatternSpec};
use regex::Regex;
use super::compiler_prefix::extract_literal_prefixes;
pub fn build_ac_pattern_set(literals: &[String]) -> Result<Option<AhoCorasick>> {
if literals.is_empty() {
return Ok(None);
}
Ok(Some(
AhoCorasickBuilder::new()
.ascii_case_insensitive(true)
.build(literals)?,
))
}
pub fn build_gpu_literals(ac_literals: &[String]) -> Option<std::sync::Arc<Vec<Vec<u8>>>> {
if ac_literals.iter().any(String::is_empty) {
tracing::warn!("GPU literal set contains an empty literal; disabling GPU literal scan");
return None;
}
let literals: Vec<Vec<u8>> = ac_literals
.iter()
.map(|literal| literal.to_ascii_lowercase().into_bytes())
.collect();
if literals.is_empty() {
None
} else {
tracing::info!(
patterns = literals.len(),
"GPU literal set prepared for Vyre"
);
Some(std::sync::Arc::new(literals))
}
}
pub fn build_same_prefix_patterns(literals: &[String]) -> Vec<Vec<usize>> {
let mut groups: std::collections::HashMap<&str, Vec<usize>> = std::collections::HashMap::new();
for (i, lit) in literals.iter().enumerate() {
groups.entry(lit.as_str()).or_default().push(i);
}
let mut map = vec![Vec::new(); literals.len()];
for indices in groups.values() {
if indices.len() > 1 {
for &i in indices {
map[i] = indices.iter().copied().filter(|&j| j != i).collect();
}
}
}
map
}
pub fn build_prefix_propagation(literals: &[String]) -> Vec<Vec<usize>> {
crate::prefix_trie::build_propagation_table(literals)
}
pub fn build_fallback_keyword_ac(
fallback: &[(CompiledPattern, Vec<String>)],
) -> (Option<AhoCorasick>, Vec<Vec<usize>>) {
let mut all_keywords = Vec::new();
let mut keyword_to_patterns = Vec::new();
let mut keyword_map: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for (pattern_idx, (_, keywords)) in fallback.iter().enumerate() {
for kw in keywords {
if kw.len() < 4 {
continue;
}
let idx = *keyword_map.entry(kw.clone()).or_insert_with(|| {
all_keywords.push(kw.clone());
keyword_to_patterns.push(Vec::new());
all_keywords.len() - 1
});
keyword_to_patterns[idx].push(pattern_idx);
}
}
if all_keywords.is_empty() {
return (None, Vec::new());
}
let ac = AhoCorasickBuilder::new()
.ascii_case_insensitive(true)
.build(all_keywords)
.ok();
(ac, keyword_to_patterns)
}
pub fn log_quality_warnings(warnings: &[String]) {
for warning in warnings {
tracing::warn!(target: "keyhog::scanner::quality", "{}", warning);
}
}
pub fn compile_detector_companions(detector: &DetectorSpec) -> Result<Vec<CompiledCompanion>> {
detector
.companions
.iter()
.map(|companion| compile_companion(companion, &detector.id))
.collect()
}
#[allow(clippy::too_many_arguments)]
pub fn compile_detector_pattern(
detector_index: usize,
detector: &DetectorSpec,
pattern_index: usize,
pattern: &PatternSpec,
ac_literals: &mut Vec<String>,
ac_map: &mut Vec<CompiledPattern>,
fallback: &mut Vec<(CompiledPattern, Vec<String>)>,
quality_warnings: &mut Vec<String>,
) -> Result<()> {
let detector_id = &detector.id;
let compiled = compile_pattern(detector_index, pattern_index, pattern, detector_id)?;
let prefixes = extract_literal_prefixes(&pattern.regex);
if !prefixes.is_empty() {
tracing::debug!(
detector_id,
?prefixes,
mode = "AC",
"compiled detector pattern"
);
for prefix in prefixes {
ac_literals.push(prefix);
ac_map.push(compiled.clone());
}
} else {
if detector.keywords.is_empty() {
quality_warnings.push(format!(
"Detector {detector_id} pattern {pattern_index} has no literal prefix and no keywords."
));
}
fallback.push((compiled, detector.keywords.clone()));
}
Ok(())
}
static VALIDATED_REGEX_SOURCES: std::sync::OnceLock<
parking_lot::Mutex<std::collections::HashSet<std::sync::Arc<str>>>,
> = std::sync::OnceLock::new();
fn validated_regex_sources(
) -> &'static parking_lot::Mutex<std::collections::HashSet<std::sync::Arc<str>>> {
VALIDATED_REGEX_SOURCES
.get_or_init(|| parking_lot::Mutex::new(std::collections::HashSet::new()))
}
pub fn compile_pattern(
detector_index: usize,
pattern_index: usize,
spec: &PatternSpec,
detector_id: &str,
) -> Result<CompiledPattern> {
let already_validated = validated_regex_sources()
.lock()
.contains(spec.regex.as_str());
if !already_validated {
if regex_syntax::Parser::new().parse(&spec.regex).is_err() {
let source = regex::Regex::new(&spec.regex)
.err()
.unwrap_or_else(|| regex::Error::Syntax(spec.regex.clone()));
return Err(ScanError::RegexCompile {
detector_id: detector_id.to_string(),
index: pattern_index,
source,
});
}
validated_regex_sources()
.lock()
.insert(std::sync::Arc::from(spec.regex.as_str()));
}
Ok(CompiledPattern {
detector_index,
regex: LazyRegex::detector(spec.regex.as_str()),
group: spec.group,
client_safe: spec.client_safe,
})
}
const REGEX_CACHE_SHARDS: usize = 64;
const REGEX_CACHE_CAPACITY: usize = 8192;
type RegexCacheShard = parking_lot::Mutex<lru::LruCache<String, std::sync::Arc<Regex>>>;
static REGEX_CACHE: std::sync::OnceLock<Box<[RegexCacheShard]>> = std::sync::OnceLock::new();
fn regex_cache() -> &'static [RegexCacheShard] {
REGEX_CACHE.get_or_init(|| {
let per_shard = (REGEX_CACHE_CAPACITY / REGEX_CACHE_SHARDS).max(1);
let nz = std::num::NonZeroUsize::new(per_shard).unwrap_or(std::num::NonZeroUsize::MIN);
(0..REGEX_CACHE_SHARDS)
.map(|_| parking_lot::Mutex::new(lru::LruCache::new(nz)))
.collect::<Vec<_>>()
.into_boxed_slice()
})
}
fn regex_cache_shard(pattern: &str) -> &'static RegexCacheShard {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
pattern.hash(&mut hasher);
let idx = (hasher.finish() as usize) % REGEX_CACHE_SHARDS;
®ex_cache()[idx]
}
pub fn shared_regex_compile(
pattern: &str,
) -> std::result::Result<std::sync::Arc<Regex>, regex::Error> {
let regex = regex::RegexBuilder::new(pattern)
.case_insensitive(true)
.size_limit(REGEX_SIZE_LIMIT_BYTES)
.dfa_size_limit(regex_dfa_limit())
.crlf(true)
.build()?;
Ok(std::sync::Arc::new(regex))
}
pub fn warm_shared_regex_cache(
compiled: Vec<(
String,
std::result::Result<std::sync::Arc<Regex>, regex::Error>,
)>,
) {
for (pattern, res) in compiled {
if let Ok(arc) = res {
regex_cache_shard(&pattern).lock().put(pattern, arc);
}
}
}
pub(crate) fn shared_regex(
pattern: &str,
) -> std::result::Result<std::sync::Arc<Regex>, regex::Error> {
let shard = regex_cache_shard(pattern);
if let Some(hit) = shard.lock().get(pattern) {
return Ok(std::sync::Arc::clone(hit));
}
let arc = shared_regex_compile(pattern)?;
let mut lock = shard.lock();
if let Some(hit) = lock.get(pattern) {
return Ok(std::sync::Arc::clone(hit));
}
lock.put(pattern.to_string(), std::sync::Arc::clone(&arc));
Ok(arc)
}
pub fn compile_companion(spec: &CompanionSpec, detector_id: &str) -> Result<CompiledCompanion> {
let regex = regex::RegexBuilder::new(&spec.regex)
.size_limit(REGEX_SIZE_LIMIT_BYTES)
.dfa_size_limit(regex_dfa_limit())
.crlf(true)
.build()
.map_err(|e| ScanError::RegexCompile {
detector_id: detector_id.to_string(),
index: FIRST_CAPTURE_GROUP_INDEX,
source: e,
})?;
let capture_group = (regex.captures_len() > 1).then_some(FIRST_CAPTURE_GROUP_INDEX);
Ok(CompiledCompanion {
name: spec.name.clone(),
regex,
capture_group,
within_lines: spec.within_lines,
required: spec.required,
})
}