use chrono::{DateTime, Utc};
use serde_json::Value;
use std::sync::Arc;
use super::entry::{AuditAction, AuditEntry, AuditFilter, AuditResult, ExportFormat};
use super::storage::{AuditStorage, MemoryStorage};
use anyhow::Result;
#[derive(Debug, Clone)]
pub struct AuditConfig {
pub enabled: bool,
pub sanitize_sensitive: bool,
pub sensitive_fields: Vec<String>,
pub default_user_id: String,
pub async_write: bool,
pub retention_days: u32,
}
impl Default for AuditConfig {
fn default() -> Self {
Self {
enabled: true,
sanitize_sensitive: true,
sensitive_fields: vec![
"password".to_string(),
"token".to_string(),
"api_key".to_string(),
"secret".to_string(),
"credential".to_string(),
],
default_user_id: "system".to_string(),
async_write: true,
retention_days: 90,
}
}
}
pub struct AuditLogger {
storage: Arc<dyn AuditStorage>,
config: AuditConfig,
}
impl AuditLogger {
pub fn new(config: AuditConfig) -> Self {
Self {
storage: Arc::new(MemoryStorage::new(10000)),
config,
}
}
pub fn with_storage(mut self, storage: Arc<dyn AuditStorage>) -> Self {
self.storage = storage;
self
}
pub fn config(&self) -> &AuditConfig {
&self.config
}
pub async fn log(&self, entry: AuditEntry) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let entry = if self.config.sanitize_sensitive {
self.sanitize_entry(entry)
} else {
entry
};
self.storage.save(&entry).await
}
pub async fn log_action(
&self,
user_id: &str,
action: AuditAction,
resource_type: &str,
resource_id: Option<&str>,
result: AuditResult,
) -> Result<()> {
let mut entry = AuditEntry::new(user_id, action, resource_type);
if let Some(id) = resource_id {
entry = entry.with_resource_id(id);
}
entry = entry.with_result(result);
self.log(entry).await
}
pub async fn log_success(
&self,
user_id: &str,
action: AuditAction,
resource_type: &str,
resource_id: Option<&str>,
) -> Result<()> {
self.log_action(
user_id,
action,
resource_type,
resource_id,
AuditResult::Success,
)
.await
}
pub async fn log_failure(
&self,
user_id: &str,
action: AuditAction,
resource_type: &str,
resource_id: Option<&str>,
error_code: &str,
error_message: &str,
) -> Result<()> {
self.log_action(
user_id,
action,
resource_type,
resource_id,
AuditResult::failure(error_code, error_message),
)
.await
}
pub async fn log_denied(
&self,
user_id: &str,
action: AuditAction,
resource_type: &str,
resource_id: Option<&str>,
reason: &str,
) -> Result<()> {
self.log_action(
user_id,
action,
resource_type,
resource_id,
AuditResult::denied(reason),
)
.await
}
pub async fn query(&self, filter: AuditFilter) -> Result<Vec<AuditEntry>> {
self.storage.query(&filter).await
}
pub async fn export(&self, format: ExportFormat, filter: AuditFilter) -> Result<Vec<u8>> {
self.storage.export(format, &filter).await
}
pub async fn cleanup(&self, before: DateTime<Utc>) -> Result<usize> {
self.storage.cleanup(before).await
}
pub async fn count(&self) -> Result<usize> {
self.storage.count().await
}
fn sanitize_entry(&self, mut entry: AuditEntry) -> AuditEntry {
if let Some(details) = &entry.details {
entry.details = Some(self.sanitize_value(details.clone()));
}
entry
}
fn sanitize_value(&self, value: Value) -> Value {
match value {
Value::Object(mut map) => {
for field in &self.config.sensitive_fields {
if let Some(v) = map.get_mut(field) {
*v = Value::String("***REDACTED***".to_string());
}
}
for (_, v) in map.iter_mut() {
*v = self.sanitize_value(v.clone());
}
Value::Object(map)
}
Value::Array(arr) => {
Value::Array(arr.into_iter().map(|v| self.sanitize_value(v)).collect())
}
other => other,
}
}
pub fn builder(
&self,
user_id: &str,
action: AuditAction,
resource_type: &str,
) -> AuditEntryBuilder<'_> {
AuditEntryBuilder {
entry: AuditEntry::new(user_id, action, resource_type),
logger: self,
}
}
}
pub struct AuditEntryBuilder<'a> {
entry: AuditEntry,
logger: &'a AuditLogger,
}
impl<'a> AuditEntryBuilder<'a> {
pub fn with_session(mut self, session_id: &str) -> Self {
self.entry = self.entry.with_session(session_id);
self
}
pub fn with_resource_id(mut self, resource_id: &str) -> Self {
self.entry = self.entry.with_resource_id(resource_id);
self
}
pub fn with_details(mut self, details: Value) -> Self {
self.entry = self.entry.with_details(details);
self
}
pub fn with_ip(mut self, ip_address: &str) -> Self {
self.entry = self.entry.with_ip(ip_address);
self
}
pub fn with_result(mut self, result: AuditResult) -> Self {
self.entry = self.entry.with_result(result);
self
}
pub fn with_duration(mut self, duration_ms: u64) -> Self {
self.entry = self.entry.with_duration(duration_ms);
self
}
pub async fn log(self) -> Result<()> {
self.logger.log(self.entry).await
}
}
impl Default for AuditLogger {
fn default() -> Self {
Self::new(AuditConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_audit_logger_creation() {
let logger = AuditLogger::default();
assert!(logger.config().enabled);
}
#[tokio::test]
async fn test_log_action() {
let logger = AuditLogger::default();
logger
.log_success("user1", AuditAction::Read, "document", Some("doc123"))
.await
.unwrap();
let count = logger.count().await.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn test_query() {
let logger = AuditLogger::default();
logger
.log_success("user1", AuditAction::Read, "doc", None)
.await
.unwrap();
logger
.log_success("user2", AuditAction::Read, "doc", None)
.await
.unwrap();
let filter = AuditFilter {
user_id: Some("user1".to_string()),
..Default::default()
};
let entries = logger.query(filter).await.unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].user_id, "user1");
}
#[tokio::test]
async fn test_sanitize_sensitive_data() {
let config = AuditConfig {
sanitize_sensitive: true,
..Default::default()
};
let logger = AuditLogger::new(config);
let entry =
AuditEntry::new("user1", AuditAction::Create, "user").with_details(serde_json::json!({
"username": "test",
"password": "secret123"
}));
logger.log(entry).await.unwrap();
let entries = logger.query(AuditFilter::default()).await.unwrap();
let details = entries[0].details.as_ref().unwrap();
assert_eq!(details.get("password").unwrap(), "***REDACTED***");
}
#[tokio::test]
async fn test_builder_pattern() {
let logger = AuditLogger::default();
logger
.builder("user1", AuditAction::Login, "session")
.with_ip("192.168.1.1")
.with_duration(100)
.log()
.await
.unwrap();
let entries = logger.query(AuditFilter::default()).await.unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].ip_address, Some("192.168.1.1".to_string()));
}
#[tokio::test]
async fn test_disabled_logger() {
let config = AuditConfig {
enabled: false,
..Default::default()
};
let logger = AuditLogger::new(config);
logger
.log_success("user1", AuditAction::Read, "doc", None)
.await
.unwrap();
let count = logger.count().await.unwrap();
assert_eq!(count, 0);
}
}