Skip to main content

aimds_response/
mitigations.rs

1//! Mitigation actions and execution framework
2
3use std::collections::HashMap;
4use std::time::Duration;
5use serde::{Deserialize, Serialize};
6use crate::Result;
7use crate::adaptive::{ChallengeType, AlertPriority};
8use crate::meta_learning::ThreatIncident;
9
10/// Mitigation actions that can be taken against threats
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub enum MitigationAction {
13    /// Block the threatening request
14    BlockRequest {
15        reason: String,
16    },
17
18    /// Apply rate limiting to user/source
19    RateLimitUser {
20        duration: Duration,
21    },
22
23    /// Require additional verification
24    RequireVerification {
25        challenge_type: ChallengeType,
26    },
27
28    /// Alert human operator
29    AlertHuman {
30        priority: AlertPriority,
31    },
32
33    /// Update detection rules
34    UpdateRules {
35        new_patterns: Vec<Pattern>,
36    },
37}
38
39impl MitigationAction {
40    /// Execute mitigation action
41    pub async fn execute(&self, context: &ThreatContext) -> Result<String> {
42        match self {
43            MitigationAction::BlockRequest { reason } => {
44                self.execute_block(context, reason).await
45            }
46            MitigationAction::RateLimitUser { duration } => {
47                self.execute_rate_limit(context, *duration).await
48            }
49            MitigationAction::RequireVerification { challenge_type } => {
50                self.execute_verification(context, challenge_type).await
51            }
52            MitigationAction::AlertHuman { priority } => {
53                self.execute_alert(context, priority).await
54            }
55            MitigationAction::UpdateRules { new_patterns } => {
56                self.execute_rule_update(context, new_patterns).await
57            }
58        }
59    }
60
61    /// Rollback mitigation action
62    pub fn rollback(&self, action_id: &str) -> Result<()> {
63        // Implementation would coordinate with actual enforcement systems
64        tracing::info!("Rolling back action: {}", action_id);
65        Ok(())
66    }
67
68    /// Execute block request action
69    async fn execute_block(&self, context: &ThreatContext, reason: &str) -> Result<String> {
70        tracing::info!(
71            "Blocking request from {} - Reason: {}",
72            context.source_id,
73            reason
74        );
75
76        // Record block action
77        let action_id = uuid::Uuid::new_v4().to_string();
78
79        // In production, this would integrate with firewall/WAF
80        // For now, we simulate the action
81        metrics::counter!("mitigation.blocks").increment(1);
82
83        Ok(action_id)
84    }
85
86    /// Execute rate limit action
87    async fn execute_rate_limit(&self, context: &ThreatContext, duration: Duration) -> Result<String> {
88        tracing::info!(
89            "Rate limiting {} for {:?}",
90            context.source_id,
91            duration
92        );
93
94        let action_id = uuid::Uuid::new_v4().to_string();
95
96        // In production, integrate with rate limiter (Redis, etc.)
97        metrics::counter!("mitigation.rate_limits").increment(1);
98
99        Ok(action_id)
100    }
101
102    /// Execute verification requirement action
103    async fn execute_verification(&self, context: &ThreatContext, challenge: &ChallengeType) -> Result<String> {
104        tracing::info!(
105            "Requiring {:?} verification for {}",
106            challenge,
107            context.source_id
108        );
109
110        let action_id = uuid::Uuid::new_v4().to_string();
111
112        // In production, integrate with verification service
113        metrics::counter!("mitigation.verifications").increment(1);
114
115        Ok(action_id)
116    }
117
118    /// Execute human alert action
119    async fn execute_alert(&self, context: &ThreatContext, priority: &AlertPriority) -> Result<String> {
120        tracing::warn!(
121            "Alerting security team - Priority: {:?} - Threat: {}",
122            priority,
123            context.threat_id
124        );
125
126        let action_id = uuid::Uuid::new_v4().to_string();
127
128        // In production, integrate with alerting system (PagerDuty, etc.)
129        metrics::counter!("mitigation.alerts").increment(1);
130
131        Ok(action_id)
132    }
133
134    /// Execute rule update action
135    async fn execute_rule_update(&self, _context: &ThreatContext, patterns: &[Pattern]) -> Result<String> {
136        tracing::info!(
137            "Updating rules with {} new patterns",
138            patterns.len()
139        );
140
141        let action_id = uuid::Uuid::new_v4().to_string();
142
143        // In production, update detection engine rules
144        metrics::counter!("mitigation.rule_updates").increment(1);
145
146        Ok(action_id)
147    }
148}
149
150/// Trait for mitigation implementations
151#[async_trait::async_trait]
152pub trait Mitigation: Send + Sync {
153    /// Execute the mitigation
154    async fn execute(&self, context: &ThreatContext) -> Result<MitigationOutcome>;
155
156    /// Rollback the mitigation
157    fn rollback(&self) -> Result<()>;
158}
159
160/// Context for mitigation execution
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ThreatContext {
163    pub threat_id: String,
164    pub source_id: String,
165    pub threat_type: String,
166    pub severity: u8,
167    pub confidence: f64,
168    pub metadata: HashMap<String, String>,
169    pub timestamp: chrono::DateTime<chrono::Utc>,
170}
171
172impl ThreatContext {
173    /// Create context from threat incident
174    pub fn from_incident(incident: &ThreatIncident) -> Self {
175        Self {
176            threat_id: incident.id.clone(),
177            source_id: format!("source_{}", incident.id),
178            threat_type: format!("{:?}", incident.threat_type),
179            severity: incident.severity,
180            confidence: incident.confidence,
181            metadata: HashMap::new(),
182            timestamp: incident.timestamp,
183        }
184    }
185
186    /// Add metadata to context
187    pub fn with_metadata(mut self, key: String, value: String) -> Self {
188        self.metadata.insert(key, value);
189        self
190    }
191}
192
193/// Outcome of mitigation execution
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct MitigationOutcome {
196    pub strategy_id: String,
197    pub threat_type: String,
198    pub features: HashMap<String, f64>,
199    pub success: bool,
200    pub actions_applied: Vec<String>,
201    pub duration: Duration,
202    pub timestamp: chrono::DateTime<chrono::Utc>,
203}
204
205impl MitigationOutcome {
206    /// Calculate effectiveness score
207    pub fn effectiveness_score(&self) -> f64 {
208        if self.success {
209            // Higher score for faster mitigations
210            let time_factor = 1.0 - (self.duration.as_millis() as f64 / 1000.0).min(1.0);
211            0.7 + 0.3 * time_factor
212        } else {
213            0.0
214        }
215    }
216
217    /// Check if outcome requires rollback
218    pub fn requires_rollback(&self) -> bool {
219        !self.success && !self.actions_applied.is_empty()
220    }
221}
222
223/// Pattern for rule updates
224#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct Pattern {
226    pub id: String,
227    pub pattern_type: PatternType,
228    pub confidence: f64,
229    pub features: HashMap<String, f64>,
230}
231
232/// Pattern type enumeration
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub enum PatternType {
235    Signature,
236    Anomaly,
237    Behavioral,
238    Statistical,
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[tokio::test]
246    async fn test_block_action() {
247        let context = ThreatContext {
248            threat_id: "test-1".to_string(),
249            source_id: "source-1".to_string(),
250            threat_type: "anomaly".to_string(),
251            severity: 8,
252            confidence: 0.9,
253            metadata: HashMap::new(),
254            timestamp: chrono::Utc::now(),
255        };
256
257        let action = MitigationAction::BlockRequest {
258            reason: "Test block".to_string(),
259        };
260
261        let result = action.execute(&context).await;
262        assert!(result.is_ok());
263    }
264
265    #[tokio::test]
266    async fn test_rate_limit_action() {
267        let context = ThreatContext {
268            threat_id: "test-2".to_string(),
269            source_id: "source-2".to_string(),
270            threat_type: "anomaly".to_string(),
271            severity: 5,
272            confidence: 0.7,
273            metadata: HashMap::new(),
274            timestamp: chrono::Utc::now(),
275        };
276
277        let action = MitigationAction::RateLimitUser {
278            duration: Duration::from_secs(300),
279        };
280
281        let result = action.execute(&context).await;
282        assert!(result.is_ok());
283    }
284
285    #[test]
286    fn test_effectiveness_score() {
287        let outcome = MitigationOutcome {
288            strategy_id: "test".to_string(),
289            threat_type: "anomaly".to_string(),
290            features: HashMap::new(),
291            success: true,
292            actions_applied: vec!["action-1".to_string()],
293            duration: Duration::from_millis(50),
294            timestamp: chrono::Utc::now(),
295        };
296
297        let score = outcome.effectiveness_score();
298        assert!(score > 0.7);
299        assert!(score <= 1.0);
300    }
301
302    #[test]
303    fn test_context_creation() {
304        let incident = crate::meta_learning::ThreatIncident {
305            id: "test-3".to_string(),
306            threat_type: crate::meta_learning::ThreatType::Anomaly(0.85),
307            severity: 7,
308            confidence: 0.9,
309            timestamp: chrono::Utc::now(),
310        };
311
312        let context = ThreatContext::from_incident(&incident);
313        assert_eq!(context.threat_id, "test-3");
314        assert_eq!(context.severity, 7);
315    }
316}