1use agnt_core::{Message, MessageStore, StoreError, ToolLog};
2use rusqlite::{params, Connection};
3use std::sync::Mutex;
4use tracing::{debug, info};
5
6fn io_err(e: impl std::fmt::Display) -> StoreError {
7 StoreError::Io(e.to_string())
8}
9
10pub struct Store {
20 conn: Mutex<Connection>,
21}
22
23impl Store {
24 pub fn open(path: &str) -> Result<Self, String> {
25 let conn = Connection::open(path).map_err(|e| e.to_string())?;
26
27 let mode: String = conn
30 .query_row("PRAGMA journal_mode=WAL", [], |r| r.get(0))
31 .map_err(|e| e.to_string())?;
32 conn.pragma_update(None, "synchronous", &"NORMAL")
33 .map_err(|e| e.to_string())?;
34 info!(path = %path, journal_mode = %mode, "agnt-store opened");
35
36 conn.execute(
37 "CREATE TABLE IF NOT EXISTS messages (
38 session TEXT NOT NULL,
39 idx INTEGER NOT NULL,
40 json TEXT NOT NULL,
41 PRIMARY KEY (session, idx)
42 )",
43 [],
44 )
45 .map_err(|e| e.to_string())?;
46 conn.execute(
47 "CREATE TABLE IF NOT EXISTS tool_calls (
48 session TEXT NOT NULL,
49 ts INTEGER NOT NULL,
50 name TEXT NOT NULL,
51 args TEXT NOT NULL,
52 result TEXT NOT NULL,
53 duration_us INTEGER NOT NULL
54 )",
55 [],
56 )
57 .map_err(|e| e.to_string())?;
58 conn.execute(
59 "CREATE TABLE IF NOT EXISTS usage (
60 session TEXT NOT NULL,
61 message_idx INTEGER NOT NULL,
62 prompt_tokens INTEGER,
63 completion_tokens INTEGER,
64 total_tokens INTEGER,
65 PRIMARY KEY (session, message_idx)
66 )",
67 [],
68 )
69 .map_err(|e| e.to_string())?;
70 Ok(Self {
71 conn: Mutex::new(conn),
72 })
73 }
74
75 fn lock(&self) -> Result<std::sync::MutexGuard<'_, Connection>, String> {
76 self.conn
77 .lock()
78 .map_err(|e| format!("store mutex poisoned: {}", e))
79 }
80
81 pub fn journal_mode(&self) -> Result<String, String> {
84 let conn = self.lock()?;
85 conn.query_row("PRAGMA journal_mode", [], |r| r.get::<_, String>(0))
86 .map_err(|e| e.to_string())
87 }
88
89 pub fn log_tool(
90 &self,
91 session: &str,
92 name: &str,
93 args: &str,
94 result: &str,
95 duration_us: u64,
96 ) -> Result<(), String> {
97 let ts = std::time::SystemTime::now()
98 .duration_since(std::time::UNIX_EPOCH)
99 .map(|d| d.as_secs() as i64)
100 .unwrap_or(0);
101 let conn = self.lock()?;
102 let mut stmt = conn
103 .prepare_cached(
104 "INSERT INTO tool_calls (session, ts, name, args, result, duration_us)
105 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
106 )
107 .map_err(|e| e.to_string())?;
108 stmt.execute(params![session, ts, name, args, result, duration_us as i64])
109 .map_err(|e| e.to_string())?;
110 Ok(())
111 }
112
113 pub fn load(&self, session: &str) -> Result<Vec<Message>, String> {
114 let conn = self.lock()?;
115 let mut stmt = conn
116 .prepare_cached("SELECT json FROM messages WHERE session = ?1 ORDER BY idx")
117 .map_err(|e| e.to_string())?;
118 let rows = stmt
119 .query_map(params![session], |r| r.get::<_, String>(0))
120 .map_err(|e| e.to_string())?;
121 let mut out = Vec::new();
122 for r in rows {
123 let s = r.map_err(|e| e.to_string())?;
124 let m: Message = serde_json::from_str(&s).map_err(|e| e.to_string())?;
125 out.push(m);
126 }
127 Ok(out)
128 }
129
130 pub fn append(&self, session: &str, msg: &Message) -> Result<(), String> {
133 let json = serde_json::to_string(msg).map_err(|e| e.to_string())?;
134 let conn = self.lock()?;
135 let mut stmt = conn
136 .prepare_cached(
137 "INSERT INTO messages (session, idx, json)
138 SELECT ?1, COALESCE(MAX(idx), -1) + 1, ?2
139 FROM messages
140 WHERE session = ?1",
141 )
142 .map_err(|e| e.to_string())?;
143 stmt.execute(params![session, json])
144 .map_err(|e| e.to_string())?;
145 Ok(())
146 }
147
148 #[tracing::instrument(skip(self, messages), fields(session = %session, count = messages.len()))]
151 pub fn append_many(&self, session: &str, messages: &[Message]) -> Result<(), String> {
152 if messages.is_empty() {
153 return Ok(());
154 }
155 let jsons: Vec<String> = messages
157 .iter()
158 .map(serde_json::to_string)
159 .collect::<Result<_, _>>()
160 .map_err(|e| e.to_string())?;
161
162 let mut conn = self.lock()?;
163 let tx = conn.transaction().map_err(|e| e.to_string())?;
164 {
165 let mut next: i64 = tx
167 .query_row(
168 "SELECT COALESCE(MAX(idx), -1) + 1 FROM messages WHERE session = ?1",
169 params![session],
170 |r| r.get(0),
171 )
172 .map_err(|e| e.to_string())?;
173 let mut stmt = tx
174 .prepare_cached(
175 "INSERT INTO messages (session, idx, json) VALUES (?1, ?2, ?3)",
176 )
177 .map_err(|e| e.to_string())?;
178 for json in &jsons {
179 stmt.execute(params![session, next, json])
180 .map_err(|e| e.to_string())?;
181 next += 1;
182 }
183 }
184 tx.commit().map_err(|e| e.to_string())?;
185 Ok(())
186 }
187
188 pub fn with_transaction<F, T>(&self, f: F) -> Result<T, String>
192 where
193 F: FnOnce(&rusqlite::Transaction<'_>) -> Result<T, String>,
194 {
195 let mut conn = self.lock()?;
196 let tx = conn.transaction().map_err(|e| e.to_string())?;
197 let out = f(&tx)?;
198 tx.commit().map_err(|e| e.to_string())?;
199 Ok(out)
200 }
201
202 #[tracing::instrument(skip(self), fields(session = %session))]
203 pub fn clear(&self, session: &str) -> Result<(), String> {
204 debug!("clearing session");
205 let conn = self.lock()?;
206 {
207 let mut stmt = conn
208 .prepare_cached("DELETE FROM messages WHERE session = ?1")
209 .map_err(|e| e.to_string())?;
210 stmt.execute(params![session]).map_err(|e| e.to_string())?;
211 }
212 {
213 let mut stmt = conn
214 .prepare_cached("DELETE FROM tool_calls WHERE session = ?1")
215 .map_err(|e| e.to_string())?;
216 stmt.execute(params![session]).map_err(|e| e.to_string())?;
217 }
218 {
219 let mut stmt = conn
220 .prepare_cached("DELETE FROM usage WHERE session = ?1")
221 .map_err(|e| e.to_string())?;
222 stmt.execute(params![session]).map_err(|e| e.to_string())?;
223 }
224 Ok(())
225 }
226
227 pub fn stats(&self, session: &str) -> Result<Vec<(String, i64, i64, i64)>, String> {
229 let conn = self.lock()?;
230 let mut stmt = conn
231 .prepare_cached(
232 "SELECT name, COUNT(*), CAST(AVG(duration_us) AS INTEGER), MAX(duration_us)
233 FROM tool_calls
234 WHERE session = ?1
235 GROUP BY name
236 ORDER BY COUNT(*) DESC",
237 )
238 .map_err(|e| e.to_string())?;
239 let rows = stmt
240 .query_map(params![session], |r| {
241 Ok((
242 r.get::<_, String>(0)?,
243 r.get::<_, i64>(1)?,
244 r.get::<_, i64>(2)?,
245 r.get::<_, i64>(3)?,
246 ))
247 })
248 .map_err(|e| e.to_string())?;
249 rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())
250 }
251
252 pub fn log_usage(
254 &self,
255 session: &str,
256 message_idx: i64,
257 prompt: u32,
258 completion: u32,
259 ) -> Result<(), String> {
260 let total = prompt as i64 + completion as i64;
261 let conn = self.lock()?;
262 let mut stmt = conn
263 .prepare_cached(
264 "INSERT OR REPLACE INTO usage
265 (session, message_idx, prompt_tokens, completion_tokens, total_tokens)
266 VALUES (?1, ?2, ?3, ?4, ?5)",
267 )
268 .map_err(|e| e.to_string())?;
269 stmt.execute(params![session, message_idx, prompt as i64, completion as i64, total])
270 .map_err(|e| e.to_string())?;
271 Ok(())
272 }
273
274 pub fn usage_total(&self, session: &str) -> Result<(i64, i64, i64), String> {
276 let conn = self.lock()?;
277 let mut stmt = conn
278 .prepare_cached(
279 "SELECT
280 COALESCE(SUM(prompt_tokens), 0),
281 COALESCE(SUM(completion_tokens), 0),
282 COALESCE(SUM(total_tokens), 0)
283 FROM usage
284 WHERE session = ?1",
285 )
286 .map_err(|e| e.to_string())?;
287 stmt.query_row(params![session], |r| {
288 Ok((r.get::<_, i64>(0)?, r.get::<_, i64>(1)?, r.get::<_, i64>(2)?))
289 })
290 .map_err(|e| e.to_string())
291 }
292}
293
294impl MessageStore for Store {
295 fn load(&self, session: &str) -> Result<Vec<Message>, StoreError> {
296 Store::load(self, session).map_err(io_err)
297 }
298
299 fn append(&self, session: &str, message: &Message) -> Result<(), StoreError> {
300 Store::append(self, session, message).map_err(io_err)
301 }
302
303 fn log_tool(&self, session: &str, log: &ToolLog<'_>) -> Result<(), StoreError> {
304 Store::log_tool(self, session, log.name, log.args, log.result, log.duration_us)
305 .map_err(io_err)
306 }
307
308 fn clear(&self, session: &str) -> Result<(), StoreError> {
309 Store::clear(self, session).map_err(io_err)
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use agnt_core::Message;
317
318 fn tmp_path(name: &str) -> String {
319 let dir = std::env::temp_dir();
320 let pid = std::process::id();
321 let nanos = std::time::SystemTime::now()
322 .duration_since(std::time::UNIX_EPOCH)
323 .map(|d| d.as_nanos())
324 .unwrap_or(0);
325 dir.join(format!("agnt-store-{}-{}-{}.db", name, pid, nanos))
326 .to_string_lossy()
327 .into_owned()
328 }
329
330 fn user(content: &str) -> Message {
331 Message {
332 role: "user".into(),
333 content: Some(content.into()),
334 tool_calls: None,
335 tool_call_id: None,
336 name: None,
337 }
338 }
339
340 #[test]
341 fn wal_mode_is_active() {
342 let path = tmp_path("wal");
343 let store = Store::open(&path).unwrap();
344 let mode = store.journal_mode().unwrap().to_lowercase();
345 assert_eq!(mode, "wal", "expected WAL journal mode, got {}", mode);
346 let _ = std::fs::remove_file(&path);
347 }
348
349 #[test]
350 fn append_and_load_roundtrip() {
351 let path = tmp_path("append");
352 let store = Store::open(&path).unwrap();
353 store.append("s1", &user("hello")).unwrap();
354 store.append("s1", &user("world")).unwrap();
355 let msgs = store.load("s1").unwrap();
356 assert_eq!(msgs.len(), 2);
357 let _ = std::fs::remove_file(&path);
358 }
359
360 #[test]
361 fn append_many_batches_in_one_tx() {
362 let path = tmp_path("batch");
363 let store = Store::open(&path).unwrap();
364 let batch = vec![user("a"), user("b"), user("c")];
365 store.append_many("s1", &batch).unwrap();
366 store.append("s1", &user("d")).unwrap();
368 let msgs = store.load("s1").unwrap();
369 assert_eq!(msgs.len(), 4);
370 let _ = std::fs::remove_file(&path);
371 }
372
373 #[test]
374 fn append_many_empty_is_noop() {
375 let path = tmp_path("empty");
376 let store = Store::open(&path).unwrap();
377 store.append_many("s1", &[]).unwrap();
378 assert!(store.load("s1").unwrap().is_empty());
379 let _ = std::fs::remove_file(&path);
380 }
381
382 #[test]
383 fn with_transaction_commits() {
384 let path = tmp_path("tx");
385 let store = Store::open(&path).unwrap();
386 store
387 .with_transaction(|tx| {
388 tx.execute(
389 "INSERT INTO messages (session, idx, json) VALUES (?1, ?2, ?3)",
390 params!["s1", 0i64, "{\"role\":\"user\",\"content\":\"hi\"}"],
391 )
392 .map_err(|e| e.to_string())?;
393 Ok(())
394 })
395 .unwrap();
396 assert_eq!(store.load("s1").unwrap().len(), 1);
397 let _ = std::fs::remove_file(&path);
398 }
399
400 #[test]
401 fn with_transaction_rolls_back_on_err() {
402 let path = tmp_path("rollback");
403 let store = Store::open(&path).unwrap();
404 let res: Result<(), String> = store.with_transaction(|tx| {
405 tx.execute(
406 "INSERT INTO messages (session, idx, json) VALUES (?1, ?2, ?3)",
407 params!["s1", 0i64, "{\"role\":\"user\",\"content\":\"hi\"}"],
408 )
409 .map_err(|e| e.to_string())?;
410 Err("boom".to_string())
411 });
412 assert!(res.is_err());
413 assert!(store.load("s1").unwrap().is_empty());
414 let _ = std::fs::remove_file(&path);
415 }
416
417 #[test]
418 fn log_tool_and_stats() {
419 let path = tmp_path("tool");
420 let store = Store::open(&path).unwrap();
421 store.log_tool("s1", "fs_read", "{}", "ok", 100).unwrap();
422 store.log_tool("s1", "fs_read", "{}", "ok", 300).unwrap();
423 store.log_tool("s1", "http", "{}", "ok", 500).unwrap();
424 let stats = store.stats("s1").unwrap();
425 assert_eq!(stats.len(), 2);
426 assert_eq!(stats[0].0, "fs_read");
427 assert_eq!(stats[0].1, 2);
428 let _ = std::fs::remove_file(&path);
429 }
430
431 #[test]
432 fn usage_log_and_total() {
433 let path = tmp_path("usage");
434 let store = Store::open(&path).unwrap();
435 store.log_usage("s1", 0, 100, 50).unwrap();
436 store.log_usage("s1", 1, 200, 80).unwrap();
437 let (p, c, t) = store.usage_total("s1").unwrap();
438 assert_eq!(p, 300);
439 assert_eq!(c, 130);
440 assert_eq!(t, 430);
441
442 let (p2, c2, t2) = store.usage_total("s2").unwrap();
444 assert_eq!((p2, c2, t2), (0, 0, 0));
445 let _ = std::fs::remove_file(&path);
446 }
447
448 #[test]
449 fn clear_wipes_usage_too() {
450 let path = tmp_path("clear");
451 let store = Store::open(&path).unwrap();
452 store.append("s1", &user("a")).unwrap();
453 store.log_tool("s1", "t", "{}", "ok", 1).unwrap();
454 store.log_usage("s1", 0, 10, 20).unwrap();
455 store.clear("s1").unwrap();
456 assert!(store.load("s1").unwrap().is_empty());
457 assert_eq!(store.usage_total("s1").unwrap(), (0, 0, 0));
458 assert!(store.stats("s1").unwrap().is_empty());
459 let _ = std::fs::remove_file(&path);
460 }
461}