1use crate::error::{ConvoError, Result};
10use crate::types::{Message, MessageData, Part, PartData, Project, Session};
11use rusqlite::{Connection, OpenFlags, params};
12use std::path::{Path, PathBuf};
13
14pub struct DbReader {
17 conn: Connection,
18 path: PathBuf,
19}
20
21impl DbReader {
22 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
24 let path = path.as_ref();
25 if !path.exists() {
26 return Err(ConvoError::DatabaseNotFound(path.to_path_buf()));
27 }
28 let conn = Connection::open_with_flags(
29 path,
30 OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
31 )?;
32 Ok(Self {
33 conn,
34 path: path.to_path_buf(),
35 })
36 }
37
38 pub fn path(&self) -> &Path {
39 &self.path
40 }
41
42 pub fn list_projects(&self) -> Result<Vec<Project>> {
44 let mut stmt = self.conn.prepare(
45 "SELECT id, worktree, vcs, name, time_created, time_updated, time_initialized, sandboxes
46 FROM project
47 ORDER BY time_updated DESC",
48 )?;
49 let rows = stmt.query_map([], Self::map_project)?;
50 let mut out = Vec::new();
51 for r in rows {
52 out.push(r?);
53 }
54 Ok(out)
55 }
56
57 pub fn get_project(&self, id: &str) -> Result<Option<Project>> {
58 let mut stmt = self.conn.prepare(
59 "SELECT id, worktree, vcs, name, time_created, time_updated, time_initialized, sandboxes
60 FROM project WHERE id = ?1",
61 )?;
62 let mut rows = stmt.query_map(params![id], Self::map_project)?;
63 rows.next().transpose().map_err(ConvoError::from)
64 }
65
66 fn map_project(row: &rusqlite::Row<'_>) -> rusqlite::Result<Project> {
67 let sandboxes_json: String = row.get::<_, Option<String>>(7)?.unwrap_or("[]".to_string());
68 let sandboxes: Vec<String> = serde_json::from_str(&sandboxes_json).unwrap_or_default();
69 Ok(Project {
70 id: row.get(0)?,
71 worktree: PathBuf::from(row.get::<_, String>(1)?),
72 vcs: row.get(2)?,
73 name: row.get(3)?,
74 time_created: row.get(4)?,
75 time_updated: row.get(5)?,
76 time_initialized: row.get(6)?,
77 sandboxes,
78 })
79 }
80
81 pub fn list_sessions(&self, project_id: Option<&str>) -> Result<Vec<Session>> {
84 let sql = "SELECT id, project_id, workspace_id, parent_id, slug, directory, title,
85 version, share_url, summary_additions, summary_deletions,
86 summary_files, time_created, time_updated, time_compacting, time_archived
87 FROM session";
88 let rows: Vec<Session> = if let Some(pid) = project_id {
89 let mut stmt = self.conn.prepare(&format!(
90 "{sql} WHERE project_id = ?1 ORDER BY time_updated DESC"
91 ))?;
92 stmt.query_map(params![pid], Self::map_session)?
93 .collect::<rusqlite::Result<Vec<_>>>()?
94 } else {
95 let mut stmt = self
96 .conn
97 .prepare(&format!("{sql} ORDER BY time_updated DESC"))?;
98 stmt.query_map([], Self::map_session)?
99 .collect::<rusqlite::Result<Vec<_>>>()?
100 };
101 Ok(rows)
102 }
103
104 pub fn get_session(&self, id: &str) -> Result<Option<Session>> {
105 let sql = "SELECT id, project_id, workspace_id, parent_id, slug, directory, title,
106 version, share_url, summary_additions, summary_deletions,
107 summary_files, time_created, time_updated, time_compacting, time_archived
108 FROM session WHERE id = ?1";
109 let mut stmt = self.conn.prepare(sql)?;
110 let mut rows = stmt.query_map(params![id], Self::map_session)?;
111 rows.next().transpose().map_err(ConvoError::from)
112 }
113
114 fn map_session(row: &rusqlite::Row<'_>) -> rusqlite::Result<Session> {
115 Ok(Session {
116 id: row.get(0)?,
117 project_id: row.get(1)?,
118 workspace_id: row.get(2)?,
119 parent_id: row.get(3)?,
120 slug: row.get(4)?,
121 directory: PathBuf::from(row.get::<_, String>(5)?),
122 title: row.get(6)?,
123 version: row.get(7)?,
124 share_url: row.get(8)?,
125 summary_additions: row.get(9)?,
126 summary_deletions: row.get(10)?,
127 summary_files: row.get(11)?,
128 time_created: row.get(12)?,
129 time_updated: row.get(13)?,
130 time_compacting: row.get(14)?,
131 time_archived: row.get(15)?,
132 messages: Vec::new(),
133 })
134 }
135
136 pub fn list_messages_raw(&self, session_id: &str) -> Result<Vec<Message>> {
139 let mut stmt = self.conn.prepare(
140 "SELECT id, session_id, time_created, time_updated, data
141 FROM message
142 WHERE session_id = ?1
143 ORDER BY time_created ASC, id ASC",
144 )?;
145 let rows = stmt
146 .query_map(params![session_id], Self::map_message)?
147 .collect::<rusqlite::Result<Vec<_>>>()?;
148 Ok(rows)
149 }
150
151 fn map_message(row: &rusqlite::Row<'_>) -> rusqlite::Result<Message> {
152 let raw_data: String = row.get(4)?;
153 let data = match serde_json::from_str::<MessageData>(&raw_data) {
154 Ok(d) => d,
155 Err(e) => {
156 eprintln!(
160 "Warning: message {} has malformed data: {}",
161 row.get::<_, String>(0)?,
162 e
163 );
164 MessageData::Other
165 }
166 };
167 Ok(Message {
168 id: row.get(0)?,
169 session_id: row.get(1)?,
170 time_created: row.get(2)?,
171 time_updated: row.get(3)?,
172 data,
173 parts: Vec::new(),
174 })
175 }
176
177 pub fn list_parts_for_message(&self, message_id: &str) -> Result<Vec<Part>> {
180 let mut stmt = self.conn.prepare(
181 "SELECT id, message_id, session_id, time_created, time_updated, data
182 FROM part
183 WHERE message_id = ?1
184 ORDER BY time_created ASC, id ASC",
185 )?;
186 let rows = stmt
187 .query_map(params![message_id], Self::map_part)?
188 .collect::<rusqlite::Result<Vec<_>>>()?;
189 Ok(rows)
190 }
191
192 pub fn list_parts_for_session(&self, session_id: &str) -> Result<Vec<Part>> {
196 let mut stmt = self.conn.prepare(
197 "SELECT id, message_id, session_id, time_created, time_updated, data
198 FROM part
199 WHERE session_id = ?1
200 ORDER BY message_id ASC, time_created ASC, id ASC",
201 )?;
202 let rows = stmt
203 .query_map(params![session_id], Self::map_part)?
204 .collect::<rusqlite::Result<Vec<_>>>()?;
205 Ok(rows)
206 }
207
208 fn map_part(row: &rusqlite::Row<'_>) -> rusqlite::Result<Part> {
209 let raw_data: String = row.get(5)?;
210 let data = match serde_json::from_str::<PartData>(&raw_data) {
211 Ok(d) => d,
212 Err(e) => {
213 eprintln!(
214 "Warning: part {} has malformed data: {}",
215 row.get::<_, String>(0)?,
216 e
217 );
218 PartData::Unknown
219 }
220 };
221 Ok(Part {
222 id: row.get(0)?,
223 message_id: row.get(1)?,
224 session_id: row.get(2)?,
225 time_created: row.get(3)?,
226 time_updated: row.get(4)?,
227 data,
228 })
229 }
230
231 pub fn load_session(&self, session_id: &str) -> Result<Session> {
234 let mut session = self
235 .get_session(session_id)?
236 .ok_or_else(|| ConvoError::SessionNotFound(session_id.to_string()))?;
237 let mut messages = self.list_messages_raw(session_id)?;
238 let parts = self.list_parts_for_session(session_id)?;
239 let mut by_msg: std::collections::HashMap<String, Vec<Part>> =
241 std::collections::HashMap::new();
242 for p in parts {
243 by_msg.entry(p.message_id.clone()).or_default().push(p);
244 }
245 for m in &mut messages {
246 if let Some(ps) = by_msg.remove(&m.id) {
247 m.parts = ps;
248 }
249 }
250 session.messages = messages;
251 Ok(session)
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use rusqlite::Connection;
259 use tempfile::NamedTempFile;
260
261 fn fixture_db() -> NamedTempFile {
262 let f = NamedTempFile::new().unwrap();
263 let conn = Connection::open(f.path()).unwrap();
264 conn.execute_batch(
265 r#"
266 CREATE TABLE project (
267 id text PRIMARY KEY, worktree text NOT NULL, vcs text, name text,
268 icon_url text, icon_color text,
269 time_created integer NOT NULL, time_updated integer NOT NULL,
270 time_initialized integer, sandboxes text NOT NULL, commands text
271 );
272 CREATE TABLE session (
273 id text PRIMARY KEY, project_id text NOT NULL, parent_id text,
274 slug text NOT NULL, directory text NOT NULL, title text NOT NULL,
275 version text NOT NULL, share_url text,
276 summary_additions integer, summary_deletions integer,
277 summary_files integer, summary_diffs text, revert text, permission text,
278 time_created integer NOT NULL, time_updated integer NOT NULL,
279 time_compacting integer, time_archived integer, workspace_id text
280 );
281 CREATE TABLE message (
282 id text PRIMARY KEY, session_id text NOT NULL,
283 time_created integer NOT NULL, time_updated integer NOT NULL,
284 data text NOT NULL
285 );
286 CREATE TABLE part (
287 id text PRIMARY KEY, message_id text NOT NULL, session_id text NOT NULL,
288 time_created integer NOT NULL, time_updated integer NOT NULL,
289 data text NOT NULL
290 );
291 INSERT INTO project (id, worktree, time_created, time_updated, sandboxes)
292 VALUES ('proj1', '/tmp/p', 1000, 2000, '[]');
293 INSERT INTO session (id, project_id, slug, directory, title, version,
294 time_created, time_updated)
295 VALUES ('ses_1', 'proj1', 'slug', '/tmp/p', 'T', '1.0.0', 1000, 2000);
296 INSERT INTO message (id, session_id, time_created, time_updated, data) VALUES
297 ('msg_1','ses_1',1001,1001,'{"role":"user","time":{"created":1001},"agent":"build","model":{"providerID":"p","modelID":"m"}}'),
298 ('msg_2','ses_1',1002,1002,'{"parentID":"msg_1","role":"assistant","mode":"build","agent":"build","path":{"cwd":"/tmp/p","root":"/tmp/p"},"cost":0.01,"tokens":{"input":5,"output":3,"reasoning":0,"cache":{"read":0,"write":0}},"modelID":"m","providerID":"p","time":{"created":1002,"completed":1003},"finish":"stop"}');
299 INSERT INTO part (id, message_id, session_id, time_created, time_updated, data) VALUES
300 ('prt_1','msg_1','ses_1',1001,1001,'{"type":"text","text":"hello"}'),
301 ('prt_2','msg_2','ses_1',1002,1002,'{"type":"step-start","snapshot":"abc"}'),
302 ('prt_3','msg_2','ses_1',1002,1002,'{"type":"text","text":"hi!"}'),
303 ('prt_4','msg_2','ses_1',1003,1003,'{"type":"step-finish","reason":"stop","snapshot":"abc","tokens":{"input":5,"output":3,"reasoning":0,"cache":{"read":0,"write":0}},"cost":0.01}');
304 "#,
305 )
306 .unwrap();
307 f
308 }
309
310 #[test]
311 fn open_reads_projects() {
312 let f = fixture_db();
313 let r = DbReader::open(f.path()).unwrap();
314 let ps = r.list_projects().unwrap();
315 assert_eq!(ps.len(), 1);
316 assert_eq!(ps[0].id, "proj1");
317 }
318
319 #[test]
320 fn open_reads_sessions() {
321 let f = fixture_db();
322 let r = DbReader::open(f.path()).unwrap();
323 let ss = r.list_sessions(Some("proj1")).unwrap();
324 assert_eq!(ss.len(), 1);
325 assert_eq!(ss[0].id, "ses_1");
326 }
327
328 #[test]
329 fn load_session_attaches_parts() {
330 let f = fixture_db();
331 let r = DbReader::open(f.path()).unwrap();
332 let s = r.load_session("ses_1").unwrap();
333 assert_eq!(s.messages.len(), 2);
334 assert_eq!(s.messages[0].parts.len(), 1);
335 assert_eq!(s.messages[1].parts.len(), 3);
336 assert_eq!(s.first_user_text().as_deref(), Some("hello"));
337 }
338
339 #[test]
340 fn load_session_missing_errors() {
341 let f = fixture_db();
342 let r = DbReader::open(f.path()).unwrap();
343 let err = r.load_session("nope").unwrap_err();
344 assert!(matches!(err, ConvoError::SessionNotFound(_)));
345 }
346
347 #[test]
348 fn malformed_message_rolls_over_to_other() {
349 let f = NamedTempFile::new().unwrap();
350 let conn = Connection::open(f.path()).unwrap();
351 conn.execute_batch(
352 "CREATE TABLE message (id text PRIMARY KEY, session_id text NOT NULL,
353 time_created integer NOT NULL, time_updated integer NOT NULL,
354 data text NOT NULL);
355 CREATE TABLE session (id text PRIMARY KEY, project_id text NOT NULL, slug text NOT NULL,
356 directory text NOT NULL, title text NOT NULL, version text NOT NULL,
357 time_created integer NOT NULL, time_updated integer NOT NULL,
358 parent_id text, share_url text, summary_additions integer,
359 summary_deletions integer, summary_files text, time_compacting integer,
360 time_archived integer, workspace_id text);
361 CREATE TABLE part (id text PRIMARY KEY, message_id text NOT NULL,
362 session_id text NOT NULL, time_created integer NOT NULL,
363 time_updated integer NOT NULL, data text NOT NULL);
364 INSERT INTO session (id, project_id, slug, directory, title, version, time_created, time_updated)
365 VALUES ('s','p','slug','/p','t','1.0.0',1,1);
366 INSERT INTO message (id, session_id, time_created, time_updated, data)
367 VALUES ('m','s',1,1,'{{not json}}');",
368 )
369 .unwrap();
370 let r = DbReader::open(f.path()).unwrap();
371 let msgs = r.list_messages_raw("s").unwrap();
372 assert_eq!(msgs.len(), 1);
373 assert!(matches!(msgs[0].data, MessageData::Other));
374 }
375}