use std::sync::Arc;
use chrono::{DateTime, Utc};
use serde_json::Value;
use sqlx::{PgPool, Row, postgres::PgRow};
use tracing::instrument;
use crate::{
runtimes::checkpointer::{Checkpoint, Checkpointer, CheckpointerError, Result},
runtimes::persistence::{PersistedState, PersistedVersionsSeen},
state::VersionedState,
types::NodeKind,
};
use super::checkpointer_postgres_helpers::{
deserialize_json_value, require_json_field, serialize_json,
};
#[derive(Debug, Clone, Default)]
pub struct StepQuery {
pub limit: Option<u32>,
pub offset: Option<u32>,
pub min_step: Option<u64>,
pub max_step: Option<u64>,
pub ran_node: Option<NodeKind>,
pub skipped_node: Option<NodeKind>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PageInfo {
pub total_count: u64,
pub page_size: u32,
pub offset: u32,
pub has_next_page: bool,
}
#[derive(Debug, Clone)]
pub struct StepQueryResult {
pub checkpoints: Vec<Checkpoint>,
pub page_info: PageInfo,
}
pub struct PostgresCheckpointer {
pool: Arc<PgPool>,
}
impl std::fmt::Debug for PostgresCheckpointer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PostgresCheckpointer").finish()
}
}
impl PostgresCheckpointer {
#[must_use = "checkpointer must be used to persist state"]
#[instrument(skip(database_url))]
pub async fn connect(database_url: &str) -> std::result::Result<Self, CheckpointerError> {
let pool = PgPool::connect(database_url)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("connect error: {e}"),
})?;
#[cfg(feature = "postgres-migrations")]
{
if let Err(e) = sqlx::migrate!("./migrations/postgres").run(&pool).await {
return Err(CheckpointerError::Backend {
message: format!("migration failure: {e}"),
});
}
}
#[cfg(not(feature = "postgres-migrations"))]
{
}
Ok(Self {
pool: Arc::new(pool),
})
}
}
#[async_trait::async_trait]
impl Checkpointer for PostgresCheckpointer {
#[instrument(skip(self, checkpoint), err)]
async fn save(&self, checkpoint: Checkpoint) -> Result<()> {
let persisted_state = PersistedState::from(&checkpoint.state);
let state_json = serialize_json(&persisted_state, "state")?;
let frontier_enc: Vec<String> = checkpoint.frontier.iter().map(|k| k.encode()).collect();
let frontier_json = serialize_json(&frontier_enc, "frontier")?;
let persisted_vs = PersistedVersionsSeen(checkpoint.versions_seen.clone());
let versions_seen_json = serialize_json(&persisted_vs, "versions_seen")?;
let ran_nodes_enc: Vec<String> = checkpoint.ran_nodes.iter().map(|k| k.encode()).collect();
let ran_nodes_json = serialize_json(&ran_nodes_enc, "ran_nodes")?;
let skipped_nodes_enc: Vec<String> = checkpoint
.skipped_nodes
.iter()
.map(|k| k.encode())
.collect();
let skipped_nodes_json = serialize_json(&skipped_nodes_enc, "skipped_nodes")?;
let updated_channels_json =
serialize_json(&checkpoint.updated_channels, "updated_channels")?;
let mut tx = self
.pool
.begin()
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("tx begin: {e}"),
})?;
sqlx::query(
r#"
INSERT INTO sessions (id, concurrency_limit)
VALUES ($1, $2)
ON CONFLICT (id) DO NOTHING
"#,
)
.bind(&checkpoint.session_id)
.bind(checkpoint.concurrency_limit as i64)
.execute(&mut *tx)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("insert session: {e}"),
})?;
sqlx::query(
r#"
INSERT INTO steps (
session_id,
step,
state_json,
frontier_json,
versions_seen_json,
ran_nodes_json,
skipped_nodes_json,
updated_channels_json
) VALUES ($1, $2, $3::jsonb, $4::jsonb, $5::jsonb, $6::jsonb, $7::jsonb, $8::jsonb)
ON CONFLICT (session_id, step) DO UPDATE SET
state_json = EXCLUDED.state_json,
frontier_json = EXCLUDED.frontier_json,
versions_seen_json = EXCLUDED.versions_seen_json,
ran_nodes_json = EXCLUDED.ran_nodes_json,
skipped_nodes_json = EXCLUDED.skipped_nodes_json,
updated_channels_json = EXCLUDED.updated_channels_json
"#,
)
.bind(&checkpoint.session_id)
.bind(checkpoint.step as i64)
.bind(&state_json)
.bind(&frontier_json)
.bind(&versions_seen_json)
.bind(&ran_nodes_json)
.bind(&skipped_nodes_json)
.bind(&updated_channels_json)
.execute(&mut *tx)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("insert step: {e}"),
})?;
sqlx::query(
r#"
UPDATE sessions
SET
updated_at = NOW(),
last_step = CASE WHEN last_step <= $2 THEN $2 ELSE last_step END,
last_state_json = CASE WHEN last_step <= $2 THEN $3::jsonb ELSE last_state_json END,
last_frontier_json = CASE WHEN last_step <= $2 THEN $4::jsonb ELSE last_frontier_json END,
last_versions_seen_json = CASE WHEN last_step <= $2 THEN $5::jsonb ELSE last_versions_seen_json END
WHERE id = $1
"#,
)
.bind(&checkpoint.session_id)
.bind(checkpoint.step as i64)
.bind(&state_json)
.bind(&frontier_json)
.bind(&versions_seen_json)
.execute(&mut *tx)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("update session latest: {e}"),
})?;
tx.commit().await.map_err(|e| CheckpointerError::Backend {
message: format!("tx commit: {e}"),
})?;
Ok(())
}
#[instrument(skip(self, session_id), err)]
async fn load_latest(&self, session_id: &str) -> Result<Option<Checkpoint>> {
let row_opt: Option<PgRow> = sqlx::query(
r#"
SELECT
s.id,
s.last_step,
s.last_state_json,
s.last_frontier_json,
s.last_versions_seen_json,
s.concurrency_limit,
s.updated_at
FROM sessions s
WHERE s.id = $1
"#,
)
.bind(session_id)
.fetch_optional(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("select latest: {e}"),
})?;
let row = match row_opt {
Some(r) => r,
None => return Ok(None),
};
let last_step: i64 = row.get("last_step");
let state_json: Option<Value> =
row.try_get("last_state_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("last_state_json read: {e}"),
})?;
let frontier_json: Option<Value> =
row.try_get("last_frontier_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("last_frontier_json read: {e}"),
})?;
let versions_seen_json: Option<Value> =
row.try_get("last_versions_seen_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("last_versions_seen_json read: {e}"),
})?;
let concurrency_limit: i64 = row.get("concurrency_limit");
let updated_at: DateTime<Utc> = row.get("updated_at");
if last_step == 0 && state_json.is_none() {
return Ok(None);
}
let state_val = require_json_field(state_json, "state_json")?;
let frontier_val = require_json_field(frontier_json, "frontier_json")?;
let versions_seen_val = require_json_field(versions_seen_json, "versions_seen_json")?;
let persisted_state: PersistedState = deserialize_json_value(state_val, "state")?;
let state =
VersionedState::try_from(persisted_state).map_err(|e| CheckpointerError::Other {
message: format!("state convert: {e}"),
})?;
let frontier: Vec<NodeKind> = frontier_val
.as_array()
.ok_or_else(|| CheckpointerError::Other {
message: "frontier not array".to_string(),
})?
.iter()
.filter_map(|v| v.as_str())
.map(NodeKind::decode)
.collect();
let persisted_vs: PersistedVersionsSeen =
deserialize_json_value(versions_seen_val, "versions_seen")?;
let versions_seen = persisted_vs.0;
Ok(Some(Checkpoint {
session_id: session_id.to_string(),
step: last_step as u64,
state,
frontier,
versions_seen,
concurrency_limit: concurrency_limit as usize,
created_at: updated_at,
ran_nodes: vec![],
skipped_nodes: vec![],
updated_channels: vec![],
}))
}
#[instrument(skip(self), err)]
async fn list_sessions(&self) -> Result<Vec<String>> {
let rows = sqlx::query(
r#"
SELECT id FROM sessions
ORDER BY updated_at DESC
"#,
)
.fetch_all(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("list sessions: {e}"),
})?;
Ok(rows.into_iter().map(|r| r.get::<String, _>("id")).collect())
}
}
impl PostgresCheckpointer {
#[instrument(skip(self), err)]
pub async fn query_steps(&self, session_id: &str, query: StepQuery) -> Result<StepQueryResult> {
let mut conditions = vec!["st.session_id = $1".to_string()];
let mut param_count = 1;
if query.min_step.is_some() {
param_count += 1;
conditions.push(format!("st.step >= ${param_count}"));
}
if query.max_step.is_some() {
param_count += 1;
conditions.push(format!("st.step <= ${param_count}"));
}
if query.ran_node.is_some() {
param_count += 1;
conditions.push(format!("st.ran_nodes_json @> ${param_count}::jsonb"));
}
if query.skipped_node.is_some() {
param_count += 1;
conditions.push(format!("st.skipped_nodes_json @> ${param_count}::jsonb"));
}
let where_clause = conditions.join(" AND ");
let count_sql = format!("SELECT COUNT(*) as total FROM steps st WHERE {where_clause}");
let limit = query.limit.unwrap_or(100).min(1000); let offset = query.offset.unwrap_or(0);
let select_sql = format!(
r#"SELECT
st.session_id,
st.step,
st.state_json,
st.frontier_json,
st.versions_seen_json,
st.ran_nodes_json,
st.skipped_nodes_json,
st.updated_channels_json,
st.created_at,
s.concurrency_limit
FROM steps st
JOIN sessions s ON s.id = st.session_id
WHERE {where_clause}
ORDER BY st.step DESC
LIMIT {limit} OFFSET {offset}"#
);
let mut count_query = sqlx::query(&count_sql).bind(session_id);
if let Some(min_step) = query.min_step {
count_query = count_query.bind(min_step as i64);
}
if let Some(max_step) = query.max_step {
count_query = count_query.bind(max_step as i64);
}
if let Some(ran_node) = &query.ran_node {
count_query = count_query.bind(serde_json::json!([ran_node.encode()]));
}
if let Some(skipped_node) = &query.skipped_node {
count_query = count_query.bind(serde_json::json!([skipped_node.encode()]));
}
let total_count: i64 = count_query
.fetch_one(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("count query: {e}"),
})?
.get("total");
let mut select_query = sqlx::query(&select_sql).bind(session_id);
if let Some(min_step) = query.min_step {
select_query = select_query.bind(min_step as i64);
}
if let Some(max_step) = query.max_step {
select_query = select_query.bind(max_step as i64);
}
if let Some(ran_node) = &query.ran_node {
select_query = select_query.bind(serde_json::json!([ran_node.encode()]));
}
if let Some(skipped_node) = &query.skipped_node {
select_query = select_query.bind(serde_json::json!([skipped_node.encode()]));
}
let rows =
select_query
.fetch_all(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("select query: {e}"),
})?;
let mut checkpoints = Vec::new();
for row in rows {
let checkpoint = self.row_to_checkpoint(session_id, &row)?;
checkpoints.push(checkpoint);
}
let page_info = PageInfo {
total_count: total_count as u64,
page_size: checkpoints.len() as u32,
offset,
has_next_page: (offset + limit) < total_count as u32,
};
Ok(StepQueryResult {
checkpoints,
page_info,
})
}
#[instrument(skip(self, checkpoint), err)]
pub async fn save_with_concurrency_check(
&self,
checkpoint: Checkpoint,
expected_last_step: Option<u64>,
) -> Result<()> {
let persisted_state = PersistedState::from(&checkpoint.state);
let state_json = serialize_json(&persisted_state, "state")?;
let frontier_enc: Vec<String> = checkpoint.frontier.iter().map(|k| k.encode()).collect();
let frontier_json = serialize_json(&frontier_enc, "frontier")?;
let persisted_vs = PersistedVersionsSeen(checkpoint.versions_seen.clone());
let versions_seen_json = serialize_json(&persisted_vs, "versions_seen")?;
let ran_nodes_enc: Vec<String> = checkpoint.ran_nodes.iter().map(|k| k.encode()).collect();
let ran_nodes_json = serialize_json(&ran_nodes_enc, "ran_nodes")?;
let skipped_nodes_enc: Vec<String> = checkpoint
.skipped_nodes
.iter()
.map(|k| k.encode())
.collect();
let skipped_nodes_json = serialize_json(&skipped_nodes_enc, "skipped_nodes")?;
let updated_channels_json =
serialize_json(&checkpoint.updated_channels, "updated_channels")?;
let mut tx = self
.pool
.begin()
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("tx begin: {e}"),
})?;
sqlx::query(
r#"
INSERT INTO sessions (id, concurrency_limit)
VALUES ($1, $2)
ON CONFLICT (id) DO NOTHING
"#,
)
.bind(&checkpoint.session_id)
.bind(checkpoint.concurrency_limit as i64)
.execute(&mut *tx)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("insert session: {e}"),
})?;
if let Some(expected_step) = expected_last_step {
let current_step: i64 =
sqlx::query_scalar("SELECT last_step FROM sessions WHERE id = $1 FOR UPDATE")
.bind(&checkpoint.session_id)
.fetch_one(&mut *tx)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("concurrency check: {e}"),
})?;
if current_step != expected_step as i64 {
return Err(CheckpointerError::Backend {
message: format!(
"concurrency conflict: expected step {}, found {}",
expected_step, current_step
),
});
}
}
sqlx::query(
r#"
INSERT INTO steps (
session_id,
step,
state_json,
frontier_json,
versions_seen_json,
ran_nodes_json,
skipped_nodes_json,
updated_channels_json
) VALUES ($1, $2, $3::jsonb, $4::jsonb, $5::jsonb, $6::jsonb, $7::jsonb, $8::jsonb)
ON CONFLICT (session_id, step) DO UPDATE SET
state_json = EXCLUDED.state_json,
frontier_json = EXCLUDED.frontier_json,
versions_seen_json = EXCLUDED.versions_seen_json,
ran_nodes_json = EXCLUDED.ran_nodes_json,
skipped_nodes_json = EXCLUDED.skipped_nodes_json,
updated_channels_json = EXCLUDED.updated_channels_json
"#,
)
.bind(&checkpoint.session_id)
.bind(checkpoint.step as i64)
.bind(&state_json)
.bind(&frontier_json)
.bind(&versions_seen_json)
.bind(&ran_nodes_json)
.bind(&skipped_nodes_json)
.bind(&updated_channels_json)
.execute(&mut *tx)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("insert step: {e}"),
})?;
sqlx::query(
r#"
UPDATE sessions
SET
updated_at = NOW(),
last_step = CASE WHEN last_step <= $2 THEN $2 ELSE last_step END,
last_state_json = CASE WHEN last_step <= $2 THEN $3::jsonb ELSE last_state_json END,
last_frontier_json = CASE WHEN last_step <= $2 THEN $4::jsonb ELSE last_frontier_json END,
last_versions_seen_json = CASE WHEN last_step <= $2 THEN $5::jsonb ELSE last_versions_seen_json END
WHERE id = $1
"#,
)
.bind(&checkpoint.session_id)
.bind(checkpoint.step as i64)
.bind(&state_json)
.bind(&frontier_json)
.bind(&versions_seen_json)
.execute(&mut *tx)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("update session latest: {e}"),
})?;
tx.commit().await.map_err(|e| CheckpointerError::Backend {
message: format!("tx commit: {e}"),
})?;
Ok(())
}
fn row_to_checkpoint(&self, session_id: &str, row: &PgRow) -> Result<Checkpoint> {
let step: i64 = row.get("step");
let state_json: Value = row.get("state_json");
let frontier_json: Value = row.get("frontier_json");
let versions_seen_json: Value = row.get("versions_seen_json");
let ran_nodes_json: Value = row.get("ran_nodes_json");
let skipped_nodes_json: Value = row.get("skipped_nodes_json");
let updated_channels_json: Option<Value> =
row.try_get("updated_channels_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("updated_channels_json read: {e}"),
})?;
let created_at: DateTime<Utc> = row.get("created_at");
let concurrency_limit: i64 = row.get("concurrency_limit");
let persisted_state: PersistedState = deserialize_json_value(state_json, "state")?;
let state =
VersionedState::try_from(persisted_state).map_err(|e| CheckpointerError::Other {
message: format!("state convert: {e}"),
})?;
let frontier: Vec<NodeKind> = frontier_json
.as_array()
.ok_or_else(|| CheckpointerError::Other {
message: "frontier not array".to_string(),
})?
.iter()
.filter_map(|v| v.as_str())
.map(NodeKind::decode)
.collect();
let ran_nodes: Vec<NodeKind> = ran_nodes_json
.as_array()
.ok_or_else(|| CheckpointerError::Other {
message: "ran_nodes not array".to_string(),
})?
.iter()
.filter_map(|v| v.as_str())
.map(NodeKind::decode)
.collect();
let skipped_nodes: Vec<NodeKind> = skipped_nodes_json
.as_array()
.ok_or_else(|| CheckpointerError::Other {
message: "skipped_nodes not array".to_string(),
})?
.iter()
.filter_map(|v| v.as_str())
.map(NodeKind::decode)
.collect();
let updated_channels: Vec<String> = match updated_channels_json {
None => vec![],
Some(v) => v
.as_array()
.ok_or_else(|| CheckpointerError::Other {
message: "updated_channels not array".to_string(),
})?
.iter()
.filter_map(|v| v.as_str())
.map(|s| s.to_string())
.collect(),
};
let persisted_vs: PersistedVersionsSeen =
deserialize_json_value(versions_seen_json, "versions_seen")?;
let versions_seen = persisted_vs.0;
Ok(Checkpoint {
session_id: session_id.to_string(),
step: step as u64,
state,
frontier,
versions_seen,
concurrency_limit: concurrency_limit as usize,
created_at,
ran_nodes,
skipped_nodes,
updated_channels,
})
}
}