Skip to main content

brainwires_stores/session/
sqlite.rs

1//! SQLite-backed [`SessionStore`] implementation.
2//!
3//! Messages are serialised to JSON and stored in a single table keyed by
4//! [`SessionId`]. Schema is auto-migrated on first connect.
5
6use std::path::{Path, PathBuf};
7use std::sync::Arc;
8
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use chrono::{DateTime, TimeZone, Utc};
12use rusqlite::{Connection, OptionalExtension, params};
13use tokio::sync::Mutex;
14
15use brainwires_core::Message;
16
17use super::{ListOptions, SessionId, SessionRecord, SessionStore};
18
19/// Disk-backed session store. Access is serialised through a single
20/// connection — adequate for single-node agent workloads.
21pub struct SqliteSessionStore {
22    conn: Arc<Mutex<Connection>>,
23    path: PathBuf,
24}
25
26impl SqliteSessionStore {
27    /// Open (or create) the store at `path`, auto-migrating the schema.
28    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
29        let path = path.as_ref().to_path_buf();
30        let conn = Connection::open(&path)
31            .with_context(|| format!("opening session store at {}", path.display()))?;
32        conn.execute_batch(
33            "CREATE TABLE IF NOT EXISTS sessions (
34                id TEXT PRIMARY KEY,
35                payload TEXT NOT NULL,
36                message_count INTEGER NOT NULL,
37                created_at INTEGER NOT NULL,
38                updated_at INTEGER NOT NULL
39            );",
40        )?;
41        Ok(Self {
42            conn: Arc::new(Mutex::new(conn)),
43            path,
44        })
45    }
46
47    /// Path this store writes to.
48    pub fn path(&self) -> &Path {
49        &self.path
50    }
51}
52
53fn ts_to_utc(secs: i64) -> DateTime<Utc> {
54    Utc.timestamp_opt(secs, 0).single().unwrap_or_else(Utc::now)
55}
56
57#[async_trait]
58impl SessionStore for SqliteSessionStore {
59    async fn load(&self, id: &SessionId) -> Result<Option<Vec<Message>>> {
60        let conn = self.conn.lock().await;
61        let payload: Option<String> = conn
62            .query_row(
63                "SELECT payload FROM sessions WHERE id = ?1",
64                params![id.as_str()],
65                |row| row.get(0),
66            )
67            .optional()?;
68        Ok(match payload {
69            Some(s) => Some(serde_json::from_str(&s)?),
70            None => None,
71        })
72    }
73
74    async fn save(&self, id: &SessionId, messages: &[Message]) -> Result<()> {
75        let payload = serde_json::to_string(messages)?;
76        let now = Utc::now().timestamp();
77        let conn = self.conn.lock().await;
78        conn.execute(
79            "INSERT INTO sessions (id, payload, message_count, created_at, updated_at)
80             VALUES (?1, ?2, ?3, ?4, ?4)
81             ON CONFLICT(id) DO UPDATE SET
82                payload = excluded.payload,
83                message_count = excluded.message_count,
84                updated_at = excluded.updated_at",
85            params![id.as_str(), payload, messages.len() as i64, now],
86        )?;
87        Ok(())
88    }
89
90    async fn list(&self) -> Result<Vec<SessionRecord>> {
91        let conn = self.conn.lock().await;
92        let mut stmt = conn.prepare(
93            "SELECT id, message_count, created_at, updated_at FROM sessions ORDER BY updated_at ASC",
94        )?;
95        let rows = stmt.query_map(params![], |row| {
96            Ok(SessionRecord {
97                id: SessionId::new(row.get::<_, String>(0)?),
98                message_count: row.get::<_, i64>(1)? as usize,
99                created_at: ts_to_utc(row.get::<_, i64>(2)?),
100                updated_at: ts_to_utc(row.get::<_, i64>(3)?),
101            })
102        })?;
103        let mut out = Vec::new();
104        for r in rows {
105            out.push(r?);
106        }
107        Ok(out)
108    }
109
110    async fn delete(&self, id: &SessionId) -> Result<()> {
111        let conn = self.conn.lock().await;
112        conn.execute("DELETE FROM sessions WHERE id = ?1", params![id.as_str()])?;
113        Ok(())
114    }
115
116    async fn list_paginated(&self, opts: ListOptions) -> Result<Vec<SessionRecord>> {
117        let conn = self.conn.lock().await;
118        // SQLite uses i64 for LIMIT/OFFSET. -1 means "no limit" — use it when
119        // the caller passed `None`.
120        let limit_sql: i64 = opts
121            .limit
122            .map(|l| l.try_into().unwrap_or(i64::MAX))
123            .unwrap_or(-1);
124        let offset_sql: i64 = opts.offset.try_into().unwrap_or(i64::MAX);
125        let mut stmt = conn.prepare(
126            "SELECT id, message_count, created_at, updated_at FROM sessions
127             ORDER BY updated_at ASC LIMIT ?1 OFFSET ?2",
128        )?;
129        let rows = stmt.query_map(params![limit_sql, offset_sql], |row| {
130            Ok(SessionRecord {
131                id: SessionId::new(row.get::<_, String>(0)?),
132                message_count: row.get::<_, i64>(1)? as usize,
133                created_at: ts_to_utc(row.get::<_, i64>(2)?),
134                updated_at: ts_to_utc(row.get::<_, i64>(3)?),
135            })
136        })?;
137        let mut out = Vec::new();
138        for r in rows {
139            out.push(r?);
140        }
141        Ok(out)
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    fn tmp_store() -> (SqliteSessionStore, tempfile::TempDir) {
150        let tmp = tempfile::tempdir().unwrap();
151        let path = tmp.path().join("sessions.db");
152        (SqliteSessionStore::open(&path).unwrap(), tmp)
153    }
154
155    #[tokio::test]
156    async fn roundtrip() {
157        let (store, _tmp) = tmp_store();
158        let id = SessionId::new("u1");
159        store.save(&id, &[Message::user("hi")]).await.unwrap();
160        let loaded = store.load(&id).await.unwrap().unwrap();
161        assert_eq!(loaded.len(), 1);
162        assert_eq!(loaded[0].text(), Some("hi"));
163    }
164
165    #[tokio::test]
166    async fn survives_reopen() {
167        let tmp = tempfile::tempdir().unwrap();
168        let path = tmp.path().join("sessions.db");
169        {
170            let store = SqliteSessionStore::open(&path).unwrap();
171            store
172                .save(&SessionId::new("persist"), &[Message::user("keep me")])
173                .await
174                .unwrap();
175        }
176        let store = SqliteSessionStore::open(&path).unwrap();
177        let loaded = store
178            .load(&SessionId::new("persist"))
179            .await
180            .unwrap()
181            .unwrap();
182        assert_eq!(loaded.len(), 1);
183    }
184
185    #[tokio::test]
186    async fn list_and_delete() {
187        let (store, _tmp) = tmp_store();
188        store
189            .save(&SessionId::new("a"), &[Message::user("x")])
190            .await
191            .unwrap();
192        store
193            .save(&SessionId::new("b"), &[Message::user("y")])
194            .await
195            .unwrap();
196        let list = store.list().await.unwrap();
197        assert_eq!(list.len(), 2);
198        store.delete(&SessionId::new("a")).await.unwrap();
199        assert_eq!(store.list().await.unwrap().len(), 1);
200    }
201}