Skip to main content

dot/db/
mod.rs

1pub(crate) mod schema;
2
3use anyhow::{Context, Result};
4use chrono::Utc;
5use rusqlite::{Connection, params};
6use std::path::PathBuf;
7use uuid::Uuid;
8
9fn db_path() -> Result<PathBuf> {
10    let dot_dir = crate::config::Config::data_dir();
11    std::fs::create_dir_all(&dot_dir).context("Could not create dot data directory")?;
12    Ok(dot_dir.join("dot.db"))
13}
14
15#[derive(Debug, Clone)]
16pub struct ConversationSummary {
17    pub id: String,
18    pub title: Option<String>,
19    pub model: String,
20    pub provider: String,
21    pub cwd: String,
22    pub created_at: String,
23    pub updated_at: String,
24}
25
26#[derive(Debug, Clone)]
27pub struct Conversation {
28    pub id: String,
29    pub title: Option<String>,
30    pub model: String,
31    pub provider: String,
32    pub cwd: String,
33    pub created_at: String,
34    pub updated_at: String,
35    pub messages: Vec<DbMessage>,
36    pub last_input_tokens: u32,
37}
38
39#[derive(Debug, Clone)]
40pub struct DbMessage {
41    pub id: String,
42    pub conversation_id: String,
43    pub role: String,
44    pub content: String,
45    pub token_count: u32,
46    pub created_at: String,
47}
48
49#[derive(Debug, Clone)]
50pub struct DbToolCall {
51    pub id: String,
52    pub message_id: String,
53    pub name: String,
54    pub input: String,
55    pub output: Option<String>,
56    pub is_error: bool,
57    pub created_at: String,
58}
59
60#[derive(Debug, Clone)]
61pub struct TaskRecord {
62    pub id: String,
63    pub prompt: String,
64    pub status: String,
65    pub session_id: Option<String>,
66    pub pid: Option<i64>,
67    pub output: Option<String>,
68    pub cwd: String,
69    pub created_at: String,
70    pub completed_at: Option<String>,
71}
72
73pub struct Db {
74    conn: Connection,
75}
76
77impl Db {
78    pub fn open() -> Result<Self> {
79        let path = db_path()?;
80        tracing::debug!("Opening database at {:?}", path);
81        let conn = Connection::open(&path)
82            .with_context(|| format!("Failed to open database at {:?}", path))?;
83        let db = Db { conn };
84        db.init()?;
85        Ok(db)
86    }
87
88    pub fn init(&self) -> Result<()> {
89        self.conn
90            .execute_batch(&format!(
91                "{}\n;\n{}\n;\n{}\n;\n{}",
92                schema::CREATE_CONVERSATIONS,
93                schema::CREATE_MESSAGES,
94                schema::CREATE_TOOL_CALLS,
95                schema::CREATE_TASKS,
96            ))
97            .context("Failed to initialize database schema")?;
98
99        let _ = self.conn.execute(
100            "ALTER TABLE conversations ADD COLUMN cwd TEXT NOT NULL DEFAULT ''",
101            [],
102        );
103        let _ = self.conn.execute(
104            "ALTER TABLE conversations ADD COLUMN last_input_tokens INTEGER NOT NULL DEFAULT 0",
105            [],
106        );
107
108        tracing::debug!("Database schema initialized");
109        Ok(())
110    }
111
112    pub fn create_conversation(&self, model: &str, provider: &str, cwd: &str) -> Result<String> {
113        let id = Uuid::new_v4().to_string();
114        self.create_conversation_with_id(&id, model, provider, cwd)?;
115        Ok(id)
116    }
117
118    pub fn create_conversation_with_id(
119        &self,
120        id: &str,
121        model: &str,
122        provider: &str,
123        cwd: &str,
124    ) -> Result<()> {
125        let now = Utc::now().to_rfc3339();
126        self.conn
127            .execute(
128                "INSERT INTO conversations (id, title, model, provider, cwd, created_at, updated_at) \
129                 VALUES (?1, NULL, ?2, ?3, ?4, ?5, ?6)",
130                params![id, model, provider, cwd, now, now],
131            )
132            .context("Failed to create conversation")?;
133        tracing::debug!("Created conversation {}", id);
134        Ok(())
135    }
136
137    pub fn list_conversations(&self, limit: usize) -> Result<Vec<ConversationSummary>> {
138        let mut stmt = self
139            .conn
140            .prepare(
141                "SELECT id, title, model, provider, cwd, created_at, updated_at \
142                 FROM conversations ORDER BY updated_at DESC LIMIT ?1",
143            )
144            .context("Failed to prepare list_conversations query")?;
145
146        let rows = stmt
147            .query_map(params![limit as i64], |row| {
148                Ok(ConversationSummary {
149                    id: row.get(0)?,
150                    title: row.get(1)?,
151                    model: row.get(2)?,
152                    provider: row.get(3)?,
153                    cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
154                    created_at: row.get(5)?,
155                    updated_at: row.get(6)?,
156                })
157            })
158            .context("Failed to list conversations")?;
159
160        let mut conversations = Vec::new();
161        for row in rows {
162            conversations.push(row.context("Failed to read conversation row")?);
163        }
164        Ok(conversations)
165    }
166
167    pub fn list_conversations_for_cwd(
168        &self,
169        cwd: &str,
170        limit: usize,
171    ) -> Result<Vec<ConversationSummary>> {
172        let mut stmt = self
173            .conn
174            .prepare(
175                "SELECT id, title, model, provider, cwd, created_at, updated_at \
176                 FROM conversations WHERE cwd = ?1 ORDER BY updated_at DESC LIMIT ?2",
177            )
178            .context("Failed to prepare list_conversations_for_cwd query")?;
179
180        let rows = stmt
181            .query_map(params![cwd, limit as i64], |row| {
182                Ok(ConversationSummary {
183                    id: row.get(0)?,
184                    title: row.get(1)?,
185                    model: row.get(2)?,
186                    provider: row.get(3)?,
187                    cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
188                    created_at: row.get(5)?,
189                    updated_at: row.get(6)?,
190                })
191            })
192            .context("Failed to list conversations for cwd")?;
193
194        let mut conversations = Vec::new();
195        for row in rows {
196            conversations.push(row.context("Failed to read conversation row")?);
197        }
198        Ok(conversations)
199    }
200
201    pub fn get_conversation(&self, id: &str) -> Result<Conversation> {
202        let (summary, last_input_tokens) = self
203            .conn
204            .query_row(
205                "SELECT id, title, model, provider, cwd, created_at, updated_at, last_input_tokens \
206                 FROM conversations WHERE id = ?1",
207                params![id],
208                |row| {
209                    Ok((
210                        ConversationSummary {
211                            id: row.get(0)?,
212                            title: row.get(1)?,
213                            model: row.get(2)?,
214                            provider: row.get(3)?,
215                            cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
216                            created_at: row.get(5)?,
217                            updated_at: row.get(6)?,
218                        },
219                        row.get::<_, i64>(7).unwrap_or(0) as u32,
220                    ))
221                },
222            )
223            .context("Failed to get conversation")?;
224
225        let messages = self.get_messages(id)?;
226        Ok(Conversation {
227            id: summary.id,
228            title: summary.title,
229            model: summary.model,
230            provider: summary.provider,
231            cwd: summary.cwd,
232            created_at: summary.created_at,
233            updated_at: summary.updated_at,
234            messages,
235            last_input_tokens,
236        })
237    }
238
239    pub fn update_conversation_title(&self, id: &str, title: &str) -> Result<()> {
240        let now = Utc::now().to_rfc3339();
241        self.conn
242            .execute(
243                "UPDATE conversations SET title = ?1, updated_at = ?2 WHERE id = ?3",
244                params![title, now, id],
245            )
246            .context("Failed to update conversation title")?;
247        Ok(())
248    }
249
250    pub fn delete_conversation(&self, id: &str) -> Result<()> {
251        self.conn
252            .execute(
253                "DELETE FROM tool_calls WHERE message_id IN \
254                 (SELECT id FROM messages WHERE conversation_id = ?1)",
255                params![id],
256            )
257            .context("Failed to delete tool calls for conversation")?;
258
259        self.conn
260            .execute(
261                "DELETE FROM messages WHERE conversation_id = ?1",
262                params![id],
263            )
264            .context("Failed to delete messages for conversation")?;
265
266        self.conn
267            .execute("DELETE FROM conversations WHERE id = ?1", params![id])
268            .context("Failed to delete conversation")?;
269
270        tracing::debug!("Deleted conversation {}", id);
271        Ok(())
272    }
273
274    pub fn truncate_messages(&self, conversation_id: &str, keep: usize) -> Result<()> {
275        let ids: Vec<String> = {
276            let mut stmt = self
277                .conn
278                .prepare(
279                    "SELECT id FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
280                )
281                .context("Failed to prepare truncate query")?;
282            let rows = stmt
283                .query_map(params![conversation_id], |row| row.get::<_, String>(0))
284                .context("Failed to query messages for truncation")?;
285            let mut all = Vec::new();
286            for row in rows {
287                all.push(row.context("Failed to read message id")?);
288            }
289            all
290        };
291        let to_delete = &ids[keep.min(ids.len())..];
292        for id in to_delete {
293            self.conn
294                .execute("DELETE FROM tool_calls WHERE message_id = ?1", params![id])
295                .context("Failed to delete tool calls for truncated message")?;
296            self.conn
297                .execute("DELETE FROM messages WHERE id = ?1", params![id])
298                .context("Failed to delete truncated message")?;
299        }
300        Ok(())
301    }
302
303    pub fn add_message(&self, conversation_id: &str, role: &str, content: &str) -> Result<String> {
304        let id = Uuid::new_v4().to_string();
305        let now = Utc::now().to_rfc3339();
306        self.conn
307            .execute(
308                "INSERT INTO messages \
309                 (id, conversation_id, role, content, token_count, created_at) \
310                 VALUES (?1, ?2, ?3, ?4, 0, ?5)",
311                params![id, conversation_id, role, content, now],
312            )
313            .context("Failed to add message")?;
314
315        self.conn
316            .execute(
317                "UPDATE conversations SET updated_at = ?1 WHERE id = ?2",
318                params![now, conversation_id],
319            )
320            .context("Failed to update conversation timestamp")?;
321
322        tracing::debug!("Added message {} to conversation {}", id, conversation_id);
323        Ok(id)
324    }
325
326    pub fn get_messages(&self, conversation_id: &str) -> Result<Vec<DbMessage>> {
327        let mut stmt = self
328            .conn
329            .prepare(
330                "SELECT id, conversation_id, role, content, token_count, created_at \
331                 FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
332            )
333            .context("Failed to prepare get_messages query")?;
334
335        let rows = stmt
336            .query_map(params![conversation_id], |row| {
337                Ok(DbMessage {
338                    id: row.get(0)?,
339                    conversation_id: row.get(1)?,
340                    role: row.get(2)?,
341                    content: row.get(3)?,
342                    token_count: row.get::<_, i64>(4)? as u32,
343                    created_at: row.get(5)?,
344                })
345            })
346            .context("Failed to get messages")?;
347
348        let mut messages = Vec::new();
349        for row in rows {
350            messages.push(row.context("Failed to read message row")?);
351        }
352        Ok(messages)
353    }
354
355    pub fn update_last_input_tokens(&self, conversation_id: &str, tokens: u32) -> Result<()> {
356        self.conn
357            .execute(
358                "UPDATE conversations SET last_input_tokens = ?1 WHERE id = ?2",
359                params![tokens as i64, conversation_id],
360            )
361            .context("Failed to update last_input_tokens")?;
362        Ok(())
363    }
364
365    pub fn update_message_tokens(&self, id: &str, tokens: u32) -> Result<()> {
366        self.conn
367            .execute(
368                "UPDATE messages SET token_count = ?1 WHERE id = ?2",
369                params![tokens as i64, id],
370            )
371            .context("Failed to update message tokens")?;
372        Ok(())
373    }
374
375    pub fn add_tool_call(
376        &self,
377        message_id: &str,
378        tool_id: &str,
379        name: &str,
380        input: &str,
381    ) -> Result<()> {
382        let now = Utc::now().to_rfc3339();
383        self.conn
384            .execute(
385                "INSERT INTO tool_calls \
386                 (id, message_id, name, input, output, is_error, created_at) \
387                 VALUES (?1, ?2, ?3, ?4, NULL, 0, ?5)",
388                params![tool_id, message_id, name, input, now],
389            )
390            .context("Failed to add tool call")?;
391        tracing::debug!("Added tool call {} for message {}", tool_id, message_id);
392        Ok(())
393    }
394
395    pub fn update_tool_result(&self, tool_id: &str, output: &str, is_error: bool) -> Result<()> {
396        self.conn
397            .execute(
398                "UPDATE tool_calls SET output = ?1, is_error = ?2 WHERE id = ?3",
399                params![output, is_error as i64, tool_id],
400            )
401            .context("Failed to update tool result")?;
402        Ok(())
403    }
404
405    pub fn get_tool_calls(&self, message_id: &str) -> Result<Vec<DbToolCall>> {
406        let mut stmt = self
407            .conn
408            .prepare(
409                "SELECT id, message_id, name, input, output, is_error, created_at \
410                 FROM tool_calls WHERE message_id = ?1 ORDER BY created_at ASC",
411            )
412            .context("Failed to prepare get_tool_calls query")?;
413
414        let rows = stmt
415            .query_map(params![message_id], |row| {
416                Ok(DbToolCall {
417                    id: row.get(0)?,
418                    message_id: row.get(1)?,
419                    name: row.get(2)?,
420                    input: row.get(3)?,
421                    output: row.get(4)?,
422                    is_error: row.get::<_, i64>(5)? != 0,
423                    created_at: row.get(6)?,
424                })
425            })
426            .context("Failed to get tool calls")?;
427
428        let mut calls = Vec::new();
429        for row in rows {
430            calls.push(row.context("Failed to read tool call row")?);
431        }
432        Ok(calls)
433    }
434
435    pub fn get_user_message_history(&self, limit: usize) -> Result<Vec<String>> {
436        let mut stmt = self
437            .conn
438            .prepare(
439                "SELECT content FROM messages WHERE role = 'user' \
440                 ORDER BY created_at DESC LIMIT ?1",
441            )
442            .context("Failed to prepare user history query")?;
443
444        let rows = stmt
445            .query_map(params![limit as i64], |row| row.get::<_, String>(0))
446            .context("Failed to query user history")?;
447
448        let mut messages = Vec::new();
449        for row in rows {
450            messages.push(row.context("Failed to read history row")?);
451        }
452        messages.reverse();
453        Ok(messages)
454    }
455
456    pub fn create_task(&self, id: &str, prompt: &str, pid: u32, cwd: &str) -> Result<()> {
457        let now = Utc::now().to_rfc3339();
458        self.conn
459            .execute(
460                "INSERT INTO tasks (id, prompt, status, pid, cwd, created_at) \
461                 VALUES (?1, ?2, 'running', ?3, ?4, ?5)",
462                params![id, prompt, pid as i64, cwd, now],
463            )
464            .context("creating task")?;
465        Ok(())
466    }
467
468    pub fn complete_task(
469        &self,
470        id: &str,
471        status: &str,
472        session_id: Option<&str>,
473        output: &str,
474    ) -> Result<()> {
475        let now = Utc::now().to_rfc3339();
476        self.conn
477            .execute(
478                "UPDATE tasks SET status = ?1, session_id = ?2, output = ?3, completed_at = ?4 WHERE id = ?5",
479                params![status, session_id, output, now, id],
480            )
481            .context("completing task")?;
482        Ok(())
483    }
484
485    pub fn list_tasks(&self, limit: usize) -> Result<Vec<TaskRecord>> {
486        let mut stmt = self
487            .conn
488            .prepare(
489                "SELECT id, prompt, status, session_id, pid, output, cwd, created_at, completed_at \
490                 FROM tasks ORDER BY created_at DESC LIMIT ?1",
491            )
492            .context("preparing list_tasks")?;
493        let rows = stmt
494            .query_map(params![limit as i64], |row| {
495                Ok(TaskRecord {
496                    id: row.get(0)?,
497                    prompt: row.get(1)?,
498                    status: row.get(2)?,
499                    session_id: row.get(3)?,
500                    pid: row.get(4)?,
501                    output: row.get(5)?,
502                    cwd: row.get::<_, Option<String>>(6)?.unwrap_or_default(),
503                    created_at: row.get(7)?,
504                    completed_at: row.get(8)?,
505                })
506            })
507            .context("listing tasks")?;
508        let mut tasks = Vec::new();
509        for row in rows {
510            tasks.push(row.context("reading task row")?);
511        }
512        Ok(tasks)
513    }
514
515    pub fn get_task(&self, id: &str) -> Result<TaskRecord> {
516        self.conn
517            .query_row(
518                "SELECT id, prompt, status, session_id, pid, output, cwd, created_at, completed_at \
519                 FROM tasks WHERE id = ?1",
520                params![id],
521                |row| {
522                    Ok(TaskRecord {
523                        id: row.get(0)?,
524                        prompt: row.get(1)?,
525                        status: row.get(2)?,
526                        session_id: row.get(3)?,
527                        pid: row.get(4)?,
528                        output: row.get(5)?,
529                        cwd: row.get::<_, Option<String>>(6)?.unwrap_or_default(),
530                        created_at: row.get(7)?,
531                        completed_at: row.get(8)?,
532                    })
533                },
534            )
535            .context("getting task")
536    }
537}