agent_sdk_store_supabase/
checkpoint.rs1use agent_sdk_core::{
2 AgentError, CheckpointPrunePolicy, CheckpointPruneReport, CheckpointSaveOutcome,
3 CheckpointStore, JournalCursor, RunCheckpoint, RunId,
4};
5use serde_json::json;
6
7use crate::{client::SupabaseClient, transport::supabase_error};
8
9#[derive(Clone)]
10pub struct SupabaseCheckpointStore {
12 client: SupabaseClient,
13}
14
15impl SupabaseCheckpointStore {
16 pub fn new(client: SupabaseClient) -> Self {
17 Self { client }
18 }
19}
20
21impl CheckpointStore for SupabaseCheckpointStore {
22 fn save(
23 &self,
24 checkpoint: RunCheckpoint,
25 latest_journal_seq: u64,
26 ) -> Result<CheckpointSaveOutcome, AgentError> {
27 checkpoint.validate_against_latest_seq(latest_journal_seq)?;
28 let checkpoint_ref = checkpoint.checkpoint_id.clone();
29 let covers_journal_seq = checkpoint.covers_journal_seq;
30 let response = self.client.insert(
31 "agent_sdk_checkpoints",
32 &json!({
33 "store_scope": self.client.config().store_scope(),
34 "run_id": checkpoint.run_id.as_str(),
35 "checkpoint_id": checkpoint.checkpoint_id,
36 "covers_journal_seq": checkpoint.covers_journal_seq,
37 "checkpoint": checkpoint,
38 }),
39 )?;
40 if !(200..300).contains(&response.status) {
41 return Err(supabase_error(format!(
42 "supabase checkpoint save failed with status {}",
43 response.status
44 )));
45 }
46 Ok(CheckpointSaveOutcome {
47 checkpoint_ref,
48 covers_journal_seq,
49 terminal_checkpoint: false,
50 })
51 }
52
53 fn load_latest(&self, run_id: &RunId) -> Result<Option<RunCheckpoint>, AgentError> {
54 let query = format!(
55 "store_scope=eq.{}&run_id=eq.{}&select=checkpoint&order=covers_journal_seq.desc&limit=1",
56 self.client.config().store_scope(),
57 run_id.as_str()
58 );
59 load_checkpoint(self.client.select("agent_sdk_checkpoints", &query)?)
60 }
61
62 fn load_at_or_before(
63 &self,
64 run_id: &RunId,
65 cursor: &JournalCursor,
66 ) -> Result<Option<RunCheckpoint>, AgentError> {
67 let seq = cursor
68 .as_str()
69 .strip_prefix("journal.")
70 .unwrap_or(cursor.as_str());
71 let query = format!(
72 "store_scope=eq.{}&run_id=eq.{}&covers_journal_seq=lte.{}&select=checkpoint&order=covers_journal_seq.desc&limit=1",
73 self.client.config().store_scope(),
74 run_id.as_str(),
75 seq
76 );
77 load_checkpoint(self.client.select("agent_sdk_checkpoints", &query)?)
78 }
79
80 fn prune(
81 &self,
82 run_id: &RunId,
83 policy: CheckpointPrunePolicy,
84 ) -> Result<CheckpointPruneReport, AgentError> {
85 let response = self.client.rpc(
86 "agent_sdk_prune_checkpoints",
87 &json!({
88 "p_store_scope": self.client.config().store_scope(),
89 "p_run_id": run_id.as_str(),
90 "p_prune_covered_before": policy.prune_covered_before,
91 "p_preserve_latest_terminal": policy.preserve_latest_terminal,
92 }),
93 )?;
94 if !(200..300).contains(&response.status) {
95 return Err(supabase_error(format!(
96 "supabase checkpoint prune failed with status {}",
97 response.status
98 )));
99 }
100 Ok(CheckpointPruneReport {
101 run_id: run_id.clone(),
102 pruned_count: 0,
103 retained_count: 0,
104 preserved_terminal_checkpoint: None,
105 })
106 }
107}
108
109fn load_checkpoint(
110 response: crate::transport::SupabaseHttpResponse,
111) -> Result<Option<RunCheckpoint>, AgentError> {
112 if !(200..300).contains(&response.status) {
113 return Err(supabase_error(format!(
114 "supabase checkpoint read failed with status {}",
115 response.status
116 )));
117 }
118 let rows = serde_json::from_slice::<Vec<serde_json::Value>>(&response.body)
119 .map_err(|error| supabase_error(error.to_string()))?;
120 rows.into_iter()
121 .next()
122 .map(|row| {
123 serde_json::from_value::<RunCheckpoint>(row["checkpoint"].clone())
124 .map_err(|error| supabase_error(error.to_string()))
125 })
126 .transpose()
127}