agent_sdk_store_sqlite/
checkpoint.rs1use std::path::{Path, PathBuf};
2
3use agent_sdk_core::{
4 AgentError, CheckpointPrunePolicy, CheckpointPruneReport, CheckpointSaveOutcome,
5 CheckpointStore, JournalCursor, RunCheckpoint, RunId,
6};
7use rusqlite::params;
8
9use crate::util::{decode, encode, open, sqlite_error};
10
11const SCHEMA: &str = "
12CREATE TABLE IF NOT EXISTS checkpoints (
13 run_id TEXT NOT NULL,
14 checkpoint_id TEXT NOT NULL,
15 latest_journal_seq INTEGER NOT NULL,
16 checkpoint_json TEXT NOT NULL,
17 PRIMARY KEY (run_id, checkpoint_id)
18);
19CREATE INDEX IF NOT EXISTS idx_checkpoints_latest
20ON checkpoints(run_id, latest_journal_seq);
21";
22
23#[derive(Clone, Debug)]
24pub struct SqliteCheckpointStore {
26 path: PathBuf,
27}
28
29impl SqliteCheckpointStore {
30 pub fn open(path: impl AsRef<Path>) -> Result<Self, AgentError> {
32 crate::util::init(path.as_ref(), SCHEMA)?;
33 Ok(Self {
34 path: path.as_ref().to_path_buf(),
35 })
36 }
37}
38
39impl CheckpointStore for SqliteCheckpointStore {
40 fn save(
41 &self,
42 checkpoint: RunCheckpoint,
43 latest_journal_seq: u64,
44 ) -> Result<CheckpointSaveOutcome, AgentError> {
45 checkpoint.validate_against_latest_seq(latest_journal_seq)?;
46 let checkpoint_ref = checkpoint.checkpoint_id.clone();
47 let covers_journal_seq = checkpoint.covers_journal_seq;
48 let terminal_checkpoint = checkpoint.pending_side_effects.is_empty()
49 && checkpoint.pending_approvals.is_empty()
50 && checkpoint.loop_state == "terminal";
51 let connection = open(&self.path)?;
52 connection
53 .execute(
54 "INSERT OR REPLACE INTO checkpoints
55 (run_id, checkpoint_id, latest_journal_seq, checkpoint_json)
56 VALUES (?1, ?2, ?3, ?4)",
57 params![
58 checkpoint.run_id.as_str(),
59 checkpoint.checkpoint_id,
60 covers_journal_seq as i64,
61 encode(&checkpoint)?,
62 ],
63 )
64 .map_err(sqlite_error)?;
65 Ok(CheckpointSaveOutcome {
66 checkpoint_ref,
67 covers_journal_seq,
68 terminal_checkpoint,
69 })
70 }
71
72 fn load_latest(&self, run_id: &RunId) -> Result<Option<RunCheckpoint>, AgentError> {
73 let connection = open(&self.path)?;
74 let mut statement = connection
75 .prepare(
76 "SELECT checkpoint_json FROM checkpoints
77 WHERE run_id = ?1 ORDER BY latest_journal_seq DESC LIMIT 1",
78 )
79 .map_err(sqlite_error)?;
80 let mut rows = statement
81 .query_map(params![run_id.as_str()], |row| row.get::<_, String>(0))
82 .map_err(sqlite_error)?;
83 match rows.next() {
84 Some(row) => Ok(Some(decode(&row.map_err(sqlite_error)?)?)),
85 None => Ok(None),
86 }
87 }
88
89 fn load_at_or_before(
90 &self,
91 run_id: &RunId,
92 cursor: &JournalCursor,
93 ) -> Result<Option<RunCheckpoint>, AgentError> {
94 let latest_seq = cursor
95 .as_str()
96 .strip_prefix("journal.")
97 .unwrap_or(cursor.as_str())
98 .parse::<u64>()
99 .unwrap_or_default();
100 let connection = open(&self.path)?;
101 let mut statement = connection
102 .prepare(
103 "SELECT checkpoint_json FROM checkpoints
104 WHERE run_id = ?1 AND latest_journal_seq <= ?2
105 ORDER BY latest_journal_seq DESC LIMIT 1",
106 )
107 .map_err(sqlite_error)?;
108 let mut rows = statement
109 .query_map(params![run_id.as_str(), latest_seq as i64], |row| {
110 row.get::<_, String>(0)
111 })
112 .map_err(sqlite_error)?;
113 match rows.next() {
114 Some(row) => Ok(Some(decode(&row.map_err(sqlite_error)?)?)),
115 None => Ok(None),
116 }
117 }
118
119 fn prune(
120 &self,
121 run_id: &RunId,
122 policy: CheckpointPrunePolicy,
123 ) -> Result<CheckpointPruneReport, AgentError> {
124 let connection = open(&self.path)?;
125 let keep_latest = self
126 .load_latest(run_id)?
127 .map(|checkpoint| checkpoint.checkpoint_id);
128 let mut statement = connection
129 .prepare(
130 "SELECT checkpoint_id FROM checkpoints
131 WHERE run_id = ?1 AND latest_journal_seq < ?2",
132 )
133 .map_err(sqlite_error)?;
134 let rows = statement
135 .query_map(
136 params![run_id.as_str(), policy.prune_covered_before as i64],
137 |row| row.get::<_, String>(0),
138 )
139 .map_err(sqlite_error)?;
140 let retained_before = connection
141 .query_row(
142 "SELECT COUNT(*) FROM checkpoints WHERE run_id = ?1",
143 params![run_id.as_str()],
144 |row| row.get::<_, i64>(0),
145 )
146 .map_err(sqlite_error)? as usize;
147 let mut deleted = 0;
148 for row in rows {
149 let checkpoint_id = row.map_err(sqlite_error)?;
150 if policy.preserve_latest_terminal && Some(checkpoint_id.clone()) == keep_latest {
151 continue;
152 }
153 deleted += connection
154 .execute(
155 "DELETE FROM checkpoints WHERE run_id = ?1 AND checkpoint_id = ?2",
156 params![run_id.as_str(), checkpoint_id],
157 )
158 .map_err(sqlite_error)?;
159 }
160 Ok(CheckpointPruneReport {
161 run_id: run_id.clone(),
162 pruned_count: deleted,
163 retained_count: retained_before.saturating_sub(deleted),
164 preserved_terminal_checkpoint: if policy.preserve_latest_terminal {
165 keep_latest
166 } else {
167 None
168 },
169 })
170 }
171}