Skip to main content

garudust_memory/
session_db.rs

1use std::path::Path;
2use std::sync::{Arc, Mutex};
3
4use rusqlite::{params, Connection};
5
6use crate::migrations;
7
8pub struct SessionDb {
9    conn: Arc<Mutex<Connection>>,
10}
11
12impl SessionDb {
13    pub fn open(home_dir: &Path) -> anyhow::Result<Self> {
14        let db_path = home_dir.join("state.db");
15        if let Some(parent) = db_path.parent() {
16            std::fs::create_dir_all(parent)?;
17        }
18        let conn = Connection::open(&db_path)?;
19        migrations::run(&conn)?;
20        Ok(Self {
21            conn: Arc::new(Mutex::new(conn)),
22        })
23    }
24
25    #[allow(clippy::too_many_arguments)]
26    pub fn save_session(
27        &self,
28        id: &str,
29        source: &str,
30        model: &str,
31        started_at: f64,
32        ended_at: f64,
33        input_tokens: u32,
34        output_tokens: u32,
35        message_count: u32,
36    ) -> anyhow::Result<()> {
37        let conn = self.conn.lock().unwrap();
38        conn.execute(
39            "INSERT OR REPLACE INTO sessions
40             (id, source, model, started_at, ended_at, input_tokens, output_tokens, message_count)
41             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
42            params![
43                id,
44                source,
45                model,
46                started_at,
47                ended_at,
48                input_tokens,
49                output_tokens,
50                message_count
51            ],
52        )?;
53        Ok(())
54    }
55
56    pub fn append_messages(
57        &self,
58        session_id: &str,
59        messages: &[(String, String, String, f64)], // (id, role, content_json, created_at)
60    ) -> anyhow::Result<()> {
61        let mut conn = self.conn.lock().unwrap();
62        let tx = conn.transaction()?;
63        for (id, role, content, created_at) in messages {
64            let affected = tx.execute(
65                "INSERT OR IGNORE INTO messages (id, session_id, role, content, created_at)
66                 VALUES (?1, ?2, ?3, ?4, ?5)",
67                params![id, session_id, role, content, created_at],
68            )?;
69            if affected > 0 {
70                let rowid = tx.last_insert_rowid();
71                tx.execute(
72                    "INSERT INTO messages_fts(rowid, content) VALUES (?1, ?2)",
73                    params![rowid, content],
74                )?;
75            }
76        }
77        tx.commit()?;
78        Ok(())
79    }
80
81    pub fn search(&self, query: &str, limit: usize) -> anyhow::Result<Vec<String>> {
82        let conn = self.conn.lock().unwrap();
83        let mut stmt =
84            conn.prepare("SELECT content FROM messages_fts WHERE messages_fts MATCH ?1 LIMIT ?2")?;
85        let rows = stmt.query_map([query, &limit.to_string()], |row| row.get(0))?;
86        rows.collect::<Result<Vec<String>, _>>().map_err(Into::into)
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    fn open_in_memory() -> SessionDb {
95        let tmp = std::env::temp_dir().join(format!("garudust-test-{}", uuid::Uuid::new_v4()));
96        std::fs::create_dir_all(&tmp).unwrap();
97        SessionDb::open(&tmp).unwrap()
98    }
99
100    #[test]
101    fn save_and_retrieve_session() {
102        let db = open_in_memory();
103        db.save_session("s1", "test", "model-x", 1.0, 2.0, 10, 20, 3)
104            .unwrap();
105
106        // Second save with same id should replace (no unique constraint error)
107        db.save_session("s1", "test", "model-x", 1.0, 3.0, 10, 25, 4)
108            .unwrap();
109    }
110
111    #[test]
112    fn append_and_search_messages() {
113        let db = open_in_memory();
114        db.save_session("s1", "test", "gpt", 0.0, 1.0, 0, 0, 1)
115            .unwrap();
116
117        let msg_id = uuid::Uuid::new_v4().to_string();
118        db.append_messages(
119            "s1",
120            &[(msg_id, "user".into(), "hello garudust world".into(), 0.0)],
121        )
122        .unwrap();
123
124        let results = db.search("garudust", 10).unwrap();
125        assert_eq!(results.len(), 1);
126        assert!(results[0].contains("garudust"));
127    }
128
129    #[test]
130    fn search_returns_empty_for_no_match() {
131        let db = open_in_memory();
132        db.save_session("s1", "test", "gpt", 0.0, 1.0, 0, 0, 1)
133            .unwrap();
134        db.append_messages(
135            "s1",
136            &[(
137                uuid::Uuid::new_v4().to_string(),
138                "user".into(),
139                "hello world".into(),
140                0.0,
141            )],
142        )
143        .unwrap();
144
145        let results = db.search("zzznomatch", 10).unwrap();
146        assert!(results.is_empty());
147    }
148
149    #[test]
150    fn duplicate_message_id_is_ignored() {
151        let db = open_in_memory();
152        db.save_session("s1", "test", "gpt", 0.0, 1.0, 0, 0, 1)
153            .unwrap();
154        let msg = (
155            "fixed-id".to_string(),
156            "user".to_string(),
157            "unique content here".to_string(),
158            0.0f64,
159        );
160        db.append_messages("s1", std::slice::from_ref(&msg))
161            .unwrap();
162        db.append_messages("s1", &[msg]).unwrap(); // should not error or duplicate
163
164        let results = db.search("unique", 10).unwrap();
165        assert_eq!(results.len(), 1);
166    }
167
168    #[test]
169    fn search_respects_limit() {
170        let db = open_in_memory();
171        db.save_session("s1", "test", "gpt", 0.0, 1.0, 0, 0, 5)
172            .unwrap();
173        let messages: Vec<_> = (0..5)
174            .map(|i| {
175                (
176                    uuid::Uuid::new_v4().to_string(),
177                    "user".to_string(),
178                    format!("searchterm entry number {i}"),
179                    0.0f64,
180                )
181            })
182            .collect();
183        db.append_messages("s1", &messages).unwrap();
184
185        let results = db.search("searchterm", 3).unwrap();
186        assert_eq!(results.len(), 3);
187    }
188}