1use std::path::Path;
10#[cfg(feature = "encryption")]
11use std::sync::Arc;
12
13use r2d2::Pool;
14use r2d2_sqlite::SqliteConnectionManager;
15use rusqlite::Connection;
16use thiserror::Error;
17use tracing::info;
18use uuid::Uuid;
19
20#[cfg(feature = "encryption")]
21use crate::encryption::Encryptor;
22
23mod migrations;
24
25#[cfg(test)]
26mod tests;
27
28#[derive(Debug, Error)]
30pub enum SqliteError {
31 #[error("SQLite error: {0}")]
32 Rusqlite(#[from] rusqlite::Error),
33
34 #[error("connection pool unavailable: {0}")]
35 Pool(String),
36
37 #[deprecated(
43 note = "No longer emitted after the r2d2 pool migration. Match `SqliteError::Pool` instead. Kept so older downstream `match` arms still compile."
44 )]
45 #[error("Lock poisoned")]
46 LockPoisoned,
47
48 #[error("Migration failed: {0}")]
49 Migration(String),
50
51 #[error(
57 "database schema v{found} is newer than this build supports (v{supported}); \
58 upgrade brain, or re-open with the downgrade override if you accept the risk"
59 )]
60 SchemaTooNew { found: i64, supported: i64 },
61
62 #[error("pre-migration backup failed: {0}")]
66 Backup(String),
67}
68
69impl From<r2d2::Error> for SqliteError {
70 fn from(e: r2d2::Error) -> Self {
71 SqliteError::Pool(e.to_string())
72 }
73}
74
75#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
77pub struct ExportedFact {
78 pub id: String,
79 pub namespace: String,
80 pub category: String,
81 pub subject: String,
82 pub predicate: String,
83 pub object: String,
84 pub confidence: f64,
85 pub source_episode_id: Option<String>,
86}
87
88#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
90pub struct ExportedEpisode {
91 pub id: String,
92 pub session_id: String,
93 pub session_channel: String,
94 #[serde(default = "default_namespace")]
95 pub namespace: String,
96 pub role: String,
97 pub content: String,
98 pub timestamp: String,
99 pub importance: f64,
100 pub reinforcement_count: i32,
101}
102
103fn default_namespace() -> String {
104 "personal".to_string()
105}
106
107#[derive(Debug, Clone)]
109pub struct Notification {
110 pub id: String,
111 pub content: String,
112 pub priority: i32,
113 pub triggered_by: String,
114 pub created_at: String,
115 pub delivered_at: Option<String>,
116 pub channel: Option<String>,
117}
118
119#[derive(Clone)]
130pub struct SqlitePool {
131 pool: Pool<SqliteConnectionManager>,
132 #[cfg(feature = "encryption")]
133 encryptor: Option<Arc<Encryptor>>,
134}
135
136const FILE_PRAGMAS: &str = "
140 PRAGMA journal_mode = WAL;
141 PRAGMA synchronous = NORMAL;
142 PRAGMA foreign_keys = ON;
143 PRAGMA busy_timeout = 5000;
144 PRAGMA cache_size = -8000;
145";
146
147const MEMORY_PRAGMAS: &str = "
150 PRAGMA foreign_keys = ON;
151";
152
153const FILE_POOL_SIZE: u32 = 8;
158
159#[derive(Debug, Clone)]
161pub struct ScheduledIntent {
162 pub id: String,
163 pub description: String,
164 pub cron: Option<String>,
165 pub namespace: String,
166 pub created_at: String,
167 pub status: String,
168 pub metadata: Option<String>,
169}
170
171impl SqlitePool {
172 pub fn open(path: &Path) -> Result<Self, SqliteError> {
181 Self::open_with(path, false)
182 }
183
184 pub fn open_with(path: &Path, allow_downgrade: bool) -> Result<Self, SqliteError> {
190 if let Some(parent) = path.parent() {
192 std::fs::create_dir_all(parent).map_err(|e| {
193 SqliteError::Migration(format!("Cannot create directory {}: {e}", parent.display()))
194 })?;
195 }
196
197 let manager = SqliteConnectionManager::file(path)
198 .with_init(|c: &mut Connection| c.execute_batch(FILE_PRAGMAS));
199 let pool = Pool::builder().max_size(FILE_POOL_SIZE).build(manager)?;
200
201 let p = Self {
202 pool,
203 #[cfg(feature = "encryption")]
204 encryptor: None,
205 };
206
207 p.reconcile_schema_version(path, allow_downgrade)?;
211
212 p.migrate()?;
215
216 info!(
217 "SQLite database opened at {} (pool size {FILE_POOL_SIZE})",
218 path.display()
219 );
220 Ok(p)
221 }
222
223 fn reconcile_schema_version(
233 &self,
234 path: &Path,
235 allow_downgrade: bool,
236 ) -> Result<(), SqliteError> {
237 let found = self.schema_version()?;
238 let supported = Self::latest_schema_version();
239
240 if found > supported {
241 if allow_downgrade {
242 info!(
243 "Opening schema v{found} with an older build (supports v{supported}) — \
244 downgrade override active"
245 );
246 } else {
247 return Err(SqliteError::SchemaTooNew { found, supported });
248 }
249 }
250
251 if found > 0 && found < supported {
252 self.backup_before_migration(path, found)?;
253 }
254
255 Ok(())
256 }
257
258 fn backup_before_migration(&self, path: &Path, version: i64) -> Result<(), SqliteError> {
263 let file_name = path
264 .file_name()
265 .and_then(|n| n.to_str())
266 .unwrap_or("brain.db");
267 let backup = path.with_file_name(format!("{file_name}.bak-v{version}"));
268
269 if backup.exists() {
271 std::fs::remove_file(&backup)
272 .map_err(|e| SqliteError::Backup(format!("{}: {e}", backup.display())))?;
273 }
274
275 let target = backup.to_string_lossy().to_string();
276 self.with_conn(|conn| {
277 conn.execute("VACUUM INTO ?1", rusqlite::params![target])?;
278 Ok(())
279 })
280 .map_err(|e| SqliteError::Backup(e.to_string()))?;
281
282 info!(
283 "Pre-migration backup written to {} (schema v{version})",
284 backup.display()
285 );
286 Ok(())
287 }
288
289 pub fn open_memory() -> Result<Self, SqliteError> {
296 let manager = SqliteConnectionManager::memory()
297 .with_init(|c: &mut Connection| c.execute_batch(MEMORY_PRAGMAS));
298 let pool = Pool::builder().max_size(1).build(manager)?;
299
300 let p = Self {
301 pool,
302 #[cfg(feature = "encryption")]
303 encryptor: None,
304 };
305
306 p.migrate()?;
307 Ok(p)
308 }
309
310 pub fn with_conn<F, T>(&self, f: F) -> Result<T, SqliteError>
317 where
318 F: FnOnce(&Connection) -> Result<T, SqliteError>,
319 {
320 let conn = self.pool.get()?;
321 f(&conn)
322 }
323
324 pub fn open_connections(&self) -> u32 {
329 self.pool.state().connections
330 }
331
332 #[cfg(feature = "encryption")]
337 pub fn with_encryptor(mut self, enc: Encryptor) -> Self {
338 self.encryptor = Some(Arc::new(enc));
339 self
340 }
341
342 pub fn is_encrypted(&self) -> bool {
344 #[cfg(feature = "encryption")]
345 {
346 self.encryptor.is_some()
347 }
348 #[cfg(not(feature = "encryption"))]
349 {
350 false
351 }
352 }
353
354 pub fn encrypt_content(&self, plaintext: &str) -> String {
356 #[cfg(feature = "encryption")]
357 {
358 if let Some(enc) = &self.encryptor {
359 return enc
360 .encrypt_string(plaintext)
361 .unwrap_or_else(|_| plaintext.to_string());
362 }
363 }
364 plaintext.to_string()
365 }
366
367 pub fn decrypt_content(&self, maybe_ciphertext: &str) -> String {
372 #[cfg(feature = "encryption")]
373 {
374 if let Some(enc) = &self.encryptor {
375 return enc
376 .decrypt_string(maybe_ciphertext)
377 .unwrap_or_else(|_| maybe_ciphertext.to_string());
378 }
379 }
380 maybe_ciphertext.to_string()
381 }
382
383 pub fn try_decrypt_content(&self, maybe_ciphertext: &str) -> Option<String> {
389 #[cfg(feature = "encryption")]
390 {
391 if let Some(enc) = &self.encryptor {
392 return enc.decrypt_string(maybe_ciphertext).ok();
393 }
394 }
395 Some(maybe_ciphertext.to_string())
396 }
397
398 pub fn wal_checkpoint(&self) -> Result<(), SqliteError> {
404 self.with_conn(|conn| {
405 conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
406 Ok(())
407 })
408 }
409
410 pub fn insert_scheduled_intent(
412 &self,
413 description: &str,
414 cron: Option<&str>,
415 namespace: &str,
416 metadata: Option<&str>,
417 ) -> Result<String, SqliteError> {
418 let id = Uuid::new_v4().to_string();
419 self.with_conn(|conn| {
420 conn.execute(
421 "INSERT INTO scheduled_intents (id, description, cron, namespace, metadata)
422 VALUES (?1, ?2, ?3, ?4, ?5)",
423 rusqlite::params![id, description, cron, namespace, metadata],
424 )?;
425 Ok(())
426 })?;
427 Ok(id)
428 }
429
430 pub fn list_scheduled_intents(
432 &self,
433 namespace: Option<&str>,
434 ) -> Result<Vec<ScheduledIntent>, SqliteError> {
435 self.with_conn(|conn| {
436 let mut intents = Vec::new();
437 if let Some(ns) = namespace {
438 let mut stmt = conn.prepare(
439 "SELECT id, description, cron, namespace, created_at, status, metadata
440 FROM scheduled_intents
441 WHERE namespace = ?1 OR namespace LIKE ?2
442 ORDER BY created_at DESC",
443 )?;
444 let prefix = format!("{}/%", ns);
445 let rows = stmt.query_map([ns, &prefix], |row| {
446 Ok(ScheduledIntent {
447 id: row.get(0)?,
448 description: row.get(1)?,
449 cron: row.get(2)?,
450 namespace: row.get(3)?,
451 created_at: row.get(4)?,
452 status: row.get(5)?,
453 metadata: row.get(6)?,
454 })
455 })?;
456 for row in rows {
457 intents.push(row?);
458 }
459 } else {
460 let mut stmt = conn.prepare(
461 "SELECT id, description, cron, namespace, created_at, status, metadata
462 FROM scheduled_intents
463 ORDER BY created_at DESC",
464 )?;
465 let rows = stmt.query_map([], |row| {
466 Ok(ScheduledIntent {
467 id: row.get(0)?,
468 description: row.get(1)?,
469 cron: row.get(2)?,
470 namespace: row.get(3)?,
471 created_at: row.get(4)?,
472 status: row.get(5)?,
473 metadata: row.get(6)?,
474 })
475 })?;
476 for row in rows {
477 intents.push(row?);
478 }
479 }
480 Ok(intents)
481 })
482 }
483
484 pub fn update_scheduled_intent_status(
486 &self,
487 id: &str,
488 status: &str,
489 ) -> Result<bool, SqliteError> {
490 self.with_conn(|conn| {
491 let affected = conn.execute(
492 "UPDATE scheduled_intents SET status = ?2 WHERE id = ?1",
493 rusqlite::params![id, status],
494 )?;
495 Ok(affected > 0)
496 })
497 }
498
499 pub fn cancel_scheduled_intent(&self, id: &str) -> Result<bool, SqliteError> {
501 self.update_scheduled_intent_status(id, "cancelled")
502 }
503
504 pub fn due_scheduled_intents(&self) -> Result<Vec<ScheduledIntent>, SqliteError> {
506 self.with_conn(|conn| {
507 let mut stmt = conn.prepare(
508 "SELECT id, description, cron, namespace, created_at, status, metadata
509 FROM scheduled_intents
510 WHERE status = 'scheduled'
511 ORDER BY created_at ASC",
512 )?;
513 let rows = stmt.query_map([], |row| {
514 Ok(ScheduledIntent {
515 id: row.get(0)?,
516 description: row.get(1)?,
517 cron: row.get(2)?,
518 namespace: row.get(3)?,
519 created_at: row.get(4)?,
520 status: row.get(5)?,
521 metadata: row.get(6)?,
522 })
523 })?;
524 Ok(rows.filter_map(|r| r.ok()).collect())
525 })
526 }
527
528 pub fn insert_notification(
530 &self,
531 content: &str,
532 priority: i32,
533 triggered_by: &str,
534 channel: Option<&str>,
535 ) -> Result<String, SqliteError> {
536 let id = Uuid::new_v4().to_string();
537 self.with_conn(|conn| {
538 conn.execute(
539 "INSERT INTO notification_outbox (id, content, priority, triggered_by, channel)
540 VALUES (?1, ?2, ?3, ?4, ?5)",
541 rusqlite::params![id, content, priority, triggered_by, channel],
542 )?;
543 Ok(())
544 })?;
545 Ok(id)
546 }
547
548 pub fn pending_notifications(&self, limit: usize) -> Result<Vec<Notification>, SqliteError> {
550 self.with_conn(|conn| {
551 let mut stmt = conn.prepare(
552 "SELECT id, content, priority, triggered_by, created_at, delivered_at, channel
553 FROM notification_outbox
554 WHERE delivered_at IS NULL
555 ORDER BY priority DESC, created_at ASC
556 LIMIT ?1",
557 )?;
558 let rows = stmt
559 .query_map([limit as i64], |row| {
560 Ok(Notification {
561 id: row.get(0)?,
562 content: row.get(1)?,
563 priority: row.get(2)?,
564 triggered_by: row.get(3)?,
565 created_at: row.get(4)?,
566 delivered_at: row.get(5)?,
567 channel: row.get(6)?,
568 })
569 })?
570 .collect::<Result<Vec<_>, _>>()?;
571 Ok(rows)
572 })
573 }
574
575 pub fn mark_notification_delivered(&self, id: &str) -> Result<bool, SqliteError> {
577 self.with_conn(|conn| {
578 let affected = conn.execute(
579 "UPDATE notification_outbox SET delivered_at = datetime('now') WHERE id = ?1 AND delivered_at IS NULL",
580 [id],
581 )?;
582 Ok(affected > 0)
583 })
584 }
585
586 pub fn mark_notifications_delivered(&self, ids: &[String]) -> Result<usize, SqliteError> {
589 if ids.is_empty() {
590 return Ok(0);
591 }
592 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
593 let sql = format!(
594 "UPDATE notification_outbox SET delivered_at = datetime('now') \
595 WHERE delivered_at IS NULL AND id IN ({placeholders})"
596 );
597 self.with_conn(|conn| {
598 let params: Vec<&dyn rusqlite::types::ToSql> = ids
599 .iter()
600 .map(|id| id as &dyn rusqlite::types::ToSql)
601 .collect();
602 let affected = conn.execute(&sql, params.as_slice())?;
603 Ok(affected)
604 })
605 }
606
607 pub fn prune_notifications(&self, max_age_days: u32) -> Result<usize, SqliteError> {
609 self.with_conn(|conn| {
610 let deleted = conn.execute(
611 "DELETE FROM notification_outbox
612 WHERE (delivered_at IS NOT NULL AND created_at < datetime('now', ?1))
613 OR created_at < datetime('now', ?1)",
614 [format!("-{max_age_days} days")],
615 )?;
616 Ok(deleted)
617 })
618 }
619
620 pub fn export_all_facts(&self) -> Result<Vec<ExportedFact>, SqliteError> {
624 self.with_conn(|conn| {
625 let mut stmt = conn.prepare(
626 "SELECT id, namespace, category, subject, predicate, object,
627 confidence, source_episode_id
628 FROM semantic_facts
629 ORDER BY id ASC",
630 )?;
631 let rows = stmt
632 .query_map([], |row| {
633 Ok(ExportedFact {
634 id: row.get(0)?,
635 namespace: row.get(1)?,
636 category: row.get(2)?,
637 subject: row.get(3)?,
638 predicate: row.get(4)?,
639 object: row.get(5)?,
640 confidence: row.get(6)?,
641 source_episode_id: row.get(7)?,
642 })
643 })?
644 .collect::<Result<Vec<_>, _>>()?;
645 Ok(rows)
646 })
647 }
648
649 pub fn export_all_episodes(&self) -> Result<Vec<ExportedEpisode>, SqliteError> {
651 self.with_conn(|conn| {
652 let mut stmt = conn.prepare(
653 "SELECT e.id, e.session_id, COALESCE(s.channel, 'cli'),
654 e.namespace, e.role, e.content, e.timestamp,
655 e.importance, e.reinforcement_count
656 FROM episodes e
657 LEFT JOIN sessions s ON s.id = e.session_id
658 ORDER BY e.timestamp ASC",
659 )?;
660 let rows = stmt
661 .query_map([], |row| {
662 Ok(ExportedEpisode {
663 id: row.get(0)?,
664 session_id: row.get(1)?,
665 session_channel: row.get(2)?,
666 namespace: row.get(3)?,
667 role: row.get(4)?,
668 content: row.get(5)?,
669 timestamp: row.get(6)?,
670 importance: row.get(7)?,
671 reinforcement_count: row.get(8)?,
672 })
673 })?
674 .collect::<Result<Vec<_>, _>>()?;
675 Ok(rows)
676 })
677 }
678
679 pub fn import_facts(&self, facts: &[ExportedFact]) -> Result<(usize, Vec<usize>), SqliteError> {
681 self.with_conn(|conn| {
682 let mut imported = 0usize;
683 let mut new_indices = Vec::new();
684 for (idx, f) in facts.iter().enumerate() {
685 let n = conn.execute(
686 "INSERT INTO semantic_facts
687 (id, namespace, category, subject, predicate, object,
688 confidence, source_episode_id)
689 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
690 ON CONFLICT(id) DO NOTHING",
691 rusqlite::params![
692 f.id,
693 f.namespace,
694 f.category,
695 f.subject,
696 f.predicate,
697 f.object,
698 f.confidence,
699 f.source_episode_id
700 ],
701 )?;
702 if n > 0 {
703 new_indices.push(idx);
704 }
705 imported += n;
706 }
707 Ok((imported, new_indices))
708 })
709 }
710
711 pub fn import_episodes(&self, episodes: &[ExportedEpisode]) -> Result<usize, SqliteError> {
713 self.with_conn(|conn| {
714 let mut sessions: std::collections::HashMap<String, String> =
716 std::collections::HashMap::new();
717 for ep in episodes {
718 sessions
719 .entry(ep.session_id.clone())
720 .or_insert_with(|| ep.session_channel.clone());
721 }
722 for (sid, channel) in &sessions {
723 conn.execute(
724 "INSERT INTO sessions (id, channel) VALUES (?1, ?2)
725 ON CONFLICT(id) DO NOTHING",
726 rusqlite::params![sid, channel],
727 )?;
728 }
729
730 let mut imported = 0usize;
731 for e in episodes {
732 let n = conn.execute(
733 "INSERT INTO episodes
734 (id, session_id, namespace, role, content, timestamp,
735 importance, reinforcement_count)
736 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
737 ON CONFLICT(id) DO NOTHING",
738 rusqlite::params![
739 e.id,
740 e.session_id,
741 e.namespace,
742 e.role,
743 e.content,
744 e.timestamp,
745 e.importance,
746 e.reinforcement_count
747 ],
748 )?;
749 imported += n;
750 }
751 Ok(imported)
752 })
753 }
754
755 pub fn table_stats(&self) -> Result<Vec<(String, i64)>, SqliteError> {
757 self.with_conn(|conn| {
758 let mut stats = Vec::new();
759 for table in &[
762 "sessions",
763 "episodes",
764 "semantic_facts",
765 "episode_promotions",
766 "scheduled_intents",
767 "notification_outbox",
768 "user_profile",
769 "procedures",
770 "audit_log",
771 ] {
772 let sql = match *table {
773 "sessions" => "SELECT COUNT(*) FROM sessions",
774 "episodes" => "SELECT COUNT(*) FROM episodes",
775 "semantic_facts" => "SELECT COUNT(*) FROM semantic_facts",
776 "episode_promotions" => "SELECT COUNT(*) FROM episode_promotions",
777 "scheduled_intents" => "SELECT COUNT(*) FROM scheduled_intents",
778 "notification_outbox" => "SELECT COUNT(*) FROM notification_outbox",
779 "user_profile" => "SELECT COUNT(*) FROM user_profile",
780 "procedures" => "SELECT COUNT(*) FROM procedures",
781 "audit_log" => "SELECT COUNT(*) FROM audit_log",
782 _ => continue,
783 };
784 let count: i64 = conn.query_row(sql, [], |row| row.get(0)).unwrap_or(0);
785 stats.push((table.to_string(), count));
786 }
787
788 Ok(stats)
789 })
790 }
791}