1use anyhow::Result;
9use std::path::Path;
10
11use super::{Database, MessageRow, SessionInfoRow};
12use crate::persistence::{
13 CompactedStats, InterruptionKind, Message, Persistence, Role, SessionInfo, SessionUsage,
14};
15
16pub(crate) fn prune_mismatched_tool_calls(messages: &mut Vec<Message>) {
27 if messages.is_empty() {
28 return;
29 }
30
31 let mut tool_call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
32 let mut tool_return_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
33
34 for msg in messages.iter() {
35 if msg.role == Role::Assistant {
36 if let Some(ref tc_json) = msg.tool_calls
37 && let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
38 {
39 for call in &calls {
40 if let Some(id) = call.get("id").and_then(|v| v.as_str()) {
41 tool_call_ids.insert(id.to_string());
42 }
43 }
44 }
45 } else if msg.role == Role::Tool
46 && let Some(ref id) = msg.tool_call_id
47 {
48 tool_return_ids.insert(id.clone());
49 }
50 }
51
52 let mismatched: std::collections::HashSet<&String> = tool_call_ids
53 .symmetric_difference(&tool_return_ids)
54 .collect();
55
56 if mismatched.is_empty() {
57 return;
58 }
59
60 messages.retain(|msg| {
61 if msg.role == Role::Tool
63 && let Some(ref id) = msg.tool_call_id
64 && mismatched.contains(id)
65 {
66 return false;
67 }
68 if msg.role == Role::Assistant
70 && let Some(ref tc_json) = msg.tool_calls
71 && let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
72 {
73 let has_mismatched = calls.iter().any(|call| {
74 call.get("id")
75 .and_then(|v| v.as_str())
76 .is_some_and(|id| mismatched.contains(&id.to_string()))
77 });
78 if has_mismatched {
79 return false;
80 }
81 }
82 true
83 });
84}
85
86pub(crate) fn prune_null_content_messages(messages: &mut Vec<Message>) {
97 messages.retain(|msg| {
98 if msg.role != Role::Assistant {
99 return true; }
101 let has_content = msg.content.as_deref().is_some_and(|c| !c.trim().is_empty());
102 let has_tool_calls = msg.tool_calls.is_some();
103 has_content || has_tool_calls
104 });
105}
106
107pub(crate) fn prune_whitespace_only_messages(messages: &mut Vec<Message>) {
114 messages.retain(|msg| {
115 if msg.role != Role::Assistant {
116 return true;
117 }
118 if msg.tool_calls.is_some() {
120 return true;
121 }
122 !matches!(msg.content.as_deref(), Some(c) if c.trim().is_empty())
124 });
125}
126
127pub fn detect_interruption(messages: &[Message]) -> Option<InterruptionKind> {
138 let last = messages.iter().rev().find(|m| m.role != Role::System)?;
139
140 match last.role {
141 Role::User => {
142 let preview = last
143 .content
144 .as_deref()
145 .unwrap_or("")
146 .chars()
147 .take(80)
148 .collect::<String>();
149 Some(InterruptionKind::Prompt(preview))
150 }
151 Role::Tool => Some(InterruptionKind::Tool),
152 _ => None,
153 }
154}
155
156#[async_trait::async_trait]
159impl Persistence for Database {
160 async fn create_session(&self, agent_name: &str, project_root: &Path) -> Result<String> {
162 let id = uuid::Uuid::new_v4().to_string();
163 let root = project_root.to_string_lossy().to_string();
164 sqlx::query("INSERT INTO sessions (id, agent_name, project_root) VALUES (?, ?, ?)")
165 .bind(&id)
166 .bind(agent_name)
167 .bind(&root)
168 .execute(&self.pool)
169 .await?;
170 tracing::info!("Created session: {id} (project: {root})");
171 Ok(id)
172 }
173
174 async fn insert_message(
176 &self,
177 session_id: &str,
178 role: &Role,
179 content: Option<&str>,
180 tool_calls: Option<&str>,
181 tool_call_id: Option<&str>,
182 usage: Option<&crate::providers::TokenUsage>,
183 ) -> Result<i64> {
184 self.insert_message_with_agent(
185 session_id,
186 role,
187 content,
188 tool_calls,
189 tool_call_id,
190 usage,
191 None,
192 )
193 .await
194 }
195
196 #[allow(clippy::too_many_arguments)]
198 async fn insert_message_with_agent(
199 &self,
200 session_id: &str,
201 role: &Role,
202 content: Option<&str>,
203 tool_calls: Option<&str>,
204 tool_call_id: Option<&str>,
205 usage: Option<&crate::providers::TokenUsage>,
206 agent_name: Option<&str>,
207 ) -> Result<i64> {
208 let result = sqlx::query(
209 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, \
210 prompt_tokens, completion_tokens, cache_read_tokens, cache_creation_tokens, \
211 thinking_tokens, agent_name)
212 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
213 )
214 .bind(session_id)
215 .bind(role.as_str())
216 .bind(content)
217 .bind(tool_calls)
218 .bind(tool_call_id)
219 .bind(usage.map(|u| u.prompt_tokens))
220 .bind(usage.map(|u| u.completion_tokens))
221 .bind(usage.map(|u| u.cache_read_tokens))
222 .bind(usage.map(|u| u.cache_creation_tokens))
223 .bind(usage.map(|u| u.thinking_tokens))
224 .bind(agent_name)
225 .execute(&self.pool)
226 .await?;
227
228 sqlx::query("UPDATE sessions SET last_accessed_at = datetime('now') WHERE id = ?")
230 .bind(session_id)
231 .execute(&self.pool)
232 .await?;
233
234 Ok(result.last_insert_rowid())
235 }
236
237 async fn insert_tool_message_with_full(
238 &self,
239 session_id: &str,
240 content: &str,
241 tool_call_id: &str,
242 full_content: &str,
243 ) -> Result<i64> {
244 let result = sqlx::query(
245 "INSERT INTO messages (session_id, role, content, full_content, tool_call_id) \
246 VALUES (?, 'tool', ?, ?, ?)",
247 )
248 .bind(session_id)
249 .bind(content)
250 .bind(full_content)
251 .bind(tool_call_id)
252 .execute(&self.pool)
253 .await?;
254
255 sqlx::query("UPDATE sessions SET last_accessed_at = datetime('now') WHERE id = ?")
256 .bind(session_id)
257 .execute(&self.pool)
258 .await?;
259
260 Ok(result.last_insert_rowid())
261 }
262
263 async fn mark_message_complete(&self, message_id: i64) -> Result<()> {
264 sqlx::query("UPDATE messages SET completed_at = datetime('now') WHERE id = ?")
265 .bind(message_id)
266 .execute(&self.pool)
267 .await?;
268 Ok(())
269 }
270
271 async fn copy_messages_into_session(
284 &self,
285 dst_session: &str,
286 messages: &[Message],
287 ) -> Result<()> {
288 if messages.is_empty() {
293 return Ok(());
294 }
295
296 let mut tx = self.pool.begin().await?;
297
298 for msg in messages {
299 sqlx::query(
307 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, completed_at) \
308 VALUES (?, ?, ?, ?, ?, \
309 CASE WHEN ? = 'assistant' THEN datetime('now') ELSE NULL END)",
310 )
311 .bind(dst_session)
312 .bind(msg.role.as_str())
313 .bind(msg.content.as_deref())
314 .bind(msg.tool_calls.as_deref())
315 .bind(msg.tool_call_id.as_deref())
316 .bind(msg.role.as_str()) .execute(&mut *tx)
318 .await?;
319 }
320
321 sqlx::query("UPDATE sessions SET last_accessed_at = datetime('now') WHERE id = ?")
326 .bind(dst_session)
327 .execute(&mut *tx)
328 .await?;
329
330 tx.commit().await?;
331 Ok(())
332 }
333
334 async fn update_message_thinking_content(&self, message_id: i64, content: &str) -> Result<()> {
340 sqlx::query("UPDATE messages SET thinking_content = ? WHERE id = ?")
341 .bind(content)
342 .bind(message_id)
343 .execute(&self.pool)
344 .await?;
345 Ok(())
346 }
347
348 async fn load_context(&self, session_id: &str) -> Result<Vec<Message>> {
358 let mut messages: Vec<Message> = sqlx::query_as::<_, MessageRow>(
359 "SELECT id, session_id, role, content, full_content, tool_calls, tool_call_id,
360 prompt_tokens, completion_tokens,
361 cache_read_tokens, cache_creation_tokens, thinking_tokens, thinking_content,
362 created_at
363 FROM messages
364 WHERE session_id = ? AND compacted_at IS NULL
365 AND (role != 'assistant' OR completed_at IS NOT NULL)
366 ORDER BY id ASC",
367 )
368 .bind(session_id)
369 .fetch_all(&self.pool)
370 .await?
371 .into_iter()
372 .map(|r| r.into())
373 .collect();
374
375 prune_mismatched_tool_calls(&mut messages); prune_null_content_messages(&mut messages); prune_whitespace_only_messages(&mut messages); Ok(messages)
381 }
382
383 async fn load_all_messages(&self, session_id: &str) -> Result<Vec<Message>> {
386 let rows: Vec<Message> = sqlx::query_as::<_, MessageRow>(
387 "SELECT id, session_id, role, content, full_content, tool_calls, tool_call_id,
388 prompt_tokens, completion_tokens,
389 cache_read_tokens, cache_creation_tokens, thinking_tokens, thinking_content,
390 created_at
391 FROM messages
392 WHERE session_id = ?
393 ORDER BY id ASC",
394 )
395 .bind(session_id)
396 .fetch_all(&self.pool)
397 .await?
398 .into_iter()
399 .map(|r| r.into())
400 .collect();
401 Ok(rows)
402 }
403
404 async fn recent_user_messages(&self, limit: i64) -> Result<Vec<String>> {
407 let rows: Vec<(String,)> = sqlx::query_as(
408 "SELECT content FROM messages
409 WHERE role = 'user' AND content IS NOT NULL AND content != ''
410 ORDER BY id DESC LIMIT ?",
411 )
412 .bind(limit)
413 .fetch_all(&self.pool)
414 .await?;
415
416 Ok(rows.into_iter().map(|r| r.0).collect())
417 }
418
419 async fn session_token_usage(&self, session_id: &str) -> Result<SessionUsage> {
421 let row: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(
422 "SELECT
423 COALESCE(SUM(prompt_tokens), 0),
424 COALESCE(SUM(completion_tokens), 0),
425 COALESCE(SUM(cache_read_tokens), 0),
426 COALESCE(SUM(cache_creation_tokens), 0),
427 COALESCE(SUM(thinking_tokens), 0),
428 COUNT(*)
429 FROM messages
430 WHERE session_id = ?
431 AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)",
432 )
433 .bind(session_id)
434 .fetch_one(&self.pool)
435 .await?;
436 Ok(SessionUsage {
437 prompt_tokens: row.0,
438 completion_tokens: row.1,
439 cache_read_tokens: row.2,
440 cache_creation_tokens: row.3,
441 thinking_tokens: row.4,
442 api_calls: row.5,
443 })
444 }
445
446 async fn session_usage_by_agent(
448 &self,
449 session_id: &str,
450 ) -> Result<Vec<(String, SessionUsage)>> {
451 let rows: Vec<(String, i64, i64, i64, i64, i64, i64)> = sqlx::query_as(
452 "SELECT
453 COALESCE(agent_name, 'main'),
454 COALESCE(SUM(prompt_tokens), 0),
455 COALESCE(SUM(completion_tokens), 0),
456 COALESCE(SUM(cache_read_tokens), 0),
457 COALESCE(SUM(cache_creation_tokens), 0),
458 COALESCE(SUM(thinking_tokens), 0),
459 COUNT(*)
460 FROM messages
461 WHERE session_id = ?
462 AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)
463 GROUP BY COALESCE(agent_name, 'main')
464 ORDER BY COALESCE(SUM(prompt_tokens), 0) + COALESCE(SUM(completion_tokens), 0) DESC",
465 )
466 .bind(session_id)
467 .fetch_all(&self.pool)
468 .await?;
469 Ok(rows
470 .into_iter()
471 .map(|r| {
472 (
473 r.0,
474 SessionUsage {
475 prompt_tokens: r.1,
476 completion_tokens: r.2,
477 cache_read_tokens: r.3,
478 cache_creation_tokens: r.4,
479 thinking_tokens: r.5,
480 api_calls: r.6,
481 },
482 )
483 })
484 .collect())
485 }
486
487 async fn list_sessions(&self, limit: i64, project_root: &Path) -> Result<Vec<SessionInfo>> {
489 let root = project_root.to_string_lossy().to_string();
490 let rows: Vec<SessionInfoRow> = sqlx::query_as(
491 "SELECT s.id, s.agent_name, s.created_at,
492 COUNT(m.id) as message_count,
493 COALESCE(SUM(m.prompt_tokens), 0) + COALESCE(SUM(m.completion_tokens), 0) as total_tokens,
494 s.title, s.mode
495 FROM sessions s
496 LEFT JOIN messages m ON m.session_id = s.id
497 WHERE s.project_root = ? OR s.project_root IS NULL
498 GROUP BY s.id
499 ORDER BY s.created_at DESC, s.rowid DESC
500 LIMIT ?",
501 )
502 .bind(&root)
503 .bind(limit)
504 .fetch_all(&self.pool)
505 .await?;
506 Ok(rows.into_iter().map(|r| r.into()).collect())
507 }
508
509 async fn last_assistant_message(&self, session_id: &str) -> Result<String> {
511 let row: Option<(String,)> = sqlx::query_as(
512 "SELECT content FROM messages
513 WHERE session_id = ? AND role = 'assistant' AND content IS NOT NULL
514 ORDER BY id DESC LIMIT 1",
515 )
516 .bind(session_id)
517 .fetch_optional(&self.pool)
518 .await?;
519 Ok(row.map(|r| r.0).unwrap_or_default())
520 }
521
522 async fn last_user_message(&self, session_id: &str) -> Result<String> {
524 let row: Option<(String,)> = sqlx::query_as(
525 "SELECT content FROM messages
526 WHERE session_id = ? AND role = 'user' AND content IS NOT NULL
527 ORDER BY id DESC LIMIT 1",
528 )
529 .bind(session_id)
530 .fetch_optional(&self.pool)
531 .await?;
532 Ok(row.map(|r| r.0).unwrap_or_default())
533 }
534
535 async fn delete_session(&self, session_id: &str) -> Result<bool> {
537 let mut tx = self.pool.begin().await?;
538
539 sqlx::query("DELETE FROM messages WHERE session_id = ?")
540 .bind(session_id)
541 .execute(&mut *tx)
542 .await?;
543
544 sqlx::query("DELETE FROM session_metadata WHERE session_id = ?")
545 .bind(session_id)
546 .execute(&mut *tx)
547 .await?;
548
549 let result = sqlx::query("DELETE FROM sessions WHERE id = ?")
550 .bind(session_id)
551 .execute(&mut *tx)
552 .await?;
553
554 tx.commit().await?;
555
556 sqlx::query("PRAGMA incremental_vacuum")
558 .execute(&self.pool)
559 .await?;
560
561 Ok(result.rows_affected() > 0)
562 }
563
564 async fn set_session_title(&self, session_id: &str, title: &str) -> Result<()> {
565 sqlx::query("UPDATE sessions SET title = ? WHERE id = ?")
566 .bind(title)
567 .bind(session_id)
568 .execute(&self.pool)
569 .await?;
570 Ok(())
571 }
572
573 async fn set_session_mode(&self, session_id: &str, mode: &str) -> Result<()> {
574 sqlx::query("UPDATE sessions SET mode = ? WHERE id = ?")
575 .bind(mode)
576 .bind(session_id)
577 .execute(&self.pool)
578 .await?;
579 Ok(())
580 }
581
582 async fn get_session_mode(&self, session_id: &str) -> Result<Option<String>> {
583 let row: Option<(Option<String>,)> =
584 sqlx::query_as("SELECT mode FROM sessions WHERE id = ?")
585 .bind(session_id)
586 .fetch_optional(&self.pool)
587 .await?;
588 Ok(row.and_then(|r| r.0))
589 }
590
591 async fn get_session_idle_secs(&self, session_id: &str) -> Result<Option<i64>> {
592 let row: Option<(Option<i64>,)> = sqlx::query_as(
594 "SELECT CAST((julianday('now') - julianday(last_accessed_at)) * 86400 AS INTEGER)
595 FROM sessions WHERE id = ?",
596 )
597 .bind(session_id)
598 .fetch_optional(&self.pool)
599 .await?;
600 Ok(row.and_then(|r| r.0))
601 }
602
603 async fn compact_session(
611 &self,
612 session_id: &str,
613 summary: &str,
614 preserve_count: usize,
615 ) -> Result<usize> {
616 let mut tx = self.pool.begin().await?;
617
618 let all_ids: Vec<(i64,)> = sqlx::query_as(
620 "SELECT id FROM messages WHERE session_id = ? AND compacted_at IS NULL ORDER BY id ASC",
621 )
622 .bind(session_id)
623 .fetch_all(&mut *tx)
624 .await?;
625
626 let total = all_ids.len();
627 if total == 0 {
628 tx.commit().await?;
629 return Ok(0);
630 }
631
632 let keep_from = total.saturating_sub(preserve_count);
634 let ids_to_archive: Vec<i64> = all_ids[..keep_from].iter().map(|r| r.0).collect();
635 let archived_count = ids_to_archive.len();
636
637 if archived_count == 0 {
638 tx.commit().await?;
639 return Ok(0);
640 }
641
642 for chunk in ids_to_archive.chunks(500) {
644 let placeholders: String = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(",");
645 let sql = format!(
646 "UPDATE messages SET compacted_at = datetime('now') \
647 WHERE session_id = ? AND id IN ({placeholders})"
648 );
649 let mut query = sqlx::query(&sql).bind(session_id);
650 for id in chunk {
651 query = query.bind(id);
652 }
653 query.execute(&mut *tx).await?;
654 }
655
656 sqlx::query(
658 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens)
659 VALUES (?, 'system', ?, NULL, NULL, NULL, NULL)",
660 )
661 .bind(session_id)
662 .bind(summary)
663 .execute(&mut *tx)
664 .await?;
665
666 let continuation = "Your context was compacted. The previous message contains a summary of our earlier conversation. \
669 Do not mention the summary or that compaction occurred. \
670 Continue the conversation naturally based on the summarized context.";
671 sqlx::query(
672 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens, completed_at)
673 VALUES (?, 'assistant', ?, NULL, NULL, NULL, NULL, datetime('now'))",
674 )
675 .bind(session_id)
676 .bind(continuation)
677 .execute(&mut *tx)
678 .await?;
679
680 tx.commit().await?;
681
682 Ok(archived_count)
683 }
684
685 async fn has_pending_tool_calls(&self, session_id: &str) -> Result<bool> {
688 let last_msg: Option<(String, Option<String>)> = sqlx::query_as(
689 "SELECT role, tool_calls FROM messages
690 WHERE session_id = ? AND compacted_at IS NULL
691 ORDER BY id DESC LIMIT 1",
692 )
693 .bind(session_id)
694 .fetch_optional(&self.pool)
695 .await?;
696
697 Ok(matches!(last_msg, Some((role, Some(_))) if role == "assistant"))
698 }
699
700 async fn clear_message_content(&self, message_ids: &[i64], stub: &str) -> Result<()> {
701 if message_ids.is_empty() {
702 return Ok(());
703 }
704 for chunk in message_ids.chunks(500) {
705 let placeholders: String = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(",");
706 let sql = format!("UPDATE messages SET content = ? WHERE id IN ({placeholders})");
707 let mut query = sqlx::query(&sql).bind(stub);
708 for id in chunk {
709 query = query.bind(id);
710 }
711 query.execute(&self.pool).await?;
712 }
713 Ok(())
714 }
715
716 async fn compacted_stats(&self) -> Result<CompactedStats> {
718 let row: (i64, i64, i64, Option<String>) = sqlx::query_as(
719 "SELECT
720 COUNT(*),
721 COUNT(DISTINCT session_id),
722 COALESCE(SUM(LENGTH(content) + LENGTH(COALESCE(tool_calls,''))), 0),
723 MIN(compacted_at)
724 FROM messages
725 WHERE compacted_at IS NOT NULL",
726 )
727 .fetch_one(&self.pool)
728 .await?;
729
730 Ok(CompactedStats {
731 message_count: row.0,
732 session_count: row.1,
733 size_bytes: row.2,
734 oldest: row.3,
735 })
736 }
737
738 async fn purge_compacted(&self, min_age_days: u32) -> Result<usize> {
741 let result = if min_age_days == 0 {
742 sqlx::query("DELETE FROM messages WHERE compacted_at IS NOT NULL")
743 .execute(&self.pool)
744 .await?
745 } else {
746 sqlx::query(
747 "DELETE FROM messages
748 WHERE compacted_at IS NOT NULL
749 AND compacted_at < datetime('now', ?)",
750 )
751 .bind(format!("-{min_age_days} days"))
752 .execute(&self.pool)
753 .await?
754 };
755
756 let deleted = result.rows_affected() as usize;
757
758 sqlx::query("VACUUM").execute(&self.pool).await?;
760
761 tracing::info!("Purged {deleted} compacted messages (>{min_age_days} days old)");
762 Ok(deleted)
763 }
764
765 async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> {
767 let row: Option<(String,)> =
768 sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ? AND key = ?")
769 .bind(session_id)
770 .bind(key)
771 .fetch_optional(&self.pool)
772 .await?;
773 Ok(row.map(|r| r.0))
774 }
775
776 async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> {
778 sqlx::query(
779 "INSERT INTO session_metadata (session_id, key, value, updated_at)
780 VALUES (?, ?, ?, CURRENT_TIMESTAMP)
781 ON CONFLICT(session_id, key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
782 )
783 .bind(session_id)
784 .bind(key)
785 .bind(value)
786 .execute(&self.pool)
787 .await?;
788 Ok(())
789 }
790
791 async fn get_todo(&self, session_id: &str) -> Result<Option<String>> {
793 self.get_metadata(session_id, "todo").await
794 }
795
796 async fn set_todo(&self, session_id: &str, content: &str) -> Result<()> {
798 self.set_metadata(session_id, "todo", content).await
799 }
800}