Skip to main content

agtrace_index/
db.rs

1use rusqlite::Connection;
2use std::path::Path;
3
4use crate::{Result, queries, records::*, schema};
5
6pub struct Database {
7    conn: Connection,
8}
9
10impl Database {
11    pub fn open(db_path: &Path) -> Result<Self> {
12        let conn = Connection::open(db_path)?;
13        let db = Self { conn };
14        schema::init_schema(&db.conn)?;
15        Ok(db)
16    }
17
18    pub fn open_in_memory() -> Result<Self> {
19        let conn = Connection::open_in_memory()?;
20        let db = Self { conn };
21        schema::init_schema(&db.conn)?;
22        Ok(db)
23    }
24
25    // Project operations
26    pub fn insert_or_update_project(&self, project: &ProjectRecord) -> Result<()> {
27        queries::project::insert_or_update(&self.conn, project)
28    }
29
30    pub fn get_project(&self, hash: &str) -> Result<Option<ProjectRecord>> {
31        queries::project::get(&self.conn, hash)
32    }
33
34    pub fn list_projects(&self) -> Result<Vec<ProjectRecord>> {
35        queries::project::list(&self.conn)
36    }
37
38    pub fn count_sessions_for_project(&self, project_hash: &str) -> Result<usize> {
39        queries::project::count_sessions(&self.conn, project_hash)
40    }
41
42    // Session operations
43    pub fn insert_or_update_session(&self, session: &SessionRecord) -> Result<()> {
44        queries::session::insert_or_update(&self.conn, session)
45    }
46
47    pub fn get_session_by_id(&self, session_id: &str) -> Result<Option<SessionSummary>> {
48        queries::session::get_by_id(&self.conn, session_id)
49    }
50
51    pub fn list_sessions(
52        &self,
53        project_hash: Option<&agtrace_types::ProjectHash>,
54        provider: Option<&str>,
55        order: agtrace_types::SessionOrder,
56        limit: Option<usize>,
57        top_level_only: bool,
58    ) -> Result<Vec<SessionSummary>> {
59        queries::session::list(
60            &self.conn,
61            project_hash,
62            provider,
63            order,
64            limit,
65            top_level_only,
66        )
67    }
68
69    pub fn find_session_by_prefix(&self, prefix: &str) -> Result<Option<String>> {
70        queries::session::find_by_prefix(&self.conn, prefix)
71    }
72
73    pub fn get_child_sessions(&self, parent_session_id: &str) -> Result<Vec<SessionSummary>> {
74        queries::session::get_children(&self.conn, parent_session_id)
75    }
76
77    // Log file operations
78    pub fn insert_or_update_log_file(&self, log_file: &LogFileRecord) -> Result<()> {
79        queries::log_file::insert_or_update(&self.conn, log_file)
80    }
81
82    pub fn get_session_files(&self, session_id: &str) -> Result<Vec<LogFileRecord>> {
83        queries::log_file::get_session_files(&self.conn, session_id)
84    }
85
86    pub fn get_all_log_files(&self) -> Result<Vec<LogFileRecord>> {
87        queries::log_file::get_all(&self.conn)
88    }
89
90    // Utility operations
91    pub fn vacuum(&self) -> Result<()> {
92        self.conn.execute("VACUUM", [])?;
93        println!("Database vacuumed successfully");
94        Ok(())
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use crate::schema::SCHEMA_VERSION;
102    use rusqlite::params;
103
104    #[test]
105    fn test_schema_initialization() {
106        let db = Database::open_in_memory().unwrap();
107
108        let projects = db.list_projects().unwrap();
109        assert_eq!(projects.len(), 0);
110    }
111
112    #[test]
113    fn test_insert_project() {
114        let db = Database::open_in_memory().unwrap();
115
116        let project = ProjectRecord {
117            hash: agtrace_types::ProjectHash::from("abc123"),
118            root_path: Some("/path/to/project".to_string()),
119            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
120        };
121
122        db.insert_or_update_project(&project).unwrap();
123
124        let retrieved = db.get_project("abc123").unwrap().unwrap();
125        assert_eq!(retrieved.hash, agtrace_types::ProjectHash::from("abc123"));
126        assert_eq!(retrieved.root_path, Some("/path/to/project".to_string()));
127    }
128
129    #[test]
130    fn test_insert_session_with_fk() {
131        let db = Database::open_in_memory().unwrap();
132
133        let project = ProjectRecord {
134            hash: agtrace_types::ProjectHash::from("abc123"),
135            root_path: Some("/path/to/project".to_string()),
136            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
137        };
138        db.insert_or_update_project(&project).unwrap();
139
140        let session = SessionRecord {
141            id: "session-001".to_string(),
142            project_hash: agtrace_types::ProjectHash::from("abc123"),
143            repository_hash: None,
144            provider: "claude".to_string(),
145            start_ts: Some("2025-12-10T10:05:00Z".to_string()),
146            end_ts: Some("2025-12-10T10:15:00Z".to_string()),
147            snippet: Some("Test session".to_string()),
148            is_valid: true,
149            parent_session_id: None,
150            spawned_by: None,
151        };
152
153        db.insert_or_update_session(&session).unwrap();
154
155        let sessions = db
156            .list_sessions(
157                Some(&agtrace_types::ProjectHash::from("abc123")),
158                None,
159                agtrace_types::SessionOrder::default(),
160                Some(10),
161                false,
162            )
163            .unwrap();
164        assert_eq!(sessions.len(), 1);
165        assert_eq!(sessions[0].id, "session-001");
166        assert_eq!(sessions[0].provider, "claude");
167    }
168
169    #[test]
170    fn test_insert_log_file() {
171        let db = Database::open_in_memory().unwrap();
172
173        let project = ProjectRecord {
174            hash: agtrace_types::ProjectHash::from("abc123"),
175            root_path: Some("/path/to/project".to_string()),
176            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
177        };
178        db.insert_or_update_project(&project).unwrap();
179
180        let session = SessionRecord {
181            id: "session-001".to_string(),
182            project_hash: agtrace_types::ProjectHash::from("abc123"),
183            repository_hash: None,
184            provider: "claude".to_string(),
185            start_ts: Some("2025-12-10T10:05:00Z".to_string()),
186            end_ts: None,
187            snippet: None,
188            is_valid: true,
189            parent_session_id: None,
190            spawned_by: None,
191        };
192        db.insert_or_update_session(&session).unwrap();
193
194        let log_file = LogFileRecord {
195            path: "/path/to/log.jsonl".to_string(),
196            session_id: "session-001".to_string(),
197            role: "main".to_string(),
198            file_size: Some(1024),
199            mod_time: Some("2025-12-10T10:05:00Z".to_string()),
200        };
201
202        db.insert_or_update_log_file(&log_file).unwrap();
203
204        let files = db.get_session_files("session-001").unwrap();
205        assert_eq!(files.len(), 1);
206        assert_eq!(files[0].path, "/path/to/log.jsonl");
207        assert_eq!(files[0].role, "main");
208    }
209
210    #[test]
211    fn test_list_sessions_query() {
212        let db = Database::open_in_memory().unwrap();
213
214        let project = ProjectRecord {
215            hash: agtrace_types::ProjectHash::from("abc123"),
216            root_path: Some("/path/to/project".to_string()),
217            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
218        };
219        db.insert_or_update_project(&project).unwrap();
220
221        for i in 1..=5 {
222            let session = SessionRecord {
223                id: format!("session-{:03}", i),
224                project_hash: agtrace_types::ProjectHash::from("abc123"),
225                repository_hash: None,
226                provider: "claude".to_string(),
227                start_ts: Some(format!("2025-12-10T10:{:02}:00Z", i)),
228                end_ts: None,
229                snippet: Some(format!("Session {}", i)),
230                is_valid: true,
231                parent_session_id: None,
232                spawned_by: None,
233            };
234            db.insert_or_update_session(&session).unwrap();
235        }
236
237        let sessions = db
238            .list_sessions(
239                Some(&agtrace_types::ProjectHash::from("abc123")),
240                None,
241                agtrace_types::SessionOrder::default(),
242                Some(10),
243                false,
244            )
245            .unwrap();
246        assert_eq!(sessions.len(), 5);
247
248        let sessions_limited = db
249            .list_sessions(
250                Some(&agtrace_types::ProjectHash::from("abc123")),
251                None,
252                agtrace_types::SessionOrder::default(),
253                Some(3),
254                false,
255            )
256            .unwrap();
257        assert_eq!(sessions_limited.len(), 3);
258    }
259
260    #[test]
261    fn test_count_sessions_for_project() {
262        let db = Database::open_in_memory().unwrap();
263
264        let project = ProjectRecord {
265            hash: agtrace_types::ProjectHash::from("abc123"),
266            root_path: Some("/path/to/project".to_string()),
267            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
268        };
269        db.insert_or_update_project(&project).unwrap();
270
271        for i in 1..=3 {
272            let session = SessionRecord {
273                id: format!("session-{:03}", i),
274                project_hash: agtrace_types::ProjectHash::from("abc123"),
275                repository_hash: None,
276                provider: "claude".to_string(),
277                start_ts: Some(format!("2025-12-10T10:{:02}:00Z", i)),
278                end_ts: None,
279                snippet: None,
280                is_valid: true,
281                parent_session_id: None,
282                spawned_by: None,
283            };
284            db.insert_or_update_session(&session).unwrap();
285        }
286
287        let count = db.count_sessions_for_project("abc123").unwrap();
288        assert_eq!(count, 3);
289    }
290
291    #[test]
292    fn test_schema_version_set_on_init() {
293        let db = Database::open_in_memory().unwrap();
294
295        let version: i32 = db
296            .conn
297            .query_row("PRAGMA user_version", [], |row| row.get(0))
298            .unwrap();
299
300        assert_eq!(version, SCHEMA_VERSION);
301    }
302
303    #[test]
304    fn test_schema_rebuild_on_version_mismatch() {
305        let conn = Connection::open_in_memory().unwrap();
306
307        conn.execute_batch(
308            r#"
309            CREATE TABLE projects (hash TEXT PRIMARY KEY);
310            CREATE TABLE sessions (id TEXT PRIMARY KEY);
311            PRAGMA user_version = 999;
312            "#,
313        )
314        .unwrap();
315
316        conn.execute(
317            "INSERT INTO projects (hash) VALUES (?1)",
318            params!["old_data"],
319        )
320        .unwrap();
321
322        let db = Database { conn };
323        schema::init_schema(&db.conn).unwrap();
324
325        let version: i32 = db
326            .conn
327            .query_row("PRAGMA user_version", [], |row| row.get(0))
328            .unwrap();
329        assert_eq!(version, SCHEMA_VERSION);
330
331        let count: i64 = db
332            .conn
333            .query_row("SELECT COUNT(*) FROM projects", [], |row| row.get(0))
334            .unwrap();
335        assert_eq!(count, 0);
336    }
337
338    #[test]
339    fn test_schema_preserved_on_version_match() {
340        let db = Database::open_in_memory().unwrap();
341
342        let project = ProjectRecord {
343            hash: agtrace_types::ProjectHash::from("abc123"),
344            root_path: Some("/path/to/project".to_string()),
345            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
346        };
347        db.insert_or_update_project(&project).unwrap();
348
349        schema::init_schema(&db.conn).unwrap();
350
351        let retrieved = db.get_project("abc123").unwrap();
352        assert!(retrieved.is_some());
353        assert_eq!(
354            retrieved.unwrap().hash,
355            agtrace_types::ProjectHash::from("abc123")
356        );
357    }
358}