use std::collections::HashMap;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::Result;
use crate::adaptive::{ChallengeType, AlertPriority};
use crate::meta_learning::ThreatIncident;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MitigationAction {
BlockRequest {
reason: String,
},
RateLimitUser {
duration: Duration,
},
RequireVerification {
challenge_type: ChallengeType,
},
AlertHuman {
priority: AlertPriority,
},
UpdateRules {
new_patterns: Vec<Pattern>,
},
}
impl MitigationAction {
pub async fn execute(&self, context: &ThreatContext) -> Result<String> {
match self {
MitigationAction::BlockRequest { reason } => {
self.execute_block(context, reason).await
}
MitigationAction::RateLimitUser { duration } => {
self.execute_rate_limit(context, *duration).await
}
MitigationAction::RequireVerification { challenge_type } => {
self.execute_verification(context, challenge_type).await
}
MitigationAction::AlertHuman { priority } => {
self.execute_alert(context, priority).await
}
MitigationAction::UpdateRules { new_patterns } => {
self.execute_rule_update(context, new_patterns).await
}
}
}
pub fn rollback(&self, action_id: &str) -> Result<()> {
tracing::info!("Rolling back action: {}", action_id);
Ok(())
}
async fn execute_block(&self, context: &ThreatContext, reason: &str) -> Result<String> {
tracing::info!(
"Blocking request from {} - Reason: {}",
context.source_id,
reason
);
let action_id = uuid::Uuid::new_v4().to_string();
metrics::counter!("mitigation.blocks").increment(1);
Ok(action_id)
}
async fn execute_rate_limit(&self, context: &ThreatContext, duration: Duration) -> Result<String> {
tracing::info!(
"Rate limiting {} for {:?}",
context.source_id,
duration
);
let action_id = uuid::Uuid::new_v4().to_string();
metrics::counter!("mitigation.rate_limits").increment(1);
Ok(action_id)
}
async fn execute_verification(&self, context: &ThreatContext, challenge: &ChallengeType) -> Result<String> {
tracing::info!(
"Requiring {:?} verification for {}",
challenge,
context.source_id
);
let action_id = uuid::Uuid::new_v4().to_string();
metrics::counter!("mitigation.verifications").increment(1);
Ok(action_id)
}
async fn execute_alert(&self, context: &ThreatContext, priority: &AlertPriority) -> Result<String> {
tracing::warn!(
"Alerting security team - Priority: {:?} - Threat: {}",
priority,
context.threat_id
);
let action_id = uuid::Uuid::new_v4().to_string();
metrics::counter!("mitigation.alerts").increment(1);
Ok(action_id)
}
async fn execute_rule_update(&self, _context: &ThreatContext, patterns: &[Pattern]) -> Result<String> {
tracing::info!(
"Updating rules with {} new patterns",
patterns.len()
);
let action_id = uuid::Uuid::new_v4().to_string();
metrics::counter!("mitigation.rule_updates").increment(1);
Ok(action_id)
}
}
#[async_trait::async_trait]
pub trait Mitigation: Send + Sync {
async fn execute(&self, context: &ThreatContext) -> Result<MitigationOutcome>;
fn rollback(&self) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreatContext {
pub threat_id: String,
pub source_id: String,
pub threat_type: String,
pub severity: u8,
pub confidence: f64,
pub metadata: HashMap<String, String>,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
impl ThreatContext {
pub fn from_incident(incident: &ThreatIncident) -> Self {
Self {
threat_id: incident.id.clone(),
source_id: format!("source_{}", incident.id),
threat_type: format!("{:?}", incident.threat_type),
severity: incident.severity,
confidence: incident.confidence,
metadata: HashMap::new(),
timestamp: incident.timestamp,
}
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MitigationOutcome {
pub strategy_id: String,
pub threat_type: String,
pub features: HashMap<String, f64>,
pub success: bool,
pub actions_applied: Vec<String>,
pub duration: Duration,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
impl MitigationOutcome {
pub fn effectiveness_score(&self) -> f64 {
if self.success {
let time_factor = 1.0 - (self.duration.as_millis() as f64 / 1000.0).min(1.0);
0.7 + 0.3 * time_factor
} else {
0.0
}
}
pub fn requires_rollback(&self) -> bool {
!self.success && !self.actions_applied.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Pattern {
pub id: String,
pub pattern_type: PatternType,
pub confidence: f64,
pub features: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PatternType {
Signature,
Anomaly,
Behavioral,
Statistical,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_block_action() {
let context = ThreatContext {
threat_id: "test-1".to_string(),
source_id: "source-1".to_string(),
threat_type: "anomaly".to_string(),
severity: 8,
confidence: 0.9,
metadata: HashMap::new(),
timestamp: chrono::Utc::now(),
};
let action = MitigationAction::BlockRequest {
reason: "Test block".to_string(),
};
let result = action.execute(&context).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_rate_limit_action() {
let context = ThreatContext {
threat_id: "test-2".to_string(),
source_id: "source-2".to_string(),
threat_type: "anomaly".to_string(),
severity: 5,
confidence: 0.7,
metadata: HashMap::new(),
timestamp: chrono::Utc::now(),
};
let action = MitigationAction::RateLimitUser {
duration: Duration::from_secs(300),
};
let result = action.execute(&context).await;
assert!(result.is_ok());
}
#[test]
fn test_effectiveness_score() {
let outcome = MitigationOutcome {
strategy_id: "test".to_string(),
threat_type: "anomaly".to_string(),
features: HashMap::new(),
success: true,
actions_applied: vec!["action-1".to_string()],
duration: Duration::from_millis(50),
timestamp: chrono::Utc::now(),
};
let score = outcome.effectiveness_score();
assert!(score > 0.7);
assert!(score <= 1.0);
}
#[test]
fn test_context_creation() {
let incident = crate::meta_learning::ThreatIncident {
id: "test-3".to_string(),
threat_type: crate::meta_learning::ThreatType::Anomaly(0.85),
severity: 7,
confidence: 0.9,
timestamp: chrono::Utc::now(),
};
let context = ThreatContext::from_incident(&incident);
assert_eq!(context.threat_id, "test-3");
assert_eq!(context.severity, 7);
}
}