Skip to main content

agent_sdk_store_sqlite/
checkpoint.rs

1use std::path::{Path, PathBuf};
2
3use agent_sdk_core::{
4    AgentError, CheckpointPrunePolicy, CheckpointPruneReport, CheckpointSaveOutcome,
5    CheckpointStore, JournalCursor, RunCheckpoint, RunId,
6};
7use rusqlite::params;
8
9use crate::util::{decode, encode, open, sqlite_error};
10
11const SCHEMA: &str = "
12CREATE TABLE IF NOT EXISTS checkpoints (
13    run_id TEXT NOT NULL,
14    checkpoint_id TEXT NOT NULL,
15    latest_journal_seq INTEGER NOT NULL,
16    checkpoint_json TEXT NOT NULL,
17    PRIMARY KEY (run_id, checkpoint_id)
18);
19CREATE INDEX IF NOT EXISTS idx_checkpoints_latest
20ON checkpoints(run_id, latest_journal_seq);
21";
22
23#[derive(Clone, Debug)]
24/// SQLite-backed checkpoint store.
25pub struct SqliteCheckpointStore {
26    path: PathBuf,
27}
28
29impl SqliteCheckpointStore {
30    /// Opens or creates a SQLite checkpoint store.
31    pub fn open(path: impl AsRef<Path>) -> Result<Self, AgentError> {
32        crate::util::init(path.as_ref(), SCHEMA)?;
33        Ok(Self {
34            path: path.as_ref().to_path_buf(),
35        })
36    }
37}
38
39impl CheckpointStore for SqliteCheckpointStore {
40    fn save(
41        &self,
42        checkpoint: RunCheckpoint,
43        latest_journal_seq: u64,
44    ) -> Result<CheckpointSaveOutcome, AgentError> {
45        checkpoint.validate_against_latest_seq(latest_journal_seq)?;
46        let checkpoint_ref = checkpoint.checkpoint_id.clone();
47        let covers_journal_seq = checkpoint.covers_journal_seq;
48        let terminal_checkpoint = checkpoint.pending_side_effects.is_empty()
49            && checkpoint.pending_approvals.is_empty()
50            && checkpoint.loop_state == "terminal";
51        let connection = open(&self.path)?;
52        connection
53            .execute(
54                "INSERT OR REPLACE INTO checkpoints
55                 (run_id, checkpoint_id, latest_journal_seq, checkpoint_json)
56                 VALUES (?1, ?2, ?3, ?4)",
57                params![
58                    checkpoint.run_id.as_str(),
59                    checkpoint.checkpoint_id,
60                    covers_journal_seq as i64,
61                    encode(&checkpoint)?,
62                ],
63            )
64            .map_err(sqlite_error)?;
65        Ok(CheckpointSaveOutcome {
66            checkpoint_ref,
67            covers_journal_seq,
68            terminal_checkpoint,
69        })
70    }
71
72    fn load_latest(&self, run_id: &RunId) -> Result<Option<RunCheckpoint>, AgentError> {
73        let connection = open(&self.path)?;
74        let mut statement = connection
75            .prepare(
76                "SELECT checkpoint_json FROM checkpoints
77                 WHERE run_id = ?1 ORDER BY latest_journal_seq DESC LIMIT 1",
78            )
79            .map_err(sqlite_error)?;
80        let mut rows = statement
81            .query_map(params![run_id.as_str()], |row| row.get::<_, String>(0))
82            .map_err(sqlite_error)?;
83        match rows.next() {
84            Some(row) => Ok(Some(decode(&row.map_err(sqlite_error)?)?)),
85            None => Ok(None),
86        }
87    }
88
89    fn load_at_or_before(
90        &self,
91        run_id: &RunId,
92        cursor: &JournalCursor,
93    ) -> Result<Option<RunCheckpoint>, AgentError> {
94        let latest_seq = cursor
95            .as_str()
96            .strip_prefix("journal.")
97            .unwrap_or(cursor.as_str())
98            .parse::<u64>()
99            .unwrap_or_default();
100        let connection = open(&self.path)?;
101        let mut statement = connection
102            .prepare(
103                "SELECT checkpoint_json FROM checkpoints
104                 WHERE run_id = ?1 AND latest_journal_seq <= ?2
105                 ORDER BY latest_journal_seq DESC LIMIT 1",
106            )
107            .map_err(sqlite_error)?;
108        let mut rows = statement
109            .query_map(params![run_id.as_str(), latest_seq as i64], |row| {
110                row.get::<_, String>(0)
111            })
112            .map_err(sqlite_error)?;
113        match rows.next() {
114            Some(row) => Ok(Some(decode(&row.map_err(sqlite_error)?)?)),
115            None => Ok(None),
116        }
117    }
118
119    fn prune(
120        &self,
121        run_id: &RunId,
122        policy: CheckpointPrunePolicy,
123    ) -> Result<CheckpointPruneReport, AgentError> {
124        let connection = open(&self.path)?;
125        let keep_latest = self
126            .load_latest(run_id)?
127            .map(|checkpoint| checkpoint.checkpoint_id);
128        let mut statement = connection
129            .prepare(
130                "SELECT checkpoint_id FROM checkpoints
131                 WHERE run_id = ?1 AND latest_journal_seq < ?2",
132            )
133            .map_err(sqlite_error)?;
134        let rows = statement
135            .query_map(
136                params![run_id.as_str(), policy.prune_covered_before as i64],
137                |row| row.get::<_, String>(0),
138            )
139            .map_err(sqlite_error)?;
140        let retained_before = connection
141            .query_row(
142                "SELECT COUNT(*) FROM checkpoints WHERE run_id = ?1",
143                params![run_id.as_str()],
144                |row| row.get::<_, i64>(0),
145            )
146            .map_err(sqlite_error)? as usize;
147        let mut deleted = 0;
148        for row in rows {
149            let checkpoint_id = row.map_err(sqlite_error)?;
150            if policy.preserve_latest_terminal && Some(checkpoint_id.clone()) == keep_latest {
151                continue;
152            }
153            deleted += connection
154                .execute(
155                    "DELETE FROM checkpoints WHERE run_id = ?1 AND checkpoint_id = ?2",
156                    params![run_id.as_str(), checkpoint_id],
157                )
158                .map_err(sqlite_error)?;
159        }
160        Ok(CheckpointPruneReport {
161            run_id: run_id.clone(),
162            pruned_count: deleted,
163            retained_count: retained_before.saturating_sub(deleted),
164            preserved_terminal_checkpoint: if policy.preserve_latest_terminal {
165                keep_latest
166            } else {
167                None
168            },
169        })
170    }
171}