Skip to main content

dot/db/
mod.rs

1pub(crate) mod schema;
2
3use anyhow::{Context, Result};
4use chrono::Utc;
5use rusqlite::{Connection, params};
6use std::path::PathBuf;
7use uuid::Uuid;
8
9fn db_path() -> Result<PathBuf> {
10    let dot_dir = crate::config::Config::data_dir();
11    std::fs::create_dir_all(&dot_dir).context("Could not create dot data directory")?;
12    Ok(dot_dir.join("dot.db"))
13}
14
15#[derive(Debug, Clone)]
16pub struct ConversationSummary {
17    pub id: String,
18    pub title: Option<String>,
19    pub model: String,
20    pub provider: String,
21    pub cwd: String,
22    pub created_at: String,
23    pub updated_at: String,
24}
25
26#[derive(Debug, Clone)]
27pub struct Conversation {
28    pub id: String,
29    pub title: Option<String>,
30    pub model: String,
31    pub provider: String,
32    pub cwd: String,
33    pub created_at: String,
34    pub updated_at: String,
35    pub messages: Vec<DbMessage>,
36    pub last_input_tokens: u32,
37}
38
39#[derive(Debug, Clone)]
40pub struct DbMessage {
41    pub id: String,
42    pub conversation_id: String,
43    pub role: String,
44    pub content: String,
45    pub token_count: u32,
46    pub created_at: String,
47}
48
49#[derive(Debug, Clone)]
50pub struct DbToolCall {
51    pub id: String,
52    pub message_id: String,
53    pub name: String,
54    pub input: String,
55    pub output: Option<String>,
56    pub is_error: bool,
57    pub created_at: String,
58}
59
60pub struct Db {
61    conn: Connection,
62}
63
64impl Db {
65    pub fn open() -> Result<Self> {
66        let path = db_path()?;
67        tracing::debug!("Opening database at {:?}", path);
68        let conn = Connection::open(&path)
69            .with_context(|| format!("Failed to open database at {:?}", path))?;
70        let db = Db { conn };
71        db.init()?;
72        Ok(db)
73    }
74
75    pub fn init(&self) -> Result<()> {
76        self.conn
77            .execute_batch(&format!(
78                "{}\n;\n{}\n;\n{}",
79                schema::CREATE_CONVERSATIONS,
80                schema::CREATE_MESSAGES,
81                schema::CREATE_TOOL_CALLS,
82            ))
83            .context("Failed to initialize database schema")?;
84
85        let _ = self.conn.execute(
86            "ALTER TABLE conversations ADD COLUMN cwd TEXT NOT NULL DEFAULT ''",
87            [],
88        );
89        let _ = self.conn.execute(
90            "ALTER TABLE conversations ADD COLUMN last_input_tokens INTEGER NOT NULL DEFAULT 0",
91            [],
92        );
93
94        tracing::debug!("Database schema initialized");
95        Ok(())
96    }
97
98    pub fn create_conversation(&self, model: &str, provider: &str, cwd: &str) -> Result<String> {
99        let id = Uuid::new_v4().to_string();
100        let now = Utc::now().to_rfc3339();
101        self.conn
102            .execute(
103                "INSERT INTO conversations (id, title, model, provider, cwd, created_at, updated_at) \
104                 VALUES (?1, NULL, ?2, ?3, ?4, ?5, ?6)",
105                params![id, model, provider, cwd, now, now],
106            )
107            .context("Failed to create conversation")?;
108        tracing::debug!("Created conversation {}", id);
109        Ok(id)
110    }
111
112    pub fn list_conversations(&self, limit: usize) -> Result<Vec<ConversationSummary>> {
113        let mut stmt = self
114            .conn
115            .prepare(
116                "SELECT id, title, model, provider, cwd, created_at, updated_at \
117                 FROM conversations ORDER BY updated_at DESC LIMIT ?1",
118            )
119            .context("Failed to prepare list_conversations query")?;
120
121        let rows = stmt
122            .query_map(params![limit as i64], |row| {
123                Ok(ConversationSummary {
124                    id: row.get(0)?,
125                    title: row.get(1)?,
126                    model: row.get(2)?,
127                    provider: row.get(3)?,
128                    cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
129                    created_at: row.get(5)?,
130                    updated_at: row.get(6)?,
131                })
132            })
133            .context("Failed to list conversations")?;
134
135        let mut conversations = Vec::new();
136        for row in rows {
137            conversations.push(row.context("Failed to read conversation row")?);
138        }
139        Ok(conversations)
140    }
141
142    pub fn list_conversations_for_cwd(
143        &self,
144        cwd: &str,
145        limit: usize,
146    ) -> Result<Vec<ConversationSummary>> {
147        let mut stmt = self
148            .conn
149            .prepare(
150                "SELECT id, title, model, provider, cwd, created_at, updated_at \
151                 FROM conversations WHERE cwd = ?1 ORDER BY updated_at DESC LIMIT ?2",
152            )
153            .context("Failed to prepare list_conversations_for_cwd query")?;
154
155        let rows = stmt
156            .query_map(params![cwd, limit as i64], |row| {
157                Ok(ConversationSummary {
158                    id: row.get(0)?,
159                    title: row.get(1)?,
160                    model: row.get(2)?,
161                    provider: row.get(3)?,
162                    cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
163                    created_at: row.get(5)?,
164                    updated_at: row.get(6)?,
165                })
166            })
167            .context("Failed to list conversations for cwd")?;
168
169        let mut conversations = Vec::new();
170        for row in rows {
171            conversations.push(row.context("Failed to read conversation row")?);
172        }
173        Ok(conversations)
174    }
175
176    pub fn get_conversation(&self, id: &str) -> Result<Conversation> {
177        let (summary, last_input_tokens) = self
178            .conn
179            .query_row(
180                "SELECT id, title, model, provider, cwd, created_at, updated_at, last_input_tokens \
181                 FROM conversations WHERE id = ?1",
182                params![id],
183                |row| {
184                    Ok((
185                        ConversationSummary {
186                            id: row.get(0)?,
187                            title: row.get(1)?,
188                            model: row.get(2)?,
189                            provider: row.get(3)?,
190                            cwd: row.get::<_, Option<String>>(4)?.unwrap_or_default(),
191                            created_at: row.get(5)?,
192                            updated_at: row.get(6)?,
193                        },
194                        row.get::<_, i64>(7).unwrap_or(0) as u32,
195                    ))
196                },
197            )
198            .context("Failed to get conversation")?;
199
200        let messages = self.get_messages(id)?;
201        Ok(Conversation {
202            id: summary.id,
203            title: summary.title,
204            model: summary.model,
205            provider: summary.provider,
206            cwd: summary.cwd,
207            created_at: summary.created_at,
208            updated_at: summary.updated_at,
209            messages,
210            last_input_tokens,
211        })
212    }
213
214    pub fn update_conversation_title(&self, id: &str, title: &str) -> Result<()> {
215        let now = Utc::now().to_rfc3339();
216        self.conn
217            .execute(
218                "UPDATE conversations SET title = ?1, updated_at = ?2 WHERE id = ?3",
219                params![title, now, id],
220            )
221            .context("Failed to update conversation title")?;
222        Ok(())
223    }
224
225    pub fn delete_conversation(&self, id: &str) -> Result<()> {
226        self.conn
227            .execute(
228                "DELETE FROM tool_calls WHERE message_id IN \
229                 (SELECT id FROM messages WHERE conversation_id = ?1)",
230                params![id],
231            )
232            .context("Failed to delete tool calls for conversation")?;
233
234        self.conn
235            .execute(
236                "DELETE FROM messages WHERE conversation_id = ?1",
237                params![id],
238            )
239            .context("Failed to delete messages for conversation")?;
240
241        self.conn
242            .execute("DELETE FROM conversations WHERE id = ?1", params![id])
243            .context("Failed to delete conversation")?;
244
245        tracing::debug!("Deleted conversation {}", id);
246        Ok(())
247    }
248
249    pub fn truncate_messages(&self, conversation_id: &str, keep: usize) -> Result<()> {
250        let ids: Vec<String> = {
251            let mut stmt = self
252                .conn
253                .prepare(
254                    "SELECT id FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
255                )
256                .context("Failed to prepare truncate query")?;
257            let rows = stmt
258                .query_map(params![conversation_id], |row| row.get::<_, String>(0))
259                .context("Failed to query messages for truncation")?;
260            let mut all = Vec::new();
261            for row in rows {
262                all.push(row.context("Failed to read message id")?);
263            }
264            all
265        };
266        let to_delete = &ids[keep.min(ids.len())..];
267        for id in to_delete {
268            self.conn
269                .execute("DELETE FROM tool_calls WHERE message_id = ?1", params![id])
270                .context("Failed to delete tool calls for truncated message")?;
271            self.conn
272                .execute("DELETE FROM messages WHERE id = ?1", params![id])
273                .context("Failed to delete truncated message")?;
274        }
275        Ok(())
276    }
277
278    pub fn add_message(&self, conversation_id: &str, role: &str, content: &str) -> Result<String> {
279        let id = Uuid::new_v4().to_string();
280        let now = Utc::now().to_rfc3339();
281        self.conn
282            .execute(
283                "INSERT INTO messages \
284                 (id, conversation_id, role, content, token_count, created_at) \
285                 VALUES (?1, ?2, ?3, ?4, 0, ?5)",
286                params![id, conversation_id, role, content, now],
287            )
288            .context("Failed to add message")?;
289
290        self.conn
291            .execute(
292                "UPDATE conversations SET updated_at = ?1 WHERE id = ?2",
293                params![now, conversation_id],
294            )
295            .context("Failed to update conversation timestamp")?;
296
297        tracing::debug!("Added message {} to conversation {}", id, conversation_id);
298        Ok(id)
299    }
300
301    pub fn get_messages(&self, conversation_id: &str) -> Result<Vec<DbMessage>> {
302        let mut stmt = self
303            .conn
304            .prepare(
305                "SELECT id, conversation_id, role, content, token_count, created_at \
306                 FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC",
307            )
308            .context("Failed to prepare get_messages query")?;
309
310        let rows = stmt
311            .query_map(params![conversation_id], |row| {
312                Ok(DbMessage {
313                    id: row.get(0)?,
314                    conversation_id: row.get(1)?,
315                    role: row.get(2)?,
316                    content: row.get(3)?,
317                    token_count: row.get::<_, i64>(4)? as u32,
318                    created_at: row.get(5)?,
319                })
320            })
321            .context("Failed to get messages")?;
322
323        let mut messages = Vec::new();
324        for row in rows {
325            messages.push(row.context("Failed to read message row")?);
326        }
327        Ok(messages)
328    }
329
330    pub fn update_last_input_tokens(&self, conversation_id: &str, tokens: u32) -> Result<()> {
331        self.conn
332            .execute(
333                "UPDATE conversations SET last_input_tokens = ?1 WHERE id = ?2",
334                params![tokens as i64, conversation_id],
335            )
336            .context("Failed to update last_input_tokens")?;
337        Ok(())
338    }
339
340    pub fn update_message_tokens(&self, id: &str, tokens: u32) -> Result<()> {
341        self.conn
342            .execute(
343                "UPDATE messages SET token_count = ?1 WHERE id = ?2",
344                params![tokens as i64, id],
345            )
346            .context("Failed to update message tokens")?;
347        Ok(())
348    }
349
350    pub fn add_tool_call(
351        &self,
352        message_id: &str,
353        tool_id: &str,
354        name: &str,
355        input: &str,
356    ) -> Result<()> {
357        let now = Utc::now().to_rfc3339();
358        self.conn
359            .execute(
360                "INSERT INTO tool_calls \
361                 (id, message_id, name, input, output, is_error, created_at) \
362                 VALUES (?1, ?2, ?3, ?4, NULL, 0, ?5)",
363                params![tool_id, message_id, name, input, now],
364            )
365            .context("Failed to add tool call")?;
366        tracing::debug!("Added tool call {} for message {}", tool_id, message_id);
367        Ok(())
368    }
369
370    pub fn update_tool_result(&self, tool_id: &str, output: &str, is_error: bool) -> Result<()> {
371        self.conn
372            .execute(
373                "UPDATE tool_calls SET output = ?1, is_error = ?2 WHERE id = ?3",
374                params![output, is_error as i64, tool_id],
375            )
376            .context("Failed to update tool result")?;
377        Ok(())
378    }
379
380    pub fn get_tool_calls(&self, message_id: &str) -> Result<Vec<DbToolCall>> {
381        let mut stmt = self
382            .conn
383            .prepare(
384                "SELECT id, message_id, name, input, output, is_error, created_at \
385                 FROM tool_calls WHERE message_id = ?1 ORDER BY created_at ASC",
386            )
387            .context("Failed to prepare get_tool_calls query")?;
388
389        let rows = stmt
390            .query_map(params![message_id], |row| {
391                Ok(DbToolCall {
392                    id: row.get(0)?,
393                    message_id: row.get(1)?,
394                    name: row.get(2)?,
395                    input: row.get(3)?,
396                    output: row.get(4)?,
397                    is_error: row.get::<_, i64>(5)? != 0,
398                    created_at: row.get(6)?,
399                })
400            })
401            .context("Failed to get tool calls")?;
402
403        let mut calls = Vec::new();
404        for row in rows {
405            calls.push(row.context("Failed to read tool call row")?);
406        }
407        Ok(calls)
408    }
409
410    pub fn get_user_message_history(&self, limit: usize) -> Result<Vec<String>> {
411        let mut stmt = self
412            .conn
413            .prepare(
414                "SELECT content FROM messages WHERE role = 'user' \
415                 ORDER BY created_at DESC LIMIT ?1",
416            )
417            .context("Failed to prepare user history query")?;
418
419        let rows = stmt
420            .query_map(params![limit as i64], |row| row.get::<_, String>(0))
421            .context("Failed to query user history")?;
422
423        let mut messages = Vec::new();
424        for row in rows {
425            messages.push(row.context("Failed to read history row")?);
426        }
427        messages.reverse();
428        Ok(messages)
429    }
430}