1use crate::{Database, DbResultExt};
2use roboticus_core::Result;
3use rusqlite::OptionalExtension;
4
5#[derive(Debug, Clone)]
6pub struct ContextCheckpoint {
7 pub id: String,
8 pub session_id: String,
9 pub system_prompt_hash: String,
10 pub memory_summary: String,
11 pub active_tasks: Option<String>,
12 pub conversation_digest: Option<String>,
13 pub turn_count: i64,
14 pub created_at: String,
15}
16
17pub fn save_checkpoint(
19 db: &Database,
20 session_id: &str,
21 system_prompt_hash: &str,
22 memory_summary: &str,
23 active_tasks: Option<&str>,
24 conversation_digest: Option<&str>,
25 turn_count: i64,
26) -> Result<String> {
27 let conn = db.conn();
28 let id = uuid::Uuid::new_v4().to_string();
29 conn.execute(
30 "INSERT INTO context_checkpoints (id, session_id, system_prompt_hash, memory_summary, active_tasks, conversation_digest, turn_count) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
31 rusqlite::params![id, session_id, system_prompt_hash, memory_summary, active_tasks, conversation_digest, turn_count],
32 ).db_err()?;
33 Ok(id)
34}
35
36pub fn load_checkpoint(db: &Database, session_id: &str) -> Result<Option<ContextCheckpoint>> {
38 let conn = db.conn();
39 conn.query_row(
40 "SELECT id, session_id, system_prompt_hash, memory_summary, active_tasks, conversation_digest, turn_count, created_at \
41 FROM context_checkpoints WHERE session_id = ?1 ORDER BY created_at DESC, rowid DESC LIMIT 1",
42 [session_id],
43 |row| {
44 Ok(ContextCheckpoint {
45 id: row.get(0)?,
46 session_id: row.get(1)?,
47 system_prompt_hash: row.get(2)?,
48 memory_summary: row.get(3)?,
49 active_tasks: row.get(4)?,
50 conversation_digest: row.get(5)?,
51 turn_count: row.get(6)?,
52 created_at: row.get(7)?,
53 })
54 },
55 )
56 .optional()
57 .db_err()
58}
59
60pub fn clear_checkpoints(db: &Database, session_id: &str) -> Result<usize> {
62 let conn = db.conn();
63 let deleted = conn
64 .execute(
65 "DELETE FROM context_checkpoints WHERE session_id = ?1",
66 [session_id],
67 )
68 .db_err()?;
69 Ok(deleted)
70}
71
72pub fn prune_checkpoints(db: &Database, keep_per_session: usize) -> Result<usize> {
75 let conn = db.conn();
76 let deleted = conn
77 .execute(
78 "DELETE FROM context_checkpoints \
79 WHERE rowid NOT IN ( \
80 SELECT rowid FROM ( \
81 SELECT rowid, ROW_NUMBER() OVER (PARTITION BY session_id ORDER BY created_at DESC, rowid DESC) AS rn \
82 FROM context_checkpoints \
83 ) WHERE rn <= ?1 \
84 )",
85 [keep_per_session as i64],
86 )
87 .db_err()?;
88 Ok(deleted)
89}
90
91pub fn count_checkpoints(db: &Database, session_id: &str) -> Result<i64> {
93 let conn = db.conn();
94 conn.query_row(
95 "SELECT COUNT(*) FROM context_checkpoints WHERE session_id = ?1",
96 [session_id],
97 |row| row.get(0),
98 )
99 .db_err()
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 fn test_db() -> Database {
107 Database::new(":memory:").unwrap()
108 }
109
110 fn create_session(db: &Database) -> String {
111 crate::sessions::find_or_create(db, "test-agent", None).unwrap()
112 }
113
114 #[test]
115 fn save_and_load_checkpoint() {
116 let db = test_db();
117 let sid = create_session(&db);
118 let cid = save_checkpoint(
119 &db,
120 &sid,
121 "hash123",
122 "memory summary",
123 Some("tasks"),
124 Some("digest"),
125 10,
126 )
127 .unwrap();
128 assert!(!cid.is_empty());
129
130 let cp = load_checkpoint(&db, &sid).unwrap().unwrap();
131 assert_eq!(cp.session_id, sid);
132 assert_eq!(cp.system_prompt_hash, "hash123");
133 assert_eq!(cp.memory_summary, "memory summary");
134 assert_eq!(cp.active_tasks.as_deref(), Some("tasks"));
135 assert_eq!(cp.conversation_digest.as_deref(), Some("digest"));
136 assert_eq!(cp.turn_count, 10);
137 }
138
139 #[test]
140 fn load_checkpoint_returns_most_recent() {
141 let db = test_db();
142 let sid = create_session(&db);
143 save_checkpoint(&db, &sid, "old", "old summary", None, None, 5).unwrap();
144 save_checkpoint(&db, &sid, "new", "new summary", None, None, 15).unwrap();
145
146 let cp = load_checkpoint(&db, &sid).unwrap().unwrap();
147 assert_eq!(cp.system_prompt_hash, "new");
148 assert_eq!(cp.turn_count, 15);
149 }
150
151 #[test]
152 fn load_checkpoint_no_session_returns_none() {
153 let db = test_db();
154 let cp = load_checkpoint(&db, "nonexistent").unwrap();
155 assert!(cp.is_none());
156 }
157
158 #[test]
159 fn clear_checkpoints_removes_all() {
160 let db = test_db();
161 let sid = create_session(&db);
162 save_checkpoint(&db, &sid, "h1", "s1", None, None, 1).unwrap();
163 save_checkpoint(&db, &sid, "h2", "s2", None, None, 2).unwrap();
164
165 let cleared = clear_checkpoints(&db, &sid).unwrap();
166 assert_eq!(cleared, 2);
167
168 let cp = load_checkpoint(&db, &sid).unwrap();
169 assert!(cp.is_none());
170 }
171
172 #[test]
173 fn count_checkpoints_accurate() {
174 let db = test_db();
175 let sid = create_session(&db);
176 assert_eq!(count_checkpoints(&db, &sid).unwrap(), 0);
177 save_checkpoint(&db, &sid, "h1", "s1", None, None, 1).unwrap();
178 assert_eq!(count_checkpoints(&db, &sid).unwrap(), 1);
179 save_checkpoint(&db, &sid, "h2", "s2", None, None, 2).unwrap();
180 assert_eq!(count_checkpoints(&db, &sid).unwrap(), 2);
181 }
182
183 #[test]
184 fn checkpoint_with_no_optional_fields() {
185 let db = test_db();
186 let sid = create_session(&db);
187 save_checkpoint(&db, &sid, "hash", "summary", None, None, 0).unwrap();
188 let cp = load_checkpoint(&db, &sid).unwrap().unwrap();
189 assert!(cp.active_tasks.is_none());
190 assert!(cp.conversation_digest.is_none());
191 }
192
193 #[test]
194 fn prune_checkpoints_keeps_n_per_session() {
195 let db = test_db();
196 let s1 = create_session(&db);
197 let s2 = crate::sessions::find_or_create(&db, "agent-b", None).unwrap();
198
199 for i in 0..5 {
201 save_checkpoint(&db, &s1, &format!("h{i}"), &format!("s{i}"), None, None, i).unwrap();
202 }
203 for i in 0..3 {
204 save_checkpoint(&db, &s2, &format!("h{i}"), &format!("s{i}"), None, None, i).unwrap();
205 }
206 assert_eq!(count_checkpoints(&db, &s1).unwrap(), 5);
207 assert_eq!(count_checkpoints(&db, &s2).unwrap(), 3);
208
209 let pruned = prune_checkpoints(&db, 2).unwrap();
211 assert_eq!(pruned, 4);
212 assert_eq!(count_checkpoints(&db, &s1).unwrap(), 2);
213 assert_eq!(count_checkpoints(&db, &s2).unwrap(), 2);
214
215 let cp = load_checkpoint(&db, &s1).unwrap().unwrap();
217 assert_eq!(cp.turn_count, 4);
218 }
219}