1use anyhow::{Context, Result};
2use rusqlite::Connection;
3use std::path::Path;
4
5use crate::{queries, records::*, schema};
6
7pub struct Database {
8 conn: Connection,
9}
10
11impl Database {
12 pub fn open(db_path: &Path) -> Result<Self> {
13 let conn = Connection::open(db_path)
14 .with_context(|| format!("Failed to open database: {}", db_path.display()))?;
15
16 let db = Self { conn };
17 schema::init_schema(&db.conn)?;
18 Ok(db)
19 }
20
21 pub fn open_in_memory() -> Result<Self> {
22 let conn = Connection::open_in_memory()?;
23 let db = Self { conn };
24 schema::init_schema(&db.conn)?;
25 Ok(db)
26 }
27
28 pub fn insert_or_update_project(&self, project: &ProjectRecord) -> Result<()> {
30 queries::project::insert_or_update(&self.conn, project)
31 }
32
33 pub fn get_project(&self, hash: &str) -> Result<Option<ProjectRecord>> {
34 queries::project::get(&self.conn, hash)
35 }
36
37 pub fn list_projects(&self) -> Result<Vec<ProjectRecord>> {
38 queries::project::list(&self.conn)
39 }
40
41 pub fn count_sessions_for_project(&self, project_hash: &str) -> Result<usize> {
42 queries::project::count_sessions(&self.conn, project_hash)
43 }
44
45 pub fn insert_or_update_session(&self, session: &SessionRecord) -> Result<()> {
47 queries::session::insert_or_update(&self.conn, session)
48 }
49
50 pub fn get_session_by_id(&self, session_id: &str) -> Result<Option<SessionSummary>> {
51 queries::session::get_by_id(&self.conn, session_id)
52 }
53
54 pub fn list_sessions(
55 &self,
56 project_hash: Option<&agtrace_types::ProjectHash>,
57 limit: usize,
58 ) -> Result<Vec<SessionSummary>> {
59 queries::session::list(&self.conn, project_hash, limit)
60 }
61
62 pub fn find_session_by_prefix(&self, prefix: &str) -> Result<Option<String>> {
63 queries::session::find_by_prefix(&self.conn, prefix)
64 }
65
66 pub fn insert_or_update_log_file(&self, log_file: &LogFileRecord) -> Result<()> {
68 queries::log_file::insert_or_update(&self.conn, log_file)
69 }
70
71 pub fn get_session_files(&self, session_id: &str) -> Result<Vec<LogFileRecord>> {
72 queries::log_file::get_session_files(&self.conn, session_id)
73 }
74
75 pub fn get_all_log_files(&self) -> Result<Vec<LogFileRecord>> {
76 queries::log_file::get_all(&self.conn)
77 }
78
79 pub fn vacuum(&self) -> Result<()> {
81 self.conn.execute("VACUUM", [])?;
82 println!("Database vacuumed successfully");
83 Ok(())
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use crate::schema::SCHEMA_VERSION;
91 use rusqlite::params;
92
93 #[test]
94 fn test_schema_initialization() {
95 let db = Database::open_in_memory().unwrap();
96
97 let projects = db.list_projects().unwrap();
98 assert_eq!(projects.len(), 0);
99 }
100
101 #[test]
102 fn test_insert_project() {
103 let db = Database::open_in_memory().unwrap();
104
105 let project = ProjectRecord {
106 hash: agtrace_types::ProjectHash::from("abc123"),
107 root_path: Some("/path/to/project".to_string()),
108 last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
109 };
110
111 db.insert_or_update_project(&project).unwrap();
112
113 let retrieved = db.get_project("abc123").unwrap().unwrap();
114 assert_eq!(retrieved.hash, agtrace_types::ProjectHash::from("abc123"));
115 assert_eq!(retrieved.root_path, Some("/path/to/project".to_string()));
116 }
117
118 #[test]
119 fn test_insert_session_with_fk() {
120 let db = Database::open_in_memory().unwrap();
121
122 let project = ProjectRecord {
123 hash: agtrace_types::ProjectHash::from("abc123"),
124 root_path: Some("/path/to/project".to_string()),
125 last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
126 };
127 db.insert_or_update_project(&project).unwrap();
128
129 let session = SessionRecord {
130 id: "session-001".to_string(),
131 project_hash: agtrace_types::ProjectHash::from("abc123"),
132 provider: "claude".to_string(),
133 start_ts: Some("2025-12-10T10:05:00Z".to_string()),
134 end_ts: Some("2025-12-10T10:15:00Z".to_string()),
135 snippet: Some("Test session".to_string()),
136 is_valid: true,
137 };
138
139 db.insert_or_update_session(&session).unwrap();
140
141 let sessions = db
142 .list_sessions(Some(&agtrace_types::ProjectHash::from("abc123")), 10)
143 .unwrap();
144 assert_eq!(sessions.len(), 1);
145 assert_eq!(sessions[0].id, "session-001");
146 assert_eq!(sessions[0].provider, "claude");
147 }
148
149 #[test]
150 fn test_insert_log_file() {
151 let db = Database::open_in_memory().unwrap();
152
153 let project = ProjectRecord {
154 hash: agtrace_types::ProjectHash::from("abc123"),
155 root_path: Some("/path/to/project".to_string()),
156 last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
157 };
158 db.insert_or_update_project(&project).unwrap();
159
160 let session = SessionRecord {
161 id: "session-001".to_string(),
162 project_hash: agtrace_types::ProjectHash::from("abc123"),
163 provider: "claude".to_string(),
164 start_ts: Some("2025-12-10T10:05:00Z".to_string()),
165 end_ts: None,
166 snippet: None,
167 is_valid: true,
168 };
169 db.insert_or_update_session(&session).unwrap();
170
171 let log_file = LogFileRecord {
172 path: "/path/to/log.jsonl".to_string(),
173 session_id: "session-001".to_string(),
174 role: "main".to_string(),
175 file_size: Some(1024),
176 mod_time: Some("2025-12-10T10:05:00Z".to_string()),
177 };
178
179 db.insert_or_update_log_file(&log_file).unwrap();
180
181 let files = db.get_session_files("session-001").unwrap();
182 assert_eq!(files.len(), 1);
183 assert_eq!(files[0].path, "/path/to/log.jsonl");
184 assert_eq!(files[0].role, "main");
185 }
186
187 #[test]
188 fn test_list_sessions_query() {
189 let db = Database::open_in_memory().unwrap();
190
191 let project = ProjectRecord {
192 hash: agtrace_types::ProjectHash::from("abc123"),
193 root_path: Some("/path/to/project".to_string()),
194 last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
195 };
196 db.insert_or_update_project(&project).unwrap();
197
198 for i in 1..=5 {
199 let session = SessionRecord {
200 id: format!("session-{:03}", i),
201 project_hash: agtrace_types::ProjectHash::from("abc123"),
202 provider: "claude".to_string(),
203 start_ts: Some(format!("2025-12-10T10:{:02}:00Z", i)),
204 end_ts: None,
205 snippet: Some(format!("Session {}", i)),
206 is_valid: true,
207 };
208 db.insert_or_update_session(&session).unwrap();
209 }
210
211 let sessions = db
212 .list_sessions(Some(&agtrace_types::ProjectHash::from("abc123")), 10)
213 .unwrap();
214 assert_eq!(sessions.len(), 5);
215
216 let sessions_limited = db
217 .list_sessions(Some(&agtrace_types::ProjectHash::from("abc123")), 3)
218 .unwrap();
219 assert_eq!(sessions_limited.len(), 3);
220 }
221
222 #[test]
223 fn test_count_sessions_for_project() {
224 let db = Database::open_in_memory().unwrap();
225
226 let project = ProjectRecord {
227 hash: agtrace_types::ProjectHash::from("abc123"),
228 root_path: Some("/path/to/project".to_string()),
229 last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
230 };
231 db.insert_or_update_project(&project).unwrap();
232
233 for i in 1..=3 {
234 let session = SessionRecord {
235 id: format!("session-{:03}", i),
236 project_hash: agtrace_types::ProjectHash::from("abc123"),
237 provider: "claude".to_string(),
238 start_ts: Some(format!("2025-12-10T10:{:02}:00Z", i)),
239 end_ts: None,
240 snippet: None,
241 is_valid: true,
242 };
243 db.insert_or_update_session(&session).unwrap();
244 }
245
246 let count = db.count_sessions_for_project("abc123").unwrap();
247 assert_eq!(count, 3);
248 }
249
250 #[test]
251 fn test_schema_version_set_on_init() {
252 let db = Database::open_in_memory().unwrap();
253
254 let version: i32 = db
255 .conn
256 .query_row("PRAGMA user_version", [], |row| row.get(0))
257 .unwrap();
258
259 assert_eq!(version, SCHEMA_VERSION);
260 }
261
262 #[test]
263 fn test_schema_rebuild_on_version_mismatch() {
264 let conn = Connection::open_in_memory().unwrap();
265
266 conn.execute_batch(
267 r#"
268 CREATE TABLE projects (hash TEXT PRIMARY KEY);
269 CREATE TABLE sessions (id TEXT PRIMARY KEY);
270 PRAGMA user_version = 999;
271 "#,
272 )
273 .unwrap();
274
275 conn.execute(
276 "INSERT INTO projects (hash) VALUES (?1)",
277 params!["old_data"],
278 )
279 .unwrap();
280
281 let db = Database { conn };
282 schema::init_schema(&db.conn).unwrap();
283
284 let version: i32 = db
285 .conn
286 .query_row("PRAGMA user_version", [], |row| row.get(0))
287 .unwrap();
288 assert_eq!(version, SCHEMA_VERSION);
289
290 let count: i64 = db
291 .conn
292 .query_row("SELECT COUNT(*) FROM projects", [], |row| row.get(0))
293 .unwrap();
294 assert_eq!(count, 0);
295 }
296
297 #[test]
298 fn test_schema_preserved_on_version_match() {
299 let db = Database::open_in_memory().unwrap();
300
301 let project = ProjectRecord {
302 hash: agtrace_types::ProjectHash::from("abc123"),
303 root_path: Some("/path/to/project".to_string()),
304 last_scanned_at: Some("2025-12-10T10:00:00Z".to_string()),
305 };
306 db.insert_or_update_project(&project).unwrap();
307
308 schema::init_schema(&db.conn).unwrap();
309
310 let retrieved = db.get_project("abc123").unwrap();
311 assert!(retrieved.is_some());
312 assert_eq!(
313 retrieved.unwrap().hash,
314 agtrace_types::ProjectHash::from("abc123")
315 );
316 }
317}