1use async_trait::async_trait;
32use chrono::{DateTime, Utc};
33use dashmap::DashMap;
34use sqlx::postgres::PgPool;
35use sqlx::Row;
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::sync::Notify;
39
40use crate::durable::{WorkflowEngine, INTERRUPT_SIGNAL};
41use crate::error::ClusterError;
42
43const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(100);
45
46pub struct SqlWorkflowEngine {
62 pool: PgPool,
63 poll_interval: Duration,
65 deferred_notifiers: DashMap<(String, String, String), Arc<Notify>>,
68 timer_notifiers: DashMap<(String, String, String), Arc<Notify>>,
70}
71
72impl SqlWorkflowEngine {
73 pub fn new(pool: PgPool) -> Self {
77 Self {
78 pool,
79 poll_interval: DEFAULT_POLL_INTERVAL,
80 deferred_notifiers: DashMap::new(),
81 timer_notifiers: DashMap::new(),
82 }
83 }
84
85 pub fn with_poll_interval(pool: PgPool, poll_interval: Duration) -> Self {
91 Self {
92 pool,
93 poll_interval,
94 deferred_notifiers: DashMap::new(),
95 timer_notifiers: DashMap::new(),
96 }
97 }
98
99 fn get_deferred_notifier(
101 &self,
102 workflow_name: &str,
103 execution_id: &str,
104 name: &str,
105 ) -> Arc<Notify> {
106 let key = (
107 workflow_name.to_string(),
108 execution_id.to_string(),
109 name.to_string(),
110 );
111 self.deferred_notifiers
112 .entry(key)
113 .or_insert_with(|| Arc::new(Notify::new()))
114 .clone()
115 }
116
117 fn get_timer_notifier(
119 &self,
120 workflow_name: &str,
121 execution_id: &str,
122 name: &str,
123 ) -> Arc<Notify> {
124 let key = (
125 workflow_name.to_string(),
126 execution_id.to_string(),
127 name.to_string(),
128 );
129 self.timer_notifiers
130 .entry(key)
131 .or_insert_with(|| Arc::new(Notify::new()))
132 .clone()
133 }
134
135 #[tracing::instrument(level = "debug", skip(self))]
140 pub async fn cleanup(&self, older_than: Duration) -> Result<u64, ClusterError> {
141 let cutoff =
142 Utc::now() - chrono::Duration::from_std(older_than).unwrap_or(chrono::TimeDelta::MAX);
143
144 let timers_deleted = sqlx::query(
146 "DELETE FROM cluster_workflow_timers WHERE fired = TRUE AND created_at < $1",
147 )
148 .bind(cutoff)
149 .execute(&self.pool)
150 .await
151 .map_err(|e| ClusterError::PersistenceError {
152 reason: format!("workflow engine timer cleanup failed: {e}"),
153 source: Some(Box::new(e)),
154 })?
155 .rows_affected();
156
157 let deferred_deleted = sqlx::query(
159 "DELETE FROM cluster_workflow_deferred WHERE resolved = TRUE AND resolved_at < $1",
160 )
161 .bind(cutoff)
162 .execute(&self.pool)
163 .await
164 .map_err(|e| ClusterError::PersistenceError {
165 reason: format!("workflow engine deferred cleanup failed: {e}"),
166 source: Some(Box::new(e)),
167 })?
168 .rows_affected();
169
170 Ok(timers_deleted + deferred_deleted)
171 }
172}
173
174#[async_trait]
175impl WorkflowEngine for SqlWorkflowEngine {
176 #[tracing::instrument(level = "debug", skip(self))]
177 async fn sleep(
178 &self,
179 workflow_name: &str,
180 execution_id: &str,
181 name: &str,
182 duration: Duration,
183 ) -> Result<(), ClusterError> {
184 let fire_at =
185 Utc::now() + chrono::Duration::from_std(duration).unwrap_or(chrono::TimeDelta::MAX);
186
187 let existing: Option<(bool, DateTime<Utc>)> = sqlx::query(
189 "SELECT fired, fire_at FROM cluster_workflow_timers
190 WHERE workflow_name = $1 AND execution_id = $2 AND timer_name = $3",
191 )
192 .bind(workflow_name)
193 .bind(execution_id)
194 .bind(name)
195 .fetch_optional(&self.pool)
196 .await
197 .map_err(|e| ClusterError::PersistenceError {
198 reason: format!("workflow engine sleep check failed: {e}"),
199 source: Some(Box::new(e)),
200 })?
201 .map(|row| {
202 let fired: bool = row.get("fired");
203 let fire_at: DateTime<Utc> = row.get("fire_at");
204 (fired, fire_at)
205 });
206
207 match existing {
208 Some((true, _)) => {
209 return Ok(());
211 }
212 Some((false, existing_fire_at)) => {
213 self.wait_for_timer(workflow_name, execution_id, name, existing_fire_at)
216 .await?;
217 }
218 None => {
219 sqlx::query(
221 "INSERT INTO cluster_workflow_timers (workflow_name, execution_id, timer_name, fire_at)
222 VALUES ($1, $2, $3, $4)",
223 )
224 .bind(workflow_name)
225 .bind(execution_id)
226 .bind(name)
227 .bind(fire_at)
228 .execute(&self.pool)
229 .await
230 .map_err(|e| ClusterError::PersistenceError {
231 reason: format!("workflow engine sleep create failed: {e}"),
232 source: Some(Box::new(e)),
233 })?;
234
235 self.wait_for_timer(workflow_name, execution_id, name, fire_at)
237 .await?;
238 }
239 }
240
241 Ok(())
242 }
243
244 #[tracing::instrument(level = "debug", skip(self))]
245 async fn await_deferred(
246 &self,
247 workflow_name: &str,
248 execution_id: &str,
249 name: &str,
250 ) -> Result<Vec<u8>, ClusterError> {
251 let existing: Option<(bool, Option<Vec<u8>>)> = sqlx::query(
253 "SELECT resolved, value FROM cluster_workflow_deferred
254 WHERE workflow_name = $1 AND execution_id = $2 AND deferred_name = $3",
255 )
256 .bind(workflow_name)
257 .bind(execution_id)
258 .bind(name)
259 .fetch_optional(&self.pool)
260 .await
261 .map_err(|e| ClusterError::PersistenceError {
262 reason: format!("workflow engine await_deferred check failed: {e}"),
263 source: Some(Box::new(e)),
264 })?
265 .map(|row| {
266 let resolved: bool = row.get("resolved");
267 let value: Option<Vec<u8>> = row.get("value");
268 (resolved, value)
269 });
270
271 match existing {
272 Some((true, Some(value))) => {
273 return Ok(value);
275 }
276 Some((true, None)) => {
277 return Err(ClusterError::PersistenceError {
279 reason: format!(
280 "deferred value resolved but missing: {}/{}/{}",
281 workflow_name, execution_id, name
282 ),
283 source: None,
284 });
285 }
286 Some((false, _)) => {
287 }
289 None => {
290 sqlx::query(
292 "INSERT INTO cluster_workflow_deferred (workflow_name, execution_id, deferred_name, resolved)
293 VALUES ($1, $2, $3, FALSE)
294 ON CONFLICT (workflow_name, execution_id, deferred_name) DO NOTHING",
295 )
296 .bind(workflow_name)
297 .bind(execution_id)
298 .bind(name)
299 .execute(&self.pool)
300 .await
301 .map_err(|e| ClusterError::PersistenceError {
302 reason: format!("workflow engine await_deferred create failed: {e}"),
303 source: Some(Box::new(e)),
304 })?;
305 }
306 }
307
308 let notifier = self.get_deferred_notifier(workflow_name, execution_id, name);
310
311 loop {
312 let row: Option<(bool, Option<Vec<u8>>)> = sqlx::query(
314 "SELECT resolved, value FROM cluster_workflow_deferred
315 WHERE workflow_name = $1 AND execution_id = $2 AND deferred_name = $3",
316 )
317 .bind(workflow_name)
318 .bind(execution_id)
319 .bind(name)
320 .fetch_optional(&self.pool)
321 .await
322 .map_err(|e| ClusterError::PersistenceError {
323 reason: format!("workflow engine await_deferred poll failed: {e}"),
324 source: Some(Box::new(e)),
325 })?
326 .map(|r| {
327 let resolved: bool = r.get("resolved");
328 let value: Option<Vec<u8>> = r.get("value");
329 (resolved, value)
330 });
331
332 match row {
333 Some((true, Some(value))) => return Ok(value),
334 Some((true, None)) => {
335 return Err(ClusterError::PersistenceError {
336 reason: format!(
337 "deferred value resolved but missing: {}/{}/{}",
338 workflow_name, execution_id, name
339 ),
340 source: None,
341 });
342 }
343 _ => {
344 tokio::select! {
346 _ = notifier.notified() => {
347 }
349 _ = tokio::time::sleep(self.poll_interval) => {
350 }
352 }
353 }
354 }
355 }
356 }
357
358 #[tracing::instrument(level = "debug", skip(self, value))]
359 async fn resolve_deferred(
360 &self,
361 workflow_name: &str,
362 execution_id: &str,
363 name: &str,
364 value: Vec<u8>,
365 ) -> Result<(), ClusterError> {
366 sqlx::query(
368 "INSERT INTO cluster_workflow_deferred (workflow_name, execution_id, deferred_name, value, resolved, resolved_at)
369 VALUES ($1, $2, $3, $4, TRUE, NOW())
370 ON CONFLICT (workflow_name, execution_id, deferred_name)
371 DO UPDATE SET value = $4, resolved = TRUE, resolved_at = NOW()",
372 )
373 .bind(workflow_name)
374 .bind(execution_id)
375 .bind(name)
376 .bind(&value)
377 .execute(&self.pool)
378 .await
379 .map_err(|e| ClusterError::PersistenceError {
380 reason: format!("workflow engine resolve_deferred failed: {e}"),
381 source: Some(Box::new(e)),
382 })?;
383
384 let key = (
386 workflow_name.to_string(),
387 execution_id.to_string(),
388 name.to_string(),
389 );
390 if let Some(notifier) = self.deferred_notifiers.get(&key) {
391 notifier.notify_waiters();
392 }
393
394 Ok(())
395 }
396
397 #[tracing::instrument(level = "debug", skip(self))]
398 async fn on_interrupt(
399 &self,
400 workflow_name: &str,
401 execution_id: &str,
402 ) -> Result<(), ClusterError> {
403 let _ = self
405 .await_deferred(workflow_name, execution_id, INTERRUPT_SIGNAL)
406 .await?;
407 Ok(())
408 }
409}
410
411impl SqlWorkflowEngine {
412 #[tracing::instrument(level = "debug", skip(self))]
414 async fn wait_for_timer(
415 &self,
416 workflow_name: &str,
417 execution_id: &str,
418 name: &str,
419 fire_at: DateTime<Utc>,
420 ) -> Result<(), ClusterError> {
421 let notifier = self.get_timer_notifier(workflow_name, execution_id, name);
422
423 loop {
424 let now = Utc::now();
425
426 if now >= fire_at {
427 let result = sqlx::query(
429 "UPDATE cluster_workflow_timers
430 SET fired = TRUE
431 WHERE workflow_name = $1 AND execution_id = $2 AND timer_name = $3 AND fired = FALSE",
432 )
433 .bind(workflow_name)
434 .bind(execution_id)
435 .bind(name)
436 .execute(&self.pool)
437 .await
438 .map_err(|e| ClusterError::PersistenceError {
439 reason: format!("workflow engine timer fire failed: {e}"),
440 source: Some(Box::new(e)),
441 })?;
442
443 if result.rows_affected() > 0 {
444 return Ok(());
446 }
447
448 let fired: bool = sqlx::query(
450 "SELECT fired FROM cluster_workflow_timers
451 WHERE workflow_name = $1 AND execution_id = $2 AND timer_name = $3",
452 )
453 .bind(workflow_name)
454 .bind(execution_id)
455 .bind(name)
456 .fetch_optional(&self.pool)
457 .await
458 .map_err(|e| ClusterError::PersistenceError {
459 reason: format!("workflow engine timer check failed: {e}"),
460 source: Some(Box::new(e)),
461 })?
462 .map(|r| r.get("fired"))
463 .unwrap_or(true); if fired {
466 return Ok(());
467 }
468 continue;
470 }
471
472 let remaining = (fire_at - now).to_std().unwrap_or(Duration::ZERO);
474 let wait_time = remaining.min(self.poll_interval);
475
476 tokio::select! {
478 _ = notifier.notified() => {
479 }
481 _ = tokio::time::sleep(wait_time) => {
482 }
484 }
485 }
486 }
487
488 #[cfg(test)]
490 pub fn notify_timer(&self, workflow_name: &str, execution_id: &str, name: &str) {
491 let key = (
492 workflow_name.to_string(),
493 execution_id.to_string(),
494 name.to_string(),
495 );
496 if let Some(notifier) = self.timer_notifiers.get(&key) {
497 notifier.notify_waiters();
498 }
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn sql_workflow_engine_is_send_sync() {
508 fn assert_send_sync<T: Send + Sync>() {}
509 assert_send_sync::<SqlWorkflowEngine>();
510 }
511}