use std::path::{Path, PathBuf};
use aa_core::identity::AgentId;
use async_trait::async_trait;
use sqlx::SqlitePool;
use super::agent::{AgentFilter, AgentRecord};
use super::audit::{AuditEvent, AuditFilter};
use super::backend::StorageBackend;
use super::error::{StorageError, StorageResult};
use super::health::StorageHealth;
use super::metric::{Metric, MetricPoint, MetricQuery};
use super::policy::{PolicyDocument, PolicyMeta, PolicyVersion};
use super::retention::{RetentionPolicy, RetentionStats};
const SCHEMA: &[&str] = &[
"CREATE TABLE IF NOT EXISTS audit_events (
ts TEXT NOT NULL,
event_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
team_id TEXT,
action TEXT NOT NULL,
decision TEXT NOT NULL,
dry_run INTEGER NOT NULL DEFAULT 0,
shadow_decision TEXT,
matched_rule_id TEXT,
payload TEXT,
PRIMARY KEY (ts, event_id)
)",
"CREATE INDEX IF NOT EXISTS idx_audit_agent ON audit_events(agent_id)",
"CREATE INDEX IF NOT EXISTS idx_audit_ts ON audit_events(ts)",
"CREATE TABLE IF NOT EXISTS agent_registry (
agent_id TEXT PRIMARY KEY,
team_id TEXT,
org_id TEXT,
metadata TEXT NOT NULL DEFAULT '{}',
registered_at TEXT NOT NULL,
last_seen_at TEXT NOT NULL,
enforcement_mode TEXT NOT NULL DEFAULT 'enforce'
)",
"CREATE TABLE IF NOT EXISTS policy_versions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
version INTEGER NOT NULL,
document TEXT NOT NULL,
created_at TEXT NOT NULL,
is_active INTEGER NOT NULL DEFAULT 0,
UNIQUE(name, version)
)",
"CREATE TABLE IF NOT EXISTS metrics (
ts TEXT NOT NULL,
agent_id TEXT NOT NULL,
metric TEXT NOT NULL,
value REAL NOT NULL,
labels TEXT NOT NULL DEFAULT '{}'
)",
];
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SqliteConfig {
pub path: PathBuf,
}
pub struct SqliteBackend {
pool: SqlitePool,
}
impl SqliteBackend {
pub async fn open(config: &SqliteConfig) -> StorageResult<Self> {
let path = expand_tilde(&config.path);
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent).map_err(|e| {
StorageError::ConnectionFailed(format!(
"failed to create parent directory {}: {e}",
parent.display()
))
})?;
}
}
let url = format!("sqlite://{}?mode=rwc", path.display());
let pool = SqlitePool::connect(&url)
.await
.map_err(|e| StorageError::ConnectionFailed(e.to_string()))?;
sqlx::query("PRAGMA journal_mode=WAL")
.execute(&pool)
.await
.map_err(|e| StorageError::ConnectionFailed(format!("WAL pragma: {e}")))?;
Ok(Self { pool })
}
#[allow(dead_code)] pub(crate) fn pool(&self) -> &SqlitePool {
&self.pool
}
}
fn agent_id_to_text(id: &AgentId) -> String {
uuid::Uuid::from_bytes(*id.as_bytes()).to_string()
}
fn agent_id_from_text(s: &str) -> StorageResult<AgentId> {
let uuid = uuid::Uuid::parse_str(s).map_err(|e| StorageError::QueryFailed(format!("invalid agent_id {s}: {e}")))?;
Ok(AgentId::from_bytes(*uuid.as_bytes()))
}
fn row_to_audit_event(row: &sqlx::sqlite::SqliteRow) -> StorageResult<AuditEvent> {
use sqlx::Row;
let ts_text: String = row
.try_get("ts")
.map_err(|e| StorageError::QueryFailed(format!("ts column: {e}")))?;
let ts = chrono::DateTime::parse_from_rfc3339(&ts_text)
.map_err(|e| StorageError::QueryFailed(format!("ts parse: {e}")))?
.with_timezone(&chrono::Utc);
let event_id_text: String = row
.try_get("event_id")
.map_err(|e| StorageError::QueryFailed(format!("event_id column: {e}")))?;
let event_id =
uuid::Uuid::parse_str(&event_id_text).map_err(|e| StorageError::QueryFailed(format!("event_id parse: {e}")))?;
let agent_id_text: String = row
.try_get("agent_id")
.map_err(|e| StorageError::QueryFailed(format!("agent_id column: {e}")))?;
let agent_id = agent_id_from_text(&agent_id_text)?;
let dry_run: i64 = row
.try_get("dry_run")
.map_err(|e| StorageError::QueryFailed(format!("dry_run column: {e}")))?;
let payload_text: Option<String> = row
.try_get("payload")
.map_err(|e| StorageError::QueryFailed(format!("payload column: {e}")))?;
let payload = payload_text
.map(|t| {
serde_json::from_str::<serde_json::Value>(&t)
.map_err(|e| StorageError::QueryFailed(format!("payload parse: {e}")))
})
.transpose()?;
Ok(AuditEvent {
ts,
event_id,
agent_id,
team_id: row
.try_get("team_id")
.map_err(|e| StorageError::QueryFailed(format!("team_id column: {e}")))?,
action: row
.try_get("action")
.map_err(|e| StorageError::QueryFailed(format!("action column: {e}")))?,
decision: row
.try_get("decision")
.map_err(|e| StorageError::QueryFailed(format!("decision column: {e}")))?,
dry_run: dry_run != 0,
shadow_decision: row
.try_get("shadow_decision")
.map_err(|e| StorageError::QueryFailed(format!("shadow_decision column: {e}")))?,
matched_rule_id: row
.try_get("matched_rule_id")
.map_err(|e| StorageError::QueryFailed(format!("matched_rule_id column: {e}")))?,
payload,
})
}
fn push_audit_where<'q>(qb: &mut sqlx::QueryBuilder<'q, sqlx::Sqlite>, filter: &'q AuditFilter) {
let mut started = false;
let mut connective = move |qb: &mut sqlx::QueryBuilder<'q, sqlx::Sqlite>| {
qb.push(if started { " AND " } else { " WHERE " });
started = true;
};
if let Some(agent_id) = filter.agent_id.as_ref() {
connective(qb);
qb.push("agent_id = ").push_bind(agent_id_to_text(agent_id));
}
if let Some(team_id) = filter.team_id.as_ref() {
connective(qb);
qb.push("team_id = ").push_bind(team_id.clone());
}
if let Some(from) = filter.from {
connective(qb);
qb.push("ts >= ").push_bind(from.to_rfc3339());
}
if let Some(to) = filter.to {
connective(qb);
qb.push("ts < ").push_bind(to.to_rfc3339());
}
if filter.dry_run_only {
connective(qb);
qb.push("dry_run = 1");
}
}
fn row_to_agent_record(row: &sqlx::sqlite::SqliteRow) -> StorageResult<AgentRecord> {
use sqlx::Row;
let agent_id_text: String = row
.try_get("agent_id")
.map_err(|e| StorageError::QueryFailed(format!("agent_id column: {e}")))?;
let agent_id = agent_id_from_text(&agent_id_text)?;
let metadata_text: String = row
.try_get("metadata")
.map_err(|e| StorageError::QueryFailed(format!("metadata column: {e}")))?;
let metadata: std::collections::BTreeMap<String, String> =
serde_json::from_str(&metadata_text).map_err(|e| StorageError::QueryFailed(format!("metadata parse: {e}")))?;
let registered_at: String = row
.try_get("registered_at")
.map_err(|e| StorageError::QueryFailed(format!("registered_at column: {e}")))?;
let registered_at = chrono::DateTime::parse_from_rfc3339(®istered_at)
.map_err(|e| StorageError::QueryFailed(format!("registered_at parse: {e}")))?
.with_timezone(&chrono::Utc);
let last_seen_at: String = row
.try_get("last_seen_at")
.map_err(|e| StorageError::QueryFailed(format!("last_seen_at column: {e}")))?;
let last_seen_at = chrono::DateTime::parse_from_rfc3339(&last_seen_at)
.map_err(|e| StorageError::QueryFailed(format!("last_seen_at parse: {e}")))?
.with_timezone(&chrono::Utc);
Ok(AgentRecord {
agent_id,
team_id: row
.try_get("team_id")
.map_err(|e| StorageError::QueryFailed(format!("team_id column: {e}")))?,
org_id: row
.try_get("org_id")
.map_err(|e| StorageError::QueryFailed(format!("org_id column: {e}")))?,
metadata,
registered_at,
last_seen_at,
enforcement_mode: row
.try_get("enforcement_mode")
.map_err(|e| StorageError::QueryFailed(format!("enforcement_mode column: {e}")))?,
})
}
fn push_agent_where<'q>(qb: &mut sqlx::QueryBuilder<'q, sqlx::Sqlite>, filter: &'q AgentFilter) {
let mut started = false;
let mut connective = move |qb: &mut sqlx::QueryBuilder<'q, sqlx::Sqlite>| {
qb.push(if started { " AND " } else { " WHERE " });
started = true;
};
if let Some(team_id) = filter.team_id.as_ref() {
connective(qb);
qb.push("team_id = ").push_bind(team_id.clone());
}
if let Some(org_id) = filter.org_id.as_ref() {
connective(qb);
qb.push("org_id = ").push_bind(org_id.clone());
}
if let Some(name_contains) = filter.name_contains.as_ref() {
connective(qb);
qb.push("json_extract(metadata, '$.name') LIKE ")
.push_bind(format!("%{name_contains}%"));
}
}
#[async_trait]
impl StorageBackend for SqliteBackend {
async fn append_audit_event(&self, event: &AuditEvent) -> StorageResult<()> {
let payload_text = match event.payload.as_ref() {
Some(value) => Some(
serde_json::to_string(value)
.map_err(|e| StorageError::QueryFailed(format!("payload serialize: {e}")))?,
),
None => None,
};
sqlx::query(
"INSERT INTO audit_events \
(ts, event_id, agent_id, team_id, action, decision, dry_run, shadow_decision, matched_rule_id, payload) \
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
)
.bind(event.ts.to_rfc3339())
.bind(event.event_id.to_string())
.bind(agent_id_to_text(&event.agent_id))
.bind(event.team_id.clone())
.bind(&event.action)
.bind(&event.decision)
.bind(i64::from(event.dry_run))
.bind(event.shadow_decision.clone())
.bind(event.matched_rule_id.clone())
.bind(payload_text)
.execute(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
Ok(())
}
async fn query_audit_events(&self, filter: AuditFilter) -> StorageResult<Vec<AuditEvent>> {
let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(
"SELECT ts, event_id, agent_id, team_id, action, decision, dry_run, \
shadow_decision, matched_rule_id, payload FROM audit_events",
);
push_audit_where(&mut qb, &filter);
qb.push(" ORDER BY ts DESC");
if let Some(limit) = filter.limit {
qb.push(" LIMIT ").push_bind(i64::from(limit));
if let Some(offset) = filter.offset {
qb.push(" OFFSET ").push_bind(i64::from(offset));
}
}
let rows = qb
.build()
.fetch_all(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
rows.iter().map(row_to_audit_event).collect()
}
async fn count_audit_events(&self, filter: AuditFilter) -> StorageResult<u64> {
let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new("SELECT COUNT(*) FROM audit_events");
push_audit_where(&mut qb, &filter);
let (count,): (i64,) = qb
.build_query_as()
.fetch_one(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
u64::try_from(count).map_err(|e| StorageError::QueryFailed(format!("count overflow: {e}")))
}
async fn upsert_agent(&self, record: AgentRecord) -> StorageResult<()> {
let metadata_text = serde_json::to_string(&record.metadata)
.map_err(|e| StorageError::QueryFailed(format!("metadata serialize: {e}")))?;
sqlx::query(
"INSERT OR REPLACE INTO agent_registry \
(agent_id, team_id, org_id, metadata, registered_at, last_seen_at, enforcement_mode) \
VALUES (?, ?, ?, ?, ?, ?, ?)",
)
.bind(agent_id_to_text(&record.agent_id))
.bind(record.team_id)
.bind(record.org_id)
.bind(metadata_text)
.bind(record.registered_at.to_rfc3339())
.bind(record.last_seen_at.to_rfc3339())
.bind(record.enforcement_mode)
.execute(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
Ok(())
}
async fn get_agent(&self, id: &AgentId) -> StorageResult<Option<AgentRecord>> {
let row = sqlx::query(
"SELECT agent_id, team_id, org_id, metadata, registered_at, last_seen_at, enforcement_mode \
FROM agent_registry WHERE agent_id = ?",
)
.bind(agent_id_to_text(id))
.fetch_optional(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
row.as_ref().map(row_to_agent_record).transpose()
}
async fn list_agents(&self, filter: AgentFilter) -> StorageResult<Vec<AgentRecord>> {
let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new(
"SELECT agent_id, team_id, org_id, metadata, registered_at, last_seen_at, \
enforcement_mode FROM agent_registry",
);
push_agent_where(&mut qb, &filter);
qb.push(" ORDER BY agent_id");
if let Some(limit) = filter.limit {
qb.push(" LIMIT ").push_bind(i64::from(limit));
if let Some(offset) = filter.offset {
qb.push(" OFFSET ").push_bind(i64::from(offset));
}
}
let rows = qb
.build()
.fetch_all(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
rows.iter().map(row_to_agent_record).collect()
}
async fn delete_agent(&self, id: &AgentId) -> StorageResult<()> {
let result = sqlx::query("DELETE FROM agent_registry WHERE agent_id = ?")
.bind(agent_id_to_text(id))
.execute(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
if result.rows_affected() == 0 {
return Err(StorageError::NotFound(agent_id_to_text(id)));
}
Ok(())
}
async fn save_policy(&self, doc: PolicyDocument) -> StorageResult<PolicyVersion> {
let document_text = std::str::from_utf8(&doc.bytes)
.map_err(|e| StorageError::QueryFailed(format!("document bytes not UTF-8: {e}")))?
.to_owned();
let mut tx = self
.pool
.begin()
.await
.map_err(|e| StorageError::QueryFailed(format!("begin tx: {e}")))?;
let (next_version,): (i64,) =
sqlx::query_as("SELECT COALESCE(MAX(version), 0) + 1 FROM policy_versions WHERE name = ?")
.bind(&doc.name)
.fetch_one(&mut *tx)
.await
.map_err(|e| StorageError::QueryFailed(format!("compute next version: {e}")))?;
let created_at = chrono::Utc::now();
sqlx::query(
"INSERT INTO policy_versions (name, version, document, created_at, is_active) \
VALUES (?, ?, ?, ?, 0)",
)
.bind(&doc.name)
.bind(next_version)
.bind(&document_text)
.bind(created_at.to_rfc3339())
.execute(&mut *tx)
.await
.map_err(|e| match e {
sqlx::Error::Database(db) if db.is_unique_violation() => {
StorageError::Conflict(format!("{}@{next_version}", doc.name))
}
other => StorageError::QueryFailed(other.to_string()),
})?;
tx.commit()
.await
.map_err(|e| StorageError::QueryFailed(format!("commit tx: {e}")))?;
let version =
u32::try_from(next_version).map_err(|e| StorageError::QueryFailed(format!("version overflow: {e}")))?;
Ok(PolicyVersion {
meta: PolicyMeta {
name: doc.name.clone(),
version,
created_at,
is_active: false,
},
document: doc,
})
}
async fn get_active_policy(&self, name: &str) -> StorageResult<Option<PolicyDocument>> {
let row: Option<(String,)> =
sqlx::query_as("SELECT document FROM policy_versions WHERE name = ? AND is_active = 1 LIMIT 1")
.bind(name)
.fetch_optional(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
Ok(row.map(|(document,)| PolicyDocument {
name: name.to_owned(),
bytes: document.into_bytes(),
}))
}
async fn list_policy_versions(&self, name: &str) -> StorageResult<Vec<PolicyMeta>> {
let rows: Vec<(i64, String, i64)> = sqlx::query_as(
"SELECT version, created_at, is_active FROM policy_versions \
WHERE name = ? ORDER BY version DESC",
)
.bind(name)
.fetch_all(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
rows.into_iter()
.map(|(version, created_at, is_active)| {
let version =
u32::try_from(version).map_err(|e| StorageError::QueryFailed(format!("version overflow: {e}")))?;
let created_at = chrono::DateTime::parse_from_rfc3339(&created_at)
.map_err(|e| StorageError::QueryFailed(format!("created_at parse: {e}")))?
.with_timezone(&chrono::Utc);
Ok(PolicyMeta {
name: name.to_owned(),
version,
created_at,
is_active: is_active != 0,
})
})
.collect()
}
async fn rollback_policy(&self, name: &str, version: u32) -> StorageResult<()> {
let version_i = i64::from(version);
let mut tx = self
.pool
.begin()
.await
.map_err(|e| StorageError::QueryFailed(format!("begin tx: {e}")))?;
let exists: Option<(i64,)> = sqlx::query_as("SELECT 1 FROM policy_versions WHERE name = ? AND version = ?")
.bind(name)
.bind(version_i)
.fetch_optional(&mut *tx)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
if exists.is_none() {
return Err(StorageError::NotFound(format!("{name}@{version}")));
}
sqlx::query("UPDATE policy_versions SET is_active = 0 WHERE name = ? AND is_active = 1")
.bind(name)
.execute(&mut *tx)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
sqlx::query("UPDATE policy_versions SET is_active = 1 WHERE name = ? AND version = ?")
.bind(name)
.bind(version_i)
.execute(&mut *tx)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
tx.commit()
.await
.map_err(|e| StorageError::QueryFailed(format!("commit tx: {e}")))?;
Ok(())
}
async fn record_metric(&self, m: Metric) -> StorageResult<()> {
let labels_text = serde_json::to_string(&m.labels)
.map_err(|e| StorageError::QueryFailed(format!("labels serialize: {e}")))?;
sqlx::query("INSERT INTO metrics (ts, agent_id, metric, value, labels) VALUES (?, ?, ?, ?, ?)")
.bind(m.ts.to_rfc3339())
.bind(agent_id_to_text(&m.agent_id))
.bind(&m.metric)
.bind(m.value)
.bind(labels_text)
.execute(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
Ok(())
}
async fn query_metrics(&self, q: MetricQuery) -> StorageResult<Vec<MetricPoint>> {
if q.bucket.is_some() {
tracing::warn!("query_metrics(bucket) ignored on SQLite backend; raw samples returned");
}
let mut qb = sqlx::QueryBuilder::<sqlx::Sqlite>::new("SELECT ts, value FROM metrics");
let mut started = false;
let mut connective = |qb: &mut sqlx::QueryBuilder<sqlx::Sqlite>| {
qb.push(if started { " AND " } else { " WHERE " });
started = true;
};
if let Some(agent_id) = q.agent_id.as_ref() {
connective(&mut qb);
qb.push("agent_id = ").push_bind(agent_id_to_text(agent_id));
}
if let Some(metric) = q.metric.as_ref() {
connective(&mut qb);
qb.push("metric = ").push_bind(metric.clone());
}
if let Some(from) = q.from {
connective(&mut qb);
qb.push("ts >= ").push_bind(from.to_rfc3339());
}
if let Some(to) = q.to {
connective(&mut qb);
qb.push("ts < ").push_bind(to.to_rfc3339());
}
qb.push(" ORDER BY ts ASC");
if let Some(limit) = q.limit {
qb.push(" LIMIT ").push_bind(i64::from(limit));
}
let rows: Vec<(String, f64)> = qb
.build_query_as()
.fetch_all(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
rows.into_iter()
.map(|(ts_text, value)| {
let ts = chrono::DateTime::parse_from_rfc3339(&ts_text)
.map_err(|e| StorageError::QueryFailed(format!("ts parse: {e}")))?
.with_timezone(&chrono::Utc);
Ok(MetricPoint { ts, value })
})
.collect()
}
async fn migrate(&self) -> StorageResult<()> {
for stmt in SCHEMA {
sqlx::query(stmt)
.execute(&self.pool)
.await
.map_err(|e| StorageError::MigrationFailed(e.to_string()))?;
}
Ok(())
}
async fn apply_retention(&self, policy: &RetentionPolicy) -> StorageResult<RetentionStats> {
if matches!(policy.cold_action, crate::storage::ColdAction::Archive) {
tracing::warn!(
archive_url = ?policy.archive_url,
"archive cold_action not supported on SQLite backend — falling back to drop"
);
}
let now = chrono::Utc::now();
let cold_threshold = now - chrono::Duration::days(i64::from(policy.hot_days + policy.warm_days));
let hot_threshold = now - chrono::Duration::days(i64::from(policy.hot_days));
let cold_ts = cold_threshold.to_rfc3339();
let dropped_rows: u64 = if policy.dry_run {
let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM audit_events WHERE ts < ?")
.bind(&cold_ts)
.fetch_one(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
u64::try_from(count).unwrap_or(0)
} else {
let result = sqlx::query("DELETE FROM audit_events WHERE ts < ?")
.bind(&cold_ts)
.execute(&self.pool)
.await
.map_err(|e| StorageError::RetentionError(e.to_string()))?;
result.rows_affected()
};
let (hot_count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM audit_events WHERE ts >= ?")
.bind(hot_threshold.to_rfc3339())
.fetch_one(&self.pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
Ok(RetentionStats {
hot_rows: u64::try_from(hot_count).unwrap_or(0),
compressed_rows: 0,
archived_rows: 0,
dropped_rows,
freed_bytes: 0,
ran_at: chrono::Utc::now(),
})
}
async fn healthcheck(&self) -> StorageResult<StorageHealth> {
let start = std::time::Instant::now();
sqlx::query("SELECT 1")
.execute(&self.pool)
.await
.map_err(|e| StorageError::ConnectionFailed(e.to_string()))?;
async fn count_rows(pool: &SqlitePool, table: &str) -> StorageResult<u64> {
let sql = format!("SELECT COUNT(*) FROM {table}");
let (count,): (i64,) = sqlx::query_as(&sql)
.fetch_one(pool)
.await
.map_err(|e| StorageError::QueryFailed(e.to_string()))?;
Ok(u64::try_from(count).unwrap_or(0))
}
let row_counts = super::health::RowCounts {
audit_events: count_rows(&self.pool, "audit_events").await?,
agents: count_rows(&self.pool, "agent_registry").await?,
policy_versions: count_rows(&self.pool, "policy_versions").await?,
};
let latency_ms = u32::try_from(start.elapsed().as_millis()).unwrap_or(u32::MAX);
Ok(StorageHealth {
status: super::health::HealthStatus::Ok,
backend: "sqlite",
latency_ms,
row_counts,
timescale: None,
})
}
}
fn expand_tilde(path: &Path) -> PathBuf {
if let Ok(stripped) = path.strip_prefix("~") {
if let Some(home) = dirs::home_dir() {
return home.join(stripped);
}
}
path.to_path_buf()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn expand_tilde_leaves_non_tilde_path_unchanged() {
let absolute = Path::new("/tmp/sqlite-skeleton/db.sqlite");
assert_eq!(expand_tilde(absolute), PathBuf::from(absolute));
let relative = Path::new("data/db.sqlite");
assert_eq!(expand_tilde(relative), PathBuf::from("data/db.sqlite"));
}
async fn open_temp_backend() -> (TempDir, SqliteBackend) {
let tmp = TempDir::new().expect("tempdir");
let path = tmp.path().join("test.db");
let backend = SqliteBackend::open(&SqliteConfig { path })
.await
.expect("open should succeed");
(tmp, backend)
}
#[tokio::test]
async fn open_creates_parent_dir_and_enables_wal() {
let tmp = TempDir::new().expect("tempdir");
let path = tmp.path().join("nested").join("dir").join("test.db");
let backend = SqliteBackend::open(&SqliteConfig { path: path.clone() })
.await
.expect("open should succeed");
assert!(path.exists(), "database file should be created");
assert!(path.parent().expect("parent").exists(), "parent dir should be created");
let (mode,): (String,) = sqlx::query_as("PRAGMA journal_mode")
.fetch_one(backend.pool())
.await
.expect("journal_mode probe");
assert_eq!(mode.to_lowercase(), "wal", "WAL pragma should stick");
}
#[tokio::test]
async fn migrate_creates_all_expected_tables_and_indexes() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate should succeed");
let names: Vec<(String, String)> = sqlx::query_as(
"SELECT type, name FROM sqlite_master \
WHERE type IN ('table', 'index') AND name NOT LIKE 'sqlite_%'",
)
.fetch_all(backend.pool())
.await
.expect("sqlite_master probe");
let actual: std::collections::BTreeSet<(String, String)> = names.into_iter().collect();
let expected: std::collections::BTreeSet<(String, String)> = [
("table", "audit_events"),
("table", "agent_registry"),
("table", "policy_versions"),
("table", "metrics"),
("index", "idx_audit_agent"),
("index", "idx_audit_ts"),
]
.into_iter()
.map(|(t, n)| (t.to_owned(), n.to_owned()))
.collect();
for entry in &expected {
assert!(actual.contains(entry), "missing schema entry: {entry:?}");
}
}
#[tokio::test]
async fn migrate_is_idempotent_across_repeated_calls() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("first migrate");
backend.migrate().await.expect("second migrate should be a no-op");
backend.migrate().await.expect("third migrate should still be a no-op");
let (count,): (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM sqlite_master \
WHERE type IN ('table', 'index') AND name NOT LIKE 'sqlite_%'",
)
.fetch_one(backend.pool())
.await
.expect("count probe");
assert_eq!(count, 6, "exactly 4 tables + 2 indexes expected");
}
fn sample_events() -> Vec<AuditEvent> {
let agent_a = AgentId::from_bytes([1; 16]);
let agent_b = AgentId::from_bytes([2; 16]);
vec![
AuditEvent {
ts: chrono::DateTime::parse_from_rfc3339("2026-05-21T10:00:00Z")
.unwrap()
.with_timezone(&chrono::Utc),
event_id: uuid::Uuid::from_u128(1),
agent_id: agent_a,
team_id: Some("team-x".into()),
action: "tool_call".into(),
decision: "allow".into(),
dry_run: false,
shadow_decision: None,
matched_rule_id: Some("rule-1".into()),
payload: Some(serde_json::json!({"tool": "fetch"})),
},
AuditEvent {
ts: chrono::DateTime::parse_from_rfc3339("2026-05-21T11:00:00Z")
.unwrap()
.with_timezone(&chrono::Utc),
event_id: uuid::Uuid::from_u128(2),
agent_id: agent_a,
team_id: Some("team-x".into()),
action: "policy_decision".into(),
decision: "deny".into(),
dry_run: true,
shadow_decision: Some("allow".into()),
matched_rule_id: None,
payload: None,
},
AuditEvent {
ts: chrono::DateTime::parse_from_rfc3339("2026-05-21T12:00:00Z")
.unwrap()
.with_timezone(&chrono::Utc),
event_id: uuid::Uuid::from_u128(3),
agent_id: agent_b,
team_id: Some("team-y".into()),
action: "tool_call".into(),
decision: "allow".into(),
dry_run: false,
shadow_decision: None,
matched_rule_id: Some("rule-2".into()),
payload: Some(serde_json::json!({"k": [1, 2, 3]})),
},
]
}
async fn migrated_backend_with_samples() -> (TempDir, SqliteBackend, Vec<AuditEvent>) {
let (tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
let events = sample_events();
for ev in &events {
backend.append_audit_event(ev).await.expect("append");
}
(tmp, backend, events)
}
#[tokio::test]
async fn audit_round_trip_preserves_all_columns_including_payload() {
let (_tmp, backend, events) = migrated_backend_with_samples().await;
let mut out = backend.query_audit_events(AuditFilter::default()).await.expect("query");
out.reverse();
assert_eq!(out, events, "all columns + payload must round-trip");
}
#[tokio::test]
async fn audit_filter_dimensions_independently_narrow_results() {
let (_tmp, backend, _events) = migrated_backend_with_samples().await;
let agent_a = AgentId::from_bytes([1; 16]);
let agent_b = AgentId::from_bytes([2; 16]);
let by_a = backend
.query_audit_events(AuditFilter {
agent_id: Some(agent_a),
..AuditFilter::default()
})
.await
.expect("agent filter");
assert_eq!(by_a.len(), 2);
assert!(by_a.iter().all(|e| e.agent_id == agent_a));
let by_team_y = backend
.query_audit_events(AuditFilter {
team_id: Some("team-y".into()),
..AuditFilter::default()
})
.await
.expect("team filter");
assert_eq!(by_team_y.len(), 1);
assert_eq!(by_team_y[0].agent_id, agent_b);
let only_dry = backend
.query_audit_events(AuditFilter {
dry_run_only: true,
..AuditFilter::default()
})
.await
.expect("dry_run filter");
assert_eq!(only_dry.len(), 1);
assert!(only_dry[0].dry_run);
let in_window = backend
.query_audit_events(AuditFilter {
from: Some(
chrono::DateTime::parse_from_rfc3339("2026-05-21T10:30:00Z")
.unwrap()
.with_timezone(&chrono::Utc),
),
to: Some(
chrono::DateTime::parse_from_rfc3339("2026-05-21T11:30:00Z")
.unwrap()
.with_timezone(&chrono::Utc),
),
..AuditFilter::default()
})
.await
.expect("time-range filter");
assert_eq!(in_window.len(), 1);
assert_eq!(in_window[0].event_id, uuid::Uuid::from_u128(2));
}
#[tokio::test]
async fn audit_query_limit_and_offset_produce_disjoint_pages() {
let (_tmp, backend, _events) = migrated_backend_with_samples().await;
let first = backend
.query_audit_events(AuditFilter {
limit: Some(2),
offset: Some(0),
..AuditFilter::default()
})
.await
.expect("page 1");
let second = backend
.query_audit_events(AuditFilter {
limit: Some(2),
offset: Some(2),
..AuditFilter::default()
})
.await
.expect("page 2");
assert_eq!(first.len(), 2);
assert_eq!(second.len(), 1);
let ids_first: std::collections::HashSet<_> = first.iter().map(|e| e.event_id).collect();
let ids_second: std::collections::HashSet<_> = second.iter().map(|e| e.event_id).collect();
assert!(ids_first.is_disjoint(&ids_second));
}
#[tokio::test]
async fn audit_count_matches_query_result_size() {
let (_tmp, backend, _events) = migrated_backend_with_samples().await;
let total = backend
.count_audit_events(AuditFilter::default())
.await
.expect("count all");
assert_eq!(total, 3);
let agent_a = AgentId::from_bytes([1; 16]);
let scoped = backend
.count_audit_events(AuditFilter {
agent_id: Some(agent_a),
..AuditFilter::default()
})
.await
.expect("count scoped");
let scoped_rows = backend
.query_audit_events(AuditFilter {
agent_id: Some(agent_a),
..AuditFilter::default()
})
.await
.expect("query scoped");
assert_eq!(scoped, scoped_rows.len() as u64);
}
fn sample_agent(seed: u8, team: &str, org: &str, name: &str) -> AgentRecord {
let mut metadata = std::collections::BTreeMap::new();
metadata.insert("name".to_owned(), name.to_owned());
let ts = chrono::DateTime::parse_from_rfc3339("2026-05-21T09:00:00Z")
.unwrap()
.with_timezone(&chrono::Utc);
AgentRecord {
agent_id: AgentId::from_bytes([seed; 16]),
team_id: Some(team.to_owned()),
org_id: Some(org.to_owned()),
metadata,
registered_at: ts,
last_seen_at: ts,
enforcement_mode: "enforce".into(),
}
}
#[tokio::test]
async fn agent_upsert_is_idempotent_and_updates_existing_row() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
let mut rec = sample_agent(7, "team-x", "org-1", "Alpha");
backend.upsert_agent(rec.clone()).await.expect("first upsert");
let later = chrono::DateTime::parse_from_rfc3339("2026-05-21T12:00:00Z")
.unwrap()
.with_timezone(&chrono::Utc);
rec.last_seen_at = later;
backend.upsert_agent(rec.clone()).await.expect("second upsert");
let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM agent_registry")
.fetch_one(backend.pool())
.await
.expect("count");
assert_eq!(count, 1, "upsert should not duplicate the row");
let fetched = backend.get_agent(&rec.agent_id).await.expect("get").expect("present");
assert_eq!(fetched.last_seen_at, later);
}
#[tokio::test]
async fn agent_get_returns_none_for_unknown_id() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
let missing = AgentId::from_bytes([0xff; 16]);
assert!(backend.get_agent(&missing).await.expect("get").is_none());
}
#[tokio::test]
async fn agent_list_filters_by_team_org_and_name_substring() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
for rec in [
sample_agent(1, "team-x", "org-1", "Alpha"),
sample_agent(2, "team-x", "org-2", "Bravo"),
sample_agent(3, "team-y", "org-1", "Alpha-prime"),
] {
backend.upsert_agent(rec).await.expect("seed");
}
let team_x = backend
.list_agents(AgentFilter {
team_id: Some("team-x".into()),
..AgentFilter::default()
})
.await
.expect("by team");
assert_eq!(team_x.len(), 2);
let org_1 = backend
.list_agents(AgentFilter {
org_id: Some("org-1".into()),
..AgentFilter::default()
})
.await
.expect("by org");
assert_eq!(org_1.len(), 2);
let named_alpha = backend
.list_agents(AgentFilter {
name_contains: Some("Alpha".into()),
..AgentFilter::default()
})
.await
.expect("by name substring");
assert_eq!(named_alpha.len(), 2);
assert!(named_alpha
.iter()
.all(|r| r.metadata.get("name").unwrap().contains("Alpha")));
}
#[tokio::test]
async fn agent_delete_removes_row_and_second_delete_returns_not_found() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
let rec = sample_agent(9, "team-x", "org-1", "Bravo");
backend.upsert_agent(rec.clone()).await.expect("seed");
backend.delete_agent(&rec.agent_id).await.expect("first delete");
assert!(
backend.get_agent(&rec.agent_id).await.expect("get").is_none(),
"row should be removed"
);
let err = backend
.delete_agent(&rec.agent_id)
.await
.expect_err("second delete must report NotFound");
assert!(matches!(err, StorageError::NotFound(_)));
}
fn policy_doc(name: &str, body: &str) -> PolicyDocument {
PolicyDocument {
name: name.to_owned(),
bytes: body.as_bytes().to_vec(),
}
}
#[tokio::test]
async fn policy_save_assigns_monotonic_versions_and_lists_desc() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
let v1 = backend
.save_policy(policy_doc("guard", "rules: v1"))
.await
.expect("save v1");
let v2 = backend
.save_policy(policy_doc("guard", "rules: v2"))
.await
.expect("save v2");
assert_eq!(v1.meta.version, 1);
assert_eq!(v2.meta.version, 2);
assert!(!v1.meta.is_active, "save must not auto-activate");
assert!(!v2.meta.is_active);
let listed = backend.list_policy_versions("guard").await.expect("list versions");
assert_eq!(listed.len(), 2);
assert_eq!(listed[0].version, 2, "list must be DESC");
assert_eq!(listed[1].version, 1);
}
#[tokio::test]
async fn policy_get_active_is_none_until_rollback_activates() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
backend
.save_policy(policy_doc("guard", "rules: v1"))
.await
.expect("save v1");
assert!(
backend.get_active_policy("guard").await.expect("get_active").is_none(),
"fresh save must not activate the row"
);
backend.rollback_policy("guard", 1).await.expect("activate v1");
let active = backend
.get_active_policy("guard")
.await
.expect("get_active")
.expect("present");
assert_eq!(active.name, "guard");
assert_eq!(active.bytes, b"rules: v1");
}
#[tokio::test]
async fn policy_rollback_enforces_single_active_per_name() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
backend
.save_policy(policy_doc("guard", "rules: v1"))
.await
.expect("save");
backend
.save_policy(policy_doc("guard", "rules: v2"))
.await
.expect("save");
backend.rollback_policy("guard", 2).await.expect("activate v2");
backend.rollback_policy("guard", 1).await.expect("activate v1");
let listed = backend.list_policy_versions("guard").await.expect("list");
let active: Vec<_> = listed.iter().filter(|m| m.is_active).collect();
assert_eq!(active.len(), 1);
assert_eq!(active[0].version, 1);
}
#[tokio::test]
async fn policy_rollback_missing_returns_not_found() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
let err = backend
.rollback_policy("guard", 99)
.await
.expect_err("rollback of missing version must fail");
assert!(matches!(err, StorageError::NotFound(_)));
}
fn sample_metric(seed: u8, name: &str, value: f64, ts_iso: &str) -> Metric {
let mut labels = std::collections::BTreeMap::new();
labels.insert("provider".to_owned(), "openai".to_owned());
Metric {
ts: chrono::DateTime::parse_from_rfc3339(ts_iso)
.unwrap()
.with_timezone(&chrono::Utc),
agent_id: AgentId::from_bytes([seed; 16]),
metric: name.to_owned(),
value,
labels,
}
}
#[tokio::test]
async fn metric_record_and_query_round_trip_without_filter() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
for m in [
sample_metric(1, "tokens_used", 100.0, "2026-05-21T10:00:00Z"),
sample_metric(1, "tokens_used", 200.0, "2026-05-21T11:00:00Z"),
sample_metric(2, "events_per_sec", 1.5, "2026-05-21T12:00:00Z"),
] {
backend.record_metric(m).await.expect("record");
}
let points = backend.query_metrics(MetricQuery::default()).await.expect("query");
assert_eq!(points.len(), 3);
}
#[tokio::test]
async fn metric_filter_by_agent_metric_and_time_range_narrows_results() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
for m in [
sample_metric(1, "tokens_used", 100.0, "2026-05-21T10:00:00Z"),
sample_metric(1, "tokens_used", 200.0, "2026-05-21T11:00:00Z"),
sample_metric(2, "events_per_sec", 1.5, "2026-05-21T12:00:00Z"),
] {
backend.record_metric(m).await.expect("record");
}
let agent_a = AgentId::from_bytes([1; 16]);
let scoped = backend
.query_metrics(MetricQuery {
agent_id: Some(agent_a),
metric: Some("tokens_used".into()),
..MetricQuery::default()
})
.await
.expect("scoped");
assert_eq!(scoped.len(), 2);
assert!(scoped.iter().all(|p| p.value == 100.0 || p.value == 200.0));
let windowed = backend
.query_metrics(MetricQuery {
from: Some(
chrono::DateTime::parse_from_rfc3339("2026-05-21T10:30:00Z")
.unwrap()
.with_timezone(&chrono::Utc),
),
to: Some(
chrono::DateTime::parse_from_rfc3339("2026-05-21T11:30:00Z")
.unwrap()
.with_timezone(&chrono::Utc),
),
..MetricQuery::default()
})
.await
.expect("window");
assert_eq!(windowed.len(), 1);
assert_eq!(windowed[0].value, 200.0);
}
#[tokio::test]
async fn metric_query_bucket_emits_warning_and_returns_raw_samples() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
backend
.record_metric(sample_metric(1, "tokens_used", 42.0, "2026-05-21T10:00:00Z"))
.await
.expect("record");
let points = backend
.query_metrics(MetricQuery {
bucket: Some("1 hour".into()),
..MetricQuery::default()
})
.await
.expect("query with bucket");
assert_eq!(points.len(), 1);
assert_eq!(points[0].value, 42.0);
}
fn seed_dated_event(seed: u8, days_ago: i64) -> AuditEvent {
AuditEvent {
ts: chrono::Utc::now() - chrono::Duration::days(days_ago),
event_id: uuid::Uuid::from_u128(u128::from(seed)),
agent_id: AgentId::from_bytes([seed; 16]),
team_id: None,
action: "tool_call".into(),
decision: "allow".into(),
dry_run: false,
shadow_decision: None,
matched_rule_id: None,
payload: None,
}
}
#[tokio::test]
async fn retention_deletes_rows_older_than_cold_threshold() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
for (seed, days) in [(1, 0), (2, 100), (3, 365)] {
backend
.append_audit_event(&seed_dated_event(seed, days))
.await
.expect("seed");
}
let stats = backend
.apply_retention(&RetentionPolicy {
hot_days: 30,
warm_days: 60,
cold_action: crate::storage::ColdAction::Drop,
archive_url: None,
dry_run: false,
})
.await
.expect("retention");
assert_eq!(stats.dropped_rows, 2);
assert_eq!(stats.compressed_rows, 0, "SQLite has no compression");
let remaining = backend.count_audit_events(AuditFilter::default()).await.expect("count");
assert_eq!(remaining, 1, "only the fresh row should survive");
}
#[tokio::test]
async fn retention_dry_run_reports_drop_count_without_deleting() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
for (seed, days) in [(1, 0), (2, 200)] {
backend
.append_audit_event(&seed_dated_event(seed, days))
.await
.expect("seed");
}
let stats = backend
.apply_retention(&RetentionPolicy {
hot_days: 30,
warm_days: 60,
cold_action: crate::storage::ColdAction::Drop,
archive_url: None,
dry_run: true,
})
.await
.expect("retention dry_run");
assert_eq!(stats.dropped_rows, 1);
let remaining = backend.count_audit_events(AuditFilter::default()).await.expect("count");
assert_eq!(remaining, 2, "dry_run must not delete any rows");
}
#[tokio::test]
async fn retention_archive_falls_back_to_drop_with_warn() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
backend
.append_audit_event(&seed_dated_event(7, 200))
.await
.expect("seed");
let stats = backend
.apply_retention(&RetentionPolicy {
hot_days: 30,
warm_days: 60,
cold_action: crate::storage::ColdAction::Archive,
archive_url: Some("s3://bucket/aasm".into()),
dry_run: false,
})
.await
.expect("retention");
assert_eq!(stats.dropped_rows, 1);
assert_eq!(stats.archived_rows, 0);
let remaining = backend.count_audit_events(AuditFilter::default()).await.expect("count");
assert_eq!(remaining, 0);
}
#[tokio::test]
async fn healthcheck_reports_ok_and_correct_row_counts() {
let (_tmp, backend) = open_temp_backend().await;
backend.migrate().await.expect("migrate");
backend
.append_audit_event(&seed_dated_event(1, 0))
.await
.expect("append");
backend
.upsert_agent(sample_agent(2, "team-x", "org-1", "Alpha"))
.await
.expect("upsert");
backend
.save_policy(policy_doc("guard", "rules: v1"))
.await
.expect("save");
let health = backend.healthcheck().await.expect("healthcheck");
assert!(matches!(health.status, crate::storage::HealthStatus::Ok));
assert_eq!(health.backend, "sqlite");
assert_eq!(health.row_counts.audit_events, 1);
assert_eq!(health.row_counts.agents, 1);
assert_eq!(health.row_counts.policy_versions, 1);
}
}