1use anyhow::{Context, Result};
7use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
8use std::path::Path;
9use std::str::FromStr;
10
11pub use crate::persistence::{
13 CompactedStats, Message, Persistence, Role, SessionInfo, SessionUsage,
14};
15
16#[derive(Debug, Clone)]
18pub struct Database {
19 pool: SqlitePool,
20}
21
22pub fn config_dir() -> Result<std::path::PathBuf> {
24 let base = std::env::var("XDG_CONFIG_HOME")
25 .ok()
26 .map(std::path::PathBuf::from)
27 .or_else(|| {
28 std::env::var("HOME")
29 .ok()
30 .map(|h| std::path::PathBuf::from(h).join(".config"))
31 })
32 .ok_or_else(|| {
33 anyhow::anyhow!("Cannot determine config directory (set HOME or XDG_CONFIG_HOME)")
34 })?;
35 Ok(base.join("koda"))
36}
37
38impl Database {
39 pub async fn init(koda_config_dir: &Path) -> Result<Self> {
46 let db_dir = koda_config_dir.join("db");
47 std::fs::create_dir_all(&db_dir)
48 .with_context(|| format!("Failed to create DB dir: {}", db_dir.display()))?;
49
50 let db_path = db_dir.join("koda.db");
51
52 Self::open(&db_path).await
53 }
54
55 pub async fn open(db_path: &Path) -> Result<Self> {
57 let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
58
59 let options = SqliteConnectOptions::from_str(&db_url)?
60 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
61 .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
62 .foreign_keys(true)
63 .create_if_missing(true);
64
65 let pool = SqlitePoolOptions::new()
66 .max_connections(5)
67 .connect_with(options)
68 .await
69 .with_context(|| format!("Failed to connect to database: {db_url}"))?;
70
71 Self::migrate(&pool).await?;
73 Ok(Self { pool })
74 }
75
76 async fn migrate(pool: &SqlitePool) -> Result<()> {
78 sqlx::query(
79 "CREATE TABLE IF NOT EXISTS sessions (
80 id TEXT PRIMARY KEY,
81 created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
82 agent_name TEXT NOT NULL
83 );",
84 )
85 .execute(pool)
86 .await?;
87
88 sqlx::query(
89 "CREATE TABLE IF NOT EXISTS messages (
90 id INTEGER PRIMARY KEY AUTOINCREMENT,
91 session_id TEXT NOT NULL,
92 role TEXT NOT NULL,
93 content TEXT,
94 tool_calls TEXT,
95 tool_call_id TEXT,
96 prompt_tokens INTEGER,
97 completion_tokens INTEGER,
98 created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
99 FOREIGN KEY(session_id) REFERENCES sessions(id)
100 );",
101 )
102 .execute(pool)
103 .await?;
104
105 sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id);")
106 .execute(pool)
107 .await?;
108
109 sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_role_id ON messages(role, id DESC);")
110 .execute(pool)
111 .await?;
112
113 for col in &[
115 "cache_read_tokens",
116 "cache_creation_tokens",
117 "thinking_tokens",
118 ] {
119 let sql = format!("ALTER TABLE messages ADD COLUMN {col} INTEGER");
120 if let Err(e) = sqlx::query(&sql).execute(pool).await {
122 let msg = e.to_string();
123 if !msg.contains("duplicate column name") {
124 return Err(e.into());
125 }
126 }
127 }
128
129 for (col, col_type) in &[("agent_name", "TEXT")] {
131 let sql = format!("ALTER TABLE messages ADD COLUMN {col} {col_type}");
132 if let Err(e) = sqlx::query(&sql).execute(pool).await {
133 let msg = e.to_string();
134 if !msg.contains("duplicate column name") {
135 return Err(e.into());
136 }
137 }
138 }
139
140 sqlx::query(
142 "CREATE TABLE IF NOT EXISTS session_metadata (
143 session_id TEXT NOT NULL,
144 key TEXT NOT NULL,
145 value TEXT NOT NULL,
146 updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
147 PRIMARY KEY(session_id, key),
148 FOREIGN KEY(session_id) REFERENCES sessions(id)
149 );",
150 )
151 .execute(pool)
152 .await?;
153
154 let sql = "ALTER TABLE sessions ADD COLUMN project_root TEXT";
156 if let Err(e) = sqlx::query(sql).execute(pool).await {
157 let msg = e.to_string();
158 if !msg.contains("duplicate column name") {
159 return Err(e.into());
160 }
161 }
162
163 let sql = "ALTER TABLE messages ADD COLUMN compacted_at TEXT";
165 if let Err(e) = sqlx::query(sql).execute(pool).await {
166 let msg = e.to_string();
167 if !msg.contains("duplicate column name") {
168 return Err(e.into());
169 }
170 }
171
172 let sql = "ALTER TABLE sessions ADD COLUMN last_accessed_at TEXT";
174 if let Err(e) = sqlx::query(sql).execute(pool).await {
175 let msg = e.to_string();
176 if !msg.contains("duplicate column name") {
177 return Err(e.into());
178 }
179 }
180
181 Ok(())
182 }
183}
184
185fn prune_mismatched_tool_calls(messages: &mut Vec<Message>) {
196 if messages.is_empty() {
197 return;
198 }
199
200 let mut tool_call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
201 let mut tool_return_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
202
203 for msg in messages.iter() {
204 if msg.role == Role::Assistant {
205 if let Some(ref tc_json) = msg.tool_calls
206 && let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
207 {
208 for call in &calls {
209 if let Some(id) = call.get("id").and_then(|v| v.as_str()) {
210 tool_call_ids.insert(id.to_string());
211 }
212 }
213 }
214 } else if msg.role == Role::Tool
215 && let Some(ref id) = msg.tool_call_id
216 {
217 tool_return_ids.insert(id.clone());
218 }
219 }
220
221 let mismatched: std::collections::HashSet<&String> = tool_call_ids
222 .symmetric_difference(&tool_return_ids)
223 .collect();
224
225 if mismatched.is_empty() {
226 return;
227 }
228
229 messages.retain(|msg| {
230 if msg.role == Role::Tool
232 && let Some(ref id) = msg.tool_call_id
233 && mismatched.contains(id)
234 {
235 return false;
236 }
237 if msg.role == Role::Assistant
239 && let Some(ref tc_json) = msg.tool_calls
240 && let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
241 {
242 let has_mismatched = calls.iter().any(|call| {
243 call.get("id")
244 .and_then(|v| v.as_str())
245 .is_some_and(|id| mismatched.contains(&id.to_string()))
246 });
247 if has_mismatched {
248 return false;
249 }
250 }
251 true
252 });
253}
254
255#[async_trait::async_trait]
256impl Persistence for Database {
257 async fn create_session(&self, agent_name: &str, project_root: &Path) -> Result<String> {
259 let id = uuid::Uuid::new_v4().to_string();
260 let root = project_root.to_string_lossy().to_string();
261 sqlx::query("INSERT INTO sessions (id, agent_name, project_root) VALUES (?, ?, ?)")
262 .bind(&id)
263 .bind(agent_name)
264 .bind(&root)
265 .execute(&self.pool)
266 .await?;
267 tracing::info!("Created session: {id} (project: {root})");
268 Ok(id)
269 }
270
271 async fn insert_message(
273 &self,
274 session_id: &str,
275 role: &Role,
276 content: Option<&str>,
277 tool_calls: Option<&str>,
278 tool_call_id: Option<&str>,
279 usage: Option<&crate::providers::TokenUsage>,
280 ) -> Result<i64> {
281 self.insert_message_with_agent(
282 session_id,
283 role,
284 content,
285 tool_calls,
286 tool_call_id,
287 usage,
288 None,
289 )
290 .await
291 }
292
293 #[allow(clippy::too_many_arguments)]
295 async fn insert_message_with_agent(
296 &self,
297 session_id: &str,
298 role: &Role,
299 content: Option<&str>,
300 tool_calls: Option<&str>,
301 tool_call_id: Option<&str>,
302 usage: Option<&crate::providers::TokenUsage>,
303 agent_name: Option<&str>,
304 ) -> Result<i64> {
305 let result = sqlx::query(
306 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, \
307 prompt_tokens, completion_tokens, cache_read_tokens, cache_creation_tokens, \
308 thinking_tokens, agent_name)
309 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
310 )
311 .bind(session_id)
312 .bind(role.as_str())
313 .bind(content)
314 .bind(tool_calls)
315 .bind(tool_call_id)
316 .bind(usage.map(|u| u.prompt_tokens))
317 .bind(usage.map(|u| u.completion_tokens))
318 .bind(usage.map(|u| u.cache_read_tokens))
319 .bind(usage.map(|u| u.cache_creation_tokens))
320 .bind(usage.map(|u| u.thinking_tokens))
321 .bind(agent_name)
322 .execute(&self.pool)
323 .await?;
324
325 sqlx::query("UPDATE sessions SET last_accessed_at = datetime('now') WHERE id = ?")
327 .bind(session_id)
328 .execute(&self.pool)
329 .await?;
330
331 Ok(result.last_insert_rowid())
332 }
333
334 async fn load_context(&self, session_id: &str) -> Result<Vec<Message>> {
340 let mut messages: Vec<Message> = sqlx::query_as::<_, MessageRow>(
341 "SELECT id, session_id, role, content, tool_calls, tool_call_id,
342 prompt_tokens, completion_tokens,
343 cache_read_tokens, cache_creation_tokens, thinking_tokens
344 FROM messages
345 WHERE session_id = ? AND compacted_at IS NULL
346 ORDER BY id ASC",
347 )
348 .bind(session_id)
349 .fetch_all(&self.pool)
350 .await?
351 .into_iter()
352 .map(|r| r.into())
353 .collect();
354
355 prune_mismatched_tool_calls(&mut messages);
359
360 Ok(messages)
361 }
362 async fn load_all_messages(&self, session_id: &str) -> Result<Vec<Message>> {
365 let rows: Vec<Message> = sqlx::query_as::<_, MessageRow>(
366 "SELECT id, session_id, role, content, tool_calls, tool_call_id,
367 prompt_tokens, completion_tokens,
368 cache_read_tokens, cache_creation_tokens, thinking_tokens
369 FROM messages
370 WHERE session_id = ?
371 ORDER BY id ASC",
372 )
373 .bind(session_id)
374 .fetch_all(&self.pool)
375 .await?
376 .into_iter()
377 .map(|r| r.into())
378 .collect();
379 Ok(rows)
380 }
381
382 async fn recent_user_messages(&self, limit: i64) -> Result<Vec<String>> {
385 let rows: Vec<(String,)> = sqlx::query_as(
386 "SELECT content FROM messages
387 WHERE role = 'user' AND content IS NOT NULL AND content != ''
388 ORDER BY id DESC LIMIT ?",
389 )
390 .bind(limit)
391 .fetch_all(&self.pool)
392 .await?;
393
394 Ok(rows.into_iter().map(|r| r.0).collect())
395 }
396
397 async fn session_token_usage(&self, session_id: &str) -> Result<SessionUsage> {
399 let row: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(
400 "SELECT
401 COALESCE(SUM(prompt_tokens), 0),
402 COALESCE(SUM(completion_tokens), 0),
403 COALESCE(SUM(cache_read_tokens), 0),
404 COALESCE(SUM(cache_creation_tokens), 0),
405 COALESCE(SUM(thinking_tokens), 0),
406 COUNT(*)
407 FROM messages
408 WHERE session_id = ?
409 AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)",
410 )
411 .bind(session_id)
412 .fetch_one(&self.pool)
413 .await?;
414 Ok(SessionUsage {
415 prompt_tokens: row.0,
416 completion_tokens: row.1,
417 cache_read_tokens: row.2,
418 cache_creation_tokens: row.3,
419 thinking_tokens: row.4,
420 api_calls: row.5,
421 })
422 }
423
424 async fn session_usage_by_agent(
426 &self,
427 session_id: &str,
428 ) -> Result<Vec<(String, SessionUsage)>> {
429 let rows: Vec<(String, i64, i64, i64, i64, i64, i64)> = sqlx::query_as(
430 "SELECT
431 COALESCE(agent_name, 'main'),
432 COALESCE(SUM(prompt_tokens), 0),
433 COALESCE(SUM(completion_tokens), 0),
434 COALESCE(SUM(cache_read_tokens), 0),
435 COALESCE(SUM(cache_creation_tokens), 0),
436 COALESCE(SUM(thinking_tokens), 0),
437 COUNT(*)
438 FROM messages
439 WHERE session_id = ?
440 AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)
441 GROUP BY COALESCE(agent_name, 'main')
442 ORDER BY COALESCE(SUM(prompt_tokens), 0) + COALESCE(SUM(completion_tokens), 0) DESC",
443 )
444 .bind(session_id)
445 .fetch_all(&self.pool)
446 .await?;
447 Ok(rows
448 .into_iter()
449 .map(|r| {
450 (
451 r.0,
452 SessionUsage {
453 prompt_tokens: r.1,
454 completion_tokens: r.2,
455 cache_read_tokens: r.3,
456 cache_creation_tokens: r.4,
457 thinking_tokens: r.5,
458 api_calls: r.6,
459 },
460 )
461 })
462 .collect())
463 }
464
465 async fn list_sessions(&self, limit: i64, project_root: &Path) -> Result<Vec<SessionInfo>> {
467 let root = project_root.to_string_lossy().to_string();
468 let rows: Vec<SessionInfoRow> = sqlx::query_as(
469 "SELECT s.id, s.agent_name, s.created_at,
470 COUNT(m.id) as message_count,
471 COALESCE(SUM(m.prompt_tokens), 0) + COALESCE(SUM(m.completion_tokens), 0) as total_tokens
472 FROM sessions s
473 LEFT JOIN messages m ON m.session_id = s.id
474 WHERE s.project_root = ? OR s.project_root IS NULL
475 GROUP BY s.id
476 ORDER BY s.created_at DESC, s.rowid DESC
477 LIMIT ?",
478 )
479 .bind(&root)
480 .bind(limit)
481 .fetch_all(&self.pool)
482 .await?;
483 Ok(rows.into_iter().map(|r| r.into()).collect())
484 }
485
486 async fn last_assistant_message(&self, session_id: &str) -> Result<String> {
488 let row: Option<(String,)> = sqlx::query_as(
489 "SELECT content FROM messages
490 WHERE session_id = ? AND role = 'assistant' AND content IS NOT NULL
491 ORDER BY id DESC LIMIT 1",
492 )
493 .bind(session_id)
494 .fetch_optional(&self.pool)
495 .await?;
496 Ok(row.map(|r| r.0).unwrap_or_default())
497 }
498
499 async fn last_user_message(&self, session_id: &str) -> Result<String> {
501 let row: Option<(String,)> = sqlx::query_as(
502 "SELECT content FROM messages
503 WHERE session_id = ? AND role = 'user' AND content IS NOT NULL
504 ORDER BY id DESC LIMIT 1",
505 )
506 .bind(session_id)
507 .fetch_optional(&self.pool)
508 .await?;
509 Ok(row.map(|r| r.0).unwrap_or_default())
510 }
511
512 async fn delete_session(&self, session_id: &str) -> Result<bool> {
514 let mut tx = self.pool.begin().await?;
515
516 sqlx::query("DELETE FROM messages WHERE session_id = ?")
517 .bind(session_id)
518 .execute(&mut *tx)
519 .await?;
520
521 sqlx::query("DELETE FROM session_metadata WHERE session_id = ?")
522 .bind(session_id)
523 .execute(&mut *tx)
524 .await?;
525
526 let result = sqlx::query("DELETE FROM sessions WHERE id = ?")
527 .bind(session_id)
528 .execute(&mut *tx)
529 .await?;
530
531 tx.commit().await?;
532
533 sqlx::query("PRAGMA incremental_vacuum")
535 .execute(&self.pool)
536 .await?;
537
538 Ok(result.rows_affected() > 0)
539 }
540
541 async fn compact_session(
549 &self,
550 session_id: &str,
551 summary: &str,
552 preserve_count: usize,
553 ) -> Result<usize> {
554 let mut tx = self.pool.begin().await?;
555
556 let all_ids: Vec<(i64,)> = sqlx::query_as(
558 "SELECT id FROM messages WHERE session_id = ? AND compacted_at IS NULL ORDER BY id ASC",
559 )
560 .bind(session_id)
561 .fetch_all(&mut *tx)
562 .await?;
563
564 let total = all_ids.len();
565 if total == 0 {
566 tx.commit().await?;
567 return Ok(0);
568 }
569
570 let keep_from = total.saturating_sub(preserve_count);
572 let ids_to_archive: Vec<i64> = all_ids[..keep_from].iter().map(|r| r.0).collect();
573 let archived_count = ids_to_archive.len();
574
575 if archived_count == 0 {
576 tx.commit().await?;
577 return Ok(0);
578 }
579
580 for chunk in ids_to_archive.chunks(500) {
582 let placeholders: String = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(",");
583 let sql = format!(
584 "UPDATE messages SET compacted_at = datetime('now') \
585 WHERE session_id = ? AND id IN ({placeholders})"
586 );
587 let mut query = sqlx::query(&sql).bind(session_id);
588 for id in chunk {
589 query = query.bind(id);
590 }
591 query.execute(&mut *tx).await?;
592 }
593
594 sqlx::query(
596 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens)
597 VALUES (?, 'system', ?, NULL, NULL, NULL, NULL)",
598 )
599 .bind(session_id)
600 .bind(summary)
601 .execute(&mut *tx)
602 .await?;
603
604 let continuation = "Your context was compacted. The previous message contains a summary of our earlier conversation. \
606 Do not mention the summary or that compaction occurred. \
607 Continue the conversation naturally based on the summarized context.";
608 sqlx::query(
609 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens)
610 VALUES (?, 'assistant', ?, NULL, NULL, NULL, NULL)",
611 )
612 .bind(session_id)
613 .bind(continuation)
614 .execute(&mut *tx)
615 .await?;
616
617 tx.commit().await?;
618
619 Ok(archived_count)
620 }
621
622 async fn has_pending_tool_calls(&self, session_id: &str) -> Result<bool> {
625 let last_msg: Option<(String, Option<String>)> = sqlx::query_as(
628 "SELECT role, tool_calls FROM messages
629 WHERE session_id = ? AND compacted_at IS NULL
630 ORDER BY id DESC LIMIT 1",
631 )
632 .bind(session_id)
633 .fetch_optional(&self.pool)
634 .await?;
635
636 Ok(matches!(last_msg, Some((role, Some(_))) if role == "assistant"))
637 }
638
639 async fn compacted_stats(&self) -> Result<CompactedStats> {
641 let row: (i64, i64, i64, Option<String>) = sqlx::query_as(
642 "SELECT
643 COUNT(*),
644 COUNT(DISTINCT session_id),
645 COALESCE(SUM(LENGTH(content) + LENGTH(COALESCE(tool_calls,''))), 0),
646 MIN(compacted_at)
647 FROM messages
648 WHERE compacted_at IS NOT NULL",
649 )
650 .fetch_one(&self.pool)
651 .await?;
652
653 Ok(CompactedStats {
654 message_count: row.0,
655 session_count: row.1,
656 size_bytes: row.2,
657 oldest: row.3,
658 })
659 }
660
661 async fn purge_compacted(&self, min_age_days: u32) -> Result<usize> {
664 let result = if min_age_days == 0 {
665 sqlx::query("DELETE FROM messages WHERE compacted_at IS NOT NULL")
666 .execute(&self.pool)
667 .await?
668 } else {
669 sqlx::query(
670 "DELETE FROM messages
671 WHERE compacted_at IS NOT NULL
672 AND compacted_at < datetime('now', ?)",
673 )
674 .bind(format!("-{min_age_days} days"))
675 .execute(&self.pool)
676 .await?
677 };
678
679 let deleted = result.rows_affected() as usize;
680
681 sqlx::query("VACUUM").execute(&self.pool).await?;
683
684 tracing::info!("Purged {deleted} compacted messages (>{min_age_days} days old)");
685 Ok(deleted)
686 }
687
688 async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> {
690 let row: Option<(String,)> =
691 sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ? AND key = ?")
692 .bind(session_id)
693 .bind(key)
694 .fetch_optional(&self.pool)
695 .await?;
696 Ok(row.map(|r| r.0))
697 }
698
699 async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> {
701 sqlx::query(
702 "INSERT INTO session_metadata (session_id, key, value, updated_at)
703 VALUES (?, ?, ?, CURRENT_TIMESTAMP)
704 ON CONFLICT(session_id, key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
705 )
706 .bind(session_id)
707 .bind(key)
708 .bind(value)
709 .execute(&self.pool)
710 .await?;
711 Ok(())
712 }
713
714 async fn get_todo(&self, session_id: &str) -> Result<Option<String>> {
716 self.get_metadata(session_id, "todo").await
717 }
718
719 async fn set_todo(&self, session_id: &str, content: &str) -> Result<()> {
721 self.set_metadata(session_id, "todo", content).await
722 }
723}
724
725#[derive(sqlx::FromRow)]
727struct MessageRow {
728 id: i64,
729 session_id: String,
730 role: String,
731 content: Option<String>,
732 tool_calls: Option<String>,
733 tool_call_id: Option<String>,
734 prompt_tokens: Option<i64>,
735 completion_tokens: Option<i64>,
736 cache_read_tokens: Option<i64>,
737 cache_creation_tokens: Option<i64>,
738 thinking_tokens: Option<i64>,
739}
740
741#[derive(Debug, Clone, sqlx::FromRow)]
743struct SessionInfoRow {
744 id: String,
745 agent_name: String,
746 created_at: String,
747 message_count: i64,
748 total_tokens: i64,
749}
750
751impl From<SessionInfoRow> for SessionInfo {
752 fn from(r: SessionInfoRow) -> Self {
753 Self {
754 id: r.id,
755 agent_name: r.agent_name,
756 created_at: r.created_at,
757 message_count: r.message_count,
758 total_tokens: r.total_tokens,
759 }
760 }
761}
762
763impl From<MessageRow> for Message {
764 fn from(r: MessageRow) -> Self {
765 Self {
766 id: r.id,
767 session_id: r.session_id,
768 role: r.role.parse().unwrap_or(Role::User),
769 content: r.content,
770 tool_calls: r.tool_calls,
771 tool_call_id: r.tool_call_id,
772 prompt_tokens: r.prompt_tokens,
773 completion_tokens: r.completion_tokens,
774 cache_read_tokens: r.cache_read_tokens,
775 cache_creation_tokens: r.cache_creation_tokens,
776 thinking_tokens: r.thinking_tokens,
777 }
778 }
779}
780
781#[cfg(test)]
782mod tests {
783 use super::*;
784 use tempfile::TempDir;
785
786 async fn setup() -> (Database, TempDir) {
787 let tmp = TempDir::new().unwrap();
788 let db_path = tmp.path().join("test.db");
789 let db = Database::open(&db_path).await.unwrap();
790 (db, tmp)
791 }
792
793 #[tokio::test]
794 async fn test_create_session() {
795 let (db, _tmp) = setup().await;
796 let id = db.create_session("default", _tmp.path()).await.unwrap();
797 assert!(!id.is_empty());
798 }
799
800 #[tokio::test]
801 async fn test_insert_and_load_messages() {
802 let (db, _tmp) = setup().await;
803 let session = db.create_session("default", _tmp.path()).await.unwrap();
804
805 db.insert_message(&session, &Role::User, Some("hello"), None, None, None)
806 .await
807 .unwrap();
808 db.insert_message(
809 &session,
810 &Role::Assistant,
811 Some("hi there!"),
812 None,
813 None,
814 None,
815 )
816 .await
817 .unwrap();
818
819 let msgs = db.load_context(&session).await.unwrap();
820 assert_eq!(msgs.len(), 2);
821 assert_eq!(msgs[0].role, Role::User);
822 assert_eq!(msgs[1].role, Role::Assistant);
823 }
824
825 #[tokio::test]
826 async fn test_load_context_returns_all_active_messages() {
827 let (db, _tmp) = setup().await;
828 let session = db.create_session("default", _tmp.path()).await.unwrap();
829
830 for i in 0..20 {
832 let content = format!("Message number {i}");
833 db.insert_message(&session, &Role::User, Some(&content), None, None, None)
834 .await
835 .unwrap();
836 }
837
838 let msgs = db.load_context(&session).await.unwrap();
840 assert_eq!(msgs.len(), 20, "Should load all 20 messages");
841
842 assert!(msgs[0].content.as_ref().unwrap().contains("number 0"));
844 assert!(msgs[19].content.as_ref().unwrap().contains("number 19"));
845 }
846
847 #[tokio::test]
848 async fn test_sessions_are_isolated() {
849 let (db, _tmp) = setup().await;
850 let s1 = db.create_session("agent-a", _tmp.path()).await.unwrap();
851 let s2 = db.create_session("agent-b", _tmp.path()).await.unwrap();
852
853 db.insert_message(&s1, &Role::User, Some("session 1"), None, None, None)
854 .await
855 .unwrap();
856 db.insert_message(&s2, &Role::User, Some("session 2"), None, None, None)
857 .await
858 .unwrap();
859
860 let msgs1 = db.load_context(&s1).await.unwrap();
861 let msgs2 = db.load_context(&s2).await.unwrap();
862
863 assert_eq!(msgs1.len(), 1);
864 assert_eq!(msgs2.len(), 1);
865 assert_eq!(msgs1[0].content.as_deref().unwrap(), "session 1");
866 assert_eq!(msgs2[0].content.as_deref().unwrap(), "session 2");
867 }
868
869 #[tokio::test]
870 async fn test_session_token_usage() {
871 let (db, _tmp) = setup().await;
872 let session = db.create_session("default", _tmp.path()).await.unwrap();
873
874 db.insert_message(&session, &Role::User, Some("q1"), None, None, None)
875 .await
876 .unwrap();
877 let usage1 = crate::providers::TokenUsage {
878 prompt_tokens: 100,
879 completion_tokens: 50,
880 ..Default::default()
881 };
882 db.insert_message(
883 &session,
884 &Role::Assistant,
885 Some("a1"),
886 None,
887 None,
888 Some(&usage1),
889 )
890 .await
891 .unwrap();
892 db.insert_message(&session, &Role::User, Some("q2"), None, None, None)
893 .await
894 .unwrap();
895 let usage2 = crate::providers::TokenUsage {
896 prompt_tokens: 200,
897 completion_tokens: 80,
898 ..Default::default()
899 };
900 db.insert_message(
901 &session,
902 &Role::Assistant,
903 Some("a2"),
904 None,
905 None,
906 Some(&usage2),
907 )
908 .await
909 .unwrap();
910
911 let u = db.session_token_usage(&session).await.unwrap();
912 assert_eq!(u.prompt_tokens, 300);
913 assert_eq!(u.completion_tokens, 130);
914 assert_eq!(u.api_calls, 2);
915 }
916
917 #[tokio::test]
918 async fn test_list_sessions() {
919 let (db, _tmp) = setup().await;
920 db.create_session("agent-a", _tmp.path()).await.unwrap();
921 db.create_session("agent-b", _tmp.path()).await.unwrap();
922 db.create_session("agent-c", _tmp.path()).await.unwrap();
923
924 let sessions = db.list_sessions(10, _tmp.path()).await.unwrap();
925 assert_eq!(sessions.len(), 3);
926 assert_eq!(sessions[0].agent_name, "agent-c");
928 }
929
930 #[tokio::test]
931 async fn test_delete_session() {
932 let (db, _tmp) = setup().await;
933 let s1 = db.create_session("default", _tmp.path()).await.unwrap();
934 db.insert_message(&s1, &Role::User, Some("hello"), None, None, None)
935 .await
936 .unwrap();
937
938 assert!(db.delete_session(&s1).await.unwrap());
939
940 let sessions = db.list_sessions(10, _tmp.path()).await.unwrap();
941 assert!(sessions.is_empty());
942
943 assert!(!db.delete_session(&s1).await.unwrap());
945 }
946
947 #[tokio::test]
948 async fn test_compact_session() {
949 let (db, _tmp) = setup().await;
950 let session = db.create_session("default", _tmp.path()).await.unwrap();
951
952 for i in 0..10 {
954 let role = if i % 2 == 0 {
955 &Role::User
956 } else {
957 &Role::Assistant
958 };
959 db.insert_message(&session, role, Some(&format!("msg {i}")), None, None, None)
960 .await
961 .unwrap();
962 }
963
964 let deleted = db
966 .compact_session(&session, "Summary of conversation", 2)
967 .await
968 .unwrap();
969 assert_eq!(deleted, 8); let msgs = db.load_context(&session).await.unwrap();
973 assert_eq!(msgs.len(), 4);
974
975 let system_msgs: Vec<_> = msgs.iter().filter(|m| m.role == Role::System).collect();
977 assert_eq!(system_msgs.len(), 1);
978 assert!(
979 system_msgs[0]
980 .content
981 .as_ref()
982 .unwrap()
983 .contains("Summary of conversation")
984 );
985
986 let assistant_msgs: Vec<_> = msgs.iter().filter(|m| m.role == Role::Assistant).collect();
988 assert!(
989 assistant_msgs
990 .iter()
991 .any(|m| m.content.as_deref().unwrap_or("").contains("compacted")),
992 "Expected a continuation hint from assistant"
993 );
994
995 let preserved: Vec<_> = msgs
997 .iter()
998 .filter(|m| m.content.as_deref().is_some_and(|c| c.starts_with("msg ")))
999 .collect();
1000 assert_eq!(preserved.len(), 2);
1001 }
1002
1003 #[tokio::test]
1004 async fn test_compact_preserves_zero() {
1005 let (db, _tmp) = setup().await;
1006 let session = db.create_session("default", _tmp.path()).await.unwrap();
1007
1008 for i in 0..6 {
1009 let role = if i % 2 == 0 {
1010 &Role::User
1011 } else {
1012 &Role::Assistant
1013 };
1014 db.insert_message(&session, role, Some(&format!("msg {i}")), None, None, None)
1015 .await
1016 .unwrap();
1017 }
1018
1019 let deleted = db
1021 .compact_session(&session, "Full summary", 0)
1022 .await
1023 .unwrap();
1024 assert_eq!(deleted, 6);
1025
1026 let msgs = db.load_context(&session).await.unwrap();
1027 assert_eq!(msgs.len(), 2); assert_eq!(msgs.iter().filter(|m| m.role == Role::System).count(), 1);
1029 assert_eq!(msgs.iter().filter(|m| m.role == Role::Assistant).count(), 1);
1030 }
1031
1032 #[tokio::test]
1033 async fn test_has_pending_tool_calls() {
1034 let (db, _tmp) = setup().await;
1035 let session = db.create_session("default", _tmp.path()).await.unwrap();
1036
1037 assert!(!db.has_pending_tool_calls(&session).await.unwrap());
1039
1040 db.insert_message(&session, &Role::User, Some("hello"), None, None, None)
1042 .await
1043 .unwrap();
1044 assert!(!db.has_pending_tool_calls(&session).await.unwrap());
1045
1046 db.insert_message(
1048 &session,
1049 &Role::Assistant,
1050 None,
1051 Some(r#"[{"id":"tc1","name":"Read","arguments":"{}"}]"#),
1052 None,
1053 None,
1054 )
1055 .await
1056 .unwrap();
1057 assert!(db.has_pending_tool_calls(&session).await.unwrap());
1058
1059 db.insert_message(
1061 &session,
1062 &Role::Tool,
1063 Some("file contents"),
1064 None,
1065 Some("tc1"),
1066 None,
1067 )
1068 .await
1069 .unwrap();
1070 assert!(!db.has_pending_tool_calls(&session).await.unwrap());
1071 }
1072
1073 #[tokio::test]
1074 async fn test_prune_mismatched_tool_calls() {
1075 let (db, _tmp) = setup().await;
1076 let session = db.create_session("default", _tmp.path()).await.unwrap();
1077
1078 db.insert_message(&session, &Role::User, Some("hello"), None, None, None)
1080 .await
1081 .unwrap();
1082 db.insert_message(
1083 &session,
1084 &Role::Assistant,
1085 Some("Let me read that."),
1086 Some(r#"[{"id":"tc1","name":"Read","arguments":"{}"}]"#),
1087 None,
1088 None,
1089 )
1090 .await
1091 .unwrap();
1092 db.insert_message(
1093 &session,
1094 &Role::Tool,
1095 Some("file contents"),
1096 None,
1097 Some("tc1"),
1098 None,
1099 )
1100 .await
1101 .unwrap();
1102
1103 db.insert_message(
1105 &session,
1106 &Role::Assistant,
1107 Some("I'll edit the file."),
1108 Some(r#"[{"id":"tc2","name":"Edit","arguments":"{}"}]"#),
1109 None,
1110 None,
1111 )
1112 .await
1113 .unwrap();
1114
1115 let msgs = db.load_context(&session).await.unwrap();
1116
1117 let first_asst = msgs
1119 .iter()
1120 .find(|m| m.content.as_deref() == Some("Let me read that."))
1121 .unwrap();
1122 assert!(
1123 first_asst.tool_calls.is_some(),
1124 "completed tool_calls should be preserved"
1125 );
1126
1127 let orphaned = msgs
1129 .iter()
1130 .find(|m| m.content.as_deref() == Some("I'll edit the file."));
1131 assert!(
1132 orphaned.is_none(),
1133 "orphaned assistant message should be dropped by prune_mismatched_tool_calls"
1134 );
1135 }
1136
1137 #[test]
1138 fn test_prune_mismatched_tool_calls_unit() {
1139 fn msg(
1140 role: &str,
1141 content: Option<&str>,
1142 tool_calls: Option<&str>,
1143 tool_call_id: Option<&str>,
1144 ) -> Message {
1145 Message {
1146 id: 0,
1147 session_id: String::new(),
1148 role: role.parse().unwrap_or(Role::User),
1149 content: content.map(Into::into),
1150 tool_calls: tool_calls.map(Into::into),
1151 tool_call_id: tool_call_id.map(Into::into),
1152 prompt_tokens: None,
1153 completion_tokens: None,
1154 cache_read_tokens: None,
1155 cache_creation_tokens: None,
1156 thinking_tokens: None,
1157 }
1158 }
1159
1160 let mut empty: Vec<Message> = vec![];
1162 prune_mismatched_tool_calls(&mut empty);
1163 assert!(empty.is_empty());
1164
1165 let mut msgs = vec![msg("user", Some("hi"), None, None)];
1167 prune_mismatched_tool_calls(&mut msgs);
1168 assert_eq!(msgs.len(), 1);
1169
1170 let mut msgs = vec![
1172 msg("user", Some("hi"), None, None),
1173 msg(
1174 "assistant",
1175 Some("doing it"),
1176 Some(r#"[{"id":"t1"}]"#),
1177 None,
1178 ),
1179 ];
1180 prune_mismatched_tool_calls(&mut msgs);
1181 assert_eq!(msgs.len(), 1, "orphaned assistant should be dropped");
1182 assert_eq!(msgs[0].role, Role::User);
1183
1184 let mut msgs = vec![
1186 msg("user", Some("hi"), None, None),
1187 msg("assistant", None, Some(r#"[{"id":"t1"}]"#), None),
1188 msg("tool", Some("ok"), None, Some("t1")),
1189 ];
1190 prune_mismatched_tool_calls(&mut msgs);
1191 assert_eq!(msgs.len(), 3, "complete pair should be preserved");
1192 assert!(msgs[1].tool_calls.is_some());
1193 }
1194
1195 #[tokio::test]
1196 async fn test_session_metadata_and_todo() {
1197 let (db, _tmp) = setup().await;
1198 let session = db.create_session("default", _tmp.path()).await.unwrap();
1199
1200 assert!(db.get_todo(&session).await.unwrap().is_none());
1202 assert!(
1203 db.get_metadata(&session, "anything")
1204 .await
1205 .unwrap()
1206 .is_none()
1207 );
1208
1209 db.set_todo(&session, "- [ ] Task 1\n- [x] Task 2")
1211 .await
1212 .unwrap();
1213 let todo = db.get_todo(&session).await.unwrap().unwrap();
1214 assert!(todo.contains("Task 1"));
1215 assert!(todo.contains("Task 2"));
1216
1217 db.set_todo(&session, "- [x] Task 1\n- [x] Task 2")
1219 .await
1220 .unwrap();
1221 let todo = db.get_todo(&session).await.unwrap().unwrap();
1222 assert!(todo.starts_with("- [x] Task 1"));
1223
1224 db.set_metadata(&session, "custom_key", "custom_value")
1226 .await
1227 .unwrap();
1228 assert_eq!(
1229 db.get_metadata(&session, "custom_key")
1230 .await
1231 .unwrap()
1232 .unwrap(),
1233 "custom_value"
1234 );
1235 }
1236
1237 #[tokio::test]
1238 async fn test_token_usage_empty_session() {
1239 let (db, _tmp) = setup().await;
1240 let session = db.create_session("default", _tmp.path()).await.unwrap();
1241
1242 let u = db.session_token_usage(&session).await.unwrap();
1243 assert_eq!(u.prompt_tokens, 0);
1244 assert_eq!(u.completion_tokens, 0);
1245 assert_eq!(u.api_calls, 0);
1246 }
1247
1248 #[tokio::test]
1249 async fn test_last_assistant_message() {
1250 let (db, _tmp) = setup().await;
1251 let session = db.create_session("default", _tmp.path()).await.unwrap();
1252
1253 let msg = db.last_assistant_message(&session).await.unwrap();
1255 assert_eq!(msg, "");
1256
1257 db.insert_message(&session, &Role::User, Some("question 1"), None, None, None)
1259 .await
1260 .unwrap();
1261 db.insert_message(
1262 &session,
1263 &Role::Assistant,
1264 Some("answer 1"),
1265 None,
1266 None,
1267 None,
1268 )
1269 .await
1270 .unwrap();
1271 db.insert_message(&session, &Role::User, Some("question 2"), None, None, None)
1272 .await
1273 .unwrap();
1274 db.insert_message(
1275 &session,
1276 &Role::Assistant,
1277 Some("answer 2"),
1278 None,
1279 None,
1280 None,
1281 )
1282 .await
1283 .unwrap();
1284
1285 let msg = db.last_assistant_message(&session).await.unwrap();
1287 assert_eq!(msg, "answer 2");
1288 }
1289
1290 #[tokio::test]
1291 async fn test_last_assistant_message_skips_tool_calls() {
1292 let (db, _tmp) = setup().await;
1293 let session = db.create_session("default", _tmp.path()).await.unwrap();
1294
1295 db.insert_message(
1296 &session,
1297 &Role::User,
1298 Some("do something"),
1299 None,
1300 None,
1301 None,
1302 )
1303 .await
1304 .unwrap();
1305 db.insert_message(
1307 &session,
1308 &Role::Assistant,
1309 None,
1310 Some("[{\"id\":\"1\"}]"),
1311 None,
1312 None,
1313 )
1314 .await
1315 .unwrap();
1316 db.insert_message(
1317 &session,
1318 &Role::Tool,
1319 Some("tool result"),
1320 None,
1321 Some("1"),
1322 None,
1323 )
1324 .await
1325 .unwrap();
1326 db.insert_message(&session, &Role::Assistant, Some("Done!"), None, None, None)
1328 .await
1329 .unwrap();
1330
1331 let msg = db.last_assistant_message(&session).await.unwrap();
1332 assert_eq!(msg, "Done!");
1333 }
1334}