mod queries;
#[cfg(test)]
mod tests;
use anyhow::{Context, Result};
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use std::path::Path;
use std::str::FromStr;
pub use crate::persistence::{
CompactedStats, Message, Persistence, Role, SessionInfo, SessionUsage,
};
#[derive(Debug, Clone)]
pub struct Database {
pub(crate) pool: SqlitePool,
}
pub fn config_dir() -> Result<std::path::PathBuf> {
let base = std::env::var("XDG_CONFIG_HOME")
.ok()
.map(std::path::PathBuf::from)
.or_else(|| {
std::env::var("HOME")
.ok()
.map(|h| std::path::PathBuf::from(h).join(".config"))
})
.ok_or_else(|| {
anyhow::anyhow!("Cannot determine config directory (set HOME or XDG_CONFIG_HOME)")
})?;
Ok(base.join("koda"))
}
impl Database {
pub async fn init(koda_config_dir: &Path) -> Result<Self> {
let db_dir = koda_config_dir.join("db");
std::fs::create_dir_all(&db_dir)
.with_context(|| format!("Failed to create DB dir: {}", db_dir.display()))?;
let db_path = db_dir.join("koda.db");
Self::open(&db_path).await
}
pub async fn open(db_path: &Path) -> Result<Self> {
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
let options = SqliteConnectOptions::from_str(&db_url)?
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
.foreign_keys(true)
.create_if_missing(true);
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await
.with_context(|| format!("Failed to connect to database: {db_url}"))?;
Self::migrate(&pool).await?;
Ok(Self { pool })
}
async fn migrate(pool: &SqlitePool) -> Result<()> {
sqlx::query(
"CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
agent_name TEXT NOT NULL
);",
)
.execute(pool)
.await?;
sqlx::query(
"CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT,
tool_calls TEXT,
tool_call_id TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(session_id) REFERENCES sessions(id)
);",
)
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id);")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_role_id ON messages(role, id DESC);")
.execute(pool)
.await?;
for col in &[
"cache_read_tokens",
"cache_creation_tokens",
"thinking_tokens",
] {
let sql = format!("ALTER TABLE messages ADD COLUMN {col} INTEGER");
if let Err(e) = sqlx::query(&sql).execute(pool).await {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
}
for (col, col_type) in &[("agent_name", "TEXT")] {
let sql = format!("ALTER TABLE messages ADD COLUMN {col} {col_type}");
if let Err(e) = sqlx::query(&sql).execute(pool).await {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
}
sqlx::query(
"CREATE TABLE IF NOT EXISTS session_metadata (
session_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY(session_id, key),
FOREIGN KEY(session_id) REFERENCES sessions(id)
);",
)
.execute(pool)
.await?;
let sql = "ALTER TABLE sessions ADD COLUMN project_root TEXT";
if let Err(e) = sqlx::query(sql).execute(pool).await {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
let sql = "ALTER TABLE messages ADD COLUMN compacted_at TEXT";
if let Err(e) = sqlx::query(sql).execute(pool).await {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
let sql = "ALTER TABLE sessions ADD COLUMN last_accessed_at TEXT";
if let Err(e) = sqlx::query(sql).execute(pool).await {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
sqlx::query(
"CREATE TABLE IF NOT EXISTS owned_files (
session_id TEXT NOT NULL,
path TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY(session_id, path)
);",
)
.execute(pool)
.await?;
Ok(())
}
}
impl Database {
pub async fn insert_owned_file(&self, session_id: &str, path: &Path) -> Result<()> {
sqlx::query("INSERT OR IGNORE INTO owned_files (session_id, path) VALUES (?, ?)")
.bind(session_id)
.bind(path.to_string_lossy().as_ref())
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn delete_owned_file(&self, session_id: &str, path: &Path) -> Result<()> {
sqlx::query("DELETE FROM owned_files WHERE session_id = ? AND path = ?")
.bind(session_id)
.bind(path.to_string_lossy().as_ref())
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn load_owned_files(
&self,
session_id: &str,
) -> Result<std::collections::HashSet<std::path::PathBuf>> {
let rows: Vec<(String,)> =
sqlx::query_as("SELECT path FROM owned_files WHERE session_id = ?")
.bind(session_id)
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(|(p,)| std::path::PathBuf::from(p))
.collect())
}
pub async fn load_messages_before(
&self,
session_id: &str,
before_id: i64,
limit: i64,
) -> Result<Vec<Message>> {
let rows: Vec<MessageRow> = sqlx::query_as(
"SELECT id, session_id, role, content, tool_calls, tool_call_id,
prompt_tokens, completion_tokens,
cache_read_tokens, cache_creation_tokens, thinking_tokens
FROM messages
WHERE session_id = ? AND id < ? AND compacted_at IS NULL
ORDER BY id DESC
LIMIT ?",
)
.bind(session_id)
.bind(before_id)
.bind(limit)
.fetch_all(&self.pool)
.await?;
let mut messages: Vec<Message> = rows.into_iter().map(|r| r.into()).collect();
messages.reverse();
Ok(messages)
}
}
#[derive(sqlx::FromRow)]
pub(crate) struct MessageRow {
pub id: i64,
pub session_id: String,
pub role: String,
pub content: Option<String>,
pub tool_calls: Option<String>,
pub tool_call_id: Option<String>,
pub prompt_tokens: Option<i64>,
pub completion_tokens: Option<i64>,
pub cache_read_tokens: Option<i64>,
pub cache_creation_tokens: Option<i64>,
pub thinking_tokens: Option<i64>,
}
#[derive(Debug, Clone, sqlx::FromRow)]
pub(crate) struct SessionInfoRow {
pub id: String,
pub agent_name: String,
pub created_at: String,
pub message_count: i64,
pub total_tokens: i64,
}
impl From<SessionInfoRow> for SessionInfo {
fn from(r: SessionInfoRow) -> Self {
Self {
id: r.id,
agent_name: r.agent_name,
created_at: r.created_at,
message_count: r.message_count,
total_tokens: r.total_tokens,
}
}
}
impl From<MessageRow> for Message {
fn from(r: MessageRow) -> Self {
Self {
id: r.id,
session_id: r.session_id,
role: r.role.parse().unwrap_or(Role::User),
content: r.content,
tool_calls: r.tool_calls,
tool_call_id: r.tool_call_id,
prompt_tokens: r.prompt_tokens,
completion_tokens: r.completion_tokens,
cache_read_tokens: r.cache_read_tokens,
cache_creation_tokens: r.cache_creation_tokens,
thinking_tokens: r.thinking_tokens,
}
}
}