use anyhow::{Context, Result};
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use std::path::Path;
use std::str::FromStr;
pub use crate::persistence::{
CompactedStats, Message, Persistence, Role, SessionInfo, SessionUsage,
};
#[derive(Debug, Clone)]
pub struct Database {
pool: SqlitePool,
}
pub fn config_dir() -> Result<std::path::PathBuf> {
let base = std::env::var("XDG_CONFIG_HOME")
.ok()
.map(std::path::PathBuf::from)
.or_else(|| {
std::env::var("HOME")
.ok()
.map(|h| std::path::PathBuf::from(h).join(".config"))
})
.ok_or_else(|| {
anyhow::anyhow!("Cannot determine config directory (set HOME or XDG_CONFIG_HOME)")
})?;
Ok(base.join("koda"))
}
impl Database {
pub async fn init(koda_config_dir: &Path) -> Result<Self> {
let db_dir = koda_config_dir.join("db");
std::fs::create_dir_all(&db_dir)
.with_context(|| format!("Failed to create DB dir: {}", db_dir.display()))?;
let db_path = db_dir.join("koda.db");
Self::open(&db_path).await
}
pub async fn open(db_path: &Path) -> Result<Self> {
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
let options = SqliteConnectOptions::from_str(&db_url)?
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
.foreign_keys(true)
.create_if_missing(true);
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await
.with_context(|| format!("Failed to connect to database: {db_url}"))?;
Self::migrate(&pool).await?;
Ok(Self { pool })
}
async fn migrate(pool: &SqlitePool) -> Result<()> {
sqlx::query(
"CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
agent_name TEXT NOT NULL
);",
)
.execute(pool)
.await?;
sqlx::query(
"CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT,
tool_calls TEXT,
tool_call_id TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(session_id) REFERENCES sessions(id)
);",
)
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id);")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_role_id ON messages(role, id DESC);")
.execute(pool)
.await?;
for col in &[
"cache_read_tokens",
"cache_creation_tokens",
"thinking_tokens",
] {
let sql = format!("ALTER TABLE messages ADD COLUMN {col} INTEGER");
if let Err(e) = sqlx::query(&sql).execute(pool).await {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
}
for (col, col_type) in &[("agent_name", "TEXT")] {
let sql = format!("ALTER TABLE messages ADD COLUMN {col} {col_type}");
if let Err(e) = sqlx::query(&sql).execute(pool).await {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
}
sqlx::query(
"CREATE TABLE IF NOT EXISTS session_metadata (
session_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY(session_id, key),
FOREIGN KEY(session_id) REFERENCES sessions(id)
);",
)
.execute(pool)
.await?;
let sql = "ALTER TABLE sessions ADD COLUMN project_root TEXT";
if let Err(e) = sqlx::query(sql).execute(pool).await {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
let sql = "ALTER TABLE messages ADD COLUMN compacted_at TEXT";
if let Err(e) = sqlx::query(sql).execute(pool).await {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
let sql = "ALTER TABLE sessions ADD COLUMN last_accessed_at TEXT";
if let Err(e) = sqlx::query(sql).execute(pool).await {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
Ok(())
}
}
fn prune_mismatched_tool_calls(messages: &mut Vec<Message>) {
if messages.is_empty() {
return;
}
let mut tool_call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut tool_return_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
for msg in messages.iter() {
if msg.role == Role::Assistant {
if let Some(ref tc_json) = msg.tool_calls
&& let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
{
for call in &calls {
if let Some(id) = call.get("id").and_then(|v| v.as_str()) {
tool_call_ids.insert(id.to_string());
}
}
}
} else if msg.role == Role::Tool
&& let Some(ref id) = msg.tool_call_id
{
tool_return_ids.insert(id.clone());
}
}
let mismatched: std::collections::HashSet<&String> = tool_call_ids
.symmetric_difference(&tool_return_ids)
.collect();
if mismatched.is_empty() {
return;
}
messages.retain(|msg| {
if msg.role == Role::Tool
&& let Some(ref id) = msg.tool_call_id
&& mismatched.contains(id)
{
return false;
}
if msg.role == Role::Assistant
&& let Some(ref tc_json) = msg.tool_calls
&& let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
{
let has_mismatched = calls.iter().any(|call| {
call.get("id")
.and_then(|v| v.as_str())
.is_some_and(|id| mismatched.contains(&id.to_string()))
});
if has_mismatched {
return false;
}
}
true
});
}
#[async_trait::async_trait]
impl Persistence for Database {
async fn create_session(&self, agent_name: &str, project_root: &Path) -> Result<String> {
let id = uuid::Uuid::new_v4().to_string();
let root = project_root.to_string_lossy().to_string();
sqlx::query("INSERT INTO sessions (id, agent_name, project_root) VALUES (?, ?, ?)")
.bind(&id)
.bind(agent_name)
.bind(&root)
.execute(&self.pool)
.await?;
tracing::info!("Created session: {id} (project: {root})");
Ok(id)
}
async fn insert_message(
&self,
session_id: &str,
role: &Role,
content: Option<&str>,
tool_calls: Option<&str>,
tool_call_id: Option<&str>,
usage: Option<&crate::providers::TokenUsage>,
) -> Result<i64> {
self.insert_message_with_agent(
session_id,
role,
content,
tool_calls,
tool_call_id,
usage,
None,
)
.await
}
#[allow(clippy::too_many_arguments)]
async fn insert_message_with_agent(
&self,
session_id: &str,
role: &Role,
content: Option<&str>,
tool_calls: Option<&str>,
tool_call_id: Option<&str>,
usage: Option<&crate::providers::TokenUsage>,
agent_name: Option<&str>,
) -> Result<i64> {
let result = sqlx::query(
"INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, \
prompt_tokens, completion_tokens, cache_read_tokens, cache_creation_tokens, \
thinking_tokens, agent_name)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
)
.bind(session_id)
.bind(role.as_str())
.bind(content)
.bind(tool_calls)
.bind(tool_call_id)
.bind(usage.map(|u| u.prompt_tokens))
.bind(usage.map(|u| u.completion_tokens))
.bind(usage.map(|u| u.cache_read_tokens))
.bind(usage.map(|u| u.cache_creation_tokens))
.bind(usage.map(|u| u.thinking_tokens))
.bind(agent_name)
.execute(&self.pool)
.await?;
sqlx::query("UPDATE sessions SET last_accessed_at = datetime('now') WHERE id = ?")
.bind(session_id)
.execute(&self.pool)
.await?;
Ok(result.last_insert_rowid())
}
async fn load_context(&self, session_id: &str) -> Result<Vec<Message>> {
let mut messages: Vec<Message> = sqlx::query_as::<_, MessageRow>(
"SELECT id, session_id, role, content, tool_calls, tool_call_id,
prompt_tokens, completion_tokens,
cache_read_tokens, cache_creation_tokens, thinking_tokens
FROM messages
WHERE session_id = ? AND compacted_at IS NULL
ORDER BY id ASC",
)
.bind(session_id)
.fetch_all(&self.pool)
.await?
.into_iter()
.map(|r| r.into())
.collect();
prune_mismatched_tool_calls(&mut messages);
Ok(messages)
}
async fn load_all_messages(&self, session_id: &str) -> Result<Vec<Message>> {
let rows: Vec<Message> = sqlx::query_as::<_, MessageRow>(
"SELECT id, session_id, role, content, tool_calls, tool_call_id,
prompt_tokens, completion_tokens,
cache_read_tokens, cache_creation_tokens, thinking_tokens
FROM messages
WHERE session_id = ?
ORDER BY id ASC",
)
.bind(session_id)
.fetch_all(&self.pool)
.await?
.into_iter()
.map(|r| r.into())
.collect();
Ok(rows)
}
async fn recent_user_messages(&self, limit: i64) -> Result<Vec<String>> {
let rows: Vec<(String,)> = sqlx::query_as(
"SELECT content FROM messages
WHERE role = 'user' AND content IS NOT NULL AND content != ''
ORDER BY id DESC LIMIT ?",
)
.bind(limit)
.fetch_all(&self.pool)
.await?;
Ok(rows.into_iter().map(|r| r.0).collect())
}
async fn session_token_usage(&self, session_id: &str) -> Result<SessionUsage> {
let row: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(
"SELECT
COALESCE(SUM(prompt_tokens), 0),
COALESCE(SUM(completion_tokens), 0),
COALESCE(SUM(cache_read_tokens), 0),
COALESCE(SUM(cache_creation_tokens), 0),
COALESCE(SUM(thinking_tokens), 0),
COUNT(*)
FROM messages
WHERE session_id = ?
AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)",
)
.bind(session_id)
.fetch_one(&self.pool)
.await?;
Ok(SessionUsage {
prompt_tokens: row.0,
completion_tokens: row.1,
cache_read_tokens: row.2,
cache_creation_tokens: row.3,
thinking_tokens: row.4,
api_calls: row.5,
})
}
async fn session_usage_by_agent(
&self,
session_id: &str,
) -> Result<Vec<(String, SessionUsage)>> {
let rows: Vec<(String, i64, i64, i64, i64, i64, i64)> = sqlx::query_as(
"SELECT
COALESCE(agent_name, 'main'),
COALESCE(SUM(prompt_tokens), 0),
COALESCE(SUM(completion_tokens), 0),
COALESCE(SUM(cache_read_tokens), 0),
COALESCE(SUM(cache_creation_tokens), 0),
COALESCE(SUM(thinking_tokens), 0),
COUNT(*)
FROM messages
WHERE session_id = ?
AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)
GROUP BY COALESCE(agent_name, 'main')
ORDER BY COALESCE(SUM(prompt_tokens), 0) + COALESCE(SUM(completion_tokens), 0) DESC",
)
.bind(session_id)
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(|r| {
(
r.0,
SessionUsage {
prompt_tokens: r.1,
completion_tokens: r.2,
cache_read_tokens: r.3,
cache_creation_tokens: r.4,
thinking_tokens: r.5,
api_calls: r.6,
},
)
})
.collect())
}
async fn list_sessions(&self, limit: i64, project_root: &Path) -> Result<Vec<SessionInfo>> {
let root = project_root.to_string_lossy().to_string();
let rows: Vec<SessionInfoRow> = sqlx::query_as(
"SELECT s.id, s.agent_name, s.created_at,
COUNT(m.id) as message_count,
COALESCE(SUM(m.prompt_tokens), 0) + COALESCE(SUM(m.completion_tokens), 0) as total_tokens
FROM sessions s
LEFT JOIN messages m ON m.session_id = s.id
WHERE s.project_root = ? OR s.project_root IS NULL
GROUP BY s.id
ORDER BY s.created_at DESC, s.rowid DESC
LIMIT ?",
)
.bind(&root)
.bind(limit)
.fetch_all(&self.pool)
.await?;
Ok(rows.into_iter().map(|r| r.into()).collect())
}
async fn last_assistant_message(&self, session_id: &str) -> Result<String> {
let row: Option<(String,)> = sqlx::query_as(
"SELECT content FROM messages
WHERE session_id = ? AND role = 'assistant' AND content IS NOT NULL
ORDER BY id DESC LIMIT 1",
)
.bind(session_id)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(|r| r.0).unwrap_or_default())
}
async fn last_user_message(&self, session_id: &str) -> Result<String> {
let row: Option<(String,)> = sqlx::query_as(
"SELECT content FROM messages
WHERE session_id = ? AND role = 'user' AND content IS NOT NULL
ORDER BY id DESC LIMIT 1",
)
.bind(session_id)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(|r| r.0).unwrap_or_default())
}
async fn delete_session(&self, session_id: &str) -> Result<bool> {
let mut tx = self.pool.begin().await?;
sqlx::query("DELETE FROM messages WHERE session_id = ?")
.bind(session_id)
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM session_metadata WHERE session_id = ?")
.bind(session_id)
.execute(&mut *tx)
.await?;
let result = sqlx::query("DELETE FROM sessions WHERE id = ?")
.bind(session_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
sqlx::query("PRAGMA incremental_vacuum")
.execute(&self.pool)
.await?;
Ok(result.rows_affected() > 0)
}
async fn compact_session(
&self,
session_id: &str,
summary: &str,
preserve_count: usize,
) -> Result<usize> {
let mut tx = self.pool.begin().await?;
let all_ids: Vec<(i64,)> = sqlx::query_as(
"SELECT id FROM messages WHERE session_id = ? AND compacted_at IS NULL ORDER BY id ASC",
)
.bind(session_id)
.fetch_all(&mut *tx)
.await?;
let total = all_ids.len();
if total == 0 {
tx.commit().await?;
return Ok(0);
}
let keep_from = total.saturating_sub(preserve_count);
let ids_to_archive: Vec<i64> = all_ids[..keep_from].iter().map(|r| r.0).collect();
let archived_count = ids_to_archive.len();
if archived_count == 0 {
tx.commit().await?;
return Ok(0);
}
for chunk in ids_to_archive.chunks(500) {
let placeholders: String = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(",");
let sql = format!(
"UPDATE messages SET compacted_at = datetime('now') \
WHERE session_id = ? AND id IN ({placeholders})"
);
let mut query = sqlx::query(&sql).bind(session_id);
for id in chunk {
query = query.bind(id);
}
query.execute(&mut *tx).await?;
}
sqlx::query(
"INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens)
VALUES (?, 'system', ?, NULL, NULL, NULL, NULL)",
)
.bind(session_id)
.bind(summary)
.execute(&mut *tx)
.await?;
let continuation = "Your context was compacted. The previous message contains a summary of our earlier conversation. \
Do not mention the summary or that compaction occurred. \
Continue the conversation naturally based on the summarized context.";
sqlx::query(
"INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens)
VALUES (?, 'assistant', ?, NULL, NULL, NULL, NULL)",
)
.bind(session_id)
.bind(continuation)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(archived_count)
}
async fn has_pending_tool_calls(&self, session_id: &str) -> Result<bool> {
let last_msg: Option<(String, Option<String>)> = sqlx::query_as(
"SELECT role, tool_calls FROM messages
WHERE session_id = ? AND compacted_at IS NULL
ORDER BY id DESC LIMIT 1",
)
.bind(session_id)
.fetch_optional(&self.pool)
.await?;
Ok(matches!(last_msg, Some((role, Some(_))) if role == "assistant"))
}
async fn compacted_stats(&self) -> Result<CompactedStats> {
let row: (i64, i64, i64, Option<String>) = sqlx::query_as(
"SELECT
COUNT(*),
COUNT(DISTINCT session_id),
COALESCE(SUM(LENGTH(content) + LENGTH(COALESCE(tool_calls,''))), 0),
MIN(compacted_at)
FROM messages
WHERE compacted_at IS NOT NULL",
)
.fetch_one(&self.pool)
.await?;
Ok(CompactedStats {
message_count: row.0,
session_count: row.1,
size_bytes: row.2,
oldest: row.3,
})
}
async fn purge_compacted(&self, min_age_days: u32) -> Result<usize> {
let result = if min_age_days == 0 {
sqlx::query("DELETE FROM messages WHERE compacted_at IS NOT NULL")
.execute(&self.pool)
.await?
} else {
sqlx::query(
"DELETE FROM messages
WHERE compacted_at IS NOT NULL
AND compacted_at < datetime('now', ?)",
)
.bind(format!("-{min_age_days} days"))
.execute(&self.pool)
.await?
};
let deleted = result.rows_affected() as usize;
sqlx::query("VACUUM").execute(&self.pool).await?;
tracing::info!("Purged {deleted} compacted messages (>{min_age_days} days old)");
Ok(deleted)
}
async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> {
let row: Option<(String,)> =
sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ? AND key = ?")
.bind(session_id)
.bind(key)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(|r| r.0))
}
async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> {
sqlx::query(
"INSERT INTO session_metadata (session_id, key, value, updated_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(session_id, key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
)
.bind(session_id)
.bind(key)
.bind(value)
.execute(&self.pool)
.await?;
Ok(())
}
async fn get_todo(&self, session_id: &str) -> Result<Option<String>> {
self.get_metadata(session_id, "todo").await
}
async fn set_todo(&self, session_id: &str, content: &str) -> Result<()> {
self.set_metadata(session_id, "todo", content).await
}
}
#[derive(sqlx::FromRow)]
struct MessageRow {
id: i64,
session_id: String,
role: String,
content: Option<String>,
tool_calls: Option<String>,
tool_call_id: Option<String>,
prompt_tokens: Option<i64>,
completion_tokens: Option<i64>,
cache_read_tokens: Option<i64>,
cache_creation_tokens: Option<i64>,
thinking_tokens: Option<i64>,
}
#[derive(Debug, Clone, sqlx::FromRow)]
struct SessionInfoRow {
id: String,
agent_name: String,
created_at: String,
message_count: i64,
total_tokens: i64,
}
impl From<SessionInfoRow> for SessionInfo {
fn from(r: SessionInfoRow) -> Self {
Self {
id: r.id,
agent_name: r.agent_name,
created_at: r.created_at,
message_count: r.message_count,
total_tokens: r.total_tokens,
}
}
}
impl From<MessageRow> for Message {
fn from(r: MessageRow) -> Self {
Self {
id: r.id,
session_id: r.session_id,
role: r.role.parse().unwrap_or(Role::User),
content: r.content,
tool_calls: r.tool_calls,
tool_call_id: r.tool_call_id,
prompt_tokens: r.prompt_tokens,
completion_tokens: r.completion_tokens,
cache_read_tokens: r.cache_read_tokens,
cache_creation_tokens: r.cache_creation_tokens,
thinking_tokens: r.thinking_tokens,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn setup() -> (Database, TempDir) {
let tmp = TempDir::new().unwrap();
let db_path = tmp.path().join("test.db");
let db = Database::open(&db_path).await.unwrap();
(db, tmp)
}
#[tokio::test]
async fn test_create_session() {
let (db, _tmp) = setup().await;
let id = db.create_session("default", _tmp.path()).await.unwrap();
assert!(!id.is_empty());
}
#[tokio::test]
async fn test_insert_and_load_messages() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
db.insert_message(&session, &Role::User, Some("hello"), None, None, None)
.await
.unwrap();
db.insert_message(
&session,
&Role::Assistant,
Some("hi there!"),
None,
None,
None,
)
.await
.unwrap();
let msgs = db.load_context(&session).await.unwrap();
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0].role, Role::User);
assert_eq!(msgs[1].role, Role::Assistant);
}
#[tokio::test]
async fn test_load_context_returns_all_active_messages() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
for i in 0..20 {
let content = format!("Message number {i}");
db.insert_message(&session, &Role::User, Some(&content), None, None, None)
.await
.unwrap();
}
let msgs = db.load_context(&session).await.unwrap();
assert_eq!(msgs.len(), 20, "Should load all 20 messages");
assert!(msgs[0].content.as_ref().unwrap().contains("number 0"));
assert!(msgs[19].content.as_ref().unwrap().contains("number 19"));
}
#[tokio::test]
async fn test_sessions_are_isolated() {
let (db, _tmp) = setup().await;
let s1 = db.create_session("agent-a", _tmp.path()).await.unwrap();
let s2 = db.create_session("agent-b", _tmp.path()).await.unwrap();
db.insert_message(&s1, &Role::User, Some("session 1"), None, None, None)
.await
.unwrap();
db.insert_message(&s2, &Role::User, Some("session 2"), None, None, None)
.await
.unwrap();
let msgs1 = db.load_context(&s1).await.unwrap();
let msgs2 = db.load_context(&s2).await.unwrap();
assert_eq!(msgs1.len(), 1);
assert_eq!(msgs2.len(), 1);
assert_eq!(msgs1[0].content.as_deref().unwrap(), "session 1");
assert_eq!(msgs2[0].content.as_deref().unwrap(), "session 2");
}
#[tokio::test]
async fn test_session_token_usage() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
db.insert_message(&session, &Role::User, Some("q1"), None, None, None)
.await
.unwrap();
let usage1 = crate::providers::TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
..Default::default()
};
db.insert_message(
&session,
&Role::Assistant,
Some("a1"),
None,
None,
Some(&usage1),
)
.await
.unwrap();
db.insert_message(&session, &Role::User, Some("q2"), None, None, None)
.await
.unwrap();
let usage2 = crate::providers::TokenUsage {
prompt_tokens: 200,
completion_tokens: 80,
..Default::default()
};
db.insert_message(
&session,
&Role::Assistant,
Some("a2"),
None,
None,
Some(&usage2),
)
.await
.unwrap();
let u = db.session_token_usage(&session).await.unwrap();
assert_eq!(u.prompt_tokens, 300);
assert_eq!(u.completion_tokens, 130);
assert_eq!(u.api_calls, 2);
}
#[tokio::test]
async fn test_list_sessions() {
let (db, _tmp) = setup().await;
db.create_session("agent-a", _tmp.path()).await.unwrap();
db.create_session("agent-b", _tmp.path()).await.unwrap();
db.create_session("agent-c", _tmp.path()).await.unwrap();
let sessions = db.list_sessions(10, _tmp.path()).await.unwrap();
assert_eq!(sessions.len(), 3);
assert_eq!(sessions[0].agent_name, "agent-c");
}
#[tokio::test]
async fn test_delete_session() {
let (db, _tmp) = setup().await;
let s1 = db.create_session("default", _tmp.path()).await.unwrap();
db.insert_message(&s1, &Role::User, Some("hello"), None, None, None)
.await
.unwrap();
assert!(db.delete_session(&s1).await.unwrap());
let sessions = db.list_sessions(10, _tmp.path()).await.unwrap();
assert!(sessions.is_empty());
assert!(!db.delete_session(&s1).await.unwrap());
}
#[tokio::test]
async fn test_compact_session() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
for i in 0..10 {
let role = if i % 2 == 0 {
&Role::User
} else {
&Role::Assistant
};
db.insert_message(&session, role, Some(&format!("msg {i}")), None, None, None)
.await
.unwrap();
}
let deleted = db
.compact_session(&session, "Summary of conversation", 2)
.await
.unwrap();
assert_eq!(deleted, 8);
let msgs = db.load_context(&session).await.unwrap();
assert_eq!(msgs.len(), 4);
let system_msgs: Vec<_> = msgs.iter().filter(|m| m.role == Role::System).collect();
assert_eq!(system_msgs.len(), 1);
assert!(
system_msgs[0]
.content
.as_ref()
.unwrap()
.contains("Summary of conversation")
);
let assistant_msgs: Vec<_> = msgs.iter().filter(|m| m.role == Role::Assistant).collect();
assert!(
assistant_msgs
.iter()
.any(|m| m.content.as_deref().unwrap_or("").contains("compacted")),
"Expected a continuation hint from assistant"
);
let preserved: Vec<_> = msgs
.iter()
.filter(|m| m.content.as_deref().is_some_and(|c| c.starts_with("msg ")))
.collect();
assert_eq!(preserved.len(), 2);
}
#[tokio::test]
async fn test_compact_preserves_zero() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
for i in 0..6 {
let role = if i % 2 == 0 {
&Role::User
} else {
&Role::Assistant
};
db.insert_message(&session, role, Some(&format!("msg {i}")), None, None, None)
.await
.unwrap();
}
let deleted = db
.compact_session(&session, "Full summary", 0)
.await
.unwrap();
assert_eq!(deleted, 6);
let msgs = db.load_context(&session).await.unwrap();
assert_eq!(msgs.len(), 2); assert_eq!(msgs.iter().filter(|m| m.role == Role::System).count(), 1);
assert_eq!(msgs.iter().filter(|m| m.role == Role::Assistant).count(), 1);
}
#[tokio::test]
async fn test_has_pending_tool_calls() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
assert!(!db.has_pending_tool_calls(&session).await.unwrap());
db.insert_message(&session, &Role::User, Some("hello"), None, None, None)
.await
.unwrap();
assert!(!db.has_pending_tool_calls(&session).await.unwrap());
db.insert_message(
&session,
&Role::Assistant,
None,
Some(r#"[{"id":"tc1","name":"Read","arguments":"{}"}]"#),
None,
None,
)
.await
.unwrap();
assert!(db.has_pending_tool_calls(&session).await.unwrap());
db.insert_message(
&session,
&Role::Tool,
Some("file contents"),
None,
Some("tc1"),
None,
)
.await
.unwrap();
assert!(!db.has_pending_tool_calls(&session).await.unwrap());
}
#[tokio::test]
async fn test_prune_mismatched_tool_calls() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
db.insert_message(&session, &Role::User, Some("hello"), None, None, None)
.await
.unwrap();
db.insert_message(
&session,
&Role::Assistant,
Some("Let me read that."),
Some(r#"[{"id":"tc1","name":"Read","arguments":"{}"}]"#),
None,
None,
)
.await
.unwrap();
db.insert_message(
&session,
&Role::Tool,
Some("file contents"),
None,
Some("tc1"),
None,
)
.await
.unwrap();
db.insert_message(
&session,
&Role::Assistant,
Some("I'll edit the file."),
Some(r#"[{"id":"tc2","name":"Edit","arguments":"{}"}]"#),
None,
None,
)
.await
.unwrap();
let msgs = db.load_context(&session).await.unwrap();
let first_asst = msgs
.iter()
.find(|m| m.content.as_deref() == Some("Let me read that."))
.unwrap();
assert!(
first_asst.tool_calls.is_some(),
"completed tool_calls should be preserved"
);
let orphaned = msgs
.iter()
.find(|m| m.content.as_deref() == Some("I'll edit the file."));
assert!(
orphaned.is_none(),
"orphaned assistant message should be dropped by prune_mismatched_tool_calls"
);
}
#[test]
fn test_prune_mismatched_tool_calls_unit() {
fn msg(
role: &str,
content: Option<&str>,
tool_calls: Option<&str>,
tool_call_id: Option<&str>,
) -> Message {
Message {
id: 0,
session_id: String::new(),
role: role.parse().unwrap_or(Role::User),
content: content.map(Into::into),
tool_calls: tool_calls.map(Into::into),
tool_call_id: tool_call_id.map(Into::into),
prompt_tokens: None,
completion_tokens: None,
cache_read_tokens: None,
cache_creation_tokens: None,
thinking_tokens: None,
}
}
let mut empty: Vec<Message> = vec![];
prune_mismatched_tool_calls(&mut empty);
assert!(empty.is_empty());
let mut msgs = vec![msg("user", Some("hi"), None, None)];
prune_mismatched_tool_calls(&mut msgs);
assert_eq!(msgs.len(), 1);
let mut msgs = vec![
msg("user", Some("hi"), None, None),
msg(
"assistant",
Some("doing it"),
Some(r#"[{"id":"t1"}]"#),
None,
),
];
prune_mismatched_tool_calls(&mut msgs);
assert_eq!(msgs.len(), 1, "orphaned assistant should be dropped");
assert_eq!(msgs[0].role, Role::User);
let mut msgs = vec![
msg("user", Some("hi"), None, None),
msg("assistant", None, Some(r#"[{"id":"t1"}]"#), None),
msg("tool", Some("ok"), None, Some("t1")),
];
prune_mismatched_tool_calls(&mut msgs);
assert_eq!(msgs.len(), 3, "complete pair should be preserved");
assert!(msgs[1].tool_calls.is_some());
}
#[tokio::test]
async fn test_session_metadata_and_todo() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
assert!(db.get_todo(&session).await.unwrap().is_none());
assert!(
db.get_metadata(&session, "anything")
.await
.unwrap()
.is_none()
);
db.set_todo(&session, "- [ ] Task 1\n- [x] Task 2")
.await
.unwrap();
let todo = db.get_todo(&session).await.unwrap().unwrap();
assert!(todo.contains("Task 1"));
assert!(todo.contains("Task 2"));
db.set_todo(&session, "- [x] Task 1\n- [x] Task 2")
.await
.unwrap();
let todo = db.get_todo(&session).await.unwrap().unwrap();
assert!(todo.starts_with("- [x] Task 1"));
db.set_metadata(&session, "custom_key", "custom_value")
.await
.unwrap();
assert_eq!(
db.get_metadata(&session, "custom_key")
.await
.unwrap()
.unwrap(),
"custom_value"
);
}
#[tokio::test]
async fn test_token_usage_empty_session() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
let u = db.session_token_usage(&session).await.unwrap();
assert_eq!(u.prompt_tokens, 0);
assert_eq!(u.completion_tokens, 0);
assert_eq!(u.api_calls, 0);
}
#[tokio::test]
async fn test_last_assistant_message() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
let msg = db.last_assistant_message(&session).await.unwrap();
assert_eq!(msg, "");
db.insert_message(&session, &Role::User, Some("question 1"), None, None, None)
.await
.unwrap();
db.insert_message(
&session,
&Role::Assistant,
Some("answer 1"),
None,
None,
None,
)
.await
.unwrap();
db.insert_message(&session, &Role::User, Some("question 2"), None, None, None)
.await
.unwrap();
db.insert_message(
&session,
&Role::Assistant,
Some("answer 2"),
None,
None,
None,
)
.await
.unwrap();
let msg = db.last_assistant_message(&session).await.unwrap();
assert_eq!(msg, "answer 2");
}
#[tokio::test]
async fn test_last_assistant_message_skips_tool_calls() {
let (db, _tmp) = setup().await;
let session = db.create_session("default", _tmp.path()).await.unwrap();
db.insert_message(
&session,
&Role::User,
Some("do something"),
None,
None,
None,
)
.await
.unwrap();
db.insert_message(
&session,
&Role::Assistant,
None,
Some("[{\"id\":\"1\"}]"),
None,
None,
)
.await
.unwrap();
db.insert_message(
&session,
&Role::Tool,
Some("tool result"),
None,
Some("1"),
None,
)
.await
.unwrap();
db.insert_message(&session, &Role::Assistant, Some("Done!"), None, None, None)
.await
.unwrap();
let msg = db.last_assistant_message(&session).await.unwrap();
assert_eq!(msg, "Done!");
}
}