1use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use crate::meta_learning::ThreatIncident;
7use crate::{MitigationAction, MitigationOutcome, ThreatContext, Result, ResponseError};
8use serde::{Deserialize, Serialize};
9
10pub struct AdaptiveMitigator {
12 strategies: Vec<MitigationStrategy>,
14
15 effectiveness_scores: HashMap<String, f64>,
17
18 application_history: Vec<StrategyApplication>,
20
21 selector: Arc<RwLock<StrategySelector>>,
23}
24
25impl AdaptiveMitigator {
26 pub fn new() -> Self {
28 let strategies = Self::initialize_strategies();
29 let effectiveness_scores = strategies.iter()
30 .map(|s| (s.id.clone(), 0.5))
31 .collect();
32
33 Self {
34 strategies,
35 effectiveness_scores,
36 application_history: Vec::new(),
37 selector: Arc::new(RwLock::new(StrategySelector::new())),
38 }
39 }
40
41 pub async fn apply_mitigation(&self, threat: &ThreatIncident) -> Result<MitigationOutcome> {
43 let strategy = self.select_strategy(threat).await?;
45
46 let context = ThreatContext::from_incident(threat);
48
49 let start = std::time::Instant::now();
51 let result = strategy.execute(&context).await;
52 let duration = start.elapsed();
53
54 let outcome = match result {
56 Ok(actions_applied) => {
57 MitigationOutcome {
58 strategy_id: strategy.id.clone(),
59 threat_type: Self::threat_type_string(&threat.threat_type),
60 features: Self::extract_features(threat),
61 success: true,
62 actions_applied,
63 duration,
64 timestamp: chrono::Utc::now(),
65 }
66 }
67 Err(_e) => {
68 MitigationOutcome {
69 strategy_id: strategy.id.clone(),
70 threat_type: Self::threat_type_string(&threat.threat_type),
71 features: Self::extract_features(threat),
72 success: false,
73 actions_applied: Vec::new(),
74 duration,
75 timestamp: chrono::Utc::now(),
76 }
77 }
78 };
79
80 Ok(outcome)
81 }
82
83 pub fn update_effectiveness(&mut self, strategy_id: &str, success: bool) {
85 if let Some(score) = self.effectiveness_scores.get_mut(strategy_id) {
86 let alpha = 0.3;
88 let new_value = if success { 1.0 } else { 0.0 };
89 *score = alpha * new_value + (1.0 - alpha) * *score;
90 }
91
92 self.application_history.push(StrategyApplication {
94 strategy_id: strategy_id.to_string(),
95 success,
96 timestamp: chrono::Utc::now(),
97 });
98 }
99
100 pub fn active_strategies_count(&self) -> usize {
102 self.strategies.iter()
103 .filter(|s| self.effectiveness_scores.get(&s.id).is_some_and(|&score| score > 0.3))
104 .count()
105 }
106
107 async fn select_strategy(&self, threat: &ThreatIncident) -> Result<MitigationStrategy> {
109 let mut selector = self.selector.write().await;
110
111 let candidates: Vec<_> = self.strategies.iter()
113 .filter(|s| s.applicable_to(threat))
114 .collect();
115
116 if candidates.is_empty() {
117 return Err(ResponseError::StrategyNotFound(
118 "No applicable strategies found".to_string()
119 ));
120 }
121
122 let best = candidates.iter()
124 .max_by(|a, b| {
125 let score_a = self.effectiveness_scores.get(&a.id).unwrap_or(&0.0);
126 let score_b = self.effectiveness_scores.get(&b.id).unwrap_or(&0.0);
127 score_a.partial_cmp(score_b).unwrap()
128 })
129 .unwrap();
130
131 selector.record_selection(&best.id);
133
134 Ok((*best).clone())
135 }
136
137 fn initialize_strategies() -> Vec<MitigationStrategy> {
139 vec![
140 MitigationStrategy::block_request(),
141 MitigationStrategy::rate_limit(),
142 MitigationStrategy::require_verification(),
143 MitigationStrategy::alert_human(),
144 MitigationStrategy::update_rules(),
145 MitigationStrategy::quarantine_source(),
146 MitigationStrategy::adaptive_throttle(),
147 ]
148 }
149
150 fn threat_type_string(threat_type: &crate::meta_learning::ThreatType) -> String {
152 match threat_type {
153 crate::meta_learning::ThreatType::Anomaly(_) => "anomaly".to_string(),
154 crate::meta_learning::ThreatType::Attack(attack) => format!("attack_{:?}", attack),
155 crate::meta_learning::ThreatType::Intrusion(_) => "intrusion".to_string(),
156 }
157 }
158
159 fn extract_features(threat: &ThreatIncident) -> HashMap<String, f64> {
161 let mut features = HashMap::new();
162 features.insert("severity".to_string(), threat.severity as f64);
163 features.insert("confidence".to_string(), threat.confidence);
164 features
165 }
166}
167
168impl Default for AdaptiveMitigator {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct MitigationStrategy {
177 pub id: String,
178 pub name: String,
179 pub description: String,
180 pub actions: Vec<MitigationAction>,
181 pub min_severity: u8,
182 pub applicable_threats: Vec<String>,
183 pub priority: u8,
184}
185
186impl MitigationStrategy {
187 pub fn applicable_to(&self, threat: &ThreatIncident) -> bool {
189 threat.severity >= self.min_severity
190 }
191
192 pub async fn execute(&self, context: &ThreatContext) -> Result<Vec<String>> {
194 let mut applied_actions = Vec::new();
195
196 for action in &self.actions {
197 match action.execute(context).await {
198 Ok(action_id) => {
199 applied_actions.push(action_id);
200 }
201 Err(e) => {
202 tracing::warn!("Action failed: {:?}", e);
203 }
205 }
206 }
207
208 Ok(applied_actions)
209 }
210
211 pub fn block_request() -> Self {
213 Self {
214 id: "block_request".to_string(),
215 name: "Block Request".to_string(),
216 description: "Immediately block the threatening request".to_string(),
217 actions: vec![
218 MitigationAction::BlockRequest {
219 reason: "Threat detected".to_string(),
220 }
221 ],
222 min_severity: 7,
223 applicable_threats: vec!["attack".to_string(), "intrusion".to_string()],
224 priority: 9,
225 }
226 }
227
228 pub fn rate_limit() -> Self {
230 Self {
231 id: "rate_limit".to_string(),
232 name: "Rate Limit".to_string(),
233 description: "Apply rate limiting to source".to_string(),
234 actions: vec![
235 MitigationAction::RateLimitUser {
236 duration: std::time::Duration::from_secs(300),
237 }
238 ],
239 min_severity: 5,
240 applicable_threats: vec!["anomaly".to_string(), "attack".to_string()],
241 priority: 6,
242 }
243 }
244
245 pub fn require_verification() -> Self {
247 Self {
248 id: "require_verification".to_string(),
249 name: "Require Verification".to_string(),
250 description: "Require additional verification from user".to_string(),
251 actions: vec![
252 MitigationAction::RequireVerification {
253 challenge_type: ChallengeType::Captcha,
254 }
255 ],
256 min_severity: 4,
257 applicable_threats: vec!["anomaly".to_string()],
258 priority: 5,
259 }
260 }
261
262 pub fn alert_human() -> Self {
264 Self {
265 id: "alert_human".to_string(),
266 name: "Alert Human".to_string(),
267 description: "Alert security team for manual review".to_string(),
268 actions: vec![
269 MitigationAction::AlertHuman {
270 priority: AlertPriority::High,
271 }
272 ],
273 min_severity: 8,
274 applicable_threats: vec!["attack".to_string(), "intrusion".to_string()],
275 priority: 8,
276 }
277 }
278
279 pub fn update_rules() -> Self {
281 Self {
282 id: "update_rules".to_string(),
283 name: "Update Rules".to_string(),
284 description: "Dynamically update detection rules".to_string(),
285 actions: vec![
286 MitigationAction::UpdateRules {
287 new_patterns: Vec::new(),
288 }
289 ],
290 min_severity: 3,
291 applicable_threats: vec!["anomaly".to_string()],
292 priority: 3,
293 }
294 }
295
296 pub fn quarantine_source() -> Self {
298 Self {
299 id: "quarantine_source".to_string(),
300 name: "Quarantine Source".to_string(),
301 description: "Isolate threat source".to_string(),
302 actions: vec![
303 MitigationAction::BlockRequest {
304 reason: "Source quarantined".to_string(),
305 }
306 ],
307 min_severity: 9,
308 applicable_threats: vec!["attack".to_string(), "intrusion".to_string()],
309 priority: 10,
310 }
311 }
312
313 pub fn adaptive_throttle() -> Self {
315 Self {
316 id: "adaptive_throttle".to_string(),
317 name: "Adaptive Throttle".to_string(),
318 description: "Dynamically adjust rate limits".to_string(),
319 actions: vec![
320 MitigationAction::RateLimitUser {
321 duration: std::time::Duration::from_secs(60),
322 }
323 ],
324 min_severity: 3,
325 applicable_threats: vec!["anomaly".to_string()],
326 priority: 4,
327 }
328 }
329}
330
331struct StrategySelector {
333 selection_counts: HashMap<String, u64>,
334 last_selected: Option<String>,
335}
336
337impl StrategySelector {
338 fn new() -> Self {
339 Self {
340 selection_counts: HashMap::new(),
341 last_selected: None,
342 }
343 }
344
345 fn record_selection(&mut self, strategy_id: &str) {
346 *self.selection_counts.entry(strategy_id.to_string()).or_insert(0) += 1;
347 self.last_selected = Some(strategy_id.to_string());
348 }
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
353struct StrategyApplication {
354 strategy_id: String,
355 success: bool,
356 timestamp: chrono::DateTime<chrono::Utc>,
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize)]
361pub enum ChallengeType {
362 Captcha,
363 TwoFactor,
364 EmailVerification,
365 PhoneVerification,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
370pub enum AlertPriority {
371 Low,
372 Medium,
373 High,
374 Critical,
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use crate::meta_learning::{ThreatIncident, ThreatType};
381
382 #[tokio::test]
383 async fn test_mitigator_creation() {
384 let mitigator = AdaptiveMitigator::new();
385 assert!(mitigator.active_strategies_count() > 0);
386 }
387
388 #[tokio::test]
389 async fn test_strategy_selection() {
390 let mitigator = AdaptiveMitigator::new();
391
392 let threat = ThreatIncident {
393 id: "test-1".to_string(),
394 threat_type: ThreatType::Anomaly(0.85),
395 severity: 7,
396 confidence: 0.9,
397 timestamp: chrono::Utc::now(),
398 };
399
400 let strategy = mitigator.select_strategy(&threat).await;
401 assert!(strategy.is_ok());
402 }
403
404 #[test]
405 fn test_effectiveness_update() {
406 let mut mitigator = AdaptiveMitigator::new();
407 let strategy_id = "block_request";
408
409 let initial = mitigator.effectiveness_scores.get(strategy_id).copied().unwrap();
410
411 mitigator.update_effectiveness(strategy_id, true);
412 let updated = mitigator.effectiveness_scores.get(strategy_id).copied().unwrap();
413
414 assert!(updated > initial);
415 }
416
417 #[test]
418 fn test_strategy_applicability() {
419 let strategy = MitigationStrategy::block_request();
420
421 let high_severity = ThreatIncident {
422 id: "test".to_string(),
423 threat_type: ThreatType::Anomaly(0.9),
424 severity: 9,
425 confidence: 0.9,
426 timestamp: chrono::Utc::now(),
427 };
428
429 let low_severity = ThreatIncident {
430 id: "test".to_string(),
431 threat_type: ThreatType::Anomaly(0.5),
432 severity: 3,
433 confidence: 0.5,
434 timestamp: chrono::Utc::now(),
435 };
436
437 assert!(strategy.applicable_to(&high_severity));
438 assert!(!strategy.applicable_to(&low_severity));
439 }
440}