Skip to main content

durable/
executor.rs

1use std::time::Duration;
2
3use sea_orm::sea_query::OnConflict;
4use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
5use tokio::task::JoinHandle;
6use uuid::Uuid;
7
8use crate::ctx::TS;
9use crate::error::DurableError;
10use durable_db::entity::executor_heartbeat;
11
12/// Configuration for the executor heartbeat loop and recovery.
13///
14/// The staleness_threshold should be strictly greater than heartbeat_interval
15/// (recommended: 3x). Default is 60s interval and 180s threshold.
16pub struct HeartbeatConfig {
17    /// How often the executor writes a heartbeat. Default: 60s.
18    pub heartbeat_interval: Duration,
19    /// Tasks from executors whose last heartbeat is older than this are reset to PENDING. Default: 180s.
20    pub staleness_threshold: Duration,
21}
22
23impl Default for HeartbeatConfig {
24    fn default() -> Self {
25        Self {
26            heartbeat_interval: Duration::from_secs(60),
27            staleness_threshold: Duration::from_secs(180),
28        }
29    }
30}
31
32/// Information about a task that was recovered from a stale RUNNING state.
33pub struct RecoveredTask {
34    pub id: Uuid,
35    pub parent_id: Option<Uuid>,
36    pub name: String,
37    pub kind: String,
38    /// The static handler function name (from `#[durable::workflow]`).
39    /// Used to look up the registered resume function on recovery.
40    /// Falls back to `name` when `None` (backwards compat).
41    pub handler: Option<String>,
42}
43
44/// Parse a row from a `RETURNING id, parent_id, name, kind, handler, handler` clause.
45fn parse_recovered_row(row: &sea_orm::QueryResult) -> Result<RecoveredTask, DurableError> {
46    let id: Uuid = row
47        .try_get_by_index(0)
48        .map_err(|e| DurableError::custom(e.to_string()))?;
49    let parent_id: Option<Uuid> = row.try_get_by_index(1).ok().flatten();
50    let name: String = row
51        .try_get_by_index(2)
52        .map_err(|e| DurableError::custom(e.to_string()))?;
53    let kind: String = row
54        .try_get_by_index(3)
55        .map_err(|e| DurableError::custom(e.to_string()))?;
56    let handler: Option<String> = row.try_get_by_index(4).ok().flatten();
57    Ok(RecoveredTask {
58        id,
59        parent_id,
60        name,
61        kind,
62        handler,
63    })
64}
65
66/// Executor configuration for a durable worker.
67pub struct Executor {
68    db: DatabaseConnection,
69    executor_id: String,
70}
71
72impl Executor {
73    pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
74        Self { db, executor_id }
75    }
76
77    pub fn db(&self) -> &DatabaseConnection {
78        &self.db
79    }
80
81    pub fn executor_id(&self) -> &str {
82        &self.executor_id
83    }
84
85    /// Write (or upsert) a heartbeat row for this executor.
86    pub async fn heartbeat(&self) -> Result<(), DurableError> {
87        let now = chrono::Utc::now().into();
88        let model = executor_heartbeat::ActiveModel {
89            executor_id: Set(self.executor_id.clone()),
90            last_seen: Set(now),
91        };
92
93        executor_heartbeat::Entity::insert(model)
94            .on_conflict(
95                OnConflict::column(executor_heartbeat::Column::ExecutorId)
96                    .update_column(executor_heartbeat::Column::LastSeen)
97                    .to_owned(),
98            )
99            .exec(&self.db)
100            .await
101            .map_err(DurableError::from)?;
102
103        Ok(())
104    }
105
106    /// Spawn a background tokio task that calls `heartbeat()` every `config.heartbeat_interval`.
107    ///
108    /// Returns a `JoinHandle` so the caller can abort it on graceful shutdown.
109    ///
110    /// The inner loop is wrapped in `catch_unwind` so that a panic inside
111    /// `heartbeat()` (e.g. from a driver bug) does not silently kill the task.
112    /// On panic the loop logs an error and restarts after one tick interval.
113    pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
114        let db = self.db.clone();
115        let executor_id = self.executor_id.clone();
116        let interval = config.heartbeat_interval;
117
118        tokio::spawn(async move {
119            let executor = Executor::new(db, executor_id);
120            let mut ticker = tokio::time::interval(interval);
121            ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
122            loop {
123                ticker.tick().await;
124                // Spawn each heartbeat write in a child task so that a panic
125                // inside `heartbeat()` is caught by the JoinHandle rather than
126                // killing the entire loop. Without this, a driver-level panic
127                // silently stops all heartbeats and causes false dead-worker
128                // detection.
129                let hb_db = executor.db.clone();
130                let hb_eid = executor.executor_id.clone();
131                let child =
132                    tokio::spawn(async move { Executor::new(hb_db, hb_eid).heartbeat().await });
133                match child.await {
134                    Ok(Ok(())) => {}
135                    Ok(Err(e)) => {
136                        tracing::warn!("heartbeat write failed: {e}");
137                    }
138                    Err(join_err) => {
139                        // JoinError is either a panic or a cancellation.
140                        tracing::error!(
141                            "heartbeat task panicked, will retry on next tick: {join_err}"
142                        );
143                    }
144                }
145            }
146        })
147    }
148
149    /// Claim stale tasks from dead workers in a single atomic operation.
150    /// Uses `FOR UPDATE SKIP LOCKED` so multiple executors can run this
151    /// concurrently without blocking each other — each executor claims a
152    /// disjoint set of tasks.
153    ///
154    /// Catches three cases:
155    /// 1. Tasks whose `executor_id` has a stale heartbeat (> staleness_threshold)
156    /// 2. Tasks whose `executor_id` has no heartbeat row at all
157    /// 3. Orphaned tasks with NULL `executor_id` that have been RUNNING too long
158    ///
159    /// Claimed tasks go directly to RUNNING (no intermediate PENDING state)
160    /// with `started_at` reset and `executor_id` set to this executor.
161    pub async fn recover_stale_tasks(
162        &self,
163        staleness_threshold: Duration,
164    ) -> Result<Vec<RecoveredTask>, DurableError> {
165        let threshold_secs = staleness_threshold.as_secs();
166        let sql = format!(
167            "WITH claimable AS ( \
168                 SELECT id FROM durable.task \
169                 WHERE status = 'RUNNING'{TS} \
170                   AND ( \
171                       (executor_id IN ( \
172                           SELECT executor_id FROM durable.executor_heartbeat \
173                           WHERE last_seen < now() - interval '{threshold_secs} seconds' \
174                       ) AND executor_id != '{eid}') \
175                       OR \
176                       (executor_id IS NOT NULL \
177                        AND executor_id != '{eid}' \
178                        AND executor_id NOT IN ( \
179                            SELECT executor_id FROM durable.executor_heartbeat \
180                        )) \
181                       OR \
182                       (executor_id IS NULL \
183                        AND started_at IS NOT NULL \
184                        AND started_at < now() - interval '{threshold_secs} seconds') \
185                   ) \
186                 FOR UPDATE SKIP LOCKED \
187             ) \
188             UPDATE durable.task \
189             SET status = 'RUNNING'{TS}, started_at = now(), executor_id = '{eid}' \
190             WHERE id IN (SELECT id FROM claimable) \
191             RETURNING id, parent_id, name, kind, handler",
192            eid = self.executor_id
193        );
194        let rows = self
195            .db
196            .query_all(Statement::from_string(DbBackend::Postgres, sql))
197            .await?;
198
199        rows.iter().map(parse_recovered_row).collect()
200    }
201
202    /// Claim timed-out tasks in a single atomic operation.
203    /// Uses `FOR UPDATE SKIP LOCKED` for safe concurrent claiming.
204    ///
205    /// A task is considered timed-out if:
206    /// - `status = 'RUNNING'` and `started_at IS NOT NULL`
207    /// - `timeout_ms` has elapsed, or `deadline_epoch_ms` has passed
208    ///
209    /// Tasks without `timeout_ms` set are never considered timed-out.
210    pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
211        let sql = format!(
212            "WITH claimable AS ( \
213                 SELECT id FROM durable.task \
214                 WHERE status = 'RUNNING'{TS} \
215                   AND started_at IS NOT NULL \
216                   AND ( \
217                     (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
218                     OR \
219                     (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
220                   ) \
221                 FOR UPDATE SKIP LOCKED \
222             ) \
223             UPDATE durable.task \
224             SET status = 'RUNNING'{TS}, started_at = now(), deadline_epoch_ms = NULL, \
225                 executor_id = '{eid}' \
226             WHERE id IN (SELECT id FROM claimable) \
227             RETURNING id, parent_id, name, kind, handler",
228            eid = self.executor_id
229        );
230
231        let rows = self
232            .db
233            .query_all(Statement::from_string(DbBackend::Postgres, sql))
234            .await?;
235
236        rows.iter().map(parse_recovered_row).collect()
237    }
238
239    /// Reset RUNNING child steps of recovered workflows back to PENDING.
240    ///
241    /// After a crash, TX1 in `step()` may have committed (setting the step to
242    /// RUNNING) but the closure never completed. These orphaned steps block
243    /// re-execution because `find_or_create_task` sees RUNNING and returns
244    /// `StepLocked`. This method resets them so the recovered workflow can
245    /// re-execute its steps from scratch.
246    pub async fn reset_orphaned_steps(&self, recovered_ids: &[Uuid]) -> Result<u64, DurableError> {
247        if recovered_ids.is_empty() {
248            return Ok(0);
249        }
250        let id_list: String = recovered_ids
251            .iter()
252            .map(|id| format!("'{id}'"))
253            .collect::<Vec<_>>()
254            .join(",");
255        let sql = format!(
256            "UPDATE durable.task \
257             SET status = 'PENDING'{TS}, started_at = NULL \
258             WHERE parent_id IN ({id_list}) \
259               AND status = 'RUNNING'{TS} \
260               AND kind = 'STEP'"
261        );
262        let result = self
263            .db
264            .execute(Statement::from_string(DbBackend::Postgres, sql))
265            .await?;
266        Ok(result.rows_affected())
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_heartbeat_config_default() {
276        let config = HeartbeatConfig::default();
277        assert_eq!(
278            config.heartbeat_interval,
279            Duration::from_secs(60),
280            "default heartbeat_interval should be 60s"
281        );
282        assert_eq!(
283            config.staleness_threshold,
284            Duration::from_secs(180),
285            "default staleness_threshold should be 180s"
286        );
287    }
288
289    #[test]
290    fn test_heartbeat_config_custom() {
291        let config = HeartbeatConfig {
292            heartbeat_interval: Duration::from_secs(30),
293            staleness_threshold: Duration::from_secs(90),
294        };
295        assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
296        assert_eq!(config.staleness_threshold, Duration::from_secs(90));
297    }
298}