1pub mod queries;
27#[cfg(test)]
28mod tests;
29
30use anyhow::{Context, Result};
31use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
32use std::path::Path;
33use std::str::FromStr;
34
35pub use crate::persistence::{
37 CompactedStats, Message, Persistence, Role, SessionInfo, SessionUsage,
38};
39
40#[derive(Debug, Clone)]
42pub struct Database {
43 pub(crate) pool: SqlitePool,
44}
45
46pub fn config_dir() -> Result<std::path::PathBuf> {
48 let base = std::env::var("XDG_CONFIG_HOME")
49 .ok()
50 .map(std::path::PathBuf::from)
51 .or_else(|| {
52 std::env::var("HOME")
53 .ok()
54 .map(|h| std::path::PathBuf::from(h).join(".config"))
55 })
56 .ok_or_else(|| {
57 anyhow::anyhow!("Cannot determine config directory (set HOME or XDG_CONFIG_HOME)")
58 })?;
59 Ok(base.join("koda"))
60}
61
62impl Database {
63 pub fn pool(&self) -> &SqlitePool {
65 &self.pool
66 }
67
68 pub async fn init(koda_config_dir: &Path) -> Result<Self> {
75 let db_dir = koda_config_dir.join("db");
76 std::fs::create_dir_all(&db_dir)
77 .with_context(|| format!("Failed to create DB dir: {}", db_dir.display()))?;
78
79 let db_path = db_dir.join("koda.db");
80
81 let db = Self::open(&db_path).await?;
82
83 #[cfg(unix)]
86 Self::set_db_permissions(&db_path);
87
88 Ok(db)
89 }
90
91 pub async fn open(db_path: &Path) -> Result<Self> {
93 let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
94
95 let options = SqliteConnectOptions::from_str(&db_url)?
96 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
97 .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
98 .foreign_keys(true)
99 .create_if_missing(true)
100 .busy_timeout(std::time::Duration::from_millis(5000));
105
106 let pool = SqlitePoolOptions::new()
107 .max_connections(5)
108 .connect_with(options)
109 .await
110 .with_context(|| format!("Failed to connect to database: {db_url}"))?;
111
112 Self::migrate(&pool).await?;
114 Ok(Self { pool })
115 }
116
117 async fn migrate(pool: &SqlitePool) -> Result<()> {
119 sqlx::query(
120 "CREATE TABLE IF NOT EXISTS sessions (
121 id TEXT PRIMARY KEY,
122 created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
123 agent_name TEXT NOT NULL,
124 project_root TEXT,
125 last_accessed_at TEXT,
126 title TEXT,
127 mode TEXT
128 );",
129 )
130 .execute(pool)
131 .await?;
132
133 sqlx::query(
134 "CREATE TABLE IF NOT EXISTS messages (
135 id INTEGER PRIMARY KEY AUTOINCREMENT,
136 session_id TEXT NOT NULL,
137 role TEXT NOT NULL,
138 content TEXT,
139 full_content TEXT,
140 tool_calls TEXT,
141 tool_call_id TEXT,
142 prompt_tokens INTEGER,
143 completion_tokens INTEGER,
144 cache_read_tokens INTEGER,
145 cache_creation_tokens INTEGER,
146 thinking_tokens INTEGER,
147 agent_name TEXT,
148 compacted_at TEXT,
149 completed_at DATETIME,
150 created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
151 FOREIGN KEY(session_id) REFERENCES sessions(id)
152 );",
153 )
154 .execute(pool)
155 .await?;
156
157 sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id);")
158 .execute(pool)
159 .await?;
160
161 sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_role_id ON messages(role, id DESC);")
162 .execute(pool)
163 .await?;
164
165 sqlx::query(
167 "CREATE TABLE IF NOT EXISTS session_metadata (
168 session_id TEXT NOT NULL,
169 key TEXT NOT NULL,
170 value TEXT NOT NULL,
171 updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
172 PRIMARY KEY(session_id, key),
173 FOREIGN KEY(session_id) REFERENCES sessions(id)
174 );",
175 )
176 .execute(pool)
177 .await?;
178
179 sqlx::query(
181 "CREATE TABLE IF NOT EXISTS owned_files (
182 session_id TEXT NOT NULL,
183 path TEXT NOT NULL,
184 created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
185 PRIMARY KEY(session_id, path)
186 );",
187 )
188 .execute(pool)
189 .await?;
190
191 sqlx::query(
196 "CREATE TABLE IF NOT EXISTS kv_store (
197 key TEXT PRIMARY KEY,
198 value TEXT NOT NULL,
199 updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
200 );",
201 )
202 .execute(pool)
203 .await?;
204
205 sqlx::query(
207 "CREATE TABLE IF NOT EXISTS input_history (
208 id INTEGER PRIMARY KEY AUTOINCREMENT,
209 input TEXT NOT NULL,
210 created_at DATETIME DEFAULT CURRENT_TIMESTAMP
211 );",
212 )
213 .execute(pool)
214 .await?;
215
216 Ok(())
217 }
218
219 #[cfg(unix)]
224 fn set_db_permissions(db_path: &Path) {
225 use std::os::unix::fs::PermissionsExt;
226 let perms = std::fs::Permissions::from_mode(0o600);
227 if let Err(e) = std::fs::set_permissions(db_path, perms) {
228 tracing::warn!("Failed to set 0600 on {}: {e}", db_path.display());
229 }
230 }
231}
232
233impl Database {
236 pub async fn insert_owned_file(&self, session_id: &str, path: &Path) -> Result<()> {
238 sqlx::query("INSERT OR IGNORE INTO owned_files (session_id, path) VALUES (?, ?)")
239 .bind(session_id)
240 .bind(path.to_string_lossy().as_ref())
241 .execute(&self.pool)
242 .await?;
243 Ok(())
244 }
245
246 pub async fn delete_owned_file(&self, session_id: &str, path: &Path) -> Result<()> {
248 sqlx::query("DELETE FROM owned_files WHERE session_id = ? AND path = ?")
249 .bind(session_id)
250 .bind(path.to_string_lossy().as_ref())
251 .execute(&self.pool)
252 .await?;
253 Ok(())
254 }
255
256 pub async fn load_owned_files(
258 &self,
259 session_id: &str,
260 ) -> Result<std::collections::HashSet<std::path::PathBuf>> {
261 let rows: Vec<(String,)> =
262 sqlx::query_as("SELECT path FROM owned_files WHERE session_id = ?")
263 .bind(session_id)
264 .fetch_all(&self.pool)
265 .await?;
266 Ok(rows
267 .into_iter()
268 .map(|(p,)| std::path::PathBuf::from(p))
269 .collect())
270 }
271
272 pub async fn load_messages_before(
278 &self,
279 session_id: &str,
280 before_id: i64,
281 limit: i64,
282 ) -> Result<Vec<Message>> {
283 let rows: Vec<MessageRow> = sqlx::query_as(
284 "SELECT id, session_id, role, content, full_content, tool_calls, tool_call_id,
285 prompt_tokens, completion_tokens,
286 cache_read_tokens, cache_creation_tokens, thinking_tokens,
287 created_at
288 FROM messages
289 WHERE session_id = ? AND id < ? AND compacted_at IS NULL
290 ORDER BY id DESC
291 LIMIT ?",
292 )
293 .bind(session_id)
294 .bind(before_id)
295 .bind(limit)
296 .fetch_all(&self.pool)
297 .await?;
298
299 let mut messages: Vec<Message> = rows.into_iter().map(|r| r.into()).collect();
301 messages.reverse();
302 Ok(messages)
303 }
304
305 pub async fn seconds_since_last_assistant(&self, session_id: &str) -> Result<Option<i64>> {
310 let row: Option<(i64,)> = sqlx::query_as(
311 "SELECT CAST((julianday('now') - julianday(created_at)) * 86400 AS INTEGER) \
312 FROM messages \
313 WHERE session_id = ? AND role = 'assistant' AND compacted_at IS NULL \
314 ORDER BY id DESC LIMIT 1",
315 )
316 .bind(session_id)
317 .fetch_optional(&self.pool)
318 .await?;
319 Ok(row.map(|(secs,)| secs))
320 }
321}
322
323#[derive(sqlx::FromRow)]
327pub(crate) struct MessageRow {
328 pub id: i64,
329 pub session_id: String,
330 pub role: String,
331 pub content: Option<String>,
332 pub full_content: Option<String>,
333 pub tool_calls: Option<String>,
334 pub tool_call_id: Option<String>,
335 pub prompt_tokens: Option<i64>,
336 pub completion_tokens: Option<i64>,
337 pub cache_read_tokens: Option<i64>,
338 pub cache_creation_tokens: Option<i64>,
339 pub thinking_tokens: Option<i64>,
340 pub created_at: Option<String>,
341}
342
343#[derive(Debug, Clone, sqlx::FromRow)]
345pub(crate) struct SessionInfoRow {
346 pub id: String,
347 pub agent_name: String,
348 pub created_at: String,
349 pub message_count: i64,
350 pub total_tokens: i64,
351 pub title: Option<String>,
352 pub mode: Option<String>,
353}
354
355impl From<SessionInfoRow> for SessionInfo {
356 fn from(r: SessionInfoRow) -> Self {
357 Self {
358 id: r.id,
359 agent_name: r.agent_name,
360 created_at: r.created_at,
361 message_count: r.message_count,
362 total_tokens: r.total_tokens,
363 title: r.title,
364 mode: r.mode,
365 }
366 }
367}
368
369impl From<MessageRow> for Message {
370 fn from(r: MessageRow) -> Self {
371 Self {
372 id: r.id,
373 session_id: r.session_id,
374 role: r.role.parse().unwrap_or(Role::User),
375 content: r.content,
376 full_content: r.full_content,
377 tool_calls: r.tool_calls,
378 tool_call_id: r.tool_call_id,
379 prompt_tokens: r.prompt_tokens,
380 completion_tokens: r.completion_tokens,
381 cache_read_tokens: r.cache_read_tokens,
382 cache_creation_tokens: r.cache_creation_tokens,
383 thinking_tokens: r.thinking_tokens,
384 created_at: r.created_at,
385 }
386 }
387}
388
389impl Database {
392 pub async fn kv_get(&self, key: &str) -> Result<Option<String>> {
394 let row: Option<(String,)> = sqlx::query_as("SELECT value FROM kv_store WHERE key = ?")
395 .bind(key)
396 .fetch_optional(&self.pool)
397 .await?;
398 Ok(row.map(|(v,)| v))
399 }
400
401 pub async fn kv_set(&self, key: &str, value: &str) -> Result<()> {
403 sqlx::query(
404 "INSERT INTO kv_store (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)
405 ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = CURRENT_TIMESTAMP",
406 )
407 .bind(key)
408 .bind(value)
409 .execute(&self.pool)
410 .await?;
411 Ok(())
412 }
413
414 pub async fn kv_delete(&self, key: &str) -> Result<()> {
416 sqlx::query("DELETE FROM kv_store WHERE key = ?")
417 .bind(key)
418 .execute(&self.pool)
419 .await?;
420 Ok(())
421 }
422
423 pub async fn kv_list_prefix(&self, prefix: &str) -> Result<Vec<(String, String)>> {
425 let pattern = format!("{prefix}%");
426 let rows: Vec<(String, String)> =
427 sqlx::query_as("SELECT key, value FROM kv_store WHERE key LIKE ?")
428 .bind(&pattern)
429 .fetch_all(&self.pool)
430 .await?;
431 Ok(rows)
432 }
433}
434
435const MAX_INPUT_HISTORY: i64 = 500;
439
440impl Database {
441 pub async fn history_push(&self, input: &str) -> Result<()> {
443 sqlx::query("INSERT INTO input_history (input) VALUES (?)")
444 .bind(input)
445 .execute(&self.pool)
446 .await?;
447
448 sqlx::query(
450 "DELETE FROM input_history WHERE id NOT IN (
451 SELECT id FROM input_history ORDER BY id DESC LIMIT ?
452 )",
453 )
454 .bind(MAX_INPUT_HISTORY)
455 .execute(&self.pool)
456 .await?;
457
458 Ok(())
459 }
460
461 pub async fn history_load(&self) -> Result<Vec<String>> {
463 let rows: Vec<(String,)> =
464 sqlx::query_as("SELECT input FROM input_history ORDER BY id ASC")
465 .fetch_all(&self.pool)
466 .await?;
467 Ok(rows.into_iter().map(|(s,)| s).collect())
468 }
469}