1use std::path::Path;
2use std::sync::{Arc, Mutex};
3
4use rusqlite::{params, Connection};
5
6use crate::migrations;
7
8pub struct SessionDb {
9 conn: Arc<Mutex<Connection>>,
10}
11
12impl SessionDb {
13 pub fn open(home_dir: &Path) -> anyhow::Result<Self> {
14 let db_path = home_dir.join("state.db");
15 if let Some(parent) = db_path.parent() {
16 std::fs::create_dir_all(parent)?;
17 }
18 let conn = Connection::open(&db_path)?;
19 migrations::run(&conn)?;
20 Ok(Self {
21 conn: Arc::new(Mutex::new(conn)),
22 })
23 }
24
25 #[allow(clippy::too_many_arguments)]
26 pub fn save_session(
27 &self,
28 id: &str,
29 source: &str,
30 model: &str,
31 started_at: f64,
32 ended_at: f64,
33 input_tokens: u32,
34 output_tokens: u32,
35 message_count: u32,
36 ) -> anyhow::Result<()> {
37 let conn = self.conn.lock().unwrap();
38 conn.execute(
39 "INSERT OR REPLACE INTO sessions
40 (id, source, model, started_at, ended_at, input_tokens, output_tokens, message_count)
41 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
42 params![
43 id,
44 source,
45 model,
46 started_at,
47 ended_at,
48 input_tokens,
49 output_tokens,
50 message_count
51 ],
52 )?;
53 Ok(())
54 }
55
56 pub fn append_messages(
57 &self,
58 session_id: &str,
59 messages: &[(String, String, String, f64)], ) -> anyhow::Result<()> {
61 let mut conn = self.conn.lock().unwrap();
62 let tx = conn.transaction()?;
63 for (id, role, content, created_at) in messages {
64 let affected = tx.execute(
65 "INSERT OR IGNORE INTO messages (id, session_id, role, content, created_at)
66 VALUES (?1, ?2, ?3, ?4, ?5)",
67 params![id, session_id, role, content, created_at],
68 )?;
69 if affected > 0 {
70 let rowid = tx.last_insert_rowid();
71 tx.execute(
72 "INSERT INTO messages_fts(rowid, content) VALUES (?1, ?2)",
73 params![rowid, content],
74 )?;
75 }
76 }
77 tx.commit()?;
78 Ok(())
79 }
80
81 pub fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<String>> {
82 let conn = self.conn.lock().unwrap();
83 let mut stmt =
84 conn.prepare("SELECT content FROM messages_fts WHERE messages_fts MATCH ?1 LIMIT ?2")?;
85 let rows = stmt.query_map([query, &limit.to_string()], |row| row.get(0))?;
86 rows.collect::<Result<Vec<String>, _>>().map_err(Into::into)
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 fn open_in_memory() -> SessionDb {
95 let tmp = std::env::temp_dir().join(format!("garudust-test-{}", uuid::Uuid::new_v4()));
96 std::fs::create_dir_all(&tmp).unwrap();
97 SessionDb::open(&tmp).unwrap()
98 }
99
100 #[test]
101 fn save_and_retrieve_session() {
102 let db = open_in_memory();
103 db.save_session("s1", "test", "model-x", 1.0, 2.0, 10, 20, 3)
104 .unwrap();
105
106 db.save_session("s1", "test", "model-x", 1.0, 3.0, 10, 25, 4)
108 .unwrap();
109 }
110
111 #[test]
112 fn append_and_search_messages() {
113 let db = open_in_memory();
114 db.save_session("s1", "test", "gpt", 0.0, 1.0, 0, 0, 1)
115 .unwrap();
116
117 let msg_id = uuid::Uuid::new_v4().to_string();
118 db.append_messages(
119 "s1",
120 &[(msg_id, "user".into(), "hello garudust world".into(), 0.0)],
121 )
122 .unwrap();
123
124 let results = db.search("garudust", 10).unwrap();
125 assert_eq!(results.len(), 1);
126 assert!(results[0].contains("garudust"));
127 }
128
129 #[test]
130 fn search_returns_empty_for_no_match() {
131 let db = open_in_memory();
132 db.save_session("s1", "test", "gpt", 0.0, 1.0, 0, 0, 1)
133 .unwrap();
134 db.append_messages(
135 "s1",
136 &[(
137 uuid::Uuid::new_v4().to_string(),
138 "user".into(),
139 "hello world".into(),
140 0.0,
141 )],
142 )
143 .unwrap();
144
145 let results = db.search("zzznomatch", 10).unwrap();
146 assert!(results.is_empty());
147 }
148
149 #[test]
150 fn duplicate_message_id_is_ignored() {
151 let db = open_in_memory();
152 db.save_session("s1", "test", "gpt", 0.0, 1.0, 0, 0, 1)
153 .unwrap();
154 let msg = (
155 "fixed-id".to_string(),
156 "user".to_string(),
157 "unique content here".to_string(),
158 0.0f64,
159 );
160 db.append_messages("s1", std::slice::from_ref(&msg))
161 .unwrap();
162 db.append_messages("s1", &[msg]).unwrap(); let results = db.search("unique", 10).unwrap();
165 assert_eq!(results.len(), 1);
166 }
167
168 #[test]
169 fn search_respects_limit() {
170 let db = open_in_memory();
171 db.save_session("s1", "test", "gpt", 0.0, 1.0, 0, 0, 5)
172 .unwrap();
173 let messages: Vec<_> = (0..5)
174 .map(|i| {
175 (
176 uuid::Uuid::new_v4().to_string(),
177 "user".to_string(),
178 format!("searchterm entry number {i}"),
179 0.0f64,
180 )
181 })
182 .collect();
183 db.append_messages("s1", &messages).unwrap();
184
185 let results = db.search("searchterm", 3).unwrap();
186 assert_eq!(results.len(), 3);
187 }
188}