use crate::detectors::base::{
is_test_file, DetectionSummary, Detector, DetectorResult, ProgressCallback,
};
use crate::detectors::context_hmm::{ContextClassifier, FunctionContext, FunctionFeatures};
use crate::detectors::function_context::{FunctionContextBuilder, FunctionContextMap};
use crate::graph::GraphStore;
use crate::models::Finding;
use anyhow::Result;
use rayon::prelude::*;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, error, info, warn};
const MAX_FINDINGS_LIMIT: usize = 10_000;
pub struct DetectorEngine {
detectors: Vec<Arc<dyn Detector>>,
workers: usize,
max_findings: usize,
progress_callback: Option<ProgressCallback>,
function_contexts: Option<Arc<FunctionContextMap>>,
hmm_contexts: Option<Arc<HashMap<String, FunctionContext>>>,
skip_test_files: bool,
hmm_cache_path: Option<std::path::PathBuf>,
}
impl DetectorEngine {
pub fn new(workers: usize) -> Self {
let actual_workers = if workers == 0 {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
.min(16) } else {
workers
};
Self {
detectors: Vec::new(),
workers: actual_workers,
max_findings: MAX_FINDINGS_LIMIT,
progress_callback: None,
function_contexts: None,
hmm_contexts: None,
skip_test_files: true, hmm_cache_path: None,
}
}
pub fn with_hmm_cache(mut self, path: std::path::PathBuf) -> Self {
self.hmm_cache_path = Some(path);
self
}
pub fn default() -> Self {
Self::new(0)
}
pub fn with_max_findings(mut self, max: usize) -> Self {
self.max_findings = max;
self
}
pub fn with_progress_callback(mut self, callback: ProgressCallback) -> Self {
self.progress_callback = Some(callback);
self
}
pub fn with_function_contexts(mut self, contexts: Arc<FunctionContextMap>) -> Self {
self.function_contexts = Some(contexts);
self
}
pub fn with_skip_test_files(mut self, skip: bool) -> Self {
self.skip_test_files = skip;
self
}
pub fn get_or_build_contexts(
&mut self,
graph: &dyn crate::graph::GraphQuery,
) -> Arc<FunctionContextMap> {
if let Some(ref ctx) = self.function_contexts {
return Arc::clone(ctx);
}
info!("Building function contexts from graph...");
let contexts = FunctionContextBuilder::new(graph).build();
let arc = Arc::new(contexts);
self.function_contexts = Some(Arc::clone(&arc));
arc
}
pub fn function_contexts(&self) -> Option<&Arc<FunctionContextMap>> {
self.function_contexts.as_ref()
}
pub fn build_hmm_contexts(
&mut self,
graph: &dyn crate::graph::GraphQuery,
) -> Arc<HashMap<String, FunctionContext>> {
if let Some(ref ctx) = self.hmm_contexts {
return Arc::clone(ctx);
}
let cache_path = self.hmm_cache_path.clone();
let mut classifier = if let Some(ref path) = cache_path {
let model_path = path.join("hmm_model.json");
if model_path.exists() {
info!("Loading cached HMM+CRF model from {:?}", model_path);
ContextClassifier::load(&model_path).unwrap_or_else(|| {
ContextClassifier::for_codebase(Some(&model_path))
})
} else {
ContextClassifier::new()
}
} else {
ContextClassifier::new()
};
info!("Building HMM function contexts from graph...");
let mut functions = graph.get_functions();
if functions.is_empty() {
let empty = Arc::new(HashMap::new());
self.hmm_contexts = Some(Arc::clone(&empty));
return empty;
}
const MAX_FUNCTIONS_FOR_HMM: usize = 20_000;
if functions.len() > MAX_FUNCTIONS_FOR_HMM {
warn!(
"Limiting HMM analysis to {} functions (codebase has {})",
MAX_FUNCTIONS_FOR_HMM,
functions.len()
);
functions.sort_by(|a, b| {
let a_callers = graph.get_callers(&a.qualified_name).len();
let b_callers = graph.get_callers(&b.qualified_name).len();
b_callers.cmp(&a_callers)
});
functions.truncate(MAX_FUNCTIONS_FOR_HMM);
}
let mut max_fan_in = 1usize;
let mut max_fan_out = 1usize;
let mut total_complexity = 0i64;
let mut complexity_count = 0usize;
let mut total_loc = 0u32;
let mut total_params = 0usize;
for func in &functions {
let fan_in = graph.get_callers(&func.qualified_name).len();
let fan_out = graph.get_callees(&func.qualified_name).len();
max_fan_in = max_fan_in.max(fan_in);
max_fan_out = max_fan_out.max(fan_out);
if let Some(c) = func.complexity() {
total_complexity += c;
complexity_count += 1;
}
total_loc += func.line_end.saturating_sub(func.line_start) + 1;
total_params += 3;
}
let avg_complexity = if complexity_count > 0 {
total_complexity as f64 / complexity_count as f64
} else {
10.0
};
let avg_loc = total_loc as f64 / functions.len().max(1) as f64;
let avg_params = total_params as f64 / functions.len().max(1) as f64;
let mut function_data = Vec::new();
for func in &functions {
let callers = graph.get_callers(&func.qualified_name);
let fan_in = callers.len();
let fan_out = graph.get_callees(&func.qualified_name).len();
let caller_files: std::collections::HashSet<_> =
callers.iter().map(|c| &c.file_path).collect();
let loc = func.line_end.saturating_sub(func.line_start) + 1;
let address_taken = func
.properties
.get("address_taken")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let features = FunctionFeatures::extract(
&func.name,
&func.file_path,
fan_in,
fan_out,
max_fan_in,
max_fan_out,
caller_files.len(),
func.complexity(),
avg_complexity,
loc,
avg_loc,
3, avg_params,
address_taken,
);
function_data.push((features, fan_in, fan_out, address_taken));
}
classifier.train(&function_data);
if let Some(ref path) = cache_path {
if let Err(e) = std::fs::create_dir_all(path) {
warn!("Failed to create HMM cache directory: {}", e);
} else {
let model_path = path.join("hmm_model.json");
if let Err(e) = classifier.save(&model_path) {
warn!("Failed to save HMM model: {}", e);
} else {
info!("Saved HMM model to {:?}", model_path);
}
}
}
let mut contexts = HashMap::new();
for (func, (features, _, _, _)) in functions.iter().zip(function_data.iter()) {
let context = classifier.classify(&func.qualified_name, features);
contexts.insert(func.qualified_name.clone(), context);
}
info!("Classified {} functions using HMM", contexts.len());
let mut counts = [0usize; 5];
for ctx in contexts.values() {
counts[ctx.index()] += 1;
}
info!(
"Context distribution: Utility={}, Handler={}, Core={}, Internal={}, Test={}",
counts[0], counts[1], counts[2], counts[3], counts[4]
);
let arc = Arc::new(contexts);
self.hmm_contexts = Some(Arc::clone(&arc));
arc
}
pub fn hmm_contexts(&self) -> Option<&Arc<HashMap<String, FunctionContext>>> {
self.hmm_contexts.as_ref()
}
pub fn get_function_context(&self, qualified_name: &str) -> Option<FunctionContext> {
self.hmm_contexts
.as_ref()
.and_then(|ctx| ctx.get(qualified_name).copied())
}
pub fn register(&mut self, detector: Arc<dyn Detector>) {
debug!("Registering detector: {}", detector.name());
self.detectors.push(detector);
}
pub fn register_all(&mut self, detectors: impl IntoIterator<Item = Arc<dyn Detector>>) {
for detector in detectors {
self.register(detector);
}
}
pub fn detector_count(&self) -> usize {
self.detectors.len()
}
pub fn detector_names(&self) -> Vec<&'static str> {
self.detectors.iter().map(|d| d.name()).collect()
}
pub fn run(&mut self, graph: &dyn crate::graph::GraphQuery) -> Result<Vec<Finding>> {
let start = Instant::now();
info!(
"Starting detection with {} detectors on {} workers",
self.detectors.len(),
self.workers
);
let contexts = self.get_or_build_contexts(graph);
let hmm_contexts = self.build_hmm_contexts(graph);
let (independent, dependent): (Vec<_>, Vec<_>) = self
.detectors
.iter()
.cloned()
.partition(|d| !d.is_dependent());
info!(
"Detectors: {} independent, {} dependent",
independent.len(),
dependent.len()
);
let completed = Arc::new(AtomicUsize::new(0));
let total = self.detectors.len();
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(self.workers)
.build()?;
let contexts_for_parallel = Arc::clone(&contexts);
let independent_results: Vec<DetectorResult> = pool.install(|| {
independent
.par_iter()
.map(|detector| {
let result = self.run_single_detector(detector, graph, &contexts_for_parallel);
let done = completed.fetch_add(1, Ordering::SeqCst) + 1;
if let Some(ref callback) = self.progress_callback {
callback(detector.name(), done, total);
}
result
})
.collect()
});
let mut all_findings: Vec<Finding> = Vec::new();
let mut summary = DetectionSummary::default();
for result in independent_results {
summary.add_result(&result);
if result.success {
all_findings.extend(result.findings);
} else if let Some(err) = &result.error {
warn!("Detector {} failed: {}", result.detector_name, err);
}
}
for detector in dependent {
let result = self.run_single_detector(&detector, graph, &contexts);
let done = completed.fetch_add(1, Ordering::SeqCst) + 1;
if let Some(ref callback) = self.progress_callback {
callback(detector.name(), done, total);
}
summary.add_result(&result);
if result.success {
all_findings.extend(result.findings);
} else if let Some(err) = &result.error {
warn!("Detector {} failed: {}", result.detector_name, err);
}
}
if self.skip_test_files {
let before_count = all_findings.len();
all_findings.retain(|finding| !self.is_test_file_finding(finding));
let filtered = before_count - all_findings.len();
if filtered > 0 {
debug!("Filtered out {} findings from test files", filtered);
}
}
let before_hmm = all_findings.len();
all_findings = self.apply_hmm_context_filter(all_findings, &hmm_contexts, graph);
let hmm_filtered = before_hmm - all_findings.len();
if hmm_filtered > 0 {
info!(
"HMM context filter removed {} false positives",
hmm_filtered
);
}
all_findings.sort_by(|a, b| b.severity.cmp(&a.severity));
if all_findings.len() > self.max_findings {
warn!(
"Truncating findings from {} to {} (max limit)",
all_findings.len(),
self.max_findings
);
all_findings.truncate(self.max_findings);
}
let duration = start.elapsed();
info!(
"Detection complete: {} findings from {}/{} detectors in {:?}",
all_findings.len(),
summary.detectors_succeeded,
summary.detectors_run,
duration
);
Ok(all_findings)
}
pub fn run_detailed(
&mut self,
graph: &dyn crate::graph::GraphQuery,
) -> Result<(Vec<DetectorResult>, DetectionSummary)> {
let start = Instant::now();
let contexts = self.get_or_build_contexts(graph);
let (independent, dependent): (Vec<_>, Vec<_>) = self
.detectors
.iter()
.cloned()
.partition(|d| !d.is_dependent());
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(self.workers)
.build()?;
let contexts_for_parallel = Arc::clone(&contexts);
let mut all_results: Vec<DetectorResult> = pool.install(|| {
independent
.par_iter()
.map(|detector| self.run_single_detector(detector, graph, &contexts_for_parallel))
.collect()
});
for detector in dependent {
all_results.push(self.run_single_detector(&detector, graph, &contexts));
}
if self.skip_test_files {
for result in &mut all_results {
let before_count = result.findings.len();
result
.findings
.retain(|finding| !self.is_test_file_finding(finding));
let filtered = before_count - result.findings.len();
if filtered > 0 {
debug!(
"Filtered {} test file findings from {}",
filtered, result.detector_name
);
}
}
}
let mut summary = DetectionSummary::default();
for result in &all_results {
summary.add_result(result);
}
summary.total_duration_ms = start.elapsed().as_millis() as u64;
Ok((all_results, summary))
}
fn apply_hmm_context_filter(
&self,
mut findings: Vec<Finding>,
hmm_contexts: &HashMap<String, FunctionContext>,
graph: &dyn crate::graph::GraphQuery,
) -> Vec<Finding> {
const COUPLING_DETECTORS: &[&str] = &[
"DegreeCentralityDetector",
"ShotgunSurgeryDetector",
"FeatureEnvyDetector",
"InappropriateIntimacyDetector",
];
const DEAD_CODE_DETECTORS: &[&str] = &["UnreachableCodeDetector", "DeadCodeDetector"];
findings.retain(|finding| {
let func_name = self.extract_function_from_finding(finding, graph);
if let Some(name) = func_name {
if let Some(context) = hmm_contexts.get(&name) {
if COUPLING_DETECTORS
.iter()
.any(|d| finding.detector.contains(d))
&& context.skip_coupling()
{
debug!(
"HMM filter: skipping coupling finding for {} (context: {:?})",
name, context
);
return false;
}
if DEAD_CODE_DETECTORS
.iter()
.any(|d| finding.detector.contains(d))
&& context.skip_dead_code()
{
debug!(
"HMM filter: skipping dead code finding for {} (context: {:?})",
name, context
);
return false;
}
}
}
true
});
findings
}
fn extract_function_from_finding(
&self,
finding: &Finding,
graph: &dyn crate::graph::GraphQuery,
) -> Option<String> {
if let (Some(file), Some(line)) = (finding.affected_files.first(), finding.line_start) {
let file_str = file.to_string_lossy();
for func in graph.get_functions() {
if func.file_path == file_str && func.line_start <= line && func.line_end >= line {
return Some(func.qualified_name.clone());
}
}
}
if finding.title.contains(':') {
let parts: Vec<&str> = finding.title.splitn(2, ':').collect();
if parts.len() == 2 {
let name = parts[1].trim();
for func in graph.get_functions() {
if func.name == name || func.qualified_name.ends_with(name) {
return Some(func.qualified_name.clone());
}
}
}
}
None
}
fn is_test_file_finding(&self, finding: &Finding) -> bool {
if finding.affected_files.is_empty() {
return false;
}
finding.affected_files.iter().all(|path| is_test_file(path))
}
fn run_single_detector(
&self,
detector: &Arc<dyn Detector>,
graph: &dyn crate::graph::GraphQuery,
contexts: &Arc<FunctionContextMap>,
) -> DetectorResult {
let name = detector.name().to_string();
let start = Instant::now();
debug!("Running detector: {}", name);
let contexts_clone = Arc::clone(contexts);
let detect_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
if detector.uses_context() {
detector.detect_with_context(graph, &contexts_clone)
} else {
detector.detect(graph)
}
}));
match detect_result {
Ok(Ok(mut findings)) => {
let duration = start.elapsed().as_millis() as u64;
if let Some(config) = detector.config() {
if let Some(max) = config.max_findings {
if findings.len() > max {
findings.truncate(max);
}
}
}
debug!(
"Detector {} found {} findings in {}ms",
name,
findings.len(),
duration
);
DetectorResult::success(name, findings, duration)
}
Ok(Err(e)) => {
let duration = start.elapsed().as_millis() as u64;
debug!("Detector {} skipped (query error): {}", name, e);
DetectorResult::failure(name, e.to_string(), duration)
}
Err(panic_info) => {
let duration = start.elapsed().as_millis() as u64;
let panic_msg = if let Some(s) = panic_info.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = panic_info.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic".to_string()
};
error!("Detector {} panicked: {}", name, panic_msg);
DetectorResult::failure(name, format!("Panic: {}", panic_msg), duration)
}
}
}
}
impl Default for DetectorEngine {
fn default() -> Self {
Self::new(0)
}
}
pub struct DetectorEngineBuilder {
workers: usize,
max_findings: usize,
detectors: Vec<Arc<dyn Detector>>,
progress_callback: Option<ProgressCallback>,
skip_test_files: bool,
}
impl DetectorEngineBuilder {
pub fn new() -> Self {
Self {
workers: 0,
max_findings: MAX_FINDINGS_LIMIT,
detectors: Vec::new(),
progress_callback: None,
skip_test_files: true,
}
}
pub fn workers(mut self, workers: usize) -> Self {
self.workers = workers;
self
}
pub fn max_findings(mut self, max: usize) -> Self {
self.max_findings = max;
self
}
pub fn detector(mut self, detector: Arc<dyn Detector>) -> Self {
self.detectors.push(detector);
self
}
pub fn detectors(mut self, detectors: impl IntoIterator<Item = Arc<dyn Detector>>) -> Self {
self.detectors.extend(detectors);
self
}
pub fn on_progress(mut self, callback: ProgressCallback) -> Self {
self.progress_callback = Some(callback);
self
}
pub fn skip_test_files(mut self, skip: bool) -> Self {
self.skip_test_files = skip;
self
}
pub fn build(self) -> DetectorEngine {
let mut engine = DetectorEngine::new(self.workers)
.with_max_findings(self.max_findings)
.with_skip_test_files(self.skip_test_files);
if let Some(callback) = self.progress_callback {
engine = engine.with_progress_callback(callback);
}
engine.register_all(self.detectors);
engine
}
}
impl Default for DetectorEngineBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::Severity;
use std::path::PathBuf;
struct MockDetector {
name: &'static str,
findings_count: usize,
dependent: bool,
}
impl Detector for MockDetector {
fn name(&self) -> &'static str {
self.name
}
fn description(&self) -> &'static str {
"Mock detector for testing"
}
fn detect(&self, _graph: &dyn crate::graph::GraphQuery) -> Result<Vec<Finding>> {
Ok((0..self.findings_count)
.map(|i| Finding {
id: format!("{}-{}", self.name, i),
detector: self.name.to_string(),
severity: Severity::Medium,
title: format!("Finding {}", i),
description: "Test finding".to_string(),
affected_files: vec![PathBuf::from("test.py")],
line_start: Some(1),
line_end: Some(10),
suggested_fix: None,
estimated_effort: None,
category: None,
cwe_id: None,
why_it_matters: None,
..Default::default()
})
.collect())
}
fn is_dependent(&self) -> bool {
self.dependent
}
}
#[test]
fn test_engine_creation() {
let engine = DetectorEngine::new(4);
assert_eq!(engine.workers, 4);
assert_eq!(engine.detector_count(), 0);
}
#[test]
fn test_engine_default_workers() {
let engine = DetectorEngine::new(0);
assert!(engine.workers > 0);
assert!(engine.workers <= 16);
}
#[test]
fn test_register_detectors() {
let mut engine = DetectorEngine::new(2);
engine.register(Arc::new(MockDetector {
name: "Detector1",
findings_count: 5,
dependent: false,
}));
engine.register(Arc::new(MockDetector {
name: "Detector2",
findings_count: 3,
dependent: true,
}));
assert_eq!(engine.detector_count(), 2);
assert_eq!(engine.detector_names(), vec!["Detector1", "Detector2"]);
}
#[test]
fn test_builder() {
let engine = DetectorEngineBuilder::new()
.workers(4)
.max_findings(100)
.detector(Arc::new(MockDetector {
name: "Test",
findings_count: 1,
dependent: false,
}))
.build();
assert_eq!(engine.workers, 4);
assert_eq!(engine.max_findings, 100);
assert_eq!(engine.detector_count(), 1);
}
}