agtrace_index/
db.rs

1use rusqlite::Connection;
2use std::path::Path;
3
4use crate::{Result, queries, records::*, schema};
5
6pub struct Database {
7    conn: Connection,
8}
9
10impl Database {
11    pub fn open(db_path: &Path) -> Result<Self> {
12        let conn = Connection::open(db_path)?;
13        let db = Self { conn };
14        schema::init_schema(&db.conn)?;
15        Ok(db)
16    }
17
18    pub fn open_in_memory() -> Result<Self> {
19        let conn = Connection::open_in_memory()?;
20        let db = Self { conn };
21        schema::init_schema(&db.conn)?;
22        Ok(db)
23    }
24
25    // Project operations
26    pub fn insert_or_update_project(&self, project: &ProjectRecord) -> Result<()> {
27        queries::project::insert_or_update(&self.conn, project)
28    }
29
30    pub fn get_project(&self, hash: &str) -> Result<Option<ProjectRecord>> {
31        queries::project::get(&self.conn, hash)
32    }
33
34    pub fn list_projects(&self) -> Result<Vec<ProjectRecord>> {
35        queries::project::list(&self.conn)
36    }
37
38    pub fn count_sessions_for_project(&self, project_hash: &str) -> Result<usize> {
39        queries::project::count_sessions(&self.conn, project_hash)
40    }
41
42    // Session operations
43    pub fn insert_or_update_session(&self, session: &SessionRecord) -> Result<()> {
44        queries::session::insert_or_update(&self.conn, session)
45    }
46
47    pub fn get_session_by_id(&self, session_id: &str) -> Result<Option<SessionSummary>> {
48        queries::session::get_by_id(&self.conn, session_id)
49    }
50
51    pub fn list_sessions(
52        &self,
53        project_hash: Option<&agtrace_types::ProjectHash>,
54        provider: Option<&str>,
55        order: agtrace_types::SessionOrder,
56        limit: Option<usize>,
57        top_level_only: bool,
58    ) -> Result<Vec<SessionSummary>> {
59        queries::session::list(
60            &self.conn,
61            project_hash,
62            provider,
63            order,
64            limit,
65            top_level_only,
66        )
67    }
68
69    pub fn find_session_by_prefix(&self, prefix: &str) -> Result<Option<String>> {
70        queries::session::find_by_prefix(&self.conn, prefix)
71    }
72
73    pub fn get_child_sessions(&self, parent_session_id: &str) -> Result<Vec<SessionSummary>> {
74        queries::session::get_children(&self.conn, parent_session_id)
75    }
76
77    // Log file operations
78    pub fn insert_or_update_log_file(&self, log_file: &LogFileRecord) -> Result<()> {
79        queries::log_file::insert_or_update(&self.conn, log_file)
80    }
81
82    pub fn get_session_files(&self, session_id: &str) -> Result<Vec<LogFileRecord>> {
83        queries::log_file::get_session_files(&self.conn, session_id)
84    }
85
86    pub fn get_all_log_files(&self) -> Result<Vec<LogFileRecord>> {
87        queries::log_file::get_all(&self.conn)
88    }
89
90    // Utility operations
91    pub fn vacuum(&self) -> Result<()> {
92        self.conn.execute("VACUUM", [])?;
93        println!("Database vacuumed successfully");
94        Ok(())
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use crate::schema::SCHEMA_VERSION;
102    use rusqlite::params;
103
104    #[test]
105    fn test_schema_initialization() {
106        let db = Database::open_in_memory().unwrap();
107
108        let projects = db.list_projects().unwrap();
109        assert_eq!(projects.len(), 0);
110    }
111
112    #[test]
113    fn test_insert_project() {
114        let db = Database::open_in_memory().unwrap();
115
116        let project = ProjectRecord {
117            hash: agtrace_types::ProjectHash::from("abc123"),
118            root_path: Some("/path/to/project".to_string()),
119            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
120        };
121
122        db.insert_or_update_project(&project).unwrap();
123
124        let retrieved = db.get_project("abc123").unwrap().unwrap();
125        assert_eq!(retrieved.hash, agtrace_types::ProjectHash::from("abc123"));
126        assert_eq!(retrieved.root_path, Some("/path/to/project".to_string()));
127    }
128
129    #[test]
130    fn test_insert_session_with_fk() {
131        let db = Database::open_in_memory().unwrap();
132
133        let project = ProjectRecord {
134            hash: agtrace_types::ProjectHash::from("abc123"),
135            root_path: Some("/path/to/project".to_string()),
136            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
137        };
138        db.insert_or_update_project(&project).unwrap();
139
140        let session = SessionRecord {
141            id: "session-001".to_string(),
142            project_hash: agtrace_types::ProjectHash::from("abc123"),
143            provider: "claude".to_string(),
144            start_ts: Some("2025-12-10T10:05:00Z".to_string()),
145            end_ts: Some("2025-12-10T10:15:00Z".to_string()),
146            snippet: Some("Test session".to_string()),
147            is_valid: true,
148            parent_session_id: None,
149            spawned_by: None,
150        };
151
152        db.insert_or_update_session(&session).unwrap();
153
154        let sessions = db
155            .list_sessions(
156                Some(&agtrace_types::ProjectHash::from("abc123")),
157                None,
158                agtrace_types::SessionOrder::default(),
159                Some(10),
160                false,
161            )
162            .unwrap();
163        assert_eq!(sessions.len(), 1);
164        assert_eq!(sessions[0].id, "session-001");
165        assert_eq!(sessions[0].provider, "claude");
166    }
167
168    #[test]
169    fn test_insert_log_file() {
170        let db = Database::open_in_memory().unwrap();
171
172        let project = ProjectRecord {
173            hash: agtrace_types::ProjectHash::from("abc123"),
174            root_path: Some("/path/to/project".to_string()),
175            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
176        };
177        db.insert_or_update_project(&project).unwrap();
178
179        let session = SessionRecord {
180            id: "session-001".to_string(),
181            project_hash: agtrace_types::ProjectHash::from("abc123"),
182            provider: "claude".to_string(),
183            start_ts: Some("2025-12-10T10:05:00Z".to_string()),
184            end_ts: None,
185            snippet: None,
186            is_valid: true,
187            parent_session_id: None,
188            spawned_by: None,
189        };
190        db.insert_or_update_session(&session).unwrap();
191
192        let log_file = LogFileRecord {
193            path: "/path/to/log.jsonl".to_string(),
194            session_id: "session-001".to_string(),
195            role: "main".to_string(),
196            file_size: Some(1024),
197            mod_time: Some("2025-12-10T10:05:00Z".to_string()),
198        };
199
200        db.insert_or_update_log_file(&log_file).unwrap();
201
202        let files = db.get_session_files("session-001").unwrap();
203        assert_eq!(files.len(), 1);
204        assert_eq!(files[0].path, "/path/to/log.jsonl");
205        assert_eq!(files[0].role, "main");
206    }
207
208    #[test]
209    fn test_list_sessions_query() {
210        let db = Database::open_in_memory().unwrap();
211
212        let project = ProjectRecord {
213            hash: agtrace_types::ProjectHash::from("abc123"),
214            root_path: Some("/path/to/project".to_string()),
215            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
216        };
217        db.insert_or_update_project(&project).unwrap();
218
219        for i in 1..=5 {
220            let session = SessionRecord {
221                id: format!("session-{:03}", i),
222                project_hash: agtrace_types::ProjectHash::from("abc123"),
223                provider: "claude".to_string(),
224                start_ts: Some(format!("2025-12-10T10:{:02}:00Z", i)),
225                end_ts: None,
226                snippet: Some(format!("Session {}", i)),
227                is_valid: true,
228                parent_session_id: None,
229                spawned_by: None,
230            };
231            db.insert_or_update_session(&session).unwrap();
232        }
233
234        let sessions = db
235            .list_sessions(
236                Some(&agtrace_types::ProjectHash::from("abc123")),
237                None,
238                agtrace_types::SessionOrder::default(),
239                Some(10),
240                false,
241            )
242            .unwrap();
243        assert_eq!(sessions.len(), 5);
244
245        let sessions_limited = db
246            .list_sessions(
247                Some(&agtrace_types::ProjectHash::from("abc123")),
248                None,
249                agtrace_types::SessionOrder::default(),
250                Some(3),
251                false,
252            )
253            .unwrap();
254        assert_eq!(sessions_limited.len(), 3);
255    }
256
257    #[test]
258    fn test_count_sessions_for_project() {
259        let db = Database::open_in_memory().unwrap();
260
261        let project = ProjectRecord {
262            hash: agtrace_types::ProjectHash::from("abc123"),
263            root_path: Some("/path/to/project".to_string()),
264            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
265        };
266        db.insert_or_update_project(&project).unwrap();
267
268        for i in 1..=3 {
269            let session = SessionRecord {
270                id: format!("session-{:03}", i),
271                project_hash: agtrace_types::ProjectHash::from("abc123"),
272                provider: "claude".to_string(),
273                start_ts: Some(format!("2025-12-10T10:{:02}:00Z", i)),
274                end_ts: None,
275                snippet: None,
276                is_valid: true,
277                parent_session_id: None,
278                spawned_by: None,
279            };
280            db.insert_or_update_session(&session).unwrap();
281        }
282
283        let count = db.count_sessions_for_project("abc123").unwrap();
284        assert_eq!(count, 3);
285    }
286
287    #[test]
288    fn test_schema_version_set_on_init() {
289        let db = Database::open_in_memory().unwrap();
290
291        let version: i32 = db
292            .conn
293            .query_row("PRAGMA user_version", [], |row| row.get(0))
294            .unwrap();
295
296        assert_eq!(version, SCHEMA_VERSION);
297    }
298
299    #[test]
300    fn test_schema_rebuild_on_version_mismatch() {
301        let conn = Connection::open_in_memory().unwrap();
302
303        conn.execute_batch(
304            r#"
305            CREATE TABLE projects (hash TEXT PRIMARY KEY);
306            CREATE TABLE sessions (id TEXT PRIMARY KEY);
307            PRAGMA user_version = 999;
308            "#,
309        )
310        .unwrap();
311
312        conn.execute(
313            "INSERT INTO projects (hash) VALUES (?1)",
314            params!["old_data"],
315        )
316        .unwrap();
317
318        let db = Database { conn };
319        schema::init_schema(&db.conn).unwrap();
320
321        let version: i32 = db
322            .conn
323            .query_row("PRAGMA user_version", [], |row| row.get(0))
324            .unwrap();
325        assert_eq!(version, SCHEMA_VERSION);
326
327        let count: i64 = db
328            .conn
329            .query_row("SELECT COUNT(*) FROM projects", [], |row| row.get(0))
330            .unwrap();
331        assert_eq!(count, 0);
332    }
333
334    #[test]
335    fn test_schema_preserved_on_version_match() {
336        let db = Database::open_in_memory().unwrap();
337
338        let project = ProjectRecord {
339            hash: agtrace_types::ProjectHash::from("abc123"),
340            root_path: Some("/path/to/project".to_string()),
341            last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
342        };
343        db.insert_or_update_project(&project).unwrap();
344
345        schema::init_schema(&db.conn).unwrap();
346
347        let retrieved = db.get_project("abc123").unwrap();
348        assert!(retrieved.is_some());
349        assert_eq!(
350            retrieved.unwrap().hash,
351            agtrace_types::ProjectHash::from("abc123")
352        );
353    }
354}