Skip to main content

durable/
executor.rs

1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use durable_db::entity::executor_heartbeat;
9
10/// Configuration for the executor heartbeat loop and recovery.
11///
12/// The staleness_threshold should be strictly greater than heartbeat_interval
13/// (recommended: 3x). Default is 60s interval and 180s threshold.
14pub struct HeartbeatConfig {
15    /// How often the executor writes a heartbeat. Default: 60s.
16    pub heartbeat_interval: Duration,
17    /// Tasks from executors whose last heartbeat is older than this are reset to PENDING. Default: 180s.
18    pub staleness_threshold: Duration,
19}
20
21impl Default for HeartbeatConfig {
22    fn default() -> Self {
23        Self {
24            heartbeat_interval: Duration::from_secs(60),
25            staleness_threshold: Duration::from_secs(180),
26        }
27    }
28}
29
30/// Information about a task that was recovered from a stale RUNNING state.
31pub struct RecoveredTask {
32    pub id: Uuid,
33    pub parent_id: Option<Uuid>,
34    pub name: String,
35    pub kind: String,
36    /// The static handler function name (from `#[durable::workflow]`).
37    /// Used to look up the registered resume function on recovery.
38    /// Falls back to `name` when `None` (backwards compat).
39    pub handler: Option<String>,
40}
41
42/// Parse a row from a `RETURNING id, parent_id, name, kind, handler, handler` clause.
43fn parse_recovered_row(row: &sea_orm::QueryResult) -> Result<RecoveredTask, DurableError> {
44    let id: Uuid = row
45        .try_get_by_index(0)
46        .map_err(|e| DurableError::custom(e.to_string()))?;
47    let parent_id: Option<Uuid> = row.try_get_by_index(1).ok().flatten();
48    let name: String = row
49        .try_get_by_index(2)
50        .map_err(|e| DurableError::custom(e.to_string()))?;
51    let kind: String = row
52        .try_get_by_index(3)
53        .map_err(|e| DurableError::custom(e.to_string()))?;
54    let handler: Option<String> = row.try_get_by_index(4).ok().flatten();
55    Ok(RecoveredTask {
56        id,
57        parent_id,
58        name,
59        kind,
60        handler,
61    })
62}
63
64/// Executor configuration for a durable worker.
65pub struct Executor {
66    db: DatabaseConnection,
67    executor_id: String,
68}
69
70impl Executor {
71    pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
72        Self { db, executor_id }
73    }
74
75    pub fn db(&self) -> &DatabaseConnection {
76        &self.db
77    }
78
79    pub fn executor_id(&self) -> &str {
80        &self.executor_id
81    }
82
83    /// Write (or upsert) a heartbeat row for this executor.
84    pub async fn heartbeat(&self) -> Result<(), DurableError> {
85        let now = chrono::Utc::now().into();
86        let model = executor_heartbeat::ActiveModel {
87            executor_id: Set(self.executor_id.clone()),
88            last_seen: Set(now),
89        };
90
91        executor_heartbeat::Entity::insert(model)
92            .on_conflict(
93                OnConflict::column(executor_heartbeat::Column::ExecutorId)
94                    .update_column(executor_heartbeat::Column::LastSeen)
95                    .to_owned(),
96            )
97            .exec(&self.db)
98            .await
99            .map_err(DurableError::from)?;
100
101        Ok(())
102    }
103
104    /// Spawn a background tokio task that calls `heartbeat()` every `config.heartbeat_interval`.
105    ///
106    /// Returns a `JoinHandle` so the caller can abort it on graceful shutdown.
107    pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
108        let db = self.db.clone();
109        let executor_id = self.executor_id.clone();
110        let interval = config.heartbeat_interval;
111
112        tokio::spawn(async move {
113            let executor = Executor::new(db, executor_id);
114            let mut ticker = tokio::time::interval(interval);
115            loop {
116                ticker.tick().await;
117                if let Err(e) = executor.heartbeat().await {
118                    tracing::warn!("heartbeat write failed: {e}");
119                }
120            }
121        })
122    }
123
124    /// Claim stale tasks from dead workers in a single atomic operation.
125    /// Uses `FOR UPDATE SKIP LOCKED` so multiple executors can run this
126    /// concurrently without blocking each other — each executor claims a
127    /// disjoint set of tasks.
128    ///
129    /// Catches three cases:
130    /// 1. Tasks whose `executor_id` has a stale heartbeat (> staleness_threshold)
131    /// 2. Tasks whose `executor_id` has no heartbeat row at all
132    /// 3. Orphaned tasks with NULL `executor_id` that have been RUNNING too long
133    ///
134    /// Claimed tasks go directly to RUNNING (no intermediate PENDING state)
135    /// with `started_at` reset and `executor_id` set to this executor.
136    pub async fn recover_stale_tasks(
137        &self,
138        staleness_threshold: Duration,
139    ) -> Result<Vec<RecoveredTask>, DurableError> {
140        let threshold_secs = staleness_threshold.as_secs();
141        let sql = format!(
142            "WITH claimable AS ( \
143                 SELECT id FROM durable.task \
144                 WHERE status = 'RUNNING' \
145                   AND ( \
146                       (executor_id IN ( \
147                           SELECT executor_id FROM durable.executor_heartbeat \
148                           WHERE last_seen < now() - interval '{threshold_secs} seconds' \
149                       ) AND executor_id != '{eid}') \
150                       OR \
151                       (executor_id IS NOT NULL \
152                        AND executor_id != '{eid}' \
153                        AND executor_id NOT IN ( \
154                            SELECT executor_id FROM durable.executor_heartbeat \
155                        )) \
156                       OR \
157                       (executor_id IS NULL \
158                        AND started_at IS NOT NULL \
159                        AND started_at < now() - interval '{threshold_secs} seconds') \
160                   ) \
161                 FOR UPDATE SKIP LOCKED \
162             ) \
163             UPDATE durable.task \
164             SET status = 'RUNNING', started_at = now(), executor_id = '{eid}' \
165             WHERE id IN (SELECT id FROM claimable) \
166             RETURNING id, parent_id, name, kind, handler",
167            eid = self.executor_id
168        );
169        let rows = self
170            .db
171            .query_all(Statement::from_string(DbBackend::Postgres, sql))
172            .await?;
173
174        rows.iter().map(parse_recovered_row).collect()
175    }
176
177    /// Claim timed-out tasks in a single atomic operation.
178    /// Uses `FOR UPDATE SKIP LOCKED` for safe concurrent claiming.
179    ///
180    /// A task is considered timed-out if:
181    /// - `status = 'RUNNING'` and `started_at IS NOT NULL`
182    /// - `timeout_ms` has elapsed, or `deadline_epoch_ms` has passed
183    ///
184    /// Tasks without `timeout_ms` set are never considered timed-out.
185    pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
186        let sql = format!(
187            "WITH claimable AS ( \
188                 SELECT id FROM durable.task \
189                 WHERE status = 'RUNNING' \
190                   AND started_at IS NOT NULL \
191                   AND ( \
192                     (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
193                     OR \
194                     (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
195                   ) \
196                 FOR UPDATE SKIP LOCKED \
197             ) \
198             UPDATE durable.task \
199             SET status = 'RUNNING', started_at = now(), deadline_epoch_ms = NULL, \
200                 executor_id = '{eid}' \
201             WHERE id IN (SELECT id FROM claimable) \
202             RETURNING id, parent_id, name, kind, handler",
203            eid = self.executor_id
204        );
205
206        let rows = self
207            .db
208            .query_all(Statement::from_string(DbBackend::Postgres, sql))
209            .await?;
210
211        rows.iter().map(parse_recovered_row).collect()
212    }
213
214    /// Reset RUNNING child steps of recovered workflows back to PENDING.
215    ///
216    /// After a crash, TX1 in `step()` may have committed (setting the step to
217    /// RUNNING) but the closure never completed. These orphaned steps block
218    /// re-execution because `find_or_create_task` sees RUNNING and returns
219    /// `StepLocked`. This method resets them so the recovered workflow can
220    /// re-execute its steps from scratch.
221    pub async fn reset_orphaned_steps(
222        &self,
223        recovered_ids: &[Uuid],
224    ) -> Result<u64, DurableError> {
225        if recovered_ids.is_empty() {
226            return Ok(0);
227        }
228        let id_list: String = recovered_ids
229            .iter()
230            .map(|id| format!("'{id}'"))
231            .collect::<Vec<_>>()
232            .join(",");
233        let sql = format!(
234            "UPDATE durable.task \
235             SET status = 'PENDING', started_at = NULL \
236             WHERE parent_id IN ({id_list}) \
237               AND status = 'RUNNING' \
238               AND kind = 'STEP'"
239        );
240        let result = self
241            .db
242            .execute(Statement::from_string(DbBackend::Postgres, sql))
243            .await?;
244        Ok(result.rows_affected())
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_heartbeat_config_default() {
254        let config = HeartbeatConfig::default();
255        assert_eq!(
256            config.heartbeat_interval,
257            Duration::from_secs(60),
258            "default heartbeat_interval should be 60s"
259        );
260        assert_eq!(
261            config.staleness_threshold,
262            Duration::from_secs(180),
263            "default staleness_threshold should be 180s"
264        );
265    }
266
267    #[test]
268    fn test_heartbeat_config_custom() {
269        let config = HeartbeatConfig {
270            heartbeat_interval: Duration::from_secs(30),
271            staleness_threshold: Duration::from_secs(90),
272        };
273        assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
274        assert_eq!(config.staleness_threshold, Duration::from_secs(90));
275    }
276}