use std::sync::Arc;
use chrono::{DateTime, Utc};
use sqlx::{Row, SqlitePool, sqlite::SqliteRow};
use tracing::instrument;
use crate::{
runtimes::checkpointer::{Checkpoint, Checkpointer, CheckpointerError, Result},
runtimes::persistence::{PersistedState, PersistedVersionsSeen},
state::VersionedState,
types::NodeKind,
};
#[derive(Debug, Clone, Default)]
pub struct StepQuery {
pub limit: Option<u32>,
pub offset: Option<u32>,
pub min_step: Option<u64>,
pub max_step: Option<u64>,
pub ran_node: Option<NodeKind>,
pub skipped_node: Option<NodeKind>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PageInfo {
pub total_count: u64,
pub page_size: u32,
pub offset: u32,
pub has_next_page: bool,
}
#[derive(Debug, Clone)]
pub struct StepQueryResult {
pub checkpoints: Vec<Checkpoint>,
pub page_info: PageInfo,
}
pub struct SQLiteCheckpointer {
pool: Arc<SqlitePool>,
}
impl std::fmt::Debug for SQLiteCheckpointer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SQLiteCheckpointer").finish()
}
}
impl SQLiteCheckpointer {
#[must_use = "checkpointer must be used to persist state"]
#[instrument(skip(database_url))]
pub async fn connect(database_url: &str) -> std::result::Result<Self, CheckpointerError> {
let pool =
SqlitePool::connect(database_url)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("connect: {e}"),
})?;
#[cfg(feature = "sqlite-migrations")]
sqlx::migrate!("./migrations")
.run(&pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("migration: {e}"),
})?;
Ok(Self {
pool: Arc::new(pool),
})
}
async fn begin_tx(&self) -> Result<Tx> {
self.pool
.begin()
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("begin transaction: {e}"),
})
}
}
#[async_trait::async_trait]
impl Checkpointer for SQLiteCheckpointer {
#[instrument(skip(self, checkpoint), err)]
async fn save(&self, checkpoint: Checkpoint) -> Result<()> {
let enc = EncodedCheckpoint::encode(&checkpoint)?;
let mut tx = self.begin_tx().await?;
exec_insert_session(
&mut tx,
&checkpoint.session_id,
checkpoint.concurrency_limit,
)
.await?;
exec_upsert_step(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?;
tx.commit().await.map_err(|e| CheckpointerError::Backend {
message: format!("commit: {e}"),
})
}
#[instrument(skip(self, session_id), err)]
async fn load_latest(&self, session_id: &str) -> Result<Option<Checkpoint>> {
let row: Option<SqliteRow> = sqlx::query(
"SELECT id, last_step, last_state_json, last_frontier_json, \
last_versions_seen_json, concurrency_limit, updated_at \
FROM sessions WHERE id = ?1",
)
.bind(session_id)
.fetch_optional(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("load_latest: {e}"),
})?;
let row = match row {
Some(r) => r,
None => return Ok(None),
};
let last_step: i64 = row.get("last_step");
let concurrency_limit: i64 = row.get("concurrency_limit");
let updated_at_str: String = row.get("updated_at");
let state_json: Option<String> =
row.try_get("last_state_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("last_state_json: {e}"),
})?;
let frontier_json: Option<String> =
row.try_get("last_frontier_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("last_frontier_json: {e}"),
})?;
let versions_seen_json: Option<String> =
row.try_get("last_versions_seen_json")
.map_err(|e| CheckpointerError::Backend {
message: format!("last_versions_seen_json: {e}"),
})?;
if last_step == 0 && state_json.is_none() {
return Ok(None);
}
let state = decode_state(&need_field(state_json, "last_state_json")?)?;
let frontier = decode_node_kinds(&need_field(frontier_json, "last_frontier_json")?)?;
let versions_seen = {
let pv: PersistedVersionsSeen = from_json_str(
&need_field(versions_seen_json, "last_versions_seen_json")?,
"versions_seen",
)?;
pv.0
};
let created_at = DateTime::parse_from_rfc3339(&updated_at_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
Ok(Some(Checkpoint {
session_id: session_id.to_string(),
step: last_step as u64,
state,
frontier,
versions_seen,
concurrency_limit: concurrency_limit as usize,
created_at,
ran_nodes: vec![],
skipped_nodes: vec![],
updated_channels: vec![],
}))
}
#[instrument(skip(self), err)]
async fn list_sessions(&self) -> Result<Vec<String>> {
sqlx::query("SELECT id FROM sessions ORDER BY updated_at DESC")
.fetch_all(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("list_sessions: {e}"),
})
.map(|rows| rows.into_iter().map(|r| r.get::<String, _>("id")).collect())
}
}
impl SQLiteCheckpointer {
#[instrument(skip(self), err)]
pub async fn query_steps(&self, session_id: &str, query: StepQuery) -> Result<StepQueryResult> {
let limit = query.limit.unwrap_or(100).min(1_000);
let offset = query.offset.unwrap_or(0);
let mut conditions = vec!["session_id = ?1".to_string()];
let mut param = 1u32;
if query.min_step.is_some() {
param += 1;
conditions.push(format!("step >= ?{param}"));
}
if query.max_step.is_some() {
param += 1;
conditions.push(format!("step <= ?{param}"));
}
if query.ran_node.is_some() {
param += 1;
conditions.push(format!("JSON_EXTRACT(ran_nodes_json, '$') LIKE ?{param}"));
}
if query.skipped_node.is_some() {
param += 1;
conditions.push(format!(
"JSON_EXTRACT(skipped_nodes_json, '$') LIKE ?{param}"
));
}
let where_clause = conditions.join(" AND ");
let count_sql = format!("SELECT COUNT(*) AS total FROM steps WHERE {where_clause}");
let select_sql = format!(
"SELECT session_id, step, state_json, frontier_json, versions_seen_json, \
ran_nodes_json, skipped_nodes_json, updated_channels_json, created_at \
FROM steps WHERE {where_clause} \
ORDER BY step DESC LIMIT {limit} OFFSET {offset}"
);
let total_count: i64 = {
let mut q = sqlx::query(&count_sql).bind(session_id);
if let Some(v) = query.min_step {
q = q.bind(v as i64);
}
if let Some(v) = query.max_step {
q = q.bind(v as i64);
}
if let Some(ref node) = query.ran_node {
q = q.bind(format!("%{}%", node.encode()));
}
if let Some(ref node) = query.skipped_node {
q = q.bind(format!("%{}%", node.encode()));
}
q
}
.fetch_one(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("count query: {e}"),
})?
.get("total");
let rows = {
let mut q = sqlx::query(&select_sql).bind(session_id);
if let Some(v) = query.min_step {
q = q.bind(v as i64);
}
if let Some(v) = query.max_step {
q = q.bind(v as i64);
}
if let Some(ref node) = query.ran_node {
q = q.bind(format!("%{}%", node.encode()));
}
if let Some(ref node) = query.skipped_node {
q = q.bind(format!("%{}%", node.encode()));
}
q
}
.fetch_all(&*self.pool)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("select query: {e}"),
})?;
let checkpoints = rows
.iter()
.map(|r| self.row_to_checkpoint(session_id, r))
.collect::<Result<Vec<_>>>()?;
Ok(StepQueryResult {
page_info: PageInfo {
total_count: total_count as u64,
page_size: checkpoints.len() as u32,
offset,
has_next_page: (offset + limit) < total_count as u32,
},
checkpoints,
})
}
#[instrument(skip(self, checkpoint), err)]
pub async fn save_with_concurrency_check(
&self,
checkpoint: Checkpoint,
expected_last_step: Option<u64>,
) -> Result<()> {
let enc = EncodedCheckpoint::encode(&checkpoint)?;
let mut tx = self.begin_tx().await?;
exec_insert_session(
&mut tx,
&checkpoint.session_id,
checkpoint.concurrency_limit,
)
.await?;
if let Some(expected) = expected_last_step {
let current: Option<i64> =
sqlx::query_scalar("SELECT last_step FROM sessions WHERE id = ?1")
.bind(&checkpoint.session_id)
.fetch_optional(&mut *tx)
.await
.map_err(|e| CheckpointerError::Backend {
message: format!("concurrency check: {e}"),
})?;
match current {
Some(actual) if actual != expected as i64 => {
return Err(CheckpointerError::Backend {
message: format!(
"concurrency conflict: expected last_step {expected}, found {actual}"
),
});
}
None if expected != 0 => {
return Err(CheckpointerError::Backend {
message: format!(
"concurrency conflict: session not found, expected step {expected}"
),
});
}
_ => {}
}
}
exec_upsert_step(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?;
tx.commit().await.map_err(|e| CheckpointerError::Backend {
message: format!("commit: {e}"),
})
}
fn row_to_checkpoint(&self, session_id: &str, row: &SqliteRow) -> Result<Checkpoint> {
let step: i64 = row.get("step");
let created_at_str: String = row.get("created_at");
let state_json: String = row.get("state_json");
let frontier_json: String = row.get("frontier_json");
let versions_seen_json: String = row.get("versions_seen_json");
let ran_nodes_json: String = row.get("ran_nodes_json");
let skipped_nodes_json: String = row.get("skipped_nodes_json");
let updated_channels_json: Option<String> =
row.try_get("updated_channels_json").ok().flatten();
let updated_channels = match updated_channels_json {
Some(ref json) => from_json_str::<Vec<String>>(json, "updated_channels")?,
None => vec![],
};
let pv: PersistedVersionsSeen = from_json_str(&versions_seen_json, "versions_seen")?;
let created_at = DateTime::parse_from_rfc3339(&created_at_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
Ok(Checkpoint {
session_id: session_id.to_string(),
step: step as u64,
state: decode_state(&state_json)?,
frontier: decode_node_kinds(&frontier_json)?,
ran_nodes: decode_node_kinds(&ran_nodes_json)?,
skipped_nodes: decode_node_kinds(&skipped_nodes_json)?,
versions_seen: pv.0,
concurrency_limit: 1,
created_at,
updated_channels,
})
}
}
struct EncodedCheckpoint {
state_json: String,
frontier_json: String,
versions_seen_json: String,
ran_nodes_json: String,
skipped_nodes_json: String,
updated_channels_json: String,
}
impl EncodedCheckpoint {
fn encode(cp: &Checkpoint) -> Result<Self> {
let frontier_enc: Vec<String> = cp.frontier.iter().map(NodeKind::encode).collect();
let ran_nodes_enc: Vec<String> = cp.ran_nodes.iter().map(NodeKind::encode).collect();
let skipped_enc: Vec<String> = cp.skipped_nodes.iter().map(NodeKind::encode).collect();
Ok(Self {
state_json: to_json(&PersistedState::from(&cp.state), "state")?,
frontier_json: to_json(&frontier_enc, "frontier")?,
versions_seen_json: to_json(
&PersistedVersionsSeen(cp.versions_seen.clone()),
"versions_seen",
)?,
ran_nodes_json: to_json(&ran_nodes_enc, "ran_nodes")?,
skipped_nodes_json: to_json(&skipped_enc, "skipped_nodes")?,
updated_channels_json: to_json(&cp.updated_channels, "updated_channels")?,
})
}
}
fn to_json<T: serde::Serialize>(value: &T, ctx: &'static str) -> Result<String> {
serde_json::to_string(value).map_err(|e| CheckpointerError::Other {
message: format!("{ctx} serialize: {e}"),
})
}
fn from_json_str<T: serde::de::DeserializeOwned>(json: &str, ctx: &'static str) -> Result<T> {
serde_json::from_str(json).map_err(|e| CheckpointerError::Other {
message: format!("{ctx} parse: {e}"),
})
}
fn need_field(opt: Option<String>, name: &'static str) -> Result<String> {
opt.ok_or_else(|| CheckpointerError::Other {
message: format!("missing field {name}"),
})
}
fn decode_state(json: &str) -> Result<VersionedState> {
let persisted: PersistedState = from_json_str(json, "state")?;
VersionedState::try_from(persisted).map_err(|e| CheckpointerError::Other {
message: format!("state convert: {e}"),
})
}
fn decode_node_kinds(json: &str) -> Result<Vec<NodeKind>> {
let encoded: Vec<String> = from_json_str(json, "node_kinds")?;
Ok(encoded.iter().map(|s| NodeKind::decode(s)).collect())
}
type Tx = sqlx::Transaction<'static, sqlx::Sqlite>;
async fn exec_insert_session(
tx: &mut Tx,
session_id: &str,
concurrency_limit: usize,
) -> Result<()> {
sqlx::query("INSERT OR IGNORE INTO sessions (id, concurrency_limit) VALUES (?1, ?2)")
.bind(session_id)
.bind(concurrency_limit as i64)
.execute(&mut **tx)
.await
.map(|_| ())
.map_err(|e| CheckpointerError::Backend {
message: format!("insert session: {e}"),
})
}
async fn exec_upsert_step(
tx: &mut Tx,
session_id: &str,
step: u64,
enc: &EncodedCheckpoint,
) -> Result<()> {
sqlx::query(
"INSERT OR REPLACE INTO steps \
(session_id, step, state_json, frontier_json, versions_seen_json, \
ran_nodes_json, skipped_nodes_json, updated_channels_json) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
.bind(session_id)
.bind(step as i64)
.bind(&enc.state_json)
.bind(&enc.frontier_json)
.bind(&enc.versions_seen_json)
.bind(&enc.ran_nodes_json)
.bind(&enc.skipped_nodes_json)
.bind(&enc.updated_channels_json)
.execute(&mut **tx)
.await
.map(|_| ())
.map_err(|e| CheckpointerError::Backend {
message: format!("upsert step: {e}"),
})
}