use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DetectorCategory {
Security,
CodeQuality,
MachineLearning,
Performance,
Other,
}
impl DetectorCategory {
pub fn from_detector(detector: &str) -> Self {
let lower = detector.to_lowercase();
if lower.contains("injection")
|| lower.contains("xss")
|| lower.contains("traversal")
|| lower.contains("crypto")
|| lower.contains("credential")
|| lower.contains("secret")
|| lower.contains("auth")
|| lower.contains("csrf")
|| lower.contains("ssrf")
|| lower.contains("xxe")
|| lower.contains("deserializ")
|| lower.contains("eval")
{
return Self::Security;
}
if lower.contains("torch")
|| lower.contains("tensorflow")
|| lower.contains("keras")
|| lower.contains("pytorch")
|| lower.contains("grad")
|| lower.contains("nan")
|| lower.contains("forward")
|| lower.contains("seed")
|| lower.contains("chain_index")
|| lower.contains("deprecated")
{
return Self::MachineLearning;
}
if lower.contains("n+1")
|| lower.contains("nplus")
|| lower.contains("lazy")
|| lower.contains("cache")
|| lower.contains("bottleneck")
|| lower.contains("performance")
{
return Self::Performance;
}
if lower.contains("complexity")
|| lower.contains("coupling")
|| lower.contains("dead")
|| lower.contains("unreachable")
|| lower.contains("god")
|| lower.contains("long")
|| lower.contains("envy")
|| lower.contains("intimacy")
|| lower.contains("duplicate")
|| lower.contains("magic")
|| lower.contains("inconsistent")
|| lower.contains("centrality")
|| lower.contains("cohesion")
{
return Self::CodeQuality;
}
Self::Other
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CategoryThresholds {
configs: HashMap<DetectorCategory, ThresholdConfig>,
default: ThresholdConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThresholdConfig {
pub filter_threshold: f32,
pub high_confidence_threshold: f32,
pub likely_fp_threshold: f32,
pub feature_weight_multiplier: f32,
}
impl Default for ThresholdConfig {
fn default() -> Self {
Self {
filter_threshold: 0.5,
high_confidence_threshold: 0.8,
likely_fp_threshold: 0.3,
feature_weight_multiplier: 1.0,
}
}
}
impl Default for CategoryThresholds {
fn default() -> Self {
let mut configs = HashMap::new();
configs.insert(
DetectorCategory::Security,
ThresholdConfig {
filter_threshold: 0.35, high_confidence_threshold: 0.85,
likely_fp_threshold: 0.2, feature_weight_multiplier: 1.2, },
);
configs.insert(
DetectorCategory::CodeQuality,
ThresholdConfig {
filter_threshold: 0.52, high_confidence_threshold: 0.75,
likely_fp_threshold: 0.45, feature_weight_multiplier: 1.0,
},
);
configs.insert(
DetectorCategory::MachineLearning,
ThresholdConfig {
filter_threshold: 0.45,
high_confidence_threshold: 0.8,
likely_fp_threshold: 0.35,
feature_weight_multiplier: 1.1,
},
);
configs.insert(
DetectorCategory::Performance,
ThresholdConfig {
filter_threshold: 0.52, high_confidence_threshold: 0.75,
likely_fp_threshold: 0.40,
feature_weight_multiplier: 1.0,
},
);
configs.insert(DetectorCategory::Other, ThresholdConfig::default());
Self {
configs,
default: ThresholdConfig::default(),
}
}
}
impl CategoryThresholds {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, detector: &str) -> &ThresholdConfig {
let category = DetectorCategory::from_detector(detector);
self.configs.get(&category).unwrap_or(&self.default)
}
pub fn get_category(&self, category: DetectorCategory) -> &ThresholdConfig {
self.configs.get(&category).unwrap_or(&self.default)
}
pub fn set(&mut self, category: DetectorCategory, config: ThresholdConfig) {
self.configs.insert(category, config);
}
pub fn load(path: &std::path::Path) -> Result<Self, std::io::Error> {
let content = std::fs::read_to_string(path)?;
serde_json::from_str(&content)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
pub fn save(&self, path: &std::path::Path) -> Result<(), std::io::Error> {
let content = serde_json::to_string_pretty(self)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
std::fs::write(path, content)
}
pub fn should_filter(&self, detector: &str, tp_probability: f32) -> bool {
let config = self.get(detector);
tp_probability < config.filter_threshold
}
pub fn is_high_confidence(&self, detector: &str, tp_probability: f32) -> bool {
let config = self.get(detector);
tp_probability >= config.high_confidence_threshold
}
pub fn is_likely_fp(&self, detector: &str, tp_probability: f32) -> bool {
let config = self.get(detector);
tp_probability < config.likely_fp_threshold
}
}
#[derive(Debug, Clone)]
pub struct CategoryAwarePrediction {
pub tp_probability: f32,
pub fp_probability: f32,
pub category: DetectorCategory,
pub is_true_positive: bool,
pub high_confidence: bool,
pub likely_fp: bool,
pub should_filter: bool,
}
impl CategoryAwarePrediction {
pub fn from_prediction(
tp_probability: f32,
detector: &str,
thresholds: &CategoryThresholds,
) -> Self {
let category = DetectorCategory::from_detector(detector);
let config = thresholds.get_category(category);
Self {
tp_probability,
fp_probability: 1.0 - tp_probability,
category,
is_true_positive: tp_probability >= config.filter_threshold,
high_confidence: tp_probability >= config.high_confidence_threshold,
likely_fp: tp_probability < config.likely_fp_threshold,
should_filter: tp_probability < config.filter_threshold,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detector_categorization() {
assert_eq!(
DetectorCategory::from_detector("SQLInjectionDetector"),
DetectorCategory::Security
);
assert_eq!(
DetectorCategory::from_detector("CommandInjectionDetector"),
DetectorCategory::Security
);
assert_eq!(
DetectorCategory::from_detector("TorchLoadUnsafeDetector"),
DetectorCategory::MachineLearning
);
assert_eq!(
DetectorCategory::from_detector("MissingZeroGradDetector"),
DetectorCategory::MachineLearning
);
assert_eq!(
DetectorCategory::from_detector("ComplexitySpike"),
DetectorCategory::CodeQuality
);
assert_eq!(
DetectorCategory::from_detector("NPlusOneDetector"),
DetectorCategory::Performance
);
assert_eq!(
DetectorCategory::from_detector("SomeRandomDetector"),
DetectorCategory::Other
);
}
#[test]
fn test_category_thresholds() {
let thresholds = CategoryThresholds::default();
let security = thresholds.get("SQLInjectionDetector");
let quality = thresholds.get("ComplexitySpike");
assert!(security.filter_threshold < quality.filter_threshold);
}
#[test]
fn test_filtering_decisions() {
let thresholds = CategoryThresholds::default();
assert!(!thresholds.should_filter("SQLInjectionDetector", 0.40));
assert!(thresholds.should_filter("ComplexitySpike", 0.40));
assert!(!thresholds.should_filter("SQLInjectionDetector", 0.60));
assert!(!thresholds.should_filter("ComplexitySpike", 0.60));
}
#[test]
fn test_category_aware_prediction() {
let thresholds = CategoryThresholds::default();
let pred =
CategoryAwarePrediction::from_prediction(0.40, "SQLInjectionDetector", &thresholds);
assert!(!pred.should_filter);
assert!(pred.is_true_positive);
assert!(!pred.high_confidence);
let pred = CategoryAwarePrediction::from_prediction(0.40, "ComplexitySpike", &thresholds);
assert!(pred.should_filter);
assert!(!pred.is_true_positive);
}
#[test]
fn test_save_load() {
let thresholds = CategoryThresholds::default();
let path = std::path::Path::new("/tmp/test_thresholds.json");
thresholds.save(path).unwrap();
let loaded = CategoryThresholds::load(path).unwrap();
let orig = thresholds.get("SQLInjectionDetector");
let load = loaded.get("SQLInjectionDetector");
assert_eq!(orig.filter_threshold, load.filter_threshold);
std::fs::remove_file(path).ok();
}
}