Skip to main content

agent_sdk_store_supabase/
checkpoint.rs

1use agent_sdk_core::{
2    AgentError, CheckpointPrunePolicy, CheckpointPruneReport, CheckpointSaveOutcome,
3    CheckpointStore, JournalCursor, RunCheckpoint, RunId,
4};
5use serde_json::json;
6
7use crate::{client::SupabaseClient, transport::supabase_error};
8
9#[derive(Clone)]
10/// Supabase-backed checkpoint store.
11pub struct SupabaseCheckpointStore {
12    client: SupabaseClient,
13}
14
15impl SupabaseCheckpointStore {
16    pub fn new(client: SupabaseClient) -> Self {
17        Self { client }
18    }
19}
20
21impl CheckpointStore for SupabaseCheckpointStore {
22    fn save(
23        &self,
24        checkpoint: RunCheckpoint,
25        latest_journal_seq: u64,
26    ) -> Result<CheckpointSaveOutcome, AgentError> {
27        checkpoint.validate_against_latest_seq(latest_journal_seq)?;
28        let checkpoint_ref = checkpoint.checkpoint_id.clone();
29        let covers_journal_seq = checkpoint.covers_journal_seq;
30        let response = self.client.insert(
31            "agent_sdk_checkpoints",
32            &json!({
33                "store_scope": self.client.config().store_scope(),
34                "run_id": checkpoint.run_id.as_str(),
35                "checkpoint_id": checkpoint.checkpoint_id,
36                "covers_journal_seq": checkpoint.covers_journal_seq,
37                "checkpoint": checkpoint,
38            }),
39        )?;
40        if !(200..300).contains(&response.status) {
41            return Err(supabase_error(format!(
42                "supabase checkpoint save failed with status {}",
43                response.status
44            )));
45        }
46        Ok(CheckpointSaveOutcome {
47            checkpoint_ref,
48            covers_journal_seq,
49            terminal_checkpoint: false,
50        })
51    }
52
53    fn load_latest(&self, run_id: &RunId) -> Result<Option<RunCheckpoint>, AgentError> {
54        let query = format!(
55            "store_scope=eq.{}&run_id=eq.{}&select=checkpoint&order=covers_journal_seq.desc&limit=1",
56            self.client.config().store_scope(),
57            run_id.as_str()
58        );
59        load_checkpoint(self.client.select("agent_sdk_checkpoints", &query)?)
60    }
61
62    fn load_at_or_before(
63        &self,
64        run_id: &RunId,
65        cursor: &JournalCursor,
66    ) -> Result<Option<RunCheckpoint>, AgentError> {
67        let seq = cursor
68            .as_str()
69            .strip_prefix("journal.")
70            .unwrap_or(cursor.as_str());
71        let query = format!(
72            "store_scope=eq.{}&run_id=eq.{}&covers_journal_seq=lte.{}&select=checkpoint&order=covers_journal_seq.desc&limit=1",
73            self.client.config().store_scope(),
74            run_id.as_str(),
75            seq
76        );
77        load_checkpoint(self.client.select("agent_sdk_checkpoints", &query)?)
78    }
79
80    fn prune(
81        &self,
82        run_id: &RunId,
83        policy: CheckpointPrunePolicy,
84    ) -> Result<CheckpointPruneReport, AgentError> {
85        let response = self.client.rpc(
86            "agent_sdk_prune_checkpoints",
87            &json!({
88                "p_store_scope": self.client.config().store_scope(),
89                "p_run_id": run_id.as_str(),
90                "p_prune_covered_before": policy.prune_covered_before,
91                "p_preserve_latest_terminal": policy.preserve_latest_terminal,
92            }),
93        )?;
94        if !(200..300).contains(&response.status) {
95            return Err(supabase_error(format!(
96                "supabase checkpoint prune failed with status {}",
97                response.status
98            )));
99        }
100        Ok(CheckpointPruneReport {
101            run_id: run_id.clone(),
102            pruned_count: 0,
103            retained_count: 0,
104            preserved_terminal_checkpoint: None,
105        })
106    }
107}
108
109fn load_checkpoint(
110    response: crate::transport::SupabaseHttpResponse,
111) -> Result<Option<RunCheckpoint>, AgentError> {
112    if !(200..300).contains(&response.status) {
113        return Err(supabase_error(format!(
114            "supabase checkpoint read failed with status {}",
115            response.status
116        )));
117    }
118    let rows = serde_json::from_slice::<Vec<serde_json::Value>>(&response.body)
119        .map_err(|error| supabase_error(error.to_string()))?;
120    rows.into_iter()
121        .next()
122        .map(|row| {
123            serde_json::from_value::<RunCheckpoint>(row["checkpoint"].clone())
124                .map_err(|error| supabase_error(error.to_string()))
125        })
126        .transpose()
127}