Skip to main content

durable/
executor.rs

1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use durable_db::entity::executor_heartbeat;
9
10/// Configuration for the executor heartbeat loop and recovery.
11///
12/// The staleness_threshold should be strictly greater than heartbeat_interval
13/// (recommended: 3x). Default is 60s interval and 180s threshold.
14pub struct HeartbeatConfig {
15    /// How often the executor writes a heartbeat. Default: 60s.
16    pub heartbeat_interval: Duration,
17    /// Tasks from executors whose last heartbeat is older than this are reset to PENDING. Default: 180s.
18    pub staleness_threshold: Duration,
19}
20
21impl Default for HeartbeatConfig {
22    fn default() -> Self {
23        Self {
24            heartbeat_interval: Duration::from_secs(60),
25            staleness_threshold: Duration::from_secs(180),
26        }
27    }
28}
29
30/// Information about a task that was recovered from a stale RUNNING state.
31pub struct RecoveredTask {
32    pub id: Uuid,
33    pub parent_id: Option<Uuid>,
34    pub name: String,
35    pub kind: String,
36}
37
38/// Parse a row from a `RETURNING id, parent_id, name, kind` clause.
39fn parse_recovered_row(row: &sea_orm::QueryResult) -> Result<RecoveredTask, DurableError> {
40    let id: Uuid = row
41        .try_get_by_index(0)
42        .map_err(|e| DurableError::custom(e.to_string()))?;
43    let parent_id: Option<Uuid> = row.try_get_by_index(1).ok().flatten();
44    let name: String = row
45        .try_get_by_index(2)
46        .map_err(|e| DurableError::custom(e.to_string()))?;
47    let kind: String = row
48        .try_get_by_index(3)
49        .map_err(|e| DurableError::custom(e.to_string()))?;
50    Ok(RecoveredTask {
51        id,
52        parent_id,
53        name,
54        kind,
55    })
56}
57
58/// Executor configuration for a durable worker.
59pub struct Executor {
60    db: DatabaseConnection,
61    executor_id: String,
62}
63
64impl Executor {
65    pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
66        Self { db, executor_id }
67    }
68
69    pub fn db(&self) -> &DatabaseConnection {
70        &self.db
71    }
72
73    pub fn executor_id(&self) -> &str {
74        &self.executor_id
75    }
76
77    /// Write (or upsert) a heartbeat row for this executor.
78    pub async fn heartbeat(&self) -> Result<(), DurableError> {
79        let now = chrono::Utc::now().into();
80        let model = executor_heartbeat::ActiveModel {
81            executor_id: Set(self.executor_id.clone()),
82            last_seen: Set(now),
83        };
84
85        executor_heartbeat::Entity::insert(model)
86            .on_conflict(
87                OnConflict::column(executor_heartbeat::Column::ExecutorId)
88                    .update_column(executor_heartbeat::Column::LastSeen)
89                    .to_owned(),
90            )
91            .exec(&self.db)
92            .await
93            .map_err(DurableError::from)?;
94
95        Ok(())
96    }
97
98    /// Spawn a background tokio task that calls `heartbeat()` every `config.heartbeat_interval`.
99    ///
100    /// Returns a `JoinHandle` so the caller can abort it on graceful shutdown.
101    pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
102        let db = self.db.clone();
103        let executor_id = self.executor_id.clone();
104        let interval = config.heartbeat_interval;
105
106        tokio::spawn(async move {
107            let executor = Executor::new(db, executor_id);
108            let mut ticker = tokio::time::interval(interval);
109            loop {
110                ticker.tick().await;
111                if let Err(e) = executor.heartbeat().await {
112                    tracing::warn!("heartbeat write failed: {e}");
113                }
114            }
115        })
116    }
117
118    /// Claim stale tasks from dead workers in a single atomic operation.
119    /// Uses `FOR UPDATE SKIP LOCKED` so multiple executors can run this
120    /// concurrently without blocking each other — each executor claims a
121    /// disjoint set of tasks.
122    ///
123    /// Catches three cases:
124    /// 1. Tasks whose `executor_id` has a stale heartbeat (> staleness_threshold)
125    /// 2. Tasks whose `executor_id` has no heartbeat row at all
126    /// 3. Orphaned tasks with NULL `executor_id` that have been RUNNING too long
127    ///
128    /// Claimed tasks go directly to RUNNING (no intermediate PENDING state)
129    /// with `started_at` reset and `executor_id` set to this executor.
130    pub async fn recover_stale_tasks(
131        &self,
132        staleness_threshold: Duration,
133    ) -> Result<Vec<RecoveredTask>, DurableError> {
134        let threshold_secs = staleness_threshold.as_secs();
135        let sql = format!(
136            "WITH claimable AS ( \
137                 SELECT id FROM durable.task \
138                 WHERE status = 'RUNNING' \
139                   AND ( \
140                       (executor_id IN ( \
141                           SELECT executor_id FROM durable.executor_heartbeat \
142                           WHERE last_seen < now() - interval '{threshold_secs} seconds' \
143                       ) AND executor_id != '{eid}') \
144                       OR \
145                       (executor_id IS NOT NULL \
146                        AND executor_id != '{eid}' \
147                        AND executor_id NOT IN ( \
148                            SELECT executor_id FROM durable.executor_heartbeat \
149                        )) \
150                       OR \
151                       (executor_id IS NULL \
152                        AND started_at IS NOT NULL \
153                        AND started_at < now() - interval '{threshold_secs} seconds') \
154                   ) \
155                 FOR UPDATE SKIP LOCKED \
156             ) \
157             UPDATE durable.task \
158             SET status = 'RUNNING', started_at = now(), executor_id = '{eid}' \
159             WHERE id IN (SELECT id FROM claimable) \
160             RETURNING id, parent_id, name, kind",
161            eid = self.executor_id
162        );
163        let rows = self
164            .db
165            .query_all(Statement::from_string(DbBackend::Postgres, sql))
166            .await?;
167
168        rows.iter().map(parse_recovered_row).collect()
169    }
170
171    /// Claim timed-out tasks in a single atomic operation.
172    /// Uses `FOR UPDATE SKIP LOCKED` for safe concurrent claiming.
173    ///
174    /// A task is considered timed-out if:
175    /// - `status = 'RUNNING'` and `started_at IS NOT NULL`
176    /// - `timeout_ms` has elapsed, or `deadline_epoch_ms` has passed
177    ///
178    /// Tasks without `timeout_ms` set are never considered timed-out.
179    pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
180        let sql = format!(
181            "WITH claimable AS ( \
182                 SELECT id FROM durable.task \
183                 WHERE status = 'RUNNING' \
184                   AND started_at IS NOT NULL \
185                   AND ( \
186                     (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
187                     OR \
188                     (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
189                   ) \
190                 FOR UPDATE SKIP LOCKED \
191             ) \
192             UPDATE durable.task \
193             SET status = 'RUNNING', started_at = now(), deadline_epoch_ms = NULL, \
194                 executor_id = '{eid}' \
195             WHERE id IN (SELECT id FROM claimable) \
196             RETURNING id, parent_id, name, kind",
197            eid = self.executor_id
198        );
199
200        let rows = self
201            .db
202            .query_all(Statement::from_string(DbBackend::Postgres, sql))
203            .await?;
204
205        rows.iter().map(parse_recovered_row).collect()
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_heartbeat_config_default() {
215        let config = HeartbeatConfig::default();
216        assert_eq!(
217            config.heartbeat_interval,
218            Duration::from_secs(60),
219            "default heartbeat_interval should be 60s"
220        );
221        assert_eq!(
222            config.staleness_threshold,
223            Duration::from_secs(180),
224            "default staleness_threshold should be 180s"
225        );
226    }
227
228    #[test]
229    fn test_heartbeat_config_custom() {
230        let config = HeartbeatConfig {
231            heartbeat_interval: Duration::from_secs(30),
232            staleness_threshold: Duration::from_secs(90),
233        };
234        assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
235        assert_eq!(config.staleness_threshold, Duration::from_secs(90));
236    }
237}