brainwires_stores/session/
sqlite.rs1use std::path::{Path, PathBuf};
7use std::sync::Arc;
8
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use chrono::{DateTime, TimeZone, Utc};
12use rusqlite::{Connection, OptionalExtension, params};
13use tokio::sync::Mutex;
14
15use brainwires_core::Message;
16
17use super::{ListOptions, SessionId, SessionRecord, SessionStore};
18
19pub struct SqliteSessionStore {
22 conn: Arc<Mutex<Connection>>,
23 path: PathBuf,
24}
25
26impl SqliteSessionStore {
27 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
29 let path = path.as_ref().to_path_buf();
30 let conn = Connection::open(&path)
31 .with_context(|| format!("opening session store at {}", path.display()))?;
32 conn.execute_batch(
33 "CREATE TABLE IF NOT EXISTS sessions (
34 id TEXT PRIMARY KEY,
35 payload TEXT NOT NULL,
36 message_count INTEGER NOT NULL,
37 created_at INTEGER NOT NULL,
38 updated_at INTEGER NOT NULL
39 );",
40 )?;
41 Ok(Self {
42 conn: Arc::new(Mutex::new(conn)),
43 path,
44 })
45 }
46
47 pub fn path(&self) -> &Path {
49 &self.path
50 }
51}
52
53fn ts_to_utc(secs: i64) -> DateTime<Utc> {
54 Utc.timestamp_opt(secs, 0).single().unwrap_or_else(Utc::now)
55}
56
57#[async_trait]
58impl SessionStore for SqliteSessionStore {
59 async fn load(&self, id: &SessionId) -> Result<Option<Vec<Message>>> {
60 let conn = self.conn.lock().await;
61 let payload: Option<String> = conn
62 .query_row(
63 "SELECT payload FROM sessions WHERE id = ?1",
64 params![id.as_str()],
65 |row| row.get(0),
66 )
67 .optional()?;
68 Ok(match payload {
69 Some(s) => Some(serde_json::from_str(&s)?),
70 None => None,
71 })
72 }
73
74 async fn save(&self, id: &SessionId, messages: &[Message]) -> Result<()> {
75 let payload = serde_json::to_string(messages)?;
76 let now = Utc::now().timestamp();
77 let conn = self.conn.lock().await;
78 conn.execute(
79 "INSERT INTO sessions (id, payload, message_count, created_at, updated_at)
80 VALUES (?1, ?2, ?3, ?4, ?4)
81 ON CONFLICT(id) DO UPDATE SET
82 payload = excluded.payload,
83 message_count = excluded.message_count,
84 updated_at = excluded.updated_at",
85 params![id.as_str(), payload, messages.len() as i64, now],
86 )?;
87 Ok(())
88 }
89
90 async fn list(&self) -> Result<Vec<SessionRecord>> {
91 let conn = self.conn.lock().await;
92 let mut stmt = conn.prepare(
93 "SELECT id, message_count, created_at, updated_at FROM sessions ORDER BY updated_at ASC",
94 )?;
95 let rows = stmt.query_map(params![], |row| {
96 Ok(SessionRecord {
97 id: SessionId::new(row.get::<_, String>(0)?),
98 message_count: row.get::<_, i64>(1)? as usize,
99 created_at: ts_to_utc(row.get::<_, i64>(2)?),
100 updated_at: ts_to_utc(row.get::<_, i64>(3)?),
101 })
102 })?;
103 let mut out = Vec::new();
104 for r in rows {
105 out.push(r?);
106 }
107 Ok(out)
108 }
109
110 async fn delete(&self, id: &SessionId) -> Result<()> {
111 let conn = self.conn.lock().await;
112 conn.execute("DELETE FROM sessions WHERE id = ?1", params![id.as_str()])?;
113 Ok(())
114 }
115
116 async fn list_paginated(&self, opts: ListOptions) -> Result<Vec<SessionRecord>> {
117 let conn = self.conn.lock().await;
118 let limit_sql: i64 = opts
121 .limit
122 .map(|l| l.try_into().unwrap_or(i64::MAX))
123 .unwrap_or(-1);
124 let offset_sql: i64 = opts.offset.try_into().unwrap_or(i64::MAX);
125 let mut stmt = conn.prepare(
126 "SELECT id, message_count, created_at, updated_at FROM sessions
127 ORDER BY updated_at ASC LIMIT ?1 OFFSET ?2",
128 )?;
129 let rows = stmt.query_map(params![limit_sql, offset_sql], |row| {
130 Ok(SessionRecord {
131 id: SessionId::new(row.get::<_, String>(0)?),
132 message_count: row.get::<_, i64>(1)? as usize,
133 created_at: ts_to_utc(row.get::<_, i64>(2)?),
134 updated_at: ts_to_utc(row.get::<_, i64>(3)?),
135 })
136 })?;
137 let mut out = Vec::new();
138 for r in rows {
139 out.push(r?);
140 }
141 Ok(out)
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 fn tmp_store() -> (SqliteSessionStore, tempfile::TempDir) {
150 let tmp = tempfile::tempdir().unwrap();
151 let path = tmp.path().join("sessions.db");
152 (SqliteSessionStore::open(&path).unwrap(), tmp)
153 }
154
155 #[tokio::test]
156 async fn roundtrip() {
157 let (store, _tmp) = tmp_store();
158 let id = SessionId::new("u1");
159 store.save(&id, &[Message::user("hi")]).await.unwrap();
160 let loaded = store.load(&id).await.unwrap().unwrap();
161 assert_eq!(loaded.len(), 1);
162 assert_eq!(loaded[0].text(), Some("hi"));
163 }
164
165 #[tokio::test]
166 async fn survives_reopen() {
167 let tmp = tempfile::tempdir().unwrap();
168 let path = tmp.path().join("sessions.db");
169 {
170 let store = SqliteSessionStore::open(&path).unwrap();
171 store
172 .save(&SessionId::new("persist"), &[Message::user("keep me")])
173 .await
174 .unwrap();
175 }
176 let store = SqliteSessionStore::open(&path).unwrap();
177 let loaded = store
178 .load(&SessionId::new("persist"))
179 .await
180 .unwrap()
181 .unwrap();
182 assert_eq!(loaded.len(), 1);
183 }
184
185 #[tokio::test]
186 async fn list_and_delete() {
187 let (store, _tmp) = tmp_store();
188 store
189 .save(&SessionId::new("a"), &[Message::user("x")])
190 .await
191 .unwrap();
192 store
193 .save(&SessionId::new("b"), &[Message::user("y")])
194 .await
195 .unwrap();
196 let list = store.list().await.unwrap();
197 assert_eq!(list.len(), 2);
198 store.delete(&SessionId::new("a")).await.unwrap();
199 assert_eq!(store.list().await.unwrap().len(), 1);
200 }
201}