1pub(crate) mod schema;
2
3use anyhow::{Context, Result};
4use chrono::Utc;
5use rusqlite::{Connection, params};
6use std::path::PathBuf;
7use uuid::Uuid;
8
9fn db_path() -> Result<PathBuf> {
10 let dot_dir = crate::config::Config::data_dir();
11 std::fs::create_dir_all(&dot_dir).context("Could not create dot data directory")?;
12 Ok(dot_dir.join("dot.db"))
13}
14
15#[derive(Debug, Clone)]
16pub struct ConversationSummary {
17 pub id: String,
18 pub title: Option<String>,
19 pub model: String,
20 pub provider: String,
21 pub cwd: String,
22 pub created_at: String,
23 pub updated_at: String,
24}
25
26#[derive(Debug, Clone)]
27pub struct Conversation {
28 pub id: String,
29 pub title: Option<String>,
30 pub model: String,
31 pub provider: String,
32 pub cwd: String,
33 pub created_at: String,
34 pub updated_at: String,
35 pub messages: Vec<DbMessage>,
36 pub last_input_tokens: u32,
37}
38
39#[derive(Debug, Clone)]
40pub struct DbMessage {
41 pub id: String,
42 pub conversation_id: String,
43 pub role: String,
44 pub content: String,
45 pub token_count: u32,
46 pub created_at: String,
47}
48
49#[derive(Debug, Clone)]
50pub struct DbToolCall {
51 pub id: String,
52 pub message_id: String,
53 pub name: String,
54 pub input: String,
55 pub output: Option<String>,
56 pub is_error: bool,
57 pub created_at: String,
58}
59
60#[derive(Debug, Clone)]
61pub struct TaskRecord {
62 pub id: String,
63 pub prompt: String,
64 pub status: String,
65 pub session_id: Option<String>,
66 pub pid: Option<i64>,
67 pub output: Option<String>,
68 pub cwd: String,
69 pub created_at: String,
70 pub completed_at: Option<String>,
71}
72
73pub struct Db {
74 conn: Connection,
75}
76
77impl Db {
78 pub fn open() -> Result<Self> {
79 let path = db_path()?;
80 tracing::debug!("Opening database at {:?}", path);
81 let conn = Connection::open(&path)
82 .with_context(|| format!("Failed to open database at {:?}", path))?;
83 let db = Db { conn };
84 db.init()?;
85 Ok(db)
86 }
87
88 pub fn init(&self) -> Result<()> {
89 self.conn
90 .execute_batch(&format!(
91 "{}\n;\n{}\n;\n{}\n;\n{}",
92 schema::CREATE_CONVERSATIONS,
93 schema::CREATE_MESSAGES,
94 schema::CREATE_TOOL_CALLS,
95 schema::CREATE_TASKS,
96 ))
97 .context("Failed to initialize database schema")?;
98
99 let _ = self.conn.execute(
100 "ALTER TABLE conversations ADD COLUMN cwd TEXT NOT NULL DEFAULT ''",
101 [],
102 );
103 let _ = self.conn.execute(
104 "ALTER TABLE conversations ADD COLUMN last_input_tokens INTEGER NOT NULL DEFAULT 0",
105 [],
106 );
107
108 tracing::debug!("Database schema initialized");
109 Ok(())
110 }
111
112 pub fn create_conversation(&self, model: &str, provider: &str, cwd: &str) -> Result<String> {
113 let id = Uuid::new_v4().to_string();
114 self.create_conversation_with_id(&id, model, provider, cwd)?;
115 Ok(id)
116 }
117
118 pub fn create_conversation_with_id(
119 &self,
120 id: &str,
121 model: &str,
122 provider: &str,
123 cwd: &str,
124 ) -> Result<()> {
125 let now = Utc::now().to_rfc3339();
126 self.conn
127 .execute(
128 "INSERT INTO conversations (id, title, model, provider, cwd, created_at, updated_at) \
129 VALUES (?1, NULL, ?2, ?3, ?4, ?5, ?6)",
130 params![id, model, provider, cwd, now, now],
131 )
132 .context("Failed to create conversation")?;
133 tracing::debug!("Created conversation {}", id);
134 Ok(())
135 }
136
137 pub fn list_conversations(&self, limit: usize) -> Result<Vec<ConversationSummary>> {
138 let mut stmt = self
139 .conn
140 .prepare(
141 "SELECT id, title, model, provider, cwd, created_at, updated_at \
142 FROM conversations ORDER BY updated_at DESC LIMIT ?1",
143 )
144 .context("Failed to prepare list_conversations query")?;
145
146 let rows = stmt
147 .query_map(params![limit as i64], |row| {
148 Ok(ConversationSummary {
149 id: row.get(0)?,
150 title: row.get(1)?,
151 model: row.get(2)?,
152 provider: row.get(3)?,
153 cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
154 created_at: row.get(5)?,
155 updated_at: row.get(6)?,
156 })
157 })
158 .context("Failed to list conversations")?;
159
160 let mut conversations = Vec::new();
161 for row in rows {
162 conversations.push(row.context("Failed to read conversation row")?);
163 }
164 Ok(conversations)
165 }
166
167 pub fn list_conversations_for_cwd(
168 &self,
169 cwd: &str,
170 limit: usize,
171 ) -> Result<Vec<ConversationSummary>> {
172 let mut stmt = self
173 .conn
174 .prepare(
175 "SELECT id, title, model, provider, cwd, created_at, updated_at \
176 FROM conversations WHERE cwd = ?1 ORDER BY updated_at DESC LIMIT ?2",
177 )
178 .context("Failed to prepare list_conversations_for_cwd query")?;
179
180 let rows = stmt
181 .query_map(params![cwd, limit as i64], |row| {
182 Ok(ConversationSummary {
183 id: row.get(0)?,
184 title: row.get(1)?,
185 model: row.get(2)?,
186 provider: row.get(3)?,
187 cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
188 created_at: row.get(5)?,
189 updated_at: row.get(6)?,
190 })
191 })
192 .context("Failed to list conversations for cwd")?;
193
194 let mut conversations = Vec::new();
195 for row in rows {
196 conversations.push(row.context("Failed to read conversation row")?);
197 }
198 Ok(conversations)
199 }
200
201 pub fn get_conversation(&self, id: &str) -> Result<Conversation> {
202 let (summary, last_input_tokens) = self
203 .conn
204 .query_row(
205 "SELECT id, title, model, provider, cwd, created_at, updated_at, last_input_tokens \
206 FROM conversations WHERE id = ?1",
207 params![id],
208 |row| {
209 Ok((
210 ConversationSummary {
211 id: row.get(0)?,
212 title: row.get(1)?,
213 model: row.get(2)?,
214 provider: row.get(3)?,
215 cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
216 created_at: row.get(5)?,
217 updated_at: row.get(6)?,
218 },
219 row.get::<_, i64>(7).unwrap_or(0) as u32,
220 ))
221 },
222 )
223 .context("Failed to get conversation")?;
224
225 let messages = self.get_messages(id)?;
226 Ok(Conversation {
227 id: summary.id,
228 title: summary.title,
229 model: summary.model,
230 provider: summary.provider,
231 cwd: summary.cwd,
232 created_at: summary.created_at,
233 updated_at: summary.updated_at,
234 messages,
235 last_input_tokens,
236 })
237 }
238
239 pub fn update_conversation_title(&self, id: &str, title: &str) -> Result<()> {
240 let now = Utc::now().to_rfc3339();
241 self.conn
242 .execute(
243 "UPDATE conversations SET title = ?1, updated_at = ?2 WHERE id = ?3",
244 params![title, now, id],
245 )
246 .context("Failed to update conversation title")?;
247 Ok(())
248 }
249
250 pub fn delete_conversation(&self, id: &str) -> Result<()> {
251 self.conn
252 .execute(
253 "DELETE FROM tool_calls WHERE message_id IN \
254 (SELECT id FROM messages WHERE conversation_id = ?1)",
255 params![id],
256 )
257 .context("Failed to delete tool calls for conversation")?;
258
259 self.conn
260 .execute(
261 "DELETE FROM messages WHERE conversation_id = ?1",
262 params![id],
263 )
264 .context("Failed to delete messages for conversation")?;
265
266 self.conn
267 .execute("DELETE FROM conversations WHERE id = ?1", params![id])
268 .context("Failed to delete conversation")?;
269
270 tracing::debug!("Deleted conversation {}", id);
271 Ok(())
272 }
273
274 pub fn truncate_messages(&self, conversation_id: &str, keep: usize) -> Result<()> {
275 let ids: Vec<String> = {
276 let mut stmt = self
277 .conn
278 .prepare(
279 "SELECT id FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
280 )
281 .context("Failed to prepare truncate query")?;
282 let rows = stmt
283 .query_map(params![conversation_id], |row| row.get::<_, String>(0))
284 .context("Failed to query messages for truncation")?;
285 let mut all = Vec::new();
286 for row in rows {
287 all.push(row.context("Failed to read message id")?);
288 }
289 all
290 };
291 let to_delete = &ids[keep.min(ids.len())..];
292 for id in to_delete {
293 self.conn
294 .execute("DELETE FROM tool_calls WHERE message_id = ?1", params![id])
295 .context("Failed to delete tool calls for truncated message")?;
296 self.conn
297 .execute("DELETE FROM messages WHERE id = ?1", params![id])
298 .context("Failed to delete truncated message")?;
299 }
300 Ok(())
301 }
302
303 pub fn add_message(&self, conversation_id: &str, role: &str, content: &str) -> Result<String> {
304 let id = Uuid::new_v4().to_string();
305 let now = Utc::now().to_rfc3339();
306 self.conn
307 .execute(
308 "INSERT INTO messages \
309 (id, conversation_id, role, content, token_count, created_at) \
310 VALUES (?1, ?2, ?3, ?4, 0, ?5)",
311 params![id, conversation_id, role, content, now],
312 )
313 .context("Failed to add message")?;
314
315 self.conn
316 .execute(
317 "UPDATE conversations SET updated_at = ?1 WHERE id = ?2",
318 params![now, conversation_id],
319 )
320 .context("Failed to update conversation timestamp")?;
321
322 tracing::debug!("Added message {} to conversation {}", id, conversation_id);
323 Ok(id)
324 }
325
326 pub fn get_messages(&self, conversation_id: &str) -> Result<Vec<DbMessage>> {
327 let mut stmt = self
328 .conn
329 .prepare(
330 "SELECT id, conversation_id, role, content, token_count, created_at \
331 FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
332 )
333 .context("Failed to prepare get_messages query")?;
334
335 let rows = stmt
336 .query_map(params![conversation_id], |row| {
337 Ok(DbMessage {
338 id: row.get(0)?,
339 conversation_id: row.get(1)?,
340 role: row.get(2)?,
341 content: row.get(3)?,
342 token_count: row.get::<_, i64>(4)? as u32,
343 created_at: row.get(5)?,
344 })
345 })
346 .context("Failed to get messages")?;
347
348 let mut messages = Vec::new();
349 for row in rows {
350 messages.push(row.context("Failed to read message row")?);
351 }
352 Ok(messages)
353 }
354
355 pub fn update_last_input_tokens(&self, conversation_id: &str, tokens: u32) -> Result<()> {
356 self.conn
357 .execute(
358 "UPDATE conversations SET last_input_tokens = ?1 WHERE id = ?2",
359 params![tokens as i64, conversation_id],
360 )
361 .context("Failed to update last_input_tokens")?;
362 Ok(())
363 }
364
365 pub fn update_message_tokens(&self, id: &str, tokens: u32) -> Result<()> {
366 self.conn
367 .execute(
368 "UPDATE messages SET token_count = ?1 WHERE id = ?2",
369 params![tokens as i64, id],
370 )
371 .context("Failed to update message tokens")?;
372 Ok(())
373 }
374
375 pub fn add_tool_call(
376 &self,
377 message_id: &str,
378 tool_id: &str,
379 name: &str,
380 input: &str,
381 ) -> Result<()> {
382 let now = Utc::now().to_rfc3339();
383 self.conn
384 .execute(
385 "INSERT INTO tool_calls \
386 (id, message_id, name, input, output, is_error, created_at) \
387 VALUES (?1, ?2, ?3, ?4, NULL, 0, ?5)",
388 params![tool_id, message_id, name, input, now],
389 )
390 .context("Failed to add tool call")?;
391 tracing::debug!("Added tool call {} for message {}", tool_id, message_id);
392 Ok(())
393 }
394
395 pub fn update_tool_result(&self, tool_id: &str, output: &str, is_error: bool) -> Result<()> {
396 self.conn
397 .execute(
398 "UPDATE tool_calls SET output = ?1, is_error = ?2 WHERE id = ?3",
399 params![output, is_error as i64, tool_id],
400 )
401 .context("Failed to update tool result")?;
402 Ok(())
403 }
404
405 pub fn get_tool_calls(&self, message_id: &str) -> Result<Vec<DbToolCall>> {
406 let mut stmt = self
407 .conn
408 .prepare(
409 "SELECT id, message_id, name, input, output, is_error, created_at \
410 FROM tool_calls WHERE message_id = ?1 ORDER BY created_at ASC",
411 )
412 .context("Failed to prepare get_tool_calls query")?;
413
414 let rows = stmt
415 .query_map(params![message_id], |row| {
416 Ok(DbToolCall {
417 id: row.get(0)?,
418 message_id: row.get(1)?,
419 name: row.get(2)?,
420 input: row.get(3)?,
421 output: row.get(4)?,
422 is_error: row.get::<_, i64>(5)? != 0,
423 created_at: row.get(6)?,
424 })
425 })
426 .context("Failed to get tool calls")?;
427
428 let mut calls = Vec::new();
429 for row in rows {
430 calls.push(row.context("Failed to read tool call row")?);
431 }
432 Ok(calls)
433 }
434
435 pub fn get_user_message_history(&self, limit: usize) -> Result<Vec<String>> {
436 let mut stmt = self
437 .conn
438 .prepare(
439 "SELECT content FROM messages WHERE role = 'user' \
440 ORDER BY created_at DESC LIMIT ?1",
441 )
442 .context("Failed to prepare user history query")?;
443
444 let rows = stmt
445 .query_map(params![limit as i64], |row| row.get::<_, String>(0))
446 .context("Failed to query user history")?;
447
448 let mut messages = Vec::new();
449 for row in rows {
450 messages.push(row.context("Failed to read history row")?);
451 }
452 messages.reverse();
453 Ok(messages)
454 }
455
456 pub fn create_task(&self, id: &str, prompt: &str, pid: u32, cwd: &str) -> Result<()> {
457 let now = Utc::now().to_rfc3339();
458 self.conn
459 .execute(
460 "INSERT INTO tasks (id, prompt, status, pid, cwd, created_at) \
461 VALUES (?1, ?2, 'running', ?3, ?4, ?5)",
462 params![id, prompt, pid as i64, cwd, now],
463 )
464 .context("creating task")?;
465 Ok(())
466 }
467
468 pub fn complete_task(
469 &self,
470 id: &str,
471 status: &str,
472 session_id: Option<&str>,
473 output: &str,
474 ) -> Result<()> {
475 let now = Utc::now().to_rfc3339();
476 self.conn
477 .execute(
478 "UPDATE tasks SET status = ?1, session_id = ?2, output = ?3, completed_at = ?4 WHERE id = ?5",
479 params![status, session_id, output, now, id],
480 )
481 .context("completing task")?;
482 Ok(())
483 }
484
485 pub fn list_tasks(&self, limit: usize) -> Result<Vec<TaskRecord>> {
486 let mut stmt = self
487 .conn
488 .prepare(
489 "SELECT id, prompt, status, session_id, pid, output, cwd, created_at, completed_at \
490 FROM tasks ORDER BY created_at DESC LIMIT ?1",
491 )
492 .context("preparing list_tasks")?;
493 let rows = stmt
494 .query_map(params![limit as i64], |row| {
495 Ok(TaskRecord {
496 id: row.get(0)?,
497 prompt: row.get(1)?,
498 status: row.get(2)?,
499 session_id: row.get(3)?,
500 pid: row.get(4)?,
501 output: row.get(5)?,
502 cwd: row.get::<_, Option<String>>(6)?.unwrap_or_default(),
503 created_at: row.get(7)?,
504 completed_at: row.get(8)?,
505 })
506 })
507 .context("listing tasks")?;
508 let mut tasks = Vec::new();
509 for row in rows {
510 tasks.push(row.context("reading task row")?);
511 }
512 Ok(tasks)
513 }
514
515 pub fn get_task(&self, id: &str) -> Result<TaskRecord> {
516 self.conn
517 .query_row(
518 "SELECT id, prompt, status, session_id, pid, output, cwd, created_at, completed_at \
519 FROM tasks WHERE id = ?1",
520 params![id],
521 |row| {
522 Ok(TaskRecord {
523 id: row.get(0)?,
524 prompt: row.get(1)?,
525 status: row.get(2)?,
526 session_id: row.get(3)?,
527 pid: row.get(4)?,
528 output: row.get(5)?,
529 cwd: row.get::<_, Option<String>>(6)?.unwrap_or_default(),
530 created_at: row.get(7)?,
531 completed_at: row.get(8)?,
532 })
533 },
534 )
535 .context("getting task")
536 }
537}