Skip to main content

koda_core/
db.rs

1//! SQLite persistence layer.
2//!
3//! Implements `Persistence` trait for SQLite via sqlx.
4//! Uses WAL mode for concurrent access.
5
6use anyhow::{Context, Result};
7use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
8use std::path::Path;
9use std::str::FromStr;
10
11/// Re-export persistence types for backward compatibility.
12pub use crate::persistence::{
13    CompactedStats, Message, Persistence, Role, SessionInfo, SessionUsage,
14};
15
16/// Wrapper around the SQLite connection pool.
17#[derive(Debug, Clone)]
18pub struct Database {
19    pool: SqlitePool,
20}
21
22/// Get the koda config directory (~/.config/koda/).
23pub fn config_dir() -> Result<std::path::PathBuf> {
24    let base = std::env::var("XDG_CONFIG_HOME")
25        .ok()
26        .map(std::path::PathBuf::from)
27        .or_else(|| {
28            std::env::var("HOME")
29                .ok()
30                .map(|h| std::path::PathBuf::from(h).join(".config"))
31        })
32        .ok_or_else(|| {
33            anyhow::anyhow!("Cannot determine config directory (set HOME or XDG_CONFIG_HOME)")
34        })?;
35    Ok(base.join("koda"))
36}
37
38impl Database {
39    /// Initialize the database, run migrations, and enable WAL mode.
40    ///
41    /// `koda_config_dir` is the koda configuration directory (e.g. `~/.config/koda`).
42    /// The database lives in `<koda_config_dir>/db/koda.db`.
43    ///
44    /// Production callers should pass `db::config_dir()?`; tests pass a temp dir.
45    pub async fn init(koda_config_dir: &Path) -> Result<Self> {
46        let db_dir = koda_config_dir.join("db");
47        std::fs::create_dir_all(&db_dir)
48            .with_context(|| format!("Failed to create DB dir: {}", db_dir.display()))?;
49
50        let db_path = db_dir.join("koda.db");
51
52        Self::open(&db_path).await
53    }
54
55    /// Open a database at a specific path (used by tests and init).
56    pub async fn open(db_path: &Path) -> Result<Self> {
57        let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
58
59        let options = SqliteConnectOptions::from_str(&db_url)?
60            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
61            .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
62            .foreign_keys(true)
63            .create_if_missing(true);
64
65        let pool = SqlitePoolOptions::new()
66            .max_connections(5)
67            .connect_with(options)
68            .await
69            .with_context(|| format!("Failed to connect to database: {db_url}"))?;
70
71        // Run schema migrations
72        Self::migrate(&pool).await?;
73        Ok(Self { pool })
74    }
75
76    /// Apply the schema (idempotent).
77    async fn migrate(pool: &SqlitePool) -> Result<()> {
78        sqlx::query(
79            "CREATE TABLE IF NOT EXISTS sessions (
80                id TEXT PRIMARY KEY,
81                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
82                agent_name TEXT NOT NULL
83            );",
84        )
85        .execute(pool)
86        .await?;
87
88        sqlx::query(
89            "CREATE TABLE IF NOT EXISTS messages (
90                id INTEGER PRIMARY KEY AUTOINCREMENT,
91                session_id TEXT NOT NULL,
92                role TEXT NOT NULL,
93                content TEXT,
94                tool_calls TEXT,
95                tool_call_id TEXT,
96                prompt_tokens INTEGER,
97                completion_tokens INTEGER,
98                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
99                FOREIGN KEY(session_id) REFERENCES sessions(id)
100            );",
101        )
102        .execute(pool)
103        .await?;
104
105        sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id);")
106            .execute(pool)
107            .await?;
108
109        sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_role_id ON messages(role, id DESC);")
110            .execute(pool)
111            .await?;
112
113        // Additive migrations for new token tracking columns (idempotent).
114        for col in &[
115            "cache_read_tokens",
116            "cache_creation_tokens",
117            "thinking_tokens",
118        ] {
119            let sql = format!("ALTER TABLE messages ADD COLUMN {col} INTEGER");
120            // Ignore "duplicate column name" errors — column already exists.
121            if let Err(e) = sqlx::query(&sql).execute(pool).await {
122                let msg = e.to_string();
123                if !msg.contains("duplicate column name") {
124                    return Err(e.into());
125                }
126            }
127        }
128
129        // Text column migrations
130        for (col, col_type) in &[("agent_name", "TEXT")] {
131            let sql = format!("ALTER TABLE messages ADD COLUMN {col} {col_type}");
132            if let Err(e) = sqlx::query(&sql).execute(pool).await {
133                let msg = e.to_string();
134                if !msg.contains("duplicate column name") {
135                    return Err(e.into());
136                }
137            }
138        }
139
140        // Session-scoped key-value metadata (e.g. todo list).
141        sqlx::query(
142            "CREATE TABLE IF NOT EXISTS session_metadata (
143                session_id TEXT NOT NULL,
144                key TEXT NOT NULL,
145                value TEXT NOT NULL,
146                updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
147                PRIMARY KEY(session_id, key),
148                FOREIGN KEY(session_id) REFERENCES sessions(id)
149            );",
150        )
151        .execute(pool)
152        .await?;
153
154        // Additive migration: add project_root to sessions
155        let sql = "ALTER TABLE sessions ADD COLUMN project_root TEXT";
156        if let Err(e) = sqlx::query(sql).execute(pool).await {
157            let msg = e.to_string();
158            if !msg.contains("duplicate column name") {
159                return Err(e.into());
160            }
161        }
162
163        // Additive migration: add compacted_at for non-destructive compaction (#428)
164        let sql = "ALTER TABLE messages ADD COLUMN compacted_at TEXT";
165        if let Err(e) = sqlx::query(sql).execute(pool).await {
166            let msg = e.to_string();
167            if !msg.contains("duplicate column name") {
168                return Err(e.into());
169            }
170        }
171
172        // Additive migration: track last activity per session (#429)
173        let sql = "ALTER TABLE sessions ADD COLUMN last_accessed_at TEXT";
174        if let Err(e) = sqlx::query(sql).execute(pool).await {
175            let msg = e.to_string();
176            if !msg.contains("duplicate column name") {
177                return Err(e.into());
178            }
179        }
180
181        Ok(())
182    }
183}
184
185// ── Private helpers ─────────────────────────────────────────────────────────
186
187/// Remove messages with mismatched tool_use / tool_result pairing (#428).
188///
189/// Uses the symmetric-difference approach (inspired by Code Puppy):
190/// collect all tool_call IDs from calls and returns, find IDs that appear
191/// in only one set, and drop any message referencing a mismatched ID.
192///
193/// Handles orphans from any source: interrupted sessions, compaction
194/// boundaries, or session resume.
195fn prune_mismatched_tool_calls(messages: &mut Vec<Message>) {
196    if messages.is_empty() {
197        return;
198    }
199
200    let mut tool_call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
201    let mut tool_return_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
202
203    for msg in messages.iter() {
204        if msg.role == Role::Assistant {
205            if let Some(ref tc_json) = msg.tool_calls
206                && let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
207            {
208                for call in &calls {
209                    if let Some(id) = call.get("id").and_then(|v| v.as_str()) {
210                        tool_call_ids.insert(id.to_string());
211                    }
212                }
213            }
214        } else if msg.role == Role::Tool
215            && let Some(ref id) = msg.tool_call_id
216        {
217            tool_return_ids.insert(id.clone());
218        }
219    }
220
221    let mismatched: std::collections::HashSet<&String> = tool_call_ids
222        .symmetric_difference(&tool_return_ids)
223        .collect();
224
225    if mismatched.is_empty() {
226        return;
227    }
228
229    messages.retain(|msg| {
230        // Drop tool_result messages with mismatched IDs
231        if msg.role == Role::Tool
232            && let Some(ref id) = msg.tool_call_id
233            && mismatched.contains(id)
234        {
235            return false;
236        }
237        // Drop assistant messages whose tool_calls contain mismatched IDs
238        if msg.role == Role::Assistant
239            && let Some(ref tc_json) = msg.tool_calls
240            && let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
241        {
242            let has_mismatched = calls.iter().any(|call| {
243                call.get("id")
244                    .and_then(|v| v.as_str())
245                    .is_some_and(|id| mismatched.contains(&id.to_string()))
246            });
247            if has_mismatched {
248                return false;
249            }
250        }
251        true
252    });
253}
254
255#[async_trait::async_trait]
256impl Persistence for Database {
257    /// Create a new session, returning the generated session ID.
258    async fn create_session(&self, agent_name: &str, project_root: &Path) -> Result<String> {
259        let id = uuid::Uuid::new_v4().to_string();
260        let root = project_root.to_string_lossy().to_string();
261        sqlx::query("INSERT INTO sessions (id, agent_name, project_root) VALUES (?, ?, ?)")
262            .bind(&id)
263            .bind(agent_name)
264            .bind(&root)
265            .execute(&self.pool)
266            .await?;
267        tracing::info!("Created session: {id} (project: {root})");
268        Ok(id)
269    }
270
271    /// Insert a message into the conversation log.
272    async fn insert_message(
273        &self,
274        session_id: &str,
275        role: &Role,
276        content: Option<&str>,
277        tool_calls: Option<&str>,
278        tool_call_id: Option<&str>,
279        usage: Option<&crate::providers::TokenUsage>,
280    ) -> Result<i64> {
281        self.insert_message_with_agent(
282            session_id,
283            role,
284            content,
285            tool_calls,
286            tool_call_id,
287            usage,
288            None,
289        )
290        .await
291    }
292
293    /// Insert a message with an optional agent name for cost tracking.
294    #[allow(clippy::too_many_arguments)]
295    async fn insert_message_with_agent(
296        &self,
297        session_id: &str,
298        role: &Role,
299        content: Option<&str>,
300        tool_calls: Option<&str>,
301        tool_call_id: Option<&str>,
302        usage: Option<&crate::providers::TokenUsage>,
303        agent_name: Option<&str>,
304    ) -> Result<i64> {
305        let result = sqlx::query(
306            "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, \
307             prompt_tokens, completion_tokens, cache_read_tokens, cache_creation_tokens, \
308             thinking_tokens, agent_name)
309             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
310        )
311        .bind(session_id)
312        .bind(role.as_str())
313        .bind(content)
314        .bind(tool_calls)
315        .bind(tool_call_id)
316        .bind(usage.map(|u| u.prompt_tokens))
317        .bind(usage.map(|u| u.completion_tokens))
318        .bind(usage.map(|u| u.cache_read_tokens))
319        .bind(usage.map(|u| u.cache_creation_tokens))
320        .bind(usage.map(|u| u.thinking_tokens))
321        .bind(agent_name)
322        .execute(&self.pool)
323        .await?;
324
325        // Update session activity timestamp
326        sqlx::query("UPDATE sessions SET last_accessed_at = datetime('now') WHERE id = ?")
327            .bind(session_id)
328            .execute(&self.pool)
329            .await?;
330
331        Ok(result.last_insert_rowid())
332    }
333
334    /// Load active (non-compacted) messages for a session.
335    ///
336    /// Returns messages in chronological order. Compacted messages
337    /// (archived by `/compact`) are excluded — their summary replaces them.
338    /// Mismatched tool_use/tool_result pairs are pruned (#428).
339    async fn load_context(&self, session_id: &str) -> Result<Vec<Message>> {
340        let mut messages: Vec<Message> = sqlx::query_as::<_, MessageRow>(
341            "SELECT id, session_id, role, content, tool_calls, tool_call_id,
342                    prompt_tokens, completion_tokens,
343                    cache_read_tokens, cache_creation_tokens, thinking_tokens
344             FROM messages
345             WHERE session_id = ? AND compacted_at IS NULL
346             ORDER BY id ASC",
347        )
348        .bind(session_id)
349        .fetch_all(&self.pool)
350        .await?
351        .into_iter()
352        .map(|r| r.into())
353        .collect();
354
355        // Prune mismatched tool_use/tool_result pairs.
356        // Handles orphans from interrupted sessions, compaction boundaries,
357        // or session resume.
358        prune_mismatched_tool_calls(&mut messages);
359
360        Ok(messages)
361    }
362    /// Load ALL messages for a session (for RecallContext search).
363    /// Returns messages in chronological order, no truncation.
364    async fn load_all_messages(&self, session_id: &str) -> Result<Vec<Message>> {
365        let rows: Vec<Message> = sqlx::query_as::<_, MessageRow>(
366            "SELECT id, session_id, role, content, tool_calls, tool_call_id,
367    prompt_tokens, completion_tokens,
368    cache_read_tokens, cache_creation_tokens, thinking_tokens
369    FROM messages
370    WHERE session_id = ?
371    ORDER BY id ASC",
372        )
373        .bind(session_id)
374        .fetch_all(&self.pool)
375        .await?
376        .into_iter()
377        .map(|r| r.into())
378        .collect();
379        Ok(rows)
380    }
381
382    /// Load recent user messages across all sessions (for the startup banner).
383    /// Returns up to `limit` messages, newest first.
384    async fn recent_user_messages(&self, limit: i64) -> Result<Vec<String>> {
385        let rows: Vec<(String,)> = sqlx::query_as(
386            "SELECT content FROM messages
387    WHERE role = 'user' AND content IS NOT NULL AND content != ''
388    ORDER BY id DESC LIMIT ?",
389        )
390        .bind(limit)
391        .fetch_all(&self.pool)
392        .await?;
393
394        Ok(rows.into_iter().map(|r| r.0).collect())
395    }
396
397    /// Get token usage totals for a session.
398    async fn session_token_usage(&self, session_id: &str) -> Result<SessionUsage> {
399        let row: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(
400            "SELECT
401                COALESCE(SUM(prompt_tokens), 0),
402                COALESCE(SUM(completion_tokens), 0),
403                COALESCE(SUM(cache_read_tokens), 0),
404                COALESCE(SUM(cache_creation_tokens), 0),
405                COALESCE(SUM(thinking_tokens), 0),
406                COUNT(*)
407             FROM messages
408             WHERE session_id = ?
409               AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)",
410        )
411        .bind(session_id)
412        .fetch_one(&self.pool)
413        .await?;
414        Ok(SessionUsage {
415            prompt_tokens: row.0,
416            completion_tokens: row.1,
417            cache_read_tokens: row.2,
418            cache_creation_tokens: row.3,
419            thinking_tokens: row.4,
420            api_calls: row.5,
421        })
422    }
423
424    /// Get token usage broken down by agent name.
425    async fn session_usage_by_agent(
426        &self,
427        session_id: &str,
428    ) -> Result<Vec<(String, SessionUsage)>> {
429        let rows: Vec<(String, i64, i64, i64, i64, i64, i64)> = sqlx::query_as(
430            "SELECT
431                COALESCE(agent_name, 'main'),
432                COALESCE(SUM(prompt_tokens), 0),
433                COALESCE(SUM(completion_tokens), 0),
434                COALESCE(SUM(cache_read_tokens), 0),
435                COALESCE(SUM(cache_creation_tokens), 0),
436                COALESCE(SUM(thinking_tokens), 0),
437                COUNT(*)
438             FROM messages
439             WHERE session_id = ?
440               AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)
441             GROUP BY COALESCE(agent_name, 'main')
442             ORDER BY COALESCE(SUM(prompt_tokens), 0) + COALESCE(SUM(completion_tokens), 0) DESC",
443        )
444        .bind(session_id)
445        .fetch_all(&self.pool)
446        .await?;
447        Ok(rows
448            .into_iter()
449            .map(|r| {
450                (
451                    r.0,
452                    SessionUsage {
453                        prompt_tokens: r.1,
454                        completion_tokens: r.2,
455                        cache_read_tokens: r.3,
456                        cache_creation_tokens: r.4,
457                        thinking_tokens: r.5,
458                        api_calls: r.6,
459                    },
460                )
461            })
462            .collect())
463    }
464
465    /// List recent sessions for a specific project.
466    async fn list_sessions(&self, limit: i64, project_root: &Path) -> Result<Vec<SessionInfo>> {
467        let root = project_root.to_string_lossy().to_string();
468        let rows: Vec<SessionInfoRow> = sqlx::query_as(
469            "SELECT s.id, s.agent_name, s.created_at,
470                    COUNT(m.id) as message_count,
471                    COALESCE(SUM(m.prompt_tokens), 0) + COALESCE(SUM(m.completion_tokens), 0) as total_tokens
472             FROM sessions s
473             LEFT JOIN messages m ON m.session_id = s.id
474             WHERE s.project_root = ? OR s.project_root IS NULL
475             GROUP BY s.id
476             ORDER BY s.created_at DESC, s.rowid DESC
477             LIMIT ?",
478        )
479        .bind(&root)
480        .bind(limit)
481        .fetch_all(&self.pool)
482        .await?;
483        Ok(rows.into_iter().map(|r| r.into()).collect())
484    }
485
486    /// Get the last assistant text response for a session (for headless JSON output).
487    async fn last_assistant_message(&self, session_id: &str) -> Result<String> {
488        let row: Option<(String,)> = sqlx::query_as(
489            "SELECT content FROM messages
490             WHERE session_id = ? AND role = 'assistant' AND content IS NOT NULL
491             ORDER BY id DESC LIMIT 1",
492        )
493        .bind(session_id)
494        .fetch_optional(&self.pool)
495        .await?;
496        Ok(row.map(|r| r.0).unwrap_or_default())
497    }
498
499    /// Get the last user message in a session.
500    async fn last_user_message(&self, session_id: &str) -> Result<String> {
501        let row: Option<(String,)> = sqlx::query_as(
502            "SELECT content FROM messages
503             WHERE session_id = ? AND role = 'user' AND content IS NOT NULL
504             ORDER BY id DESC LIMIT 1",
505        )
506        .bind(session_id)
507        .fetch_optional(&self.pool)
508        .await?;
509        Ok(row.map(|r| r.0).unwrap_or_default())
510    }
511
512    /// Delete a session and all its messages/metadata atomically.
513    async fn delete_session(&self, session_id: &str) -> Result<bool> {
514        let mut tx = self.pool.begin().await?;
515
516        sqlx::query("DELETE FROM messages WHERE session_id = ?")
517            .bind(session_id)
518            .execute(&mut *tx)
519            .await?;
520
521        sqlx::query("DELETE FROM session_metadata WHERE session_id = ?")
522            .bind(session_id)
523            .execute(&mut *tx)
524            .await?;
525
526        let result = sqlx::query("DELETE FROM sessions WHERE id = ?")
527            .bind(session_id)
528            .execute(&mut *tx)
529            .await?;
530
531        tx.commit().await?;
532
533        // Reclaim freed pages from the deletion.
534        sqlx::query("PRAGMA incremental_vacuum")
535            .execute(&self.pool)
536            .await?;
537
538        Ok(result.rows_affected() > 0)
539    }
540
541    /// Compact a session: summarize old messages while preserving the most recent ones.
542    ///
543    /// Keeps the last `preserve_count` messages intact, deletes the rest, and
544    /// inserts a summary (as a `system` message) plus a continuation hint
545    /// (as an `assistant` message) before the preserved tail.
546    ///
547    /// Returns the number of messages that were deleted/replaced.
548    async fn compact_session(
549        &self,
550        session_id: &str,
551        summary: &str,
552        preserve_count: usize,
553    ) -> Result<usize> {
554        let mut tx = self.pool.begin().await?;
555
556        // Get active (non-compacted) message IDs ordered oldest→newest
557        let all_ids: Vec<(i64,)> = sqlx::query_as(
558            "SELECT id FROM messages WHERE session_id = ? AND compacted_at IS NULL ORDER BY id ASC",
559        )
560        .bind(session_id)
561        .fetch_all(&mut *tx)
562        .await?;
563
564        let total = all_ids.len();
565        if total == 0 {
566            tx.commit().await?;
567            return Ok(0);
568        }
569
570        // Determine which messages to archive (everything except the tail)
571        let keep_from = total.saturating_sub(preserve_count);
572        let ids_to_archive: Vec<i64> = all_ids[..keep_from].iter().map(|r| r.0).collect();
573        let archived_count = ids_to_archive.len();
574
575        if archived_count == 0 {
576            tx.commit().await?;
577            return Ok(0);
578        }
579
580        // Mark old messages as compacted (non-destructive — history preserved in DB)
581        for chunk in ids_to_archive.chunks(500) {
582            let placeholders: String = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(",");
583            let sql = format!(
584                "UPDATE messages SET compacted_at = datetime('now') \
585                 WHERE session_id = ? AND id IN ({placeholders})"
586            );
587            let mut query = sqlx::query(&sql).bind(session_id);
588            for id in chunk {
589                query = query.bind(id);
590            }
591            query.execute(&mut *tx).await?;
592        }
593
594        // Insert the summary as a system message
595        sqlx::query(
596            "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens)
597             VALUES (?, 'system', ?, NULL, NULL, NULL, NULL)",
598        )
599        .bind(session_id)
600        .bind(summary)
601        .execute(&mut *tx)
602        .await?;
603
604        // Insert a continuation hint so the LLM knows how to behave
605        let continuation = "Your context was compacted. The previous message contains a summary of our earlier conversation. \
606            Do not mention the summary or that compaction occurred. \
607            Continue the conversation naturally based on the summarized context.";
608        sqlx::query(
609            "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens)
610             VALUES (?, 'assistant', ?, NULL, NULL, NULL, NULL)",
611        )
612        .bind(session_id)
613        .bind(continuation)
614        .execute(&mut *tx)
615        .await?;
616
617        tx.commit().await?;
618
619        Ok(archived_count)
620    }
621
622    /// Check if the last message in a session is a tool call awaiting a response.
623    /// Used to defer compaction during active tool execution.
624    async fn has_pending_tool_calls(&self, session_id: &str) -> Result<bool> {
625        // A pending tool call exists when the last message has role='assistant'
626        // with tool_calls set, and there's no subsequent tool response.
627        let last_msg: Option<(String, Option<String>)> = sqlx::query_as(
628            "SELECT role, tool_calls FROM messages
629             WHERE session_id = ? AND compacted_at IS NULL
630             ORDER BY id DESC LIMIT 1",
631        )
632        .bind(session_id)
633        .fetch_optional(&self.pool)
634        .await?;
635
636        Ok(matches!(last_msg, Some((role, Some(_))) if role == "assistant"))
637    }
638
639    /// Stats about compacted (archived) messages across all sessions.
640    async fn compacted_stats(&self) -> Result<CompactedStats> {
641        let row: (i64, i64, i64, Option<String>) = sqlx::query_as(
642            "SELECT
643                 COUNT(*),
644                 COUNT(DISTINCT session_id),
645                 COALESCE(SUM(LENGTH(content) + LENGTH(COALESCE(tool_calls,''))), 0),
646                 MIN(compacted_at)
647             FROM messages
648             WHERE compacted_at IS NOT NULL",
649        )
650        .fetch_one(&self.pool)
651        .await?;
652
653        Ok(CompactedStats {
654            message_count: row.0,
655            session_count: row.1,
656            size_bytes: row.2,
657            oldest: row.3,
658        })
659    }
660
661    /// Permanently delete compacted messages older than `min_age_days`.
662    /// Pass 0 to delete all compacted messages regardless of age.
663    async fn purge_compacted(&self, min_age_days: u32) -> Result<usize> {
664        let result = if min_age_days == 0 {
665            sqlx::query("DELETE FROM messages WHERE compacted_at IS NOT NULL")
666                .execute(&self.pool)
667                .await?
668        } else {
669            sqlx::query(
670                "DELETE FROM messages
671                 WHERE compacted_at IS NOT NULL
672                   AND compacted_at < datetime('now', ?)",
673            )
674            .bind(format!("-{min_age_days} days"))
675            .execute(&self.pool)
676            .await?
677        };
678
679        let deleted = result.rows_affected() as usize;
680
681        // Reclaim disk space.
682        sqlx::query("VACUUM").execute(&self.pool).await?;
683
684        tracing::info!("Purged {deleted} compacted messages (>{min_age_days} days old)");
685        Ok(deleted)
686    }
687
688    /// Get a session metadata value by key.
689    async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> {
690        let row: Option<(String,)> =
691            sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ? AND key = ?")
692                .bind(session_id)
693                .bind(key)
694                .fetch_optional(&self.pool)
695                .await?;
696        Ok(row.map(|r| r.0))
697    }
698
699    /// Set a session metadata value (upsert).
700    async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> {
701        sqlx::query(
702            "INSERT INTO session_metadata (session_id, key, value, updated_at)
703             VALUES (?, ?, ?, CURRENT_TIMESTAMP)
704             ON CONFLICT(session_id, key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
705        )
706        .bind(session_id)
707        .bind(key)
708        .bind(value)
709        .execute(&self.pool)
710        .await?;
711        Ok(())
712    }
713
714    /// Get the todo list for a session (convenience wrapper).
715    async fn get_todo(&self, session_id: &str) -> Result<Option<String>> {
716        self.get_metadata(session_id, "todo").await
717    }
718
719    /// Set the todo list for a session (convenience wrapper).
720    async fn set_todo(&self, session_id: &str, content: &str) -> Result<()> {
721        self.set_metadata(session_id, "todo", content).await
722    }
723}
724
725/// Internal row type for sqlx deserialization.
726#[derive(sqlx::FromRow)]
727struct MessageRow {
728    id: i64,
729    session_id: String,
730    role: String,
731    content: Option<String>,
732    tool_calls: Option<String>,
733    tool_call_id: Option<String>,
734    prompt_tokens: Option<i64>,
735    completion_tokens: Option<i64>,
736    cache_read_tokens: Option<i64>,
737    cache_creation_tokens: Option<i64>,
738    thinking_tokens: Option<i64>,
739}
740
741/// Session metadata for listing.
742#[derive(Debug, Clone, sqlx::FromRow)]
743struct SessionInfoRow {
744    id: String,
745    agent_name: String,
746    created_at: String,
747    message_count: i64,
748    total_tokens: i64,
749}
750
751impl From<SessionInfoRow> for SessionInfo {
752    fn from(r: SessionInfoRow) -> Self {
753        Self {
754            id: r.id,
755            agent_name: r.agent_name,
756            created_at: r.created_at,
757            message_count: r.message_count,
758            total_tokens: r.total_tokens,
759        }
760    }
761}
762
763impl From<MessageRow> for Message {
764    fn from(r: MessageRow) -> Self {
765        Self {
766            id: r.id,
767            session_id: r.session_id,
768            role: r.role.parse().unwrap_or(Role::User),
769            content: r.content,
770            tool_calls: r.tool_calls,
771            tool_call_id: r.tool_call_id,
772            prompt_tokens: r.prompt_tokens,
773            completion_tokens: r.completion_tokens,
774            cache_read_tokens: r.cache_read_tokens,
775            cache_creation_tokens: r.cache_creation_tokens,
776            thinking_tokens: r.thinking_tokens,
777        }
778    }
779}
780
781#[cfg(test)]
782mod tests {
783    use super::*;
784    use tempfile::TempDir;
785
786    async fn setup() -> (Database, TempDir) {
787        let tmp = TempDir::new().unwrap();
788        let db_path = tmp.path().join("test.db");
789        let db = Database::open(&db_path).await.unwrap();
790        (db, tmp)
791    }
792
793    #[tokio::test]
794    async fn test_create_session() {
795        let (db, _tmp) = setup().await;
796        let id = db.create_session("default", _tmp.path()).await.unwrap();
797        assert!(!id.is_empty());
798    }
799
800    #[tokio::test]
801    async fn test_insert_and_load_messages() {
802        let (db, _tmp) = setup().await;
803        let session = db.create_session("default", _tmp.path()).await.unwrap();
804
805        db.insert_message(&session, &Role::User, Some("hello"), None, None, None)
806            .await
807            .unwrap();
808        db.insert_message(
809            &session,
810            &Role::Assistant,
811            Some("hi there!"),
812            None,
813            None,
814            None,
815        )
816        .await
817        .unwrap();
818
819        let msgs = db.load_context(&session).await.unwrap();
820        assert_eq!(msgs.len(), 2);
821        assert_eq!(msgs[0].role, Role::User);
822        assert_eq!(msgs[1].role, Role::Assistant);
823    }
824
825    #[tokio::test]
826    async fn test_load_context_returns_all_active_messages() {
827        let (db, _tmp) = setup().await;
828        let session = db.create_session("default", _tmp.path()).await.unwrap();
829
830        // Insert many messages
831        for i in 0..20 {
832            let content = format!("Message number {i}");
833            db.insert_message(&session, &Role::User, Some(&content), None, None, None)
834                .await
835                .unwrap();
836        }
837
838        // Load all messages — no sliding window, no truncation
839        let msgs = db.load_context(&session).await.unwrap();
840        assert_eq!(msgs.len(), 20, "Should load all 20 messages");
841
842        // Messages should be in chronological order
843        assert!(msgs[0].content.as_ref().unwrap().contains("number 0"));
844        assert!(msgs[19].content.as_ref().unwrap().contains("number 19"));
845    }
846
847    #[tokio::test]
848    async fn test_sessions_are_isolated() {
849        let (db, _tmp) = setup().await;
850        let s1 = db.create_session("agent-a", _tmp.path()).await.unwrap();
851        let s2 = db.create_session("agent-b", _tmp.path()).await.unwrap();
852
853        db.insert_message(&s1, &Role::User, Some("session 1"), None, None, None)
854            .await
855            .unwrap();
856        db.insert_message(&s2, &Role::User, Some("session 2"), None, None, None)
857            .await
858            .unwrap();
859
860        let msgs1 = db.load_context(&s1).await.unwrap();
861        let msgs2 = db.load_context(&s2).await.unwrap();
862
863        assert_eq!(msgs1.len(), 1);
864        assert_eq!(msgs2.len(), 1);
865        assert_eq!(msgs1[0].content.as_deref().unwrap(), "session 1");
866        assert_eq!(msgs2[0].content.as_deref().unwrap(), "session 2");
867    }
868
869    #[tokio::test]
870    async fn test_session_token_usage() {
871        let (db, _tmp) = setup().await;
872        let session = db.create_session("default", _tmp.path()).await.unwrap();
873
874        db.insert_message(&session, &Role::User, Some("q1"), None, None, None)
875            .await
876            .unwrap();
877        let usage1 = crate::providers::TokenUsage {
878            prompt_tokens: 100,
879            completion_tokens: 50,
880            ..Default::default()
881        };
882        db.insert_message(
883            &session,
884            &Role::Assistant,
885            Some("a1"),
886            None,
887            None,
888            Some(&usage1),
889        )
890        .await
891        .unwrap();
892        db.insert_message(&session, &Role::User, Some("q2"), None, None, None)
893            .await
894            .unwrap();
895        let usage2 = crate::providers::TokenUsage {
896            prompt_tokens: 200,
897            completion_tokens: 80,
898            ..Default::default()
899        };
900        db.insert_message(
901            &session,
902            &Role::Assistant,
903            Some("a2"),
904            None,
905            None,
906            Some(&usage2),
907        )
908        .await
909        .unwrap();
910
911        let u = db.session_token_usage(&session).await.unwrap();
912        assert_eq!(u.prompt_tokens, 300);
913        assert_eq!(u.completion_tokens, 130);
914        assert_eq!(u.api_calls, 2);
915    }
916
917    #[tokio::test]
918    async fn test_list_sessions() {
919        let (db, _tmp) = setup().await;
920        db.create_session("agent-a", _tmp.path()).await.unwrap();
921        db.create_session("agent-b", _tmp.path()).await.unwrap();
922        db.create_session("agent-c", _tmp.path()).await.unwrap();
923
924        let sessions = db.list_sessions(10, _tmp.path()).await.unwrap();
925        assert_eq!(sessions.len(), 3);
926        // Most recent first
927        assert_eq!(sessions[0].agent_name, "agent-c");
928    }
929
930    #[tokio::test]
931    async fn test_delete_session() {
932        let (db, _tmp) = setup().await;
933        let s1 = db.create_session("default", _tmp.path()).await.unwrap();
934        db.insert_message(&s1, &Role::User, Some("hello"), None, None, None)
935            .await
936            .unwrap();
937
938        assert!(db.delete_session(&s1).await.unwrap());
939
940        let sessions = db.list_sessions(10, _tmp.path()).await.unwrap();
941        assert!(sessions.is_empty());
942
943        // Deleting again returns false
944        assert!(!db.delete_session(&s1).await.unwrap());
945    }
946
947    #[tokio::test]
948    async fn test_compact_session() {
949        let (db, _tmp) = setup().await;
950        let session = db.create_session("default", _tmp.path()).await.unwrap();
951
952        // Insert several messages
953        for i in 0..10 {
954            let role = if i % 2 == 0 {
955                &Role::User
956            } else {
957                &Role::Assistant
958            };
959            db.insert_message(&session, role, Some(&format!("msg {i}")), None, None, None)
960                .await
961                .unwrap();
962        }
963
964        // Compact preserving the last 2 messages
965        let deleted = db
966            .compact_session(&session, "Summary of conversation", 2)
967            .await
968            .unwrap();
969        assert_eq!(deleted, 8); // 10 total - 2 preserved = 8 deleted
970
971        // Should have: summary(system) + continuation(assistant) + 2 preserved = 4
972        let msgs = db.load_context(&session).await.unwrap();
973        assert_eq!(msgs.len(), 4);
974
975        // Check that the summary is a system message
976        let system_msgs: Vec<_> = msgs.iter().filter(|m| m.role == Role::System).collect();
977        assert_eq!(system_msgs.len(), 1);
978        assert!(
979            system_msgs[0]
980                .content
981                .as_ref()
982                .unwrap()
983                .contains("Summary of conversation")
984        );
985
986        // Check that there's a continuation hint as assistant
987        let assistant_msgs: Vec<_> = msgs.iter().filter(|m| m.role == Role::Assistant).collect();
988        assert!(
989            assistant_msgs
990                .iter()
991                .any(|m| m.content.as_deref().unwrap_or("").contains("compacted")),
992            "Expected a continuation hint from assistant"
993        );
994
995        // The 2 preserved messages should still be there
996        let preserved: Vec<_> = msgs
997            .iter()
998            .filter(|m| m.content.as_deref().is_some_and(|c| c.starts_with("msg ")))
999            .collect();
1000        assert_eq!(preserved.len(), 2);
1001    }
1002
1003    #[tokio::test]
1004    async fn test_compact_preserves_zero() {
1005        let (db, _tmp) = setup().await;
1006        let session = db.create_session("default", _tmp.path()).await.unwrap();
1007
1008        for i in 0..6 {
1009            let role = if i % 2 == 0 {
1010                &Role::User
1011            } else {
1012                &Role::Assistant
1013            };
1014            db.insert_message(&session, role, Some(&format!("msg {i}")), None, None, None)
1015                .await
1016                .unwrap();
1017        }
1018
1019        // Compact preserving 0 — deletes everything, inserts summary + continuation
1020        let deleted = db
1021            .compact_session(&session, "Full summary", 0)
1022            .await
1023            .unwrap();
1024        assert_eq!(deleted, 6);
1025
1026        let msgs = db.load_context(&session).await.unwrap();
1027        assert_eq!(msgs.len(), 2); // summary + continuation
1028        assert_eq!(msgs.iter().filter(|m| m.role == Role::System).count(), 1);
1029        assert_eq!(msgs.iter().filter(|m| m.role == Role::Assistant).count(), 1);
1030    }
1031
1032    #[tokio::test]
1033    async fn test_has_pending_tool_calls() {
1034        let (db, _tmp) = setup().await;
1035        let session = db.create_session("default", _tmp.path()).await.unwrap();
1036
1037        // No messages → no pending
1038        assert!(!db.has_pending_tool_calls(&session).await.unwrap());
1039
1040        // User message → no pending
1041        db.insert_message(&session, &Role::User, Some("hello"), None, None, None)
1042            .await
1043            .unwrap();
1044        assert!(!db.has_pending_tool_calls(&session).await.unwrap());
1045
1046        // Assistant with tool_calls → pending!
1047        db.insert_message(
1048            &session,
1049            &Role::Assistant,
1050            None,
1051            Some(r#"[{"id":"tc1","name":"Read","arguments":"{}"}]"#),
1052            None,
1053            None,
1054        )
1055        .await
1056        .unwrap();
1057        assert!(db.has_pending_tool_calls(&session).await.unwrap());
1058
1059        // Tool response → no longer pending
1060        db.insert_message(
1061            &session,
1062            &Role::Tool,
1063            Some("file contents"),
1064            None,
1065            Some("tc1"),
1066            None,
1067        )
1068        .await
1069        .unwrap();
1070        assert!(!db.has_pending_tool_calls(&session).await.unwrap());
1071    }
1072
1073    #[tokio::test]
1074    async fn test_prune_mismatched_tool_calls() {
1075        let (db, _tmp) = setup().await;
1076        let session = db.create_session("default", _tmp.path()).await.unwrap();
1077
1078        // Normal turn: user → assistant with tool_calls → tool result
1079        db.insert_message(&session, &Role::User, Some("hello"), None, None, None)
1080            .await
1081            .unwrap();
1082        db.insert_message(
1083            &session,
1084            &Role::Assistant,
1085            Some("Let me read that."),
1086            Some(r#"[{"id":"tc1","name":"Read","arguments":"{}"}]"#),
1087            None,
1088            None,
1089        )
1090        .await
1091        .unwrap();
1092        db.insert_message(
1093            &session,
1094            &Role::Tool,
1095            Some("file contents"),
1096            None,
1097            Some("tc1"),
1098            None,
1099        )
1100        .await
1101        .unwrap();
1102
1103        // Interrupted turn: assistant with tool_calls but NO tool result
1104        db.insert_message(
1105            &session,
1106            &Role::Assistant,
1107            Some("I'll edit the file."),
1108            Some(r#"[{"id":"tc2","name":"Edit","arguments":"{}"}]"#),
1109            None,
1110            None,
1111        )
1112        .await
1113        .unwrap();
1114
1115        let msgs = db.load_context(&session).await.unwrap();
1116
1117        // The first assistant's tool_calls should be preserved (has tool result)
1118        let first_asst = msgs
1119            .iter()
1120            .find(|m| m.content.as_deref() == Some("Let me read that."))
1121            .unwrap();
1122        assert!(
1123            first_asst.tool_calls.is_some(),
1124            "completed tool_calls should be preserved"
1125        );
1126
1127        // The orphaned assistant (tool_calls with no result) should be dropped entirely
1128        let orphaned = msgs
1129            .iter()
1130            .find(|m| m.content.as_deref() == Some("I'll edit the file."));
1131        assert!(
1132            orphaned.is_none(),
1133            "orphaned assistant message should be dropped by prune_mismatched_tool_calls"
1134        );
1135    }
1136
1137    #[test]
1138    fn test_prune_mismatched_tool_calls_unit() {
1139        fn msg(
1140            role: &str,
1141            content: Option<&str>,
1142            tool_calls: Option<&str>,
1143            tool_call_id: Option<&str>,
1144        ) -> Message {
1145            Message {
1146                id: 0,
1147                session_id: String::new(),
1148                role: role.parse().unwrap_or(Role::User),
1149                content: content.map(Into::into),
1150                tool_calls: tool_calls.map(Into::into),
1151                tool_call_id: tool_call_id.map(Into::into),
1152                prompt_tokens: None,
1153                completion_tokens: None,
1154                cache_read_tokens: None,
1155                cache_creation_tokens: None,
1156                thinking_tokens: None,
1157            }
1158        }
1159
1160        // No messages — no crash
1161        let mut empty: Vec<Message> = vec![];
1162        prune_mismatched_tool_calls(&mut empty);
1163        assert!(empty.is_empty());
1164
1165        // User message only — no change
1166        let mut msgs = vec![msg("user", Some("hi"), None, None)];
1167        prune_mismatched_tool_calls(&mut msgs);
1168        assert_eq!(msgs.len(), 1);
1169
1170        // Orphaned assistant with tool_calls, no result — dropped
1171        let mut msgs = vec![
1172            msg("user", Some("hi"), None, None),
1173            msg(
1174                "assistant",
1175                Some("doing it"),
1176                Some(r#"[{"id":"t1"}]"#),
1177                None,
1178            ),
1179        ];
1180        prune_mismatched_tool_calls(&mut msgs);
1181        assert_eq!(msgs.len(), 1, "orphaned assistant should be dropped");
1182        assert_eq!(msgs[0].role, Role::User);
1183
1184        // Complete pair — preserved
1185        let mut msgs = vec![
1186            msg("user", Some("hi"), None, None),
1187            msg("assistant", None, Some(r#"[{"id":"t1"}]"#), None),
1188            msg("tool", Some("ok"), None, Some("t1")),
1189        ];
1190        prune_mismatched_tool_calls(&mut msgs);
1191        assert_eq!(msgs.len(), 3, "complete pair should be preserved");
1192        assert!(msgs[1].tool_calls.is_some());
1193    }
1194
1195    #[tokio::test]
1196    async fn test_session_metadata_and_todo() {
1197        let (db, _tmp) = setup().await;
1198        let session = db.create_session("default", _tmp.path()).await.unwrap();
1199
1200        // No metadata initially
1201        assert!(db.get_todo(&session).await.unwrap().is_none());
1202        assert!(
1203            db.get_metadata(&session, "anything")
1204                .await
1205                .unwrap()
1206                .is_none()
1207        );
1208
1209        // Set and get todo
1210        db.set_todo(&session, "- [ ] Task 1\n- [x] Task 2")
1211            .await
1212            .unwrap();
1213        let todo = db.get_todo(&session).await.unwrap().unwrap();
1214        assert!(todo.contains("Task 1"));
1215        assert!(todo.contains("Task 2"));
1216
1217        // Update (upsert) replaces the value
1218        db.set_todo(&session, "- [x] Task 1\n- [x] Task 2")
1219            .await
1220            .unwrap();
1221        let todo = db.get_todo(&session).await.unwrap().unwrap();
1222        assert!(todo.starts_with("- [x] Task 1"));
1223
1224        // Generic metadata works too
1225        db.set_metadata(&session, "custom_key", "custom_value")
1226            .await
1227            .unwrap();
1228        assert_eq!(
1229            db.get_metadata(&session, "custom_key")
1230                .await
1231                .unwrap()
1232                .unwrap(),
1233            "custom_value"
1234        );
1235    }
1236
1237    #[tokio::test]
1238    async fn test_token_usage_empty_session() {
1239        let (db, _tmp) = setup().await;
1240        let session = db.create_session("default", _tmp.path()).await.unwrap();
1241
1242        let u = db.session_token_usage(&session).await.unwrap();
1243        assert_eq!(u.prompt_tokens, 0);
1244        assert_eq!(u.completion_tokens, 0);
1245        assert_eq!(u.api_calls, 0);
1246    }
1247
1248    #[tokio::test]
1249    async fn test_last_assistant_message() {
1250        let (db, _tmp) = setup().await;
1251        let session = db.create_session("default", _tmp.path()).await.unwrap();
1252
1253        // Empty session returns empty string
1254        let msg = db.last_assistant_message(&session).await.unwrap();
1255        assert_eq!(msg, "");
1256
1257        // Insert some messages
1258        db.insert_message(&session, &Role::User, Some("question 1"), None, None, None)
1259            .await
1260            .unwrap();
1261        db.insert_message(
1262            &session,
1263            &Role::Assistant,
1264            Some("answer 1"),
1265            None,
1266            None,
1267            None,
1268        )
1269        .await
1270        .unwrap();
1271        db.insert_message(&session, &Role::User, Some("question 2"), None, None, None)
1272            .await
1273            .unwrap();
1274        db.insert_message(
1275            &session,
1276            &Role::Assistant,
1277            Some("answer 2"),
1278            None,
1279            None,
1280            None,
1281        )
1282        .await
1283        .unwrap();
1284
1285        // Should return the LAST assistant message
1286        let msg = db.last_assistant_message(&session).await.unwrap();
1287        assert_eq!(msg, "answer 2");
1288    }
1289
1290    #[tokio::test]
1291    async fn test_last_assistant_message_skips_tool_calls() {
1292        let (db, _tmp) = setup().await;
1293        let session = db.create_session("default", _tmp.path()).await.unwrap();
1294
1295        db.insert_message(
1296            &session,
1297            &Role::User,
1298            Some("do something"),
1299            None,
1300            None,
1301            None,
1302        )
1303        .await
1304        .unwrap();
1305        // Assistant with tool calls but no text content
1306        db.insert_message(
1307            &session,
1308            &Role::Assistant,
1309            None,
1310            Some("[{\"id\":\"1\"}]"),
1311            None,
1312            None,
1313        )
1314        .await
1315        .unwrap();
1316        db.insert_message(
1317            &session,
1318            &Role::Tool,
1319            Some("tool result"),
1320            None,
1321            Some("1"),
1322            None,
1323        )
1324        .await
1325        .unwrap();
1326        // Final text response
1327        db.insert_message(&session, &Role::Assistant, Some("Done!"), None, None, None)
1328            .await
1329            .unwrap();
1330
1331        let msg = db.last_assistant_message(&session).await.unwrap();
1332        assert_eq!(msg, "Done!");
1333    }
1334}