1use crate::security::{AuditConfig, Result, SecurityError};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use sqlx::{PgPool, Row};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tracing::{debug, error, info};
9use uuid::Uuid;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum AuditEventType {
14    Authentication,
15    Authorization,
16    DataAccess,
17    DataModification,
18    DataDeletion,
19    SystemAccess,
20    ConfigurationChange,
21    SecurityEvent,
22    Error,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum AuditSeverity {
28    Low,
29    Medium,
30    High,
31    Critical,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct AuditEvent {
37    pub id: String,
38    pub timestamp: DateTime<Utc>,
39    pub event_type: AuditEventType,
40    pub severity: AuditSeverity,
41    pub user_id: Option<String>,
42    pub session_id: Option<String>,
43    pub ip_address: Option<String>,
44    pub user_agent: Option<String>,
45    pub resource: Option<String>,
46    pub action: String,
47    pub outcome: AuditOutcome,
48    pub details: HashMap<String, Value>,
49    pub error_message: Option<String>,
50    pub request_id: Option<String>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub enum AuditOutcome {
56    Success,
57    Failure,
58    Partial,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct AuditStatistics {
64    pub total_events: u64,
65    pub events_by_type: HashMap<String, u64>,
66    pub events_by_user: HashMap<String, u64>,
67    pub failed_events: u64,
68    pub critical_events: u64,
69    pub retention_days: u32,
70    pub oldest_event: Option<DateTime<Utc>>,
71    pub newest_event: Option<DateTime<Utc>>,
72}
73
74pub struct AuditManager {
76    config: AuditConfig,
77    db_pool: Arc<PgPool>,
78}
79
80impl AuditManager {
81    pub fn new(config: AuditConfig, db_pool: Arc<PgPool>) -> Self {
82        Self { config, db_pool }
83    }
84
85    pub async fn initialize(&self) -> Result<()> {
87        if !self.config.enabled {
88            debug!("Audit logging is disabled");
89            return Ok(());
90        }
91
92        info!("Initializing audit logging system");
93
94        let create_table_sql = r#"
96            CREATE TABLE IF NOT EXISTS audit_events (
97                id UUID PRIMARY KEY,
98                timestamp TIMESTAMPTZ NOT NULL,
99                event_type TEXT NOT NULL,
100                severity TEXT NOT NULL,
101                user_id TEXT,
102                session_id TEXT,
103                ip_address INET,
104                user_agent TEXT,
105                resource TEXT,
106                action TEXT NOT NULL,
107                outcome TEXT NOT NULL,
108                details JSONB,
109                error_message TEXT,
110                request_id TEXT,
111                created_at TIMESTAMPTZ DEFAULT NOW()
112            );
113        "#;
114
115        sqlx::query(create_table_sql)
116            .execute(self.db_pool.as_ref())
117            .await
118            .map_err(|e| SecurityError::AuditError {
119                message: format!("Failed to create audit events table: {e}"),
120            })?;
121
122        let create_indexes_sql = vec![
124            "CREATE INDEX IF NOT EXISTS idx_audit_events_timestamp ON audit_events (timestamp);",
125            "CREATE INDEX IF NOT EXISTS idx_audit_events_user_id ON audit_events (user_id);",
126            "CREATE INDEX IF NOT EXISTS idx_audit_events_event_type ON audit_events (event_type);",
127            "CREATE INDEX IF NOT EXISTS idx_audit_events_severity ON audit_events (severity);",
128            "CREATE INDEX IF NOT EXISTS idx_audit_events_outcome ON audit_events (outcome);",
129        ];
130
131        for sql in create_indexes_sql {
132            sqlx::query(sql)
133                .execute(self.db_pool.as_ref())
134                .await
135                .map_err(|e| SecurityError::AuditError {
136                    message: format!("Failed to create audit index: {e}"),
137                })?;
138        }
139
140        info!("Audit logging system initialized successfully");
141        Ok(())
142    }
143
144    pub async fn log_event(&self, event: AuditEvent) -> Result<()> {
146        if !self.config.enabled {
147            return Ok(());
148        }
149
150        if !self.should_log_event(&event.event_type) {
152            return Ok(());
153        }
154
155        debug!(
156            "Logging audit event: {:?} - {}",
157            event.event_type, event.action
158        );
159
160        let insert_sql = r#"
161            INSERT INTO audit_events (
162                id, timestamp, event_type, severity, user_id, session_id,
163                ip_address, user_agent, resource, action, outcome, details,
164                error_message, request_id
165            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
166        "#;
167
168        sqlx::query(insert_sql)
169            .bind(&event.id)
170            .bind(event.timestamp)
171            .bind(format!("{:?}", event.event_type))
172            .bind(format!("{:?}", event.severity))
173            .bind(&event.user_id)
174            .bind(&event.session_id)
175            .bind(event.ip_address.as_ref())
176            .bind(&event.user_agent)
177            .bind(&event.resource)
178            .bind(&event.action)
179            .bind(format!("{:?}", event.outcome))
180            .bind(serde_json::to_value(&event.details).unwrap_or(Value::Null))
181            .bind(&event.error_message)
182            .bind(&event.request_id)
183            .execute(self.db_pool.as_ref())
184            .await
185            .map_err(|e| SecurityError::AuditError {
186                message: format!("Failed to log audit event: {e}"),
187            })?;
188
189        if matches!(event.severity, AuditSeverity::Critical) {
191            error!(
192                "CRITICAL AUDIT EVENT - Type: {:?}, Action: {}, User: {:?}, Details: {:?}",
193                event.event_type, event.action, event.user_id, event.details
194            );
195        }
196
197        Ok(())
198    }
199
200    pub async fn log_authentication(
202        &self,
203        user_id: &str,
204        action: &str,
205        outcome: AuditOutcome,
206        ip_address: Option<String>,
207        details: HashMap<String, Value>,
208    ) -> Result<()> {
209        let event = AuditEvent {
210            id: Uuid::new_v4().to_string(),
211            timestamp: Utc::now(),
212            event_type: AuditEventType::Authentication,
213            severity: if matches!(outcome, AuditOutcome::Failure) {
214                AuditSeverity::High
215            } else {
216                AuditSeverity::Medium
217            },
218            user_id: Some(user_id.to_string()),
219            session_id: None,
220            ip_address,
221            user_agent: None,
222            resource: Some("authentication".to_string()),
223            action: action.to_string(),
224            outcome,
225            details,
226            error_message: None,
227            request_id: None,
228        };
229
230        self.log_event(event).await
231    }
232
233    pub async fn log_data_access(
235        &self,
236        user_id: Option<&str>,
237        resource: &str,
238        action: &str,
239        outcome: AuditOutcome,
240        details: HashMap<String, Value>,
241    ) -> Result<()> {
242        let event = AuditEvent {
243            id: Uuid::new_v4().to_string(),
244            timestamp: Utc::now(),
245            event_type: AuditEventType::DataAccess,
246            severity: AuditSeverity::Low,
247            user_id: user_id.map(|s| s.to_string()),
248            session_id: None,
249            ip_address: None,
250            user_agent: None,
251            resource: Some(resource.to_string()),
252            action: action.to_string(),
253            outcome,
254            details,
255            error_message: None,
256            request_id: None,
257        };
258
259        self.log_event(event).await
260    }
261
262    pub async fn log_data_modification(
264        &self,
265        user_id: Option<&str>,
266        resource: &str,
267        action: &str,
268        outcome: AuditOutcome,
269        details: HashMap<String, Value>,
270    ) -> Result<()> {
271        let event = AuditEvent {
272            id: Uuid::new_v4().to_string(),
273            timestamp: Utc::now(),
274            event_type: AuditEventType::DataModification,
275            severity: AuditSeverity::Medium,
276            user_id: user_id.map(|s| s.to_string()),
277            session_id: None,
278            ip_address: None,
279            user_agent: None,
280            resource: Some(resource.to_string()),
281            action: action.to_string(),
282            outcome,
283            details,
284            error_message: None,
285            request_id: None,
286        };
287
288        self.log_event(event).await
289    }
290
291    pub async fn log_security_event(
293        &self,
294        action: &str,
295        severity: AuditSeverity,
296        user_id: Option<&str>,
297        ip_address: Option<String>,
298        details: HashMap<String, Value>,
299    ) -> Result<()> {
300        let event = AuditEvent {
301            id: Uuid::new_v4().to_string(),
302            timestamp: Utc::now(),
303            event_type: AuditEventType::SecurityEvent,
304            severity,
305            user_id: user_id.map(|s| s.to_string()),
306            session_id: None,
307            ip_address,
308            user_agent: None,
309            resource: Some("security".to_string()),
310            action: action.to_string(),
311            outcome: AuditOutcome::Success,
312            details,
313            error_message: None,
314            request_id: None,
315        };
316
317        self.log_event(event).await
318    }
319
320    pub async fn get_events(&self, filter: AuditFilter) -> Result<Vec<AuditEvent>> {
322        if !self.config.enabled {
323            return Ok(Vec::new());
324        }
325
326        let mut where_clauses = Vec::new();
327        let mut bind_count = 0;
328
329        if let Some(_user_id) = &filter.user_id {
330            bind_count += 1;
331            where_clauses.push(format!("user_id = ${bind_count}"));
332        }
333
334        if let Some(_event_type) = &filter.event_type {
335            bind_count += 1;
336            where_clauses.push(format!("event_type = ${bind_count}"));
337        }
338
339        if let Some(_start_time) = &filter.start_time {
340            bind_count += 1;
341            where_clauses.push(format!("timestamp >= ${bind_count}"));
342        }
343
344        if let Some(_end_time) = &filter.end_time {
345            bind_count += 1;
346            where_clauses.push(format!("timestamp <= ${bind_count}"));
347        }
348
349        let where_clause = if where_clauses.is_empty() {
350            String::new()
351        } else {
352            format!("WHERE {}", where_clauses.join(" AND "))
353        };
354
355        let limit = filter.limit.unwrap_or(100).min(1000); let offset = filter.offset.unwrap_or(0);
357
358        let query = format!(
359            "SELECT * FROM audit_events {where_clause} ORDER BY timestamp DESC LIMIT {limit} OFFSET {offset}"
360        );
361
362        let mut sql_query = sqlx::query(&query);
363
364        if let Some(user_id) = &filter.user_id {
366            sql_query = sql_query.bind(user_id);
367        }
368        if let Some(event_type) = &filter.event_type {
369            sql_query = sql_query.bind(format!("{event_type:?}"));
370        }
371        if let Some(start_time) = &filter.start_time {
372            sql_query = sql_query.bind(start_time);
373        }
374        if let Some(end_time) = &filter.end_time {
375            sql_query = sql_query.bind(end_time);
376        }
377
378        let rows = sql_query
379            .fetch_all(self.db_pool.as_ref())
380            .await
381            .map_err(|e| SecurityError::AuditError {
382                message: format!("Failed to fetch audit events: {e}"),
383            })?;
384
385        let mut events = Vec::new();
386        for row in rows {
387            events.push(self.row_to_audit_event(row)?);
388        }
389
390        Ok(events)
391    }
392
393    pub async fn get_statistics(&self) -> Result<AuditStatistics> {
395        if !self.config.enabled {
396            return Ok(AuditStatistics {
397                total_events: 0,
398                events_by_type: HashMap::new(),
399                events_by_user: HashMap::new(),
400                failed_events: 0,
401                critical_events: 0,
402                retention_days: self.config.retention_days,
403                oldest_event: None,
404                newest_event: None,
405            });
406        }
407
408        let total_events: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM audit_events")
410            .fetch_one(self.db_pool.as_ref())
411            .await
412            .map_err(|e| SecurityError::AuditError {
413                message: format!("Failed to get total events count: {e}"),
414            })?;
415
416        let failed_events: i64 =
417            sqlx::query_scalar("SELECT COUNT(*) FROM audit_events WHERE outcome = 'Failure'")
418                .fetch_one(self.db_pool.as_ref())
419                .await
420                .map_err(|e| SecurityError::AuditError {
421                    message: format!("Failed to get failed events count: {e}"),
422                })?;
423
424        let critical_events: i64 =
425            sqlx::query_scalar("SELECT COUNT(*) FROM audit_events WHERE severity = 'Critical'")
426                .fetch_one(self.db_pool.as_ref())
427                .await
428                .map_err(|e| SecurityError::AuditError {
429                    message: format!("Failed to get critical events count: {e}"),
430                })?;
431
432        let oldest_event: Option<DateTime<Utc>> =
434            sqlx::query_scalar("SELECT MIN(timestamp) FROM audit_events")
435                .fetch_one(self.db_pool.as_ref())
436                .await
437                .map_err(|e| SecurityError::AuditError {
438                    message: format!("Failed to get oldest event: {e}"),
439                })?;
440
441        let newest_event: Option<DateTime<Utc>> =
442            sqlx::query_scalar("SELECT MAX(timestamp) FROM audit_events")
443                .fetch_one(self.db_pool.as_ref())
444                .await
445                .map_err(|e| SecurityError::AuditError {
446                    message: format!("Failed to get newest event: {e}"),
447                })?;
448
449        let type_rows = sqlx::query(
451            "SELECT event_type, COUNT(*) as count FROM audit_events GROUP BY event_type",
452        )
453        .fetch_all(self.db_pool.as_ref())
454        .await
455        .map_err(|e| SecurityError::AuditError {
456            message: format!("Failed to get events by type: {e}"),
457        })?;
458
459        let mut events_by_type = HashMap::new();
460        for row in type_rows {
461            let event_type: String = row.get("event_type");
462            let count: i64 = row.get("count");
463            events_by_type.insert(event_type, count as u64);
464        }
465
466        let user_rows = sqlx::query("SELECT user_id, COUNT(*) as count FROM audit_events WHERE user_id IS NOT NULL GROUP BY user_id ORDER BY count DESC LIMIT 20")
468            .fetch_all(self.db_pool.as_ref())
469            .await
470            .map_err(|e| SecurityError::AuditError {
471                message: format!("Failed to get events by user: {e}"),
472            })?;
473
474        let mut events_by_user = HashMap::new();
475        for row in user_rows {
476            let user_id: String = row.get("user_id");
477            let count: i64 = row.get("count");
478            events_by_user.insert(user_id, count as u64);
479        }
480
481        Ok(AuditStatistics {
482            total_events: total_events as u64,
483            events_by_type,
484            events_by_user,
485            failed_events: failed_events as u64,
486            critical_events: critical_events as u64,
487            retention_days: self.config.retention_days,
488            oldest_event,
489            newest_event,
490        })
491    }
492
493    pub async fn cleanup_old_events(&self) -> Result<u64> {
495        if !self.config.enabled {
496            return Ok(0);
497        }
498
499        let cutoff_date = Utc::now() - chrono::Duration::days(self.config.retention_days as i64);
500
501        let deleted_count: i64 =
502            sqlx::query_scalar("DELETE FROM audit_events WHERE timestamp < $1 RETURNING COUNT(*)")
503                .bind(cutoff_date)
504                .fetch_optional(self.db_pool.as_ref())
505                .await
506                .map_err(|e| SecurityError::AuditError {
507                    message: format!("Failed to cleanup old audit events: {e}"),
508                })?
509                .unwrap_or(0);
510
511        if deleted_count > 0 {
512            info!("Cleaned up {} old audit events", deleted_count);
513        }
514
515        Ok(deleted_count as u64)
516    }
517
518    fn should_log_event(&self, event_type: &AuditEventType) -> bool {
519        match event_type {
520            AuditEventType::Authentication => self.config.log_auth_events,
521            AuditEventType::DataAccess => self.config.log_data_access,
522            AuditEventType::DataModification => self.config.log_modifications,
523            AuditEventType::DataDeletion => self.config.log_modifications,
524            _ => true, }
526    }
527
528    fn row_to_audit_event(&self, row: sqlx::postgres::PgRow) -> Result<AuditEvent> {
529        let event_type_str: String = row.get("event_type");
530        let event_type = match event_type_str.as_str() {
531            "Authentication" => AuditEventType::Authentication,
532            "Authorization" => AuditEventType::Authorization,
533            "DataAccess" => AuditEventType::DataAccess,
534            "DataModification" => AuditEventType::DataModification,
535            "DataDeletion" => AuditEventType::DataDeletion,
536            "SystemAccess" => AuditEventType::SystemAccess,
537            "ConfigurationChange" => AuditEventType::ConfigurationChange,
538            "SecurityEvent" => AuditEventType::SecurityEvent,
539            "Error" => AuditEventType::Error,
540            _ => AuditEventType::SystemAccess,
541        };
542
543        let severity_str: String = row.get("severity");
544        let severity = match severity_str.as_str() {
545            "Low" => AuditSeverity::Low,
546            "Medium" => AuditSeverity::Medium,
547            "High" => AuditSeverity::High,
548            "Critical" => AuditSeverity::Critical,
549            _ => AuditSeverity::Low,
550        };
551
552        let outcome_str: String = row.get("outcome");
553        let outcome = match outcome_str.as_str() {
554            "Success" => AuditOutcome::Success,
555            "Failure" => AuditOutcome::Failure,
556            "Partial" => AuditOutcome::Partial,
557            _ => AuditOutcome::Success,
558        };
559
560        let details_value: Value = row.get("details");
561        let details: HashMap<String, Value> =
562            serde_json::from_value(details_value).unwrap_or_else(|_| HashMap::new());
563
564        Ok(AuditEvent {
565            id: row.get("id"),
566            timestamp: row.get("timestamp"),
567            event_type,
568            severity,
569            user_id: row.get("user_id"),
570            session_id: row.get("session_id"),
571            ip_address: row.get::<Option<String>, _>("ip_address"),
572            user_agent: row.get("user_agent"),
573            resource: row.get("resource"),
574            action: row.get("action"),
575            outcome,
576            details,
577            error_message: row.get("error_message"),
578            request_id: row.get("request_id"),
579        })
580    }
581
582    pub fn is_enabled(&self) -> bool {
583        self.config.enabled
584    }
585}
586
587#[derive(Debug, Clone, Default)]
589pub struct AuditFilter {
590    pub user_id: Option<String>,
591    pub event_type: Option<AuditEventType>,
592    pub start_time: Option<DateTime<Utc>>,
593    pub end_time: Option<DateTime<Utc>>,
594    pub limit: Option<i64>,
595    pub offset: Option<i64>,
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use serde_json::json;
602
603    #[test]
604    fn test_audit_event_creation() {
605        let mut details = HashMap::new();
606        details.insert("test_key".to_string(), json!("test_value"));
607
608        let event = AuditEvent {
609            id: Uuid::new_v4().to_string(),
610            timestamp: Utc::now(),
611            event_type: AuditEventType::Authentication,
612            severity: AuditSeverity::Medium,
613            user_id: Some("test-user".to_string()),
614            session_id: None,
615            ip_address: Some("192.168.1.1".to_string()),
616            user_agent: None,
617            resource: Some("login".to_string()),
618            action: "user_login".to_string(),
619            outcome: AuditOutcome::Success,
620            details,
621            error_message: None,
622            request_id: None,
623        };
624
625        assert!(!event.id.is_empty());
626        assert_eq!(event.action, "user_login");
627        assert!(matches!(event.event_type, AuditEventType::Authentication));
628        assert!(matches!(event.severity, AuditSeverity::Medium));
629        assert!(matches!(event.outcome, AuditOutcome::Success));
630        assert_eq!(event.user_id.unwrap(), "test-user");
631    }
632
633    #[test]
634    fn test_audit_filter_default() {
635        let filter = AuditFilter::default();
636        assert!(filter.user_id.is_none());
637        assert!(filter.event_type.is_none());
638        assert!(filter.start_time.is_none());
639        assert!(filter.end_time.is_none());
640        assert!(filter.limit.is_none());
641        assert!(filter.offset.is_none());
642    }
643
644    #[test]
645    fn test_audit_statistics_default() {
646        let stats = AuditStatistics {
647            total_events: 0,
648            events_by_type: HashMap::new(),
649            events_by_user: HashMap::new(),
650            failed_events: 0,
651            critical_events: 0,
652            retention_days: 90,
653            oldest_event: None,
654            newest_event: None,
655        };
656
657        assert_eq!(stats.total_events, 0);
658        assert_eq!(stats.failed_events, 0);
659        assert_eq!(stats.critical_events, 0);
660        assert_eq!(stats.retention_days, 90);
661        assert!(stats.events_by_type.is_empty());
662        assert!(stats.events_by_user.is_empty());
663    }
664
665    #[test]
666    fn test_event_type_serialization() {
667        let event_type = AuditEventType::Authentication;
668        let serialized = serde_json::to_string(&event_type).unwrap();
669        assert_eq!(serialized, "\"Authentication\"");
670
671        let deserialized: AuditEventType = serde_json::from_str(&serialized).unwrap();
672        assert!(matches!(deserialized, AuditEventType::Authentication));
673    }
674
675    #[test]
676    fn test_severity_ordering() {
677        let low = AuditSeverity::Low;
679        let critical = AuditSeverity::Critical;
680
681        match low {
683            AuditSeverity::Low => assert!(true),
684            _ => assert!(false),
685        }
686
687        match critical {
688            AuditSeverity::Critical => assert!(true),
689            _ => assert!(false),
690        }
691    }
692
693    #[test]
694    fn test_outcome_variants() {
695        let outcomes = vec![
696            AuditOutcome::Success,
697            AuditOutcome::Failure,
698            AuditOutcome::Partial,
699        ];
700
701        assert_eq!(outcomes.len(), 3);
702
703        for outcome in outcomes {
704            match outcome {
705                AuditOutcome::Success | AuditOutcome::Failure | AuditOutcome::Partial => {
706                    assert!(true);
708                }
709            }
710        }
711    }
712}