use async_trait::async_trait;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use sqlx::postgres::PgPool;
use sqlx::Row;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify;
use crate::durable::{WorkflowEngine, INTERRUPT_SIGNAL};
use crate::error::ClusterError;
const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(100);
pub struct SqlWorkflowEngine {
pool: PgPool,
poll_interval: Duration,
deferred_notifiers: DashMap<(String, String, String), Arc<Notify>>,
timer_notifiers: DashMap<(String, String, String), Arc<Notify>>,
}
impl SqlWorkflowEngine {
pub fn new(pool: PgPool) -> Self {
Self {
pool,
poll_interval: DEFAULT_POLL_INTERVAL,
deferred_notifiers: DashMap::new(),
timer_notifiers: DashMap::new(),
}
}
pub fn with_poll_interval(pool: PgPool, poll_interval: Duration) -> Self {
Self {
pool,
poll_interval,
deferred_notifiers: DashMap::new(),
timer_notifiers: DashMap::new(),
}
}
fn get_deferred_notifier(
&self,
workflow_name: &str,
execution_id: &str,
name: &str,
) -> Arc<Notify> {
let key = (
workflow_name.to_string(),
execution_id.to_string(),
name.to_string(),
);
self.deferred_notifiers
.entry(key)
.or_insert_with(|| Arc::new(Notify::new()))
.clone()
}
fn get_timer_notifier(
&self,
workflow_name: &str,
execution_id: &str,
name: &str,
) -> Arc<Notify> {
let key = (
workflow_name.to_string(),
execution_id.to_string(),
name.to_string(),
);
self.timer_notifiers
.entry(key)
.or_insert_with(|| Arc::new(Notify::new()))
.clone()
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn cleanup(&self, older_than: Duration) -> Result<u64, ClusterError> {
let cutoff =
Utc::now() - chrono::Duration::from_std(older_than).unwrap_or(chrono::TimeDelta::MAX);
let timers_deleted = sqlx::query(
"DELETE FROM cluster_workflow_timers WHERE fired = TRUE AND created_at < $1",
)
.bind(cutoff)
.execute(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow engine timer cleanup failed: {e}"),
source: Some(Box::new(e)),
})?
.rows_affected();
let deferred_deleted = sqlx::query(
"DELETE FROM cluster_workflow_deferred WHERE resolved = TRUE AND resolved_at < $1",
)
.bind(cutoff)
.execute(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow engine deferred cleanup failed: {e}"),
source: Some(Box::new(e)),
})?
.rows_affected();
Ok(timers_deleted + deferred_deleted)
}
}
#[async_trait]
impl WorkflowEngine for SqlWorkflowEngine {
#[tracing::instrument(level = "debug", skip(self))]
async fn sleep(
&self,
workflow_name: &str,
execution_id: &str,
name: &str,
duration: Duration,
) -> Result<(), ClusterError> {
let fire_at =
Utc::now() + chrono::Duration::from_std(duration).unwrap_or(chrono::TimeDelta::MAX);
let existing: Option<(bool, DateTime<Utc>)> = sqlx::query(
"SELECT fired, fire_at FROM cluster_workflow_timers
WHERE workflow_name = $1 AND execution_id = $2 AND timer_name = $3",
)
.bind(workflow_name)
.bind(execution_id)
.bind(name)
.fetch_optional(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow engine sleep check failed: {e}"),
source: Some(Box::new(e)),
})?
.map(|row| {
let fired: bool = row.get("fired");
let fire_at: DateTime<Utc> = row.get("fire_at");
(fired, fire_at)
});
match existing {
Some((true, _)) => {
return Ok(());
}
Some((false, existing_fire_at)) => {
self.wait_for_timer(workflow_name, execution_id, name, existing_fire_at)
.await?;
}
None => {
sqlx::query(
"INSERT INTO cluster_workflow_timers (workflow_name, execution_id, timer_name, fire_at)
VALUES ($1, $2, $3, $4)",
)
.bind(workflow_name)
.bind(execution_id)
.bind(name)
.bind(fire_at)
.execute(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow engine sleep create failed: {e}"),
source: Some(Box::new(e)),
})?;
self.wait_for_timer(workflow_name, execution_id, name, fire_at)
.await?;
}
}
Ok(())
}
#[tracing::instrument(level = "debug", skip(self))]
async fn await_deferred(
&self,
workflow_name: &str,
execution_id: &str,
name: &str,
) -> Result<Vec<u8>, ClusterError> {
let existing: Option<(bool, Option<Vec<u8>>)> = sqlx::query(
"SELECT resolved, value FROM cluster_workflow_deferred
WHERE workflow_name = $1 AND execution_id = $2 AND deferred_name = $3",
)
.bind(workflow_name)
.bind(execution_id)
.bind(name)
.fetch_optional(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow engine await_deferred check failed: {e}"),
source: Some(Box::new(e)),
})?
.map(|row| {
let resolved: bool = row.get("resolved");
let value: Option<Vec<u8>> = row.get("value");
(resolved, value)
});
match existing {
Some((true, Some(value))) => {
return Ok(value);
}
Some((true, None)) => {
return Err(ClusterError::PersistenceError {
reason: format!(
"deferred value resolved but missing: {}/{}/{}",
workflow_name, execution_id, name
),
source: None,
});
}
Some((false, _)) => {
}
None => {
sqlx::query(
"INSERT INTO cluster_workflow_deferred (workflow_name, execution_id, deferred_name, resolved)
VALUES ($1, $2, $3, FALSE)
ON CONFLICT (workflow_name, execution_id, deferred_name) DO NOTHING",
)
.bind(workflow_name)
.bind(execution_id)
.bind(name)
.execute(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow engine await_deferred create failed: {e}"),
source: Some(Box::new(e)),
})?;
}
}
let notifier = self.get_deferred_notifier(workflow_name, execution_id, name);
loop {
let row: Option<(bool, Option<Vec<u8>>)> = sqlx::query(
"SELECT resolved, value FROM cluster_workflow_deferred
WHERE workflow_name = $1 AND execution_id = $2 AND deferred_name = $3",
)
.bind(workflow_name)
.bind(execution_id)
.bind(name)
.fetch_optional(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow engine await_deferred poll failed: {e}"),
source: Some(Box::new(e)),
})?
.map(|r| {
let resolved: bool = r.get("resolved");
let value: Option<Vec<u8>> = r.get("value");
(resolved, value)
});
match row {
Some((true, Some(value))) => return Ok(value),
Some((true, None)) => {
return Err(ClusterError::PersistenceError {
reason: format!(
"deferred value resolved but missing: {}/{}/{}",
workflow_name, execution_id, name
),
source: None,
});
}
_ => {
tokio::select! {
_ = notifier.notified() => {
}
_ = tokio::time::sleep(self.poll_interval) => {
}
}
}
}
}
}
#[tracing::instrument(level = "debug", skip(self, value))]
async fn resolve_deferred(
&self,
workflow_name: &str,
execution_id: &str,
name: &str,
value: Vec<u8>,
) -> Result<(), ClusterError> {
sqlx::query(
"INSERT INTO cluster_workflow_deferred (workflow_name, execution_id, deferred_name, value, resolved, resolved_at)
VALUES ($1, $2, $3, $4, TRUE, NOW())
ON CONFLICT (workflow_name, execution_id, deferred_name)
DO UPDATE SET value = $4, resolved = TRUE, resolved_at = NOW()",
)
.bind(workflow_name)
.bind(execution_id)
.bind(name)
.bind(&value)
.execute(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow engine resolve_deferred failed: {e}"),
source: Some(Box::new(e)),
})?;
let key = (
workflow_name.to_string(),
execution_id.to_string(),
name.to_string(),
);
if let Some(notifier) = self.deferred_notifiers.get(&key) {
notifier.notify_waiters();
}
Ok(())
}
#[tracing::instrument(level = "debug", skip(self))]
async fn on_interrupt(
&self,
workflow_name: &str,
execution_id: &str,
) -> Result<(), ClusterError> {
let _ = self
.await_deferred(workflow_name, execution_id, INTERRUPT_SIGNAL)
.await?;
Ok(())
}
}
impl SqlWorkflowEngine {
#[tracing::instrument(level = "debug", skip(self))]
async fn wait_for_timer(
&self,
workflow_name: &str,
execution_id: &str,
name: &str,
fire_at: DateTime<Utc>,
) -> Result<(), ClusterError> {
let notifier = self.get_timer_notifier(workflow_name, execution_id, name);
loop {
let now = Utc::now();
if now >= fire_at {
let result = sqlx::query(
"UPDATE cluster_workflow_timers
SET fired = TRUE
WHERE workflow_name = $1 AND execution_id = $2 AND timer_name = $3 AND fired = FALSE",
)
.bind(workflow_name)
.bind(execution_id)
.bind(name)
.execute(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow engine timer fire failed: {e}"),
source: Some(Box::new(e)),
})?;
if result.rows_affected() > 0 {
return Ok(());
}
let fired: bool = sqlx::query(
"SELECT fired FROM cluster_workflow_timers
WHERE workflow_name = $1 AND execution_id = $2 AND timer_name = $3",
)
.bind(workflow_name)
.bind(execution_id)
.bind(name)
.fetch_optional(&self.pool)
.await
.map_err(|e| ClusterError::PersistenceError {
reason: format!("workflow engine timer check failed: {e}"),
source: Some(Box::new(e)),
})?
.map(|r| r.get("fired"))
.unwrap_or(true);
if fired {
return Ok(());
}
continue;
}
let remaining = (fire_at - now).to_std().unwrap_or(Duration::ZERO);
let wait_time = remaining.min(self.poll_interval);
tokio::select! {
_ = notifier.notified() => {
}
_ = tokio::time::sleep(wait_time) => {
}
}
}
}
#[cfg(test)]
pub fn notify_timer(&self, workflow_name: &str, execution_id: &str, name: &str) {
let key = (
workflow_name.to_string(),
execution_id.to_string(),
name.to_string(),
);
if let Some(notifier) = self.timer_notifiers.get(&key) {
notifier.notify_waiters();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sql_workflow_engine_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SqlWorkflowEngine>();
}
}