1pub(crate) mod schema;
2
3use anyhow::{Context, Result};
4use chrono::Utc;
5use rusqlite::{params, Connection};
6use std::path::PathBuf;
7use uuid::Uuid;
8
9fn db_path() -> Result<PathBuf> {
10 let dot_dir = crate::config::Config::data_dir();
11 std::fs::create_dir_all(&dot_dir).context("Could not create dot data directory")?;
12 Ok(dot_dir.join("dot.db"))
13}
14
15#[derive(Debug, Clone)]
16pub struct ConversationSummary {
17 pub id: String,
18 pub title: Option<String>,
19 pub model: String,
20 pub provider: String,
21 pub cwd: String,
22 pub created_at: String,
23 pub updated_at: String,
24}
25
26#[derive(Debug, Clone)]
27pub struct Conversation {
28 pub id: String,
29 pub title: Option<String>,
30 pub model: String,
31 pub provider: String,
32 pub cwd: String,
33 pub created_at: String,
34 pub updated_at: String,
35 pub messages: Vec<DbMessage>,
36 pub last_input_tokens: u32,
37}
38
39#[derive(Debug, Clone)]
40pub struct DbMessage {
41 pub id: String,
42 pub conversation_id: String,
43 pub role: String,
44 pub content: String,
45 pub token_count: u32,
46 pub created_at: String,
47}
48
49#[derive(Debug, Clone)]
50pub struct DbToolCall {
51 pub id: String,
52 pub message_id: String,
53 pub name: String,
54 pub input: String,
55 pub output: Option<String>,
56 pub is_error: bool,
57 pub created_at: String,
58}
59
60#[derive(Debug, Clone)]
61pub struct TaskRecord {
62 pub id: String,
63 pub prompt: String,
64 pub status: String,
65 pub session_id: Option<String>,
66 pub pid: Option<i64>,
67 pub output: Option<String>,
68 pub cwd: String,
69 pub created_at: String,
70 pub completed_at: Option<String>,
71}
72
73pub struct Db {
74 conn: Connection,
75}
76
77impl Db {
78 pub fn open() -> Result<Self> {
79 let path = db_path()?;
80 tracing::debug!("Opening database at {:?}", path);
81 let conn = Connection::open(&path)
82 .with_context(|| format!("Failed to open database at {:?}", path))?;
83 let db = Db { conn };
84 db.init()?;
85 Ok(db)
86 }
87
88 pub fn init(&self) -> Result<()> {
89 self.conn
90 .execute_batch(&format!(
91 "{}\n;\n{}\n;\n{}\n;\n{}",
92 schema::CREATE_CONVERSATIONS,
93 schema::CREATE_MESSAGES,
94 schema::CREATE_TOOL_CALLS,
95 schema::CREATE_TASKS,
96 ))
97 .context("Failed to initialize database schema")?;
98
99 let _ = self.conn.execute(
100 "ALTER TABLE conversations ADD COLUMN cwd TEXT NOT NULL DEFAULT ''",
101 [],
102 );
103 let _ = self.conn.execute(
104 "ALTER TABLE conversations ADD COLUMN last_input_tokens INTEGER NOT NULL DEFAULT 0",
105 [],
106 );
107
108 tracing::debug!("Database schema initialized");
109 Ok(())
110 }
111
112 pub fn create_conversation(&self, model: &str, provider: &str, cwd: &str) -> Result<String> {
113 let id = Uuid::new_v4().to_string();
114 let now = Utc::now().to_rfc3339();
115 self.conn
116 .execute(
117 "INSERT INTO conversations (id, title, model, provider, cwd, created_at, updated_at) \
118 VALUES (?1, NULL, ?2, ?3, ?4, ?5, ?6)",
119 params![id, model, provider, cwd, now, now],
120 )
121 .context("Failed to create conversation")?;
122 tracing::debug!("Created conversation {}", id);
123 Ok(id)
124 }
125
126 pub fn list_conversations(&self, limit: usize) -> Result<Vec<ConversationSummary>> {
127 let mut stmt = self
128 .conn
129 .prepare(
130 "SELECT id, title, model, provider, cwd, created_at, updated_at \
131 FROM conversations ORDER BY updated_at DESC LIMIT ?1",
132 )
133 .context("Failed to prepare list_conversations query")?;
134
135 let rows = stmt
136 .query_map(params![limit as i64], |row| {
137 Ok(ConversationSummary {
138 id: row.get(0)?,
139 title: row.get(1)?,
140 model: row.get(2)?,
141 provider: row.get(3)?,
142 cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
143 created_at: row.get(5)?,
144 updated_at: row.get(6)?,
145 })
146 })
147 .context("Failed to list conversations")?;
148
149 let mut conversations = Vec::new();
150 for row in rows {
151 conversations.push(row.context("Failed to read conversation row")?);
152 }
153 Ok(conversations)
154 }
155
156 pub fn list_conversations_for_cwd(
157 &self,
158 cwd: &str,
159 limit: usize,
160 ) -> Result<Vec<ConversationSummary>> {
161 let mut stmt = self
162 .conn
163 .prepare(
164 "SELECT id, title, model, provider, cwd, created_at, updated_at \
165 FROM conversations WHERE cwd = ?1 ORDER BY updated_at DESC LIMIT ?2",
166 )
167 .context("Failed to prepare list_conversations_for_cwd query")?;
168
169 let rows = stmt
170 .query_map(params![cwd, limit as i64], |row| {
171 Ok(ConversationSummary {
172 id: row.get(0)?,
173 title: row.get(1)?,
174 model: row.get(2)?,
175 provider: row.get(3)?,
176 cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
177 created_at: row.get(5)?,
178 updated_at: row.get(6)?,
179 })
180 })
181 .context("Failed to list conversations for cwd")?;
182
183 let mut conversations = Vec::new();
184 for row in rows {
185 conversations.push(row.context("Failed to read conversation row")?);
186 }
187 Ok(conversations)
188 }
189
190 pub fn get_conversation(&self, id: &str) -> Result<Conversation> {
191 let (summary, last_input_tokens) = self
192 .conn
193 .query_row(
194 "SELECT id, title, model, provider, cwd, created_at, updated_at, last_input_tokens \
195 FROM conversations WHERE id = ?1",
196 params![id],
197 |row| {
198 Ok((
199 ConversationSummary {
200 id: row.get(0)?,
201 title: row.get(1)?,
202 model: row.get(2)?,
203 provider: row.get(3)?,
204 cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
205 created_at: row.get(5)?,
206 updated_at: row.get(6)?,
207 },
208 row.get::<_, i64>(7).unwrap_or(0) as u32,
209 ))
210 },
211 )
212 .context("Failed to get conversation")?;
213
214 let messages = self.get_messages(id)?;
215 Ok(Conversation {
216 id: summary.id,
217 title: summary.title,
218 model: summary.model,
219 provider: summary.provider,
220 cwd: summary.cwd,
221 created_at: summary.created_at,
222 updated_at: summary.updated_at,
223 messages,
224 last_input_tokens,
225 })
226 }
227
228 pub fn update_conversation_title(&self, id: &str, title: &str) -> Result<()> {
229 let now = Utc::now().to_rfc3339();
230 self.conn
231 .execute(
232 "UPDATE conversations SET title = ?1, updated_at = ?2 WHERE id = ?3",
233 params![title, now, id],
234 )
235 .context("Failed to update conversation title")?;
236 Ok(())
237 }
238
239 pub fn delete_conversation(&self, id: &str) -> Result<()> {
240 self.conn
241 .execute(
242 "DELETE FROM tool_calls WHERE message_id IN \
243 (SELECT id FROM messages WHERE conversation_id = ?1)",
244 params![id],
245 )
246 .context("Failed to delete tool calls for conversation")?;
247
248 self.conn
249 .execute(
250 "DELETE FROM messages WHERE conversation_id = ?1",
251 params![id],
252 )
253 .context("Failed to delete messages for conversation")?;
254
255 self.conn
256 .execute("DELETE FROM conversations WHERE id = ?1", params![id])
257 .context("Failed to delete conversation")?;
258
259 tracing::debug!("Deleted conversation {}", id);
260 Ok(())
261 }
262
263 pub fn truncate_messages(&self, conversation_id: &str, keep: usize) -> Result<()> {
264 let ids: Vec<String> = {
265 let mut stmt = self
266 .conn
267 .prepare(
268 "SELECT id FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
269 )
270 .context("Failed to prepare truncate query")?;
271 let rows = stmt
272 .query_map(params![conversation_id], |row| row.get::<_, String>(0))
273 .context("Failed to query messages for truncation")?;
274 let mut all = Vec::new();
275 for row in rows {
276 all.push(row.context("Failed to read message id")?);
277 }
278 all
279 };
280 let to_delete = &ids[keep.min(ids.len())..];
281 for id in to_delete {
282 self.conn
283 .execute("DELETE FROM tool_calls WHERE message_id = ?1", params![id])
284 .context("Failed to delete tool calls for truncated message")?;
285 self.conn
286 .execute("DELETE FROM messages WHERE id = ?1", params![id])
287 .context("Failed to delete truncated message")?;
288 }
289 Ok(())
290 }
291
292 pub fn add_message(&self, conversation_id: &str, role: &str, content: &str) -> Result<String> {
293 let id = Uuid::new_v4().to_string();
294 let now = Utc::now().to_rfc3339();
295 self.conn
296 .execute(
297 "INSERT INTO messages \
298 (id, conversation_id, role, content, token_count, created_at) \
299 VALUES (?1, ?2, ?3, ?4, 0, ?5)",
300 params![id, conversation_id, role, content, now],
301 )
302 .context("Failed to add message")?;
303
304 self.conn
305 .execute(
306 "UPDATE conversations SET updated_at = ?1 WHERE id = ?2",
307 params![now, conversation_id],
308 )
309 .context("Failed to update conversation timestamp")?;
310
311 tracing::debug!("Added message {} to conversation {}", id, conversation_id);
312 Ok(id)
313 }
314
315 pub fn get_messages(&self, conversation_id: &str) -> Result<Vec<DbMessage>> {
316 let mut stmt = self
317 .conn
318 .prepare(
319 "SELECT id, conversation_id, role, content, token_count, created_at \
320 FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
321 )
322 .context("Failed to prepare get_messages query")?;
323
324 let rows = stmt
325 .query_map(params![conversation_id], |row| {
326 Ok(DbMessage {
327 id: row.get(0)?,
328 conversation_id: row.get(1)?,
329 role: row.get(2)?,
330 content: row.get(3)?,
331 token_count: row.get::<_, i64>(4)? as u32,
332 created_at: row.get(5)?,
333 })
334 })
335 .context("Failed to get messages")?;
336
337 let mut messages = Vec::new();
338 for row in rows {
339 messages.push(row.context("Failed to read message row")?);
340 }
341 Ok(messages)
342 }
343
344 pub fn update_last_input_tokens(&self, conversation_id: &str, tokens: u32) -> Result<()> {
345 self.conn
346 .execute(
347 "UPDATE conversations SET last_input_tokens = ?1 WHERE id = ?2",
348 params![tokens as i64, conversation_id],
349 )
350 .context("Failed to update last_input_tokens")?;
351 Ok(())
352 }
353
354 pub fn update_message_tokens(&self, id: &str, tokens: u32) -> Result<()> {
355 self.conn
356 .execute(
357 "UPDATE messages SET token_count = ?1 WHERE id = ?2",
358 params![tokens as i64, id],
359 )
360 .context("Failed to update message tokens")?;
361 Ok(())
362 }
363
364 pub fn add_tool_call(
365 &self,
366 message_id: &str,
367 tool_id: &str,
368 name: &str,
369 input: &str,
370 ) -> Result<()> {
371 let now = Utc::now().to_rfc3339();
372 self.conn
373 .execute(
374 "INSERT INTO tool_calls \
375 (id, message_id, name, input, output, is_error, created_at) \
376 VALUES (?1, ?2, ?3, ?4, NULL, 0, ?5)",
377 params![tool_id, message_id, name, input, now],
378 )
379 .context("Failed to add tool call")?;
380 tracing::debug!("Added tool call {} for message {}", tool_id, message_id);
381 Ok(())
382 }
383
384 pub fn update_tool_result(&self, tool_id: &str, output: &str, is_error: bool) -> Result<()> {
385 self.conn
386 .execute(
387 "UPDATE tool_calls SET output = ?1, is_error = ?2 WHERE id = ?3",
388 params![output, is_error as i64, tool_id],
389 )
390 .context("Failed to update tool result")?;
391 Ok(())
392 }
393
394 pub fn get_tool_calls(&self, message_id: &str) -> Result<Vec<DbToolCall>> {
395 let mut stmt = self
396 .conn
397 .prepare(
398 "SELECT id, message_id, name, input, output, is_error, created_at \
399 FROM tool_calls WHERE message_id = ?1 ORDER BY created_at ASC",
400 )
401 .context("Failed to prepare get_tool_calls query")?;
402
403 let rows = stmt
404 .query_map(params![message_id], |row| {
405 Ok(DbToolCall {
406 id: row.get(0)?,
407 message_id: row.get(1)?,
408 name: row.get(2)?,
409 input: row.get(3)?,
410 output: row.get(4)?,
411 is_error: row.get::<_, i64>(5)? != 0,
412 created_at: row.get(6)?,
413 })
414 })
415 .context("Failed to get tool calls")?;
416
417 let mut calls = Vec::new();
418 for row in rows {
419 calls.push(row.context("Failed to read tool call row")?);
420 }
421 Ok(calls)
422 }
423
424 pub fn get_user_message_history(&self, limit: usize) -> Result<Vec<String>> {
425 let mut stmt = self
426 .conn
427 .prepare(
428 "SELECT content FROM messages WHERE role = 'user' \
429 ORDER BY created_at DESC LIMIT ?1",
430 )
431 .context("Failed to prepare user history query")?;
432
433 let rows = stmt
434 .query_map(params![limit as i64], |row| row.get::<_, String>(0))
435 .context("Failed to query user history")?;
436
437 let mut messages = Vec::new();
438 for row in rows {
439 messages.push(row.context("Failed to read history row")?);
440 }
441 messages.reverse();
442 Ok(messages)
443 }
444
445 pub fn create_task(&self, id: &str, prompt: &str, pid: u32, cwd: &str) -> Result<()> {
446 let now = Utc::now().to_rfc3339();
447 self.conn
448 .execute(
449 "INSERT INTO tasks (id, prompt, status, pid, cwd, created_at) \
450 VALUES (?1, ?2, 'running', ?3, ?4, ?5)",
451 params![id, prompt, pid as i64, cwd, now],
452 )
453 .context("creating task")?;
454 Ok(())
455 }
456
457 pub fn complete_task(
458 &self,
459 id: &str,
460 status: &str,
461 session_id: Option<&str>,
462 output: &str,
463 ) -> Result<()> {
464 let now = Utc::now().to_rfc3339();
465 self.conn
466 .execute(
467 "UPDATE tasks SET status = ?1, session_id = ?2, output = ?3, completed_at = ?4 WHERE id = ?5",
468 params![status, session_id, output, now, id],
469 )
470 .context("completing task")?;
471 Ok(())
472 }
473
474 pub fn list_tasks(&self, limit: usize) -> Result<Vec<TaskRecord>> {
475 let mut stmt = self
476 .conn
477 .prepare(
478 "SELECT id, prompt, status, session_id, pid, output, cwd, created_at, completed_at \
479 FROM tasks ORDER BY created_at DESC LIMIT ?1",
480 )
481 .context("preparing list_tasks")?;
482 let rows = stmt
483 .query_map(params![limit as i64], |row| {
484 Ok(TaskRecord {
485 id: row.get(0)?,
486 prompt: row.get(1)?,
487 status: row.get(2)?,
488 session_id: row.get(3)?,
489 pid: row.get(4)?,
490 output: row.get(5)?,
491 cwd: row.get::<_, Option<String>>(6)?.unwrap_or_default(),
492 created_at: row.get(7)?,
493 completed_at: row.get(8)?,
494 })
495 })
496 .context("listing tasks")?;
497 let mut tasks = Vec::new();
498 for row in rows {
499 tasks.push(row.context("reading task row")?);
500 }
501 Ok(tasks)
502 }
503
504 pub fn get_task(&self, id: &str) -> Result<TaskRecord> {
505 self.conn
506 .query_row(
507 "SELECT id, prompt, status, session_id, pid, output, cwd, created_at, completed_at \
508 FROM tasks WHERE id = ?1",
509 params![id],
510 |row| {
511 Ok(TaskRecord {
512 id: row.get(0)?,
513 prompt: row.get(1)?,
514 status: row.get(2)?,
515 session_id: row.get(3)?,
516 pid: row.get(4)?,
517 output: row.get(5)?,
518 cwd: row.get::<_, Option<String>>(6)?.unwrap_or_default(),
519 created_at: row.get(7)?,
520 completed_at: row.get(8)?,
521 })
522 },
523 )
524 .context("getting task")
525 }
526}