1use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use tokio::sync::RwLock;
6use serde::{Deserialize, Serialize};
7use crate::{ThreatContext, MitigationOutcome, ResponseError};
8
9pub struct AuditLogger {
19 entries: Arc<RwLock<Vec<AuditEntry>>>,
21
22 stats: Arc<RwLock<AuditStatistics>>,
24
25 total_mitigations: Arc<AtomicU64>,
27 successful_mitigations: Arc<AtomicU64>,
28
29 max_entries: usize,
31}
32
33impl AuditLogger {
34 pub fn new() -> Self {
36 Self {
37 entries: Arc::new(RwLock::new(Vec::new())),
38 stats: Arc::new(RwLock::new(AuditStatistics::default())),
39 total_mitigations: Arc::new(AtomicU64::new(0)),
40 successful_mitigations: Arc::new(AtomicU64::new(0)),
41 max_entries: 10000,
42 }
43 }
44
45 pub fn with_max_entries(max_entries: usize) -> Self {
47 Self {
48 entries: Arc::new(RwLock::new(Vec::new())),
49 stats: Arc::new(RwLock::new(AuditStatistics::default())),
50 total_mitigations: Arc::new(AtomicU64::new(0)),
51 successful_mitigations: Arc::new(AtomicU64::new(0)),
52 max_entries,
53 }
54 }
55
56 pub async fn log_mitigation_start(&self, context: &ThreatContext) {
58 let entry = AuditEntry {
59 id: uuid::Uuid::new_v4().to_string(),
60 event_type: AuditEventType::MitigationStart,
61 threat_id: context.threat_id.clone(),
62 source_id: context.source_id.clone(),
63 severity: context.severity,
64 details: serde_json::to_value(context).ok(),
65 timestamp: chrono::Utc::now(),
66 };
67
68 self.add_entry(entry).await;
69
70 self.total_mitigations.fetch_add(1, Ordering::Relaxed);
74 let mut stats = self.stats.write().await;
75 stats.total_mitigations += 1;
76 }
77
78 pub async fn log_mitigation_success(&self, context: &ThreatContext, outcome: &MitigationOutcome) {
80 let entry = AuditEntry {
81 id: uuid::Uuid::new_v4().to_string(),
82 event_type: AuditEventType::MitigationSuccess,
83 threat_id: context.threat_id.clone(),
84 source_id: context.source_id.clone(),
85 severity: context.severity,
86 details: serde_json::to_value(outcome).ok(),
87 timestamp: chrono::Utc::now(),
88 };
89
90 self.add_entry(entry).await;
91
92 self.successful_mitigations.fetch_add(1, Ordering::Relaxed);
93 let mut stats = self.stats.write().await;
94 stats.successful_mitigations += 1;
95 stats.total_actions_applied += outcome.actions_applied.len() as u64;
96 }
97
98 pub async fn log_mitigation_failure(&self, context: &ThreatContext, error: &ResponseError) {
100 let entry = AuditEntry {
101 id: uuid::Uuid::new_v4().to_string(),
102 event_type: AuditEventType::MitigationFailure,
103 threat_id: context.threat_id.clone(),
104 source_id: context.source_id.clone(),
105 severity: context.severity,
106 details: serde_json::json!({
107 "error": error.to_string(),
108 "severity": error.severity(),
109 }).into(),
110 timestamp: chrono::Utc::now(),
111 };
112
113 self.add_entry(entry).await;
114
115 let mut stats = self.stats.write().await;
116 stats.failed_mitigations += 1;
117 }
118
119 pub async fn log_rollback(&self, action_id: &str, success: bool) {
121 let entry = AuditEntry {
122 id: uuid::Uuid::new_v4().to_string(),
123 event_type: if success {
124 AuditEventType::RollbackSuccess
125 } else {
126 AuditEventType::RollbackFailure
127 },
128 threat_id: String::new(),
129 source_id: String::new(),
130 severity: 0,
131 details: serde_json::json!({ "action_id": action_id }).into(),
132 timestamp: chrono::Utc::now(),
133 };
134
135 self.add_entry(entry).await;
136
137 let mut stats = self.stats.write().await;
138 if success {
139 stats.successful_rollbacks += 1;
140 } else {
141 stats.failed_rollbacks += 1;
142 }
143 }
144
145 pub async fn log_strategy_update(&self, strategy_id: &str, details: serde_json::Value) {
147 let entry = AuditEntry {
148 id: uuid::Uuid::new_v4().to_string(),
149 event_type: AuditEventType::StrategyUpdate,
150 threat_id: String::new(),
151 source_id: String::new(),
152 severity: 0,
153 details: Some(serde_json::json!({
154 "strategy_id": strategy_id,
155 "details": details,
156 })),
157 timestamp: chrono::Utc::now(),
158 };
159
160 self.add_entry(entry).await;
161
162 let mut stats = self.stats.write().await;
163 stats.strategy_updates += 1;
164 }
165
166 pub fn total_mitigations(&self) -> u64 {
168 self.total_mitigations.load(Ordering::Relaxed)
169 }
170
171 pub fn successful_mitigations(&self) -> u64 {
173 self.successful_mitigations.load(Ordering::Relaxed)
174 }
175
176 pub async fn entries(&self) -> Vec<AuditEntry> {
178 self.entries.read().await.clone()
179 }
180
181 pub async fn statistics(&self) -> AuditStatistics {
183 self.stats.read().await.clone()
184 }
185
186 pub async fn query(&self, criteria: AuditQuery) -> Vec<AuditEntry> {
188 let entries = self.entries.read().await;
189
190 entries.iter()
191 .filter(|e| criteria.matches(e))
192 .cloned()
193 .collect()
194 }
195
196 pub async fn export(&self, format: ExportFormat) -> Result<String, ResponseError> {
198 let entries = self.entries.read().await;
199
200 match format {
201 ExportFormat::Json => {
202 serde_json::to_string_pretty(&*entries)
203 .map_err(ResponseError::Serialization)
204 }
205 ExportFormat::Csv => {
206 self.export_csv(&entries)
207 }
208 }
209 }
210
211 async fn add_entry(&self, entry: AuditEntry) {
213 let mut entries = self.entries.write().await;
214
215 if entries.len() >= self.max_entries {
217 entries.remove(0);
218 }
219
220 tracing::info!(
222 event_type = ?entry.event_type,
223 threat_id = %entry.threat_id,
224 "Audit event recorded"
225 );
226
227 entries.push(entry);
228 }
229
230 fn export_csv(&self, entries: &[AuditEntry]) -> Result<String, ResponseError> {
232 let mut csv = String::from("id,event_type,threat_id,source_id,severity,timestamp\n");
233
234 for entry in entries {
235 csv.push_str(&format!(
236 "{},{:?},{},{},{},{}\n",
237 entry.id,
238 entry.event_type,
239 entry.threat_id,
240 entry.source_id,
241 entry.severity,
242 entry.timestamp.to_rfc3339()
243 ));
244 }
245
246 Ok(csv)
247 }
248}
249
250impl Default for AuditLogger {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct AuditEntry {
259 pub id: String,
260 pub event_type: AuditEventType,
261 pub threat_id: String,
262 pub source_id: String,
263 pub severity: u8,
264 pub details: Option<serde_json::Value>,
265 pub timestamp: chrono::DateTime<chrono::Utc>,
266}
267
268#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
270pub enum AuditEventType {
271 MitigationStart,
272 MitigationSuccess,
273 MitigationFailure,
274 RollbackSuccess,
275 RollbackFailure,
276 StrategyUpdate,
277 RuleUpdate,
278 AlertGenerated,
279}
280
281#[derive(Debug, Clone, Default, Serialize, Deserialize)]
283pub struct AuditStatistics {
284 pub total_mitigations: u64,
285 pub successful_mitigations: u64,
286 pub failed_mitigations: u64,
287 pub total_actions_applied: u64,
288 pub successful_rollbacks: u64,
289 pub failed_rollbacks: u64,
290 pub strategy_updates: u64,
291}
292
293impl AuditStatistics {
294 pub fn success_rate(&self) -> f64 {
296 if self.total_mitigations == 0 {
297 return 0.0;
298 }
299 self.successful_mitigations as f64 / self.total_mitigations as f64
300 }
301
302 pub fn rollback_rate(&self) -> f64 {
304 let total_rollbacks = self.successful_rollbacks + self.failed_rollbacks;
305 if total_rollbacks == 0 {
306 return 0.0;
307 }
308 self.successful_rollbacks as f64 / total_rollbacks as f64
309 }
310}
311
312#[derive(Debug, Clone, Default)]
314pub struct AuditQuery {
315 pub event_type: Option<AuditEventType>,
316 pub threat_id: Option<String>,
317 pub source_id: Option<String>,
318 pub min_severity: Option<u8>,
319 pub after: Option<chrono::DateTime<chrono::Utc>>,
320 pub before: Option<chrono::DateTime<chrono::Utc>>,
321}
322
323impl AuditQuery {
324 fn matches(&self, entry: &AuditEntry) -> bool {
326 if let Some(_event_type) = self.event_type {
327 }
330
331 if let Some(ref threat_id) = self.threat_id {
332 if entry.threat_id != *threat_id {
333 return false;
334 }
335 }
336
337 if let Some(ref source_id) = self.source_id {
338 if entry.source_id != *source_id {
339 return false;
340 }
341 }
342
343 if let Some(min_severity) = self.min_severity {
344 if entry.severity < min_severity {
345 return false;
346 }
347 }
348
349 if let Some(after) = self.after {
350 if entry.timestamp < after {
351 return false;
352 }
353 }
354
355 if let Some(before) = self.before {
356 if entry.timestamp > before {
357 return false;
358 }
359 }
360
361 true
362 }
363}
364
365#[derive(Debug, Clone, Copy)]
367pub enum ExportFormat {
368 Json,
369 Csv,
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use crate::ThreatContext;
376 use std::collections::HashMap;
377
378 #[tokio::test]
379 async fn test_audit_logger_creation() {
380 let logger = AuditLogger::new();
381 assert_eq!(logger.entries().await.len(), 0);
382 }
383
384 #[tokio::test]
385 async fn test_log_mitigation_start() {
386 let logger = AuditLogger::new();
387
388 let context = ThreatContext {
389 threat_id: "test-1".to_string(),
390 source_id: "source-1".to_string(),
391 threat_type: "anomaly".to_string(),
392 severity: 7,
393 confidence: 0.9,
394 metadata: HashMap::new(),
395 timestamp: chrono::Utc::now(),
396 };
397
398 logger.log_mitigation_start(&context).await;
399
400 let entries = logger.entries().await;
401 assert_eq!(entries.len(), 1);
402 assert!(matches!(entries[0].event_type, AuditEventType::MitigationStart));
403 }
404
405 #[tokio::test]
406 async fn test_statistics() {
407 let logger = AuditLogger::new();
408
409 let context = ThreatContext {
410 threat_id: "test-1".to_string(),
411 source_id: "source-1".to_string(),
412 threat_type: "anomaly".to_string(),
413 severity: 7,
414 confidence: 0.9,
415 metadata: HashMap::new(),
416 timestamp: chrono::Utc::now(),
417 };
418
419 logger.log_mitigation_start(&context).await;
420
421 let stats = logger.statistics().await;
422 assert_eq!(stats.total_mitigations, 1);
423 }
424
425 #[tokio::test]
426 async fn test_audit_query() {
427 let logger = AuditLogger::new();
428
429 let context = ThreatContext {
430 threat_id: "test-1".to_string(),
431 source_id: "source-1".to_string(),
432 threat_type: "anomaly".to_string(),
433 severity: 7,
434 confidence: 0.9,
435 metadata: HashMap::new(),
436 timestamp: chrono::Utc::now(),
437 };
438
439 logger.log_mitigation_start(&context).await;
440
441 let query = AuditQuery {
442 min_severity: Some(5),
443 ..Default::default()
444 };
445
446 let results = logger.query(query).await;
447 assert_eq!(results.len(), 1);
448 }
449
450 #[tokio::test]
451 async fn test_export_json() {
452 let logger = AuditLogger::new();
453
454 let context = ThreatContext {
455 threat_id: "test-1".to_string(),
456 source_id: "source-1".to_string(),
457 threat_type: "anomaly".to_string(),
458 severity: 7,
459 confidence: 0.9,
460 metadata: HashMap::new(),
461 timestamp: chrono::Utc::now(),
462 };
463
464 logger.log_mitigation_start(&context).await;
465
466 let json = logger.export(ExportFormat::Json).await;
467 assert!(json.is_ok());
468 }
469
470 #[test]
471 fn test_statistics_calculations() {
472 let stats = AuditStatistics {
473 total_mitigations: 100,
474 successful_mitigations: 85,
475 failed_mitigations: 15,
476 total_actions_applied: 200,
477 successful_rollbacks: 8,
478 failed_rollbacks: 2,
479 strategy_updates: 5,
480 };
481
482 assert_eq!(stats.success_rate(), 0.85);
483 assert_eq!(stats.rollback_rate(), 0.8);
484 }
485}