use async_trait::async_trait;
use chrono::Utc;
use sqlx::postgres::{PgPool, PgQueryResult, PgRow, Postgres};
use sqlx::{Row, Transaction};
use std::any::Any;
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::durable::{StorageTransaction, WorkflowStorage};
use crate::error::ClusterError;
fn increment_last_byte(prefix: &str) -> Option<String> {
let mut bytes = prefix.as_bytes().to_vec();
while let Some(&last) = bytes.last() {
if last < 0xFF {
*bytes.last_mut().unwrap() = last + 1;
match String::from_utf8(bytes.clone()) {
Ok(s) => return Some(s),
Err(_) => {
bytes.pop();
continue;
}
}
}
bytes.pop();
}
None
}
pub struct SqlWorkflowStorage {
pool: PgPool,
}
impl SqlWorkflowStorage {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[async_trait]
impl WorkflowStorage for SqlWorkflowStorage {
#[tracing::instrument(level = "debug", skip(self))]
async fn load(&self, key: &str) -> Result<Option<Vec<u8>>, ClusterError> {
let row = sqlx::query("SELECT value FROM cluster_workflow_journal WHERE key = $1")
.bind(key)
.fetch_optional(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow storage load failed: {e}"),
source: Some(Box::new(e)),
})?;
match row {
Some(r) => {
let value: Vec<u8> =
r.try_get("value")
.map_err(|e| ClusterError::PersistenceError {
reason: format!(
"workflow storage load: failed to read 'value' column: {e}"
),
source: Some(Box::new(e)),
})?;
Ok(Some(value))
}
None => Ok(None),
}
}
#[tracing::instrument(level = "debug", skip(self, value))]
async fn save(&self, key: &str, value: &[u8]) -> Result<(), ClusterError> {
sqlx::query(
"INSERT INTO cluster_workflow_journal (key, value, updated_at)
VALUES ($1, $2, NOW())
ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = NOW()",
)
.bind(key)
.bind(value)
.execute(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow storage save failed: {e}"),
source: Some(Box::new(e)),
})?;
Ok(())
}
#[tracing::instrument(level = "debug", skip(self))]
async fn delete(&self, key: &str) -> Result<(), ClusterError> {
sqlx::query("DELETE FROM cluster_workflow_journal WHERE key = $1")
.bind(key)
.execute(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow storage delete failed: {e}"),
source: Some(Box::new(e)),
})?;
Ok(())
}
#[tracing::instrument(level = "debug", skip(self))]
async fn mark_completed(&self, key: &str) -> Result<(), ClusterError> {
sqlx::query("UPDATE cluster_workflow_journal SET completed_at = NOW() WHERE key = $1")
.bind(key)
.execute(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow storage mark_completed failed: {e}"),
source: Some(Box::new(e)),
})?;
Ok(())
}
#[tracing::instrument(level = "debug", skip(self))]
async fn cleanup(&self, older_than: std::time::Duration) -> Result<u64, ClusterError> {
let cutoff =
Utc::now() - chrono::Duration::from_std(older_than).unwrap_or(chrono::TimeDelta::MAX);
let result = sqlx::query(
"DELETE FROM cluster_workflow_journal WHERE completed_at IS NOT NULL AND completed_at < $1",
)
.bind(cutoff)
.execute(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow storage cleanup failed: {e}"),
source: Some(Box::new(e)),
})?;
Ok(result.rows_affected())
}
#[tracing::instrument(level = "debug", skip(self))]
async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, ClusterError> {
let range_end = increment_last_byte(prefix);
let rows = match &range_end {
Some(end) => {
sqlx::query(
"SELECT key FROM cluster_workflow_journal WHERE key COLLATE \"C\" >= $1 AND key COLLATE \"C\" < $2 ORDER BY key ASC",
)
.bind(prefix)
.bind(end)
.fetch_all(&self.pool)
.await
}
None => {
sqlx::query(
"SELECT key FROM cluster_workflow_journal WHERE key COLLATE \"C\" >= $1 ORDER BY key ASC",
)
.bind(prefix)
.fetch_all(&self.pool)
.await
}
}
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow storage list_keys failed: {e}"),
source: Some(Box::new(e)),
})?;
rows.iter()
.map(|r| {
r.try_get("key")
.map_err(|e| ClusterError::PersistenceError {
reason: format!(
"workflow storage list_keys: failed to read 'key' column: {e}"
),
source: Some(Box::new(e)),
})
})
.collect()
}
#[tracing::instrument(level = "debug", skip(self))]
async fn begin_transaction(&self) -> Result<Box<dyn StorageTransaction>, ClusterError> {
let tx = self
.pool
.begin()
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("failed to begin transaction: {e}"),
source: Some(Box::new(e)),
})?;
Ok(Box::new(SqlJournalTransaction {
tx: Arc::new(Mutex::new(Some(tx))),
}))
}
fn as_arc(&self) -> Arc<dyn WorkflowStorage> {
panic!("SqlWorkflowStorage::as_arc() should not be called")
}
fn sql_pool(&self) -> Option<&PgPool> {
Some(&self.pool)
}
}
pub struct SqlJournalTransaction {
tx: Arc<Mutex<Option<Transaction<'static, Postgres>>>>,
}
#[async_trait]
impl StorageTransaction for SqlJournalTransaction {
async fn save(&mut self, key: &str, value: &[u8]) -> Result<(), ClusterError> {
let mut guard = self.tx.lock().await;
let tx = guard
.as_mut()
.ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction already committed or rolled back".to_string(),
source: None,
})?;
sqlx::query(
"INSERT INTO cluster_workflow_journal (key, value, updated_at)
VALUES ($1, $2, NOW())
ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = NOW()",
)
.bind(key)
.bind(value)
.execute(&mut **tx)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("transaction save failed: {e}"),
source: Some(Box::new(e)),
})?;
Ok(())
}
async fn delete(&mut self, key: &str) -> Result<(), ClusterError> {
let mut guard = self.tx.lock().await;
let tx = guard
.as_mut()
.ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction already committed or rolled back".to_string(),
source: None,
})?;
sqlx::query("DELETE FROM cluster_workflow_journal WHERE key = $1")
.bind(key)
.execute(&mut **tx)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("transaction delete failed: {e}"),
source: Some(Box::new(e)),
})?;
Ok(())
}
async fn commit(self: Box<Self>) -> Result<(), ClusterError> {
let mut guard = self.tx.lock().await;
let tx = guard.take().ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction already committed or rolled back".to_string(),
source: None,
})?;
tx.commit()
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("transaction commit failed: {e}"),
source: Some(Box::new(e)),
})?;
Ok(())
}
async fn rollback(self: Box<Self>) -> Result<(), ClusterError> {
let mut guard = self.tx.lock().await;
let tx = guard.take().ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction already committed or rolled back".to_string(),
source: None,
})?;
tx.rollback()
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("transaction rollback failed: {e}"),
source: Some(Box::new(e)),
})?;
Ok(())
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
impl SqlJournalTransaction {
pub async fn execute(
&self,
query: sqlx::query::Query<'_, Postgres, sqlx::postgres::PgArguments>,
) -> Result<PgQueryResult, ClusterError> {
let mut guard = self.tx.lock().await;
let tx = guard
.as_mut()
.ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction already committed or rolled back".to_string(),
source: None,
})?;
query
.execute(&mut **tx)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("SQL execute failed: {e}"),
source: Some(Box::new(e)),
})
}
pub async fn fetch_one<'q, O>(
&self,
query: sqlx::query::QueryAs<'q, Postgres, O, sqlx::postgres::PgArguments>,
) -> Result<O, ClusterError>
where
O: Send + Unpin + for<'r> sqlx::FromRow<'r, PgRow>,
{
let mut guard = self.tx.lock().await;
let tx = guard
.as_mut()
.ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction already committed or rolled back".to_string(),
source: None,
})?;
query
.fetch_one(&mut **tx)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("SQL fetch_one failed: {e}"),
source: Some(Box::new(e)),
})
}
pub async fn fetch_optional<'q, O>(
&self,
query: sqlx::query::QueryAs<'q, Postgres, O, sqlx::postgres::PgArguments>,
) -> Result<Option<O>, ClusterError>
where
O: Send + Unpin + for<'r> sqlx::FromRow<'r, PgRow>,
{
let mut guard = self.tx.lock().await;
let tx = guard
.as_mut()
.ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction already committed or rolled back".to_string(),
source: None,
})?;
query
.fetch_optional(&mut **tx)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("SQL fetch_optional failed: {e}"),
source: Some(Box::new(e)),
})
}
pub async fn fetch_all<'q, O>(
&self,
query: sqlx::query::QueryAs<'q, Postgres, O, sqlx::postgres::PgArguments>,
) -> Result<Vec<O>, ClusterError>
where
O: Send + Unpin + for<'r> sqlx::FromRow<'r, PgRow>,
{
let mut guard = self.tx.lock().await;
let tx = guard
.as_mut()
.ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction already committed or rolled back".to_string(),
source: None,
})?;
query
.fetch_all(&mut **tx)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("SQL fetch_all failed: {e}"),
source: Some(Box::new(e)),
})
}
}
pub async fn save_journal_entry(
conn: &mut sqlx::PgConnection,
key: &str,
value: &[u8],
) -> Result<(), ClusterError> {
sqlx::query(
"INSERT INTO cluster_workflow_journal (key, value, updated_at)
VALUES ($1, $2, NOW())
ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = NOW()",
)
.bind(key)
.bind(value)
.execute(conn)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("journal entry save failed: {e}"),
source: Some(Box::new(e)),
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sql_workflow_storage_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SqlWorkflowStorage>();
}
#[test]
fn increment_last_byte_normal() {
assert_eq!(super::increment_last_byte("abc"), Some("abd".to_string()));
}
#[test]
fn increment_last_byte_trailing_high_byte() {
assert_eq!(super::increment_last_byte("az"), Some("a{".to_string()));
}
#[test]
fn increment_last_byte_empty() {
assert_eq!(super::increment_last_byte(""), None);
}
#[test]
fn increment_last_byte_prefix_query_semantics() {
let result = super::increment_last_byte("result/");
assert_eq!(result, Some("result0".to_string())); }
}