Skip to main content

agnt_store/
store.rs

1use agnt_core::{Message, MessageStore, StoreError, ToolLog};
2use rusqlite::{params, Connection};
3use std::sync::Mutex;
4use tracing::{debug, info};
5
6fn io_err(e: impl std::fmt::Display) -> StoreError {
7    StoreError::Io(e.to_string())
8}
9
10/// SQLite-backed session store.
11///
12/// Wraps `rusqlite::Connection` in a `Mutex` so the store is `Send + Sync`
13/// and can be shared across threads without the caller needing to add their
14/// own interior mutability.
15///
16/// The connection is opened in WAL mode with `synchronous=NORMAL` for
17/// significantly better throughput on the agent hot path. Prepared
18/// statements are cached across calls.
19pub struct Store {
20    conn: Mutex<Connection>,
21}
22
23impl Store {
24    pub fn open(path: &str) -> Result<Self, String> {
25        let conn = Connection::open(path).map_err(|e| e.to_string())?;
26
27        // WAL mode + relaxed fsync on the hot path. `journal_mode` is a
28        // PRAGMA that returns the new mode as a row, so we use `query_row`.
29        let mode: String = conn
30            .query_row("PRAGMA journal_mode=WAL", [], |r| r.get(0))
31            .map_err(|e| e.to_string())?;
32        conn.pragma_update(None, "synchronous", &"NORMAL")
33            .map_err(|e| e.to_string())?;
34        info!(path = %path, journal_mode = %mode, "agnt-store opened");
35
36        conn.execute(
37            "CREATE TABLE IF NOT EXISTS messages (
38                session TEXT NOT NULL,
39                idx     INTEGER NOT NULL,
40                json    TEXT NOT NULL,
41                PRIMARY KEY (session, idx)
42            )",
43            [],
44        )
45        .map_err(|e| e.to_string())?;
46        conn.execute(
47            "CREATE TABLE IF NOT EXISTS tool_calls (
48                session     TEXT NOT NULL,
49                ts          INTEGER NOT NULL,
50                name        TEXT NOT NULL,
51                args        TEXT NOT NULL,
52                result      TEXT NOT NULL,
53                duration_us INTEGER NOT NULL
54            )",
55            [],
56        )
57        .map_err(|e| e.to_string())?;
58        conn.execute(
59            "CREATE TABLE IF NOT EXISTS usage (
60                session           TEXT NOT NULL,
61                message_idx       INTEGER NOT NULL,
62                prompt_tokens     INTEGER,
63                completion_tokens INTEGER,
64                total_tokens      INTEGER,
65                PRIMARY KEY (session, message_idx)
66            )",
67            [],
68        )
69        .map_err(|e| e.to_string())?;
70        Ok(Self {
71            conn: Mutex::new(conn),
72        })
73    }
74
75    fn lock(&self) -> Result<std::sync::MutexGuard<'_, Connection>, String> {
76        self.conn
77            .lock()
78            .map_err(|e| format!("store mutex poisoned: {}", e))
79    }
80
81    /// Returns the current `journal_mode` PRAGMA value (e.g. "wal").
82    /// Primarily used by tests to confirm WAL is active.
83    pub fn journal_mode(&self) -> Result<String, String> {
84        let conn = self.lock()?;
85        conn.query_row("PRAGMA journal_mode", [], |r| r.get::<_, String>(0))
86            .map_err(|e| e.to_string())
87    }
88
89    pub fn log_tool(
90        &self,
91        session: &str,
92        name: &str,
93        args: &str,
94        result: &str,
95        duration_us: u64,
96    ) -> Result<(), String> {
97        let ts = std::time::SystemTime::now()
98            .duration_since(std::time::UNIX_EPOCH)
99            .map(|d| d.as_secs() as i64)
100            .unwrap_or(0);
101        let conn = self.lock()?;
102        let mut stmt = conn
103            .prepare_cached(
104                "INSERT INTO tool_calls (session, ts, name, args, result, duration_us)
105                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
106            )
107            .map_err(|e| e.to_string())?;
108        stmt.execute(params![session, ts, name, args, result, duration_us as i64])
109            .map_err(|e| e.to_string())?;
110        Ok(())
111    }
112
113    pub fn load(&self, session: &str) -> Result<Vec<Message>, String> {
114        let conn = self.lock()?;
115        let mut stmt = conn
116            .prepare_cached("SELECT json FROM messages WHERE session = ?1 ORDER BY idx")
117            .map_err(|e| e.to_string())?;
118        let rows = stmt
119            .query_map(params![session], |r| r.get::<_, String>(0))
120            .map_err(|e| e.to_string())?;
121        let mut out = Vec::new();
122        for r in rows {
123            let s = r.map_err(|e| e.to_string())?;
124            let m: Message = serde_json::from_str(&s).map_err(|e| e.to_string())?;
125            out.push(m);
126        }
127        Ok(out)
128    }
129
130    /// Append a single message. Uses a single INSERT…SELECT to compute the
131    /// next `idx` in one roundtrip instead of a separate SELECT + INSERT.
132    pub fn append(&self, session: &str, msg: &Message) -> Result<(), String> {
133        let json = serde_json::to_string(msg).map_err(|e| e.to_string())?;
134        let conn = self.lock()?;
135        let mut stmt = conn
136            .prepare_cached(
137                "INSERT INTO messages (session, idx, json)
138                 SELECT ?1, COALESCE(MAX(idx), -1) + 1, ?2
139                 FROM messages
140                 WHERE session = ?1",
141            )
142            .map_err(|e| e.to_string())?;
143        stmt.execute(params![session, json])
144            .map_err(|e| e.to_string())?;
145        Ok(())
146    }
147
148    /// Append many messages in a single transaction. Reduces fsync cost
149    /// from 2+N per turn down to 1.
150    #[tracing::instrument(skip(self, messages), fields(session = %session, count = messages.len()))]
151    pub fn append_many(&self, session: &str, messages: &[Message]) -> Result<(), String> {
152        if messages.is_empty() {
153            return Ok(());
154        }
155        // Pre-serialize outside the lock.
156        let jsons: Vec<String> = messages
157            .iter()
158            .map(serde_json::to_string)
159            .collect::<Result<_, _>>()
160            .map_err(|e| e.to_string())?;
161
162        let mut conn = self.lock()?;
163        let tx = conn.transaction().map_err(|e| e.to_string())?;
164        {
165            // Find starting idx once.
166            let mut next: i64 = tx
167                .query_row(
168                    "SELECT COALESCE(MAX(idx), -1) + 1 FROM messages WHERE session = ?1",
169                    params![session],
170                    |r| r.get(0),
171                )
172                .map_err(|e| e.to_string())?;
173            let mut stmt = tx
174                .prepare_cached(
175                    "INSERT INTO messages (session, idx, json) VALUES (?1, ?2, ?3)",
176                )
177                .map_err(|e| e.to_string())?;
178            for json in &jsons {
179                stmt.execute(params![session, next, json])
180                    .map_err(|e| e.to_string())?;
181                next += 1;
182            }
183        }
184        tx.commit().map_err(|e| e.to_string())?;
185        Ok(())
186    }
187
188    /// Run an arbitrary closure inside a single BEGIN/COMMIT transaction.
189    /// The closure receives a borrowed `rusqlite::Transaction` and can issue
190    /// multiple writes that will all commit together (or roll back on error).
191    pub fn with_transaction<F, T>(&self, f: F) -> Result<T, String>
192    where
193        F: FnOnce(&rusqlite::Transaction<'_>) -> Result<T, String>,
194    {
195        let mut conn = self.lock()?;
196        let tx = conn.transaction().map_err(|e| e.to_string())?;
197        let out = f(&tx)?;
198        tx.commit().map_err(|e| e.to_string())?;
199        Ok(out)
200    }
201
202    #[tracing::instrument(skip(self), fields(session = %session))]
203    pub fn clear(&self, session: &str) -> Result<(), String> {
204        debug!("clearing session");
205        let conn = self.lock()?;
206        {
207            let mut stmt = conn
208                .prepare_cached("DELETE FROM messages WHERE session = ?1")
209                .map_err(|e| e.to_string())?;
210            stmt.execute(params![session]).map_err(|e| e.to_string())?;
211        }
212        {
213            let mut stmt = conn
214                .prepare_cached("DELETE FROM tool_calls WHERE session = ?1")
215                .map_err(|e| e.to_string())?;
216            stmt.execute(params![session]).map_err(|e| e.to_string())?;
217        }
218        {
219            let mut stmt = conn
220                .prepare_cached("DELETE FROM usage WHERE session = ?1")
221                .map_err(|e| e.to_string())?;
222            stmt.execute(params![session]).map_err(|e| e.to_string())?;
223        }
224        Ok(())
225    }
226
227    /// Per-tool latency stats for a session: (name, count, avg_us, max_us).
228    pub fn stats(&self, session: &str) -> Result<Vec<(String, i64, i64, i64)>, String> {
229        let conn = self.lock()?;
230        let mut stmt = conn
231            .prepare_cached(
232                "SELECT name, COUNT(*), CAST(AVG(duration_us) AS INTEGER), MAX(duration_us)
233                 FROM tool_calls
234                 WHERE session = ?1
235                 GROUP BY name
236                 ORDER BY COUNT(*) DESC",
237            )
238            .map_err(|e| e.to_string())?;
239        let rows = stmt
240            .query_map(params![session], |r| {
241                Ok((
242                    r.get::<_, String>(0)?,
243                    r.get::<_, i64>(1)?,
244                    r.get::<_, i64>(2)?,
245                    r.get::<_, i64>(3)?,
246                ))
247            })
248            .map_err(|e| e.to_string())?;
249        rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())
250    }
251
252    /// Record token usage for a single message in a session.
253    pub fn log_usage(
254        &self,
255        session: &str,
256        message_idx: i64,
257        prompt: u32,
258        completion: u32,
259    ) -> Result<(), String> {
260        let total = prompt as i64 + completion as i64;
261        let conn = self.lock()?;
262        let mut stmt = conn
263            .prepare_cached(
264                "INSERT OR REPLACE INTO usage
265                    (session, message_idx, prompt_tokens, completion_tokens, total_tokens)
266                 VALUES (?1, ?2, ?3, ?4, ?5)",
267            )
268            .map_err(|e| e.to_string())?;
269        stmt.execute(params![session, message_idx, prompt as i64, completion as i64, total])
270            .map_err(|e| e.to_string())?;
271        Ok(())
272    }
273
274    /// Sum token usage across a session: `(prompt_sum, completion_sum, total_sum)`.
275    pub fn usage_total(&self, session: &str) -> Result<(i64, i64, i64), String> {
276        let conn = self.lock()?;
277        let mut stmt = conn
278            .prepare_cached(
279                "SELECT
280                    COALESCE(SUM(prompt_tokens), 0),
281                    COALESCE(SUM(completion_tokens), 0),
282                    COALESCE(SUM(total_tokens), 0)
283                 FROM usage
284                 WHERE session = ?1",
285            )
286            .map_err(|e| e.to_string())?;
287        stmt.query_row(params![session], |r| {
288            Ok((r.get::<_, i64>(0)?, r.get::<_, i64>(1)?, r.get::<_, i64>(2)?))
289        })
290        .map_err(|e| e.to_string())
291    }
292}
293
294impl MessageStore for Store {
295    fn load(&self, session: &str) -> Result<Vec<Message>, StoreError> {
296        Store::load(self, session).map_err(io_err)
297    }
298
299    fn append(&self, session: &str, message: &Message) -> Result<(), StoreError> {
300        Store::append(self, session, message).map_err(io_err)
301    }
302
303    fn log_tool(&self, session: &str, log: &ToolLog<'_>) -> Result<(), StoreError> {
304        Store::log_tool(self, session, log.name, log.args, log.result, log.duration_us)
305            .map_err(io_err)
306    }
307
308    fn clear(&self, session: &str) -> Result<(), StoreError> {
309        Store::clear(self, session).map_err(io_err)
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use agnt_core::Message;
317
318    fn tmp_path(name: &str) -> String {
319        let dir = std::env::temp_dir();
320        let pid = std::process::id();
321        let nanos = std::time::SystemTime::now()
322            .duration_since(std::time::UNIX_EPOCH)
323            .map(|d| d.as_nanos())
324            .unwrap_or(0);
325        dir.join(format!("agnt-store-{}-{}-{}.db", name, pid, nanos))
326            .to_string_lossy()
327            .into_owned()
328    }
329
330    fn user(content: &str) -> Message {
331        Message {
332            role: "user".into(),
333            content: Some(content.into()),
334            tool_calls: None,
335            tool_call_id: None,
336            name: None,
337        }
338    }
339
340    #[test]
341    fn wal_mode_is_active() {
342        let path = tmp_path("wal");
343        let store = Store::open(&path).unwrap();
344        let mode = store.journal_mode().unwrap().to_lowercase();
345        assert_eq!(mode, "wal", "expected WAL journal mode, got {}", mode);
346        let _ = std::fs::remove_file(&path);
347    }
348
349    #[test]
350    fn append_and_load_roundtrip() {
351        let path = tmp_path("append");
352        let store = Store::open(&path).unwrap();
353        store.append("s1", &user("hello")).unwrap();
354        store.append("s1", &user("world")).unwrap();
355        let msgs = store.load("s1").unwrap();
356        assert_eq!(msgs.len(), 2);
357        let _ = std::fs::remove_file(&path);
358    }
359
360    #[test]
361    fn append_many_batches_in_one_tx() {
362        let path = tmp_path("batch");
363        let store = Store::open(&path).unwrap();
364        let batch = vec![user("a"), user("b"), user("c")];
365        store.append_many("s1", &batch).unwrap();
366        // Append another single message — idx must continue after the batch.
367        store.append("s1", &user("d")).unwrap();
368        let msgs = store.load("s1").unwrap();
369        assert_eq!(msgs.len(), 4);
370        let _ = std::fs::remove_file(&path);
371    }
372
373    #[test]
374    fn append_many_empty_is_noop() {
375        let path = tmp_path("empty");
376        let store = Store::open(&path).unwrap();
377        store.append_many("s1", &[]).unwrap();
378        assert!(store.load("s1").unwrap().is_empty());
379        let _ = std::fs::remove_file(&path);
380    }
381
382    #[test]
383    fn with_transaction_commits() {
384        let path = tmp_path("tx");
385        let store = Store::open(&path).unwrap();
386        store
387            .with_transaction(|tx| {
388                tx.execute(
389                    "INSERT INTO messages (session, idx, json) VALUES (?1, ?2, ?3)",
390                    params!["s1", 0i64, "{\"role\":\"user\",\"content\":\"hi\"}"],
391                )
392                .map_err(|e| e.to_string())?;
393                Ok(())
394            })
395            .unwrap();
396        assert_eq!(store.load("s1").unwrap().len(), 1);
397        let _ = std::fs::remove_file(&path);
398    }
399
400    #[test]
401    fn with_transaction_rolls_back_on_err() {
402        let path = tmp_path("rollback");
403        let store = Store::open(&path).unwrap();
404        let res: Result<(), String> = store.with_transaction(|tx| {
405            tx.execute(
406                "INSERT INTO messages (session, idx, json) VALUES (?1, ?2, ?3)",
407                params!["s1", 0i64, "{\"role\":\"user\",\"content\":\"hi\"}"],
408            )
409            .map_err(|e| e.to_string())?;
410            Err("boom".to_string())
411        });
412        assert!(res.is_err());
413        assert!(store.load("s1").unwrap().is_empty());
414        let _ = std::fs::remove_file(&path);
415    }
416
417    #[test]
418    fn log_tool_and_stats() {
419        let path = tmp_path("tool");
420        let store = Store::open(&path).unwrap();
421        store.log_tool("s1", "fs_read", "{}", "ok", 100).unwrap();
422        store.log_tool("s1", "fs_read", "{}", "ok", 300).unwrap();
423        store.log_tool("s1", "http", "{}", "ok", 500).unwrap();
424        let stats = store.stats("s1").unwrap();
425        assert_eq!(stats.len(), 2);
426        assert_eq!(stats[0].0, "fs_read");
427        assert_eq!(stats[0].1, 2);
428        let _ = std::fs::remove_file(&path);
429    }
430
431    #[test]
432    fn usage_log_and_total() {
433        let path = tmp_path("usage");
434        let store = Store::open(&path).unwrap();
435        store.log_usage("s1", 0, 100, 50).unwrap();
436        store.log_usage("s1", 1, 200, 80).unwrap();
437        let (p, c, t) = store.usage_total("s1").unwrap();
438        assert_eq!(p, 300);
439        assert_eq!(c, 130);
440        assert_eq!(t, 430);
441
442        // Different session isolated.
443        let (p2, c2, t2) = store.usage_total("s2").unwrap();
444        assert_eq!((p2, c2, t2), (0, 0, 0));
445        let _ = std::fs::remove_file(&path);
446    }
447
448    #[test]
449    fn clear_wipes_usage_too() {
450        let path = tmp_path("clear");
451        let store = Store::open(&path).unwrap();
452        store.append("s1", &user("a")).unwrap();
453        store.log_tool("s1", "t", "{}", "ok", 1).unwrap();
454        store.log_usage("s1", 0, 10, 20).unwrap();
455        store.clear("s1").unwrap();
456        assert!(store.load("s1").unwrap().is_empty());
457        assert_eq!(store.usage_total("s1").unwrap(), (0, 0, 0));
458        assert!(store.stats("s1").unwrap().is_empty());
459        let _ = std::fs::remove_file(&path);
460    }
461}