Skip to main content

agent_sdk_store_postgres/
checkpoint.rs

1use agent_sdk_core::{
2    AgentError, CheckpointPrunePolicy, CheckpointPruneReport, CheckpointSaveOutcome,
3    CheckpointStore, JournalCursor, RunCheckpoint, RunId,
4};
5use serde_json::Value;
6
7use crate::{
8    PostgresStoreClient,
9    util::{decode_row, json_value},
10};
11
12#[derive(Clone)]
13pub struct PostgresCheckpointStore {
14    client: PostgresStoreClient,
15}
16
17impl PostgresCheckpointStore {
18    pub fn new(client: PostgresStoreClient) -> Self {
19        Self { client }
20    }
21}
22
23impl CheckpointStore for PostgresCheckpointStore {
24    fn save(
25        &self,
26        checkpoint: RunCheckpoint,
27        latest_journal_seq: u64,
28    ) -> Result<CheckpointSaveOutcome, AgentError> {
29        checkpoint.validate_against_latest_seq(latest_journal_seq)?;
30        let terminal_checkpoint = checkpoint.pending_side_effects.is_empty()
31            && checkpoint.pending_approvals.is_empty()
32            && checkpoint.loop_state == "terminal";
33        self.client.execute(
34            format!("insert into {} (store_scope, run_id, checkpoint_id, covers_journal_seq, checkpoint_json) values ($1, $2, $3, $4, $5) on conflict (store_scope, run_id, checkpoint_id) do update set covers_journal_seq = excluded.covers_journal_seq, checkpoint_json = excluded.checkpoint_json", self.client.table("agent_sdk_checkpoints")),
35            vec![
36                self.client.scope(),
37                Value::String(checkpoint.run_id.as_str().to_string()),
38                Value::String(checkpoint.checkpoint_id.clone()),
39                Value::from(checkpoint.covers_journal_seq),
40                json_value(&checkpoint)?,
41            ],
42        )?;
43        Ok(CheckpointSaveOutcome {
44            checkpoint_ref: checkpoint.checkpoint_id,
45            covers_journal_seq: checkpoint.covers_journal_seq,
46            terminal_checkpoint,
47        })
48    }
49
50    fn load_latest(&self, run_id: &RunId) -> Result<Option<RunCheckpoint>, AgentError> {
51        let response = self.client.execute(
52            format!("select checkpoint_json from {} where store_scope = $1 and run_id = $2 order by covers_journal_seq desc limit 1", self.client.table("agent_sdk_checkpoints")),
53            vec![self.client.scope(), Value::String(run_id.as_str().to_string())],
54        )?;
55        response
56            .rows
57            .into_iter()
58            .next()
59            .map(|row| decode_row(row, "checkpoint_json"))
60            .transpose()
61    }
62
63    fn load_at_or_before(
64        &self,
65        run_id: &RunId,
66        cursor: &JournalCursor,
67    ) -> Result<Option<RunCheckpoint>, AgentError> {
68        let seq = cursor
69            .as_str()
70            .strip_prefix("journal.")
71            .unwrap_or(cursor.as_str());
72        let response = self.client.execute(
73            format!("select checkpoint_json from {} where store_scope = $1 and run_id = $2 and covers_journal_seq <= $3 order by covers_journal_seq desc limit 1", self.client.table("agent_sdk_checkpoints")),
74            vec![self.client.scope(), Value::String(run_id.as_str().to_string()), Value::String(seq.to_string())],
75        )?;
76        response
77            .rows
78            .into_iter()
79            .next()
80            .map(|row| decode_row(row, "checkpoint_json"))
81            .transpose()
82    }
83
84    fn prune(
85        &self,
86        run_id: &RunId,
87        policy: CheckpointPrunePolicy,
88    ) -> Result<CheckpointPruneReport, AgentError> {
89        let response = self.client.execute(
90            format!(
91                "delete from {} where store_scope = $1 and run_id = $2 and covers_journal_seq < $3",
92                self.client.table("agent_sdk_checkpoints")
93            ),
94            vec![
95                self.client.scope(),
96                Value::String(run_id.as_str().to_string()),
97                Value::from(policy.prune_covered_before),
98            ],
99        )?;
100        Ok(CheckpointPruneReport {
101            run_id: run_id.clone(),
102            pruned_count: response.affected as usize,
103            retained_count: 0,
104            preserved_terminal_checkpoint: None,
105        })
106    }
107}