use super::dos_analyzer::DosAnalyzer;
use super::error_analyzer::ErrorAnalyzer;
use super::pii_detector::PiiDetector;
use super::remediation_generator::RemediationGenerator;
use super::schema_analyzer::SchemaAnalyzer;
use super::types::{
AggregationLevel, ThreatAssessment, ThreatCategory, ThreatFinding, ThreatLevel,
ThreatModelingConfig,
};
use chrono::Utc;
use mockforge_foundation::Result;
use mockforge_openapi::OpenApiSpec;
pub struct ThreatAnalyzer {
pii_detector: PiiDetector,
dos_analyzer: DosAnalyzer,
error_analyzer: ErrorAnalyzer,
schema_analyzer: SchemaAnalyzer,
remediation_generator: Option<RemediationGenerator>,
config: ThreatModelingConfig,
}
impl ThreatAnalyzer {
pub fn new(config: ThreatModelingConfig) -> Result<Self> {
let pii_detector = PiiDetector::new(config.pii_patterns.clone());
let dos_analyzer = DosAnalyzer::default();
let error_analyzer = ErrorAnalyzer::new(config.detect_error_leakage);
let schema_analyzer = SchemaAnalyzer::new(config.max_optional_fields_threshold);
let remediation_generator = if config.generate_remediations {
Some(RemediationGenerator::new(
true,
"openai".to_string(),
"gpt-4".to_string(),
None,
)?)
} else {
None
};
Ok(Self {
pii_detector,
dos_analyzer,
error_analyzer,
schema_analyzer,
remediation_generator,
config,
})
}
pub async fn analyze_contract(
&self,
spec: &OpenApiSpec,
workspace_id: Option<String>,
service_id: Option<String>,
service_name: Option<String>,
endpoint: Option<String>,
method: Option<String>,
) -> Result<ThreatAssessment> {
if !self.config.enabled {
return Ok(ThreatAssessment {
workspace_id,
service_id,
service_name,
endpoint: endpoint.clone(),
method: method.clone(),
aggregation_level: self
.determine_aggregation_level(endpoint.as_ref(), method.as_ref()),
threat_level: ThreatLevel::Low,
threat_score: 0.0,
threat_categories: Vec::new(),
findings: Vec::new(),
remediation_suggestions: Vec::new(),
assessed_at: Utc::now(),
});
}
let mut all_findings = Vec::new();
all_findings.extend(self.pii_detector.detect_pii(spec));
all_findings.extend(self.dos_analyzer.analyze_dos_risks(spec));
all_findings.extend(self.error_analyzer.analyze_errors(spec));
all_findings.extend(self.schema_analyzer.analyze_schemas(spec));
let remediation_suggestions = if let Some(ref generator) = self.remediation_generator {
generator.generate_remediations(&all_findings).await.unwrap_or_default()
} else {
Vec::new()
};
let threat_score = self.calculate_threat_score(&all_findings);
let threat_level = self.determine_threat_level(threat_score, &all_findings);
let threat_categories: Vec<ThreatCategory> = all_findings
.iter()
.map(|f| f.finding_type.clone())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
Ok(ThreatAssessment {
workspace_id,
service_id,
service_name,
endpoint: endpoint.clone(),
method: method.clone(),
aggregation_level: self.determine_aggregation_level(endpoint.as_ref(), method.as_ref()),
threat_level,
threat_score,
threat_categories,
findings: all_findings,
remediation_suggestions,
assessed_at: Utc::now(),
})
}
fn determine_aggregation_level(
&self,
endpoint: Option<&String>,
method: Option<&String>,
) -> AggregationLevel {
if endpoint.is_some() && method.is_some() {
AggregationLevel::Endpoint
} else {
AggregationLevel::Service
}
}
fn calculate_threat_score(&self, findings: &[ThreatFinding]) -> f64 {
if findings.is_empty() {
return 0.0;
}
let total_score: f64 = findings
.iter()
.map(|f| {
let severity_score = match f.severity {
ThreatLevel::Critical => 1.0,
ThreatLevel::High => 0.75,
ThreatLevel::Medium => 0.5,
ThreatLevel::Low => 0.25,
};
severity_score * f.confidence
})
.sum();
(total_score / findings.len() as f64).min(1.0)
}
fn determine_threat_level(&self, score: f64, findings: &[ThreatFinding]) -> ThreatLevel {
let has_critical = findings.iter().any(|f| matches!(f.severity, ThreatLevel::Critical));
if has_critical {
return ThreatLevel::Critical;
}
if score >= 0.75 {
ThreatLevel::High
} else if score >= 0.5 {
ThreatLevel::Medium
} else {
ThreatLevel::Low
}
}
}
impl Default for ThreatAnalyzer {
fn default() -> Self {
Self::new(ThreatModelingConfig::default()).unwrap()
}
}