agent-sdk-store-postgres 0.1.0-alpha.4

Scripted Postgres-style durable store adapters for the Agent SDK.
Documentation
use agent_sdk_core::{
    AgentError, CheckpointPrunePolicy, CheckpointPruneReport, CheckpointSaveOutcome,
    CheckpointStore, JournalCursor, RunCheckpoint, RunId,
};
use serde_json::Value;

use crate::{
    PostgresStoreClient,
    util::{decode_row, json_value},
};

#[derive(Clone)]
pub struct PostgresCheckpointStore {
    client: PostgresStoreClient,
}

impl PostgresCheckpointStore {
    pub fn new(client: PostgresStoreClient) -> Self {
        Self { client }
    }
}

impl CheckpointStore for PostgresCheckpointStore {
    fn save(
        &self,
        checkpoint: RunCheckpoint,
        latest_journal_seq: u64,
    ) -> Result<CheckpointSaveOutcome, AgentError> {
        checkpoint.validate_against_latest_seq(latest_journal_seq)?;
        let terminal_checkpoint = checkpoint.pending_side_effects.is_empty()
            && checkpoint.pending_approvals.is_empty()
            && checkpoint.loop_state == "terminal";
        self.client.execute(
            format!("insert into {} (store_scope, run_id, checkpoint_id, covers_journal_seq, checkpoint_json) values ($1, $2, $3, $4, $5) on conflict (store_scope, run_id, checkpoint_id) do update set covers_journal_seq = excluded.covers_journal_seq, checkpoint_json = excluded.checkpoint_json", self.client.table("agent_sdk_checkpoints")),
            vec![
                self.client.scope(),
                Value::String(checkpoint.run_id.as_str().to_string()),
                Value::String(checkpoint.checkpoint_id.clone()),
                Value::from(checkpoint.covers_journal_seq),
                json_value(&checkpoint)?,
            ],
        )?;
        Ok(CheckpointSaveOutcome {
            checkpoint_ref: checkpoint.checkpoint_id,
            covers_journal_seq: checkpoint.covers_journal_seq,
            terminal_checkpoint,
        })
    }

    fn load_latest(&self, run_id: &RunId) -> Result<Option<RunCheckpoint>, AgentError> {
        let response = self.client.execute(
            format!("select checkpoint_json from {} where store_scope = $1 and run_id = $2 order by covers_journal_seq desc limit 1", self.client.table("agent_sdk_checkpoints")),
            vec![self.client.scope(), Value::String(run_id.as_str().to_string())],
        )?;
        response
            .rows
            .into_iter()
            .next()
            .map(|row| decode_row(row, "checkpoint_json"))
            .transpose()
    }

    fn load_at_or_before(
        &self,
        run_id: &RunId,
        cursor: &JournalCursor,
    ) -> Result<Option<RunCheckpoint>, AgentError> {
        let seq = cursor
            .as_str()
            .strip_prefix("journal.")
            .unwrap_or(cursor.as_str());
        let response = self.client.execute(
            format!("select checkpoint_json from {} where store_scope = $1 and run_id = $2 and covers_journal_seq <= $3 order by covers_journal_seq desc limit 1", self.client.table("agent_sdk_checkpoints")),
            vec![self.client.scope(), Value::String(run_id.as_str().to_string()), Value::String(seq.to_string())],
        )?;
        response
            .rows
            .into_iter()
            .next()
            .map(|row| decode_row(row, "checkpoint_json"))
            .transpose()
    }

    fn prune(
        &self,
        run_id: &RunId,
        policy: CheckpointPrunePolicy,
    ) -> Result<CheckpointPruneReport, AgentError> {
        let response = self.client.execute(
            format!(
                "delete from {} where store_scope = $1 and run_id = $2 and covers_journal_seq < $3",
                self.client.table("agent_sdk_checkpoints")
            ),
            vec![
                self.client.scope(),
                Value::String(run_id.as_str().to_string()),
                Value::from(policy.prune_covered_before),
            ],
        )?;
        Ok(CheckpointPruneReport {
            run_id: run_id.clone(),
            pruned_count: response.affected as usize,
            retained_count: 0,
            preserved_terminal_checkpoint: None,
        })
    }
}