aimds_response/
adaptive.rs

1//! Adaptive mitigation with self-improving strategy selection
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use crate::meta_learning::ThreatIncident;
7use crate::{MitigationAction, MitigationOutcome, ThreatContext, Result, ResponseError};
8use serde::{Deserialize, Serialize};
9
10/// Adaptive mitigator with strategy selection and effectiveness tracking
11pub struct AdaptiveMitigator {
12    /// Available mitigation strategies
13    strategies: Vec<MitigationStrategy>,
14
15    /// Effectiveness scores per strategy
16    effectiveness_scores: HashMap<String, f64>,
17
18    /// Strategy application history
19    application_history: Vec<StrategyApplication>,
20
21    /// Strategy selector
22    selector: Arc<RwLock<StrategySelector>>,
23}
24
25impl AdaptiveMitigator {
26    /// Create new adaptive mitigator
27    pub fn new() -> Self {
28        let strategies = Self::initialize_strategies();
29        let effectiveness_scores = strategies.iter()
30            .map(|s| (s.id.clone(), 0.5))
31            .collect();
32
33        Self {
34            strategies,
35            effectiveness_scores,
36            application_history: Vec::new(),
37            selector: Arc::new(RwLock::new(StrategySelector::new())),
38        }
39    }
40
41    /// Apply mitigation to threat
42    pub async fn apply_mitigation(&self, threat: &ThreatIncident) -> Result<MitigationOutcome> {
43        // Select best strategy for threat
44        let strategy = self.select_strategy(threat).await?;
45
46        // Create threat context
47        let context = ThreatContext::from_incident(threat);
48
49        // Execute mitigation actions
50        let start = std::time::Instant::now();
51        let result = strategy.execute(&context).await;
52        let duration = start.elapsed();
53
54        // Build outcome
55        let outcome = match result {
56            Ok(actions_applied) => {
57                MitigationOutcome {
58                    strategy_id: strategy.id.clone(),
59                    threat_type: Self::threat_type_string(&threat.threat_type),
60                    features: Self::extract_features(threat),
61                    success: true,
62                    actions_applied,
63                    duration,
64                    timestamp: chrono::Utc::now(),
65                }
66            }
67            Err(_e) => {
68                MitigationOutcome {
69                    strategy_id: strategy.id.clone(),
70                    threat_type: Self::threat_type_string(&threat.threat_type),
71                    features: Self::extract_features(threat),
72                    success: false,
73                    actions_applied: Vec::new(),
74                    duration,
75                    timestamp: chrono::Utc::now(),
76                }
77            }
78        };
79
80        Ok(outcome)
81    }
82
83    /// Update effectiveness score for strategy
84    pub fn update_effectiveness(&mut self, strategy_id: &str, success: bool) {
85        if let Some(score) = self.effectiveness_scores.get_mut(strategy_id) {
86            // Exponential moving average
87            let alpha = 0.3;
88            let new_value = if success { 1.0 } else { 0.0 };
89            *score = alpha * new_value + (1.0 - alpha) * *score;
90        }
91
92        // Record application
93        self.application_history.push(StrategyApplication {
94            strategy_id: strategy_id.to_string(),
95            success,
96            timestamp: chrono::Utc::now(),
97        });
98    }
99
100    /// Get count of active strategies
101    pub fn active_strategies_count(&self) -> usize {
102        self.strategies.iter()
103            .filter(|s| self.effectiveness_scores.get(&s.id).is_some_and(|&score| score > 0.3))
104            .count()
105    }
106
107    /// Select best strategy for threat
108    async fn select_strategy(&self, threat: &ThreatIncident) -> Result<MitigationStrategy> {
109        let mut selector = self.selector.write().await;
110
111        // Get candidate strategies
112        let candidates: Vec<_> = self.strategies.iter()
113            .filter(|s| s.applicable_to(threat))
114            .collect();
115
116        if candidates.is_empty() {
117            return Err(ResponseError::StrategyNotFound(
118                "No applicable strategies found".to_string()
119            ));
120        }
121
122        // Select based on effectiveness scores
123        let best = candidates.iter()
124            .max_by(|a, b| {
125                let score_a = self.effectiveness_scores.get(&a.id).unwrap_or(&0.0);
126                let score_b = self.effectiveness_scores.get(&b.id).unwrap_or(&0.0);
127                score_a.partial_cmp(score_b).unwrap()
128            })
129            .unwrap();
130
131        // Update selector statistics
132        selector.record_selection(&best.id);
133
134        Ok((*best).clone())
135    }
136
137    /// Initialize default mitigation strategies
138    fn initialize_strategies() -> Vec<MitigationStrategy> {
139        vec![
140            MitigationStrategy::block_request(),
141            MitigationStrategy::rate_limit(),
142            MitigationStrategy::require_verification(),
143            MitigationStrategy::alert_human(),
144            MitigationStrategy::update_rules(),
145            MitigationStrategy::quarantine_source(),
146            MitigationStrategy::adaptive_throttle(),
147        ]
148    }
149
150    /// Convert threat type to string
151    fn threat_type_string(threat_type: &crate::meta_learning::ThreatType) -> String {
152        match threat_type {
153            crate::meta_learning::ThreatType::Anomaly(_) => "anomaly".to_string(),
154            crate::meta_learning::ThreatType::Attack(attack) => format!("attack_{:?}", attack),
155            crate::meta_learning::ThreatType::Intrusion(_) => "intrusion".to_string(),
156        }
157    }
158
159    /// Extract features from threat
160    fn extract_features(threat: &ThreatIncident) -> HashMap<String, f64> {
161        let mut features = HashMap::new();
162        features.insert("severity".to_string(), threat.severity as f64);
163        features.insert("confidence".to_string(), threat.confidence);
164        features
165    }
166}
167
168impl Default for AdaptiveMitigator {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174/// Mitigation strategy with actions and applicability rules
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct MitigationStrategy {
177    pub id: String,
178    pub name: String,
179    pub description: String,
180    pub actions: Vec<MitigationAction>,
181    pub min_severity: u8,
182    pub applicable_threats: Vec<String>,
183    pub priority: u8,
184}
185
186impl MitigationStrategy {
187    /// Check if strategy applies to threat
188    pub fn applicable_to(&self, threat: &ThreatIncident) -> bool {
189        threat.severity >= self.min_severity
190    }
191
192    /// Execute strategy actions
193    pub async fn execute(&self, context: &ThreatContext) -> Result<Vec<String>> {
194        let mut applied_actions = Vec::new();
195
196        for action in &self.actions {
197            match action.execute(context).await {
198                Ok(action_id) => {
199                    applied_actions.push(action_id);
200                }
201                Err(e) => {
202                    tracing::warn!("Action failed: {:?}", e);
203                    // Continue with remaining actions
204                }
205            }
206        }
207
208        Ok(applied_actions)
209    }
210
211    /// Create block request strategy
212    pub fn block_request() -> Self {
213        Self {
214            id: "block_request".to_string(),
215            name: "Block Request".to_string(),
216            description: "Immediately block the threatening request".to_string(),
217            actions: vec![
218                MitigationAction::BlockRequest {
219                    reason: "Threat detected".to_string(),
220                }
221            ],
222            min_severity: 7,
223            applicable_threats: vec!["attack".to_string(), "intrusion".to_string()],
224            priority: 9,
225        }
226    }
227
228    /// Create rate limit strategy
229    pub fn rate_limit() -> Self {
230        Self {
231            id: "rate_limit".to_string(),
232            name: "Rate Limit".to_string(),
233            description: "Apply rate limiting to source".to_string(),
234            actions: vec![
235                MitigationAction::RateLimitUser {
236                    duration: std::time::Duration::from_secs(300),
237                }
238            ],
239            min_severity: 5,
240            applicable_threats: vec!["anomaly".to_string(), "attack".to_string()],
241            priority: 6,
242        }
243    }
244
245    /// Create verification requirement strategy
246    pub fn require_verification() -> Self {
247        Self {
248            id: "require_verification".to_string(),
249            name: "Require Verification".to_string(),
250            description: "Require additional verification from user".to_string(),
251            actions: vec![
252                MitigationAction::RequireVerification {
253                    challenge_type: ChallengeType::Captcha,
254                }
255            ],
256            min_severity: 4,
257            applicable_threats: vec!["anomaly".to_string()],
258            priority: 5,
259        }
260    }
261
262    /// Create human alert strategy
263    pub fn alert_human() -> Self {
264        Self {
265            id: "alert_human".to_string(),
266            name: "Alert Human".to_string(),
267            description: "Alert security team for manual review".to_string(),
268            actions: vec![
269                MitigationAction::AlertHuman {
270                    priority: AlertPriority::High,
271                }
272            ],
273            min_severity: 8,
274            applicable_threats: vec!["attack".to_string(), "intrusion".to_string()],
275            priority: 8,
276        }
277    }
278
279    /// Create rule update strategy
280    pub fn update_rules() -> Self {
281        Self {
282            id: "update_rules".to_string(),
283            name: "Update Rules".to_string(),
284            description: "Dynamically update detection rules".to_string(),
285            actions: vec![
286                MitigationAction::UpdateRules {
287                    new_patterns: Vec::new(),
288                }
289            ],
290            min_severity: 3,
291            applicable_threats: vec!["anomaly".to_string()],
292            priority: 3,
293        }
294    }
295
296    /// Create quarantine strategy
297    pub fn quarantine_source() -> Self {
298        Self {
299            id: "quarantine_source".to_string(),
300            name: "Quarantine Source".to_string(),
301            description: "Isolate threat source".to_string(),
302            actions: vec![
303                MitigationAction::BlockRequest {
304                    reason: "Source quarantined".to_string(),
305                }
306            ],
307            min_severity: 9,
308            applicable_threats: vec!["attack".to_string(), "intrusion".to_string()],
309            priority: 10,
310        }
311    }
312
313    /// Create adaptive throttle strategy
314    pub fn adaptive_throttle() -> Self {
315        Self {
316            id: "adaptive_throttle".to_string(),
317            name: "Adaptive Throttle".to_string(),
318            description: "Dynamically adjust rate limits".to_string(),
319            actions: vec![
320                MitigationAction::RateLimitUser {
321                    duration: std::time::Duration::from_secs(60),
322                }
323            ],
324            min_severity: 3,
325            applicable_threats: vec!["anomaly".to_string()],
326            priority: 4,
327        }
328    }
329}
330
331/// Strategy selector with selection tracking
332struct StrategySelector {
333    selection_counts: HashMap<String, u64>,
334    last_selected: Option<String>,
335}
336
337impl StrategySelector {
338    fn new() -> Self {
339        Self {
340            selection_counts: HashMap::new(),
341            last_selected: None,
342        }
343    }
344
345    fn record_selection(&mut self, strategy_id: &str) {
346        *self.selection_counts.entry(strategy_id.to_string()).or_insert(0) += 1;
347        self.last_selected = Some(strategy_id.to_string());
348    }
349}
350
351/// Record of strategy application
352#[derive(Debug, Clone, Serialize, Deserialize)]
353struct StrategyApplication {
354    strategy_id: String,
355    success: bool,
356    timestamp: chrono::DateTime<chrono::Utc>,
357}
358
359/// Challenge type for verification
360#[derive(Debug, Clone, Serialize, Deserialize)]
361pub enum ChallengeType {
362    Captcha,
363    TwoFactor,
364    EmailVerification,
365    PhoneVerification,
366}
367
368/// Alert priority levels
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub enum AlertPriority {
371    Low,
372    Medium,
373    High,
374    Critical,
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use crate::meta_learning::{ThreatIncident, ThreatType};
381
382    #[tokio::test]
383    async fn test_mitigator_creation() {
384        let mitigator = AdaptiveMitigator::new();
385        assert!(mitigator.active_strategies_count() > 0);
386    }
387
388    #[tokio::test]
389    async fn test_strategy_selection() {
390        let mitigator = AdaptiveMitigator::new();
391
392        let threat = ThreatIncident {
393            id: "test-1".to_string(),
394            threat_type: ThreatType::Anomaly(0.85),
395            severity: 7,
396            confidence: 0.9,
397            timestamp: chrono::Utc::now(),
398        };
399
400        let strategy = mitigator.select_strategy(&threat).await;
401        assert!(strategy.is_ok());
402    }
403
404    #[test]
405    fn test_effectiveness_update() {
406        let mut mitigator = AdaptiveMitigator::new();
407        let strategy_id = "block_request";
408
409        let initial = mitigator.effectiveness_scores.get(strategy_id).copied().unwrap();
410
411        mitigator.update_effectiveness(strategy_id, true);
412        let updated = mitigator.effectiveness_scores.get(strategy_id).copied().unwrap();
413
414        assert!(updated > initial);
415    }
416
417    #[test]
418    fn test_strategy_applicability() {
419        let strategy = MitigationStrategy::block_request();
420
421        let high_severity = ThreatIncident {
422            id: "test".to_string(),
423            threat_type: ThreatType::Anomaly(0.9),
424            severity: 9,
425            confidence: 0.9,
426            timestamp: chrono::Utc::now(),
427        };
428
429        let low_severity = ThreatIncident {
430            id: "test".to_string(),
431            threat_type: ThreatType::Anomaly(0.5),
432            severity: 3,
433            confidence: 0.5,
434            timestamp: chrono::Utc::now(),
435        };
436
437        assert!(strategy.applicable_to(&high_severity));
438        assert!(!strategy.applicable_to(&low_severity));
439    }
440}