use std::sync::Arc;
use chrono::{DateTime, Utc};
use serde_json::Value;
use sqlx::{PgPool, Row, postgres::PgRow};
use tracing::instrument;
use crate::{
runtimes::checkpointer::{Checkpoint, Checkpointer, CheckpointerError, Result},
runtimes::persistence::{PersistedState, PersistedVersionsSeen},
state::VersionedState,
types::NodeKind,
};
#[derive(Debug, Clone, Default)]
pub struct StepQuery {
pub limit: Option<u32>,
pub offset: Option<u32>,
pub min_step: Option<u64>,
pub max_step: Option<u64>,
pub ran_node: Option<NodeKind>,
pub skipped_node: Option<NodeKind>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PageInfo {
pub total_count: u64,
pub page_size: u32,
pub offset: u32,
pub has_next_page: bool,
}
#[derive(Debug, Clone)]
pub struct StepQueryResult {
pub checkpoints: Vec<Checkpoint>,
pub page_info: PageInfo,
}
pub struct PostgresCheckpointer {
pool: Arc<PgPool>,
}
impl std::fmt::Debug for PostgresCheckpointer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PostgresCheckpointer").finish()
}
}
impl PostgresCheckpointer {
#[must_use = "checkpointer must be used to persist state"]
#[instrument(skip(database_url))]
pub async fn connect(database_url: &str) -> std::result::Result<Self, CheckpointerError> {
let pool = PgPool::connect(database_url)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("connect: {e}"),
})?;
#[cfg(feature = "postgres-migrations")]
sqlx::migrate!("./migrations/postgres")
.run(&pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("migration: {e}"),
})?;
Ok(Self {
pool: Arc::new(pool),
})
}
async fn begin_tx(&self) -> Result<Tx> {
self.pool
.begin()
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("begin transaction: {e}"),
})
}
}
#[async_trait::async_trait]
impl Checkpointer for PostgresCheckpointer {
#[instrument(skip(self, checkpoint), err)]
async fn save(&self, checkpoint: Checkpoint) -> Result<()> {
let enc = EncodedCheckpoint::encode(&checkpoint)?;
let mut tx = self.begin_tx().await?;
exec_upsert_session(
&mut tx,
&checkpoint.session_id,
checkpoint.concurrency_limit,
)
.await?;
exec_upsert_step(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?;
exec_update_latest(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?;
tx.commit().await.map_err(|e| CheckpointerError::Backend {
message: format!("commit: {e}"),
})
}
#[instrument(skip(self, session_id), err)]
async fn load_latest(&self, session_id: &str) -> Result<Option<Checkpoint>> {
let row: Option<PgRow> = sqlx::query(
"SELECT id, last_step, last_state_json, last_frontier_json, \
last_versions_seen_json, concurrency_limit, updated_at \
FROM sessions WHERE id = $1",
)
.bind(session_id)
.fetch_optional(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("load_latest: {e}"),
})?;
let row = match row {
Some(r) => r,
None => return Ok(None),
};
let last_step: i64 = row.get("last_step");
let concurrency_limit: i64 = row.get("concurrency_limit");
let updated_at: DateTime<Utc> = row.get("updated_at");
let state_json: Option<Value> =
row.try_get("last_state_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("last_state_json: {e}"),
})?;
let frontier_json: Option<Value> =
row.try_get("last_frontier_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("last_frontier_json: {e}"),
})?;
let versions_seen_json: Option<Value> =
row.try_get("last_versions_seen_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("last_versions_seen_json: {e}"),
})?;
if last_step == 0 && state_json.is_none() {
return Ok(None);
}
let state = decode_state(need_field(state_json, "last_state_json")?)?;
let frontier = decode_node_kinds(need_field(frontier_json, "last_frontier_json")?)?;
let versions_seen = {
let pv: PersistedVersionsSeen = from_json_value(
need_field(versions_seen_json, "last_versions_seen_json")?,
"versions_seen",
)?;
pv.0
};
Ok(Some(Checkpoint {
session_id: session_id.to_string(),
step: last_step as u64,
state,
frontier,
versions_seen,
concurrency_limit: concurrency_limit as usize,
created_at: updated_at,
ran_nodes: vec![],
skipped_nodes: vec![],
updated_channels: vec![],
}))
}
#[instrument(skip(self), err)]
async fn list_sessions(&self) -> Result<Vec<String>> {
sqlx::query("SELECT id FROM sessions ORDER BY updated_at DESC")
.fetch_all(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("list_sessions: {e}"),
})
.map(|rows| rows.into_iter().map(|r| r.get::<String, _>("id")).collect())
}
}
impl PostgresCheckpointer {
#[instrument(skip(self), err)]
pub async fn query_steps(&self, session_id: &str, query: StepQuery) -> Result<StepQueryResult> {
let limit = query.limit.unwrap_or(100).min(1_000);
let offset = query.offset.unwrap_or(0);
let mut conditions = vec!["st.session_id = $1".to_string()];
let mut param = 1u32;
if query.min_step.is_some() {
param += 1;
conditions.push(format!("st.step >= ${param}"));
}
if query.max_step.is_some() {
param += 1;
conditions.push(format!("st.step <= ${param}"));
}
if query.ran_node.is_some() {
param += 1;
conditions.push(format!("st.ran_nodes_json @> ${param}::jsonb"));
}
if query.skipped_node.is_some() {
param += 1;
conditions.push(format!("st.skipped_nodes_json @> ${param}::jsonb"));
}
let where_clause = conditions.join(" AND ");
let count_sql = format!("SELECT COUNT(*) AS total FROM steps st WHERE {where_clause}");
let select_sql = format!(
"SELECT st.session_id, st.step, st.state_json, st.frontier_json, \
st.versions_seen_json, st.ran_nodes_json, st.skipped_nodes_json, \
st.updated_channels_json, st.created_at, s.concurrency_limit \
FROM steps st \
JOIN sessions s ON s.id = st.session_id \
WHERE {where_clause} \
ORDER BY st.step DESC \
LIMIT {limit} OFFSET {offset}"
);
let total_count: i64 = {
let mut q = sqlx::query(&count_sql).bind(session_id);
if let Some(v) = query.min_step {
q = q.bind(v as i64);
}
if let Some(v) = query.max_step {
q = q.bind(v as i64);
}
if let Some(ref node) = query.ran_node {
q = q.bind(serde_json::json!([node.encode()]));
}
if let Some(ref node) = query.skipped_node {
q = q.bind(serde_json::json!([node.encode()]));
}
q
}
.fetch_one(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("count query: {e}"),
})?
.get("total");
let rows = {
let mut q = sqlx::query(&select_sql).bind(session_id);
if let Some(v) = query.min_step {
q = q.bind(v as i64);
}
if let Some(v) = query.max_step {
q = q.bind(v as i64);
}
if let Some(ref node) = query.ran_node {
q = q.bind(serde_json::json!([node.encode()]));
}
if let Some(ref node) = query.skipped_node {
q = q.bind(serde_json::json!([node.encode()]));
}
q
}
.fetch_all(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("select query: {e}"),
})?;
let checkpoints = rows
.iter()
.map(|r| self.row_to_checkpoint(session_id, r))
.collect::<Result<Vec<_>>>()?;
Ok(StepQueryResult {
page_info: PageInfo {
total_count: total_count as u64,
page_size: checkpoints.len() as u32,
offset,
has_next_page: (offset + limit) < total_count as u32,
},
checkpoints,
})
}
#[instrument(skip(self, checkpoint), err)]
pub async fn save_with_concurrency_check(
&self,
checkpoint: Checkpoint,
expected_last_step: Option<u64>,
) -> Result<()> {
let enc = EncodedCheckpoint::encode(&checkpoint)?;
let mut tx = self.begin_tx().await?;
exec_upsert_session(
&mut tx,
&checkpoint.session_id,
checkpoint.concurrency_limit,
)
.await?;
if let Some(expected) = expected_last_step {
let actual: i64 =
sqlx::query_scalar("SELECT last_step FROM sessions WHERE id = $1 FOR UPDATE")
.bind(&checkpoint.session_id)
.fetch_one(&mut *tx)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("concurrency check: {e}"),
})?;
if actual != expected as i64 {
return Err(CheckpointerError::Backend {
message: format!(
"concurrency conflict: expected last_step {expected}, found {actual}"
),
});
}
}
exec_upsert_step(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?;
exec_update_latest(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?;
tx.commit().await.map_err(|e| CheckpointerError::Backend {
message: format!("commit: {e}"),
})
}
fn row_to_checkpoint(&self, session_id: &str, row: &PgRow) -> Result<Checkpoint> {
let step: i64 = row.get("step");
let created_at: DateTime<Utc> = row.get("created_at");
let concurrency_limit: i64 = row.get("concurrency_limit");
let updated_channels_json: Option<Value> =
row.try_get("updated_channels_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("updated_channels_json: {e}"),
})?;
let updated_channels: Vec<String> = match updated_channels_json {
None => vec![],
Some(v) => v
.as_array()
.ok_or_else(|| CheckpointerError::Other {
message: "updated_channels_json: expected array".to_string(),
})?
.iter()
.filter_map(Value::as_str)
.map(str::to_owned)
.collect(),
};
let versions_seen = {
let pv: PersistedVersionsSeen =
from_json_value(row.get("versions_seen_json"), "versions_seen")?;
pv.0
};
Ok(Checkpoint {
session_id: session_id.to_string(),
step: step as u64,
state: decode_state(row.get("state_json"))?,
frontier: decode_node_kinds(row.get("frontier_json"))?,
ran_nodes: decode_node_kinds(row.get("ran_nodes_json"))?,
skipped_nodes: decode_node_kinds(row.get("skipped_nodes_json"))?,
versions_seen,
concurrency_limit: concurrency_limit as usize,
created_at,
updated_channels,
})
}
}
struct EncodedCheckpoint {
state_json: String,
frontier_json: String,
versions_seen_json: String,
ran_nodes_json: String,
skipped_nodes_json: String,
updated_channels_json: String,
}
impl EncodedCheckpoint {
fn encode(cp: &Checkpoint) -> Result<Self> {
let frontier_enc: Vec<String> = cp.frontier.iter().map(NodeKind::encode).collect();
let ran_nodes_enc: Vec<String> = cp.ran_nodes.iter().map(NodeKind::encode).collect();
let skipped_enc: Vec<String> = cp.skipped_nodes.iter().map(NodeKind::encode).collect();
Ok(Self {
state_json: to_json(&PersistedState::from(&cp.state), "state")?,
frontier_json: to_json(&frontier_enc, "frontier")?,
versions_seen_json: to_json(
&PersistedVersionsSeen(cp.versions_seen.clone()),
"versions_seen",
)?,
ran_nodes_json: to_json(&ran_nodes_enc, "ran_nodes")?,
skipped_nodes_json: to_json(&skipped_enc, "skipped_nodes")?,
updated_channels_json: to_json(&cp.updated_channels, "updated_channels")?,
})
}
}
fn to_json<T: serde::Serialize>(value: &T, ctx: &'static str) -> Result<String> {
serde_json::to_string(value).map_err(|e| CheckpointerError::Other {
message: format!("{ctx} serialize: {e}"),
})
}
fn from_json_value<T: serde::de::DeserializeOwned>(value: Value, ctx: &'static str) -> Result<T> {
serde_json::from_value(value).map_err(|e| CheckpointerError::Other {
message: format!("{ctx} parse: {e}"),
})
}
fn need_field(opt: Option<Value>, name: &'static str) -> Result<Value> {
opt.ok_or_else(|| CheckpointerError::Other {
message: format!("missing field {name}"),
})
}
fn decode_state(v: Value) -> Result<VersionedState> {
let persisted: PersistedState = from_json_value(v, "state")?;
VersionedState::try_from(persisted).map_err(|e| CheckpointerError::Other {
message: format!("state convert: {e}"),
})
}
fn decode_node_kinds(v: Value) -> Result<Vec<NodeKind>> {
v.as_array()
.ok_or_else(|| CheckpointerError::Other {
message: "expected JSON array of node kinds".to_string(),
})
.map(|arr| {
arr.iter()
.filter_map(Value::as_str)
.map(NodeKind::decode)
.collect()
})
}
type Tx = sqlx::Transaction<'static, sqlx::Postgres>;
async fn exec_upsert_session(
tx: &mut Tx,
session_id: &str,
concurrency_limit: usize,
) -> Result<()> {
sqlx::query(
"INSERT INTO sessions (id, concurrency_limit) VALUES ($1, $2) \
ON CONFLICT (id) DO NOTHING",
)
.bind(session_id)
.bind(concurrency_limit as i64)
.execute(&mut **tx)
.await
.map(|_| ())
.map_err(|e| CheckpointerError::Backend {
message: format!("upsert session: {e}"),
})
}
async fn exec_upsert_step(
tx: &mut Tx,
session_id: &str,
step: u64,
enc: &EncodedCheckpoint,
) -> Result<()> {
sqlx::query(
"INSERT INTO steps (
session_id, step,
state_json, frontier_json, versions_seen_json,
ran_nodes_json, skipped_nodes_json, updated_channels_json
) VALUES ($1, $2, $3::jsonb, $4::jsonb, $5::jsonb, $6::jsonb, $7::jsonb, $8::jsonb)
ON CONFLICT (session_id, step) DO UPDATE SET
state_json = EXCLUDED.state_json,
frontier_json = EXCLUDED.frontier_json,
versions_seen_json = EXCLUDED.versions_seen_json,
ran_nodes_json = EXCLUDED.ran_nodes_json,
skipped_nodes_json = EXCLUDED.skipped_nodes_json,
updated_channels_json = EXCLUDED.updated_channels_json",
)
.bind(session_id)
.bind(step as i64)
.bind(&enc.state_json)
.bind(&enc.frontier_json)
.bind(&enc.versions_seen_json)
.bind(&enc.ran_nodes_json)
.bind(&enc.skipped_nodes_json)
.bind(&enc.updated_channels_json)
.execute(&mut **tx)
.await
.map(|_| ())
.map_err(|e| CheckpointerError::Backend {
message: format!("upsert step: {e}"),
})
}
async fn exec_update_latest(
tx: &mut Tx,
session_id: &str,
step: u64,
enc: &EncodedCheckpoint,
) -> Result<()> {
sqlx::query(
"UPDATE sessions SET
updated_at = NOW(),
last_step = CASE WHEN last_step <= $2 THEN $2 ELSE last_step END,
last_state_json = CASE WHEN last_step <= $2 THEN $3::jsonb ELSE last_state_json END,
last_frontier_json = CASE WHEN last_step <= $2 THEN $4::jsonb ELSE last_frontier_json END,
last_versions_seen_json = CASE WHEN last_step <= $2 THEN $5::jsonb ELSE last_versions_seen_json END
WHERE id = $1",
)
.bind(session_id)
.bind(step as i64)
.bind(&enc.state_json)
.bind(&enc.frontier_json)
.bind(&enc.versions_seen_json)
.execute(&mut **tx)
.await
.map(|_| ())
.map_err(|e| CheckpointerError::Backend {
message: format!("update session latest: {e}"),
})
}