agent_sdk_store_postgres/
checkpoint.rs1use agent_sdk_core::{
2 AgentError, CheckpointPrunePolicy, CheckpointPruneReport, CheckpointSaveOutcome,
3 CheckpointStore, JournalCursor, RunCheckpoint, RunId,
4};
5use serde_json::Value;
6
7use crate::{
8 PostgresStoreClient,
9 util::{decode_row, json_value},
10};
11
12#[derive(Clone)]
13pub struct PostgresCheckpointStore {
14 client: PostgresStoreClient,
15}
16
17impl PostgresCheckpointStore {
18 pub fn new(client: PostgresStoreClient) -> Self {
19 Self { client }
20 }
21}
22
23impl CheckpointStore for PostgresCheckpointStore {
24 fn save(
25 &self,
26 checkpoint: RunCheckpoint,
27 latest_journal_seq: u64,
28 ) -> Result<CheckpointSaveOutcome, AgentError> {
29 checkpoint.validate_against_latest_seq(latest_journal_seq)?;
30 let terminal_checkpoint = checkpoint.pending_side_effects.is_empty()
31 && checkpoint.pending_approvals.is_empty()
32 && checkpoint.loop_state == "terminal";
33 self.client.execute(
34 format!("insert into {} (store_scope, run_id, checkpoint_id, covers_journal_seq, checkpoint_json) values ($1, $2, $3, $4, $5) on conflict (store_scope, run_id, checkpoint_id) do update set covers_journal_seq = excluded.covers_journal_seq, checkpoint_json = excluded.checkpoint_json", self.client.table("agent_sdk_checkpoints")),
35 vec![
36 self.client.scope(),
37 Value::String(checkpoint.run_id.as_str().to_string()),
38 Value::String(checkpoint.checkpoint_id.clone()),
39 Value::from(checkpoint.covers_journal_seq),
40 json_value(&checkpoint)?,
41 ],
42 )?;
43 Ok(CheckpointSaveOutcome {
44 checkpoint_ref: checkpoint.checkpoint_id,
45 covers_journal_seq: checkpoint.covers_journal_seq,
46 terminal_checkpoint,
47 })
48 }
49
50 fn load_latest(&self, run_id: &RunId) -> Result<Option<RunCheckpoint>, AgentError> {
51 let response = self.client.execute(
52 format!("select checkpoint_json from {} where store_scope = $1 and run_id = $2 order by covers_journal_seq desc limit 1", self.client.table("agent_sdk_checkpoints")),
53 vec![self.client.scope(), Value::String(run_id.as_str().to_string())],
54 )?;
55 response
56 .rows
57 .into_iter()
58 .next()
59 .map(|row| decode_row(row, "checkpoint_json"))
60 .transpose()
61 }
62
63 fn load_at_or_before(
64 &self,
65 run_id: &RunId,
66 cursor: &JournalCursor,
67 ) -> Result<Option<RunCheckpoint>, AgentError> {
68 let seq = cursor
69 .as_str()
70 .strip_prefix("journal.")
71 .unwrap_or(cursor.as_str());
72 let response = self.client.execute(
73 format!("select checkpoint_json from {} where store_scope = $1 and run_id = $2 and covers_journal_seq <= $3 order by covers_journal_seq desc limit 1", self.client.table("agent_sdk_checkpoints")),
74 vec![self.client.scope(), Value::String(run_id.as_str().to_string()), Value::String(seq.to_string())],
75 )?;
76 response
77 .rows
78 .into_iter()
79 .next()
80 .map(|row| decode_row(row, "checkpoint_json"))
81 .transpose()
82 }
83
84 fn prune(
85 &self,
86 run_id: &RunId,
87 policy: CheckpointPrunePolicy,
88 ) -> Result<CheckpointPruneReport, AgentError> {
89 let response = self.client.execute(
90 format!(
91 "delete from {} where store_scope = $1 and run_id = $2 and covers_journal_seq < $3",
92 self.client.table("agent_sdk_checkpoints")
93 ),
94 vec![
95 self.client.scope(),
96 Value::String(run_id.as_str().to_string()),
97 Value::from(policy.prune_covered_before),
98 ],
99 )?;
100 Ok(CheckpointPruneReport {
101 run_id: run_id.clone(),
102 pruned_count: response.affected as usize,
103 retained_count: 0,
104 preserved_terminal_checkpoint: None,
105 })
106 }
107}