Skip to main content

durable/
executor.rs

1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use crate::ctx::TS;
9use durable_db::entity::executor_heartbeat;
10
11/// Configuration for the executor heartbeat loop and recovery.
12///
13/// The staleness_threshold should be strictly greater than heartbeat_interval
14/// (recommended: 3x). Default is 60s interval and 180s threshold.
15pub struct HeartbeatConfig {
16    /// How often the executor writes a heartbeat. Default: 60s.
17    pub heartbeat_interval: Duration,
18    /// Tasks from executors whose last heartbeat is older than this are reset to PENDING. Default: 180s.
19    pub staleness_threshold: Duration,
20}
21
22impl Default for HeartbeatConfig {
23    fn default() -> Self {
24        Self {
25            heartbeat_interval: Duration::from_secs(60),
26            staleness_threshold: Duration::from_secs(180),
27        }
28    }
29}
30
31/// Information about a task that was recovered from a stale RUNNING state.
32pub struct RecoveredTask {
33    pub id: Uuid,
34    pub parent_id: Option<Uuid>,
35    pub name: String,
36    pub kind: String,
37    /// The static handler function name (from `#[durable::workflow]`).
38    /// Used to look up the registered resume function on recovery.
39    /// Falls back to `name` when `None` (backwards compat).
40    pub handler: Option<String>,
41}
42
43/// Parse a row from a `RETURNING id, parent_id, name, kind, handler, handler` clause.
44fn parse_recovered_row(row: &sea_orm::QueryResult) -> Result<RecoveredTask, DurableError> {
45    let id: Uuid = row
46        .try_get_by_index(0)
47        .map_err(|e| DurableError::custom(e.to_string()))?;
48    let parent_id: Option<Uuid> = row.try_get_by_index(1).ok().flatten();
49    let name: String = row
50        .try_get_by_index(2)
51        .map_err(|e| DurableError::custom(e.to_string()))?;
52    let kind: String = row
53        .try_get_by_index(3)
54        .map_err(|e| DurableError::custom(e.to_string()))?;
55    let handler: Option<String> = row.try_get_by_index(4).ok().flatten();
56    Ok(RecoveredTask {
57        id,
58        parent_id,
59        name,
60        kind,
61        handler,
62    })
63}
64
65/// Executor configuration for a durable worker.
66pub struct Executor {
67    db: DatabaseConnection,
68    executor_id: String,
69}
70
71impl Executor {
72    pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
73        Self { db, executor_id }
74    }
75
76    pub fn db(&self) -> &DatabaseConnection {
77        &self.db
78    }
79
80    pub fn executor_id(&self) -> &str {
81        &self.executor_id
82    }
83
84    /// Write (or upsert) a heartbeat row for this executor.
85    pub async fn heartbeat(&self) -> Result<(), DurableError> {
86        let now = chrono::Utc::now().into();
87        let model = executor_heartbeat::ActiveModel {
88            executor_id: Set(self.executor_id.clone()),
89            last_seen: Set(now),
90        };
91
92        executor_heartbeat::Entity::insert(model)
93            .on_conflict(
94                OnConflict::column(executor_heartbeat::Column::ExecutorId)
95                    .update_column(executor_heartbeat::Column::LastSeen)
96                    .to_owned(),
97            )
98            .exec(&self.db)
99            .await
100            .map_err(DurableError::from)?;
101
102        Ok(())
103    }
104
105    /// Spawn a background tokio task that calls `heartbeat()` every `config.heartbeat_interval`.
106    ///
107    /// Returns a `JoinHandle` so the caller can abort it on graceful shutdown.
108    pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
109        let db = self.db.clone();
110        let executor_id = self.executor_id.clone();
111        let interval = config.heartbeat_interval;
112
113        tokio::spawn(async move {
114            let executor = Executor::new(db, executor_id);
115            let mut ticker = tokio::time::interval(interval);
116            loop {
117                ticker.tick().await;
118                if let Err(e) = executor.heartbeat().await {
119                    tracing::warn!("heartbeat write failed: {e}");
120                }
121            }
122        })
123    }
124
125    /// Claim stale tasks from dead workers in a single atomic operation.
126    /// Uses `FOR UPDATE SKIP LOCKED` so multiple executors can run this
127    /// concurrently without blocking each other — each executor claims a
128    /// disjoint set of tasks.
129    ///
130    /// Catches three cases:
131    /// 1. Tasks whose `executor_id` has a stale heartbeat (> staleness_threshold)
132    /// 2. Tasks whose `executor_id` has no heartbeat row at all
133    /// 3. Orphaned tasks with NULL `executor_id` that have been RUNNING too long
134    ///
135    /// Claimed tasks go directly to RUNNING (no intermediate PENDING state)
136    /// with `started_at` reset and `executor_id` set to this executor.
137    pub async fn recover_stale_tasks(
138        &self,
139        staleness_threshold: Duration,
140    ) -> Result<Vec<RecoveredTask>, DurableError> {
141        let threshold_secs = staleness_threshold.as_secs();
142        let sql = format!(
143            "WITH claimable AS ( \
144                 SELECT id FROM durable.task \
145                 WHERE status = 'RUNNING'{TS} \
146                   AND ( \
147                       (executor_id IN ( \
148                           SELECT executor_id FROM durable.executor_heartbeat \
149                           WHERE last_seen < now() - interval '{threshold_secs} seconds' \
150                       ) AND executor_id != '{eid}') \
151                       OR \
152                       (executor_id IS NOT NULL \
153                        AND executor_id != '{eid}' \
154                        AND executor_id NOT IN ( \
155                            SELECT executor_id FROM durable.executor_heartbeat \
156                        )) \
157                       OR \
158                       (executor_id IS NULL \
159                        AND started_at IS NOT NULL \
160                        AND started_at < now() - interval '{threshold_secs} seconds') \
161                   ) \
162                 FOR UPDATE SKIP LOCKED \
163             ) \
164             UPDATE durable.task \
165             SET status = 'RUNNING'{TS}, started_at = now(), executor_id = '{eid}' \
166             WHERE id IN (SELECT id FROM claimable) \
167             RETURNING id, parent_id, name, kind, handler",
168            eid = self.executor_id
169        );
170        let rows = self
171            .db
172            .query_all(Statement::from_string(DbBackend::Postgres, sql))
173            .await?;
174
175        rows.iter().map(parse_recovered_row).collect()
176    }
177
178    /// Claim timed-out tasks in a single atomic operation.
179    /// Uses `FOR UPDATE SKIP LOCKED` for safe concurrent claiming.
180    ///
181    /// A task is considered timed-out if:
182    /// - `status = 'RUNNING'` and `started_at IS NOT NULL`
183    /// - `timeout_ms` has elapsed, or `deadline_epoch_ms` has passed
184    ///
185    /// Tasks without `timeout_ms` set are never considered timed-out.
186    pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
187        let sql = format!(
188            "WITH claimable AS ( \
189                 SELECT id FROM durable.task \
190                 WHERE status = 'RUNNING'{TS} \
191                   AND started_at IS NOT NULL \
192                   AND ( \
193                     (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
194                     OR \
195                     (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
196                   ) \
197                 FOR UPDATE SKIP LOCKED \
198             ) \
199             UPDATE durable.task \
200             SET status = 'RUNNING'{TS}, started_at = now(), deadline_epoch_ms = NULL, \
201                 executor_id = '{eid}' \
202             WHERE id IN (SELECT id FROM claimable) \
203             RETURNING id, parent_id, name, kind, handler",
204            eid = self.executor_id
205        );
206
207        let rows = self
208            .db
209            .query_all(Statement::from_string(DbBackend::Postgres, sql))
210            .await?;
211
212        rows.iter().map(parse_recovered_row).collect()
213    }
214
215    /// Reset RUNNING child steps of recovered workflows back to PENDING.
216    ///
217    /// After a crash, TX1 in `step()` may have committed (setting the step to
218    /// RUNNING) but the closure never completed. These orphaned steps block
219    /// re-execution because `find_or_create_task` sees RUNNING and returns
220    /// `StepLocked`. This method resets them so the recovered workflow can
221    /// re-execute its steps from scratch.
222    pub async fn reset_orphaned_steps(
223        &self,
224        recovered_ids: &[Uuid],
225    ) -> Result<u64, DurableError> {
226        if recovered_ids.is_empty() {
227            return Ok(0);
228        }
229        let id_list: String = recovered_ids
230            .iter()
231            .map(|id| format!("'{id}'"))
232            .collect::<Vec<_>>()
233            .join(",");
234        let sql = format!(
235            "UPDATE durable.task \
236             SET status = 'PENDING'{TS}, started_at = NULL \
237             WHERE parent_id IN ({id_list}) \
238               AND status = 'RUNNING'{TS} \
239               AND kind = 'STEP'"
240        );
241        let result = self
242            .db
243            .execute(Statement::from_string(DbBackend::Postgres, sql))
244            .await?;
245        Ok(result.rows_affected())
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_heartbeat_config_default() {
255        let config = HeartbeatConfig::default();
256        assert_eq!(
257            config.heartbeat_interval,
258            Duration::from_secs(60),
259            "default heartbeat_interval should be 60s"
260        );
261        assert_eq!(
262            config.staleness_threshold,
263            Duration::from_secs(180),
264            "default staleness_threshold should be 180s"
265        );
266    }
267
268    #[test]
269    fn test_heartbeat_config_custom() {
270        let config = HeartbeatConfig {
271            heartbeat_interval: Duration::from_secs(30),
272            staleness_threshold: Duration::from_secs(90),
273        };
274        assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
275        assert_eq!(config.staleness_threshold, Duration::from_secs(90));
276    }
277}