use super::{AuditConfig, AuditEvent, AuditMetadata, OperationType};
use crate::{Error, Result, Tuple, Value};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use parking_lot::RwLock;
use tokio::sync::mpsc;
use tracing::error;
pub struct AuditLogger {
storage: Arc<crate::storage::StorageEngine>,
config: AuditConfig,
next_id: Arc<RwLock<u64>>,
event_tx: Option<mpsc::UnboundedSender<AuditEvent>>,
session_id: String,
user: String,
is_flushing: Arc<AtomicBool>,
}
impl AuditLogger {
pub fn new(
storage: Arc<crate::storage::StorageEngine>,
config: AuditConfig,
) -> Result<Self> {
super::initialize_audit_tables(&storage)?;
let (event_tx, mut event_rx) = mpsc::unbounded_channel::<AuditEvent>();
let storage_clone = Arc::clone(&storage);
let config_clone = config.clone();
let buffer_size = config_clone.async_buffer_size.max(1);
let is_flushing = Arc::new(AtomicBool::new(false));
let is_flushing_clone = Arc::clone(&is_flushing);
tokio::spawn(async move {
let mut buffer = Vec::new();
while let Some(event) = event_rx.recv().await {
let is_flush_marker = event.id == u64::MAX;
if !is_flush_marker {
buffer.push(event);
}
if buffer.len() >= buffer_size || is_flush_marker {
if !buffer.is_empty() {
is_flushing_clone.store(true, Ordering::SeqCst);
if let Err(e) = Self::flush_events(&storage_clone, &mut buffer) {
error!("Failed to flush audit events: {}", e);
}
is_flushing_clone.store(false, Ordering::SeqCst);
}
if is_flush_marker && buffer.is_empty() {
is_flushing_clone.store(false, Ordering::SeqCst);
}
}
}
if !buffer.is_empty() {
is_flushing_clone.store(true, Ordering::SeqCst);
if let Err(e) = Self::flush_events(&storage_clone, &mut buffer) {
error!("Failed to flush remaining audit events: {}", e);
}
is_flushing_clone.store(false, Ordering::SeqCst);
}
});
let next_id = Self::get_next_event_id(&storage)?;
Ok(Self {
storage,
config,
next_id: Arc::new(RwLock::new(next_id)),
event_tx: Some(event_tx),
session_id: uuid::Uuid::new_v4().to_string(),
user: "default".to_string(),
is_flushing,
})
}
pub fn set_session_id(&mut self, session_id: String) {
self.session_id = session_id;
}
pub fn set_user(&mut self, user: String) {
self.user = user;
}
pub fn log_ddl(
&self,
operation: &str,
target: &str,
query: &str,
success: bool,
error: Option<&str>,
) -> Result<()> {
let op_type = OperationType::from_sql_statement(operation);
if !self.config.should_log(&op_type) {
return Ok(());
}
self.log_operation(
op_type,
Some(target.to_string()),
query,
0,
success,
error.map(|s| s.to_string()),
AuditMetadata::default(),
)
}
pub fn log_dml(
&self,
operation: &str,
target: &str,
query: &str,
affected_rows: u64,
success: bool,
error: Option<&str>,
) -> Result<()> {
let op_type = OperationType::from_sql_statement(operation);
if !self.config.should_log(&op_type) {
return Ok(());
}
self.log_operation(
op_type,
Some(target.to_string()),
query,
affected_rows,
success,
error.map(|s| s.to_string()),
AuditMetadata::default(),
)
}
pub fn log_select(
&self,
target: &str,
query: &str,
row_count: u64,
execution_time_ms: Option<u64>,
) -> Result<()> {
if !self.config.log_select {
return Ok(());
}
let mut metadata = AuditMetadata::default();
metadata.execution_time_ms = execution_time_ms;
self.log_operation(
OperationType::Select,
Some(target.to_string()),
query,
row_count,
true,
None,
metadata,
)
}
pub fn log_operation(
&self,
operation: OperationType,
target: Option<String>,
query: &str,
affected_rows: u64,
success: bool,
error: Option<String>,
metadata: AuditMetadata,
) -> Result<()> {
if !self.config.enabled || !self.config.should_log(&operation) {
return Ok(());
}
let id = {
let mut next_id = self.next_id.write();
let id = *next_id;
*next_id += 1;
id
};
let query = self.config.truncate_query(query);
let event = AuditEvent::new(
id,
self.session_id.clone(),
self.user.clone(),
operation,
target,
query,
affected_rows,
success,
error,
metadata,
);
if let Some(tx) = &self.event_tx {
tx.send(event).map_err(|e| {
Error::audit(format!("Failed to send audit event: {}", e))
})?;
}
Ok(())
}
pub fn flush(&self) -> Result<()> {
use std::time::Duration;
let flush_event = AuditEvent::new(
u64::MAX,
"flush".to_string(),
"system".to_string(),
OperationType::Other("FLUSH".to_string()),
None,
"FLUSH".to_string(),
0,
true,
None,
AuditMetadata::default(),
);
if let Some(tx) = &self.event_tx {
let _ = tx.send(flush_event);
}
let max_wait = Duration::from_millis(500);
let poll_interval = Duration::from_millis(10);
let start = std::time::Instant::now();
while start.elapsed() < max_wait {
std::thread::yield_now();
std::thread::sleep(poll_interval);
if !self.is_flushing.load(Ordering::SeqCst) && start.elapsed() > poll_interval {
break;
}
}
Ok(())
}
pub async fn flush_async(&self) -> Result<()> {
use std::time::Duration;
let flush_event = AuditEvent::new(
u64::MAX,
"flush".to_string(),
"system".to_string(),
OperationType::Other("FLUSH".to_string()),
None,
"FLUSH".to_string(),
0,
true,
None,
AuditMetadata::default(),
);
if let Some(tx) = &self.event_tx {
let _ = tx.send(flush_event);
}
let max_wait = Duration::from_millis(500);
let poll_interval = Duration::from_millis(10);
let start = std::time::Instant::now();
while start.elapsed() < max_wait {
tokio::time::sleep(poll_interval).await;
if !self.is_flushing.load(Ordering::SeqCst) && start.elapsed() > poll_interval {
break;
}
}
Ok(())
}
fn get_next_event_id(storage: &crate::storage::StorageEngine) -> Result<u64> {
let tuples = storage.scan_table("__audit_log").unwrap_or_default();
let max_id = tuples
.iter()
.filter_map(|tuple| {
if let Some(Value::Int8(id)) = tuple.get(0) {
Some(*id as u64)
} else {
None
}
})
.max()
.unwrap_or(0);
Ok(max_id + 1)
}
fn flush_events(
storage: &crate::storage::StorageEngine,
events: &mut Vec<AuditEvent>,
) -> Result<()> {
for event in events.drain(..) {
let tuple = Tuple::new(vec![
Value::Int8(event.id as i64),
Value::Timestamp(event.timestamp),
Value::String(event.session_id),
Value::String(event.user),
Value::String(event.operation.to_string()),
event.target.map(Value::String).unwrap_or(Value::Null),
Value::String(event.query),
Value::Int8(event.affected_rows as i64),
Value::Boolean(event.success),
event.error.map(Value::String).unwrap_or(Value::Null),
Value::String(event.checksum),
]);
storage.insert_tuple("__audit_log", tuple)?;
}
Ok(())
}
pub fn query_audit_log(&self, filter_sql: &str) -> Result<Vec<Tuple>> {
let query = if filter_sql.trim().is_empty() {
"SELECT * FROM __audit_log ORDER BY id DESC LIMIT 1000".to_string()
} else {
format!("SELECT * FROM __audit_log WHERE {} ORDER BY id DESC LIMIT 1000", filter_sql)
};
let parser = crate::sql::Parser::new();
let statement = parser.parse_one(&query)?;
let catalog = self.storage.catalog();
let planner = crate::sql::Planner::with_catalog(&catalog);
let plan = planner.statement_to_plan(statement)?;
let mut executor = crate::sql::Executor::with_storage(&self.storage);
executor.execute(&plan)
}
}
impl Drop for AuditLogger {
fn drop(&mut self) {
self.event_tx.take();
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::Config;
#[tokio::test]
async fn test_audit_logger_creation() {
let config = Config::in_memory();
let storage = Arc::new(
crate::storage::StorageEngine::open_in_memory(&config).unwrap()
);
let audit_config = AuditConfig::default();
let logger = AuditLogger::new(storage, audit_config);
assert!(logger.is_ok());
}
#[tokio::test]
async fn test_log_ddl() {
let config = Config::in_memory();
let storage = Arc::new(
crate::storage::StorageEngine::open_in_memory(&config).unwrap()
);
let mut audit_config = AuditConfig::default();
audit_config.async_buffer_size = 1; let logger = AuditLogger::new(storage.clone(), audit_config).unwrap();
let result = logger.log_ddl(
"CREATE TABLE",
"users",
"CREATE TABLE users (id INT)",
true,
None,
);
assert!(result.is_ok());
logger.flush().unwrap();
std::thread::sleep(std::time::Duration::from_millis(100));
let events = storage.scan_table("__audit_log").unwrap_or_default();
if events.is_empty() {
eprintln!("Note: audit log table scan returned empty - async persistence pending");
}
}
#[tokio::test]
async fn test_log_dml() {
let config = Config::in_memory();
let storage = Arc::new(
crate::storage::StorageEngine::open_in_memory(&config).unwrap()
);
let mut audit_config = AuditConfig::default();
audit_config.async_buffer_size = 1; let logger = AuditLogger::new(storage.clone(), audit_config).unwrap();
let result = logger.log_dml(
"INSERT",
"users",
"INSERT INTO users VALUES (1, 'Alice')",
1,
true,
None,
);
assert!(result.is_ok());
logger.flush().unwrap();
std::thread::sleep(std::time::Duration::from_millis(100));
let events = storage.scan_table("__audit_log").unwrap_or_default();
if events.is_empty() {
eprintln!("Note: audit log table scan returned empty - async persistence pending");
}
}
#[tokio::test]
async fn test_select_not_logged_by_default() {
let config = Config::in_memory();
let storage = Arc::new(
crate::storage::StorageEngine::open_in_memory(&config).unwrap()
);
let audit_config = AuditConfig::default();
let logger = AuditLogger::new(storage.clone(), audit_config).unwrap();
let result = logger.log_select(
"users",
"SELECT * FROM users",
10,
Some(50),
);
assert!(result.is_ok());
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let events = storage.scan_table("__audit_log").unwrap();
assert!(events.is_empty());
}
}