Skip to main content

dot/db/
mod.rs

1pub(crate) mod schema;
2
3use anyhow::{Context, Result};
4use chrono::Utc;
5use rusqlite::{params, Connection};
6use std::path::PathBuf;
7use uuid::Uuid;
8
9fn db_path() -> Result<PathBuf> {
10    let dot_dir = crate::config::Config::data_dir();
11    std::fs::create_dir_all(&dot_dir).context("Could not create dot data directory")?;
12    Ok(dot_dir.join("dot.db"))
13}
14
15#[derive(Debug, Clone)]
16pub struct ConversationSummary {
17    pub id: String,
18    pub title: Option<String>,
19    pub model: String,
20    pub provider: String,
21    pub cwd: String,
22    pub created_at: String,
23    pub updated_at: String,
24}
25
26#[derive(Debug, Clone)]
27pub struct Conversation {
28    pub id: String,
29    pub title: Option<String>,
30    pub model: String,
31    pub provider: String,
32    pub cwd: String,
33    pub created_at: String,
34    pub updated_at: String,
35    pub messages: Vec<DbMessage>,
36    pub last_input_tokens: u32,
37}
38
39#[derive(Debug, Clone)]
40pub struct DbMessage {
41    pub id: String,
42    pub conversation_id: String,
43    pub role: String,
44    pub content: String,
45    pub token_count: u32,
46    pub created_at: String,
47}
48
49#[derive(Debug, Clone)]
50pub struct DbToolCall {
51    pub id: String,
52    pub message_id: String,
53    pub name: String,
54    pub input: String,
55    pub output: Option<String>,
56    pub is_error: bool,
57    pub created_at: String,
58}
59
60#[derive(Debug, Clone)]
61pub struct TaskRecord {
62    pub id: String,
63    pub prompt: String,
64    pub status: String,
65    pub session_id: Option<String>,
66    pub pid: Option<i64>,
67    pub output: Option<String>,
68    pub cwd: String,
69    pub created_at: String,
70    pub completed_at: Option<String>,
71}
72
73pub struct Db {
74    conn: Connection,
75}
76
77impl Db {
78    pub fn open() -> Result<Self> {
79        let path = db_path()?;
80        tracing::debug!("Opening database at {:?}", path);
81        let conn = Connection::open(&path)
82            .with_context(|| format!("Failed to open database at {:?}", path))?;
83        let db = Db { conn };
84        db.init()?;
85        Ok(db)
86    }
87
88    pub fn init(&self) -> Result<()> {
89        self.conn
90            .execute_batch(&format!(
91                "{}\n;\n{}\n;\n{}\n;\n{}",
92                schema::CREATE_CONVERSATIONS,
93                schema::CREATE_MESSAGES,
94                schema::CREATE_TOOL_CALLS,
95                schema::CREATE_TASKS,
96            ))
97            .context("Failed to initialize database schema")?;
98
99        let _ = self.conn.execute(
100            "ALTER TABLE conversations ADD COLUMN cwd TEXT NOT NULL DEFAULT ''",
101            [],
102        );
103        let _ = self.conn.execute(
104            "ALTER TABLE conversations ADD COLUMN last_input_tokens INTEGER NOT NULL DEFAULT 0",
105            [],
106        );
107
108        tracing::debug!("Database schema initialized");
109        Ok(())
110    }
111
112    pub fn create_conversation(&self, model: &str, provider: &str, cwd: &str) -> Result<String> {
113        let id = Uuid::new_v4().to_string();
114        let now = Utc::now().to_rfc3339();
115        self.conn
116            .execute(
117                "INSERT INTO conversations (id, title, model, provider, cwd, created_at, updated_at) \
118                 VALUES (?1, NULL, ?2, ?3, ?4, ?5, ?6)",
119                params![id, model, provider, cwd, now, now],
120            )
121            .context("Failed to create conversation")?;
122        tracing::debug!("Created conversation {}", id);
123        Ok(id)
124    }
125
126    pub fn list_conversations(&self, limit: usize) -> Result<Vec<ConversationSummary>> {
127        let mut stmt = self
128            .conn
129            .prepare(
130                "SELECT id, title, model, provider, cwd, created_at, updated_at \
131                 FROM conversations ORDER BY updated_at DESC LIMIT ?1",
132            )
133            .context("Failed to prepare list_conversations query")?;
134
135        let rows = stmt
136            .query_map(params![limit as i64], |row| {
137                Ok(ConversationSummary {
138                    id: row.get(0)?,
139                    title: row.get(1)?,
140                    model: row.get(2)?,
141                    provider: row.get(3)?,
142                    cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
143                    created_at: row.get(5)?,
144                    updated_at: row.get(6)?,
145                })
146            })
147            .context("Failed to list conversations")?;
148
149        let mut conversations = Vec::new();
150        for row in rows {
151            conversations.push(row.context("Failed to read conversation row")?);
152        }
153        Ok(conversations)
154    }
155
156    pub fn list_conversations_for_cwd(
157        &self,
158        cwd: &str,
159        limit: usize,
160    ) -> Result<Vec<ConversationSummary>> {
161        let mut stmt = self
162            .conn
163            .prepare(
164                "SELECT id, title, model, provider, cwd, created_at, updated_at \
165                 FROM conversations WHERE cwd = ?1 ORDER BY updated_at DESC LIMIT ?2",
166            )
167            .context("Failed to prepare list_conversations_for_cwd query")?;
168
169        let rows = stmt
170            .query_map(params![cwd, limit as i64], |row| {
171                Ok(ConversationSummary {
172                    id: row.get(0)?,
173                    title: row.get(1)?,
174                    model: row.get(2)?,
175                    provider: row.get(3)?,
176                    cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
177                    created_at: row.get(5)?,
178                    updated_at: row.get(6)?,
179                })
180            })
181            .context("Failed to list conversations for cwd")?;
182
183        let mut conversations = Vec::new();
184        for row in rows {
185            conversations.push(row.context("Failed to read conversation row")?);
186        }
187        Ok(conversations)
188    }
189
190    pub fn get_conversation(&self, id: &str) -> Result<Conversation> {
191        let (summary, last_input_tokens) = self
192            .conn
193            .query_row(
194                "SELECT id, title, model, provider, cwd, created_at, updated_at, last_input_tokens \
195                 FROM conversations WHERE id = ?1",
196                params![id],
197                |row| {
198                    Ok((
199                        ConversationSummary {
200                            id: row.get(0)?,
201                            title: row.get(1)?,
202                            model: row.get(2)?,
203                            provider: row.get(3)?,
204                            cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
205                            created_at: row.get(5)?,
206                            updated_at: row.get(6)?,
207                        },
208                        row.get::<_, i64>(7).unwrap_or(0) as u32,
209                    ))
210                },
211            )
212            .context("Failed to get conversation")?;
213
214        let messages = self.get_messages(id)?;
215        Ok(Conversation {
216            id: summary.id,
217            title: summary.title,
218            model: summary.model,
219            provider: summary.provider,
220            cwd: summary.cwd,
221            created_at: summary.created_at,
222            updated_at: summary.updated_at,
223            messages,
224            last_input_tokens,
225        })
226    }
227
228    pub fn update_conversation_title(&self, id: &str, title: &str) -> Result<()> {
229        let now = Utc::now().to_rfc3339();
230        self.conn
231            .execute(
232                "UPDATE conversations SET title = ?1, updated_at = ?2 WHERE id = ?3",
233                params![title, now, id],
234            )
235            .context("Failed to update conversation title")?;
236        Ok(())
237    }
238
239    pub fn delete_conversation(&self, id: &str) -> Result<()> {
240        self.conn
241            .execute(
242                "DELETE FROM tool_calls WHERE message_id IN \
243                 (SELECT id FROM messages WHERE conversation_id = ?1)",
244                params![id],
245            )
246            .context("Failed to delete tool calls for conversation")?;
247
248        self.conn
249            .execute(
250                "DELETE FROM messages WHERE conversation_id = ?1",
251                params![id],
252            )
253            .context("Failed to delete messages for conversation")?;
254
255        self.conn
256            .execute("DELETE FROM conversations WHERE id = ?1", params![id])
257            .context("Failed to delete conversation")?;
258
259        tracing::debug!("Deleted conversation {}", id);
260        Ok(())
261    }
262
263    pub fn truncate_messages(&self, conversation_id: &str, keep: usize) -> Result<()> {
264        let ids: Vec<String> = {
265            let mut stmt = self
266                .conn
267                .prepare(
268                    "SELECT id FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
269                )
270                .context("Failed to prepare truncate query")?;
271            let rows = stmt
272                .query_map(params![conversation_id], |row| row.get::<_, String>(0))
273                .context("Failed to query messages for truncation")?;
274            let mut all = Vec::new();
275            for row in rows {
276                all.push(row.context("Failed to read message id")?);
277            }
278            all
279        };
280        let to_delete = &ids[keep.min(ids.len())..];
281        for id in to_delete {
282            self.conn
283                .execute("DELETE FROM tool_calls WHERE message_id = ?1", params![id])
284                .context("Failed to delete tool calls for truncated message")?;
285            self.conn
286                .execute("DELETE FROM messages WHERE id = ?1", params![id])
287                .context("Failed to delete truncated message")?;
288        }
289        Ok(())
290    }
291
292    pub fn add_message(&self, conversation_id: &str, role: &str, content: &str) -> Result<String> {
293        let id = Uuid::new_v4().to_string();
294        let now = Utc::now().to_rfc3339();
295        self.conn
296            .execute(
297                "INSERT INTO messages \
298                 (id, conversation_id, role, content, token_count, created_at) \
299                 VALUES (?1, ?2, ?3, ?4, 0, ?5)",
300                params![id, conversation_id, role, content, now],
301            )
302            .context("Failed to add message")?;
303
304        self.conn
305            .execute(
306                "UPDATE conversations SET updated_at = ?1 WHERE id = ?2",
307                params![now, conversation_id],
308            )
309            .context("Failed to update conversation timestamp")?;
310
311        tracing::debug!("Added message {} to conversation {}", id, conversation_id);
312        Ok(id)
313    }
314
315    pub fn get_messages(&self, conversation_id: &str) -> Result<Vec<DbMessage>> {
316        let mut stmt = self
317            .conn
318            .prepare(
319                "SELECT id, conversation_id, role, content, token_count, created_at \
320                 FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
321            )
322            .context("Failed to prepare get_messages query")?;
323
324        let rows = stmt
325            .query_map(params![conversation_id], |row| {
326                Ok(DbMessage {
327                    id: row.get(0)?,
328                    conversation_id: row.get(1)?,
329                    role: row.get(2)?,
330                    content: row.get(3)?,
331                    token_count: row.get::<_, i64>(4)? as u32,
332                    created_at: row.get(5)?,
333                })
334            })
335            .context("Failed to get messages")?;
336
337        let mut messages = Vec::new();
338        for row in rows {
339            messages.push(row.context("Failed to read message row")?);
340        }
341        Ok(messages)
342    }
343
344    pub fn update_last_input_tokens(&self, conversation_id: &str, tokens: u32) -> Result<()> {
345        self.conn
346            .execute(
347                "UPDATE conversations SET last_input_tokens = ?1 WHERE id = ?2",
348                params![tokens as i64, conversation_id],
349            )
350            .context("Failed to update last_input_tokens")?;
351        Ok(())
352    }
353
354    pub fn update_message_tokens(&self, id: &str, tokens: u32) -> Result<()> {
355        self.conn
356            .execute(
357                "UPDATE messages SET token_count = ?1 WHERE id = ?2",
358                params![tokens as i64, id],
359            )
360            .context("Failed to update message tokens")?;
361        Ok(())
362    }
363
364    pub fn add_tool_call(
365        &self,
366        message_id: &str,
367        tool_id: &str,
368        name: &str,
369        input: &str,
370    ) -> Result<()> {
371        let now = Utc::now().to_rfc3339();
372        self.conn
373            .execute(
374                "INSERT INTO tool_calls \
375                 (id, message_id, name, input, output, is_error, created_at) \
376                 VALUES (?1, ?2, ?3, ?4, NULL, 0, ?5)",
377                params![tool_id, message_id, name, input, now],
378            )
379            .context("Failed to add tool call")?;
380        tracing::debug!("Added tool call {} for message {}", tool_id, message_id);
381        Ok(())
382    }
383
384    pub fn update_tool_result(&self, tool_id: &str, output: &str, is_error: bool) -> Result<()> {
385        self.conn
386            .execute(
387                "UPDATE tool_calls SET output = ?1, is_error = ?2 WHERE id = ?3",
388                params![output, is_error as i64, tool_id],
389            )
390            .context("Failed to update tool result")?;
391        Ok(())
392    }
393
394    pub fn get_tool_calls(&self, message_id: &str) -> Result<Vec<DbToolCall>> {
395        let mut stmt = self
396            .conn
397            .prepare(
398                "SELECT id, message_id, name, input, output, is_error, created_at \
399                 FROM tool_calls WHERE message_id = ?1 ORDER BY created_at ASC",
400            )
401            .context("Failed to prepare get_tool_calls query")?;
402
403        let rows = stmt
404            .query_map(params![message_id], |row| {
405                Ok(DbToolCall {
406                    id: row.get(0)?,
407                    message_id: row.get(1)?,
408                    name: row.get(2)?,
409                    input: row.get(3)?,
410                    output: row.get(4)?,
411                    is_error: row.get::<_, i64>(5)? != 0,
412                    created_at: row.get(6)?,
413                })
414            })
415            .context("Failed to get tool calls")?;
416
417        let mut calls = Vec::new();
418        for row in rows {
419            calls.push(row.context("Failed to read tool call row")?);
420        }
421        Ok(calls)
422    }
423
424    pub fn get_user_message_history(&self, limit: usize) -> Result<Vec<String>> {
425        let mut stmt = self
426            .conn
427            .prepare(
428                "SELECT content FROM messages WHERE role = 'user' \
429                 ORDER BY created_at DESC LIMIT ?1",
430            )
431            .context("Failed to prepare user history query")?;
432
433        let rows = stmt
434            .query_map(params![limit as i64], |row| row.get::<_, String>(0))
435            .context("Failed to query user history")?;
436
437        let mut messages = Vec::new();
438        for row in rows {
439            messages.push(row.context("Failed to read history row")?);
440        }
441        messages.reverse();
442        Ok(messages)
443    }
444
445    pub fn create_task(&self, id: &str, prompt: &str, pid: u32, cwd: &str) -> Result<()> {
446        let now = Utc::now().to_rfc3339();
447        self.conn
448            .execute(
449                "INSERT INTO tasks (id, prompt, status, pid, cwd, created_at) \
450                 VALUES (?1, ?2, 'running', ?3, ?4, ?5)",
451                params![id, prompt, pid as i64, cwd, now],
452            )
453            .context("creating task")?;
454        Ok(())
455    }
456
457    pub fn complete_task(
458        &self,
459        id: &str,
460        status: &str,
461        session_id: Option<&str>,
462        output: &str,
463    ) -> Result<()> {
464        let now = Utc::now().to_rfc3339();
465        self.conn
466            .execute(
467                "UPDATE tasks SET status = ?1, session_id = ?2, output = ?3, completed_at = ?4 WHERE id = ?5",
468                params![status, session_id, output, now, id],
469            )
470            .context("completing task")?;
471        Ok(())
472    }
473
474    pub fn list_tasks(&self, limit: usize) -> Result<Vec<TaskRecord>> {
475        let mut stmt = self
476            .conn
477            .prepare(
478                "SELECT id, prompt, status, session_id, pid, output, cwd, created_at, completed_at \
479                 FROM tasks ORDER BY created_at DESC LIMIT ?1",
480            )
481            .context("preparing list_tasks")?;
482        let rows = stmt
483            .query_map(params![limit as i64], |row| {
484                Ok(TaskRecord {
485                    id: row.get(0)?,
486                    prompt: row.get(1)?,
487                    status: row.get(2)?,
488                    session_id: row.get(3)?,
489                    pid: row.get(4)?,
490                    output: row.get(5)?,
491                    cwd: row.get::<_, Option<String>>(6)?.unwrap_or_default(),
492                    created_at: row.get(7)?,
493                    completed_at: row.get(8)?,
494                })
495            })
496            .context("listing tasks")?;
497        let mut tasks = Vec::new();
498        for row in rows {
499            tasks.push(row.context("reading task row")?);
500        }
501        Ok(tasks)
502    }
503
504    pub fn get_task(&self, id: &str) -> Result<TaskRecord> {
505        self.conn
506            .query_row(
507                "SELECT id, prompt, status, session_id, pid, output, cwd, created_at, completed_at \
508                 FROM tasks WHERE id = ?1",
509                params![id],
510                |row| {
511                    Ok(TaskRecord {
512                        id: row.get(0)?,
513                        prompt: row.get(1)?,
514                        status: row.get(2)?,
515                        session_id: row.get(3)?,
516                        pid: row.get(4)?,
517                        output: row.get(5)?,
518                        cwd: row.get::<_, Option<String>>(6)?.unwrap_or_default(),
519                        created_at: row.get(7)?,
520                        completed_at: row.get(8)?,
521                    })
522                },
523            )
524            .context("getting task")
525    }
526}