Skip to main content

brainos_storage/
sqlite.rs

1//! SQLite storage backend.
2//!
3//! Provides connection management, schema migrations,
4//! and typed CRUD operations for all Brain data:
5//! - Episodes (conversations)
6//! - Semantic facts (user model, extracted knowledge)
7//! - Sessions (conversation grouping)
8
9use std::path::Path;
10use std::sync::{Arc, Mutex};
11
12use rusqlite::Connection;
13use thiserror::Error;
14use tracing::info;
15use uuid::Uuid;
16
17#[cfg(feature = "encryption")]
18use crate::encryption::Encryptor;
19
20mod migrations;
21
22#[cfg(test)]
23mod tests;
24
25/// Errors from the SQLite storage layer.
26#[derive(Debug, Error)]
27pub enum SqliteError {
28    #[error("SQLite error: {0}")]
29    Rusqlite(#[from] rusqlite::Error),
30
31    #[error("Lock poisoned")]
32    LockPoisoned,
33
34    #[error("Migration failed: {0}")]
35    Migration(String),
36}
37
38/// A semantic fact for export/import operations.
39#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40pub struct ExportedFact {
41    pub id: String,
42    pub namespace: String,
43    pub category: String,
44    pub subject: String,
45    pub predicate: String,
46    pub object: String,
47    pub confidence: f64,
48    pub source_episode_id: Option<String>,
49}
50
51/// An episodic memory entry for export/import operations.
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53pub struct ExportedEpisode {
54    pub id: String,
55    pub session_id: String,
56    pub session_channel: String,
57    #[serde(default = "default_namespace")]
58    pub namespace: String,
59    pub role: String,
60    pub content: String,
61    pub timestamp: String,
62    pub importance: f64,
63    pub reinforcement_count: i32,
64}
65
66fn default_namespace() -> String {
67    "personal".to_string()
68}
69
70/// A notification queued for delivery to the user.
71#[derive(Debug, Clone)]
72pub struct Notification {
73    pub id: String,
74    pub content: String,
75    pub priority: i32,
76    pub triggered_by: String,
77    pub created_at: String,
78    pub delivered_at: Option<String>,
79    pub channel: Option<String>,
80}
81
82/// Thread-safe SQLite connection wrapper.
83///
84/// Uses a `Mutex<Connection>` — sufficient for our single-process,
85/// moderate-write workload. If we ever need concurrent writers,
86/// switch to `r2d2` or WAL mode (already enabled).
87///
88/// When an `Encryptor` is set, `content` columns are transparently
89/// encrypted on write and decrypted on read by the store layers.
90#[derive(Clone)]
91pub struct SqlitePool {
92    conn: Arc<Mutex<Connection>>,
93    #[cfg(feature = "encryption")]
94    encryptor: Option<Arc<Encryptor>>,
95}
96
97/// Persisted scheduling intent (persist-only mode, no internal runtime).
98#[derive(Debug, Clone)]
99pub struct ScheduledIntent {
100    pub id: String,
101    pub description: String,
102    pub cron: Option<String>,
103    pub namespace: String,
104    pub created_at: String,
105    pub status: String,
106    pub metadata: Option<String>,
107}
108
109impl SqlitePool {
110    /// Open a new SQLite database at the given path.
111    ///
112    /// - Creates the file if it doesn't exist
113    /// - Enables WAL mode for concurrent reads
114    /// - Enables foreign keys
115    /// - Runs all schema migrations
116    pub fn open(path: &Path) -> Result<Self, SqliteError> {
117        // Ensure parent directory exists
118        if let Some(parent) = path.parent() {
119            std::fs::create_dir_all(parent).map_err(|e| {
120                SqliteError::Migration(format!("Cannot create directory {}: {e}", parent.display()))
121            })?;
122        }
123
124        let conn = Connection::open(path)?;
125
126        // Performance and safety pragmas — foreign_keys enforced by SQLite.
127        conn.execute_batch(
128            "
129            PRAGMA journal_mode = WAL;
130            PRAGMA synchronous = NORMAL;
131            PRAGMA foreign_keys = ON;
132            PRAGMA busy_timeout = 5000;
133            PRAGMA cache_size = -8000;
134            ",
135        )?;
136
137        let pool = Self {
138            conn: Arc::new(Mutex::new(conn)),
139            #[cfg(feature = "encryption")]
140            encryptor: None,
141        };
142
143        // Run migrations
144        pool.migrate()?;
145
146        info!("SQLite database opened at {}", path.display());
147        Ok(pool)
148    }
149
150    /// Open an in-memory database (for testing).
151    pub fn open_memory() -> Result<Self, SqliteError> {
152        let conn = Connection::open_in_memory()?;
153        conn.execute_batch(
154            "
155            PRAGMA journal_mode = WAL;
156            PRAGMA foreign_keys = ON;
157            ",
158        )?;
159
160        let pool = Self {
161            conn: Arc::new(Mutex::new(conn)),
162            #[cfg(feature = "encryption")]
163            encryptor: None,
164        };
165
166        pool.migrate()?;
167        Ok(pool)
168    }
169
170    /// Execute a closure with an exclusive lock on the connection.
171    pub fn with_conn<F, T>(&self, f: F) -> Result<T, SqliteError>
172    where
173        F: FnOnce(&Connection) -> Result<T, SqliteError>,
174    {
175        let conn = self.conn.lock().map_err(|_| SqliteError::LockPoisoned)?;
176        f(&conn)
177    }
178
179    /// Attach an encryptor to this pool (builder pattern).
180    ///
181    /// Once set, `encrypt_content` / `decrypt_content` are active on all
182    /// store layers that use this pool.
183    #[cfg(feature = "encryption")]
184    pub fn with_encryptor(mut self, enc: Encryptor) -> Self {
185        self.encryptor = Some(Arc::new(enc));
186        self
187    }
188
189    /// Returns true if an encryptor is active.
190    pub fn is_encrypted(&self) -> bool {
191        #[cfg(feature = "encryption")]
192        {
193            self.encryptor.is_some()
194        }
195        #[cfg(not(feature = "encryption"))]
196        {
197            false
198        }
199    }
200
201    /// Encrypt a string if encryption is enabled, otherwise return as-is.
202    pub fn encrypt_content(&self, plaintext: &str) -> String {
203        #[cfg(feature = "encryption")]
204        {
205            if let Some(enc) = &self.encryptor {
206                return enc
207                    .encrypt_string(plaintext)
208                    .unwrap_or_else(|_| plaintext.to_string());
209            }
210        }
211        plaintext.to_string()
212    }
213
214    /// Decrypt a string if encryption is enabled.
215    ///
216    /// Falls back to returning the input unchanged if decryption fails
217    /// (e.g. legacy plaintext rows written before encryption was enabled).
218    pub fn decrypt_content(&self, maybe_ciphertext: &str) -> String {
219        #[cfg(feature = "encryption")]
220        {
221            if let Some(enc) = &self.encryptor {
222                return enc
223                    .decrypt_string(maybe_ciphertext)
224                    .unwrap_or_else(|_| maybe_ciphertext.to_string());
225            }
226        }
227        maybe_ciphertext.to_string()
228    }
229
230    /// Try to decrypt a string, returning `None` if decryption fails.
231    ///
232    /// Unlike `decrypt_content`, this does NOT fall back to returning raw
233    /// ciphertext. Use this at read boundaries to filter out rows that
234    /// were encrypted with a different key or are corrupted.
235    pub fn try_decrypt_content(&self, maybe_ciphertext: &str) -> Option<String> {
236        #[cfg(feature = "encryption")]
237        {
238            if let Some(enc) = &self.encryptor {
239                return enc.decrypt_string(maybe_ciphertext).ok();
240            }
241        }
242        Some(maybe_ciphertext.to_string())
243    }
244
245    /// Flush the WAL file into the main database file.
246    ///
247    /// Should be called on graceful shutdown to ensure all committed writes are
248    /// fully persisted and the WAL file is clean. Uses `TRUNCATE` mode which
249    /// also resets the WAL to zero size.
250    pub fn wal_checkpoint(&self) -> Result<(), SqliteError> {
251        self.with_conn(|conn| {
252            conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
253            Ok(())
254        })
255    }
256
257    /// Persist a scheduled intent and return its generated ID.
258    pub fn insert_scheduled_intent(
259        &self,
260        description: &str,
261        cron: Option<&str>,
262        namespace: &str,
263        metadata: Option<&str>,
264    ) -> Result<String, SqliteError> {
265        let id = Uuid::new_v4().to_string();
266        self.with_conn(|conn| {
267            conn.execute(
268                "INSERT INTO scheduled_intents (id, description, cron, namespace, metadata)
269                 VALUES (?1, ?2, ?3, ?4, ?5)",
270                rusqlite::params![id, description, cron, namespace, metadata],
271            )?;
272            Ok(())
273        })?;
274        Ok(id)
275    }
276
277    /// List scheduled intents, optionally filtered by namespace.
278    pub fn list_scheduled_intents(
279        &self,
280        namespace: Option<&str>,
281    ) -> Result<Vec<ScheduledIntent>, SqliteError> {
282        self.with_conn(|conn| {
283            let mut intents = Vec::new();
284            if let Some(ns) = namespace {
285                let mut stmt = conn.prepare(
286                    "SELECT id, description, cron, namespace, created_at, status, metadata
287                     FROM scheduled_intents
288                     WHERE namespace = ?1 OR namespace LIKE ?2
289                     ORDER BY created_at DESC",
290                )?;
291                let prefix = format!("{}/%", ns);
292                let rows = stmt.query_map([ns, &prefix], |row| {
293                    Ok(ScheduledIntent {
294                        id: row.get(0)?,
295                        description: row.get(1)?,
296                        cron: row.get(2)?,
297                        namespace: row.get(3)?,
298                        created_at: row.get(4)?,
299                        status: row.get(5)?,
300                        metadata: row.get(6)?,
301                    })
302                })?;
303                for row in rows {
304                    intents.push(row?);
305                }
306            } else {
307                let mut stmt = conn.prepare(
308                    "SELECT id, description, cron, namespace, created_at, status, metadata
309                     FROM scheduled_intents
310                     ORDER BY created_at DESC",
311                )?;
312                let rows = stmt.query_map([], |row| {
313                    Ok(ScheduledIntent {
314                        id: row.get(0)?,
315                        description: row.get(1)?,
316                        cron: row.get(2)?,
317                        namespace: row.get(3)?,
318                        created_at: row.get(4)?,
319                        status: row.get(5)?,
320                        metadata: row.get(6)?,
321                    })
322                })?;
323                for row in rows {
324                    intents.push(row?);
325                }
326            }
327            Ok(intents)
328        })
329    }
330
331    /// Update a scheduled intent status. Returns true when a row was updated.
332    pub fn update_scheduled_intent_status(
333        &self,
334        id: &str,
335        status: &str,
336    ) -> Result<bool, SqliteError> {
337        self.with_conn(|conn| {
338            let affected = conn.execute(
339                "UPDATE scheduled_intents SET status = ?2 WHERE id = ?1",
340                rusqlite::params![id, status],
341            )?;
342            Ok(affected > 0)
343        })
344    }
345
346    /// Cancel a scheduled intent (set status to "cancelled").
347    pub fn cancel_scheduled_intent(&self, id: &str) -> Result<bool, SqliteError> {
348        self.update_scheduled_intent_status(id, "cancelled")
349    }
350
351    /// Return all scheduled intents with status `"scheduled"` (i.e. pending execution).
352    pub fn due_scheduled_intents(&self) -> Result<Vec<ScheduledIntent>, SqliteError> {
353        self.with_conn(|conn| {
354            let mut stmt = conn.prepare(
355                "SELECT id, description, cron, namespace, created_at, status, metadata
356                 FROM scheduled_intents
357                 WHERE status = 'scheduled'
358                 ORDER BY created_at ASC",
359            )?;
360            let rows = stmt.query_map([], |row| {
361                Ok(ScheduledIntent {
362                    id: row.get(0)?,
363                    description: row.get(1)?,
364                    cron: row.get(2)?,
365                    namespace: row.get(3)?,
366                    created_at: row.get(4)?,
367                    status: row.get(5)?,
368                    metadata: row.get(6)?,
369                })
370            })?;
371            Ok(rows.filter_map(|r| r.ok()).collect())
372        })
373    }
374
375    /// Insert a notification into the outbox for later delivery.
376    pub fn insert_notification(
377        &self,
378        content: &str,
379        priority: i32,
380        triggered_by: &str,
381        channel: Option<&str>,
382    ) -> Result<String, SqliteError> {
383        let id = Uuid::new_v4().to_string();
384        self.with_conn(|conn| {
385            conn.execute(
386                "INSERT INTO notification_outbox (id, content, priority, triggered_by, channel)
387                 VALUES (?1, ?2, ?3, ?4, ?5)",
388                rusqlite::params![id, content, priority, triggered_by, channel],
389            )?;
390            Ok(())
391        })?;
392        Ok(id)
393    }
394
395    /// Fetch all pending (undelivered) notifications, ordered by priority then age.
396    pub fn pending_notifications(&self, limit: usize) -> Result<Vec<Notification>, SqliteError> {
397        self.with_conn(|conn| {
398            let mut stmt = conn.prepare(
399                "SELECT id, content, priority, triggered_by, created_at, delivered_at, channel
400                 FROM notification_outbox
401                 WHERE delivered_at IS NULL
402                 ORDER BY priority DESC, created_at ASC
403                 LIMIT ?1",
404            )?;
405            let rows = stmt
406                .query_map([limit as i64], |row| {
407                    Ok(Notification {
408                        id: row.get(0)?,
409                        content: row.get(1)?,
410                        priority: row.get(2)?,
411                        triggered_by: row.get(3)?,
412                        created_at: row.get(4)?,
413                        delivered_at: row.get(5)?,
414                        channel: row.get(6)?,
415                    })
416                })?
417                .collect::<Result<Vec<_>, _>>()?;
418            Ok(rows)
419        })
420    }
421
422    /// Mark a notification as delivered (sets `delivered_at` to now).
423    pub fn mark_notification_delivered(&self, id: &str) -> Result<bool, SqliteError> {
424        self.with_conn(|conn| {
425            let affected = conn.execute(
426                "UPDATE notification_outbox SET delivered_at = datetime('now') WHERE id = ?1 AND delivered_at IS NULL",
427                [id],
428            )?;
429            Ok(affected > 0)
430        })
431    }
432
433    /// Mark multiple notifications as delivered in a single UPDATE.
434    /// Returns the count of notifications actually marked delivered.
435    pub fn mark_notifications_delivered(&self, ids: &[String]) -> Result<usize, SqliteError> {
436        if ids.is_empty() {
437            return Ok(0);
438        }
439        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
440        let sql = format!(
441            "UPDATE notification_outbox SET delivered_at = datetime('now') \
442             WHERE delivered_at IS NULL AND id IN ({placeholders})"
443        );
444        self.with_conn(|conn| {
445            let params: Vec<&dyn rusqlite::types::ToSql> = ids
446                .iter()
447                .map(|id| id as &dyn rusqlite::types::ToSql)
448                .collect();
449            let affected = conn.execute(&sql, params.as_slice())?;
450            Ok(affected)
451        })
452    }
453
454    /// Prune old delivered notifications and stale undelivered ones.
455    pub fn prune_notifications(&self, max_age_days: u32) -> Result<usize, SqliteError> {
456        self.with_conn(|conn| {
457            let deleted = conn.execute(
458                "DELETE FROM notification_outbox
459                 WHERE (delivered_at IS NOT NULL AND created_at < datetime('now', ?1))
460                    OR created_at < datetime('now', ?1)",
461                [format!("-{max_age_days} days")],
462            )?;
463            Ok(deleted)
464        })
465    }
466
467    // ── Export / Import ──────────────────────────────────────────────────────
468
469    /// Export all semantic facts ordered by ID.
470    pub fn export_all_facts(&self) -> Result<Vec<ExportedFact>, SqliteError> {
471        self.with_conn(|conn| {
472            let mut stmt = conn.prepare(
473                "SELECT id, namespace, category, subject, predicate, object,
474                        confidence, source_episode_id
475                 FROM semantic_facts
476                 ORDER BY id ASC",
477            )?;
478            let rows = stmt
479                .query_map([], |row| {
480                    Ok(ExportedFact {
481                        id: row.get(0)?,
482                        namespace: row.get(1)?,
483                        category: row.get(2)?,
484                        subject: row.get(3)?,
485                        predicate: row.get(4)?,
486                        object: row.get(5)?,
487                        confidence: row.get(6)?,
488                        source_episode_id: row.get(7)?,
489                    })
490                })?
491                .collect::<Result<Vec<_>, _>>()?;
492            Ok(rows)
493        })
494    }
495
496    /// Export all episodes with session info, ordered by timestamp.
497    pub fn export_all_episodes(&self) -> Result<Vec<ExportedEpisode>, SqliteError> {
498        self.with_conn(|conn| {
499            let mut stmt = conn.prepare(
500                "SELECT e.id, e.session_id, COALESCE(s.channel, 'cli'),
501                        e.namespace, e.role, e.content, e.timestamp,
502                        e.importance, e.reinforcement_count
503                 FROM episodes e
504                 LEFT JOIN sessions s ON s.id = e.session_id
505                 ORDER BY e.timestamp ASC",
506            )?;
507            let rows = stmt
508                .query_map([], |row| {
509                    Ok(ExportedEpisode {
510                        id: row.get(0)?,
511                        session_id: row.get(1)?,
512                        session_channel: row.get(2)?,
513                        namespace: row.get(3)?,
514                        role: row.get(4)?,
515                        content: row.get(5)?,
516                        timestamp: row.get(6)?,
517                        importance: row.get(7)?,
518                        reinforcement_count: row.get(8)?,
519                    })
520                })?
521                .collect::<Result<Vec<_>, _>>()?;
522            Ok(rows)
523        })
524    }
525
526    /// Import facts (ON CONFLICT DO NOTHING). Returns (imported_count, new_indices).
527    pub fn import_facts(&self, facts: &[ExportedFact]) -> Result<(usize, Vec<usize>), SqliteError> {
528        self.with_conn(|conn| {
529            let mut imported = 0usize;
530            let mut new_indices = Vec::new();
531            for (idx, f) in facts.iter().enumerate() {
532                let n = conn.execute(
533                    "INSERT INTO semantic_facts
534                        (id, namespace, category, subject, predicate, object,
535                         confidence, source_episode_id)
536                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
537                     ON CONFLICT(id) DO NOTHING",
538                    rusqlite::params![
539                        f.id,
540                        f.namespace,
541                        f.category,
542                        f.subject,
543                        f.predicate,
544                        f.object,
545                        f.confidence,
546                        f.source_episode_id
547                    ],
548                )?;
549                if n > 0 {
550                    new_indices.push(idx);
551                }
552                imported += n;
553            }
554            Ok((imported, new_indices))
555        })
556    }
557
558    /// Import episodes (ON CONFLICT DO NOTHING). Returns count of newly imported episodes.
559    pub fn import_episodes(&self, episodes: &[ExportedEpisode]) -> Result<usize, SqliteError> {
560        self.with_conn(|conn| {
561            // Ensure sessions exist first
562            let mut sessions: std::collections::HashMap<String, String> =
563                std::collections::HashMap::new();
564            for ep in episodes {
565                sessions
566                    .entry(ep.session_id.clone())
567                    .or_insert_with(|| ep.session_channel.clone());
568            }
569            for (sid, channel) in &sessions {
570                conn.execute(
571                    "INSERT INTO sessions (id, channel) VALUES (?1, ?2)
572                     ON CONFLICT(id) DO NOTHING",
573                    rusqlite::params![sid, channel],
574                )?;
575            }
576
577            let mut imported = 0usize;
578            for e in episodes {
579                let n = conn.execute(
580                    "INSERT INTO episodes
581                        (id, session_id, namespace, role, content, timestamp,
582                         importance, reinforcement_count)
583                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
584                     ON CONFLICT(id) DO NOTHING",
585                    rusqlite::params![
586                        e.id,
587                        e.session_id,
588                        e.namespace,
589                        e.role,
590                        e.content,
591                        e.timestamp,
592                        e.importance,
593                        e.reinforcement_count
594                    ],
595                )?;
596                imported += n;
597            }
598            Ok(imported)
599        })
600    }
601
602    /// Get table row counts for status display.
603    pub fn table_stats(&self) -> Result<Vec<(String, i64)>, SqliteError> {
604        self.with_conn(|conn| {
605            let mut stats = Vec::new();
606            // Whitelist approach: each match arm is both the table name and its SQL,
607            // preventing any possibility of SQL injection via table name interpolation.
608            for table in &[
609                "sessions",
610                "episodes",
611                "semantic_facts",
612                "episode_promotions",
613                "scheduled_intents",
614                "notification_outbox",
615                "user_profile",
616                "procedures",
617                "audit_log",
618            ] {
619                let sql = match *table {
620                    "sessions" => "SELECT COUNT(*) FROM sessions",
621                    "episodes" => "SELECT COUNT(*) FROM episodes",
622                    "semantic_facts" => "SELECT COUNT(*) FROM semantic_facts",
623                    "episode_promotions" => "SELECT COUNT(*) FROM episode_promotions",
624                    "scheduled_intents" => "SELECT COUNT(*) FROM scheduled_intents",
625                    "notification_outbox" => "SELECT COUNT(*) FROM notification_outbox",
626                    "user_profile" => "SELECT COUNT(*) FROM user_profile",
627                    "procedures" => "SELECT COUNT(*) FROM procedures",
628                    "audit_log" => "SELECT COUNT(*) FROM audit_log",
629                    _ => continue,
630                };
631                let count: i64 = conn.query_row(sql, [], |row| row.get(0)).unwrap_or(0);
632                stats.push((table.to_string(), count));
633            }
634
635            Ok(stats)
636        })
637    }
638}