Skip to main content

roboticus_db/
checkpoint.rs

1use crate::{Database, DbResultExt};
2use roboticus_core::Result;
3use rusqlite::OptionalExtension;
4
5#[derive(Debug, Clone)]
6pub struct ContextCheckpoint {
7    pub id: String,
8    pub session_id: String,
9    pub system_prompt_hash: String,
10    pub memory_summary: String,
11    pub active_tasks: Option<String>,
12    pub conversation_digest: Option<String>,
13    pub turn_count: i64,
14    pub created_at: String,
15}
16
17/// Save a new checkpoint for a session.
18pub fn save_checkpoint(
19    db: &Database,
20    session_id: &str,
21    system_prompt_hash: &str,
22    memory_summary: &str,
23    active_tasks: Option<&str>,
24    conversation_digest: Option<&str>,
25    turn_count: i64,
26) -> Result<String> {
27    let conn = db.conn();
28    let id = uuid::Uuid::new_v4().to_string();
29    conn.execute(
30        "INSERT INTO context_checkpoints (id, session_id, system_prompt_hash, memory_summary, active_tasks, conversation_digest, turn_count) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
31        rusqlite::params![id, session_id, system_prompt_hash, memory_summary, active_tasks, conversation_digest, turn_count],
32    ).db_err()?;
33    Ok(id)
34}
35
36/// Load the most recent checkpoint for a session.
37pub fn load_checkpoint(db: &Database, session_id: &str) -> Result<Option<ContextCheckpoint>> {
38    let conn = db.conn();
39    conn.query_row(
40        "SELECT id, session_id, system_prompt_hash, memory_summary, active_tasks, conversation_digest, turn_count, created_at \
41         FROM context_checkpoints WHERE session_id = ?1 ORDER BY created_at DESC, rowid DESC LIMIT 1",
42        [session_id],
43        |row| {
44            Ok(ContextCheckpoint {
45                id: row.get(0)?,
46                session_id: row.get(1)?,
47                system_prompt_hash: row.get(2)?,
48                memory_summary: row.get(3)?,
49                active_tasks: row.get(4)?,
50                conversation_digest: row.get(5)?,
51                turn_count: row.get(6)?,
52                created_at: row.get(7)?,
53            })
54        },
55    )
56    .optional()
57    .db_err()
58}
59
60/// Delete all checkpoints for a session (used on session archive/expiry).
61pub fn clear_checkpoints(db: &Database, session_id: &str) -> Result<usize> {
62    let conn = db.conn();
63    let deleted = conn
64        .execute(
65            "DELETE FROM context_checkpoints WHERE session_id = ?1",
66            [session_id],
67        )
68        .db_err()?;
69    Ok(deleted)
70}
71
72/// Keep only the most recent `keep_per_session` checkpoints per session,
73/// deleting older ones.  Returns the total number of rows deleted.
74pub fn prune_checkpoints(db: &Database, keep_per_session: usize) -> Result<usize> {
75    let conn = db.conn();
76    let deleted = conn
77        .execute(
78            "DELETE FROM context_checkpoints \
79             WHERE rowid NOT IN ( \
80               SELECT rowid FROM ( \
81                 SELECT rowid, ROW_NUMBER() OVER (PARTITION BY session_id ORDER BY created_at DESC, rowid DESC) AS rn \
82                 FROM context_checkpoints \
83               ) WHERE rn <= ?1 \
84             )",
85            [keep_per_session as i64],
86        )
87        .db_err()?;
88    Ok(deleted)
89}
90
91/// Count checkpoints for a session.
92pub fn count_checkpoints(db: &Database, session_id: &str) -> Result<i64> {
93    let conn = db.conn();
94    conn.query_row(
95        "SELECT COUNT(*) FROM context_checkpoints WHERE session_id = ?1",
96        [session_id],
97        |row| row.get(0),
98    )
99    .db_err()
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    fn test_db() -> Database {
107        Database::new(":memory:").unwrap()
108    }
109
110    fn create_session(db: &Database) -> String {
111        crate::sessions::find_or_create(db, "test-agent", None).unwrap()
112    }
113
114    #[test]
115    fn save_and_load_checkpoint() {
116        let db = test_db();
117        let sid = create_session(&db);
118        let cid = save_checkpoint(
119            &db,
120            &sid,
121            "hash123",
122            "memory summary",
123            Some("tasks"),
124            Some("digest"),
125            10,
126        )
127        .unwrap();
128        assert!(!cid.is_empty());
129
130        let cp = load_checkpoint(&db, &sid).unwrap().unwrap();
131        assert_eq!(cp.session_id, sid);
132        assert_eq!(cp.system_prompt_hash, "hash123");
133        assert_eq!(cp.memory_summary, "memory summary");
134        assert_eq!(cp.active_tasks.as_deref(), Some("tasks"));
135        assert_eq!(cp.conversation_digest.as_deref(), Some("digest"));
136        assert_eq!(cp.turn_count, 10);
137    }
138
139    #[test]
140    fn load_checkpoint_returns_most_recent() {
141        let db = test_db();
142        let sid = create_session(&db);
143        save_checkpoint(&db, &sid, "old", "old summary", None, None, 5).unwrap();
144        save_checkpoint(&db, &sid, "new", "new summary", None, None, 15).unwrap();
145
146        let cp = load_checkpoint(&db, &sid).unwrap().unwrap();
147        assert_eq!(cp.system_prompt_hash, "new");
148        assert_eq!(cp.turn_count, 15);
149    }
150
151    #[test]
152    fn load_checkpoint_no_session_returns_none() {
153        let db = test_db();
154        let cp = load_checkpoint(&db, "nonexistent").unwrap();
155        assert!(cp.is_none());
156    }
157
158    #[test]
159    fn clear_checkpoints_removes_all() {
160        let db = test_db();
161        let sid = create_session(&db);
162        save_checkpoint(&db, &sid, "h1", "s1", None, None, 1).unwrap();
163        save_checkpoint(&db, &sid, "h2", "s2", None, None, 2).unwrap();
164
165        let cleared = clear_checkpoints(&db, &sid).unwrap();
166        assert_eq!(cleared, 2);
167
168        let cp = load_checkpoint(&db, &sid).unwrap();
169        assert!(cp.is_none());
170    }
171
172    #[test]
173    fn count_checkpoints_accurate() {
174        let db = test_db();
175        let sid = create_session(&db);
176        assert_eq!(count_checkpoints(&db, &sid).unwrap(), 0);
177        save_checkpoint(&db, &sid, "h1", "s1", None, None, 1).unwrap();
178        assert_eq!(count_checkpoints(&db, &sid).unwrap(), 1);
179        save_checkpoint(&db, &sid, "h2", "s2", None, None, 2).unwrap();
180        assert_eq!(count_checkpoints(&db, &sid).unwrap(), 2);
181    }
182
183    #[test]
184    fn checkpoint_with_no_optional_fields() {
185        let db = test_db();
186        let sid = create_session(&db);
187        save_checkpoint(&db, &sid, "hash", "summary", None, None, 0).unwrap();
188        let cp = load_checkpoint(&db, &sid).unwrap().unwrap();
189        assert!(cp.active_tasks.is_none());
190        assert!(cp.conversation_digest.is_none());
191    }
192
193    #[test]
194    fn prune_checkpoints_keeps_n_per_session() {
195        let db = test_db();
196        let s1 = create_session(&db);
197        let s2 = crate::sessions::find_or_create(&db, "agent-b", None).unwrap();
198
199        // Create 5 checkpoints for s1, 3 for s2
200        for i in 0..5 {
201            save_checkpoint(&db, &s1, &format!("h{i}"), &format!("s{i}"), None, None, i).unwrap();
202        }
203        for i in 0..3 {
204            save_checkpoint(&db, &s2, &format!("h{i}"), &format!("s{i}"), None, None, i).unwrap();
205        }
206        assert_eq!(count_checkpoints(&db, &s1).unwrap(), 5);
207        assert_eq!(count_checkpoints(&db, &s2).unwrap(), 3);
208
209        // Keep 2 per session → should delete 3 from s1, 1 from s2
210        let pruned = prune_checkpoints(&db, 2).unwrap();
211        assert_eq!(pruned, 4);
212        assert_eq!(count_checkpoints(&db, &s1).unwrap(), 2);
213        assert_eq!(count_checkpoints(&db, &s2).unwrap(), 2);
214
215        // Most recent checkpoint for s1 should have turn_count=4
216        let cp = load_checkpoint(&db, &s1).unwrap().unwrap();
217        assert_eq!(cp.turn_count, 4);
218    }
219}