Skip to main content

aimds_response/
audit.rs

1//! Audit logging for mitigation actions
2
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use tokio::sync::RwLock;
6use serde::{Deserialize, Serialize};
7use crate::{ThreatContext, MitigationOutcome, ResponseError};
8
9/// Audit logger for tracking all mitigation activities.
10///
11/// Hot-path counters (`total_mitigations`, `successful_mitigations`)
12/// are kept in `AtomicU64` so they can be read from sync contexts
13/// (e.g. `ResponseSystem::metrics()`) without acquiring the async
14/// RwLock that guards the full `AuditStatistics` record. The atomics
15/// and the lock are kept in sync — atomics are bumped *before* the
16/// lock-protected stats are mutated, so an observer never sees a
17/// counter that's smaller than the lock-protected version.
18pub struct AuditLogger {
19    /// Audit log entries
20    entries: Arc<RwLock<Vec<AuditEntry>>>,
21
22    /// Statistics (full snapshot, async access)
23    stats: Arc<RwLock<AuditStatistics>>,
24
25    /// Hot-path counters — readable from sync code
26    total_mitigations: Arc<AtomicU64>,
27    successful_mitigations: Arc<AtomicU64>,
28
29    /// Maximum entries to retain
30    max_entries: usize,
31}
32
33impl AuditLogger {
34    /// Create new audit logger
35    pub fn new() -> Self {
36        Self {
37            entries: Arc::new(RwLock::new(Vec::new())),
38            stats: Arc::new(RwLock::new(AuditStatistics::default())),
39            total_mitigations: Arc::new(AtomicU64::new(0)),
40            successful_mitigations: Arc::new(AtomicU64::new(0)),
41            max_entries: 10000,
42        }
43    }
44
45    /// Create with custom max entries
46    pub fn with_max_entries(max_entries: usize) -> Self {
47        Self {
48            entries: Arc::new(RwLock::new(Vec::new())),
49            stats: Arc::new(RwLock::new(AuditStatistics::default())),
50            total_mitigations: Arc::new(AtomicU64::new(0)),
51            successful_mitigations: Arc::new(AtomicU64::new(0)),
52            max_entries,
53        }
54    }
55
56    /// Log mitigation start
57    pub async fn log_mitigation_start(&self, context: &ThreatContext) {
58        let entry = AuditEntry {
59            id: uuid::Uuid::new_v4().to_string(),
60            event_type: AuditEventType::MitigationStart,
61            threat_id: context.threat_id.clone(),
62            source_id: context.source_id.clone(),
63            severity: context.severity,
64            details: serde_json::to_value(context).ok(),
65            timestamp: chrono::Utc::now(),
66        };
67
68        self.add_entry(entry).await;
69
70        // Bump the sync-readable atomic first so a concurrent
71        // `total_mitigations()` reader never sees a lower value than
72        // the lock-protected stats.
73        self.total_mitigations.fetch_add(1, Ordering::Relaxed);
74        let mut stats = self.stats.write().await;
75        stats.total_mitigations += 1;
76    }
77
78    /// Log successful mitigation
79    pub async fn log_mitigation_success(&self, context: &ThreatContext, outcome: &MitigationOutcome) {
80        let entry = AuditEntry {
81            id: uuid::Uuid::new_v4().to_string(),
82            event_type: AuditEventType::MitigationSuccess,
83            threat_id: context.threat_id.clone(),
84            source_id: context.source_id.clone(),
85            severity: context.severity,
86            details: serde_json::to_value(outcome).ok(),
87            timestamp: chrono::Utc::now(),
88        };
89
90        self.add_entry(entry).await;
91
92        self.successful_mitigations.fetch_add(1, Ordering::Relaxed);
93        let mut stats = self.stats.write().await;
94        stats.successful_mitigations += 1;
95        stats.total_actions_applied += outcome.actions_applied.len() as u64;
96    }
97
98    /// Log failed mitigation
99    pub async fn log_mitigation_failure(&self, context: &ThreatContext, error: &ResponseError) {
100        let entry = AuditEntry {
101            id: uuid::Uuid::new_v4().to_string(),
102            event_type: AuditEventType::MitigationFailure,
103            threat_id: context.threat_id.clone(),
104            source_id: context.source_id.clone(),
105            severity: context.severity,
106            details: serde_json::json!({
107                "error": error.to_string(),
108                "severity": error.severity(),
109            }).into(),
110            timestamp: chrono::Utc::now(),
111        };
112
113        self.add_entry(entry).await;
114
115        let mut stats = self.stats.write().await;
116        stats.failed_mitigations += 1;
117    }
118
119    /// Log rollback event
120    pub async fn log_rollback(&self, action_id: &str, success: bool) {
121        let entry = AuditEntry {
122            id: uuid::Uuid::new_v4().to_string(),
123            event_type: if success {
124                AuditEventType::RollbackSuccess
125            } else {
126                AuditEventType::RollbackFailure
127            },
128            threat_id: String::new(),
129            source_id: String::new(),
130            severity: 0,
131            details: serde_json::json!({ "action_id": action_id }).into(),
132            timestamp: chrono::Utc::now(),
133        };
134
135        self.add_entry(entry).await;
136
137        let mut stats = self.stats.write().await;
138        if success {
139            stats.successful_rollbacks += 1;
140        } else {
141            stats.failed_rollbacks += 1;
142        }
143    }
144
145    /// Log strategy update
146    pub async fn log_strategy_update(&self, strategy_id: &str, details: serde_json::Value) {
147        let entry = AuditEntry {
148            id: uuid::Uuid::new_v4().to_string(),
149            event_type: AuditEventType::StrategyUpdate,
150            threat_id: String::new(),
151            source_id: String::new(),
152            severity: 0,
153            details: Some(serde_json::json!({
154                "strategy_id": strategy_id,
155                "details": details,
156            })),
157            timestamp: chrono::Utc::now(),
158        };
159
160        self.add_entry(entry).await;
161
162        let mut stats = self.stats.write().await;
163        stats.strategy_updates += 1;
164    }
165
166    /// Get total mitigations count (sync, lock-free).
167    pub fn total_mitigations(&self) -> u64 {
168        self.total_mitigations.load(Ordering::Relaxed)
169    }
170
171    /// Get successful mitigations count (sync, lock-free).
172    pub fn successful_mitigations(&self) -> u64 {
173        self.successful_mitigations.load(Ordering::Relaxed)
174    }
175
176    /// Get audit entries
177    pub async fn entries(&self) -> Vec<AuditEntry> {
178        self.entries.read().await.clone()
179    }
180
181    /// Get audit statistics
182    pub async fn statistics(&self) -> AuditStatistics {
183        self.stats.read().await.clone()
184    }
185
186    /// Query entries by criteria
187    pub async fn query(&self, criteria: AuditQuery) -> Vec<AuditEntry> {
188        let entries = self.entries.read().await;
189
190        entries.iter()
191            .filter(|e| criteria.matches(e))
192            .cloned()
193            .collect()
194    }
195
196    /// Export audit log
197    pub async fn export(&self, format: ExportFormat) -> Result<String, ResponseError> {
198        let entries = self.entries.read().await;
199
200        match format {
201            ExportFormat::Json => {
202                serde_json::to_string_pretty(&*entries)
203                    .map_err(ResponseError::Serialization)
204            }
205            ExportFormat::Csv => {
206                self.export_csv(&entries)
207            }
208        }
209    }
210
211    /// Add entry to log
212    async fn add_entry(&self, entry: AuditEntry) {
213        let mut entries = self.entries.write().await;
214
215        // Maintain max size
216        if entries.len() >= self.max_entries {
217            entries.remove(0);
218        }
219
220        // Log to tracing
221        tracing::info!(
222            event_type = ?entry.event_type,
223            threat_id = %entry.threat_id,
224            "Audit event recorded"
225        );
226
227        entries.push(entry);
228    }
229
230    /// Export entries as CSV
231    fn export_csv(&self, entries: &[AuditEntry]) -> Result<String, ResponseError> {
232        let mut csv = String::from("id,event_type,threat_id,source_id,severity,timestamp\n");
233
234        for entry in entries {
235            csv.push_str(&format!(
236                "{},{:?},{},{},{},{}\n",
237                entry.id,
238                entry.event_type,
239                entry.threat_id,
240                entry.source_id,
241                entry.severity,
242                entry.timestamp.to_rfc3339()
243            ));
244        }
245
246        Ok(csv)
247    }
248}
249
250impl Default for AuditLogger {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256/// Audit log entry
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct AuditEntry {
259    pub id: String,
260    pub event_type: AuditEventType,
261    pub threat_id: String,
262    pub source_id: String,
263    pub severity: u8,
264    pub details: Option<serde_json::Value>,
265    pub timestamp: chrono::DateTime<chrono::Utc>,
266}
267
268/// Audit event types
269#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
270pub enum AuditEventType {
271    MitigationStart,
272    MitigationSuccess,
273    MitigationFailure,
274    RollbackSuccess,
275    RollbackFailure,
276    StrategyUpdate,
277    RuleUpdate,
278    AlertGenerated,
279}
280
281/// Audit statistics
282#[derive(Debug, Clone, Default, Serialize, Deserialize)]
283pub struct AuditStatistics {
284    pub total_mitigations: u64,
285    pub successful_mitigations: u64,
286    pub failed_mitigations: u64,
287    pub total_actions_applied: u64,
288    pub successful_rollbacks: u64,
289    pub failed_rollbacks: u64,
290    pub strategy_updates: u64,
291}
292
293impl AuditStatistics {
294    /// Calculate success rate
295    pub fn success_rate(&self) -> f64 {
296        if self.total_mitigations == 0 {
297            return 0.0;
298        }
299        self.successful_mitigations as f64 / self.total_mitigations as f64
300    }
301
302    /// Calculate rollback rate
303    pub fn rollback_rate(&self) -> f64 {
304        let total_rollbacks = self.successful_rollbacks + self.failed_rollbacks;
305        if total_rollbacks == 0 {
306            return 0.0;
307        }
308        self.successful_rollbacks as f64 / total_rollbacks as f64
309    }
310}
311
312/// Query criteria for audit entries
313#[derive(Debug, Clone, Default)]
314pub struct AuditQuery {
315    pub event_type: Option<AuditEventType>,
316    pub threat_id: Option<String>,
317    pub source_id: Option<String>,
318    pub min_severity: Option<u8>,
319    pub after: Option<chrono::DateTime<chrono::Utc>>,
320    pub before: Option<chrono::DateTime<chrono::Utc>>,
321}
322
323impl AuditQuery {
324    /// Check if entry matches criteria
325    fn matches(&self, entry: &AuditEntry) -> bool {
326        if let Some(_event_type) = self.event_type {
327            // TODO: Implement proper event type matching when enum comparison is needed
328            // For now, we skip this filter
329        }
330
331        if let Some(ref threat_id) = self.threat_id {
332            if entry.threat_id != *threat_id {
333                return false;
334            }
335        }
336
337        if let Some(ref source_id) = self.source_id {
338            if entry.source_id != *source_id {
339                return false;
340            }
341        }
342
343        if let Some(min_severity) = self.min_severity {
344            if entry.severity < min_severity {
345                return false;
346            }
347        }
348
349        if let Some(after) = self.after {
350            if entry.timestamp < after {
351                return false;
352            }
353        }
354
355        if let Some(before) = self.before {
356            if entry.timestamp > before {
357                return false;
358            }
359        }
360
361        true
362    }
363}
364
365/// Export format for audit logs
366#[derive(Debug, Clone, Copy)]
367pub enum ExportFormat {
368    Json,
369    Csv,
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use crate::ThreatContext;
376    use std::collections::HashMap;
377
378    #[tokio::test]
379    async fn test_audit_logger_creation() {
380        let logger = AuditLogger::new();
381        assert_eq!(logger.entries().await.len(), 0);
382    }
383
384    #[tokio::test]
385    async fn test_log_mitigation_start() {
386        let logger = AuditLogger::new();
387
388        let context = ThreatContext {
389            threat_id: "test-1".to_string(),
390            source_id: "source-1".to_string(),
391            threat_type: "anomaly".to_string(),
392            severity: 7,
393            confidence: 0.9,
394            metadata: HashMap::new(),
395            timestamp: chrono::Utc::now(),
396        };
397
398        logger.log_mitigation_start(&context).await;
399
400        let entries = logger.entries().await;
401        assert_eq!(entries.len(), 1);
402        assert!(matches!(entries[0].event_type, AuditEventType::MitigationStart));
403    }
404
405    #[tokio::test]
406    async fn test_statistics() {
407        let logger = AuditLogger::new();
408
409        let context = ThreatContext {
410            threat_id: "test-1".to_string(),
411            source_id: "source-1".to_string(),
412            threat_type: "anomaly".to_string(),
413            severity: 7,
414            confidence: 0.9,
415            metadata: HashMap::new(),
416            timestamp: chrono::Utc::now(),
417        };
418
419        logger.log_mitigation_start(&context).await;
420
421        let stats = logger.statistics().await;
422        assert_eq!(stats.total_mitigations, 1);
423    }
424
425    #[tokio::test]
426    async fn test_audit_query() {
427        let logger = AuditLogger::new();
428
429        let context = ThreatContext {
430            threat_id: "test-1".to_string(),
431            source_id: "source-1".to_string(),
432            threat_type: "anomaly".to_string(),
433            severity: 7,
434            confidence: 0.9,
435            metadata: HashMap::new(),
436            timestamp: chrono::Utc::now(),
437        };
438
439        logger.log_mitigation_start(&context).await;
440
441        let query = AuditQuery {
442            min_severity: Some(5),
443            ..Default::default()
444        };
445
446        let results = logger.query(query).await;
447        assert_eq!(results.len(), 1);
448    }
449
450    #[tokio::test]
451    async fn test_export_json() {
452        let logger = AuditLogger::new();
453
454        let context = ThreatContext {
455            threat_id: "test-1".to_string(),
456            source_id: "source-1".to_string(),
457            threat_type: "anomaly".to_string(),
458            severity: 7,
459            confidence: 0.9,
460            metadata: HashMap::new(),
461            timestamp: chrono::Utc::now(),
462        };
463
464        logger.log_mitigation_start(&context).await;
465
466        let json = logger.export(ExportFormat::Json).await;
467        assert!(json.is_ok());
468    }
469
470    #[test]
471    fn test_statistics_calculations() {
472        let stats = AuditStatistics {
473            total_mitigations: 100,
474            successful_mitigations: 85,
475            failed_mitigations: 15,
476            total_actions_applied: 200,
477            successful_rollbacks: 8,
478            failed_rollbacks: 2,
479            strategy_updates: 5,
480        };
481
482        assert_eq!(stats.success_rate(), 0.85);
483        assert_eq!(stats.rollback_rate(), 0.8);
484    }
485}