use std::collections::HashMap;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc, Duration};
use crate::security::SecurityEvent;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditEntry {
pub id: String,
pub timestamp: DateTime<Utc>,
pub event_type: String,
pub user_id: Option<String>,
pub session_id: Option<String>,
pub tenant_id: Option<String>,
pub resource: Option<String>,
pub action: Option<String>,
pub outcome: AuditOutcome,
pub risk_score: u32,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub metadata: HashMap<String, String>,
pub event_data: SecurityEvent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AuditOutcome {
Success,
Failure,
Warning,
Information,
}
#[derive(Debug, Clone)]
pub struct AuditQuery {
pub start_time: DateTime<Utc>,
pub end_time: DateTime<Utc>,
pub user_id: Option<String>,
pub event_types: Vec<String>,
pub outcome: Option<AuditOutcome>,
pub resource: Option<String>,
pub min_risk_score: Option<u32>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditStatistics {
pub total_events: u64,
pub events_by_type: HashMap<String, u64>,
pub events_by_outcome: HashMap<String, u64>,
pub events_by_user: HashMap<String, u64>,
pub risk_score_distribution: RiskScoreDistribution,
pub time_period: DateRange,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskScoreDistribution {
pub low: u64, pub medium: u64, pub high: u64, pub critical: u64, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DateRange {
pub start: DateTime<Utc>,
pub end: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntegrityVerification {
pub verified: bool,
pub total_entries: u64,
pub verified_entries: u64,
pub corrupted_entries: Vec<String>,
pub verification_timestamp: DateTime<Utc>,
pub hash_chain_valid: bool,
}
#[async_trait]
pub trait AuditService: Send + Sync {
async fn log_event(&self, event: SecurityEvent) -> crate::core::error::Result<String>;
async fn query_events(&self, query: AuditQuery) -> crate::core::error::Result<Vec<AuditEntry>>;
async fn get_statistics(&self, start: DateTime<Utc>, end: DateTime<Utc>) -> crate::core::error::Result<AuditStatistics>;
async fn verify_integrity(&self, start: DateTime<Utc>, end: DateTime<Utc>) -> crate::core::error::Result<IntegrityVerification>;
async fn export_audit_log(&self, query: AuditQuery, format: ExportFormat) -> crate::core::error::Result<Vec<u8>>;
async fn count_events_by_type(&self, event_type: &str, period: Duration) -> crate::core::error::Result<u64>;
async fn purge_old_entries(&self, older_than: DateTime<Utc>) -> crate::core::error::Result<u64>;
}
#[derive(Debug, Clone)]
pub enum ExportFormat {
Json,
Csv,
Xml,
Pdf,
}
pub struct DatabaseAuditService {
entries: std::sync::Mutex<Vec<AuditEntry>>,
hash_chain: std::sync::Mutex<Vec<String>>,
}
impl DatabaseAuditService {
pub fn new() -> crate::core::error::Result<Self> {
Ok(Self {
entries: std::sync::Mutex::new(Vec::new()),
hash_chain: std::sync::Mutex::new(Vec::new()),
})
}
fn calculate_risk_score(&self, event: &SecurityEvent) -> u32 {
match event {
SecurityEvent::Authentication { success, .. } => {
if *success { 10 } else { 40 }
}
SecurityEvent::Authorization { granted, .. } => {
if *granted { 5 } else { 35 }
}
SecurityEvent::DataAccess { classification, .. } => {
match classification {
crate::security::SecurityLevel::Public => 5,
crate::security::SecurityLevel::Internal => 15,
crate::security::SecurityLevel::Confidential => 30,
crate::security::SecurityLevel::Restricted => 50,
crate::security::SecurityLevel::TopSecret => 75,
}
}
SecurityEvent::ThreatDetected { severity, .. } => {
match severity {
crate::security::ThreatSeverity::Low => 30,
crate::security::ThreatSeverity::Medium => 50,
crate::security::ThreatSeverity::High => 75,
crate::security::ThreatSeverity::Critical => 95,
}
}
SecurityEvent::PolicyViolation { severity, .. } => {
match severity {
crate::security::ViolationSeverity::Minor => 20,
crate::security::ViolationSeverity::Major => 60,
crate::security::ViolationSeverity::Critical => 90,
}
}
}
}
fn determine_outcome(&self, event: &SecurityEvent) -> AuditOutcome {
match event {
SecurityEvent::Authentication { success, .. } => {
if *success { AuditOutcome::Success } else { AuditOutcome::Failure }
}
SecurityEvent::Authorization { granted, .. } => {
if *granted { AuditOutcome::Success } else { AuditOutcome::Warning }
}
SecurityEvent::DataAccess { .. } => AuditOutcome::Information,
SecurityEvent::ThreatDetected { .. } => AuditOutcome::Warning,
SecurityEvent::PolicyViolation { .. } => AuditOutcome::Failure,
}
}
fn extract_metadata(&self, event: &SecurityEvent) -> (Option<String>, Option<String>, HashMap<String, String>) {
let mut metadata = HashMap::new();
match event {
SecurityEvent::Authentication { user_id, method, ip_address, .. } => {
metadata.insert("auth_method".to_string(), method.clone());
if let Some(ip) = ip_address {
metadata.insert("ip_address".to_string(), ip.clone());
}
(Some(user_id.clone()), None, metadata)
}
SecurityEvent::Authorization { user_id, resource, action, .. } => {
metadata.insert("action".to_string(), action.clone());
(Some(user_id.clone()), Some(resource.clone()), metadata)
}
SecurityEvent::DataAccess { user_id, resource, operation, classification } => {
metadata.insert("operation".to_string(), operation.clone());
metadata.insert("classification".to_string(), format!("{:?}", classification));
(Some(user_id.clone()), Some(resource.clone()), metadata)
}
SecurityEvent::ThreatDetected { threat_type, source, details, .. } => {
metadata.insert("threat_type".to_string(), threat_type.clone());
metadata.extend(details.clone());
(Some(source.clone()), None, metadata)
}
SecurityEvent::PolicyViolation { user_id, policy, violation_type, .. } => {
metadata.insert("policy".to_string(), policy.clone());
metadata.insert("violation_type".to_string(), violation_type.clone());
(Some(user_id.clone()), None, metadata)
}
}
}
fn extract_context_info(&self, event: &SecurityEvent) -> (Option<String>, Option<String>, Option<String>, Option<String>, Option<String>) {
match event {
SecurityEvent::Authentication { ip_address, .. } => {
(None, None, None, ip_address.clone(), None)
}
SecurityEvent::Authorization { action, .. } => {
(None, None, Some(action.clone()), None, None)
}
SecurityEvent::DataAccess { operation, .. } => {
(None, None, Some(operation.clone()), None, None)
}
SecurityEvent::ThreatDetected { details, .. } => {
let ip_address = details.get("ip_address").cloned();
let user_agent = details.get("user_agent").cloned();
let session_id = details.get("session_id").cloned();
let tenant_id = details.get("tenant_id").cloned();
(session_id, tenant_id, None, ip_address, user_agent)
}
SecurityEvent::PolicyViolation { violation_type, .. } => {
(None, None, Some(violation_type.clone()), None, None)
}
}
}
fn calculate_entry_hash(&self, entry: &AuditEntry, previous_hash: &str) -> String {
use sha2::{Sha256, Digest};
let mut hasher = Sha256::new();
hasher.update(previous_hash);
hasher.update(entry.id.as_bytes());
hasher.update(entry.timestamp.timestamp().to_be_bytes());
hasher.update(entry.event_type.as_bytes());
if let Some(user_id) = &entry.user_id {
hasher.update(user_id.as_bytes());
}
hex::encode(hasher.finalize())
}
fn matches_query(&self, entry: &AuditEntry, query: &AuditQuery) -> bool {
if entry.timestamp < query.start_time || entry.timestamp > query.end_time {
return false;
}
if let Some(user_id) = &query.user_id {
if entry.user_id.as_ref() != Some(user_id) {
return false;
}
}
if !query.event_types.is_empty() && !query.event_types.contains(&entry.event_type) {
return false;
}
if let Some(outcome) = &query.outcome {
if !matches!((outcome, &entry.outcome),
(AuditOutcome::Success, AuditOutcome::Success) |
(AuditOutcome::Failure, AuditOutcome::Failure) |
(AuditOutcome::Warning, AuditOutcome::Warning) |
(AuditOutcome::Information, AuditOutcome::Information)
) {
return false;
}
}
if let Some(resource) = &query.resource {
if entry.resource.as_ref() != Some(resource) {
return false;
}
}
if let Some(min_risk) = query.min_risk_score {
if entry.risk_score < min_risk {
return false;
}
}
true
}
}
#[async_trait]
impl AuditService for DatabaseAuditService {
async fn log_event(&self, event: SecurityEvent) -> crate::core::error::Result<String> {
let entry_id = uuid::Uuid::new_v4().to_string();
let timestamp = Utc::now();
let risk_score = self.calculate_risk_score(&event);
let outcome = self.determine_outcome(&event);
let (user_id, resource, metadata) = self.extract_metadata(&event);
let event_type = match &event {
SecurityEvent::Authentication { .. } => "authentication",
SecurityEvent::Authorization { .. } => "authorization",
SecurityEvent::DataAccess { .. } => "data_access",
SecurityEvent::ThreatDetected { .. } => "threat_detected",
SecurityEvent::PolicyViolation { .. } => "policy_violation",
}.to_string();
let (session_id, tenant_id, action, ip_address, user_agent) = self.extract_context_info(&event);
let entry = AuditEntry {
id: entry_id.clone(),
timestamp,
event_type,
user_id,
session_id,
tenant_id,
resource,
action,
outcome,
risk_score,
ip_address,
user_agent,
metadata,
event_data: event,
};
let mut hash_chain = self.hash_chain.lock().unwrap();
let previous_hash = hash_chain.last().unwrap_or(&"genesis".to_string()).clone();
let entry_hash = self.calculate_entry_hash(&entry, &previous_hash);
hash_chain.push(entry_hash);
let mut entries = self.entries.lock().unwrap();
entries.push(entry);
Ok(entry_id)
}
async fn query_events(&self, query: AuditQuery) -> crate::core::error::Result<Vec<AuditEntry>> {
let entries = self.entries.lock().unwrap();
let mut matching_entries: Vec<AuditEntry> = entries.iter()
.filter(|entry| self.matches_query(entry, &query))
.cloned()
.collect();
matching_entries.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
if let Some(offset) = query.offset {
matching_entries = matching_entries.into_iter().skip(offset).collect();
}
if let Some(limit) = query.limit {
matching_entries.truncate(limit);
}
Ok(matching_entries)
}
async fn get_statistics(&self, start: DateTime<Utc>, end: DateTime<Utc>) -> crate::core::error::Result<AuditStatistics> {
let entries = self.entries.lock().unwrap();
let period_entries: Vec<&AuditEntry> = entries.iter()
.filter(|entry| entry.timestamp >= start && entry.timestamp <= end)
.collect();
let total_events = period_entries.len() as u64;
let mut events_by_type = HashMap::new();
let mut events_by_outcome = HashMap::new();
let mut events_by_user = HashMap::new();
let mut risk_distribution = RiskScoreDistribution {
low: 0,
medium: 0,
high: 0,
critical: 0,
};
for entry in &period_entries {
*events_by_type.entry(entry.event_type.clone()).or_insert(0) += 1;
let outcome_str = format!("{:?}", entry.outcome);
*events_by_outcome.entry(outcome_str).or_insert(0) += 1;
if let Some(user_id) = &entry.user_id {
*events_by_user.entry(user_id.clone()).or_insert(0) += 1;
}
match entry.risk_score {
0..=25 => risk_distribution.low += 1,
26..=50 => risk_distribution.medium += 1,
51..=75 => risk_distribution.high += 1,
76..=100 => risk_distribution.critical += 1,
_ => {}
}
}
Ok(AuditStatistics {
total_events,
events_by_type,
events_by_outcome,
events_by_user,
risk_score_distribution: risk_distribution,
time_period: DateRange { start, end },
})
}
async fn verify_integrity(&self, start: DateTime<Utc>, end: DateTime<Utc>) -> crate::core::error::Result<IntegrityVerification> {
let entries = self.entries.lock().unwrap();
let hash_chain = self.hash_chain.lock().unwrap();
let period_entries: Vec<&AuditEntry> = entries.iter()
.filter(|entry| entry.timestamp >= start && entry.timestamp <= end)
.collect();
let total_entries = period_entries.len() as u64;
let mut verified_entries = 0u64;
let mut corrupted_entries = Vec::new();
let mut hash_chain_valid = true;
let mut previous_hash = "genesis".to_string();
for (i, entry) in period_entries.iter().enumerate() {
let expected_hash = self.calculate_entry_hash(entry, &previous_hash);
if i < hash_chain.len() && hash_chain[i] == expected_hash {
verified_entries += 1;
previous_hash = expected_hash;
} else {
hash_chain_valid = false;
corrupted_entries.push(entry.id.clone());
}
}
Ok(IntegrityVerification {
verified: corrupted_entries.is_empty(),
total_entries,
verified_entries,
corrupted_entries,
verification_timestamp: Utc::now(),
hash_chain_valid,
})
}
async fn export_audit_log(&self, query: AuditQuery, format: ExportFormat) -> crate::core::error::Result<Vec<u8>> {
let entries = self.query_events(query).await?;
match format {
ExportFormat::Json => {
let json = serde_json::to_string_pretty(&entries)
.map_err(|e| crate::core::error::RustChainError::Security(format!("JSON serialization error: {}", e)))?;
Ok(json.into_bytes())
}
ExportFormat::Csv => {
let mut csv = String::new();
csv.push_str("ID,Timestamp,EventType,UserID,Resource,Outcome,RiskScore\n");
for entry in entries {
csv.push_str(&format!(
"{},{},{},{},{},{:?},{}\n",
entry.id,
entry.timestamp.to_rfc3339(),
entry.event_type,
entry.user_id.unwrap_or_default(),
entry.resource.unwrap_or_default(),
entry.outcome,
entry.risk_score
));
}
Ok(csv.into_bytes())
}
ExportFormat::Xml => {
let mut xml = String::new();
xml.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<audit_log>\n");
for entry in entries {
xml.push_str(&format!(
" <entry id=\"{}\" timestamp=\"{}\" type=\"{}\" risk=\"{}\">\n",
entry.id, entry.timestamp.to_rfc3339(), entry.event_type, entry.risk_score
));
if let Some(user_id) = entry.user_id {
xml.push_str(&format!(" <user>{}</user>\n", user_id));
}
if let Some(resource) = entry.resource {
xml.push_str(&format!(" <resource>{}</resource>\n", resource));
}
xml.push_str(" </entry>\n");
}
xml.push_str("</audit_log>\n");
Ok(xml.into_bytes())
}
ExportFormat::Pdf => {
Ok(b"PDF export not implemented".to_vec())
}
}
}
async fn count_events_by_type(&self, event_type: &str, period: Duration) -> crate::core::error::Result<u64> {
let entries = self.entries.lock().unwrap();
let cutoff = Utc::now() - period;
let count = entries.iter()
.filter(|entry| entry.timestamp >= cutoff && entry.event_type == event_type)
.count() as u64;
Ok(count)
}
async fn purge_old_entries(&self, older_than: DateTime<Utc>) -> crate::core::error::Result<u64> {
let mut entries = self.entries.lock().unwrap();
let initial_count = entries.len();
entries.retain(|entry| entry.timestamp >= older_than);
let purged_count = (initial_count - entries.len()) as u64;
Ok(purged_count)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::{SecurityEvent, ThreatSeverity};
#[tokio::test]
async fn test_log_and_query_events() {
let audit_service = DatabaseAuditService::new().unwrap();
let event = SecurityEvent::Authentication {
user_id: "test_user".to_string(),
method: "password".to_string(),
success: true,
ip_address: Some("127.0.0.1".to_string()),
};
let entry_id = audit_service.log_event(event).await.unwrap();
assert!(!entry_id.is_empty());
let query = AuditQuery {
start_time: Utc::now() - Duration::minutes(5),
end_time: Utc::now() + Duration::minutes(5),
user_id: Some("test_user".to_string()),
event_types: vec![],
outcome: None,
resource: None,
min_risk_score: None,
limit: None,
offset: None,
};
let events = audit_service.query_events(query).await.unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].id, entry_id);
}
#[tokio::test]
async fn test_risk_score_calculation() {
let audit_service = DatabaseAuditService::new().unwrap();
let threat_event = SecurityEvent::ThreatDetected {
threat_type: "SQL Injection".to_string(),
severity: ThreatSeverity::Critical,
source: "external".to_string(),
details: std::collections::HashMap::new(),
};
let entry_id = audit_service.log_event(threat_event).await.unwrap();
let query = AuditQuery {
start_time: Utc::now() - Duration::minutes(5),
end_time: Utc::now() + Duration::minutes(5),
user_id: None,
event_types: vec!["threat_detected".to_string()],
outcome: None,
resource: None,
min_risk_score: Some(90),
limit: None,
offset: None,
};
let events = audit_service.query_events(query).await.unwrap();
assert_eq!(events.len(), 1);
assert!(events[0].risk_score >= 90);
}
#[tokio::test]
async fn test_integrity_verification() {
let audit_service = DatabaseAuditService::new().unwrap();
for i in 0..5 {
let event = SecurityEvent::Authentication {
user_id: format!("user_{}", i),
method: "password".to_string(),
success: true,
ip_address: Some("127.0.0.1".to_string()),
};
audit_service.log_event(event).await.unwrap();
}
let verification = audit_service.verify_integrity(
Utc::now() - Duration::hours(1),
Utc::now() + Duration::hours(1)
).await.unwrap();
assert!(verification.verified);
assert_eq!(verification.total_entries, 5);
assert_eq!(verification.verified_entries, 5);
assert!(verification.hash_chain_valid);
}
}