use super::{AuditEntry, AuditLog, Outcome};
use crate::metrics::GatewayMetrics;
use async_trait::async_trait;
use rusqlite::{Connection, params};
use std::sync::{Arc, Mutex};
use std::time::SystemTime;
use tokio::sync::mpsc;
const CHANNEL_CAPACITY: usize = 4096;
pub struct SqliteAudit {
tx: Arc<Mutex<Option<mpsc::Sender<Arc<AuditEntry>>>>>,
handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
metrics: Arc<GatewayMetrics>,
}
impl SqliteAudit {
pub fn new(path: &str, metrics: Arc<GatewayMetrics>) -> anyhow::Result<Self> {
Self::with_rotation(path, None, None, metrics)
}
pub fn with_rotation(
path: &str,
max_entries: Option<usize>,
max_age_days: Option<u64>,
metrics: Arc<GatewayMetrics>,
) -> anyhow::Result<Self> {
let conn = Connection::open(path)?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS audit_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts INTEGER NOT NULL,
agent_id TEXT NOT NULL,
method TEXT NOT NULL,
tool TEXT,
arguments TEXT,
outcome TEXT NOT NULL,
reason TEXT,
input_tokens INTEGER NOT NULL DEFAULT 0
);",
)?;
let _ = conn.execute_batch("ALTER TABLE audit_log ADD COLUMN arguments TEXT;");
let _ = conn.execute_batch(
"ALTER TABLE audit_log ADD COLUMN input_tokens INTEGER NOT NULL DEFAULT 0;",
);
let conn = Arc::new(Mutex::new(conn));
let (tx, mut rx) = mpsc::channel::<Arc<AuditEntry>>(CHANNEL_CAPACITY);
let handle = tokio::spawn(async move {
while let Some(entry) = rx.recv().await {
let ts = entry
.ts
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
let (outcome_str, reason) = match &entry.outcome {
Outcome::Allowed => ("allowed".to_string(), None),
Outcome::Blocked(r) => ("blocked".to_string(), Some(r.clone())),
Outcome::Forwarded => ("forwarded".to_string(), None),
Outcome::Shadowed => ("shadowed".to_string(), None),
};
match &entry.outcome {
Outcome::Allowed => tracing::info!(
outcome = "allowed",
agent = %entry.agent_id,
method = %entry.method,
tool = entry.tool.as_deref().unwrap_or("-"),
),
Outcome::Blocked(r) => tracing::info!(
outcome = "blocked",
agent = %entry.agent_id,
method = %entry.method,
tool = entry.tool.as_deref().unwrap_or("-"),
reason = %r,
),
Outcome::Forwarded => tracing::info!(
outcome = "forwarded",
agent = %entry.agent_id,
method = %entry.method,
),
Outcome::Shadowed => tracing::info!(
outcome = "shadowed",
agent = %entry.agent_id,
method = %entry.method,
tool = entry.tool.as_deref().unwrap_or("-"),
),
}
let conn = conn.clone();
tokio::task::spawn_blocking(move || {
if let Ok(c) = conn.lock() {
let args_json = entry
.arguments
.as_ref()
.and_then(|v| serde_json::to_string(v).ok());
if let Err(e) = c.execute(
"INSERT INTO audit_log \
(ts, agent_id, method, tool, arguments, outcome, reason, input_tokens) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
params![
ts,
entry.agent_id,
entry.method,
entry.tool,
args_json,
outcome_str,
reason,
entry.input_tokens as i64
],
) {
tracing::error!(
error = %e,
agent = %entry.agent_id,
"audit insert failed"
);
}
if let Some(max) = max_entries
&& let Err(e) = c.execute(
"DELETE FROM audit_log WHERE id NOT IN \
(SELECT id FROM audit_log ORDER BY id DESC LIMIT ?1)",
params![max as i64],
)
{
tracing::warn!(error = %e, "audit rotation (max_entries) failed");
}
if let Some(days) = max_age_days {
let cutoff = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64
- (days as i64 * 86400);
if let Err(e) =
c.execute("DELETE FROM audit_log WHERE ts < ?1", params![cutoff])
{
tracing::warn!(error = %e, "audit rotation (max_age_days) failed");
}
}
}
})
.await
.ok();
}
});
Ok(Self {
tx: Arc::new(Mutex::new(Some(tx))),
handle: Arc::new(Mutex::new(Some(handle))),
metrics,
})
}
}
#[async_trait]
impl AuditLog for SqliteAudit {
fn record(&self, entry: Arc<AuditEntry>) {
if let Ok(guard) = self.tx.lock()
&& let Some(tx) = guard.as_ref()
&& tx.try_send(entry).is_err()
{
tracing::warn!("sqlite audit channel full — entry dropped");
self.metrics.record_audit_drop("sqlite");
}
}
async fn flush(&self) {
{
let mut guard = self.tx.lock().unwrap();
*guard = None;
}
let handle = {
let mut guard = self.handle.lock().unwrap();
guard.take()
};
if let Some(h) = handle {
let _ = h.await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::audit::Outcome;
use crate::metrics::GatewayMetrics;
use rusqlite::Connection;
use std::time::{Duration, UNIX_EPOCH};
use tempfile::NamedTempFile;
fn test_metrics() -> Arc<GatewayMetrics> {
Arc::new(GatewayMetrics::new().unwrap())
}
fn entry(outcome: Outcome) -> Arc<AuditEntry> {
Arc::new(AuditEntry {
ts: UNIX_EPOCH + Duration::from_secs(1_000_000),
agent_id: "agent-x".to_string(),
method: "tools/call".to_string(),
tool: Some("read_file".to_string()),
arguments: Some(serde_json::json!({"path": "/tmp/foo"})),
outcome,
request_id: "req-1".to_string(),
input_tokens: 5,
})
}
fn count_rows(path: &str) -> usize {
let conn = Connection::open(path).unwrap();
conn.query_row("SELECT COUNT(*) FROM audit_log", [], |r| r.get::<_, i64>(0))
.unwrap() as usize
}
fn fetch_outcomes(path: &str) -> Vec<String> {
let conn = Connection::open(path).unwrap();
let mut stmt = conn
.prepare("SELECT outcome FROM audit_log ORDER BY id")
.unwrap();
stmt.query_map([], |r| r.get::<_, String>(0))
.unwrap()
.map(|r| r.unwrap())
.collect()
}
fn fetch_reasons(path: &str) -> Vec<Option<String>> {
let conn = Connection::open(path).unwrap();
let mut stmt = conn
.prepare("SELECT reason FROM audit_log ORDER BY id")
.unwrap();
stmt.query_map([], |r| r.get::<_, Option<String>>(0))
.unwrap()
.map(|r| r.unwrap())
.collect()
}
#[tokio::test]
async fn records_are_persisted() {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_str().unwrap();
let audit = SqliteAudit::new(path, test_metrics()).unwrap();
audit.record(entry(Outcome::Allowed));
audit.flush().await;
assert_eq!(count_rows(path), 1);
}
#[tokio::test]
async fn outcome_strings_are_correct() {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_str().unwrap();
let audit = SqliteAudit::new(path, test_metrics()).unwrap();
audit.record(entry(Outcome::Allowed));
audit.record(entry(Outcome::Forwarded));
audit.record(entry(Outcome::Shadowed));
audit.record(entry(Outcome::Blocked("denied".to_string())));
audit.flush().await;
let outcomes = fetch_outcomes(path);
assert_eq!(outcomes, ["allowed", "forwarded", "shadowed", "blocked"]);
}
#[tokio::test]
async fn null_reason_stored_for_non_blocked_outcomes() {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_str().unwrap();
let audit = SqliteAudit::new(path, test_metrics()).unwrap();
audit.record(entry(Outcome::Allowed));
audit.record(entry(Outcome::Forwarded));
audit.record(entry(Outcome::Shadowed));
audit.flush().await;
let reasons = fetch_reasons(path);
assert!(
reasons.iter().all(|r| r.is_none()),
"expected all NULL reasons"
);
}
#[tokio::test]
async fn blocked_reason_stored() {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_str().unwrap();
let audit = SqliteAudit::new(path, test_metrics()).unwrap();
audit.record(entry(Outcome::Blocked("rate limit".to_string())));
audit.flush().await;
let reasons = fetch_reasons(path);
assert_eq!(reasons, [Some("rate limit".to_string())]);
}
#[tokio::test]
async fn max_entries_rotation_keeps_newest() {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_str().unwrap();
let audit = SqliteAudit::with_rotation(path, Some(3), None, test_metrics()).unwrap();
for _ in 0..6 {
audit.record(entry(Outcome::Allowed));
}
audit.flush().await;
assert_eq!(count_rows(path), 3, "rotation should keep only 3 rows");
}
#[tokio::test]
async fn max_age_days_rotation_purges_old() {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_str().unwrap();
let audit = SqliteAudit::with_rotation(path, None, Some(1), test_metrics()).unwrap();
audit.record(entry(Outcome::Allowed)); audit.flush().await;
assert_eq!(count_rows(path), 0, "old entry should have been purged");
}
#[tokio::test]
async fn flush_is_idempotent() {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_str().unwrap();
let audit = SqliteAudit::new(path, test_metrics()).unwrap();
audit.record(entry(Outcome::Allowed));
audit.flush().await;
audit.flush().await;
assert_eq!(count_rows(path), 1);
}
#[tokio::test]
async fn multiple_entries_all_persisted() {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_str().unwrap();
let audit = SqliteAudit::new(path, test_metrics()).unwrap();
for _ in 0..10 {
audit.record(entry(Outcome::Forwarded));
}
audit.flush().await;
assert_eq!(count_rows(path), 10);
}
#[tokio::test]
async fn input_tokens_persisted() {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_str().unwrap();
let audit = SqliteAudit::new(path, test_metrics()).unwrap();
audit.record(entry(Outcome::Allowed)); audit.flush().await;
let conn = Connection::open(path).unwrap();
let tokens: i64 = conn
.query_row("SELECT input_tokens FROM audit_log LIMIT 1", [], |r| {
r.get(0)
})
.unwrap();
assert_eq!(tokens, 5);
}
#[tokio::test]
async fn full_channel_drops_entry_and_increments_counter() {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_str().unwrap();
let metrics = test_metrics();
let audit = SqliteAudit::new(path, Arc::clone(&metrics)).unwrap();
for _ in 0..(CHANNEL_CAPACITY + 10) {
audit.record(entry(Outcome::Allowed));
}
let rendered = metrics.render();
assert!(
rendered.contains("arbit_audit_drops_total"),
"drop counter must be registered"
);
audit.flush().await;
let rows = count_rows(path);
assert!(rows > 0, "at least some entries must have been persisted");
}
}