Skip to main content

dot/db/
mod.rs

1mod schema;
2
3use anyhow::{Context, Result};
4use chrono::Utc;
5use rusqlite::{Connection, params};
6use std::path::PathBuf;
7use uuid::Uuid;
8
9fn db_path() -> Result<PathBuf> {
10    let dot_dir = crate::config::Config::data_dir();
11    std::fs::create_dir_all(&dot_dir).context("Could not create dot data directory")?;
12    Ok(dot_dir.join("dot.db"))
13}
14
15#[derive(Debug, Clone)]
16pub struct ConversationSummary {
17    pub id: String,
18    pub title: Option<String>,
19    pub model: String,
20    pub provider: String,
21    pub cwd: String,
22    pub created_at: String,
23    pub updated_at: String,
24}
25
26#[derive(Debug, Clone)]
27pub struct Conversation {
28    pub id: String,
29    pub title: Option<String>,
30    pub model: String,
31    pub provider: String,
32    pub cwd: String,
33    pub created_at: String,
34    pub updated_at: String,
35    pub messages: Vec<DbMessage>,
36}
37
38#[derive(Debug, Clone)]
39pub struct DbMessage {
40    pub id: String,
41    pub conversation_id: String,
42    pub role: String,
43    pub content: String,
44    pub token_count: u32,
45    pub created_at: String,
46}
47
48#[derive(Debug, Clone)]
49pub struct DbToolCall {
50    pub id: String,
51    pub message_id: String,
52    pub name: String,
53    pub input: String,
54    pub output: Option<String>,
55    pub is_error: bool,
56    pub created_at: String,
57}
58
59pub struct Db {
60    conn: Connection,
61}
62
63impl Db {
64    pub fn open() -> Result<Self> {
65        let path = db_path()?;
66        tracing::debug!("Opening database at {:?}", path);
67        let conn = Connection::open(&path)
68            .with_context(|| format!("Failed to open database at {:?}", path))?;
69        let db = Db { conn };
70        db.init()?;
71        Ok(db)
72    }
73
74    pub fn init(&self) -> Result<()> {
75        self.conn
76            .execute_batch(&format!(
77                "{}\n;\n{}\n;\n{}",
78                schema::CREATE_CONVERSATIONS,
79                schema::CREATE_MESSAGES,
80                schema::CREATE_TOOL_CALLS,
81            ))
82            .context("Failed to initialize database schema")?;
83
84        let _ = self.conn.execute(
85            "ALTER TABLE conversations ADD COLUMN cwd TEXT NOT NULL DEFAULT ''",
86            [],
87        );
88
89        tracing::debug!("Database schema initialized");
90        Ok(())
91    }
92
93    pub fn create_conversation(&self, model: &str, provider: &str, cwd: &str) -> Result<String> {
94        let id = Uuid::new_v4().to_string();
95        let now = Utc::now().to_rfc3339();
96        self.conn
97            .execute(
98                "INSERT INTO conversations (id, title, model, provider, cwd, created_at, updated_at) \
99                 VALUES (?1, NULL, ?2, ?3, ?4, ?5, ?6)",
100                params![id, model, provider, cwd, now, now],
101            )
102            .context("Failed to create conversation")?;
103        tracing::debug!("Created conversation {}", id);
104        Ok(id)
105    }
106
107    pub fn list_conversations(&self, limit: usize) -> Result<Vec<ConversationSummary>> {
108        let mut stmt = self
109            .conn
110            .prepare(
111                "SELECT id, title, model, provider, cwd, created_at, updated_at \
112                 FROM conversations ORDER BY updated_at DESC LIMIT ?1",
113            )
114            .context("Failed to prepare list_conversations query")?;
115
116        let rows = stmt
117            .query_map(params![limit as i64], |row| {
118                Ok(ConversationSummary {
119                    id: row.get(0)?,
120                    title: row.get(1)?,
121                    model: row.get(2)?,
122                    provider: row.get(3)?,
123                    cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
124                    created_at: row.get(5)?,
125                    updated_at: row.get(6)?,
126                })
127            })
128            .context("Failed to list conversations")?;
129
130        let mut conversations = Vec::new();
131        for row in rows {
132            conversations.push(row.context("Failed to read conversation row")?);
133        }
134        Ok(conversations)
135    }
136
137    pub fn list_conversations_for_cwd(
138        &self,
139        cwd: &str,
140        limit: usize,
141    ) -> Result<Vec<ConversationSummary>> {
142        let mut stmt = self
143            .conn
144            .prepare(
145                "SELECT id, title, model, provider, cwd, created_at, updated_at \
146                 FROM conversations WHERE cwd = ?1 ORDER BY updated_at DESC LIMIT ?2",
147            )
148            .context("Failed to prepare list_conversations_for_cwd query")?;
149
150        let rows = stmt
151            .query_map(params![cwd, limit as i64], |row| {
152                Ok(ConversationSummary {
153                    id: row.get(0)?,
154                    title: row.get(1)?,
155                    model: row.get(2)?,
156                    provider: row.get(3)?,
157                    cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
158                    created_at: row.get(5)?,
159                    updated_at: row.get(6)?,
160                })
161            })
162            .context("Failed to list conversations for cwd")?;
163
164        let mut conversations = Vec::new();
165        for row in rows {
166            conversations.push(row.context("Failed to read conversation row")?);
167        }
168        Ok(conversations)
169    }
170
171    pub fn get_conversation(&self, id: &str) -> Result<Conversation> {
172        let summary: ConversationSummary = self
173            .conn
174            .query_row(
175                "SELECT id, title, model, provider, cwd, created_at, updated_at \
176                 FROM conversations WHERE id = ?1",
177                params![id],
178                |row| {
179                    Ok(ConversationSummary {
180                        id: row.get(0)?,
181                        title: row.get(1)?,
182                        model: row.get(2)?,
183                        provider: row.get(3)?,
184                        cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
185                        created_at: row.get(5)?,
186                        updated_at: row.get(6)?,
187                    })
188                },
189            )
190            .context("Failed to get conversation")?;
191
192        let messages = self.get_messages(id)?;
193        Ok(Conversation {
194            id: summary.id,
195            title: summary.title,
196            model: summary.model,
197            provider: summary.provider,
198            cwd: summary.cwd,
199            created_at: summary.created_at,
200            updated_at: summary.updated_at,
201            messages,
202        })
203    }
204
205    pub fn update_conversation_title(&self, id: &str, title: &str) -> Result<()> {
206        let now = Utc::now().to_rfc3339();
207        self.conn
208            .execute(
209                "UPDATE conversations SET title = ?1, updated_at = ?2 WHERE id = ?3",
210                params![title, now, id],
211            )
212            .context("Failed to update conversation title")?;
213        Ok(())
214    }
215
216    pub fn delete_conversation(&self, id: &str) -> Result<()> {
217        self.conn
218            .execute(
219                "DELETE FROM tool_calls WHERE message_id IN \
220                 (SELECT id FROM messages WHERE conversation_id = ?1)",
221                params![id],
222            )
223            .context("Failed to delete tool calls for conversation")?;
224
225        self.conn
226            .execute(
227                "DELETE FROM messages WHERE conversation_id = ?1",
228                params![id],
229            )
230            .context("Failed to delete messages for conversation")?;
231
232        self.conn
233            .execute("DELETE FROM conversations WHERE id = ?1", params![id])
234            .context("Failed to delete conversation")?;
235
236        tracing::debug!("Deleted conversation {}", id);
237        Ok(())
238    }
239
240    pub fn add_message(&self, conversation_id: &str, role: &str, content: &str) -> Result<String> {
241        let id = Uuid::new_v4().to_string();
242        let now = Utc::now().to_rfc3339();
243        self.conn
244            .execute(
245                "INSERT INTO messages \
246                 (id, conversation_id, role, content, token_count, created_at) \
247                 VALUES (?1, ?2, ?3, ?4, 0, ?5)",
248                params![id, conversation_id, role, content, now],
249            )
250            .context("Failed to add message")?;
251
252        self.conn
253            .execute(
254                "UPDATE conversations SET updated_at = ?1 WHERE id = ?2",
255                params![now, conversation_id],
256            )
257            .context("Failed to update conversation timestamp")?;
258
259        tracing::debug!("Added message {} to conversation {}", id, conversation_id);
260        Ok(id)
261    }
262
263    pub fn get_messages(&self, conversation_id: &str) -> Result<Vec<DbMessage>> {
264        let mut stmt = self
265            .conn
266            .prepare(
267                "SELECT id, conversation_id, role, content, token_count, created_at \
268                 FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
269            )
270            .context("Failed to prepare get_messages query")?;
271
272        let rows = stmt
273            .query_map(params![conversation_id], |row| {
274                Ok(DbMessage {
275                    id: row.get(0)?,
276                    conversation_id: row.get(1)?,
277                    role: row.get(2)?,
278                    content: row.get(3)?,
279                    token_count: row.get::<_, i64>(4)? as u32,
280                    created_at: row.get(5)?,
281                })
282            })
283            .context("Failed to get messages")?;
284
285        let mut messages = Vec::new();
286        for row in rows {
287            messages.push(row.context("Failed to read message row")?);
288        }
289        Ok(messages)
290    }
291
292    pub fn update_message_tokens(&self, id: &str, tokens: u32) -> Result<()> {
293        self.conn
294            .execute(
295                "UPDATE messages SET token_count = ?1 WHERE id = ?2",
296                params![tokens as i64, id],
297            )
298            .context("Failed to update message tokens")?;
299        Ok(())
300    }
301
302    pub fn add_tool_call(
303        &self,
304        message_id: &str,
305        tool_id: &str,
306        name: &str,
307        input: &str,
308    ) -> Result<()> {
309        let now = Utc::now().to_rfc3339();
310        self.conn
311            .execute(
312                "INSERT INTO tool_calls \
313                 (id, message_id, name, input, output, is_error, created_at) \
314                 VALUES (?1, ?2, ?3, ?4, NULL, 0, ?5)",
315                params![tool_id, message_id, name, input, now],
316            )
317            .context("Failed to add tool call")?;
318        tracing::debug!("Added tool call {} for message {}", tool_id, message_id);
319        Ok(())
320    }
321
322    pub fn update_tool_result(&self, tool_id: &str, output: &str, is_error: bool) -> Result<()> {
323        self.conn
324            .execute(
325                "UPDATE tool_calls SET output = ?1, is_error = ?2 WHERE id = ?3",
326                params![output, is_error as i64, tool_id],
327            )
328            .context("Failed to update tool result")?;
329        Ok(())
330    }
331
332    pub fn get_tool_calls(&self, message_id: &str) -> Result<Vec<DbToolCall>> {
333        let mut stmt = self
334            .conn
335            .prepare(
336                "SELECT id, message_id, name, input, output, is_error, created_at \
337                 FROM tool_calls WHERE message_id = ?1 ORDER BY created_at ASC",
338            )
339            .context("Failed to prepare get_tool_calls query")?;
340
341        let rows = stmt
342            .query_map(params![message_id], |row| {
343                Ok(DbToolCall {
344                    id: row.get(0)?,
345                    message_id: row.get(1)?,
346                    name: row.get(2)?,
347                    input: row.get(3)?,
348                    output: row.get(4)?,
349                    is_error: row.get::<_, i64>(5)? != 0,
350                    created_at: row.get(6)?,
351                })
352            })
353            .context("Failed to get tool calls")?;
354
355        let mut calls = Vec::new();
356        for row in rows {
357            calls.push(row.context("Failed to read tool call row")?);
358        }
359        Ok(calls)
360    }
361}