Skip to main content

koda_core/db/
mod.rs

1//! SQLite persistence layer.
2//!
3//! Implements the [`crate::persistence::Persistence`] trait for SQLite via sqlx.
4//! Uses WAL mode for concurrent read/write access.
5//!
6//! ## Database location
7//!
8//! - **Default**: `~/.config/koda/koda.db`
9//! - Schema is auto-migrated on startup
10//! - WAL mode enables concurrent reads (main session + sub-agents)
11//!
12//! ## What's stored
13//!
14//! - **Conversation history** — all messages, tool calls, and results
15//! - **Sessions** — session metadata, timestamps, model info
16//! - **File ownership** — which files Koda created (for auto-approve Delete)
17//! - **Progress entries** — survive compaction for persistent tracking
18//! - **KV store** — settings (last provider) and API keys (#693)
19//! - **Input history** — REPL command history (#693)
20//!
21//! ## Module layout
22//!
23//! - **mod.rs** — `Database` struct, init/open, schema migrations, row types
24//! - **queries.rs** — `Persistence` trait implementation (all SQL queries)
25
26pub mod queries;
27#[cfg(test)]
28mod tests;
29
30use anyhow::{Context, Result};
31use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
32use std::path::Path;
33use std::str::FromStr;
34
35/// Re-export persistence types for backward compatibility.
36pub use crate::persistence::{
37    CompactedStats, Message, Persistence, Role, SessionInfo, SessionUsage,
38};
39
40/// Wrapper around the SQLite connection pool.
41#[derive(Debug, Clone)]
42pub struct Database {
43    pub(crate) pool: SqlitePool,
44}
45
46/// Get the koda config directory (~/.config/koda/).
47pub fn config_dir() -> Result<std::path::PathBuf> {
48    let base = std::env::var("XDG_CONFIG_HOME")
49        .ok()
50        .map(std::path::PathBuf::from)
51        .or_else(|| {
52            std::env::var("HOME")
53                .ok()
54                .map(|h| std::path::PathBuf::from(h).join(".config"))
55        })
56        .ok_or_else(|| {
57            anyhow::anyhow!("Cannot determine config directory (set HOME or XDG_CONFIG_HOME)")
58        })?;
59    Ok(base.join("koda"))
60}
61
62impl Database {
63    /// Access the underlying connection pool (for tests and raw queries).
64    pub fn pool(&self) -> &SqlitePool {
65        &self.pool
66    }
67
68    /// Initialize the database, run migrations, and enable WAL mode.
69    ///
70    /// `koda_config_dir` is the koda configuration directory (e.g. `~/.config/koda`).
71    /// The database lives in `<koda_config_dir>/db/koda.db`.
72    ///
73    /// Production callers should pass `db::config_dir()?`; tests pass a temp dir.
74    pub async fn init(koda_config_dir: &Path) -> Result<Self> {
75        let db_dir = koda_config_dir.join("db");
76        std::fs::create_dir_all(&db_dir)
77            .with_context(|| format!("Failed to create DB dir: {}", db_dir.display()))?;
78
79        let db_path = db_dir.join("koda.db");
80
81        let db = Self::open(&db_path).await?;
82
83        // Ensure restrictive permissions — DB contains API keys and
84        // conversation history that may include secrets (#693).
85        #[cfg(unix)]
86        Self::set_db_permissions(&db_path);
87
88        Ok(db)
89    }
90
91    /// Open a database at a specific path (used by tests and init).
92    pub async fn open(db_path: &Path) -> Result<Self> {
93        let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
94
95        let options = SqliteConnectOptions::from_str(&db_url)?
96            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
97            .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
98            .foreign_keys(true)
99            .create_if_missing(true)
100            // Retry for up to 5 s when another connection holds the write
101            // lock. Without this, concurrent writes from parallel sub-agents
102            // (#595) return SQLITE_BUSY immediately and the insert is silently
103            // dropped. Individual writes are ~1 ms so the retry resolves fast.
104            .busy_timeout(std::time::Duration::from_millis(5000));
105
106        let pool = SqlitePoolOptions::new()
107            .max_connections(5)
108            .connect_with(options)
109            .await
110            .with_context(|| format!("Failed to connect to database: {db_url}"))?;
111
112        // Run schema migrations
113        Self::migrate(&pool).await?;
114        Ok(Self { pool })
115    }
116
117    /// Apply the schema (idempotent).
118    async fn migrate(pool: &SqlitePool) -> Result<()> {
119        sqlx::query(
120            "CREATE TABLE IF NOT EXISTS sessions (
121                id TEXT PRIMARY KEY,
122                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
123                agent_name TEXT NOT NULL,
124                project_root TEXT,
125                last_accessed_at TEXT,
126                title TEXT,
127                mode TEXT
128            );",
129        )
130        .execute(pool)
131        .await?;
132
133        sqlx::query(
134            "CREATE TABLE IF NOT EXISTS messages (
135                id INTEGER PRIMARY KEY AUTOINCREMENT,
136                session_id TEXT NOT NULL,
137                role TEXT NOT NULL,
138                content TEXT,
139                full_content TEXT,
140                tool_calls TEXT,
141                tool_call_id TEXT,
142                prompt_tokens INTEGER,
143                completion_tokens INTEGER,
144                cache_read_tokens INTEGER,
145                cache_creation_tokens INTEGER,
146                thinking_tokens INTEGER,
147                agent_name TEXT,
148                compacted_at TEXT,
149                completed_at DATETIME,
150                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
151                FOREIGN KEY(session_id) REFERENCES sessions(id)
152            );",
153        )
154        .execute(pool)
155        .await?;
156
157        sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id);")
158            .execute(pool)
159            .await?;
160
161        sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_role_id ON messages(role, id DESC);")
162            .execute(pool)
163            .await?;
164
165        // Session-scoped key-value metadata (e.g. todo list).
166        sqlx::query(
167            "CREATE TABLE IF NOT EXISTS session_metadata (
168                session_id TEXT NOT NULL,
169                key TEXT NOT NULL,
170                value TEXT NOT NULL,
171                updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
172                PRIMARY KEY(session_id, key),
173                FOREIGN KEY(session_id) REFERENCES sessions(id)
174            );",
175        )
176        .execute(pool)
177        .await?;
178
179        // File lifecycle tracking (#465): files created by Koda in a session.
180        sqlx::query(
181            "CREATE TABLE IF NOT EXISTS owned_files (
182                session_id TEXT NOT NULL,
183                path TEXT NOT NULL,
184                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
185                PRIMARY KEY(session_id, path)
186            );",
187        )
188        .execute(pool)
189        .await?;
190
191        // Global key-value store (#693): replaces settings.toml and keys.toml.
192        // Keys are namespaced by convention:
193        //   - `setting:*`  — last-used provider, etc.
194        //   - `apikey:*`   — API keys (GEMINI_API_KEY, etc.)
195        sqlx::query(
196            "CREATE TABLE IF NOT EXISTS kv_store (
197                key TEXT PRIMARY KEY,
198                value TEXT NOT NULL,
199                updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
200            );",
201        )
202        .execute(pool)
203        .await?;
204
205        // REPL input history (#693): replaces ~/.config/koda/history.
206        sqlx::query(
207            "CREATE TABLE IF NOT EXISTS input_history (
208                id INTEGER PRIMARY KEY AUTOINCREMENT,
209                input TEXT NOT NULL,
210                created_at DATETIME DEFAULT CURRENT_TIMESTAMP
211            );",
212        )
213        .execute(pool)
214        .await?;
215
216        Ok(())
217    }
218
219    /// Set koda.db file permissions to 0600 (owner-only).
220    ///
221    /// The DB contains API keys and conversation history that may include
222    /// secrets. Restrictive permissions prevent other local users from reading.
223    #[cfg(unix)]
224    fn set_db_permissions(db_path: &Path) {
225        use std::os::unix::fs::PermissionsExt;
226        let perms = std::fs::Permissions::from_mode(0o600);
227        if let Err(e) = std::fs::set_permissions(db_path, perms) {
228            tracing::warn!("Failed to set 0600 on {}: {e}", db_path.display());
229        }
230    }
231}
232
233// ── File lifecycle tracking (#465) ────────────────────────────────────────────
234
235impl Database {
236    /// Record that Koda created a file in this session.
237    pub async fn insert_owned_file(&self, session_id: &str, path: &Path) -> Result<()> {
238        sqlx::query("INSERT OR IGNORE INTO owned_files (session_id, path) VALUES (?, ?)")
239            .bind(session_id)
240            .bind(path.to_string_lossy().as_ref())
241            .execute(&self.pool)
242            .await?;
243        Ok(())
244    }
245
246    /// Remove a file from the owned set.
247    pub async fn delete_owned_file(&self, session_id: &str, path: &Path) -> Result<()> {
248        sqlx::query("DELETE FROM owned_files WHERE session_id = ? AND path = ?")
249            .bind(session_id)
250            .bind(path.to_string_lossy().as_ref())
251            .execute(&self.pool)
252            .await?;
253        Ok(())
254    }
255
256    /// Load all owned file paths for a session (used on session resume).
257    pub async fn load_owned_files(
258        &self,
259        session_id: &str,
260    ) -> Result<std::collections::HashSet<std::path::PathBuf>> {
261        let rows: Vec<(String,)> =
262            sqlx::query_as("SELECT path FROM owned_files WHERE session_id = ?")
263                .bind(session_id)
264                .fetch_all(&self.pool)
265                .await?;
266        Ok(rows
267            .into_iter()
268            .map(|(p,)| std::path::PathBuf::from(p))
269            .collect())
270    }
271
272    /// Load a page of messages older than `before_id` (for virtual scroll).
273    ///
274    /// Returns up to `limit` messages with `id < before_id`, ordered
275    /// newest-first so the caller can reverse them for display.
276    /// Only non-compacted messages are returned.
277    pub async fn load_messages_before(
278        &self,
279        session_id: &str,
280        before_id: i64,
281        limit: i64,
282    ) -> Result<Vec<Message>> {
283        let rows: Vec<MessageRow> = sqlx::query_as(
284            "SELECT id, session_id, role, content, full_content, tool_calls, tool_call_id,
285                    prompt_tokens, completion_tokens,
286                    cache_read_tokens, cache_creation_tokens, thinking_tokens,
287                    created_at
288             FROM messages
289             WHERE session_id = ? AND id < ? AND compacted_at IS NULL
290             ORDER BY id DESC
291             LIMIT ?",
292        )
293        .bind(session_id)
294        .bind(before_id)
295        .bind(limit)
296        .fetch_all(&self.pool)
297        .await?;
298
299        // Reverse to chronological order
300        let mut messages: Vec<Message> = rows.into_iter().map(|r| r.into()).collect();
301        messages.reverse();
302        Ok(messages)
303    }
304
305    /// Seconds since the last assistant message in this session.
306    ///
307    /// Returns `None` if there are no (non-compacted) assistant messages.
308    /// Used by microcompact to decide whether the idle gap threshold is met.
309    pub async fn seconds_since_last_assistant(&self, session_id: &str) -> Result<Option<i64>> {
310        let row: Option<(i64,)> = sqlx::query_as(
311            "SELECT CAST((julianday('now') - julianday(created_at)) * 86400 AS INTEGER) \
312             FROM messages \
313             WHERE session_id = ? AND role = 'assistant' AND compacted_at IS NULL \
314             ORDER BY id DESC LIMIT 1",
315        )
316        .bind(session_id)
317        .fetch_optional(&self.pool)
318        .await?;
319        Ok(row.map(|(secs,)| secs))
320    }
321}
322
323// ── Row types ───────────────────────────────────────────────────────────
324
325/// Internal row type for sqlx deserialization.
326#[derive(sqlx::FromRow)]
327pub(crate) struct MessageRow {
328    pub id: i64,
329    pub session_id: String,
330    pub role: String,
331    pub content: Option<String>,
332    pub full_content: Option<String>,
333    pub tool_calls: Option<String>,
334    pub tool_call_id: Option<String>,
335    pub prompt_tokens: Option<i64>,
336    pub completion_tokens: Option<i64>,
337    pub cache_read_tokens: Option<i64>,
338    pub cache_creation_tokens: Option<i64>,
339    pub thinking_tokens: Option<i64>,
340    pub created_at: Option<String>,
341}
342
343/// Session metadata for listing.
344#[derive(Debug, Clone, sqlx::FromRow)]
345pub(crate) struct SessionInfoRow {
346    pub id: String,
347    pub agent_name: String,
348    pub created_at: String,
349    pub message_count: i64,
350    pub total_tokens: i64,
351    pub title: Option<String>,
352    pub mode: Option<String>,
353}
354
355impl From<SessionInfoRow> for SessionInfo {
356    fn from(r: SessionInfoRow) -> Self {
357        Self {
358            id: r.id,
359            agent_name: r.agent_name,
360            created_at: r.created_at,
361            message_count: r.message_count,
362            total_tokens: r.total_tokens,
363            title: r.title,
364            mode: r.mode,
365        }
366    }
367}
368
369impl From<MessageRow> for Message {
370    fn from(r: MessageRow) -> Self {
371        Self {
372            id: r.id,
373            session_id: r.session_id,
374            role: r.role.parse().unwrap_or(Role::User),
375            content: r.content,
376            full_content: r.full_content,
377            tool_calls: r.tool_calls,
378            tool_call_id: r.tool_call_id,
379            prompt_tokens: r.prompt_tokens,
380            completion_tokens: r.completion_tokens,
381            cache_read_tokens: r.cache_read_tokens,
382            cache_creation_tokens: r.cache_creation_tokens,
383            thinking_tokens: r.thinking_tokens,
384            created_at: r.created_at,
385        }
386    }
387}
388
389// ── Global KV store (#693) ─────────────────────────────────────────────────────────
390
391impl Database {
392    /// Get a value from the global KV store.
393    pub async fn kv_get(&self, key: &str) -> Result<Option<String>> {
394        let row: Option<(String,)> = sqlx::query_as("SELECT value FROM kv_store WHERE key = ?")
395            .bind(key)
396            .fetch_optional(&self.pool)
397            .await?;
398        Ok(row.map(|(v,)| v))
399    }
400
401    /// Set a value in the global KV store (upsert).
402    pub async fn kv_set(&self, key: &str, value: &str) -> Result<()> {
403        sqlx::query(
404            "INSERT INTO kv_store (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)
405             ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = CURRENT_TIMESTAMP",
406        )
407        .bind(key)
408        .bind(value)
409        .execute(&self.pool)
410        .await?;
411        Ok(())
412    }
413
414    /// Delete a key from the global KV store.
415    pub async fn kv_delete(&self, key: &str) -> Result<()> {
416        sqlx::query("DELETE FROM kv_store WHERE key = ?")
417            .bind(key)
418            .execute(&self.pool)
419            .await?;
420        Ok(())
421    }
422
423    /// Get all KV entries matching a prefix (e.g. `"apikey:"`).
424    pub async fn kv_list_prefix(&self, prefix: &str) -> Result<Vec<(String, String)>> {
425        let pattern = format!("{prefix}%");
426        let rows: Vec<(String, String)> =
427            sqlx::query_as("SELECT key, value FROM kv_store WHERE key LIKE ?")
428                .bind(&pattern)
429                .fetch_all(&self.pool)
430                .await?;
431        Ok(rows)
432    }
433}
434
435// ── Input history (#693) ─────────────────────────────────────────────────────────
436
437/// Maximum number of input history entries to keep.
438const MAX_INPUT_HISTORY: i64 = 500;
439
440impl Database {
441    /// Append an input to the history.
442    pub async fn history_push(&self, input: &str) -> Result<()> {
443        sqlx::query("INSERT INTO input_history (input) VALUES (?)")
444            .bind(input)
445            .execute(&self.pool)
446            .await?;
447
448        // Trim old entries beyond the cap.
449        sqlx::query(
450            "DELETE FROM input_history WHERE id NOT IN (
451                SELECT id FROM input_history ORDER BY id DESC LIMIT ?
452            )",
453        )
454        .bind(MAX_INPUT_HISTORY)
455        .execute(&self.pool)
456        .await?;
457
458        Ok(())
459    }
460
461    /// Load all history entries, oldest first.
462    pub async fn history_load(&self) -> Result<Vec<String>> {
463        let rows: Vec<(String,)> =
464            sqlx::query_as("SELECT input FROM input_history ORDER BY id ASC")
465                .fetch_all(&self.pool)
466                .await?;
467        Ok(rows.into_iter().map(|(s,)| s).collect())
468    }
469}