use std::collections::HashMap;
use tracing::{debug, info};
use crate::features::{self, ProbeContext};
use crate::scorer::ModelScorer;
use crate::types::Vulnerability;
pub struct MlEnhancer {
scorer: ModelScorer,
min_confidence_threshold: f64,
}
impl MlEnhancer {
pub fn new(scorer: ModelScorer) -> Self {
Self {
scorer,
min_confidence_threshold: 0.15, }
}
pub fn score_finding(
&self,
vuln: &Vulnerability,
probe_context: Option<&ProbeContext>,
) -> Option<(f64, HashMap<String, f64>)> {
let mut extracted_features = HashMap::new();
if let Some(ctx) = probe_context {
extracted_features = features::extract_features(ctx);
}
self.enrich_features_from_vuln(&mut extracted_features, vuln);
if extracted_features.is_empty() {
return None;
}
let result = self.scorer.score_detailed(&extracted_features);
Some((result.confidence, extracted_features))
}
pub fn enhance_findings(&self, vulns: &mut Vec<Vulnerability>, filter_fps: bool) {
let original_count = vulns.len();
let chain_features = features::extract_chain_features(vulns);
let mut enhanced = Vec::with_capacity(vulns.len());
for vuln in vulns.drain(..) {
let mut extracted_features = HashMap::new();
self.enrich_features_from_vuln(&mut extracted_features, &vuln);
extracted_features.extend(chain_features.clone());
if extracted_features.is_empty() {
enhanced.push(vuln);
continue;
}
let result = self.scorer.score_detailed(&extracted_features);
let ml_confidence = result.confidence;
let mut v = vuln;
v.ml_confidence = Some(ml_confidence);
if filter_fps && ml_confidence < self.min_confidence_threshold {
debug!(
"[ML] Filtered FP: {} on {} (ml_conf={:.3})",
v.vuln_type, v.url, ml_confidence
);
continue; }
enhanced.push(v);
}
let filtered = original_count - enhanced.len();
if filtered > 0 {
info!(
"[ML] Enhanced {} findings, filtered {} likely FPs",
enhanced.len(),
filtered
);
}
*vulns = enhanced;
}
fn enrich_features_from_vuln(
&self,
extracted_features: &mut HashMap<String, f64>,
vuln: &Vulnerability,
) {
let vuln_type_lower = vuln.vuln_type.to_lowercase();
let category = if vuln_type_lower.contains("sql") || vuln_type_lower.contains("sqli") {
"sqli"
} else if vuln_type_lower.contains("xss")
|| vuln_type_lower.contains("cross-site scripting")
{
"xss"
} else if vuln_type_lower.contains("ssrf") {
"ssrf"
} else if vuln_type_lower.contains("traversal") || vuln_type_lower.contains("lfi") {
"traversal"
} else if vuln_type_lower.contains("command")
|| vuln_type_lower.contains("cmdi")
|| vuln_type_lower.contains("rce")
{
"cmdi"
} else if vuln_type_lower.contains("ssti") || vuln_type_lower.contains("template") {
"ssti"
} else if vuln_type_lower.contains("auth") || vuln_type_lower.contains("jwt") {
"auth"
} else if vuln_type_lower.contains("idor") {
"idor"
} else if vuln_type_lower.contains("xxe") {
"xxe"
} else if vuln_type_lower.contains("redirect") {
"redirect"
} else if vuln_type_lower.contains("csrf") {
"csrf"
} else if vuln_type_lower.contains("deser") {
"deser"
} else if vuln_type_lower.contains("upload") {
"upload"
} else if vuln_type_lower.contains("graphql") {
"graphql"
} else if vuln_type_lower.contains("nosql") {
"nosqli"
} else if vuln_type_lower.contains("cors") || vuln_type_lower.contains("config") {
"config"
} else if vuln_type_lower.contains("cloud")
|| vuln_type_lower.contains("s3")
|| vuln_type_lower.contains("aws")
{
"cloud"
} else if vuln_type_lower.contains("api") || vuln_type_lower.contains("bola") {
"api"
} else {
"signal"
};
let url_lower = vuln.url.to_lowercase();
if url_lower.contains("/login")
|| url_lower.contains("/auth")
|| url_lower.contains("/signin")
{
extracted_features.insert("severity:endpoint_is_login".into(), 1.0);
}
if url_lower.contains("/admin") || url_lower.contains("/dashboard") {
extracted_features.insert("severity:endpoint_is_admin".into(), 1.0);
}
if url_lower.contains("/api/")
|| url_lower.contains("/v1/")
|| url_lower.contains("/v2/")
{
extracted_features.insert("severity:endpoint_is_api".into(), 1.0);
}
if url_lower.contains("/pay")
|| url_lower.contains("/checkout")
|| url_lower.contains("/billing")
{
extracted_features.insert("severity:endpoint_is_payment".into(), 1.0);
}
if url_lower.contains("/health")
|| url_lower.contains("/status")
|| url_lower.contains("/ping")
{
extracted_features.insert("api:endpoint_is_healthcheck".into(), 1.0);
}
if let Some(ref evidence) = vuln.evidence {
let evidence_lower = evidence.to_lowercase();
if evidence_lower.contains("reflected") || evidence_lower.contains("echo") {
extracted_features.insert("signal:input_reflected_anywhere".into(), 1.0);
}
if evidence_lower.contains("error") || evidence_lower.contains("exception") || evidence_lower.contains("traceback") {
extracted_features.insert("signal:error_triggered".into(), 1.0);
}
if evidence_lower.contains("time") || evidence_lower.contains("delay") || evidence_lower.contains("sleep") {
extracted_features.insert("signal:response_time_anomaly".into(), 1.0);
}
if evidence_lower.contains("accesskeyid") || evidence_lower.contains("secretaccesskey") {
extracted_features.insert("cloud:iam_role_credentials_leaked".into(), 1.0);
}
if evidence_lower.contains("169.254.169.254") || evidence_lower.contains("metadata") {
extracted_features.insert("cloud:imds_v1_accessible".into(), 1.0);
}
if evidence_lower.contains("listbucketresult") {
extracted_features.insert("cloud:s3_bucket_listing".into(), 1.0);
}
}
let conf = match vuln.confidence {
crate::types::Confidence::High => 0.95,
crate::types::Confidence::Medium => 0.75,
crate::types::Confidence::Low => 0.4,
};
if conf > 0.9 {
match category {
"sqli" => {
extracted_features.insert("sqli:union_select_reflected".into(), 1.0);
}
"xss" => {
extracted_features.insert("xss:script_tag_reflected".into(), 1.0);
}
"ssrf" => {
extracted_features.insert("ssrf:cloud_metadata_accessed".into(), 1.0);
}
"cmdi" => {
extracted_features.insert("cmdi:os_command_output".into(), 1.0);
}
"ssti" => {
extracted_features.insert("ssti:rce_via_template".into(), 1.0);
}
"traversal" => {
extracted_features.insert("traversal:etc_passwd_content".into(), 1.0);
}
_ => {}
}
} else if conf > 0.7 {
match category {
"sqli" => {
extracted_features.insert("sqli:error_generic_db".into(), 1.0);
}
"xss" => {
extracted_features.insert("xss:reflection_unencoded".into(), 1.0);
}
"ssrf" => {
extracted_features.insert("ssrf:internal_ip_in_response".into(), 1.0);
}
"cmdi" => {
extracted_features.insert("cmdi:time_delay_via_sleep".into(), 1.0);
}
_ => {}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Confidence, Severity};
fn make_scorer() -> ModelScorer {
ModelScorer {
weights: [
("sqli:error_mysql_syntax".to_string(), 2.0),
("sqli:union_select_reflected".to_string(), 3.0),
("signal:error_triggered".to_string(), 0.5),
("severity:endpoint_is_api".to_string(), 0.3),
("xss:reflection_unencoded".to_string(), 1.5),
("xss:script_tag_reflected".to_string(), 2.5),
]
.into_iter()
.collect(),
bias: -0.35,
}
}
fn make_vuln(vuln_type: &str, url: &str, confidence: Confidence) -> Vulnerability {
Vulnerability {
id: "test-1".to_string(),
vuln_type: vuln_type.to_string(),
severity: Severity::High,
confidence,
category: "Injection".to_string(),
url: url.to_string(),
parameter: Some("id".to_string()),
payload: "' OR 1=1--".to_string(),
description: "Test vulnerability".to_string(),
evidence: Some("SQL error in response".to_string()),
cwe: "CWE-89".to_string(),
cvss: 8.5,
verified: false,
false_positive: false,
remediation: "Use parameterized queries".to_string(),
discovered_at: "2026-01-01T00:00:00Z".to_string(),
ml_confidence: None,
ml_data: None,
}
}
#[test]
fn test_score_finding_from_vuln_metadata() {
let enhancer = MlEnhancer::new(make_scorer());
let vuln = make_vuln(
"SQL Injection",
"https://example.com/api/users?id=1",
Confidence::High,
);
let result = enhancer.score_finding(&vuln, None);
assert!(result.is_some());
let (confidence, features) = result.unwrap();
assert!(confidence > 0.5); assert!(features.contains_key("sqli:union_select_reflected")); assert!(features.contains_key("severity:endpoint_is_api"));
assert!(features.contains_key("signal:error_triggered")); }
#[test]
fn test_enhance_findings_keeps_high_confidence() {
let enhancer = MlEnhancer::new(make_scorer());
let mut vulns = vec![make_vuln(
"SQL Injection",
"https://example.com/api/v1/users",
Confidence::High,
)];
enhancer.enhance_findings(&mut vulns, true);
assert_eq!(vulns.len(), 1);
assert!(vulns[0].ml_confidence.is_some());
assert!(vulns[0].ml_confidence.unwrap() > 0.15);
}
#[test]
fn test_enhance_findings_sets_ml_confidence() {
let enhancer = MlEnhancer::new(make_scorer());
let mut vulns = vec![make_vuln(
"Cross-Site Scripting (XSS)",
"https://example.com/search",
Confidence::Medium,
)];
enhancer.enhance_findings(&mut vulns, false);
assert_eq!(vulns.len(), 1);
assert!(vulns[0].ml_confidence.is_some());
}
#[test]
fn test_enrich_endpoint_features() {
let enhancer = MlEnhancer::new(make_scorer());
let vuln = make_vuln(
"SQL Injection",
"https://example.com/admin/dashboard",
Confidence::Medium,
);
let (_, features) = enhancer.score_finding(&vuln, None).unwrap();
assert!(features.contains_key("severity:endpoint_is_admin"));
}
#[test]
fn test_enrich_payment_endpoint() {
let enhancer = MlEnhancer::new(make_scorer());
let vuln = make_vuln(
"SQL Injection",
"https://shop.example.com/checkout",
Confidence::High,
);
let (_, features) = enhancer.score_finding(&vuln, None).unwrap();
assert!(features.contains_key("severity:endpoint_is_payment"));
}
}