use super::auto_learning::{AutoLearner, AutoVerification};
use super::federated::FederatedClient;
use super::fp_classifier::FalsePositiveClassifier;
use super::privacy::PrivacyManager;
use super::training_data::VerificationStatus;
use crate::http_client::HttpResponse;
use crate::types::Vulnerability;
use anyhow::Result;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
pub struct MlPipeline {
auto_learner: Arc<RwLock<AutoLearner>>,
fp_classifier: Arc<RwLock<FalsePositiveClassifier>>,
federated_client: Arc<RwLock<FederatedClient>>,
privacy_manager: Arc<RwLock<PrivacyManager>>,
enabled: bool,
findings_processed: usize,
auto_confirmed: usize,
auto_rejected: usize,
}
impl MlPipeline {
pub fn new() -> Result<Self> {
let privacy_manager = PrivacyManager::new()?;
let enabled = privacy_manager.is_ml_allowed();
Ok(Self {
auto_learner: Arc::new(RwLock::new(AutoLearner::new()?)),
fp_classifier: Arc::new(RwLock::new(FalsePositiveClassifier::new()?)),
federated_client: Arc::new(RwLock::new(FederatedClient::new()?)),
privacy_manager: Arc::new(RwLock::new(privacy_manager)),
enabled,
findings_processed: 0,
auto_confirmed: 0,
auto_rejected: 0,
})
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub async fn process_features(
&mut self,
vuln: &Vulnerability,
features: &super::VulnFeatures,
) -> Result<bool> {
if !self.enabled {
return Ok(false);
}
let mut learner = self.auto_learner.write().await;
let verification = learner.learn_from_features(vuln, features)?;
self.findings_processed += 1;
match verification.status {
VerificationStatus::Confirmed => {
self.auto_confirmed += 1;
debug!(
"ML: Auto-confirmed {} at {} (confidence: {:.0}%)",
vuln.vuln_type,
vuln.url,
verification.confidence * 100.0
);
}
VerificationStatus::FalsePositive => {
self.auto_rejected += 1;
debug!(
"ML: Auto-rejected {} at {} as FP (confidence: {:.0}%)",
vuln.vuln_type,
vuln.url,
verification.confidence * 100.0
);
}
VerificationStatus::Unverified => {
debug!(
"ML: {} at {} needs more data (confidence: {:.0}%)",
vuln.vuln_type,
vuln.url,
verification.confidence * 100.0
);
}
}
Ok(true)
}
pub async fn process_finding(
&mut self,
vuln: &Vulnerability,
response: &HttpResponse,
baseline: Option<&HttpResponse>,
payload: Option<&str>,
) -> Result<Option<AutoVerification>> {
if !self.enabled {
return Ok(None);
}
let mut learner = self.auto_learner.write().await;
let verification = learner.learn_from_finding(vuln, response, baseline, payload)?;
self.findings_processed += 1;
match verification.status {
VerificationStatus::Confirmed => {
self.auto_confirmed += 1;
debug!(
"ML: Auto-confirmed {} at {} (confidence: {:.0}%)",
vuln.vuln_type,
vuln.url,
verification.confidence * 100.0
);
}
VerificationStatus::FalsePositive => {
self.auto_rejected += 1;
debug!(
"ML: Auto-rejected {} at {} as FP (confidence: {:.0}%)",
vuln.vuln_type,
vuln.url,
verification.confidence * 100.0
);
}
VerificationStatus::Unverified => {
debug!(
"ML: {} at {} needs more data (confidence: {:.0}%)",
vuln.vuln_type,
vuln.url,
verification.confidence * 100.0
);
}
}
Ok(Some(verification))
}
pub async fn process_findings_batch(
&mut self,
findings: &[(
Vulnerability,
HttpResponse,
Option<HttpResponse>,
Option<String>,
)],
) -> Result<Vec<AutoVerification>> {
if !self.enabled {
return Ok(Vec::new());
}
let mut verifications = Vec::with_capacity(findings.len());
for (vuln, response, baseline, payload) in findings {
if let Some(verification) = self
.process_finding(vuln, response, baseline.as_ref(), payload.as_deref())
.await?
{
verifications.push(verification);
}
}
Ok(verifications)
}
pub async fn predict_false_positive(
&self,
_vuln: &Vulnerability,
response: &HttpResponse,
baseline: Option<&HttpResponse>,
payload: Option<&str>,
) -> Result<f32> {
if !self.enabled {
return Ok(0.5); }
let learner = self.auto_learner.read().await;
let features = learner
.feature_extractor
.extract(response, baseline, payload);
let classifier = self.fp_classifier.read().await;
let prediction = classifier.predict(&features);
Ok(1.0 - prediction.true_positive_probability)
}
pub async fn filter_likely_true_positives(
&self,
findings: Vec<(
Vulnerability,
HttpResponse,
Option<HttpResponse>,
Option<String>,
)>,
threshold: f32,
) -> Result<Vec<Vulnerability>> {
if !self.enabled {
return Ok(findings.into_iter().map(|(v, _, _, _)| v).collect());
}
let mut filtered = Vec::new();
for (vuln, response, baseline, payload) in findings {
let fp_prob = self
.predict_false_positive(&vuln, &response, baseline.as_ref(), payload.as_deref())
.await?;
if fp_prob < threshold {
filtered.push(vuln);
} else {
debug!(
"ML: Filtered {} at {} (FP probability: {:.0}%)",
vuln.vuln_type,
vuln.url,
fp_prob * 100.0
);
}
}
Ok(filtered)
}
pub async fn predict_false_positive_from_features(
&self,
features: &super::VulnFeatures,
) -> Result<f32> {
if !self.enabled {
return Ok(0.5); }
let classifier = self.fp_classifier.read().await;
let prediction = classifier.predict(features);
Ok(1.0 - prediction.true_positive_probability)
}
pub async fn filter_vulns_by_features(
&self,
vulns: Vec<Vulnerability>,
threshold: f32,
) -> Result<(Vec<Vulnerability>, usize)> {
if !self.enabled {
return Ok((vulns, 0));
}
let mut filtered = Vec::new();
let mut filtered_count = 0;
for vuln in vulns {
if let Some(ref ml_data) = vuln.ml_data {
let fp_prob = self
.predict_false_positive_from_features(&ml_data.features)
.await?;
if fp_prob < threshold {
filtered.push(vuln);
} else {
filtered_count += 1;
debug!(
"ML: Filtered {} at {} (FP probability: {:.0}%)",
vuln.vuln_type,
vuln.url,
fp_prob * 100.0
);
}
} else {
filtered.push(vuln);
}
}
Ok((filtered, filtered_count))
}
pub async fn on_scan_complete(&mut self) -> Result<()> {
if !self.enabled {
return Ok(());
}
info!(
"ML: Scan complete - processed {} findings ({} confirmed, {} rejected)",
self.findings_processed, self.auto_confirmed, self.auto_rejected
);
let privacy = self.privacy_manager.read().await;
if privacy.is_federated_allowed() {
drop(privacy);
let mut federated = self.federated_client.write().await;
if let Ok(Some(model)) = federated.fetch_global_model().await {
info!(
"ML: Fetched global model v{} ({} contributors)",
model.global_version, model.contributor_count
);
}
if federated.can_contribute() {
match federated.contribute_weights().await {
Ok(true) => info!("ML: Contributed to federated network"),
Ok(false) => debug!("ML: Not enough data to contribute yet"),
Err(e) => warn!("ML: Failed to contribute: {}", e),
}
}
if let Ok(count) = federated.upload_pending().await {
if count > 0 {
info!("ML: Uploaded {} pending contributions", count);
}
}
}
self.findings_processed = 0;
self.auto_confirmed = 0;
self.auto_rejected = 0;
Ok(())
}
pub async fn enable(&mut self, federated_opt_in: bool) -> Result<()> {
let mut privacy = self.privacy_manager.write().await;
privacy.record_consent(federated_opt_in)?;
self.enabled = true;
info!("ML: Enabled (federated: {})", federated_opt_in);
Ok(())
}
pub async fn disable(&mut self, delete_data: bool) -> Result<()> {
self.enabled = false;
if delete_data {
let mut privacy = self.privacy_manager.write().await;
privacy.withdraw_consent()?;
info!("ML: Disabled and all data deleted");
} else {
info!("ML: Disabled (data retained)");
}
Ok(())
}
pub async fn get_stats(&self) -> MlPipelineStats {
let learner = self.auto_learner.read().await;
let learning_stats = learner.get_stats();
let federated = self.federated_client.read().await;
let federated_stats = federated.get_stats();
MlPipelineStats {
enabled: self.enabled,
session_processed: self.findings_processed,
session_confirmed: self.auto_confirmed,
session_rejected: self.auto_rejected,
total_confirmed: learning_stats.auto_confirmed,
total_rejected: learning_stats.auto_rejected,
pending_learning: learning_stats.pending_learning,
endpoint_patterns: learning_stats.endpoint_patterns,
federated_enabled: federated_stats.has_global_model,
federated_contributors: federated_stats.global_contributors,
can_contribute: federated_stats.can_contribute,
}
}
}
impl Default for MlPipeline {
fn default() -> Self {
Self::new().unwrap_or_else(|e| {
warn!("ML: Failed to initialize pipeline: {}", e);
Self {
auto_learner: Arc::new(RwLock::new(AutoLearner::default())),
fp_classifier: Arc::new(RwLock::new(FalsePositiveClassifier::default())),
federated_client: Arc::new(RwLock::new(FederatedClient::default())),
privacy_manager: Arc::new(RwLock::new(PrivacyManager::default())),
enabled: false,
findings_processed: 0,
auto_confirmed: 0,
auto_rejected: 0,
}
})
}
}
#[derive(Debug, Clone)]
pub struct MlPipelineStats {
pub enabled: bool,
pub session_processed: usize,
pub session_confirmed: usize,
pub session_rejected: usize,
pub total_confirmed: usize,
pub total_rejected: usize,
pub pending_learning: usize,
pub endpoint_patterns: usize,
pub federated_enabled: bool,
pub federated_contributors: Option<usize>,
pub can_contribute: bool,
}
pub struct MlIntegration {
pipeline: Arc<RwLock<MlPipeline>>,
}
impl MlIntegration {
pub fn new() -> Result<Self> {
Ok(Self {
pipeline: Arc::new(RwLock::new(MlPipeline::new()?)),
})
}
pub async fn learn(&self, vuln: &Vulnerability, response: &HttpResponse) -> Result<()> {
let mut pipeline = self.pipeline.write().await;
pipeline.process_finding(vuln, response, None, None).await?;
Ok(())
}
pub async fn scan_complete(&self) -> Result<()> {
let mut pipeline = self.pipeline.write().await;
pipeline.on_scan_complete().await
}
pub fn pipeline(&self) -> Arc<RwLock<MlPipeline>> {
Arc::clone(&self.pipeline)
}
}
impl Default for MlIntegration {
fn default() -> Self {
Self::new().unwrap_or_else(|_| Self {
pipeline: Arc::new(RwLock::new(MlPipeline::default())),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Confidence, Severity};
use std::collections::HashMap;
fn create_test_vuln() -> Vulnerability {
Vulnerability {
id: "test-123".to_string(),
vuln_type: "SQL Injection".to_string(),
severity: Severity::High,
confidence: Confidence::High,
category: "Injection".to_string(),
url: "https://example.com/api/users/123".to_string(),
parameter: Some("id".to_string()),
payload: Some("' OR '1'='1".to_string()),
description: "Test SQL injection".to_string(),
evidence: None,
cwe: Some("CWE-89".to_string()),
cvss: None,
verified: false,
false_positive: false,
remediation: None,
discovered_at: chrono::Utc::now().to_rfc3339(),
}
}
fn create_test_response(body: &str, status: u16) -> HttpResponse {
HttpResponse {
status_code: status,
headers: HashMap::new(),
body: body.to_string(),
duration_ms: 100,
}
}
#[tokio::test]
async fn test_pipeline_creation() {
let pipeline = MlPipeline::new();
assert!(pipeline.is_ok() || pipeline.is_err());
}
}