1use crate::security::{AuditConfig, Result, SecurityError};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use sqlx::{PgPool, Row};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tracing::{debug, error, info};
9use uuid::Uuid;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum AuditEventType {
14 Authentication,
15 Authorization,
16 DataAccess,
17 DataModification,
18 DataDeletion,
19 SystemAccess,
20 ConfigurationChange,
21 SecurityEvent,
22 Error,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum AuditSeverity {
28 Low,
29 Medium,
30 High,
31 Critical,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct AuditEvent {
37 pub id: String,
38 pub timestamp: DateTime<Utc>,
39 pub event_type: AuditEventType,
40 pub severity: AuditSeverity,
41 pub user_id: Option<String>,
42 pub session_id: Option<String>,
43 pub ip_address: Option<String>,
44 pub user_agent: Option<String>,
45 pub resource: Option<String>,
46 pub action: String,
47 pub outcome: AuditOutcome,
48 pub details: HashMap<String, Value>,
49 pub error_message: Option<String>,
50 pub request_id: Option<String>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub enum AuditOutcome {
56 Success,
57 Failure,
58 Partial,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct AuditStatistics {
64 pub total_events: u64,
65 pub events_by_type: HashMap<String, u64>,
66 pub events_by_user: HashMap<String, u64>,
67 pub failed_events: u64,
68 pub critical_events: u64,
69 pub retention_days: u32,
70 pub oldest_event: Option<DateTime<Utc>>,
71 pub newest_event: Option<DateTime<Utc>>,
72}
73
74pub struct AuditManager {
76 config: AuditConfig,
77 db_pool: Arc<PgPool>,
78}
79
80impl AuditManager {
81 pub fn new(config: AuditConfig, db_pool: Arc<PgPool>) -> Self {
82 Self { config, db_pool }
83 }
84
85 pub async fn initialize(&self) -> Result<()> {
87 if !self.config.enabled {
88 debug!("Audit logging is disabled");
89 return Ok(());
90 }
91
92 info!("Initializing audit logging system");
93
94 let create_table_sql = r#"
96 CREATE TABLE IF NOT EXISTS audit_events (
97 id UUID PRIMARY KEY,
98 timestamp TIMESTAMPTZ NOT NULL,
99 event_type TEXT NOT NULL,
100 severity TEXT NOT NULL,
101 user_id TEXT,
102 session_id TEXT,
103 ip_address INET,
104 user_agent TEXT,
105 resource TEXT,
106 action TEXT NOT NULL,
107 outcome TEXT NOT NULL,
108 details JSONB,
109 error_message TEXT,
110 request_id TEXT,
111 created_at TIMESTAMPTZ DEFAULT NOW()
112 );
113 "#;
114
115 sqlx::query(create_table_sql)
116 .execute(self.db_pool.as_ref())
117 .await
118 .map_err(|e| SecurityError::AuditError {
119 message: format!("Failed to create audit events table: {e}"),
120 })?;
121
122 let create_indexes_sql = vec![
124 "CREATE INDEX IF NOT EXISTS idx_audit_events_timestamp ON audit_events (timestamp);",
125 "CREATE INDEX IF NOT EXISTS idx_audit_events_user_id ON audit_events (user_id);",
126 "CREATE INDEX IF NOT EXISTS idx_audit_events_event_type ON audit_events (event_type);",
127 "CREATE INDEX IF NOT EXISTS idx_audit_events_severity ON audit_events (severity);",
128 "CREATE INDEX IF NOT EXISTS idx_audit_events_outcome ON audit_events (outcome);",
129 ];
130
131 for sql in create_indexes_sql {
132 sqlx::query(sql)
133 .execute(self.db_pool.as_ref())
134 .await
135 .map_err(|e| SecurityError::AuditError {
136 message: format!("Failed to create audit index: {e}"),
137 })?;
138 }
139
140 info!("Audit logging system initialized successfully");
141 Ok(())
142 }
143
144 pub async fn log_event(&self, event: AuditEvent) -> Result<()> {
146 if !self.config.enabled {
147 return Ok(());
148 }
149
150 if !self.should_log_event(&event.event_type) {
152 return Ok(());
153 }
154
155 debug!(
156 "Logging audit event: {:?} - {}",
157 event.event_type, event.action
158 );
159
160 let insert_sql = r#"
161 INSERT INTO audit_events (
162 id, timestamp, event_type, severity, user_id, session_id,
163 ip_address, user_agent, resource, action, outcome, details,
164 error_message, request_id
165 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
166 "#;
167
168 sqlx::query(insert_sql)
169 .bind(&event.id)
170 .bind(event.timestamp)
171 .bind(format!("{:?}", event.event_type))
172 .bind(format!("{:?}", event.severity))
173 .bind(&event.user_id)
174 .bind(&event.session_id)
175 .bind(event.ip_address.as_ref())
176 .bind(&event.user_agent)
177 .bind(&event.resource)
178 .bind(&event.action)
179 .bind(format!("{:?}", event.outcome))
180 .bind(serde_json::to_value(&event.details).unwrap_or(Value::Null))
181 .bind(&event.error_message)
182 .bind(&event.request_id)
183 .execute(self.db_pool.as_ref())
184 .await
185 .map_err(|e| SecurityError::AuditError {
186 message: format!("Failed to log audit event: {e}"),
187 })?;
188
189 if matches!(event.severity, AuditSeverity::Critical) {
191 error!(
192 "CRITICAL AUDIT EVENT - Type: {:?}, Action: {}, User: {:?}, Details: {:?}",
193 event.event_type, event.action, event.user_id, event.details
194 );
195 }
196
197 Ok(())
198 }
199
200 pub async fn log_authentication(
202 &self,
203 user_id: &str,
204 action: &str,
205 outcome: AuditOutcome,
206 ip_address: Option<String>,
207 details: HashMap<String, Value>,
208 ) -> Result<()> {
209 let event = AuditEvent {
210 id: Uuid::new_v4().to_string(),
211 timestamp: Utc::now(),
212 event_type: AuditEventType::Authentication,
213 severity: if matches!(outcome, AuditOutcome::Failure) {
214 AuditSeverity::High
215 } else {
216 AuditSeverity::Medium
217 },
218 user_id: Some(user_id.to_string()),
219 session_id: None,
220 ip_address,
221 user_agent: None,
222 resource: Some("authentication".to_string()),
223 action: action.to_string(),
224 outcome,
225 details,
226 error_message: None,
227 request_id: None,
228 };
229
230 self.log_event(event).await
231 }
232
233 pub async fn log_data_access(
235 &self,
236 user_id: Option<&str>,
237 resource: &str,
238 action: &str,
239 outcome: AuditOutcome,
240 details: HashMap<String, Value>,
241 ) -> Result<()> {
242 let event = AuditEvent {
243 id: Uuid::new_v4().to_string(),
244 timestamp: Utc::now(),
245 event_type: AuditEventType::DataAccess,
246 severity: AuditSeverity::Low,
247 user_id: user_id.map(|s| s.to_string()),
248 session_id: None,
249 ip_address: None,
250 user_agent: None,
251 resource: Some(resource.to_string()),
252 action: action.to_string(),
253 outcome,
254 details,
255 error_message: None,
256 request_id: None,
257 };
258
259 self.log_event(event).await
260 }
261
262 pub async fn log_data_modification(
264 &self,
265 user_id: Option<&str>,
266 resource: &str,
267 action: &str,
268 outcome: AuditOutcome,
269 details: HashMap<String, Value>,
270 ) -> Result<()> {
271 let event = AuditEvent {
272 id: Uuid::new_v4().to_string(),
273 timestamp: Utc::now(),
274 event_type: AuditEventType::DataModification,
275 severity: AuditSeverity::Medium,
276 user_id: user_id.map(|s| s.to_string()),
277 session_id: None,
278 ip_address: None,
279 user_agent: None,
280 resource: Some(resource.to_string()),
281 action: action.to_string(),
282 outcome,
283 details,
284 error_message: None,
285 request_id: None,
286 };
287
288 self.log_event(event).await
289 }
290
291 pub async fn log_security_event(
293 &self,
294 action: &str,
295 severity: AuditSeverity,
296 user_id: Option<&str>,
297 ip_address: Option<String>,
298 details: HashMap<String, Value>,
299 ) -> Result<()> {
300 let event = AuditEvent {
301 id: Uuid::new_v4().to_string(),
302 timestamp: Utc::now(),
303 event_type: AuditEventType::SecurityEvent,
304 severity,
305 user_id: user_id.map(|s| s.to_string()),
306 session_id: None,
307 ip_address,
308 user_agent: None,
309 resource: Some("security".to_string()),
310 action: action.to_string(),
311 outcome: AuditOutcome::Success,
312 details,
313 error_message: None,
314 request_id: None,
315 };
316
317 self.log_event(event).await
318 }
319
320 pub async fn get_events(&self, filter: AuditFilter) -> Result<Vec<AuditEvent>> {
322 if !self.config.enabled {
323 return Ok(Vec::new());
324 }
325
326 let mut where_clauses = Vec::new();
327 let mut bind_count = 0;
328
329 if let Some(_user_id) = &filter.user_id {
330 bind_count += 1;
331 where_clauses.push(format!("user_id = ${bind_count}"));
332 }
333
334 if let Some(_event_type) = &filter.event_type {
335 bind_count += 1;
336 where_clauses.push(format!("event_type = ${bind_count}"));
337 }
338
339 if let Some(_start_time) = &filter.start_time {
340 bind_count += 1;
341 where_clauses.push(format!("timestamp >= ${bind_count}"));
342 }
343
344 if let Some(_end_time) = &filter.end_time {
345 bind_count += 1;
346 where_clauses.push(format!("timestamp <= ${bind_count}"));
347 }
348
349 let where_clause = if where_clauses.is_empty() {
350 String::new()
351 } else {
352 format!("WHERE {}", where_clauses.join(" AND "))
353 };
354
355 let limit = filter.limit.unwrap_or(100).min(1000); let offset = filter.offset.unwrap_or(0);
357
358 let query = format!(
359 "SELECT * FROM audit_events {where_clause} ORDER BY timestamp DESC LIMIT {limit} OFFSET {offset}"
360 );
361
362 let mut sql_query = sqlx::query(&query);
363
364 if let Some(user_id) = &filter.user_id {
366 sql_query = sql_query.bind(user_id);
367 }
368 if let Some(event_type) = &filter.event_type {
369 sql_query = sql_query.bind(format!("{event_type:?}"));
370 }
371 if let Some(start_time) = &filter.start_time {
372 sql_query = sql_query.bind(start_time);
373 }
374 if let Some(end_time) = &filter.end_time {
375 sql_query = sql_query.bind(end_time);
376 }
377
378 let rows = sql_query
379 .fetch_all(self.db_pool.as_ref())
380 .await
381 .map_err(|e| SecurityError::AuditError {
382 message: format!("Failed to fetch audit events: {e}"),
383 })?;
384
385 let mut events = Vec::new();
386 for row in rows {
387 events.push(self.row_to_audit_event(row)?);
388 }
389
390 Ok(events)
391 }
392
393 pub async fn get_statistics(&self) -> Result<AuditStatistics> {
395 if !self.config.enabled {
396 return Ok(AuditStatistics {
397 total_events: 0,
398 events_by_type: HashMap::new(),
399 events_by_user: HashMap::new(),
400 failed_events: 0,
401 critical_events: 0,
402 retention_days: self.config.retention_days,
403 oldest_event: None,
404 newest_event: None,
405 });
406 }
407
408 let total_events: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM audit_events")
410 .fetch_one(self.db_pool.as_ref())
411 .await
412 .map_err(|e| SecurityError::AuditError {
413 message: format!("Failed to get total events count: {e}"),
414 })?;
415
416 let failed_events: i64 =
417 sqlx::query_scalar("SELECT COUNT(*) FROM audit_events WHERE outcome = 'Failure'")
418 .fetch_one(self.db_pool.as_ref())
419 .await
420 .map_err(|e| SecurityError::AuditError {
421 message: format!("Failed to get failed events count: {e}"),
422 })?;
423
424 let critical_events: i64 =
425 sqlx::query_scalar("SELECT COUNT(*) FROM audit_events WHERE severity = 'Critical'")
426 .fetch_one(self.db_pool.as_ref())
427 .await
428 .map_err(|e| SecurityError::AuditError {
429 message: format!("Failed to get critical events count: {e}"),
430 })?;
431
432 let oldest_event: Option<DateTime<Utc>> =
434 sqlx::query_scalar("SELECT MIN(timestamp) FROM audit_events")
435 .fetch_one(self.db_pool.as_ref())
436 .await
437 .map_err(|e| SecurityError::AuditError {
438 message: format!("Failed to get oldest event: {e}"),
439 })?;
440
441 let newest_event: Option<DateTime<Utc>> =
442 sqlx::query_scalar("SELECT MAX(timestamp) FROM audit_events")
443 .fetch_one(self.db_pool.as_ref())
444 .await
445 .map_err(|e| SecurityError::AuditError {
446 message: format!("Failed to get newest event: {e}"),
447 })?;
448
449 let type_rows = sqlx::query(
451 "SELECT event_type, COUNT(*) as count FROM audit_events GROUP BY event_type",
452 )
453 .fetch_all(self.db_pool.as_ref())
454 .await
455 .map_err(|e| SecurityError::AuditError {
456 message: format!("Failed to get events by type: {e}"),
457 })?;
458
459 let mut events_by_type = HashMap::new();
460 for row in type_rows {
461 let event_type: String = row.get("event_type");
462 let count: i64 = row.get("count");
463 events_by_type.insert(event_type, count as u64);
464 }
465
466 let user_rows = sqlx::query("SELECT user_id, COUNT(*) as count FROM audit_events WHERE user_id IS NOT NULL GROUP BY user_id ORDER BY count DESC LIMIT 20")
468 .fetch_all(self.db_pool.as_ref())
469 .await
470 .map_err(|e| SecurityError::AuditError {
471 message: format!("Failed to get events by user: {e}"),
472 })?;
473
474 let mut events_by_user = HashMap::new();
475 for row in user_rows {
476 let user_id: String = row.get("user_id");
477 let count: i64 = row.get("count");
478 events_by_user.insert(user_id, count as u64);
479 }
480
481 Ok(AuditStatistics {
482 total_events: total_events as u64,
483 events_by_type,
484 events_by_user,
485 failed_events: failed_events as u64,
486 critical_events: critical_events as u64,
487 retention_days: self.config.retention_days,
488 oldest_event,
489 newest_event,
490 })
491 }
492
493 pub async fn cleanup_old_events(&self) -> Result<u64> {
495 if !self.config.enabled {
496 return Ok(0);
497 }
498
499 let cutoff_date = Utc::now() - chrono::Duration::days(self.config.retention_days as i64);
500
501 let deleted_count: i64 =
502 sqlx::query_scalar("DELETE FROM audit_events WHERE timestamp < $1 RETURNING COUNT(*)")
503 .bind(cutoff_date)
504 .fetch_optional(self.db_pool.as_ref())
505 .await
506 .map_err(|e| SecurityError::AuditError {
507 message: format!("Failed to cleanup old audit events: {e}"),
508 })?
509 .unwrap_or(0);
510
511 if deleted_count > 0 {
512 info!("Cleaned up {} old audit events", deleted_count);
513 }
514
515 Ok(deleted_count as u64)
516 }
517
518 fn should_log_event(&self, event_type: &AuditEventType) -> bool {
519 match event_type {
520 AuditEventType::Authentication => self.config.log_auth_events,
521 AuditEventType::DataAccess => self.config.log_data_access,
522 AuditEventType::DataModification => self.config.log_modifications,
523 AuditEventType::DataDeletion => self.config.log_modifications,
524 _ => true, }
526 }
527
528 fn row_to_audit_event(&self, row: sqlx::postgres::PgRow) -> Result<AuditEvent> {
529 let event_type_str: String = row.get("event_type");
530 let event_type = match event_type_str.as_str() {
531 "Authentication" => AuditEventType::Authentication,
532 "Authorization" => AuditEventType::Authorization,
533 "DataAccess" => AuditEventType::DataAccess,
534 "DataModification" => AuditEventType::DataModification,
535 "DataDeletion" => AuditEventType::DataDeletion,
536 "SystemAccess" => AuditEventType::SystemAccess,
537 "ConfigurationChange" => AuditEventType::ConfigurationChange,
538 "SecurityEvent" => AuditEventType::SecurityEvent,
539 "Error" => AuditEventType::Error,
540 _ => AuditEventType::SystemAccess,
541 };
542
543 let severity_str: String = row.get("severity");
544 let severity = match severity_str.as_str() {
545 "Low" => AuditSeverity::Low,
546 "Medium" => AuditSeverity::Medium,
547 "High" => AuditSeverity::High,
548 "Critical" => AuditSeverity::Critical,
549 _ => AuditSeverity::Low,
550 };
551
552 let outcome_str: String = row.get("outcome");
553 let outcome = match outcome_str.as_str() {
554 "Success" => AuditOutcome::Success,
555 "Failure" => AuditOutcome::Failure,
556 "Partial" => AuditOutcome::Partial,
557 _ => AuditOutcome::Success,
558 };
559
560 let details_value: Value = row.get("details");
561 let details: HashMap<String, Value> =
562 serde_json::from_value(details_value).unwrap_or_else(|_| HashMap::new());
563
564 Ok(AuditEvent {
565 id: row.get("id"),
566 timestamp: row.get("timestamp"),
567 event_type,
568 severity,
569 user_id: row.get("user_id"),
570 session_id: row.get("session_id"),
571 ip_address: row.get::<Option<String>, _>("ip_address"),
572 user_agent: row.get("user_agent"),
573 resource: row.get("resource"),
574 action: row.get("action"),
575 outcome,
576 details,
577 error_message: row.get("error_message"),
578 request_id: row.get("request_id"),
579 })
580 }
581
582 pub fn is_enabled(&self) -> bool {
583 self.config.enabled
584 }
585}
586
587#[derive(Debug, Clone, Default)]
589pub struct AuditFilter {
590 pub user_id: Option<String>,
591 pub event_type: Option<AuditEventType>,
592 pub start_time: Option<DateTime<Utc>>,
593 pub end_time: Option<DateTime<Utc>>,
594 pub limit: Option<i64>,
595 pub offset: Option<i64>,
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use serde_json::json;
602
603 #[test]
604 fn test_audit_event_creation() {
605 let mut details = HashMap::new();
606 details.insert("test_key".to_string(), json!("test_value"));
607
608 let event = AuditEvent {
609 id: Uuid::new_v4().to_string(),
610 timestamp: Utc::now(),
611 event_type: AuditEventType::Authentication,
612 severity: AuditSeverity::Medium,
613 user_id: Some("test-user".to_string()),
614 session_id: None,
615 ip_address: Some("192.168.1.1".to_string()),
616 user_agent: None,
617 resource: Some("login".to_string()),
618 action: "user_login".to_string(),
619 outcome: AuditOutcome::Success,
620 details,
621 error_message: None,
622 request_id: None,
623 };
624
625 assert!(!event.id.is_empty());
626 assert_eq!(event.action, "user_login");
627 assert!(matches!(event.event_type, AuditEventType::Authentication));
628 assert!(matches!(event.severity, AuditSeverity::Medium));
629 assert!(matches!(event.outcome, AuditOutcome::Success));
630 assert_eq!(event.user_id.unwrap(), "test-user");
631 }
632
633 #[test]
634 fn test_audit_filter_default() {
635 let filter = AuditFilter::default();
636 assert!(filter.user_id.is_none());
637 assert!(filter.event_type.is_none());
638 assert!(filter.start_time.is_none());
639 assert!(filter.end_time.is_none());
640 assert!(filter.limit.is_none());
641 assert!(filter.offset.is_none());
642 }
643
644 #[test]
645 fn test_audit_statistics_default() {
646 let stats = AuditStatistics {
647 total_events: 0,
648 events_by_type: HashMap::new(),
649 events_by_user: HashMap::new(),
650 failed_events: 0,
651 critical_events: 0,
652 retention_days: 90,
653 oldest_event: None,
654 newest_event: None,
655 };
656
657 assert_eq!(stats.total_events, 0);
658 assert_eq!(stats.failed_events, 0);
659 assert_eq!(stats.critical_events, 0);
660 assert_eq!(stats.retention_days, 90);
661 assert!(stats.events_by_type.is_empty());
662 assert!(stats.events_by_user.is_empty());
663 }
664
665 #[test]
666 fn test_event_type_serialization() {
667 let event_type = AuditEventType::Authentication;
668 let serialized = serde_json::to_string(&event_type).unwrap();
669 assert_eq!(serialized, "\"Authentication\"");
670
671 let deserialized: AuditEventType = serde_json::from_str(&serialized).unwrap();
672 assert!(matches!(deserialized, AuditEventType::Authentication));
673 }
674
675 #[test]
676 fn test_severity_ordering() {
677 let low = AuditSeverity::Low;
679 let critical = AuditSeverity::Critical;
680
681 match low {
683 AuditSeverity::Low => assert!(true),
684 _ => assert!(false),
685 }
686
687 match critical {
688 AuditSeverity::Critical => assert!(true),
689 _ => assert!(false),
690 }
691 }
692
693 #[test]
694 fn test_outcome_variants() {
695 let outcomes = vec![
696 AuditOutcome::Success,
697 AuditOutcome::Failure,
698 AuditOutcome::Partial,
699 ];
700
701 assert_eq!(outcomes.len(), 3);
702
703 for outcome in outcomes {
704 match outcome {
705 AuditOutcome::Success | AuditOutcome::Failure | AuditOutcome::Partial => {
706 assert!(true);
708 }
709 }
710 }
711 }
712}