use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AuditEventType {
ApiKeyCreated,
ApiKeyValidated,
ApiKeyValidationFailed,
ApiKeyRevoked,
ApiKeyRotated,
ApiKeyExpired,
AccessGranted,
AccessDenied,
ScopeEscalationAttempt,
DataRead,
DataWrite,
DataDelete,
DataExport,
ConfigChanged,
IpConfigUpdated,
RateLimitConfigUpdated,
TenantCreated,
TenantSuspended,
BruteForceDetected,
IpBlocked,
AnomalousActivity,
SecurityIncident,
RateLimitExceeded,
RateLimitThrottled,
SystemStartup,
SystemShutdown,
HealthCheckFailed,
BackupCompleted,
KeyRotationStarted,
KeyRotationCompleted,
KeyRotationFailed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "actor_type", rename_all = "snake_case")]
pub enum AuditActor {
System,
ApiKey(Uuid),
Admin(Uuid),
ServiceAccount(String),
Anonymous,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum AuditOutcome {
Success,
Failure {
error_code: String,
error_message: String,
},
Partial {
completed: u32,
total: u32,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditEvent {
pub event_id: Uuid,
pub event_type: AuditEventType,
pub tenant_id: Option<Uuid>,
pub key_id: Option<Uuid>,
pub actor: AuditActor,
pub timestamp: DateTime<Utc>,
pub details: serde_json::Value,
pub ip_address: Option<IpAddr>,
pub user_agent: Option<String>,
pub request_id: Option<Uuid>,
pub session_id: Option<Uuid>,
pub resource: Option<String>,
pub outcome: AuditOutcome,
pub previous_hash: Option<String>,
pub event_hash: String,
}
impl AuditEvent {
pub fn new(event_type: AuditEventType, actor: AuditActor) -> Self {
Self {
event_id: Uuid::new_v4(),
event_type,
tenant_id: None,
key_id: None,
actor,
timestamp: Utc::now(),
details: serde_json::json!({}),
ip_address: None,
user_agent: None,
request_id: None,
session_id: None,
resource: None,
outcome: AuditOutcome::Success,
previous_hash: None,
event_hash: String::new(),
}
}
pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
self.tenant_id = Some(tenant_id);
self
}
pub fn with_key(mut self, key_id: Uuid) -> Self {
self.key_id = Some(key_id);
self
}
pub fn with_details(mut self, details: serde_json::Value) -> Self {
self.details = details;
self
}
pub fn with_ip(mut self, ip: IpAddr) -> Self {
self.ip_address = Some(ip);
self
}
pub fn with_user_agent(mut self, user_agent: String) -> Self {
self.user_agent = Some(user_agent);
self
}
pub fn with_request_id(mut self, request_id: Uuid) -> Self {
self.request_id = Some(request_id);
self
}
pub fn with_outcome(mut self, outcome: AuditOutcome) -> Self {
self.outcome = outcome;
self
}
pub fn with_resource(mut self, resource: String) -> Self {
self.resource = Some(resource);
self
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct AuditConfig {
pub retention_days: u32,
pub batch_size: usize,
pub flush_interval_secs: u32,
pub enable_siem: bool,
pub siem_endpoint: Option<String>,
pub enable_chaining: bool,
pub sensitive_fields: Vec<String>,
}
impl Default for AuditConfig {
fn default() -> Self {
Self {
retention_days: 365,
batch_size: 100,
flush_interval_secs: 5,
enable_siem: false,
siem_endpoint: None,
enable_chaining: true,
sensitive_fields: vec![
"password".to_string(),
"secret".to_string(),
"token".to_string(),
"api_key".to_string(),
"credential".to_string(),
"authorization".to_string(),
"cookie".to_string(),
"session".to_string(),
"private_key".to_string(),
],
}
}
}
#[async_trait::async_trait]
pub trait AuditStore: Send + Sync {
async fn store(&self, events: &[AuditEvent]) -> Result<(), AuditError>;
async fn query(
&self,
filter: AuditQueryFilter,
pagination: Pagination,
) -> Result<AuditQueryResult, AuditError>;
async fn get_range(
&self,
start: DateTime<Utc>,
end: DateTime<Utc>,
) -> Result<Vec<AuditEvent>, AuditError>;
}
#[async_trait::async_trait]
pub trait SiemForwarder: Send + Sync {
async fn forward(&self, event: &AuditEvent) -> Result<(), AuditError>;
}
pub struct AuditLogger {
primary_store: Arc<dyn AuditStore>,
secondary_store: Option<Arc<dyn AuditStore>>,
siem_forwarder: Option<Arc<dyn SiemForwarder>>,
last_hash: RwLock<Option<String>>,
buffer: RwLock<Vec<AuditEvent>>,
config: AuditConfig,
}
impl AuditLogger {
pub fn new(primary_store: Arc<dyn AuditStore>, config: AuditConfig) -> Self {
Self {
primary_store,
secondary_store: None,
siem_forwarder: None,
last_hash: RwLock::new(None),
buffer: RwLock::new(Vec::new()),
config,
}
}
pub fn with_secondary_store(mut self, store: Arc<dyn AuditStore>) -> Self {
self.secondary_store = Some(store);
self
}
pub fn with_siem(mut self, forwarder: Arc<dyn SiemForwarder>) -> Self {
self.siem_forwarder = Some(forwarder);
self
}
pub async fn log(&self, mut event: AuditEvent) -> Result<Uuid, AuditError> {
if self.config.enable_chaining {
let last_hash = self.last_hash.read().await;
event.previous_hash = last_hash.clone();
event.event_hash = self.compute_hash(&event);
}
event.details = self.redact_sensitive_fields(event.details);
let event_id = event.event_id;
{
let mut buffer = self.buffer.write().await;
buffer.push(event.clone());
if buffer.len() >= self.config.batch_size {
self.flush_buffer_internal(&mut buffer).await?;
}
}
if self.is_security_event(&event.event_type) {
self.forward_to_siem(&event).await?;
}
if self.config.enable_chaining {
let mut last_hash = self.last_hash.write().await;
*last_hash = Some(event.event_hash.clone());
}
Ok(event_id)
}
pub async fn flush(&self) -> Result<(), AuditError> {
let mut buffer = self.buffer.write().await;
self.flush_buffer_internal(&mut buffer).await
}
async fn flush_buffer_internal(&self, buffer: &mut Vec<AuditEvent>) -> Result<(), AuditError> {
if buffer.is_empty() {
return Ok(());
}
self.primary_store.store(buffer).await?;
if let Some(secondary) = &self.secondary_store {
let _ = secondary.store(buffer).await;
}
buffer.clear();
Ok(())
}
fn compute_hash(&self, event: &AuditEvent) -> String {
let mut hasher = Sha256::new();
hasher.update(event.event_id.as_bytes());
hasher.update(event.timestamp.timestamp().to_le_bytes());
hasher.update(serde_json::to_vec(&event.event_type).unwrap_or_default());
hasher.update(serde_json::to_vec(&event.actor).unwrap_or_default());
if let Some(prev) = &event.previous_hash {
hasher.update(prev.as_bytes());
}
hasher.update(event.details.to_string().as_bytes());
hex::encode(hasher.finalize())
}
fn is_security_event(&self, event_type: &AuditEventType) -> bool {
matches!(
event_type,
AuditEventType::ApiKeyValidationFailed
| AuditEventType::AccessDenied
| AuditEventType::ScopeEscalationAttempt
| AuditEventType::BruteForceDetected
| AuditEventType::IpBlocked
| AuditEventType::AnomalousActivity
| AuditEventType::SecurityIncident
)
}
async fn forward_to_siem(&self, event: &AuditEvent) -> Result<(), AuditError> {
if let Some(forwarder) = &self.siem_forwarder {
forwarder.forward(event).await?;
}
Ok(())
}
fn redact_sensitive_fields(&self, mut details: serde_json::Value) -> serde_json::Value {
if let Some(obj) = details.as_object_mut() {
for key in obj.keys().cloned().collect::<Vec<_>>() {
let key_lower = key.to_lowercase();
for sensitive in &self.config.sensitive_fields {
if key_lower.contains(sensitive) {
obj.insert(key.clone(), serde_json::json!("[REDACTED]"));
break;
}
}
}
}
details
}
pub async fn query(
&self,
filter: AuditQueryFilter,
pagination: Pagination,
) -> Result<AuditQueryResult, AuditError> {
self.primary_store.query(filter, pagination).await
}
pub async fn verify_integrity(
&self,
start: DateTime<Utc>,
end: DateTime<Utc>,
) -> Result<IntegrityReport, AuditError> {
let events = self.primary_store.get_range(start, end).await?;
let mut valid_count = 0u64;
let mut invalid_events = Vec::new();
let mut last_hash: Option<String> = None;
for event in &events {
if event.previous_hash != last_hash {
invalid_events.push(IntegrityViolation {
event_id: event.event_id,
violation_type: ViolationType::ChainBroken,
details: "Previous hash mismatch".to_string(),
});
}
let computed_hash = self.compute_hash(event);
if computed_hash != event.event_hash {
invalid_events.push(IntegrityViolation {
event_id: event.event_id,
violation_type: ViolationType::HashMismatch,
details: "Event hash verification failed".to_string(),
});
} else {
valid_count += 1;
}
last_hash = Some(event.event_hash.clone());
}
Ok(IntegrityReport {
start,
end,
total_events: valid_count + invalid_events.len() as u64,
valid_events: valid_count,
violations: invalid_events,
verified_at: Utc::now(),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AuditQueryFilter {
pub tenant_id: Option<Uuid>,
pub event_types: Option<Vec<AuditEventType>>,
pub start_time: Option<DateTime<Utc>>,
pub end_time: Option<DateTime<Utc>>,
pub ip_address: Option<IpAddr>,
pub key_id: Option<Uuid>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Pagination {
pub offset: u64,
pub limit: u64,
}
impl Default for Pagination {
fn default() -> Self {
Self {
offset: 0,
limit: 100,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditQueryResult {
pub events: Vec<AuditEvent>,
pub total_count: u64,
pub has_more: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntegrityReport {
pub start: DateTime<Utc>,
pub end: DateTime<Utc>,
pub total_events: u64,
pub valid_events: u64,
pub violations: Vec<IntegrityViolation>,
pub verified_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntegrityViolation {
pub event_id: Uuid,
pub violation_type: ViolationType,
pub details: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ViolationType {
ChainBroken,
HashMismatch,
TimestampAnomaly,
MissingEvent,
}
#[derive(Debug, thiserror::Error)]
pub enum AuditError {
#[error("Storage error: {0}")]
Storage(String),
#[error("SIEM forwarding error: {0}")]
Siem(String),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Integrity verification failed: {0}")]
Integrity(String),
}
pub struct InMemoryAuditStore {
events: RwLock<Vec<AuditEvent>>,
}
impl InMemoryAuditStore {
pub fn new() -> Self {
Self {
events: RwLock::new(Vec::new()),
}
}
}
impl Default for InMemoryAuditStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl AuditStore for InMemoryAuditStore {
async fn store(&self, events: &[AuditEvent]) -> Result<(), AuditError> {
let mut store = self.events.write().await;
store.extend(events.iter().cloned());
Ok(())
}
async fn query(
&self,
filter: AuditQueryFilter,
pagination: Pagination,
) -> Result<AuditQueryResult, AuditError> {
let store = self.events.read().await;
let filtered: Vec<_> = store
.iter()
.filter(|e| {
if let Some(tenant_id) = filter.tenant_id {
if e.tenant_id != Some(tenant_id) {
return false;
}
}
if let Some(ref types) = filter.event_types {
if !types.contains(&e.event_type) {
return false;
}
}
if let Some(start) = filter.start_time {
if e.timestamp < start {
return false;
}
}
if let Some(end) = filter.end_time {
if e.timestamp > end {
return false;
}
}
true
})
.cloned()
.collect();
let total_count = filtered.len() as u64;
let events: Vec<_> = filtered
.into_iter()
.skip(pagination.offset as usize)
.take(pagination.limit as usize)
.collect();
let has_more = pagination.offset + events.len() as u64 < total_count;
Ok(AuditQueryResult {
events,
total_count,
has_more,
})
}
async fn get_range(
&self,
start: DateTime<Utc>,
end: DateTime<Utc>,
) -> Result<Vec<AuditEvent>, AuditError> {
let store = self.events.read().await;
Ok(store
.iter()
.filter(|e| e.timestamp >= start && e.timestamp <= end)
.cloned()
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_audit_event_creation() {
let event = AuditEvent::new(AuditEventType::ApiKeyCreated, AuditActor::System)
.with_tenant(Uuid::new_v4())
.with_details(serde_json::json!({"test": "value"}));
assert!(matches!(event.event_type, AuditEventType::ApiKeyCreated));
assert!(event.tenant_id.is_some());
}
#[tokio::test]
async fn test_audit_logger_redaction() {
let store = Arc::new(InMemoryAuditStore::new());
let config = AuditConfig::default();
let logger = AuditLogger::new(store, config);
let details = serde_json::json!({
"user": "test",
"password": "secret123",
"api_key": "key123"
});
let redacted = logger.redact_sensitive_fields(details);
assert_eq!(redacted["user"], "test");
assert_eq!(redacted["password"], "[REDACTED]");
assert_eq!(redacted["api_key"], "[REDACTED]");
}
#[tokio::test]
async fn test_in_memory_store() {
let store = InMemoryAuditStore::new();
let event = AuditEvent::new(AuditEventType::ApiKeyCreated, AuditActor::System);
store.store(&[event.clone()]).await.unwrap();
let result = store
.query(AuditQueryFilter::default(), Pagination::default())
.await
.unwrap();
assert_eq!(result.total_count, 1);
assert_eq!(result.events[0].event_id, event.event_id);
}
}