koda-core 0.1.13

Core engine for the Koda AI coding agent
Documentation
//! SQLite persistence layer.
//!
//! Implements `Persistence` trait for SQLite via sqlx.
//! Uses WAL mode for concurrent access.
//!
//! ## Module layout
//!
//! - **mod.rs** — `Database` struct, init/open, schema migrations, row types
//! - **queries.rs** — `Persistence` trait implementation (all SQL queries)

mod queries;
#[cfg(test)]
mod tests;

use anyhow::{Context, Result};
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use std::path::Path;
use std::str::FromStr;

/// Re-export persistence types for backward compatibility.
pub use crate::persistence::{
    CompactedStats, Message, Persistence, Role, SessionInfo, SessionUsage,
};

/// Wrapper around the SQLite connection pool.
#[derive(Debug, Clone)]
pub struct Database {
    pub(crate) pool: SqlitePool,
}

/// Get the koda config directory (~/.config/koda/).
pub fn config_dir() -> Result<std::path::PathBuf> {
    let base = std::env::var("XDG_CONFIG_HOME")
        .ok()
        .map(std::path::PathBuf::from)
        .or_else(|| {
            std::env::var("HOME")
                .ok()
                .map(|h| std::path::PathBuf::from(h).join(".config"))
        })
        .ok_or_else(|| {
            anyhow::anyhow!("Cannot determine config directory (set HOME or XDG_CONFIG_HOME)")
        })?;
    Ok(base.join("koda"))
}

impl Database {
    /// Initialize the database, run migrations, and enable WAL mode.
    ///
    /// `koda_config_dir` is the koda configuration directory (e.g. `~/.config/koda`).
    /// The database lives in `<koda_config_dir>/db/koda.db`.
    ///
    /// Production callers should pass `db::config_dir()?`; tests pass a temp dir.
    pub async fn init(koda_config_dir: &Path) -> Result<Self> {
        let db_dir = koda_config_dir.join("db");
        std::fs::create_dir_all(&db_dir)
            .with_context(|| format!("Failed to create DB dir: {}", db_dir.display()))?;

        let db_path = db_dir.join("koda.db");

        Self::open(&db_path).await
    }

    /// Open a database at a specific path (used by tests and init).
    pub async fn open(db_path: &Path) -> Result<Self> {
        let db_url = format!("sqlite:{}?mode=rwc", db_path.display());

        let options = SqliteConnectOptions::from_str(&db_url)?
            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
            .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
            .foreign_keys(true)
            .create_if_missing(true);

        let pool = SqlitePoolOptions::new()
            .max_connections(5)
            .connect_with(options)
            .await
            .with_context(|| format!("Failed to connect to database: {db_url}"))?;

        // Run schema migrations
        Self::migrate(&pool).await?;
        Ok(Self { pool })
    }

    /// Apply the schema (idempotent).
    async fn migrate(pool: &SqlitePool) -> Result<()> {
        sqlx::query(
            "CREATE TABLE IF NOT EXISTS sessions (
                id TEXT PRIMARY KEY,
                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                agent_name TEXT NOT NULL
            );",
        )
        .execute(pool)
        .await?;

        sqlx::query(
            "CREATE TABLE IF NOT EXISTS messages (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                session_id TEXT NOT NULL,
                role TEXT NOT NULL,
                content TEXT,
                tool_calls TEXT,
                tool_call_id TEXT,
                prompt_tokens INTEGER,
                completion_tokens INTEGER,
                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                FOREIGN KEY(session_id) REFERENCES sessions(id)
            );",
        )
        .execute(pool)
        .await?;

        sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id);")
            .execute(pool)
            .await?;

        sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_role_id ON messages(role, id DESC);")
            .execute(pool)
            .await?;

        // Additive migrations for new token tracking columns (idempotent).
        for col in &[
            "cache_read_tokens",
            "cache_creation_tokens",
            "thinking_tokens",
        ] {
            let sql = format!("ALTER TABLE messages ADD COLUMN {col} INTEGER");
            // Ignore "duplicate column name" errors — column already exists.
            if let Err(e) = sqlx::query(&sql).execute(pool).await {
                let msg = e.to_string();
                if !msg.contains("duplicate column name") {
                    return Err(e.into());
                }
            }
        }

        // Text column migrations
        for (col, col_type) in &[("agent_name", "TEXT")] {
            let sql = format!("ALTER TABLE messages ADD COLUMN {col} {col_type}");
            if let Err(e) = sqlx::query(&sql).execute(pool).await {
                let msg = e.to_string();
                if !msg.contains("duplicate column name") {
                    return Err(e.into());
                }
            }
        }

        // Session-scoped key-value metadata (e.g. todo list).
        sqlx::query(
            "CREATE TABLE IF NOT EXISTS session_metadata (
                session_id TEXT NOT NULL,
                key TEXT NOT NULL,
                value TEXT NOT NULL,
                updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                PRIMARY KEY(session_id, key),
                FOREIGN KEY(session_id) REFERENCES sessions(id)
            );",
        )
        .execute(pool)
        .await?;

        // Additive migration: add project_root to sessions
        let sql = "ALTER TABLE sessions ADD COLUMN project_root TEXT";
        if let Err(e) = sqlx::query(sql).execute(pool).await {
            let msg = e.to_string();
            if !msg.contains("duplicate column name") {
                return Err(e.into());
            }
        }

        // Additive migration: add compacted_at for non-destructive compaction (#428)
        let sql = "ALTER TABLE messages ADD COLUMN compacted_at TEXT";
        if let Err(e) = sqlx::query(sql).execute(pool).await {
            let msg = e.to_string();
            if !msg.contains("duplicate column name") {
                return Err(e.into());
            }
        }

        // Additive migration: track last activity per session (#429)
        let sql = "ALTER TABLE sessions ADD COLUMN last_accessed_at TEXT";
        if let Err(e) = sqlx::query(sql).execute(pool).await {
            let msg = e.to_string();
            if !msg.contains("duplicate column name") {
                return Err(e.into());
            }
        }

        // File lifecycle tracking (#465): files created by Koda in a session.
        sqlx::query(
            "CREATE TABLE IF NOT EXISTS owned_files (
                session_id TEXT NOT NULL,
                path TEXT NOT NULL,
                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                PRIMARY KEY(session_id, path)
            );",
        )
        .execute(pool)
        .await?;

        Ok(())
    }
}

// ── File lifecycle tracking (#465) ────────────────────────────────────────────

impl Database {
    /// Record that Koda created a file in this session.
    pub async fn insert_owned_file(&self, session_id: &str, path: &Path) -> Result<()> {
        sqlx::query("INSERT OR IGNORE INTO owned_files (session_id, path) VALUES (?, ?)")
            .bind(session_id)
            .bind(path.to_string_lossy().as_ref())
            .execute(&self.pool)
            .await?;
        Ok(())
    }

    /// Remove a file from the owned set.
    pub async fn delete_owned_file(&self, session_id: &str, path: &Path) -> Result<()> {
        sqlx::query("DELETE FROM owned_files WHERE session_id = ? AND path = ?")
            .bind(session_id)
            .bind(path.to_string_lossy().as_ref())
            .execute(&self.pool)
            .await?;
        Ok(())
    }

    /// Load all owned file paths for a session (used on session resume).
    pub async fn load_owned_files(
        &self,
        session_id: &str,
    ) -> Result<std::collections::HashSet<std::path::PathBuf>> {
        let rows: Vec<(String,)> =
            sqlx::query_as("SELECT path FROM owned_files WHERE session_id = ?")
                .bind(session_id)
                .fetch_all(&self.pool)
                .await?;
        Ok(rows
            .into_iter()
            .map(|(p,)| std::path::PathBuf::from(p))
            .collect())
    }

    /// Load a page of messages older than `before_id` (for virtual scroll).
    ///
    /// Returns up to `limit` messages with `id < before_id`, ordered
    /// newest-first so the caller can reverse them for display.
    /// Only non-compacted messages are returned.
    pub async fn load_messages_before(
        &self,
        session_id: &str,
        before_id: i64,
        limit: i64,
    ) -> Result<Vec<Message>> {
        let rows: Vec<MessageRow> = sqlx::query_as(
            "SELECT id, session_id, role, content, tool_calls, tool_call_id,
                    prompt_tokens, completion_tokens,
                    cache_read_tokens, cache_creation_tokens, thinking_tokens
             FROM messages
             WHERE session_id = ? AND id < ? AND compacted_at IS NULL
             ORDER BY id DESC
             LIMIT ?",
        )
        .bind(session_id)
        .bind(before_id)
        .bind(limit)
        .fetch_all(&self.pool)
        .await?;

        // Reverse to chronological order
        let mut messages: Vec<Message> = rows.into_iter().map(|r| r.into()).collect();
        messages.reverse();
        Ok(messages)
    }
}

// ── Row types ───────────────────────────────────────────────────────────

/// Internal row type for sqlx deserialization.
#[derive(sqlx::FromRow)]
pub(crate) struct MessageRow {
    pub id: i64,
    pub session_id: String,
    pub role: String,
    pub content: Option<String>,
    pub tool_calls: Option<String>,
    pub tool_call_id: Option<String>,
    pub prompt_tokens: Option<i64>,
    pub completion_tokens: Option<i64>,
    pub cache_read_tokens: Option<i64>,
    pub cache_creation_tokens: Option<i64>,
    pub thinking_tokens: Option<i64>,
}

/// Session metadata for listing.
#[derive(Debug, Clone, sqlx::FromRow)]
pub(crate) struct SessionInfoRow {
    pub id: String,
    pub agent_name: String,
    pub created_at: String,
    pub message_count: i64,
    pub total_tokens: i64,
}

impl From<SessionInfoRow> for SessionInfo {
    fn from(r: SessionInfoRow) -> Self {
        Self {
            id: r.id,
            agent_name: r.agent_name,
            created_at: r.created_at,
            message_count: r.message_count,
            total_tokens: r.total_tokens,
        }
    }
}

impl From<MessageRow> for Message {
    fn from(r: MessageRow) -> Self {
        Self {
            id: r.id,
            session_id: r.session_id,
            role: r.role.parse().unwrap_or(Role::User),
            content: r.content,
            tool_calls: r.tool_calls,
            tool_call_id: r.tool_call_id,
            prompt_tokens: r.prompt_tokens,
            completion_tokens: r.completion_tokens,
            cache_read_tokens: r.cache_read_tokens,
            cache_creation_tokens: r.cache_creation_tokens,
            thinking_tokens: r.thinking_tokens,
        }
    }
}