Skip to main content

cruster/storage/
sql_workflow_runtime.rs

1//! SQL-backed workflow runtime engine using PostgreSQL via sqlx.
2//!
3//! Provides durable `sleep()` and `await_deferred()`/`resolve_deferred()` operations
4//! that survive entity restarts and runner crashes.
5//!
6//! This module is only available when the `sql` feature is enabled.
7//!
8//! # Architecture
9//!
10//! The engine uses two tables:
11//! - `cluster_workflow_timers`: Stores scheduled wake-up times for `sleep()` operations
12//! - `cluster_workflow_deferred`: Stores key-value pairs for deferred signal operations
13//!
14//! ## Timer Flow (sleep)
15//!
16//! 1. Workflow calls `sleep(name, duration)`
17//! 2. Engine checks if timer already exists and has fired (idempotency)
18//! 3. If not, creates timer record with `fire_at = now + duration`
19//! 4. Polls database until timer is due
20//! 5. Marks timer as fired and returns
21//!
22//! ## Deferred Value Flow (await_deferred/resolve_deferred)
23//!
24//! 1. Workflow calls `await_deferred(key)`
25//! 2. Engine checks if value already resolved (idempotency)
26//! 3. If not, creates pending record and polls until resolved
27//! 4. External caller invokes `resolve_deferred(key, value)`
28//! 5. Engine updates record with value and sets resolved=true
29//! 6. Waiting workflow sees resolved value and returns
30
31use async_trait::async_trait;
32use chrono::{DateTime, Utc};
33use dashmap::DashMap;
34use sqlx::postgres::PgPool;
35use sqlx::Row;
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::sync::Notify;
39
40use crate::durable::{WorkflowEngine, INTERRUPT_SIGNAL};
41use crate::error::ClusterError;
42
43/// Poll interval for checking timer/deferred status in the database.
44const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(100);
45
46/// PostgreSQL-backed workflow runtime engine.
47///
48/// Provides durable implementations of `sleep()`, `await_deferred()`, and
49/// `resolve_deferred()` that persist to the database and survive restarts.
50///
51/// # Example
52///
53/// ```text
54/// use sqlx::postgres::PgPool;
55/// use cruster::storage::sql_workflow_runtime::SqlWorkflowEngine;
56///
57/// let pool = PgPool::connect("postgres://...").await?;
58/// cruster::storage::migrate(&pool).await?;
59/// let engine = SqlWorkflowEngine::new(pool);
60/// ```
61pub struct SqlWorkflowEngine {
62    pool: PgPool,
63    /// Poll interval for database checks
64    poll_interval: Duration,
65    /// In-memory notifiers for immediate wake-up on resolve_deferred
66    /// Key: (workflow_name, execution_id, deferred_name)
67    deferred_notifiers: DashMap<(String, String, String), Arc<Notify>>,
68    /// In-memory notifiers for timers (for testing with short sleeps)
69    timer_notifiers: DashMap<(String, String, String), Arc<Notify>>,
70}
71
72impl SqlWorkflowEngine {
73    /// Create a new SQL workflow runtime engine with the given connection pool.
74    ///
75    /// Run [`crate::storage::migrate`] before using SQL storage backends.
76    pub fn new(pool: PgPool) -> Self {
77        Self {
78            pool,
79            poll_interval: DEFAULT_POLL_INTERVAL,
80            deferred_notifiers: DashMap::new(),
81            timer_notifiers: DashMap::new(),
82        }
83    }
84
85    /// Create a new SQL workflow runtime engine with a custom poll interval.
86    ///
87    /// The poll interval controls how often the engine checks the database
88    /// for timer/deferred state changes. Lower values provide faster response
89    /// but increase database load.
90    pub fn with_poll_interval(pool: PgPool, poll_interval: Duration) -> Self {
91        Self {
92            pool,
93            poll_interval,
94            deferred_notifiers: DashMap::new(),
95            timer_notifiers: DashMap::new(),
96        }
97    }
98
99    /// Get or create a notifier for a deferred key.
100    fn get_deferred_notifier(
101        &self,
102        workflow_name: &str,
103        execution_id: &str,
104        name: &str,
105    ) -> Arc<Notify> {
106        let key = (
107            workflow_name.to_string(),
108            execution_id.to_string(),
109            name.to_string(),
110        );
111        self.deferred_notifiers
112            .entry(key)
113            .or_insert_with(|| Arc::new(Notify::new()))
114            .clone()
115    }
116
117    /// Get or create a notifier for a timer.
118    fn get_timer_notifier(
119        &self,
120        workflow_name: &str,
121        execution_id: &str,
122        name: &str,
123    ) -> Arc<Notify> {
124        let key = (
125            workflow_name.to_string(),
126            execution_id.to_string(),
127            name.to_string(),
128        );
129        self.timer_notifiers
130            .entry(key)
131            .or_insert_with(|| Arc::new(Notify::new()))
132            .clone()
133    }
134
135    /// Clean up old completed entries from the database.
136    ///
137    /// Deletes timers that have fired and deferred values that have been resolved
138    /// older than the specified duration.
139    #[tracing::instrument(level = "debug", skip(self))]
140    pub async fn cleanup(&self, older_than: Duration) -> Result<u64, ClusterError> {
141        let cutoff =
142            Utc::now() - chrono::Duration::from_std(older_than).unwrap_or(chrono::TimeDelta::MAX);
143
144        // Clean up fired timers
145        let timers_deleted = sqlx::query(
146            "DELETE FROM cluster_workflow_timers WHERE fired = TRUE AND created_at < $1",
147        )
148        .bind(cutoff)
149        .execute(&self.pool)
150        .await
151        .map_err(|e| ClusterError::PersistenceError {
152            reason: format!("workflow engine timer cleanup failed: {e}"),
153            source: Some(Box::new(e)),
154        })?
155        .rows_affected();
156
157        // Clean up resolved deferred values
158        let deferred_deleted = sqlx::query(
159            "DELETE FROM cluster_workflow_deferred WHERE resolved = TRUE AND resolved_at < $1",
160        )
161        .bind(cutoff)
162        .execute(&self.pool)
163        .await
164        .map_err(|e| ClusterError::PersistenceError {
165            reason: format!("workflow engine deferred cleanup failed: {e}"),
166            source: Some(Box::new(e)),
167        })?
168        .rows_affected();
169
170        Ok(timers_deleted + deferred_deleted)
171    }
172}
173
174#[async_trait]
175impl WorkflowEngine for SqlWorkflowEngine {
176    #[tracing::instrument(level = "debug", skip(self))]
177    async fn sleep(
178        &self,
179        workflow_name: &str,
180        execution_id: &str,
181        name: &str,
182        duration: Duration,
183    ) -> Result<(), ClusterError> {
184        let fire_at =
185            Utc::now() + chrono::Duration::from_std(duration).unwrap_or(chrono::TimeDelta::MAX);
186
187        // Check if timer already exists
188        let existing: Option<(bool, DateTime<Utc>)> = sqlx::query(
189            "SELECT fired, fire_at FROM cluster_workflow_timers 
190             WHERE workflow_name = $1 AND execution_id = $2 AND timer_name = $3",
191        )
192        .bind(workflow_name)
193        .bind(execution_id)
194        .bind(name)
195        .fetch_optional(&self.pool)
196        .await
197        .map_err(|e| ClusterError::PersistenceError {
198            reason: format!("workflow engine sleep check failed: {e}"),
199            source: Some(Box::new(e)),
200        })?
201        .map(|row| {
202            let fired: bool = row.get("fired");
203            let fire_at: DateTime<Utc> = row.get("fire_at");
204            (fired, fire_at)
205        });
206
207        match existing {
208            Some((true, _)) => {
209                // Timer already fired - idempotent return
210                return Ok(());
211            }
212            Some((false, existing_fire_at)) => {
213                // Timer exists but hasn't fired - wait for it
214                // Use the existing fire_at time, not the new one
215                self.wait_for_timer(workflow_name, execution_id, name, existing_fire_at)
216                    .await?;
217            }
218            None => {
219                // Create new timer
220                sqlx::query(
221                    "INSERT INTO cluster_workflow_timers (workflow_name, execution_id, timer_name, fire_at)
222                     VALUES ($1, $2, $3, $4)",
223                )
224                .bind(workflow_name)
225                .bind(execution_id)
226                .bind(name)
227                .bind(fire_at)
228                .execute(&self.pool)
229                .await
230                .map_err(|e| ClusterError::PersistenceError {
231                    reason: format!("workflow engine sleep create failed: {e}"),
232                    source: Some(Box::new(e)),
233                })?;
234
235                // Wait for timer
236                self.wait_for_timer(workflow_name, execution_id, name, fire_at)
237                    .await?;
238            }
239        }
240
241        Ok(())
242    }
243
244    #[tracing::instrument(level = "debug", skip(self))]
245    async fn await_deferred(
246        &self,
247        workflow_name: &str,
248        execution_id: &str,
249        name: &str,
250    ) -> Result<Vec<u8>, ClusterError> {
251        // Check if already resolved
252        let existing: Option<(bool, Option<Vec<u8>>)> = sqlx::query(
253            "SELECT resolved, value FROM cluster_workflow_deferred 
254             WHERE workflow_name = $1 AND execution_id = $2 AND deferred_name = $3",
255        )
256        .bind(workflow_name)
257        .bind(execution_id)
258        .bind(name)
259        .fetch_optional(&self.pool)
260        .await
261        .map_err(|e| ClusterError::PersistenceError {
262            reason: format!("workflow engine await_deferred check failed: {e}"),
263            source: Some(Box::new(e)),
264        })?
265        .map(|row| {
266            let resolved: bool = row.get("resolved");
267            let value: Option<Vec<u8>> = row.get("value");
268            (resolved, value)
269        });
270
271        match existing {
272            Some((true, Some(value))) => {
273                // Already resolved - idempotent return
274                return Ok(value);
275            }
276            Some((true, None)) => {
277                // Resolved but no value - shouldn't happen but handle gracefully
278                return Err(ClusterError::PersistenceError {
279                    reason: format!(
280                        "deferred value resolved but missing: {}/{}/{}",
281                        workflow_name, execution_id, name
282                    ),
283                    source: None,
284                });
285            }
286            Some((false, _)) => {
287                // Record exists but not resolved - wait for it
288            }
289            None => {
290                // Create pending record
291                sqlx::query(
292                    "INSERT INTO cluster_workflow_deferred (workflow_name, execution_id, deferred_name, resolved)
293                     VALUES ($1, $2, $3, FALSE)
294                     ON CONFLICT (workflow_name, execution_id, deferred_name) DO NOTHING",
295                )
296                .bind(workflow_name)
297                .bind(execution_id)
298                .bind(name)
299                .execute(&self.pool)
300                .await
301                .map_err(|e| ClusterError::PersistenceError {
302                    reason: format!("workflow engine await_deferred create failed: {e}"),
303                    source: Some(Box::new(e)),
304                })?;
305            }
306        }
307
308        // Poll/wait for resolution
309        let notifier = self.get_deferred_notifier(workflow_name, execution_id, name);
310
311        loop {
312            // Check if resolved
313            let row: Option<(bool, Option<Vec<u8>>)> = sqlx::query(
314                "SELECT resolved, value FROM cluster_workflow_deferred 
315                 WHERE workflow_name = $1 AND execution_id = $2 AND deferred_name = $3",
316            )
317            .bind(workflow_name)
318            .bind(execution_id)
319            .bind(name)
320            .fetch_optional(&self.pool)
321            .await
322            .map_err(|e| ClusterError::PersistenceError {
323                reason: format!("workflow engine await_deferred poll failed: {e}"),
324                source: Some(Box::new(e)),
325            })?
326            .map(|r| {
327                let resolved: bool = r.get("resolved");
328                let value: Option<Vec<u8>> = r.get("value");
329                (resolved, value)
330            });
331
332            match row {
333                Some((true, Some(value))) => return Ok(value),
334                Some((true, None)) => {
335                    return Err(ClusterError::PersistenceError {
336                        reason: format!(
337                            "deferred value resolved but missing: {}/{}/{}",
338                            workflow_name, execution_id, name
339                        ),
340                        source: None,
341                    });
342                }
343                _ => {
344                    // Wait for notification or poll timeout
345                    tokio::select! {
346                        _ = notifier.notified() => {
347                            // Got notification, re-check immediately
348                        }
349                        _ = tokio::time::sleep(self.poll_interval) => {
350                            // Poll timeout, re-check
351                        }
352                    }
353                }
354            }
355        }
356    }
357
358    #[tracing::instrument(level = "debug", skip(self, value))]
359    async fn resolve_deferred(
360        &self,
361        workflow_name: &str,
362        execution_id: &str,
363        name: &str,
364        value: Vec<u8>,
365    ) -> Result<(), ClusterError> {
366        // Upsert the resolved value
367        sqlx::query(
368            "INSERT INTO cluster_workflow_deferred (workflow_name, execution_id, deferred_name, value, resolved, resolved_at)
369             VALUES ($1, $2, $3, $4, TRUE, NOW())
370             ON CONFLICT (workflow_name, execution_id, deferred_name) 
371             DO UPDATE SET value = $4, resolved = TRUE, resolved_at = NOW()",
372        )
373        .bind(workflow_name)
374        .bind(execution_id)
375        .bind(name)
376        .bind(&value)
377        .execute(&self.pool)
378        .await
379        .map_err(|e| ClusterError::PersistenceError {
380            reason: format!("workflow engine resolve_deferred failed: {e}"),
381            source: Some(Box::new(e)),
382        })?;
383
384        // Notify any in-memory waiters
385        let key = (
386            workflow_name.to_string(),
387            execution_id.to_string(),
388            name.to_string(),
389        );
390        if let Some(notifier) = self.deferred_notifiers.get(&key) {
391            notifier.notify_waiters();
392        }
393
394        Ok(())
395    }
396
397    #[tracing::instrument(level = "debug", skip(self))]
398    async fn on_interrupt(
399        &self,
400        workflow_name: &str,
401        execution_id: &str,
402    ) -> Result<(), ClusterError> {
403        // Wait for the special interrupt signal
404        let _ = self
405            .await_deferred(workflow_name, execution_id, INTERRUPT_SIGNAL)
406            .await?;
407        Ok(())
408    }
409}
410
411impl SqlWorkflowEngine {
412    /// Wait for a timer to fire, polling the database.
413    #[tracing::instrument(level = "debug", skip(self))]
414    async fn wait_for_timer(
415        &self,
416        workflow_name: &str,
417        execution_id: &str,
418        name: &str,
419        fire_at: DateTime<Utc>,
420    ) -> Result<(), ClusterError> {
421        let notifier = self.get_timer_notifier(workflow_name, execution_id, name);
422
423        loop {
424            let now = Utc::now();
425
426            if now >= fire_at {
427                // Timer is due - try to mark as fired (atomic check-and-set)
428                let result = sqlx::query(
429                    "UPDATE cluster_workflow_timers 
430                     SET fired = TRUE 
431                     WHERE workflow_name = $1 AND execution_id = $2 AND timer_name = $3 AND fired = FALSE",
432                )
433                .bind(workflow_name)
434                .bind(execution_id)
435                .bind(name)
436                .execute(&self.pool)
437                .await
438                .map_err(|e| ClusterError::PersistenceError {
439                    reason: format!("workflow engine timer fire failed: {e}"),
440                    source: Some(Box::new(e)),
441                })?;
442
443                if result.rows_affected() > 0 {
444                    // We successfully fired the timer
445                    return Ok(());
446                }
447
448                // Someone else fired it - check if it's actually fired
449                let fired: bool = sqlx::query(
450                    "SELECT fired FROM cluster_workflow_timers 
451                     WHERE workflow_name = $1 AND execution_id = $2 AND timer_name = $3",
452                )
453                .bind(workflow_name)
454                .bind(execution_id)
455                .bind(name)
456                .fetch_optional(&self.pool)
457                .await
458                .map_err(|e| ClusterError::PersistenceError {
459                    reason: format!("workflow engine timer check failed: {e}"),
460                    source: Some(Box::new(e)),
461                })?
462                .map(|r| r.get("fired"))
463                .unwrap_or(true); // If deleted, treat as fired
464
465                if fired {
466                    return Ok(());
467                }
468                // Rare race condition - retry
469                continue;
470            }
471
472            // Calculate how long to wait
473            let remaining = (fire_at - now).to_std().unwrap_or(Duration::ZERO);
474            let wait_time = remaining.min(self.poll_interval);
475
476            // Wait for notification or timeout
477            tokio::select! {
478                _ = notifier.notified() => {
479                    // Got notification (e.g., for testing), re-check immediately
480                }
481                _ = tokio::time::sleep(wait_time) => {
482                    // Timeout, re-check
483                }
484            }
485        }
486    }
487
488    /// Notify a timer (for testing - allows immediate wake-up).
489    #[cfg(test)]
490    pub fn notify_timer(&self, workflow_name: &str, execution_id: &str, name: &str) {
491        let key = (
492            workflow_name.to_string(),
493            execution_id.to_string(),
494            name.to_string(),
495        );
496        if let Some(notifier) = self.timer_notifiers.get(&key) {
497            notifier.notify_waiters();
498        }
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn sql_workflow_engine_is_send_sync() {
508        fn assert_send_sync<T: Send + Sync>() {}
509        assert_send_sync::<SqlWorkflowEngine>();
510    }
511}