use rusqlite::{Connection, Transaction, TransactionBehavior, params};
use solo_core::{Error, Result};
use std::path::PathBuf;
use std::time::Duration;
use tokio::sync::mpsc;
use crate::init::open_sqlcipher;
use crate::key_material::KeyMaterial;
pub const AUDIT_QUEUE_CAPACITY: usize = 1024;
pub const AUDIT_BATCH_FLUSH_MAX_EVENTS: usize = 64;
pub const AUDIT_BATCH_FLUSH_MAX_MILLIS: u64 = 50;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuditOperation {
MemoryRemember,
MemoryRememberBatch,
MemoryUpdate,
MemoryForget,
MemoryConsolidate,
MemoryReembed,
MemoryIngestDocument,
MemoryForgetDocument,
MemoryNormalizeSubjects,
MemoryBackup,
MemorySaveSnapshot,
MemoryTriplesExtract,
LlmSamplingCall,
MemoryRecall,
MemoryContext,
MemoryInspect,
MemoryThemes,
MemoryFactsAbout,
MemoryEntities,
MemoryContradictions,
MemoryContradictionResolve,
MemoryInspectCluster,
MemorySearchDocs,
MemoryInspectDocument,
MemoryListDocuments,
RedactionApplied,
TenantCreate,
TenantDelete,
TenantBackup,
TenantRestore,
TenantSetQuota,
GdprForgetUser,
}
impl AuditOperation {
pub const fn as_str(&self) -> &'static str {
match self {
Self::MemoryRemember => "memory.remember",
Self::MemoryRememberBatch => "memory.remember_batch",
Self::MemoryUpdate => "memory.update",
Self::MemoryForget => "memory.forget",
Self::MemoryConsolidate => "memory.consolidate",
Self::MemoryReembed => "memory.reembed",
Self::MemoryIngestDocument => "memory.ingest_document",
Self::MemoryForgetDocument => "memory.forget_document",
Self::MemoryNormalizeSubjects => "memory.normalize_subjects",
Self::MemoryBackup => "memory.backup",
Self::MemorySaveSnapshot => "memory.save_snapshot",
Self::MemoryTriplesExtract => "memory.triples_extract",
Self::LlmSamplingCall => "llm.sampling_call",
Self::MemoryRecall => "memory.recall",
Self::MemoryContext => "memory.context",
Self::MemoryInspect => "memory.inspect",
Self::MemoryThemes => "memory.themes",
Self::MemoryFactsAbout => "memory.facts_about",
Self::MemoryEntities => "memory.entities",
Self::MemoryContradictions => "memory.contradictions",
Self::MemoryContradictionResolve => "memory.contradiction_resolve",
Self::MemoryInspectCluster => "memory.inspect_cluster",
Self::MemorySearchDocs => "memory.search_docs",
Self::MemoryInspectDocument => "memory.inspect_document",
Self::MemoryListDocuments => "memory.list_documents",
Self::RedactionApplied => "redaction.applied",
Self::TenantCreate => "tenant.create",
Self::TenantDelete => "tenant.delete",
Self::TenantBackup => "tenant.backup",
Self::TenantRestore => "tenant.restore",
Self::TenantSetQuota => "tenant.set_quota",
Self::GdprForgetUser => "gdpr.forget_user",
}
}
}
impl std::fmt::Display for AuditOperation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuditResult {
Ok,
Error,
Forbidden,
}
impl AuditResult {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Ok => "ok",
Self::Error => "error",
Self::Forbidden => "forbidden",
}
}
}
#[derive(Debug, Clone)]
pub struct AuditEvent {
pub ts_ms: i64,
pub principal_subject: Option<String>,
pub operation: AuditOperation,
pub target_id: Option<String>,
pub result: AuditResult,
pub details: Option<serde_json::Value>,
}
impl AuditEvent {
pub fn ok_now(
principal_subject: Option<String>,
operation: AuditOperation,
target_id: Option<String>,
) -> Self {
Self {
ts_ms: chrono::Utc::now().timestamp_millis(),
principal_subject,
operation,
target_id,
result: AuditResult::Ok,
details: None,
}
}
pub fn error_now(
principal_subject: Option<String>,
operation: AuditOperation,
target_id: Option<String>,
error_message: impl Into<String>,
) -> Self {
Self {
ts_ms: chrono::Utc::now().timestamp_millis(),
principal_subject,
operation,
target_id,
result: AuditResult::Error,
details: Some(serde_json::json!({ "error": error_message.into() })),
}
}
}
pub fn insert_audit_row_in_tx(tx: &Transaction<'_>, event: &AuditEvent) -> Result<()> {
let details_json: Option<String> = match event.details.as_ref() {
Some(v) => Some(
serde_json::to_string(v)
.map_err(|e| Error::storage(format!("serialize audit details: {e}")))?,
),
None => None,
};
tx.execute(
"INSERT INTO audit_events (
ts_ms, principal_subject, operation, target_id, result, details_json
) VALUES (?, ?, ?, ?, ?, ?)",
params![
event.ts_ms,
event.principal_subject.as_deref(),
event.operation.as_str(),
event.target_id.as_deref(),
event.result.as_str(),
details_json,
],
)
.map_err(|e| Error::storage(format!("INSERT audit_events: {e}")))?;
Ok(())
}
pub fn insert_audit_admin_row(
conn: &Connection,
ts_ms: i64,
principal_subject: Option<&str>,
operation: AuditOperation,
target_tenant_id: Option<&str>,
result: AuditResult,
details: Option<&serde_json::Value>,
) -> Result<i64> {
let details_json: Option<String> = match details {
Some(v) => Some(
serde_json::to_string(v)
.map_err(|e| Error::storage(format!("serialize admin audit details: {e}")))?,
),
None => None,
};
conn.execute(
"INSERT INTO audit_events_admin (
ts_ms, principal_subject, operation, target_tenant_id, result, details_json
) VALUES (?, ?, ?, ?, ?, ?)",
params![
ts_ms,
principal_subject,
operation.as_str(),
target_tenant_id,
result.as_str(),
details_json,
],
)
.map_err(|e| Error::storage(format!("INSERT audit_events_admin: {e}")))?;
Ok(conn.last_insert_rowid())
}
#[allow(dead_code)]
fn insert_audit_row_one_off(conn: &Connection, event: &AuditEvent) -> Result<()> {
let details_json: Option<String> = match event.details.as_ref() {
Some(v) => Some(
serde_json::to_string(v)
.map_err(|e| Error::storage(format!("serialize audit details: {e}")))?,
),
None => None,
};
conn.execute(
"INSERT INTO audit_events (
ts_ms, principal_subject, operation, target_id, result, details_json
) VALUES (?, ?, ?, ?, ?, ?)",
params![
event.ts_ms,
event.principal_subject.as_deref(),
event.operation.as_str(),
event.target_id.as_deref(),
event.result.as_str(),
details_json,
],
)
.map_err(|e| Error::storage(format!("INSERT audit_events (async): {e}")))?;
Ok(())
}
#[derive(Clone)]
pub struct AuditWriter {
tx: mpsc::Sender<AuditEvent>,
}
pub struct AuditWriterShutdown {
drainer: tokio::task::JoinHandle<()>,
}
impl AuditWriterShutdown {
pub async fn join(self) {
if let Err(e) = self.drainer.await {
tracing::warn!(error = %e, "audit drainer task join error");
}
}
}
impl AuditWriter {
pub fn spawn(db_path: PathBuf, key: Option<KeyMaterial>) -> (Self, AuditWriterShutdown) {
let (tx, rx) = mpsc::channel::<AuditEvent>(AUDIT_QUEUE_CAPACITY);
let drainer = tokio::spawn(audit_drainer_loop(rx, db_path, key));
(Self { tx }, AuditWriterShutdown { drainer })
}
pub fn noop() -> Self {
let (tx, _rx) = mpsc::channel::<AuditEvent>(1);
Self { tx }
}
pub fn emit_async(&self, event: AuditEvent) -> bool {
match self.tx.try_send(event) {
Ok(()) => true,
Err(mpsc::error::TrySendError::Full(ev)) => {
tracing::warn!(
operation = %ev.operation,
"audit: mpsc full, dropping event (queue capacity {AUDIT_QUEUE_CAPACITY})"
);
false
}
Err(mpsc::error::TrySendError::Closed(ev)) => {
tracing::debug!(
operation = %ev.operation,
"audit: writer closed, dropping event (shutdown in progress?)"
);
false
}
}
}
pub fn emit_ok(
&self,
principal_subject: Option<String>,
operation: AuditOperation,
target_id: Option<String>,
) {
let _ = self.emit_async(AuditEvent::ok_now(principal_subject, operation, target_id));
}
pub fn emit_error(
&self,
principal_subject: Option<String>,
operation: AuditOperation,
target_id: Option<String>,
err: impl std::fmt::Display,
) {
let _ = self.emit_async(AuditEvent::error_now(
principal_subject,
operation,
target_id,
err.to_string(),
));
}
}
async fn audit_drainer_loop(
mut rx: mpsc::Receiver<AuditEvent>,
db_path: PathBuf,
key: Option<KeyMaterial>,
) {
let mut conn: Option<Connection> = None;
let flush_interval = Duration::from_millis(AUDIT_BATCH_FLUSH_MAX_MILLIS);
loop {
let first = match rx.recv().await {
Some(e) => e,
None => {
tracing::debug!("audit drainer: mpsc closed, exiting");
return;
}
};
let mut batch: Vec<AuditEvent> = Vec::with_capacity(AUDIT_BATCH_FLUSH_MAX_EVENTS);
batch.push(first);
while batch.len() < AUDIT_BATCH_FLUSH_MAX_EVENTS {
match rx.try_recv() {
Ok(e) => batch.push(e),
Err(mpsc::error::TryRecvError::Empty) => break,
Err(mpsc::error::TryRecvError::Disconnected) => break,
}
}
if batch.len() < AUDIT_BATCH_FLUSH_MAX_EVENTS {
let deadline = tokio::time::Instant::now() + flush_interval;
while batch.len() < AUDIT_BATCH_FLUSH_MAX_EVENTS {
let now = tokio::time::Instant::now();
if now >= deadline {
break;
}
let remaining = deadline - now;
match tokio::time::timeout(remaining, rx.recv()).await {
Ok(Some(e)) => batch.push(e),
Ok(None) => break, Err(_) => break, }
}
}
if conn.is_none() {
match open_drainer_conn(&db_path, key.as_ref()) {
Ok(c) => conn = Some(c),
Err(e) => {
tracing::error!(
error = %e,
path = %db_path.display(),
dropped = batch.len(),
"audit drainer: failed to open connection, dropping batch"
);
continue;
}
}
}
let c = conn.as_mut().expect("conn just set");
if let Err(e) = flush_batch(c, &batch) {
tracing::error!(
error = %e,
dropped = batch.len(),
"audit drainer: flush failed, dropping batch"
);
}
}
}
fn open_drainer_conn(db_path: &std::path::Path, key: Option<&KeyMaterial>) -> Result<Connection> {
if let Some(k) = key {
open_sqlcipher(db_path, k)
} else {
let conn = Connection::open(db_path).map_err(|e| {
Error::storage(format!("audit drainer open {}: {e}", db_path.display()))
})?;
conn.execute_batch(
"PRAGMA journal_mode = wal;
PRAGMA busy_timeout = 5000;",
)
.map_err(|e| Error::storage(format!("audit drainer pragmas: {e}")))?;
Ok(conn)
}
}
fn flush_batch(conn: &mut Connection, batch: &[AuditEvent]) -> Result<()> {
let tx = conn
.transaction_with_behavior(TransactionBehavior::Immediate)
.map_err(|e| Error::storage(format!("audit drainer BEGIN IMMEDIATE: {e}")))?;
for event in batch {
insert_audit_row_in_tx(&tx, event)?;
}
tx.commit()
.map_err(|e| Error::storage(format!("audit drainer COMMIT: {e}")))?;
Ok(())
}
pub fn purge_older_than(conn: &mut Connection, cutoff_ms: i64) -> Result<usize> {
let tx = conn
.transaction_with_behavior(TransactionBehavior::Immediate)
.map_err(|e| Error::storage(format!("BEGIN IMMEDIATE for purge: {e}")))?;
let rows = tx
.execute(
"DELETE FROM audit_events WHERE ts_ms < ?",
params![cutoff_ms],
)
.map_err(|e| Error::storage(format!("DELETE audit_events: {e}")))?;
tx.commit()
.map_err(|e| Error::storage(format!("COMMIT purge: {e}")))?;
Ok(rows)
}
#[cfg(test)]
mod tests {
use super::*;
fn open_in_memory_with_audit() -> Connection {
let mut conn = Connection::open_in_memory().expect("open in-memory db");
crate::migration::run_migrations(&mut conn).expect("migrations");
conn
}
#[test]
fn audit_operation_display_matches_canonical_string() {
assert_eq!(
AuditOperation::MemoryRemember.to_string(),
"memory.remember"
);
assert_eq!(AuditOperation::MemoryRecall.to_string(), "memory.recall");
assert_eq!(AuditOperation::TenantCreate.to_string(), "tenant.create");
}
#[test]
fn audit_result_check_constraint_rejects_unknown_value() {
let conn = open_in_memory_with_audit();
let res = conn.execute(
"INSERT INTO audit_events (ts_ms, operation, result) VALUES (?, ?, ?)",
params![0i64, "memory.remember", "bogus"],
);
assert!(res.is_err(), "result='bogus' must violate CHECK");
}
#[test]
fn insert_then_select_round_trip() {
let mut conn = open_in_memory_with_audit();
let tx = conn.transaction().unwrap();
let event = AuditEvent::ok_now(
Some("alice".into()),
AuditOperation::MemoryRemember,
Some("00000000-0000-0000-0000-000000000001".into()),
);
insert_audit_row_in_tx(&tx, &event).unwrap();
tx.commit().unwrap();
let (op, principal, target, result): (String, Option<String>, Option<String>, String) =
conn.query_row(
"SELECT operation, principal_subject, target_id, result FROM audit_events",
[],
|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?)),
)
.unwrap();
assert_eq!(op, "memory.remember");
assert_eq!(principal.as_deref(), Some("alice"));
assert_eq!(
target.as_deref(),
Some("00000000-0000-0000-0000-000000000001")
);
assert_eq!(result, "ok");
}
#[test]
fn purge_older_than_drops_old_rows() {
let mut conn = open_in_memory_with_audit();
for ts in [100i64, 200, 300] {
conn.execute(
"INSERT INTO audit_events (ts_ms, operation, result) VALUES (?, ?, ?)",
params![ts, "memory.remember", "ok"],
)
.unwrap();
}
let purged = purge_older_than(&mut conn, 250).unwrap();
assert_eq!(purged, 2, "purge should drop ts=100 and ts=200");
let remaining: i64 = conn
.query_row("SELECT COUNT(*) FROM audit_events", [], |r| r.get(0))
.unwrap();
assert_eq!(remaining, 1);
}
#[test]
fn purge_older_than_idempotent() {
let mut conn = open_in_memory_with_audit();
conn.execute(
"INSERT INTO audit_events (ts_ms, operation, result) VALUES (100, 'memory.remember', 'ok')",
[],
)
.unwrap();
assert_eq!(purge_older_than(&mut conn, 200).unwrap(), 1);
assert_eq!(purge_older_than(&mut conn, 200).unwrap(), 0);
}
#[test]
fn noop_writer_drops_events_without_blocking() {
let writer = AuditWriter::noop();
for _ in 0..100 {
writer.emit_ok(None, AuditOperation::MemoryRecall, None);
}
}
#[test]
fn migration_0005_applied_once_on_repeated_open() {
let mut conn = Connection::open_in_memory().expect("in-memory");
crate::migration::run_migrations(&mut conn).unwrap();
crate::migration::run_migrations(&mut conn).unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM schema_migrations WHERE version = 5",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(count, 1, "migration 0005 must apply at most once");
}
#[test]
fn audit_table_present_and_indices_exist() {
let conn = open_in_memory_with_audit();
let table: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='audit_events'",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(table, 1);
for idx in ["idx_audit_ts", "idx_audit_principal"] {
let exists: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='index' AND name=?",
params![idx],
|r| r.get(0),
)
.unwrap();
assert_eq!(exists, 1, "missing index: {idx}");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn batched_load_does_not_drop_under_burst() {
let tmp = tempfile::TempDir::new().unwrap();
let db_path = tmp.path().join("burst.db");
let mut conn = rusqlite::Connection::open(&db_path).unwrap();
conn.execute_batch(
"PRAGMA journal_mode = wal;
PRAGMA foreign_keys = ON;
PRAGMA busy_timeout = 5000;",
)
.unwrap();
crate::migration::run_migrations(&mut conn).unwrap();
drop(conn);
let (audit, shutdown) = AuditWriter::spawn(db_path.clone(), None);
for i in 0..1000 {
audit.emit_ok(
Some(format!("user-{i}")),
AuditOperation::MemoryRecall,
None,
);
}
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
drop(audit);
shutdown.join().await;
let conn = rusqlite::Connection::open(&db_path).unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM audit_events WHERE operation = 'memory.recall'",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(count, 1000, "all 1000 audit rows must land");
}
#[test]
fn ok_now_and_error_now_construct_expected_shapes() {
let ok = AuditEvent::ok_now(
Some("u".into()),
AuditOperation::MemoryRecall,
Some("tid".into()),
);
assert_eq!(ok.result, AuditResult::Ok);
assert!(ok.details.is_none());
let err = AuditEvent::error_now(None, AuditOperation::MemoryRecall, None, "boom");
assert_eq!(err.result, AuditResult::Error);
let details = err.details.expect("error event carries details");
assert_eq!(details["error"], "boom");
}
#[test]
fn memory_triples_extract_renders_canonical_string() {
assert_eq!(
AuditOperation::MemoryTriplesExtract.as_str(),
"memory.triples_extract"
);
assert_eq!(
format!("{}", AuditOperation::MemoryTriplesExtract),
"memory.triples_extract"
);
}
#[test]
fn llm_sampling_call_renders_canonical_string() {
assert_eq!(
AuditOperation::LlmSamplingCall.as_str(),
"llm.sampling_call"
);
assert_eq!(
format!("{}", AuditOperation::LlmSamplingCall),
"llm.sampling_call"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn audit_rows_under_burst_are_ordered_by_ts_then_rowid() {
let tmp = tempfile::TempDir::new().unwrap();
let db_path = tmp.path().join("burst-order.db");
let mut conn = rusqlite::Connection::open(&db_path).unwrap();
conn.execute_batch(
"PRAGMA journal_mode = wal;
PRAGMA foreign_keys = ON;
PRAGMA busy_timeout = 5000;",
)
.unwrap();
crate::migration::run_migrations(&mut conn).unwrap();
drop(conn);
let (audit, shutdown) = AuditWriter::spawn(db_path.clone(), None);
const BURST: i64 = 50;
for i in 0..BURST {
audit.emit_ok(None, AuditOperation::MemoryRecall, Some(i.to_string()));
}
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
drop(audit);
shutdown.join().await;
let conn = rusqlite::Connection::open(&db_path).unwrap();
let mut stmt = conn
.prepare(
"SELECT target_id, ts_ms, audit_id FROM audit_events
WHERE operation = 'memory.recall'
ORDER BY ts_ms ASC, audit_id ASC",
)
.unwrap();
let rows: Vec<(Option<String>, i64, i64)> = stmt
.query_map([], |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)))
.unwrap()
.map(|r| r.unwrap())
.collect();
assert_eq!(
rows.len(),
BURST as usize,
"all {BURST} burst rows must land"
);
let target_ids: Vec<String> = rows
.iter()
.map(|(t, _, _)| t.clone().expect("target_id present"))
.collect();
let expected: Vec<String> = (0..BURST).map(|i| i.to_string()).collect();
assert_eq!(
target_ids, expected,
"burst-emit order must be recoverable from (ts_ms ASC, audit_id ASC)"
);
let unique_ts: std::collections::HashSet<i64> = rows.iter().map(|(_, ts, _)| *ts).collect();
assert!(
unique_ts.len() < rows.len(),
"burst should produce same-ms collisions (got {} unique ts for {} rows)",
unique_ts.len(),
rows.len()
);
let audit_ids: Vec<i64> = rows.iter().map(|(_, _, id)| *id).collect();
let mut sorted = audit_ids.clone();
sorted.sort();
assert_eq!(
audit_ids, sorted,
"audit_id sequence under the (ts_ms, audit_id) sort must already be ascending"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn audit_row_ts_ms_is_monotonic_within_writer_actor() {
let tmp = tempfile::TempDir::new().unwrap();
let db_path = tmp.path().join("monotonic.db");
let mut conn = rusqlite::Connection::open(&db_path).unwrap();
conn.execute_batch(
"PRAGMA journal_mode = wal;
PRAGMA foreign_keys = ON;
PRAGMA busy_timeout = 5000;",
)
.unwrap();
crate::migration::run_migrations(&mut conn).unwrap();
drop(conn);
let (audit, shutdown) = AuditWriter::spawn(db_path.clone(), None);
for _ in 0..100 {
audit.emit_ok(None, AuditOperation::MemoryRecall, None);
}
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
drop(audit);
shutdown.join().await;
let conn = rusqlite::Connection::open(&db_path).unwrap();
let mut stmt = conn
.prepare(
"SELECT ts_ms FROM audit_events
WHERE operation = 'memory.recall'
ORDER BY audit_id ASC",
)
.unwrap();
let ts_seq: Vec<i64> = stmt
.query_map([], |r| r.get::<_, i64>(0))
.unwrap()
.map(|r| r.unwrap())
.collect();
assert_eq!(ts_seq.len(), 100);
for w in ts_seq.windows(2) {
assert!(
w[1] >= w[0],
"ts_ms must not decrease within an emit burst, got {} then {}",
w[0],
w[1]
);
}
}
}