1use std::collections::HashMap;
4use std::time::Duration;
5use serde::{Deserialize, Serialize};
6use crate::Result;
7use crate::adaptive::{ChallengeType, AlertPriority};
8use crate::meta_learning::ThreatIncident;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub enum MitigationAction {
13 BlockRequest {
15 reason: String,
16 },
17
18 RateLimitUser {
20 duration: Duration,
21 },
22
23 RequireVerification {
25 challenge_type: ChallengeType,
26 },
27
28 AlertHuman {
30 priority: AlertPriority,
31 },
32
33 UpdateRules {
35 new_patterns: Vec<Pattern>,
36 },
37}
38
39impl MitigationAction {
40 pub async fn execute(&self, context: &ThreatContext) -> Result<String> {
42 match self {
43 MitigationAction::BlockRequest { reason } => {
44 self.execute_block(context, reason).await
45 }
46 MitigationAction::RateLimitUser { duration } => {
47 self.execute_rate_limit(context, *duration).await
48 }
49 MitigationAction::RequireVerification { challenge_type } => {
50 self.execute_verification(context, challenge_type).await
51 }
52 MitigationAction::AlertHuman { priority } => {
53 self.execute_alert(context, priority).await
54 }
55 MitigationAction::UpdateRules { new_patterns } => {
56 self.execute_rule_update(context, new_patterns).await
57 }
58 }
59 }
60
61 pub fn rollback(&self, action_id: &str) -> Result<()> {
63 tracing::info!("Rolling back action: {}", action_id);
65 Ok(())
66 }
67
68 async fn execute_block(&self, context: &ThreatContext, reason: &str) -> Result<String> {
70 tracing::info!(
71 "Blocking request from {} - Reason: {}",
72 context.source_id,
73 reason
74 );
75
76 let action_id = uuid::Uuid::new_v4().to_string();
78
79 metrics::counter!("mitigation.blocks").increment(1);
82
83 Ok(action_id)
84 }
85
86 async fn execute_rate_limit(&self, context: &ThreatContext, duration: Duration) -> Result<String> {
88 tracing::info!(
89 "Rate limiting {} for {:?}",
90 context.source_id,
91 duration
92 );
93
94 let action_id = uuid::Uuid::new_v4().to_string();
95
96 metrics::counter!("mitigation.rate_limits").increment(1);
98
99 Ok(action_id)
100 }
101
102 async fn execute_verification(&self, context: &ThreatContext, challenge: &ChallengeType) -> Result<String> {
104 tracing::info!(
105 "Requiring {:?} verification for {}",
106 challenge,
107 context.source_id
108 );
109
110 let action_id = uuid::Uuid::new_v4().to_string();
111
112 metrics::counter!("mitigation.verifications").increment(1);
114
115 Ok(action_id)
116 }
117
118 async fn execute_alert(&self, context: &ThreatContext, priority: &AlertPriority) -> Result<String> {
120 tracing::warn!(
121 "Alerting security team - Priority: {:?} - Threat: {}",
122 priority,
123 context.threat_id
124 );
125
126 let action_id = uuid::Uuid::new_v4().to_string();
127
128 metrics::counter!("mitigation.alerts").increment(1);
130
131 Ok(action_id)
132 }
133
134 async fn execute_rule_update(&self, _context: &ThreatContext, patterns: &[Pattern]) -> Result<String> {
136 tracing::info!(
137 "Updating rules with {} new patterns",
138 patterns.len()
139 );
140
141 let action_id = uuid::Uuid::new_v4().to_string();
142
143 metrics::counter!("mitigation.rule_updates").increment(1);
145
146 Ok(action_id)
147 }
148}
149
150#[async_trait::async_trait]
152pub trait Mitigation: Send + Sync {
153 async fn execute(&self, context: &ThreatContext) -> Result<MitigationOutcome>;
155
156 fn rollback(&self) -> Result<()>;
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ThreatContext {
163 pub threat_id: String,
164 pub source_id: String,
165 pub threat_type: String,
166 pub severity: u8,
167 pub confidence: f64,
168 pub metadata: HashMap<String, String>,
169 pub timestamp: chrono::DateTime<chrono::Utc>,
170}
171
172impl ThreatContext {
173 pub fn from_incident(incident: &ThreatIncident) -> Self {
175 Self {
176 threat_id: incident.id.clone(),
177 source_id: format!("source_{}", incident.id),
178 threat_type: format!("{:?}", incident.threat_type),
179 severity: incident.severity,
180 confidence: incident.confidence,
181 metadata: HashMap::new(),
182 timestamp: incident.timestamp,
183 }
184 }
185
186 pub fn with_metadata(mut self, key: String, value: String) -> Self {
188 self.metadata.insert(key, value);
189 self
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct MitigationOutcome {
196 pub strategy_id: String,
197 pub threat_type: String,
198 pub features: HashMap<String, f64>,
199 pub success: bool,
200 pub actions_applied: Vec<String>,
201 pub duration: Duration,
202 pub timestamp: chrono::DateTime<chrono::Utc>,
203}
204
205impl MitigationOutcome {
206 pub fn effectiveness_score(&self) -> f64 {
208 if self.success {
209 let time_factor = 1.0 - (self.duration.as_millis() as f64 / 1000.0).min(1.0);
211 0.7 + 0.3 * time_factor
212 } else {
213 0.0
214 }
215 }
216
217 pub fn requires_rollback(&self) -> bool {
219 !self.success && !self.actions_applied.is_empty()
220 }
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct Pattern {
226 pub id: String,
227 pub pattern_type: PatternType,
228 pub confidence: f64,
229 pub features: HashMap<String, f64>,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
234pub enum PatternType {
235 Signature,
236 Anomaly,
237 Behavioral,
238 Statistical,
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[tokio::test]
246 async fn test_block_action() {
247 let context = ThreatContext {
248 threat_id: "test-1".to_string(),
249 source_id: "source-1".to_string(),
250 threat_type: "anomaly".to_string(),
251 severity: 8,
252 confidence: 0.9,
253 metadata: HashMap::new(),
254 timestamp: chrono::Utc::now(),
255 };
256
257 let action = MitigationAction::BlockRequest {
258 reason: "Test block".to_string(),
259 };
260
261 let result = action.execute(&context).await;
262 assert!(result.is_ok());
263 }
264
265 #[tokio::test]
266 async fn test_rate_limit_action() {
267 let context = ThreatContext {
268 threat_id: "test-2".to_string(),
269 source_id: "source-2".to_string(),
270 threat_type: "anomaly".to_string(),
271 severity: 5,
272 confidence: 0.7,
273 metadata: HashMap::new(),
274 timestamp: chrono::Utc::now(),
275 };
276
277 let action = MitigationAction::RateLimitUser {
278 duration: Duration::from_secs(300),
279 };
280
281 let result = action.execute(&context).await;
282 assert!(result.is_ok());
283 }
284
285 #[test]
286 fn test_effectiveness_score() {
287 let outcome = MitigationOutcome {
288 strategy_id: "test".to_string(),
289 threat_type: "anomaly".to_string(),
290 features: HashMap::new(),
291 success: true,
292 actions_applied: vec!["action-1".to_string()],
293 duration: Duration::from_millis(50),
294 timestamp: chrono::Utc::now(),
295 };
296
297 let score = outcome.effectiveness_score();
298 assert!(score > 0.7);
299 assert!(score <= 1.0);
300 }
301
302 #[test]
303 fn test_context_creation() {
304 let incident = crate::meta_learning::ThreatIncident {
305 id: "test-3".to_string(),
306 threat_type: crate::meta_learning::ThreatType::Anomaly(0.85),
307 severity: 7,
308 confidence: 0.9,
309 timestamp: chrono::Utc::now(),
310 };
311
312 let context = ThreatContext::from_incident(&incident);
313 assert_eq!(context.threat_id, "test-3");
314 assert_eq!(context.severity, 7);
315 }
316}