1use anyhow::Result;
9use std::path::Path;
10
11use super::{Database, MessageRow, SessionInfoRow};
12use crate::persistence::{
13 CompactedStats, InterruptionKind, Message, Persistence, Role, SessionInfo, SessionUsage,
14};
15
16pub(crate) fn prune_mismatched_tool_calls(messages: &mut Vec<Message>) {
27 if messages.is_empty() {
28 return;
29 }
30
31 let mut tool_call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
32 let mut tool_return_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
33
34 for msg in messages.iter() {
35 if msg.role == Role::Assistant {
36 if let Some(ref tc_json) = msg.tool_calls
37 && let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
38 {
39 for call in &calls {
40 if let Some(id) = call.get("id").and_then(|v| v.as_str()) {
41 tool_call_ids.insert(id.to_string());
42 }
43 }
44 }
45 } else if msg.role == Role::Tool
46 && let Some(ref id) = msg.tool_call_id
47 {
48 tool_return_ids.insert(id.clone());
49 }
50 }
51
52 let mismatched: std::collections::HashSet<&String> = tool_call_ids
53 .symmetric_difference(&tool_return_ids)
54 .collect();
55
56 if mismatched.is_empty() {
57 return;
58 }
59
60 messages.retain(|msg| {
61 if msg.role == Role::Tool
63 && let Some(ref id) = msg.tool_call_id
64 && mismatched.contains(id)
65 {
66 return false;
67 }
68 if msg.role == Role::Assistant
70 && let Some(ref tc_json) = msg.tool_calls
71 && let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
72 {
73 let has_mismatched = calls.iter().any(|call| {
74 call.get("id")
75 .and_then(|v| v.as_str())
76 .is_some_and(|id| mismatched.contains(&id.to_string()))
77 });
78 if has_mismatched {
79 return false;
80 }
81 }
82 true
83 });
84}
85
86pub(crate) fn prune_null_content_messages(messages: &mut Vec<Message>) {
97 messages.retain(|msg| {
98 if msg.role != Role::Assistant {
99 return true; }
101 let has_content = msg.content.as_deref().is_some_and(|c| !c.trim().is_empty());
102 let has_tool_calls = msg.tool_calls.is_some();
103 has_content || has_tool_calls
104 });
105}
106
107pub(crate) fn prune_whitespace_only_messages(messages: &mut Vec<Message>) {
114 messages.retain(|msg| {
115 if msg.role != Role::Assistant {
116 return true;
117 }
118 if msg.tool_calls.is_some() {
120 return true;
121 }
122 !matches!(msg.content.as_deref(), Some(c) if c.trim().is_empty())
124 });
125}
126
127pub fn detect_interruption(messages: &[Message]) -> Option<InterruptionKind> {
138 let last = messages.iter().rev().find(|m| m.role != Role::System)?;
139
140 match last.role {
141 Role::User => {
142 let preview = last
143 .content
144 .as_deref()
145 .unwrap_or("")
146 .chars()
147 .take(80)
148 .collect::<String>();
149 Some(InterruptionKind::Prompt(preview))
150 }
151 Role::Tool => Some(InterruptionKind::Tool),
152 _ => None,
153 }
154}
155
156#[async_trait::async_trait]
159impl Persistence for Database {
160 async fn create_session(&self, agent_name: &str, project_root: &Path) -> Result<String> {
162 let id = uuid::Uuid::new_v4().to_string();
163 let root = project_root.to_string_lossy().to_string();
164 sqlx::query("INSERT INTO sessions (id, agent_name, project_root) VALUES (?, ?, ?)")
165 .bind(&id)
166 .bind(agent_name)
167 .bind(&root)
168 .execute(&self.pool)
169 .await?;
170 tracing::info!("Created session: {id} (project: {root})");
171 Ok(id)
172 }
173
174 async fn insert_message(
176 &self,
177 session_id: &str,
178 role: &Role,
179 content: Option<&str>,
180 tool_calls: Option<&str>,
181 tool_call_id: Option<&str>,
182 usage: Option<&crate::providers::TokenUsage>,
183 ) -> Result<i64> {
184 self.insert_message_with_agent(
185 session_id,
186 role,
187 content,
188 tool_calls,
189 tool_call_id,
190 usage,
191 None,
192 )
193 .await
194 }
195
196 #[allow(clippy::too_many_arguments)]
198 async fn insert_message_with_agent(
199 &self,
200 session_id: &str,
201 role: &Role,
202 content: Option<&str>,
203 tool_calls: Option<&str>,
204 tool_call_id: Option<&str>,
205 usage: Option<&crate::providers::TokenUsage>,
206 agent_name: Option<&str>,
207 ) -> Result<i64> {
208 let result = sqlx::query(
209 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, \
210 prompt_tokens, completion_tokens, cache_read_tokens, cache_creation_tokens, \
211 thinking_tokens, agent_name)
212 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
213 )
214 .bind(session_id)
215 .bind(role.as_str())
216 .bind(content)
217 .bind(tool_calls)
218 .bind(tool_call_id)
219 .bind(usage.map(|u| u.prompt_tokens))
220 .bind(usage.map(|u| u.completion_tokens))
221 .bind(usage.map(|u| u.cache_read_tokens))
222 .bind(usage.map(|u| u.cache_creation_tokens))
223 .bind(usage.map(|u| u.thinking_tokens))
224 .bind(agent_name)
225 .execute(&self.pool)
226 .await?;
227
228 sqlx::query("UPDATE sessions SET last_accessed_at = datetime('now') WHERE id = ?")
230 .bind(session_id)
231 .execute(&self.pool)
232 .await?;
233
234 Ok(result.last_insert_rowid())
235 }
236
237 async fn insert_tool_message_with_full(
238 &self,
239 session_id: &str,
240 content: &str,
241 tool_call_id: &str,
242 full_content: &str,
243 ) -> Result<i64> {
244 let result = sqlx::query(
245 "INSERT INTO messages (session_id, role, content, full_content, tool_call_id) \
246 VALUES (?, 'tool', ?, ?, ?)",
247 )
248 .bind(session_id)
249 .bind(content)
250 .bind(full_content)
251 .bind(tool_call_id)
252 .execute(&self.pool)
253 .await?;
254
255 sqlx::query("UPDATE sessions SET last_accessed_at = datetime('now') WHERE id = ?")
256 .bind(session_id)
257 .execute(&self.pool)
258 .await?;
259
260 Ok(result.last_insert_rowid())
261 }
262
263 async fn mark_message_complete(&self, message_id: i64) -> Result<()> {
264 sqlx::query("UPDATE messages SET completed_at = datetime('now') WHERE id = ?")
265 .bind(message_id)
266 .execute(&self.pool)
267 .await?;
268 Ok(())
269 }
270
271 async fn update_message_thinking_content(&self, message_id: i64, content: &str) -> Result<()> {
277 sqlx::query("UPDATE messages SET thinking_content = ? WHERE id = ?")
278 .bind(content)
279 .bind(message_id)
280 .execute(&self.pool)
281 .await?;
282 Ok(())
283 }
284
285 async fn load_context(&self, session_id: &str) -> Result<Vec<Message>> {
295 let mut messages: Vec<Message> = sqlx::query_as::<_, MessageRow>(
296 "SELECT id, session_id, role, content, full_content, tool_calls, tool_call_id,
297 prompt_tokens, completion_tokens,
298 cache_read_tokens, cache_creation_tokens, thinking_tokens, thinking_content,
299 created_at
300 FROM messages
301 WHERE session_id = ? AND compacted_at IS NULL
302 AND (role != 'assistant' OR completed_at IS NOT NULL)
303 ORDER BY id ASC",
304 )
305 .bind(session_id)
306 .fetch_all(&self.pool)
307 .await?
308 .into_iter()
309 .map(|r| r.into())
310 .collect();
311
312 prune_mismatched_tool_calls(&mut messages); prune_null_content_messages(&mut messages); prune_whitespace_only_messages(&mut messages); Ok(messages)
318 }
319
320 async fn load_all_messages(&self, session_id: &str) -> Result<Vec<Message>> {
323 let rows: Vec<Message> = sqlx::query_as::<_, MessageRow>(
324 "SELECT id, session_id, role, content, full_content, tool_calls, tool_call_id,
325 prompt_tokens, completion_tokens,
326 cache_read_tokens, cache_creation_tokens, thinking_tokens, thinking_content,
327 created_at
328 FROM messages
329 WHERE session_id = ?
330 ORDER BY id ASC",
331 )
332 .bind(session_id)
333 .fetch_all(&self.pool)
334 .await?
335 .into_iter()
336 .map(|r| r.into())
337 .collect();
338 Ok(rows)
339 }
340
341 async fn recent_user_messages(&self, limit: i64) -> Result<Vec<String>> {
344 let rows: Vec<(String,)> = sqlx::query_as(
345 "SELECT content FROM messages
346 WHERE role = 'user' AND content IS NOT NULL AND content != ''
347 ORDER BY id DESC LIMIT ?",
348 )
349 .bind(limit)
350 .fetch_all(&self.pool)
351 .await?;
352
353 Ok(rows.into_iter().map(|r| r.0).collect())
354 }
355
356 async fn session_token_usage(&self, session_id: &str) -> Result<SessionUsage> {
358 let row: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(
359 "SELECT
360 COALESCE(SUM(prompt_tokens), 0),
361 COALESCE(SUM(completion_tokens), 0),
362 COALESCE(SUM(cache_read_tokens), 0),
363 COALESCE(SUM(cache_creation_tokens), 0),
364 COALESCE(SUM(thinking_tokens), 0),
365 COUNT(*)
366 FROM messages
367 WHERE session_id = ?
368 AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)",
369 )
370 .bind(session_id)
371 .fetch_one(&self.pool)
372 .await?;
373 Ok(SessionUsage {
374 prompt_tokens: row.0,
375 completion_tokens: row.1,
376 cache_read_tokens: row.2,
377 cache_creation_tokens: row.3,
378 thinking_tokens: row.4,
379 api_calls: row.5,
380 })
381 }
382
383 async fn session_usage_by_agent(
385 &self,
386 session_id: &str,
387 ) -> Result<Vec<(String, SessionUsage)>> {
388 let rows: Vec<(String, i64, i64, i64, i64, i64, i64)> = sqlx::query_as(
389 "SELECT
390 COALESCE(agent_name, 'main'),
391 COALESCE(SUM(prompt_tokens), 0),
392 COALESCE(SUM(completion_tokens), 0),
393 COALESCE(SUM(cache_read_tokens), 0),
394 COALESCE(SUM(cache_creation_tokens), 0),
395 COALESCE(SUM(thinking_tokens), 0),
396 COUNT(*)
397 FROM messages
398 WHERE session_id = ?
399 AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)
400 GROUP BY COALESCE(agent_name, 'main')
401 ORDER BY COALESCE(SUM(prompt_tokens), 0) + COALESCE(SUM(completion_tokens), 0) DESC",
402 )
403 .bind(session_id)
404 .fetch_all(&self.pool)
405 .await?;
406 Ok(rows
407 .into_iter()
408 .map(|r| {
409 (
410 r.0,
411 SessionUsage {
412 prompt_tokens: r.1,
413 completion_tokens: r.2,
414 cache_read_tokens: r.3,
415 cache_creation_tokens: r.4,
416 thinking_tokens: r.5,
417 api_calls: r.6,
418 },
419 )
420 })
421 .collect())
422 }
423
424 async fn list_sessions(&self, limit: i64, project_root: &Path) -> Result<Vec<SessionInfo>> {
426 let root = project_root.to_string_lossy().to_string();
427 let rows: Vec<SessionInfoRow> = sqlx::query_as(
428 "SELECT s.id, s.agent_name, s.created_at,
429 COUNT(m.id) as message_count,
430 COALESCE(SUM(m.prompt_tokens), 0) + COALESCE(SUM(m.completion_tokens), 0) as total_tokens,
431 s.title, s.mode
432 FROM sessions s
433 LEFT JOIN messages m ON m.session_id = s.id
434 WHERE s.project_root = ? OR s.project_root IS NULL
435 GROUP BY s.id
436 ORDER BY s.created_at DESC, s.rowid DESC
437 LIMIT ?",
438 )
439 .bind(&root)
440 .bind(limit)
441 .fetch_all(&self.pool)
442 .await?;
443 Ok(rows.into_iter().map(|r| r.into()).collect())
444 }
445
446 async fn last_assistant_message(&self, session_id: &str) -> Result<String> {
448 let row: Option<(String,)> = sqlx::query_as(
449 "SELECT content FROM messages
450 WHERE session_id = ? AND role = 'assistant' AND content IS NOT NULL
451 ORDER BY id DESC LIMIT 1",
452 )
453 .bind(session_id)
454 .fetch_optional(&self.pool)
455 .await?;
456 Ok(row.map(|r| r.0).unwrap_or_default())
457 }
458
459 async fn last_user_message(&self, session_id: &str) -> Result<String> {
461 let row: Option<(String,)> = sqlx::query_as(
462 "SELECT content FROM messages
463 WHERE session_id = ? AND role = 'user' AND content IS NOT NULL
464 ORDER BY id DESC LIMIT 1",
465 )
466 .bind(session_id)
467 .fetch_optional(&self.pool)
468 .await?;
469 Ok(row.map(|r| r.0).unwrap_or_default())
470 }
471
472 async fn delete_session(&self, session_id: &str) -> Result<bool> {
474 let mut tx = self.pool.begin().await?;
475
476 sqlx::query("DELETE FROM messages WHERE session_id = ?")
477 .bind(session_id)
478 .execute(&mut *tx)
479 .await?;
480
481 sqlx::query("DELETE FROM session_metadata WHERE session_id = ?")
482 .bind(session_id)
483 .execute(&mut *tx)
484 .await?;
485
486 let result = sqlx::query("DELETE FROM sessions WHERE id = ?")
487 .bind(session_id)
488 .execute(&mut *tx)
489 .await?;
490
491 tx.commit().await?;
492
493 sqlx::query("PRAGMA incremental_vacuum")
495 .execute(&self.pool)
496 .await?;
497
498 Ok(result.rows_affected() > 0)
499 }
500
501 async fn set_session_title(&self, session_id: &str, title: &str) -> Result<()> {
502 sqlx::query("UPDATE sessions SET title = ? WHERE id = ?")
503 .bind(title)
504 .bind(session_id)
505 .execute(&self.pool)
506 .await?;
507 Ok(())
508 }
509
510 async fn set_session_mode(&self, session_id: &str, mode: &str) -> Result<()> {
511 sqlx::query("UPDATE sessions SET mode = ? WHERE id = ?")
512 .bind(mode)
513 .bind(session_id)
514 .execute(&self.pool)
515 .await?;
516 Ok(())
517 }
518
519 async fn get_session_mode(&self, session_id: &str) -> Result<Option<String>> {
520 let row: Option<(Option<String>,)> =
521 sqlx::query_as("SELECT mode FROM sessions WHERE id = ?")
522 .bind(session_id)
523 .fetch_optional(&self.pool)
524 .await?;
525 Ok(row.and_then(|r| r.0))
526 }
527
528 async fn get_session_idle_secs(&self, session_id: &str) -> Result<Option<i64>> {
529 let row: Option<(Option<i64>,)> = sqlx::query_as(
531 "SELECT CAST((julianday('now') - julianday(last_accessed_at)) * 86400 AS INTEGER)
532 FROM sessions WHERE id = ?",
533 )
534 .bind(session_id)
535 .fetch_optional(&self.pool)
536 .await?;
537 Ok(row.and_then(|r| r.0))
538 }
539
540 async fn compact_session(
548 &self,
549 session_id: &str,
550 summary: &str,
551 preserve_count: usize,
552 ) -> Result<usize> {
553 let mut tx = self.pool.begin().await?;
554
555 let all_ids: Vec<(i64,)> = sqlx::query_as(
557 "SELECT id FROM messages WHERE session_id = ? AND compacted_at IS NULL ORDER BY id ASC",
558 )
559 .bind(session_id)
560 .fetch_all(&mut *tx)
561 .await?;
562
563 let total = all_ids.len();
564 if total == 0 {
565 tx.commit().await?;
566 return Ok(0);
567 }
568
569 let keep_from = total.saturating_sub(preserve_count);
571 let ids_to_archive: Vec<i64> = all_ids[..keep_from].iter().map(|r| r.0).collect();
572 let archived_count = ids_to_archive.len();
573
574 if archived_count == 0 {
575 tx.commit().await?;
576 return Ok(0);
577 }
578
579 for chunk in ids_to_archive.chunks(500) {
581 let placeholders: String = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(",");
582 let sql = format!(
583 "UPDATE messages SET compacted_at = datetime('now') \
584 WHERE session_id = ? AND id IN ({placeholders})"
585 );
586 let mut query = sqlx::query(&sql).bind(session_id);
587 for id in chunk {
588 query = query.bind(id);
589 }
590 query.execute(&mut *tx).await?;
591 }
592
593 sqlx::query(
595 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens)
596 VALUES (?, 'system', ?, NULL, NULL, NULL, NULL)",
597 )
598 .bind(session_id)
599 .bind(summary)
600 .execute(&mut *tx)
601 .await?;
602
603 let continuation = "Your context was compacted. The previous message contains a summary of our earlier conversation. \
606 Do not mention the summary or that compaction occurred. \
607 Continue the conversation naturally based on the summarized context.";
608 sqlx::query(
609 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens, completed_at)
610 VALUES (?, 'assistant', ?, NULL, NULL, NULL, NULL, datetime('now'))",
611 )
612 .bind(session_id)
613 .bind(continuation)
614 .execute(&mut *tx)
615 .await?;
616
617 tx.commit().await?;
618
619 Ok(archived_count)
620 }
621
622 async fn has_pending_tool_calls(&self, session_id: &str) -> Result<bool> {
625 let last_msg: Option<(String, Option<String>)> = sqlx::query_as(
626 "SELECT role, tool_calls FROM messages
627 WHERE session_id = ? AND compacted_at IS NULL
628 ORDER BY id DESC LIMIT 1",
629 )
630 .bind(session_id)
631 .fetch_optional(&self.pool)
632 .await?;
633
634 Ok(matches!(last_msg, Some((role, Some(_))) if role == "assistant"))
635 }
636
637 async fn clear_message_content(&self, message_ids: &[i64], stub: &str) -> Result<()> {
638 if message_ids.is_empty() {
639 return Ok(());
640 }
641 for chunk in message_ids.chunks(500) {
642 let placeholders: String = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(",");
643 let sql = format!("UPDATE messages SET content = ? WHERE id IN ({placeholders})");
644 let mut query = sqlx::query(&sql).bind(stub);
645 for id in chunk {
646 query = query.bind(id);
647 }
648 query.execute(&self.pool).await?;
649 }
650 Ok(())
651 }
652
653 async fn compacted_stats(&self) -> Result<CompactedStats> {
655 let row: (i64, i64, i64, Option<String>) = sqlx::query_as(
656 "SELECT
657 COUNT(*),
658 COUNT(DISTINCT session_id),
659 COALESCE(SUM(LENGTH(content) + LENGTH(COALESCE(tool_calls,''))), 0),
660 MIN(compacted_at)
661 FROM messages
662 WHERE compacted_at IS NOT NULL",
663 )
664 .fetch_one(&self.pool)
665 .await?;
666
667 Ok(CompactedStats {
668 message_count: row.0,
669 session_count: row.1,
670 size_bytes: row.2,
671 oldest: row.3,
672 })
673 }
674
675 async fn purge_compacted(&self, min_age_days: u32) -> Result<usize> {
678 let result = if min_age_days == 0 {
679 sqlx::query("DELETE FROM messages WHERE compacted_at IS NOT NULL")
680 .execute(&self.pool)
681 .await?
682 } else {
683 sqlx::query(
684 "DELETE FROM messages
685 WHERE compacted_at IS NOT NULL
686 AND compacted_at < datetime('now', ?)",
687 )
688 .bind(format!("-{min_age_days} days"))
689 .execute(&self.pool)
690 .await?
691 };
692
693 let deleted = result.rows_affected() as usize;
694
695 sqlx::query("VACUUM").execute(&self.pool).await?;
697
698 tracing::info!("Purged {deleted} compacted messages (>{min_age_days} days old)");
699 Ok(deleted)
700 }
701
702 async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> {
704 let row: Option<(String,)> =
705 sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ? AND key = ?")
706 .bind(session_id)
707 .bind(key)
708 .fetch_optional(&self.pool)
709 .await?;
710 Ok(row.map(|r| r.0))
711 }
712
713 async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> {
715 sqlx::query(
716 "INSERT INTO session_metadata (session_id, key, value, updated_at)
717 VALUES (?, ?, ?, CURRENT_TIMESTAMP)
718 ON CONFLICT(session_id, key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
719 )
720 .bind(session_id)
721 .bind(key)
722 .bind(value)
723 .execute(&self.pool)
724 .await?;
725 Ok(())
726 }
727
728 async fn get_todo(&self, session_id: &str) -> Result<Option<String>> {
730 self.get_metadata(session_id, "todo").await
731 }
732
733 async fn set_todo(&self, session_id: &str, content: &str) -> Result<()> {
735 self.set_metadata(session_id, "todo", content).await
736 }
737}