Skip to main content

keyhog_core/spec/
load.rs

1//! Detector loading pipeline: read TOML files, run the quality gate, and inject
2//! small compatibility shims for legacy token formats when needed.
3
4use std::io;
5use std::path::{Path, PathBuf};
6
7use rayon::prelude::*;
8use serde::{Deserialize, Serialize};
9
10use super::{DetectorFile, DetectorSpec, PatternSpec, QualityIssue, SpecError, validate_detector};
11
12const DETECTOR_CACHE_VERSION: u32 = 2;
13
14#[derive(Serialize, Deserialize)]
15struct DetectorCacheFile {
16    version: u32,
17    detectors: Vec<DetectorSpec>,
18}
19
20/// Save detectors to a JSON cache file for fast subsequent loads.
21///
22/// # Examples
23///
24/// ```rust,no_run
25/// use keyhog_core::{DetectorSpec, save_detector_cache};
26/// use std::path::Path;
27///
28/// let detectors: Vec<DetectorSpec> = Vec::new();
29/// save_detector_cache(&detectors, Path::new(".keyhog-cache.json")).unwrap();
30/// ```
31pub fn save_detector_cache(
32    detectors: &[DetectorSpec],
33    cache_path: &Path,
34) -> Result<(), std::io::Error> {
35    for detector in detectors {
36        let issues = validate_detector(detector);
37        if issues
38            .iter()
39            .any(|issue| matches!(issue, QualityIssue::Error(_)))
40        {
41            return Err(io::Error::new(
42                io::ErrorKind::InvalidData,
43                format!(
44                    "refusing to cache invalid detector '{}'. Fix: repair the detector before writing the cache",
45                    detector.id
46                ),
47            ));
48        }
49    }
50
51    let json = serde_json::to_vec(&DetectorCacheFile {
52        version: DETECTOR_CACHE_VERSION,
53        detectors: detectors.to_vec(),
54    })?;
55    std::fs::write(cache_path, json)
56}
57
58/// Load detectors from a JSON cache file. Returns None if cache is stale or missing.
59///
60/// # Examples
61///
62/// ```rust,no_run
63/// use keyhog_core::load_detector_cache;
64/// use std::path::Path;
65///
66/// let _cached = load_detector_cache(
67///     Path::new(".keyhog-cache.json"),
68///     Path::new("detectors"),
69/// );
70/// ```
71///
72/// # Security
73///
74/// Cached detectors are re-validated through the quality gate to prevent cache
75/// poisoning attacks where a malicious `.keyhog-cache.json` injects evil regex
76/// patterns that bypass the TOML quality gate.
77pub fn load_detector_cache(cache_path: &Path, source_dir: &Path) -> Option<Vec<DetectorSpec>> {
78    let cache_meta = std::fs::metadata(cache_path).ok()?;
79    let cache_mtime = cache_meta.modified().ok()?;
80
81    // Check if any TOML in source_dir is newer than the cache
82    let entries = std::fs::read_dir(source_dir).ok()?;
83    for entry in entries.flatten() {
84        let path = entry.path();
85        if path.extension().is_some_and(|ext| ext == "toml") {
86            let is_stale = std::fs::metadata(&path)
87                .and_then(|meta| meta.modified())
88                .is_ok_and(|mtime| mtime > cache_mtime);
89
90            if is_stale {
91                return None; // Cache is stale
92            }
93        }
94    }
95
96    let data = match std::fs::read(cache_path) {
97        Ok(data) => data,
98        Err(error) => {
99            tracing::warn!(
100                "failed to read detector cache {}: {}",
101                cache_path.display(),
102                error
103            );
104            return None;
105        }
106    };
107    let cache: DetectorCacheFile = match serde_json::from_slice(&data) {
108        Ok(cache) => cache,
109        Err(error) => {
110            tracing::warn!(
111                "failed to parse detector cache {}: {}",
112                cache_path.display(),
113                error
114            );
115            return None;
116        }
117    };
118    if cache.version != DETECTOR_CACHE_VERSION {
119        return None;
120    }
121
122    let mut validated = Vec::with_capacity(cache.detectors.len());
123    for spec in cache.detectors {
124        let issues = validate_detector(&spec);
125        if issues
126            .iter()
127            .any(|issue| matches!(issue, QualityIssue::Error(_)))
128        {
129            tracing::warn!(
130                "cached detector '{}' failed quality gate; discarding the entire cache",
131                spec.id
132            );
133            return None;
134        }
135        validated.push(spec);
136    }
137
138    if validated.is_empty() {
139        tracing::warn!("detector cache is empty after validation, falling back to TOML load");
140        return None;
141    }
142
143    Some(validated)
144}
145
146/// Load all detector specs from a directory of TOML files.
147/// Runs quality gate on each detector. Rejects detectors with errors, warns on issues.
148///
149/// # Examples
150///
151/// ```rust,no_run
152/// use keyhog_core::load_detectors;
153/// use std::path::Path;
154///
155/// let detectors = load_detectors(Path::new("detectors")).unwrap();
156/// assert!(!detectors.is_empty());
157/// ```
158pub fn load_detectors(dir: &Path) -> Result<Vec<DetectorSpec>, SpecError> {
159    load_detectors_with_gate(dir, true)
160}
161
162/// Load detectors with optional quality gate enforcement.
163/// When `enforce_gate` is `true`, detectors with quality errors are skipped.
164///
165/// # Examples
166///
167/// ```rust,no_run
168/// use keyhog_core::load_detectors_with_gate;
169/// use std::path::Path;
170///
171/// let _detectors = load_detectors_with_gate(Path::new("detectors"), true).unwrap();
172/// ```
173pub fn load_detectors_with_gate(
174    dir: &Path,
175    enforce_gate: bool,
176) -> Result<Vec<DetectorSpec>, SpecError> {
177    // Phase 1: collect all TOML file paths (fast, sequential)
178    let entries = std::fs::read_dir(dir).map_err(|e| SpecError::ReadFile {
179        path: dir.display().to_string(),
180        source: e,
181    })?;
182    let toml_paths: Vec<PathBuf> = entries
183        .filter_map(|entry| {
184            let entry = entry.ok()?;
185            let path = entry.path();
186            if path.extension().is_some_and(|ext| ext == "toml") {
187                Some(path)
188            } else {
189                None
190            }
191        })
192        .collect();
193
194    // Phase 2: read + parse all TOMLs in parallel
195    let parsed: Vec<ReadDetectorOutcome> = toml_paths
196        .par_iter()
197        .map(|path| read_detector_file(path))
198        .collect();
199
200    // Phase 3: validate + filter (sequential for logging)
201    let mut load_state = DetectorLoadState::default();
202    let mut detectors = Vec::with_capacity(parsed.len());
203
204    for outcome in parsed {
205        match outcome {
206            ReadDetectorOutcome::Loaded(spec) => {
207                if should_reject_detector(
208                    &spec,
209                    enforce_gate,
210                    &mut load_state.gate_rejected,
211                    &mut load_state.total_warnings,
212                ) {
213                    continue;
214                }
215                detectors.push(spec);
216            }
217            ReadDetectorOutcome::Skipped { message } => {
218                load_state.skipped += 1;
219                load_state.load_errors.push(message);
220            }
221        }
222    }
223
224    if should_inject_github_classic_pat_detector(&detectors) {
225        inject_github_classic_pat_detector(&mut detectors);
226    }
227
228    log_load_summary(&load_state);
229
230    detectors.sort_by(|a, b| a.id.cmp(&b.id));
231    Ok(detectors)
232}
233
234#[derive(Default)]
235struct DetectorLoadState {
236    skipped: usize,
237    load_errors: Vec<String>,
238    gate_rejected: usize,
239    total_warnings: usize,
240}
241
242fn log_load_summary(state: &DetectorLoadState) {
243    if state.skipped > 0 {
244        tracing::warn!("skipped {} malformed detector files", state.skipped);
245    }
246    for error in &state.load_errors {
247        tracing::warn!("detector load issue: {error}");
248    }
249    if state.gate_rejected > 0 {
250        tracing::warn!("quality gate: rejected {} detectors", state.gate_rejected);
251    }
252    if state.total_warnings > 0 {
253        tracing::debug!("quality gate: {} warnings", state.total_warnings);
254    }
255}
256
257enum ReadDetectorOutcome {
258    Loaded(DetectorSpec),
259    Skipped { message: String },
260}
261
262fn read_detector_file(path: &Path) -> ReadDetectorOutcome {
263    let contents = match std::fs::read_to_string(path) {
264        Ok(contents) => contents,
265        Err(error) => {
266            let message = format!("failed to read {}: {}", path.display(), error);
267            tracing::debug!("{message}");
268            return ReadDetectorOutcome::Skipped { message };
269        }
270    };
271
272    match toml::from_str::<DetectorFile>(&contents) {
273        Ok(file) => ReadDetectorOutcome::Loaded(file.detector),
274        Err(error) => {
275            let message = format!("failed to parse {}: {}", path.display(), error);
276            tracing::debug!("{message}");
277            ReadDetectorOutcome::Skipped { message }
278        }
279    }
280}
281
282fn should_reject_detector(
283    spec: &DetectorSpec,
284    enforce_gate: bool,
285    gate_rejected: &mut usize,
286    total_warnings: &mut usize,
287) -> bool {
288    let mut has_errors = false;
289    for issue in validate_detector(spec) {
290        match issue {
291            QualityIssue::Warning(warning) => {
292                tracing::debug!("quality: {} — {}", spec.id, warning);
293                *total_warnings += 1;
294            }
295            QualityIssue::Error(error) => {
296                tracing::warn!("failed to validate detector: {}: {}", spec.id, error);
297                has_errors = true;
298            }
299        }
300    }
301
302    if has_errors && enforce_gate {
303        *gate_rejected += 1;
304        return true;
305    }
306
307    false
308}
309
310pub(super) fn inject_github_classic_pat_detector(detectors: &mut Vec<DetectorSpec>) {
311    let Some(github_fine_grained) = detectors
312        .iter()
313        .find(|d| d.id == "github-pat-fine-grained")
314        .cloned()
315    else {
316        return;
317    };
318
319    let mut compat = github_fine_grained;
320    compat.id = "github-classic-pat".into();
321    compat.name = "GitHub Classic PAT".into();
322    compat.keywords = vec!["ghp_".into(), "github".into()];
323    compat.patterns = vec![PatternSpec {
324        regex: "ghp_[a-zA-Z0-9]{36,40}".into(),
325        description: Some("GitHub classic personal access token".into()),
326        group: None,
327    }];
328
329    detectors.push(compat);
330}
331
332fn should_inject_github_classic_pat_detector(detectors: &[DetectorSpec]) -> bool {
333    !detectors.iter().any(|d| d.id == "github-classic-pat")
334        && detectors.iter().any(|d| d.id == "github-pat-fine-grained")
335}