1use anyhow::Result;
7use std::path::Path;
8
9use super::{Database, MessageRow, SessionInfoRow};
10use crate::persistence::{
11 CompactedStats, InterruptionKind, Message, Persistence, Role, SessionInfo, SessionUsage,
12};
13
14pub(crate) fn prune_mismatched_tool_calls(messages: &mut Vec<Message>) {
25 if messages.is_empty() {
26 return;
27 }
28
29 let mut tool_call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
30 let mut tool_return_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
31
32 for msg in messages.iter() {
33 if msg.role == Role::Assistant {
34 if let Some(ref tc_json) = msg.tool_calls
35 && let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
36 {
37 for call in &calls {
38 if let Some(id) = call.get("id").and_then(|v| v.as_str()) {
39 tool_call_ids.insert(id.to_string());
40 }
41 }
42 }
43 } else if msg.role == Role::Tool
44 && let Some(ref id) = msg.tool_call_id
45 {
46 tool_return_ids.insert(id.clone());
47 }
48 }
49
50 let mismatched: std::collections::HashSet<&String> = tool_call_ids
51 .symmetric_difference(&tool_return_ids)
52 .collect();
53
54 if mismatched.is_empty() {
55 return;
56 }
57
58 messages.retain(|msg| {
59 if msg.role == Role::Tool
61 && let Some(ref id) = msg.tool_call_id
62 && mismatched.contains(id)
63 {
64 return false;
65 }
66 if msg.role == Role::Assistant
68 && let Some(ref tc_json) = msg.tool_calls
69 && let Ok(calls) = serde_json::from_str::<Vec<serde_json::Value>>(tc_json)
70 {
71 let has_mismatched = calls.iter().any(|call| {
72 call.get("id")
73 .and_then(|v| v.as_str())
74 .is_some_and(|id| mismatched.contains(&id.to_string()))
75 });
76 if has_mismatched {
77 return false;
78 }
79 }
80 true
81 });
82}
83
84pub(crate) fn prune_null_content_messages(messages: &mut Vec<Message>) {
95 messages.retain(|msg| {
96 if msg.role != Role::Assistant {
97 return true; }
99 let has_content = msg.content.as_deref().is_some_and(|c| !c.trim().is_empty());
100 let has_tool_calls = msg.tool_calls.is_some();
101 has_content || has_tool_calls
102 });
103}
104
105pub(crate) fn prune_whitespace_only_messages(messages: &mut Vec<Message>) {
112 messages.retain(|msg| {
113 if msg.role != Role::Assistant {
114 return true;
115 }
116 if msg.tool_calls.is_some() {
118 return true;
119 }
120 !matches!(msg.content.as_deref(), Some(c) if c.trim().is_empty())
122 });
123}
124
125pub fn detect_interruption(messages: &[Message]) -> Option<InterruptionKind> {
136 let last = messages.iter().rev().find(|m| m.role != Role::System)?;
137
138 match last.role {
139 Role::User => {
140 let preview = last
141 .content
142 .as_deref()
143 .unwrap_or("")
144 .chars()
145 .take(80)
146 .collect::<String>();
147 Some(InterruptionKind::Prompt(preview))
148 }
149 Role::Tool => Some(InterruptionKind::Tool),
150 _ => None,
151 }
152}
153
154#[async_trait::async_trait]
157impl Persistence for Database {
158 async fn create_session(&self, agent_name: &str, project_root: &Path) -> Result<String> {
160 let id = uuid::Uuid::new_v4().to_string();
161 let root = project_root.to_string_lossy().to_string();
162 sqlx::query("INSERT INTO sessions (id, agent_name, project_root) VALUES (?, ?, ?)")
163 .bind(&id)
164 .bind(agent_name)
165 .bind(&root)
166 .execute(&self.pool)
167 .await?;
168 tracing::info!("Created session: {id} (project: {root})");
169 Ok(id)
170 }
171
172 async fn insert_message(
174 &self,
175 session_id: &str,
176 role: &Role,
177 content: Option<&str>,
178 tool_calls: Option<&str>,
179 tool_call_id: Option<&str>,
180 usage: Option<&crate::providers::TokenUsage>,
181 ) -> Result<i64> {
182 self.insert_message_with_agent(
183 session_id,
184 role,
185 content,
186 tool_calls,
187 tool_call_id,
188 usage,
189 None,
190 )
191 .await
192 }
193
194 #[allow(clippy::too_many_arguments)]
196 async fn insert_message_with_agent(
197 &self,
198 session_id: &str,
199 role: &Role,
200 content: Option<&str>,
201 tool_calls: Option<&str>,
202 tool_call_id: Option<&str>,
203 usage: Option<&crate::providers::TokenUsage>,
204 agent_name: Option<&str>,
205 ) -> Result<i64> {
206 let result = sqlx::query(
207 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, \
208 prompt_tokens, completion_tokens, cache_read_tokens, cache_creation_tokens, \
209 thinking_tokens, agent_name)
210 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
211 )
212 .bind(session_id)
213 .bind(role.as_str())
214 .bind(content)
215 .bind(tool_calls)
216 .bind(tool_call_id)
217 .bind(usage.map(|u| u.prompt_tokens))
218 .bind(usage.map(|u| u.completion_tokens))
219 .bind(usage.map(|u| u.cache_read_tokens))
220 .bind(usage.map(|u| u.cache_creation_tokens))
221 .bind(usage.map(|u| u.thinking_tokens))
222 .bind(agent_name)
223 .execute(&self.pool)
224 .await?;
225
226 sqlx::query("UPDATE sessions SET last_accessed_at = datetime('now') WHERE id = ?")
228 .bind(session_id)
229 .execute(&self.pool)
230 .await?;
231
232 Ok(result.last_insert_rowid())
233 }
234
235 async fn insert_tool_message_with_full(
236 &self,
237 session_id: &str,
238 content: &str,
239 tool_call_id: &str,
240 full_content: &str,
241 ) -> Result<i64> {
242 let result = sqlx::query(
243 "INSERT INTO messages (session_id, role, content, full_content, tool_call_id) \
244 VALUES (?, 'tool', ?, ?, ?)",
245 )
246 .bind(session_id)
247 .bind(content)
248 .bind(full_content)
249 .bind(tool_call_id)
250 .execute(&self.pool)
251 .await?;
252
253 sqlx::query("UPDATE sessions SET last_accessed_at = datetime('now') WHERE id = ?")
254 .bind(session_id)
255 .execute(&self.pool)
256 .await?;
257
258 Ok(result.last_insert_rowid())
259 }
260
261 async fn mark_message_complete(&self, message_id: i64) -> Result<()> {
262 sqlx::query("UPDATE messages SET completed_at = datetime('now') WHERE id = ?")
263 .bind(message_id)
264 .execute(&self.pool)
265 .await?;
266 Ok(())
267 }
268
269 async fn load_context(&self, session_id: &str) -> Result<Vec<Message>> {
278 let mut messages: Vec<Message> = sqlx::query_as::<_, MessageRow>(
279 "SELECT id, session_id, role, content, full_content, tool_calls, tool_call_id,
280 prompt_tokens, completion_tokens,
281 cache_read_tokens, cache_creation_tokens, thinking_tokens,
282 created_at
283 FROM messages
284 WHERE session_id = ? AND compacted_at IS NULL
285 ORDER BY id ASC",
286 )
287 .bind(session_id)
288 .fetch_all(&self.pool)
289 .await?
290 .into_iter()
291 .map(|r| r.into())
292 .collect();
293
294 prune_mismatched_tool_calls(&mut messages); prune_null_content_messages(&mut messages); prune_whitespace_only_messages(&mut messages); Ok(messages)
300 }
301
302 async fn load_all_messages(&self, session_id: &str) -> Result<Vec<Message>> {
305 let rows: Vec<Message> = sqlx::query_as::<_, MessageRow>(
306 "SELECT id, session_id, role, content, full_content, tool_calls, tool_call_id,
307 prompt_tokens, completion_tokens,
308 cache_read_tokens, cache_creation_tokens, thinking_tokens,
309 created_at
310 FROM messages
311 WHERE session_id = ?
312 ORDER BY id ASC",
313 )
314 .bind(session_id)
315 .fetch_all(&self.pool)
316 .await?
317 .into_iter()
318 .map(|r| r.into())
319 .collect();
320 Ok(rows)
321 }
322
323 async fn recent_user_messages(&self, limit: i64) -> Result<Vec<String>> {
326 let rows: Vec<(String,)> = sqlx::query_as(
327 "SELECT content FROM messages
328 WHERE role = 'user' AND content IS NOT NULL AND content != ''
329 ORDER BY id DESC LIMIT ?",
330 )
331 .bind(limit)
332 .fetch_all(&self.pool)
333 .await?;
334
335 Ok(rows.into_iter().map(|r| r.0).collect())
336 }
337
338 async fn session_token_usage(&self, session_id: &str) -> Result<SessionUsage> {
340 let row: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(
341 "SELECT
342 COALESCE(SUM(prompt_tokens), 0),
343 COALESCE(SUM(completion_tokens), 0),
344 COALESCE(SUM(cache_read_tokens), 0),
345 COALESCE(SUM(cache_creation_tokens), 0),
346 COALESCE(SUM(thinking_tokens), 0),
347 COUNT(*)
348 FROM messages
349 WHERE session_id = ?
350 AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)",
351 )
352 .bind(session_id)
353 .fetch_one(&self.pool)
354 .await?;
355 Ok(SessionUsage {
356 prompt_tokens: row.0,
357 completion_tokens: row.1,
358 cache_read_tokens: row.2,
359 cache_creation_tokens: row.3,
360 thinking_tokens: row.4,
361 api_calls: row.5,
362 })
363 }
364
365 async fn session_usage_by_agent(
367 &self,
368 session_id: &str,
369 ) -> Result<Vec<(String, SessionUsage)>> {
370 let rows: Vec<(String, i64, i64, i64, i64, i64, i64)> = sqlx::query_as(
371 "SELECT
372 COALESCE(agent_name, 'main'),
373 COALESCE(SUM(prompt_tokens), 0),
374 COALESCE(SUM(completion_tokens), 0),
375 COALESCE(SUM(cache_read_tokens), 0),
376 COALESCE(SUM(cache_creation_tokens), 0),
377 COALESCE(SUM(thinking_tokens), 0),
378 COUNT(*)
379 FROM messages
380 WHERE session_id = ?
381 AND (prompt_tokens IS NOT NULL OR completion_tokens IS NOT NULL)
382 GROUP BY COALESCE(agent_name, 'main')
383 ORDER BY COALESCE(SUM(prompt_tokens), 0) + COALESCE(SUM(completion_tokens), 0) DESC",
384 )
385 .bind(session_id)
386 .fetch_all(&self.pool)
387 .await?;
388 Ok(rows
389 .into_iter()
390 .map(|r| {
391 (
392 r.0,
393 SessionUsage {
394 prompt_tokens: r.1,
395 completion_tokens: r.2,
396 cache_read_tokens: r.3,
397 cache_creation_tokens: r.4,
398 thinking_tokens: r.5,
399 api_calls: r.6,
400 },
401 )
402 })
403 .collect())
404 }
405
406 async fn list_sessions(&self, limit: i64, project_root: &Path) -> Result<Vec<SessionInfo>> {
408 let root = project_root.to_string_lossy().to_string();
409 let rows: Vec<SessionInfoRow> = sqlx::query_as(
410 "SELECT s.id, s.agent_name, s.created_at,
411 COUNT(m.id) as message_count,
412 COALESCE(SUM(m.prompt_tokens), 0) + COALESCE(SUM(m.completion_tokens), 0) as total_tokens,
413 s.title, s.mode
414 FROM sessions s
415 LEFT JOIN messages m ON m.session_id = s.id
416 WHERE s.project_root = ? OR s.project_root IS NULL
417 GROUP BY s.id
418 ORDER BY s.created_at DESC, s.rowid DESC
419 LIMIT ?",
420 )
421 .bind(&root)
422 .bind(limit)
423 .fetch_all(&self.pool)
424 .await?;
425 Ok(rows.into_iter().map(|r| r.into()).collect())
426 }
427
428 async fn last_assistant_message(&self, session_id: &str) -> Result<String> {
430 let row: Option<(String,)> = sqlx::query_as(
431 "SELECT content FROM messages
432 WHERE session_id = ? AND role = 'assistant' AND content IS NOT NULL
433 ORDER BY id DESC LIMIT 1",
434 )
435 .bind(session_id)
436 .fetch_optional(&self.pool)
437 .await?;
438 Ok(row.map(|r| r.0).unwrap_or_default())
439 }
440
441 async fn last_user_message(&self, session_id: &str) -> Result<String> {
443 let row: Option<(String,)> = sqlx::query_as(
444 "SELECT content FROM messages
445 WHERE session_id = ? AND role = 'user' AND content IS NOT NULL
446 ORDER BY id DESC LIMIT 1",
447 )
448 .bind(session_id)
449 .fetch_optional(&self.pool)
450 .await?;
451 Ok(row.map(|r| r.0).unwrap_or_default())
452 }
453
454 async fn delete_session(&self, session_id: &str) -> Result<bool> {
456 let mut tx = self.pool.begin().await?;
457
458 sqlx::query("DELETE FROM messages WHERE session_id = ?")
459 .bind(session_id)
460 .execute(&mut *tx)
461 .await?;
462
463 sqlx::query("DELETE FROM session_metadata WHERE session_id = ?")
464 .bind(session_id)
465 .execute(&mut *tx)
466 .await?;
467
468 let result = sqlx::query("DELETE FROM sessions WHERE id = ?")
469 .bind(session_id)
470 .execute(&mut *tx)
471 .await?;
472
473 tx.commit().await?;
474
475 sqlx::query("PRAGMA incremental_vacuum")
477 .execute(&self.pool)
478 .await?;
479
480 Ok(result.rows_affected() > 0)
481 }
482
483 async fn set_session_title(&self, session_id: &str, title: &str) -> Result<()> {
484 sqlx::query("UPDATE sessions SET title = ? WHERE id = ?")
485 .bind(title)
486 .bind(session_id)
487 .execute(&self.pool)
488 .await?;
489 Ok(())
490 }
491
492 async fn set_session_mode(&self, session_id: &str, mode: &str) -> Result<()> {
493 sqlx::query("UPDATE sessions SET mode = ? WHERE id = ?")
494 .bind(mode)
495 .bind(session_id)
496 .execute(&self.pool)
497 .await?;
498 Ok(())
499 }
500
501 async fn get_session_mode(&self, session_id: &str) -> Result<Option<String>> {
502 let row: Option<(Option<String>,)> =
503 sqlx::query_as("SELECT mode FROM sessions WHERE id = ?")
504 .bind(session_id)
505 .fetch_optional(&self.pool)
506 .await?;
507 Ok(row.and_then(|r| r.0))
508 }
509
510 async fn get_session_idle_secs(&self, session_id: &str) -> Result<Option<i64>> {
511 let row: Option<(Option<i64>,)> = sqlx::query_as(
513 "SELECT CAST((julianday('now') - julianday(last_accessed_at)) * 86400 AS INTEGER)
514 FROM sessions WHERE id = ?",
515 )
516 .bind(session_id)
517 .fetch_optional(&self.pool)
518 .await?;
519 Ok(row.and_then(|r| r.0))
520 }
521
522 async fn compact_session(
530 &self,
531 session_id: &str,
532 summary: &str,
533 preserve_count: usize,
534 ) -> Result<usize> {
535 let mut tx = self.pool.begin().await?;
536
537 let all_ids: Vec<(i64,)> = sqlx::query_as(
539 "SELECT id FROM messages WHERE session_id = ? AND compacted_at IS NULL ORDER BY id ASC",
540 )
541 .bind(session_id)
542 .fetch_all(&mut *tx)
543 .await?;
544
545 let total = all_ids.len();
546 if total == 0 {
547 tx.commit().await?;
548 return Ok(0);
549 }
550
551 let keep_from = total.saturating_sub(preserve_count);
553 let ids_to_archive: Vec<i64> = all_ids[..keep_from].iter().map(|r| r.0).collect();
554 let archived_count = ids_to_archive.len();
555
556 if archived_count == 0 {
557 tx.commit().await?;
558 return Ok(0);
559 }
560
561 for chunk in ids_to_archive.chunks(500) {
563 let placeholders: String = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(",");
564 let sql = format!(
565 "UPDATE messages SET compacted_at = datetime('now') \
566 WHERE session_id = ? AND id IN ({placeholders})"
567 );
568 let mut query = sqlx::query(&sql).bind(session_id);
569 for id in chunk {
570 query = query.bind(id);
571 }
572 query.execute(&mut *tx).await?;
573 }
574
575 sqlx::query(
577 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens)
578 VALUES (?, 'system', ?, NULL, NULL, NULL, NULL)",
579 )
580 .bind(session_id)
581 .bind(summary)
582 .execute(&mut *tx)
583 .await?;
584
585 let continuation = "Your context was compacted. The previous message contains a summary of our earlier conversation. \
587 Do not mention the summary or that compaction occurred. \
588 Continue the conversation naturally based on the summarized context.";
589 sqlx::query(
590 "INSERT INTO messages (session_id, role, content, tool_calls, tool_call_id, prompt_tokens, completion_tokens)
591 VALUES (?, 'assistant', ?, NULL, NULL, NULL, NULL)",
592 )
593 .bind(session_id)
594 .bind(continuation)
595 .execute(&mut *tx)
596 .await?;
597
598 tx.commit().await?;
599
600 Ok(archived_count)
601 }
602
603 async fn has_pending_tool_calls(&self, session_id: &str) -> Result<bool> {
606 let last_msg: Option<(String, Option<String>)> = sqlx::query_as(
607 "SELECT role, tool_calls FROM messages
608 WHERE session_id = ? AND compacted_at IS NULL
609 ORDER BY id DESC LIMIT 1",
610 )
611 .bind(session_id)
612 .fetch_optional(&self.pool)
613 .await?;
614
615 Ok(matches!(last_msg, Some((role, Some(_))) if role == "assistant"))
616 }
617
618 async fn clear_message_content(&self, message_ids: &[i64], stub: &str) -> Result<()> {
619 if message_ids.is_empty() {
620 return Ok(());
621 }
622 for chunk in message_ids.chunks(500) {
623 let placeholders: String = chunk.iter().map(|_| "?").collect::<Vec<_>>().join(",");
624 let sql = format!("UPDATE messages SET content = ? WHERE id IN ({placeholders})");
625 let mut query = sqlx::query(&sql).bind(stub);
626 for id in chunk {
627 query = query.bind(id);
628 }
629 query.execute(&self.pool).await?;
630 }
631 Ok(())
632 }
633
634 async fn compacted_stats(&self) -> Result<CompactedStats> {
636 let row: (i64, i64, i64, Option<String>) = sqlx::query_as(
637 "SELECT
638 COUNT(*),
639 COUNT(DISTINCT session_id),
640 COALESCE(SUM(LENGTH(content) + LENGTH(COALESCE(tool_calls,''))), 0),
641 MIN(compacted_at)
642 FROM messages
643 WHERE compacted_at IS NOT NULL",
644 )
645 .fetch_one(&self.pool)
646 .await?;
647
648 Ok(CompactedStats {
649 message_count: row.0,
650 session_count: row.1,
651 size_bytes: row.2,
652 oldest: row.3,
653 })
654 }
655
656 async fn purge_compacted(&self, min_age_days: u32) -> Result<usize> {
659 let result = if min_age_days == 0 {
660 sqlx::query("DELETE FROM messages WHERE compacted_at IS NOT NULL")
661 .execute(&self.pool)
662 .await?
663 } else {
664 sqlx::query(
665 "DELETE FROM messages
666 WHERE compacted_at IS NOT NULL
667 AND compacted_at < datetime('now', ?)",
668 )
669 .bind(format!("-{min_age_days} days"))
670 .execute(&self.pool)
671 .await?
672 };
673
674 let deleted = result.rows_affected() as usize;
675
676 sqlx::query("VACUUM").execute(&self.pool).await?;
678
679 tracing::info!("Purged {deleted} compacted messages (>{min_age_days} days old)");
680 Ok(deleted)
681 }
682
683 async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> {
685 let row: Option<(String,)> =
686 sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ? AND key = ?")
687 .bind(session_id)
688 .bind(key)
689 .fetch_optional(&self.pool)
690 .await?;
691 Ok(row.map(|r| r.0))
692 }
693
694 async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> {
696 sqlx::query(
697 "INSERT INTO session_metadata (session_id, key, value, updated_at)
698 VALUES (?, ?, ?, CURRENT_TIMESTAMP)
699 ON CONFLICT(session_id, key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
700 )
701 .bind(session_id)
702 .bind(key)
703 .bind(value)
704 .execute(&self.pool)
705 .await?;
706 Ok(())
707 }
708
709 async fn get_todo(&self, session_id: &str) -> Result<Option<String>> {
711 self.get_metadata(session_id, "todo").await
712 }
713
714 async fn set_todo(&self, session_id: &str, content: &str) -> Result<()> {
716 self.set_metadata(session_id, "todo", content).await
717 }
718}