Skip to main content

koda_core/db/
mod.rs

1//! SQLite persistence layer.
2//!
3//! Implements the [`crate::persistence::Persistence`] trait for SQLite via sqlx.
4//! Uses WAL mode for concurrent read/write access.
5//!
6//! ## Database location
7//!
8//! - **Default**: `~/.config/koda/koda.db`
9//! - Schema is auto-migrated on startup
10//! - WAL mode enables concurrent reads (main session + sub-agents)
11//!
12//! ## What's stored
13//!
14//! - **Conversation history** — all messages, tool calls, and results
15//! - **Sessions** — session metadata, timestamps, model info
16//! - **File ownership** — which files Koda created (for auto-approve Delete)
17//! - **Progress entries** — survive compaction for persistent tracking
18//! - **KV store** — settings (last provider) and API keys (#693)
19//! - **Input history** — REPL command history (#693)
20//!
21//! ## Module layout
22//!
23//! - **mod.rs** — `Database` struct, init/open, schema migrations, row types
24//! - **queries.rs** — `Persistence` trait implementation (all SQL queries)
25
26pub mod queries;
27#[cfg(test)]
28mod tests;
29
30use anyhow::{Context, Result};
31use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
32use std::path::Path;
33use std::str::FromStr;
34
35/// Re-export persistence types for backward compatibility.
36pub use crate::persistence::{
37    CompactedStats, Message, Persistence, Role, SessionInfo, SessionUsage,
38};
39
40/// Wrapper around the SQLite connection pool.
41#[derive(Debug, Clone)]
42pub struct Database {
43    pub(crate) pool: SqlitePool,
44}
45
46/// Get the koda config directory (~/.config/koda/).
47pub fn config_dir() -> Result<std::path::PathBuf> {
48    let base = std::env::var("XDG_CONFIG_HOME")
49        .ok()
50        .map(std::path::PathBuf::from)
51        .or_else(|| {
52            // Unix: $HOME/.config  (XDG Base Directory spec fallback)
53            std::env::var("HOME")
54                .ok()
55                .map(|h| std::path::PathBuf::from(h).join(".config"))
56        })
57        .or({
58            // Windows: %APPDATA%  (e.g. C:\Users\Alice\AppData\Roaming)
59            #[cfg(windows)]
60            {
61                std::env::var("APPDATA").ok().map(std::path::PathBuf::from)
62            }
63            #[cfg(not(windows))]
64            {
65                None
66            }
67        })
68        .ok_or_else(|| {
69            anyhow::anyhow!(
70                "Cannot determine config directory \
71                 (set XDG_CONFIG_HOME, HOME, or APPDATA)"
72            )
73        })?;
74    Ok(base.join("koda"))
75}
76
77impl Database {
78    /// Access the underlying connection pool (for tests and raw queries).
79    pub fn pool(&self) -> &SqlitePool {
80        &self.pool
81    }
82
83    /// Initialize the database, run migrations, and enable WAL mode.
84    ///
85    /// `koda_config_dir` is the koda configuration directory (e.g. `~/.config/koda`).
86    /// The database lives in `<koda_config_dir>/db/koda.db`.
87    ///
88    /// Production callers should pass `db::config_dir()?`; tests pass a temp dir.
89    pub async fn init(koda_config_dir: &Path) -> Result<Self> {
90        let db_dir = koda_config_dir.join("db");
91        std::fs::create_dir_all(&db_dir)
92            .with_context(|| format!("Failed to create DB dir: {}", db_dir.display()))?;
93
94        let db_path = db_dir.join("koda.db");
95
96        let db = Self::open(&db_path).await?;
97
98        // Ensure restrictive permissions — DB contains API keys and
99        // conversation history that may include secrets (#693).
100        #[cfg(unix)]
101        Self::set_db_permissions(&db_path);
102
103        Ok(db)
104    }
105
106    /// Open a database at a specific path (used by tests and init).
107    pub async fn open(db_path: &Path) -> Result<Self> {
108        let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
109
110        let options = SqliteConnectOptions::from_str(&db_url)?
111            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
112            .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
113            .foreign_keys(true)
114            .create_if_missing(true)
115            // Retry for up to 5 s when another connection holds the write
116            // lock. Without this, concurrent writes from parallel sub-agents
117            // (#595) return SQLITE_BUSY immediately and the insert is silently
118            // dropped. Individual writes are ~1 ms so the retry resolves fast.
119            .busy_timeout(std::time::Duration::from_millis(5000));
120
121        let pool = SqlitePoolOptions::new()
122            .max_connections(5)
123            .connect_with(options)
124            .await
125            .with_context(|| format!("Failed to connect to database: {db_url}"))?;
126
127        // Run schema migrations
128        Self::migrate(&pool).await?;
129        Ok(Self { pool })
130    }
131
132    /// Apply the schema (idempotent).
133    async fn migrate(pool: &SqlitePool) -> Result<()> {
134        sqlx::query(
135            "CREATE TABLE IF NOT EXISTS sessions (
136                id TEXT PRIMARY KEY,
137                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
138                agent_name TEXT NOT NULL,
139                project_root TEXT,
140                last_accessed_at TEXT,
141                title TEXT,
142                mode TEXT
143            );",
144        )
145        .execute(pool)
146        .await?;
147
148        sqlx::query(
149            "CREATE TABLE IF NOT EXISTS messages (
150                id INTEGER PRIMARY KEY AUTOINCREMENT,
151                session_id TEXT NOT NULL,
152                role TEXT NOT NULL,
153                content TEXT,
154                full_content TEXT,
155                tool_calls TEXT,
156                tool_call_id TEXT,
157                prompt_tokens INTEGER,
158                completion_tokens INTEGER,
159                cache_read_tokens INTEGER,
160                cache_creation_tokens INTEGER,
161                thinking_tokens INTEGER,
162                thinking_content TEXT,
163                agent_name TEXT,
164                compacted_at TEXT,
165                completed_at DATETIME,
166                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
167                FOREIGN KEY(session_id) REFERENCES sessions(id)
168            );",
169        )
170        .execute(pool)
171        .await?;
172
173        sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id);")
174            .execute(pool)
175            .await?;
176
177        sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_role_id ON messages(role, id DESC);")
178            .execute(pool)
179            .await?;
180
181        // Session-scoped key-value metadata (e.g. todo list).
182        sqlx::query(
183            "CREATE TABLE IF NOT EXISTS session_metadata (
184                session_id TEXT NOT NULL,
185                key TEXT NOT NULL,
186                value TEXT NOT NULL,
187                updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
188                PRIMARY KEY(session_id, key),
189                FOREIGN KEY(session_id) REFERENCES sessions(id)
190            );",
191        )
192        .execute(pool)
193        .await?;
194
195        // File lifecycle tracking (#465): files created by Koda in a session.
196        sqlx::query(
197            "CREATE TABLE IF NOT EXISTS owned_files (
198                session_id TEXT NOT NULL,
199                path TEXT NOT NULL,
200                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
201                PRIMARY KEY(session_id, path)
202            );",
203        )
204        .execute(pool)
205        .await?;
206
207        // Global key-value store (#693): replaces settings.toml and keys.toml.
208        // Keys are namespaced by convention:
209        //   - `setting:*`  — last-used provider, etc.
210        //   - `apikey:*`   — API keys (GEMINI_API_KEY, etc.)
211        sqlx::query(
212            "CREATE TABLE IF NOT EXISTS kv_store (
213                key TEXT PRIMARY KEY,
214                value TEXT NOT NULL,
215                updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
216            );",
217        )
218        .execute(pool)
219        .await?;
220
221        // REPL input history (#693): replaces ~/.config/koda/history.
222        sqlx::query(
223            "CREATE TABLE IF NOT EXISTS input_history (
224                id INTEGER PRIMARY KEY AUTOINCREMENT,
225                input TEXT NOT NULL,
226                created_at DATETIME DEFAULT CURRENT_TIMESTAMP
227            );",
228        )
229        .execute(pool)
230        .await?;
231
232        Ok(())
233    }
234
235    /// Set koda.db file permissions to 0600 (owner-only).
236    ///
237    /// The DB contains API keys and conversation history that may include
238    /// secrets. Restrictive permissions prevent other local users from reading.
239    #[cfg(unix)]
240    fn set_db_permissions(db_path: &Path) {
241        use std::os::unix::fs::PermissionsExt;
242        let perms = std::fs::Permissions::from_mode(0o600);
243        if let Err(e) = std::fs::set_permissions(db_path, perms) {
244            tracing::warn!("Failed to set 0600 on {}: {e}", db_path.display());
245        }
246    }
247}
248
249// ── File lifecycle tracking (#465) ────────────────────────────────────────────
250
251impl Database {
252    /// Record that Koda created a file in this session.
253    pub async fn insert_owned_file(&self, session_id: &str, path: &Path) -> Result<()> {
254        sqlx::query("INSERT OR IGNORE INTO owned_files (session_id, path) VALUES (?, ?)")
255            .bind(session_id)
256            .bind(path.to_string_lossy().as_ref())
257            .execute(&self.pool)
258            .await?;
259        Ok(())
260    }
261
262    /// Remove a file from the owned set.
263    pub async fn delete_owned_file(&self, session_id: &str, path: &Path) -> Result<()> {
264        sqlx::query("DELETE FROM owned_files WHERE session_id = ? AND path = ?")
265            .bind(session_id)
266            .bind(path.to_string_lossy().as_ref())
267            .execute(&self.pool)
268            .await?;
269        Ok(())
270    }
271
272    /// Load all owned file paths for a session (used on session resume).
273    pub async fn load_owned_files(
274        &self,
275        session_id: &str,
276    ) -> Result<std::collections::HashSet<std::path::PathBuf>> {
277        let rows: Vec<(String,)> =
278            sqlx::query_as("SELECT path FROM owned_files WHERE session_id = ?")
279                .bind(session_id)
280                .fetch_all(&self.pool)
281                .await?;
282        Ok(rows
283            .into_iter()
284            .map(|(p,)| std::path::PathBuf::from(p))
285            .collect())
286    }
287
288    /// Load a page of messages older than `before_id` (for virtual scroll).
289    ///
290    /// Returns up to `limit` messages with `id < before_id`, ordered
291    /// newest-first so the caller can reverse them for display.
292    /// Only non-compacted messages are returned.
293    pub async fn load_messages_before(
294        &self,
295        session_id: &str,
296        before_id: i64,
297        limit: i64,
298    ) -> Result<Vec<Message>> {
299        let rows: Vec<MessageRow> = sqlx::query_as(
300            "SELECT id, session_id, role, content, full_content, tool_calls, tool_call_id,
301                    prompt_tokens, completion_tokens,
302                    cache_read_tokens, cache_creation_tokens, thinking_tokens, thinking_content,
303                    created_at
304             FROM messages
305             WHERE session_id = ? AND id < ? AND compacted_at IS NULL
306             ORDER BY id DESC
307             LIMIT ?",
308        )
309        .bind(session_id)
310        .bind(before_id)
311        .bind(limit)
312        .fetch_all(&self.pool)
313        .await?;
314
315        // Reverse to chronological order
316        let mut messages: Vec<Message> = rows.into_iter().map(|r| r.into()).collect();
317        messages.reverse();
318        Ok(messages)
319    }
320
321    /// Seconds since the last assistant message in this session.
322    ///
323    /// Returns `None` if there are no (non-compacted) assistant messages.
324    /// Used by microcompact to decide whether the idle gap threshold is met.
325    pub async fn seconds_since_last_assistant(&self, session_id: &str) -> Result<Option<i64>> {
326        let row: Option<(i64,)> = sqlx::query_as(
327            "SELECT CAST((julianday('now') - julianday(created_at)) * 86400 AS INTEGER) \
328             FROM messages \
329             WHERE session_id = ? AND role = 'assistant' AND compacted_at IS NULL \
330             ORDER BY id DESC LIMIT 1",
331        )
332        .bind(session_id)
333        .fetch_optional(&self.pool)
334        .await?;
335        Ok(row.map(|(secs,)| secs))
336    }
337}
338
339// ── Row types ───────────────────────────────────────────────────────────
340
341/// Internal row type for sqlx deserialization.
342#[derive(sqlx::FromRow)]
343pub(crate) struct MessageRow {
344    pub id: i64,
345    pub session_id: String,
346    pub role: String,
347    pub content: Option<String>,
348    pub full_content: Option<String>,
349    pub tool_calls: Option<String>,
350    pub tool_call_id: Option<String>,
351    pub prompt_tokens: Option<i64>,
352    pub completion_tokens: Option<i64>,
353    pub cache_read_tokens: Option<i64>,
354    pub cache_creation_tokens: Option<i64>,
355    pub thinking_tokens: Option<i64>,
356    pub thinking_content: Option<String>,
357    pub created_at: Option<String>,
358}
359
360/// Session metadata for listing.
361#[derive(Debug, Clone, sqlx::FromRow)]
362pub(crate) struct SessionInfoRow {
363    pub id: String,
364    pub agent_name: String,
365    pub created_at: String,
366    pub message_count: i64,
367    pub total_tokens: i64,
368    pub title: Option<String>,
369    pub mode: Option<String>,
370}
371
372impl From<SessionInfoRow> for SessionInfo {
373    fn from(r: SessionInfoRow) -> Self {
374        Self {
375            id: r.id,
376            agent_name: r.agent_name,
377            created_at: r.created_at,
378            message_count: r.message_count,
379            total_tokens: r.total_tokens,
380            title: r.title,
381            mode: r.mode,
382        }
383    }
384}
385
386impl From<MessageRow> for Message {
387    fn from(r: MessageRow) -> Self {
388        Self {
389            id: r.id,
390            session_id: r.session_id,
391            role: r.role.parse().unwrap_or(Role::User),
392            content: r.content,
393            full_content: r.full_content,
394            tool_calls: r.tool_calls,
395            tool_call_id: r.tool_call_id,
396            prompt_tokens: r.prompt_tokens,
397            completion_tokens: r.completion_tokens,
398            cache_read_tokens: r.cache_read_tokens,
399            cache_creation_tokens: r.cache_creation_tokens,
400            thinking_tokens: r.thinking_tokens,
401            thinking_content: r.thinking_content,
402            created_at: r.created_at,
403        }
404    }
405}
406
407// ── Global KV store (#693) ─────────────────────────────────────────────────────────
408
409impl Database {
410    /// Get a value from the global KV store.
411    pub async fn kv_get(&self, key: &str) -> Result<Option<String>> {
412        let row: Option<(String,)> = sqlx::query_as("SELECT value FROM kv_store WHERE key = ?")
413            .bind(key)
414            .fetch_optional(&self.pool)
415            .await?;
416        Ok(row.map(|(v,)| v))
417    }
418
419    /// Set a value in the global KV store (upsert).
420    pub async fn kv_set(&self, key: &str, value: &str) -> Result<()> {
421        sqlx::query(
422            "INSERT INTO kv_store (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)
423             ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = CURRENT_TIMESTAMP",
424        )
425        .bind(key)
426        .bind(value)
427        .execute(&self.pool)
428        .await?;
429        Ok(())
430    }
431
432    /// Delete a key from the global KV store.
433    pub async fn kv_delete(&self, key: &str) -> Result<()> {
434        sqlx::query("DELETE FROM kv_store WHERE key = ?")
435            .bind(key)
436            .execute(&self.pool)
437            .await?;
438        Ok(())
439    }
440
441    /// Get all KV entries matching a prefix (e.g. `"apikey:"`).
442    pub async fn kv_list_prefix(&self, prefix: &str) -> Result<Vec<(String, String)>> {
443        let pattern = format!("{prefix}%");
444        let rows: Vec<(String, String)> =
445            sqlx::query_as("SELECT key, value FROM kv_store WHERE key LIKE ?")
446                .bind(&pattern)
447                .fetch_all(&self.pool)
448                .await?;
449        Ok(rows)
450    }
451}
452
453// ── Input history (#693) ─────────────────────────────────────────────────────────
454
455/// Maximum number of input history entries to keep.
456const MAX_INPUT_HISTORY: i64 = 500;
457
458impl Database {
459    /// Append an input to the history.
460    pub async fn history_push(&self, input: &str) -> Result<()> {
461        sqlx::query("INSERT INTO input_history (input) VALUES (?)")
462            .bind(input)
463            .execute(&self.pool)
464            .await?;
465
466        // Trim old entries beyond the cap.
467        sqlx::query(
468            "DELETE FROM input_history WHERE id NOT IN (
469                SELECT id FROM input_history ORDER BY id DESC LIMIT ?
470            )",
471        )
472        .bind(MAX_INPUT_HISTORY)
473        .execute(&self.pool)
474        .await?;
475
476        Ok(())
477    }
478
479    /// Load all history entries, oldest first.
480    pub async fn history_load(&self) -> Result<Vec<String>> {
481        let rows: Vec<(String,)> =
482            sqlx::query_as("SELECT input FROM input_history ORDER BY id ASC")
483                .fetch_all(&self.pool)
484                .await?;
485        Ok(rows.into_iter().map(|(s,)| s).collect())
486    }
487}