use anyhow::{Context, Result};
use rusqlite::Connection;
use std::path::Path;
use super::migrations::{SCHEMA, migrate};
const CHECK_CONSTRAINT_TRIGGERS_SQLITE: &str =
include_str!("../../migrations/sqlite/0023_v07_check_constraints.sql");
pub const SQL_BEGIN_IMMEDIATE: &str = "BEGIN IMMEDIATE";
pub const SQL_COMMIT: &str = "COMMIT";
pub const SQL_ROLLBACK: &str = "ROLLBACK";
pub const DEFAULT_DB_MMAP_SIZE_BYTES: i64 = 256 * 1024 * 1024;
static DB_MMAP_SIZE_BYTES: std::sync::OnceLock<i64> = std::sync::OnceLock::new();
pub fn set_db_mmap_size(bytes: i64) {
let _ = DB_MMAP_SIZE_BYTES.set(bytes);
}
fn db_mmap_size() -> i64 {
*DB_MMAP_SIZE_BYTES
.get()
.unwrap_or(&DEFAULT_DB_MMAP_SIZE_BYTES)
}
pub fn open(path: &Path) -> Result<Connection> {
let conn = Connection::open(path).context("failed to open database")?;
apply_sqlcipher_key(&conn)?;
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.pragma_update(None, "busy_timeout", 5000)?;
conn.pragma_update(None, "synchronous", "NORMAL")?;
conn.pragma_update(None, "mmap_size", db_mmap_size())?;
conn.pragma_update(None, "foreign_keys", "ON")?;
conn.execute_batch(SCHEMA)
.context("failed to initialize schema")?;
migrate(&conn)?;
apply_check_constraint_triggers(&conn)
.context("failed to apply R1-M2 CHECK-constraint triggers")?;
Ok(conn)
}
fn apply_check_constraint_triggers(conn: &Connection) -> Result<()> {
let already_installed: bool = conn
.query_row(
"SELECT EXISTS(SELECT 1 FROM sqlite_master \
WHERE type = 'trigger' AND name = 'memories_ck_tier_ins')",
[],
|r| r.get::<_, i64>(0).map(|n| n != 0),
)
.unwrap_or(false);
if already_installed {
return Ok(());
}
let count_violations =
|sql: &str| -> i64 { conn.query_row(sql, [], |r| r.get::<_, i64>(0)).unwrap_or(0) };
let bad_tier = count_violations(
"SELECT COUNT(*) FROM memories WHERE tier NOT IN ('short', 'mid', 'long')",
);
let bad_priority =
count_violations("SELECT COUNT(*) FROM memories WHERE priority < 1 OR priority > 10");
let bad_confidence = count_violations(
"SELECT COUNT(*) FROM memories WHERE confidence < 0.0 OR confidence > 1.0",
);
let bad_relation = count_violations(
"SELECT COUNT(*) FROM memory_links \
WHERE relation NOT IN ('related_to', 'supersedes', 'contradicts', 'derived_from', 'reflects_on', 'derives_from')",
);
let bad_attest = count_violations(
"SELECT COUNT(*) FROM memory_links \
WHERE attest_level IS NOT NULL \
AND attest_level NOT IN ('unsigned', 'self_signed', 'peer_attested')",
);
let total_bad = bad_tier + bad_priority + bad_confidence + bad_relation + bad_attest;
if total_bad > 0 {
tracing::warn!(
target: "ai_memory::storage::checks",
"R1-M2 CHECK trigger install: \
pre-existing constraint violations detected — \
memories.tier={bad_tier}, memories.priority={bad_priority}, \
memories.confidence={bad_confidence}, \
memory_links.relation={bad_relation}, \
memory_links.attest_level={bad_attest}. \
Triggers will still install; future writes that touch these \
rows will fail loudly until the values are repaired."
);
}
conn.execute_batch("BEGIN IMMEDIATE")?;
let result = (|| -> Result<()> {
conn.execute_batch(CHECK_CONSTRAINT_TRIGGERS_SQLITE)
.context("apply CHECK-constraint triggers")?;
Ok(())
})();
match result {
Ok(()) => {
conn.execute_batch("COMMIT")?;
Ok(())
}
Err(e) => {
let _ = conn.execute_batch("ROLLBACK");
Err(e)
}
}
}
#[cfg(feature = "sqlcipher")]
fn apply_sqlcipher_key(conn: &Connection) -> Result<()> {
let Ok(passphrase) = std::env::var("AI_MEMORY_DB_PASSPHRASE") else {
return Err(anyhow::Error::new(
super::error::StorageError::SqlcipherMissingPassphrase,
));
};
let escaped = passphrase.replace('\'', "''");
conn.pragma_update(None, "key", format!("'{escaped}'"))
.context("PRAGMA key failed (wrong passphrase or unencrypted DB?)")?;
conn.query_row("SELECT count(*) FROM sqlite_master", [], |r| {
r.get::<_, i64>(0)
})
.context("SQLCipher unlock verification failed — wrong passphrase?")?;
Ok(())
}
#[cfg(not(feature = "sqlcipher"))]
#[allow(clippy::unnecessary_wraps)]
fn apply_sqlcipher_key(_conn: &Connection) -> Result<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn open_round_trip_creates_db_and_runs_migrations() {
let tmp = tempfile::NamedTempFile::new().expect("tempfile");
let conn = open(tmp.path()).expect("open initial");
let v: i64 = conn
.query_row(
"SELECT COALESCE(MAX(version), 0) FROM schema_version",
[],
|r| r.get(0),
)
.expect("schema_version readable");
assert!(v > 0, "expected positive schema version, got {v}");
}
#[test]
fn open_twice_is_idempotent_for_check_triggers() {
let tmp = tempfile::NamedTempFile::new().expect("tempfile");
let _conn1 = open(tmp.path()).expect("first open");
let conn2 = open(tmp.path()).expect("re-open idempotent");
let n: i64 = conn2
.query_row(
"SELECT COUNT(*) FROM sqlite_master \
WHERE type = 'trigger' AND name = 'memories_ck_tier_ins'",
[],
|r| r.get(0),
)
.expect("trigger query");
assert_eq!(n, 1, "sentinel trigger must be installed exactly once");
}
#[test]
fn open_applies_wal_journal_mode() {
let tmp = tempfile::NamedTempFile::new().expect("tempfile");
let conn = open(tmp.path()).expect("open");
let mode: String = conn
.query_row("PRAGMA journal_mode", [], |r| r.get(0))
.expect("journal_mode");
assert_eq!(mode.to_lowercase(), "wal");
}
#[test]
fn open_applies_default_mmap_size() {
let tmp = tempfile::NamedTempFile::new().expect("tempfile");
let conn = open(tmp.path()).expect("open");
let mmap: i64 = conn
.query_row("PRAGMA mmap_size", [], |r| r.get(0))
.expect("mmap_size");
assert_eq!(
mmap, DEFAULT_DB_MMAP_SIZE_BYTES,
"open() must apply the P1-proven 256 MiB mmap_size default"
);
}
#[test]
fn open_enables_foreign_keys() {
let tmp = tempfile::NamedTempFile::new().expect("tempfile");
let conn = open(tmp.path()).expect("open");
let fk: i32 = conn
.query_row("PRAGMA foreign_keys", [], |r| r.get(0))
.expect("foreign_keys");
assert_eq!(fk, 1, "open() must enable foreign_keys");
}
fn index_present(conn: &Connection, name: &str) -> bool {
let n: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type = 'index' AND name = ?1",
rusqlite::params![name],
|r| r.get(0),
)
.unwrap_or(0);
n == 1
}
fn column_present(conn: &Connection, table: &str, column: &str) -> bool {
let sql = format!("PRAGMA table_info({table})");
let mut stmt = match conn.prepare(&sql) {
Ok(s) => s,
Err(_) => return false,
};
let mut rows = stmt.query([]).expect("PRAGMA query");
while let Some(row) = rows.next().expect("PRAGMA next") {
let name: String = row.get(1).expect("col name");
if name == column {
return true;
}
}
false
}
#[test]
fn open_succeeds_on_legacy_pre_v36_memories_shape() {
let tmp = tempfile::NamedTempFile::new().expect("tempfile");
{
let conn = open(tmp.path()).expect("seed: fresh open");
for ix in [
"idx_memories_atom_of",
"idx_memories_atomised_into",
"idx_personas_by_entity",
"idx_memories_source_uri",
"idx_memories_confidence_source",
"idx_memories_mentioned_entity",
] {
conn.execute(&format!("DROP INDEX IF EXISTS {ix}"), [])
.expect("drop index");
}
for col in [
"mentioned_entity_id",
"confidence_decayed_at",
"confidence_signals",
"confidence_source",
"source_span",
"source_uri",
"citations",
"persona_version",
"entity_id",
"atom_of",
"atomised_into",
] {
conn.execute(&format!("ALTER TABLE memories DROP COLUMN {col}"), [])
.unwrap_or_else(|e| panic!("DROP COLUMN {col}: {e}"));
}
conn.execute("DROP TABLE IF EXISTS confidence_shadow_observations", [])
.expect("drop shadow table");
conn.execute("DROP TABLE IF EXISTS signed_events_dlq", [])
.expect("drop dlq");
conn.execute("DELETE FROM schema_version", [])
.expect("clear version");
conn.execute("INSERT INTO schema_version (version) VALUES (34)", [])
.expect("stamp v34");
}
let conn = open(tmp.path()).expect("legacy-upgrade open must succeed");
let v: i64 = conn
.query_row(
"SELECT COALESCE(MAX(version), 0) FROM schema_version",
[],
|r| r.get(0),
)
.expect("read schema_version");
assert!(
v >= 42,
"migrate ladder must reach CURRENT_SCHEMA_VERSION; got {v}"
);
for col in [
"atom_of",
"atomised_into",
"entity_id",
"persona_version",
"citations",
"source_uri",
"source_span",
"confidence_source",
"confidence_signals",
"confidence_decayed_at",
"mentioned_entity_id",
] {
assert!(
column_present(&conn, "memories", col),
"memories.{col} must be ALTER-added by the migrate ladder"
);
}
for ix in [
"idx_memories_atom_of",
"idx_memories_atomised_into",
"idx_memories_source_uri",
"idx_memories_confidence_source",
"idx_memories_mentioned_entity",
"idx_shadow_obs_namespace_source_observed",
] {
assert!(
index_present(&conn, ix),
"index {ix} must exist after legacy upgrade"
);
}
}
#[test]
fn open_succeeds_on_legacy_pre_v41_shadow_shape() {
let tmp = tempfile::NamedTempFile::new().expect("tempfile");
{
let conn = open(tmp.path()).expect("seed: fresh open");
conn.execute(
"DROP INDEX IF EXISTS idx_shadow_obs_namespace_source_observed",
[],
)
.expect("drop compound shadow index");
conn.execute(
"ALTER TABLE confidence_shadow_observations DROP COLUMN source",
[],
)
.expect("drop shadow.source");
conn.execute("DELETE FROM schema_version", [])
.expect("clear version");
conn.execute("INSERT INTO schema_version (version) VALUES (40)", [])
.expect("stamp v40");
}
let conn = open(tmp.path()).expect("v40 legacy-upgrade open must succeed");
assert!(
column_present(&conn, "confidence_shadow_observations", "source"),
"v41 migrate arm must ALTER-add shadow.source"
);
assert!(
index_present(&conn, "idx_shadow_obs_namespace_source_observed"),
"v41 compound shadow index must be re-attached"
);
}
#[test]
fn check_trigger_rejects_bad_tier_insert() {
let tmp = tempfile::NamedTempFile::new().expect("tempfile");
let conn = open(tmp.path()).expect("open");
let now = chrono::Utc::now().to_rfc3339();
let res = conn.execute(
"INSERT INTO memories \
(id, tier, namespace, title, content, tags, priority, confidence, \
source, access_count, created_at, updated_at, metadata, reflection_depth) \
VALUES (?1, 'NOT_A_TIER', 'test', 't', 'c', '[]', 5, 1.0, \
'src', 0, ?2, ?2, '{}', 0)",
rusqlite::params!["bad-tier-id", now],
);
assert!(
res.is_err(),
"INSERT with bad tier must be rejected by R1-M2 trigger"
);
}
}