1use anyhow::{Context, Result};
9use rusqlite::Connection;
10use std::path::Path;
11
12use super::migrations::{SCHEMA, migrate};
13
14const CHECK_CONSTRAINT_TRIGGERS_SQLITE: &str =
28 include_str!("../../migrations/sqlite/0023_v07_check_constraints.sql");
29
30pub const SQL_BEGIN_IMMEDIATE: &str = "BEGIN IMMEDIATE";
37pub const SQL_COMMIT: &str = "COMMIT";
38pub const SQL_ROLLBACK: &str = "ROLLBACK";
39
40pub const DEFAULT_DB_MMAP_SIZE_BYTES: i64 = 256 * 1024 * 1024;
59
60static DB_MMAP_SIZE_BYTES: std::sync::OnceLock<i64> = std::sync::OnceLock::new();
67
68pub fn set_db_mmap_size(bytes: i64) {
72 let _ = DB_MMAP_SIZE_BYTES.set(bytes);
73}
74
75fn db_mmap_size() -> i64 {
77 *DB_MMAP_SIZE_BYTES
78 .get()
79 .unwrap_or(&DEFAULT_DB_MMAP_SIZE_BYTES)
80}
81
82pub fn open(path: &Path) -> Result<Connection> {
83 let conn = Connection::open(path).context("failed to open database")?;
84 apply_sqlcipher_key(&conn)?;
85 conn.pragma_update(None, "journal_mode", "WAL")?;
86 conn.pragma_update(None, "busy_timeout", 5000)?;
87 conn.pragma_update(None, "synchronous", "NORMAL")?;
88 conn.pragma_update(None, "mmap_size", db_mmap_size())?;
91 conn.pragma_update(None, "foreign_keys", "ON")?;
92 conn.execute_batch(SCHEMA)
93 .context("failed to initialize schema")?;
94 migrate(&conn)?;
95 apply_check_constraint_triggers(&conn)
96 .context("failed to apply R1-M2 CHECK-constraint triggers")?;
97 Ok(conn)
98}
99
100fn apply_check_constraint_triggers(conn: &Connection) -> Result<()> {
118 let already_installed: bool = conn
122 .query_row(
123 "SELECT EXISTS(SELECT 1 FROM sqlite_master \
124 WHERE type = 'trigger' AND name = 'memories_ck_tier_ins')",
125 [],
126 |r| r.get::<_, i64>(0).map(|n| n != 0),
127 )
128 .unwrap_or(false);
129 if already_installed {
130 return Ok(());
131 }
132
133 let count_violations =
138 |sql: &str| -> i64 { conn.query_row(sql, [], |r| r.get::<_, i64>(0)).unwrap_or(0) };
139 let bad_tier = count_violations(
140 "SELECT COUNT(*) FROM memories WHERE tier NOT IN ('short', 'mid', 'long')",
141 );
142 let bad_priority =
143 count_violations("SELECT COUNT(*) FROM memories WHERE priority < 1 OR priority > 10");
144 let bad_confidence = count_violations(
145 "SELECT COUNT(*) FROM memories WHERE confidence < 0.0 OR confidence > 1.0",
146 );
147 let bad_relation = count_violations(
148 "SELECT COUNT(*) FROM memory_links \
149 WHERE relation NOT IN ('related_to', 'supersedes', 'contradicts', 'derived_from', 'reflects_on', 'derives_from')",
150 );
151 let bad_attest = count_violations(
152 "SELECT COUNT(*) FROM memory_links \
153 WHERE attest_level IS NOT NULL \
154 AND attest_level NOT IN ('unsigned', 'self_signed', 'peer_attested')",
155 );
156 let total_bad = bad_tier + bad_priority + bad_confidence + bad_relation + bad_attest;
157 if total_bad > 0 {
158 tracing::warn!(
159 target: "ai_memory::storage::checks",
160 "R1-M2 CHECK trigger install: \
161 pre-existing constraint violations detected — \
162 memories.tier={bad_tier}, memories.priority={bad_priority}, \
163 memories.confidence={bad_confidence}, \
164 memory_links.relation={bad_relation}, \
165 memory_links.attest_level={bad_attest}. \
166 Triggers will still install; future writes that touch these \
167 rows will fail loudly until the values are repaired."
168 );
169 }
170
171 conn.execute_batch("BEGIN IMMEDIATE")?;
172 let result = (|| -> Result<()> {
173 conn.execute_batch(CHECK_CONSTRAINT_TRIGGERS_SQLITE)
174 .context("apply CHECK-constraint triggers")?;
175 Ok(())
176 })();
177 match result {
178 Ok(()) => {
179 conn.execute_batch("COMMIT")?;
180 Ok(())
181 }
182 Err(e) => {
183 let _ = conn.execute_batch("ROLLBACK");
184 Err(e)
185 }
186 }
187}
188
189#[cfg(feature = "sqlcipher")]
201fn apply_sqlcipher_key(conn: &Connection) -> Result<()> {
202 let Ok(passphrase) = std::env::var("AI_MEMORY_DB_PASSPHRASE") else {
203 return Err(anyhow::Error::new(
205 super::error::StorageError::SqlcipherMissingPassphrase,
206 ));
207 };
208 let escaped = passphrase.replace('\'', "''");
211 conn.pragma_update(None, "key", format!("'{escaped}'"))
212 .context("PRAGMA key failed (wrong passphrase or unencrypted DB?)")?;
213 conn.query_row("SELECT count(*) FROM sqlite_master", [], |r| {
215 r.get::<_, i64>(0)
216 })
217 .context("SQLCipher unlock verification failed — wrong passphrase?")?;
218 Ok(())
219}
220
221#[cfg(not(feature = "sqlcipher"))]
222#[allow(clippy::unnecessary_wraps)]
223fn apply_sqlcipher_key(_conn: &Connection) -> Result<()> {
224 Ok(())
225}
226
227#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn open_round_trip_creates_db_and_runs_migrations() {
238 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
239 let conn = open(tmp.path()).expect("open initial");
240 let v: i64 = conn
242 .query_row(
243 "SELECT COALESCE(MAX(version), 0) FROM schema_version",
244 [],
245 |r| r.get(0),
246 )
247 .expect("schema_version readable");
248 assert!(v > 0, "expected positive schema version, got {v}");
249 }
250
251 #[test]
252 fn open_twice_is_idempotent_for_check_triggers() {
253 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
259 let _conn1 = open(tmp.path()).expect("first open");
261 let conn2 = open(tmp.path()).expect("re-open idempotent");
263 let n: i64 = conn2
265 .query_row(
266 "SELECT COUNT(*) FROM sqlite_master \
267 WHERE type = 'trigger' AND name = 'memories_ck_tier_ins'",
268 [],
269 |r| r.get(0),
270 )
271 .expect("trigger query");
272 assert_eq!(n, 1, "sentinel trigger must be installed exactly once");
273 }
274
275 #[test]
276 fn open_applies_wal_journal_mode() {
277 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
278 let conn = open(tmp.path()).expect("open");
279 let mode: String = conn
280 .query_row("PRAGMA journal_mode", [], |r| r.get(0))
281 .expect("journal_mode");
282 assert_eq!(mode.to_lowercase(), "wal");
283 }
284
285 #[test]
286 fn open_applies_default_mmap_size() {
287 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
294 let conn = open(tmp.path()).expect("open");
295 let mmap: i64 = conn
296 .query_row("PRAGMA mmap_size", [], |r| r.get(0))
297 .expect("mmap_size");
298 assert_eq!(
299 mmap, DEFAULT_DB_MMAP_SIZE_BYTES,
300 "open() must apply the P1-proven 256 MiB mmap_size default"
301 );
302 }
303
304 #[test]
305 fn open_enables_foreign_keys() {
306 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
307 let conn = open(tmp.path()).expect("open");
308 let fk: i32 = conn
309 .query_row("PRAGMA foreign_keys", [], |r| r.get(0))
310 .expect("foreign_keys");
311 assert_eq!(fk, 1, "open() must enable foreign_keys");
312 }
313
314 fn index_present(conn: &Connection, name: &str) -> bool {
316 let n: i64 = conn
317 .query_row(
318 "SELECT COUNT(*) FROM sqlite_master WHERE type = 'index' AND name = ?1",
319 rusqlite::params![name],
320 |r| r.get(0),
321 )
322 .unwrap_or(0);
323 n == 1
324 }
325
326 fn column_present(conn: &Connection, table: &str, column: &str) -> bool {
328 let sql = format!("PRAGMA table_info({table})");
329 let mut stmt = match conn.prepare(&sql) {
330 Ok(s) => s,
331 Err(_) => return false,
332 };
333 let mut rows = stmt.query([]).expect("PRAGMA query");
334 while let Some(row) = rows.next().expect("PRAGMA next") {
335 let name: String = row.get(1).expect("col name");
336 if name == column {
337 return true;
338 }
339 }
340 false
341 }
342
343 #[test]
356 fn open_succeeds_on_legacy_pre_v36_memories_shape() {
357 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
358 {
359 let conn = open(tmp.path()).expect("seed: fresh open");
360 for ix in [
361 "idx_memories_atom_of",
362 "idx_memories_atomised_into",
363 "idx_personas_by_entity",
364 "idx_memories_source_uri",
365 "idx_memories_confidence_source",
366 "idx_memories_mentioned_entity",
367 ] {
368 conn.execute(&format!("DROP INDEX IF EXISTS {ix}"), [])
369 .expect("drop index");
370 }
371 for col in [
372 "mentioned_entity_id",
373 "confidence_decayed_at",
374 "confidence_signals",
375 "confidence_source",
376 "source_span",
377 "source_uri",
378 "citations",
379 "persona_version",
380 "entity_id",
381 "atom_of",
382 "atomised_into",
383 ] {
384 conn.execute(&format!("ALTER TABLE memories DROP COLUMN {col}"), [])
385 .unwrap_or_else(|e| panic!("DROP COLUMN {col}: {e}"));
386 }
387 conn.execute("DROP TABLE IF EXISTS confidence_shadow_observations", [])
388 .expect("drop shadow table");
389 conn.execute("DROP TABLE IF EXISTS signed_events_dlq", [])
390 .expect("drop dlq");
391 conn.execute("DELETE FROM schema_version", [])
392 .expect("clear version");
393 conn.execute("INSERT INTO schema_version (version) VALUES (34)", [])
394 .expect("stamp v34");
395 }
396
397 let conn = open(tmp.path()).expect("legacy-upgrade open must succeed");
398
399 let v: i64 = conn
400 .query_row(
401 "SELECT COALESCE(MAX(version), 0) FROM schema_version",
402 [],
403 |r| r.get(0),
404 )
405 .expect("read schema_version");
406 assert!(
407 v >= 42,
408 "migrate ladder must reach CURRENT_SCHEMA_VERSION; got {v}"
409 );
410
411 for col in [
412 "atom_of",
413 "atomised_into",
414 "entity_id",
415 "persona_version",
416 "citations",
417 "source_uri",
418 "source_span",
419 "confidence_source",
420 "confidence_signals",
421 "confidence_decayed_at",
422 "mentioned_entity_id",
423 ] {
424 assert!(
425 column_present(&conn, "memories", col),
426 "memories.{col} must be ALTER-added by the migrate ladder"
427 );
428 }
429
430 for ix in [
431 "idx_memories_atom_of",
432 "idx_memories_atomised_into",
433 "idx_memories_source_uri",
434 "idx_memories_confidence_source",
435 "idx_memories_mentioned_entity",
436 "idx_shadow_obs_namespace_source_observed",
437 ] {
438 assert!(
439 index_present(&conn, ix),
440 "index {ix} must exist after legacy upgrade"
441 );
442 }
443 }
444
445 #[test]
450 fn open_succeeds_on_legacy_pre_v41_shadow_shape() {
451 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
452 {
453 let conn = open(tmp.path()).expect("seed: fresh open");
454 conn.execute(
455 "DROP INDEX IF EXISTS idx_shadow_obs_namespace_source_observed",
456 [],
457 )
458 .expect("drop compound shadow index");
459 conn.execute(
460 "ALTER TABLE confidence_shadow_observations DROP COLUMN source",
461 [],
462 )
463 .expect("drop shadow.source");
464 conn.execute("DELETE FROM schema_version", [])
465 .expect("clear version");
466 conn.execute("INSERT INTO schema_version (version) VALUES (40)", [])
467 .expect("stamp v40");
468 }
469
470 let conn = open(tmp.path()).expect("v40 legacy-upgrade open must succeed");
471 assert!(
472 column_present(&conn, "confidence_shadow_observations", "source"),
473 "v41 migrate arm must ALTER-add shadow.source"
474 );
475 assert!(
476 index_present(&conn, "idx_shadow_obs_namespace_source_observed"),
477 "v41 compound shadow index must be re-attached"
478 );
479 }
480
481 #[test]
482 fn check_trigger_rejects_bad_tier_insert() {
483 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
490 let conn = open(tmp.path()).expect("open");
491 let now = chrono::Utc::now().to_rfc3339();
492 let res = conn.execute(
493 "INSERT INTO memories \
494 (id, tier, namespace, title, content, tags, priority, confidence, \
495 source, access_count, created_at, updated_at, metadata, reflection_depth) \
496 VALUES (?1, 'NOT_A_TIER', 'test', 't', 'c', '[]', 5, 1.0, \
497 'src', 0, ?2, ?2, '{}', 0)",
498 rusqlite::params!["bad-tier-id", now],
499 );
500 assert!(
501 res.is_err(),
502 "INSERT with bad tier must be rejected by R1-M2 trigger"
503 );
504 }
505}