Skip to main content

durable/
executor.rs

1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use durable_db::entity::executor_heartbeat;
9
10/// Configuration for the executor heartbeat loop and recovery.
11///
12/// The staleness_threshold should be strictly greater than heartbeat_interval
13/// (recommended: 3x). Default is 60s interval and 180s threshold.
14pub struct HeartbeatConfig {
15    /// How often the executor writes a heartbeat. Default: 60s.
16    pub heartbeat_interval: Duration,
17    /// Tasks from executors whose last heartbeat is older than this are reset to PENDING. Default: 180s.
18    pub staleness_threshold: Duration,
19}
20
21impl Default for HeartbeatConfig {
22    fn default() -> Self {
23        Self {
24            heartbeat_interval: Duration::from_secs(60),
25            staleness_threshold: Duration::from_secs(180),
26        }
27    }
28}
29
30/// Information about a task that was recovered from a stale RUNNING state.
31pub struct RecoveredTask {
32    pub id: Uuid,
33    pub parent_id: Option<Uuid>,
34    pub name: String,
35    pub kind: String,
36}
37
38/// Parse a row from a `RETURNING id, parent_id, name, kind` clause.
39fn parse_recovered_row(row: &sea_orm::QueryResult) -> Result<RecoveredTask, DurableError> {
40    let id: Uuid = row
41        .try_get_by_index(0)
42        .map_err(|e| DurableError::custom(e.to_string()))?;
43    let parent_id: Option<Uuid> = row.try_get_by_index(1).ok().flatten();
44    let name: String = row
45        .try_get_by_index(2)
46        .map_err(|e| DurableError::custom(e.to_string()))?;
47    let kind: String = row
48        .try_get_by_index(3)
49        .map_err(|e| DurableError::custom(e.to_string()))?;
50    Ok(RecoveredTask {
51        id,
52        parent_id,
53        name,
54        kind,
55    })
56}
57
58/// Executor configuration for a durable worker.
59pub struct Executor {
60    db: DatabaseConnection,
61    executor_id: String,
62}
63
64impl Executor {
65    pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
66        Self { db, executor_id }
67    }
68
69    pub fn db(&self) -> &DatabaseConnection {
70        &self.db
71    }
72
73    pub fn executor_id(&self) -> &str {
74        &self.executor_id
75    }
76
77    /// Write (or upsert) a heartbeat row for this executor.
78    pub async fn heartbeat(&self) -> Result<(), DurableError> {
79        let now = chrono::Utc::now().into();
80        let model = executor_heartbeat::ActiveModel {
81            executor_id: Set(self.executor_id.clone()),
82            last_seen: Set(now),
83        };
84
85        executor_heartbeat::Entity::insert(model)
86            .on_conflict(
87                OnConflict::column(executor_heartbeat::Column::ExecutorId)
88                    .update_column(executor_heartbeat::Column::LastSeen)
89                    .to_owned(),
90            )
91            .exec(&self.db)
92            .await
93            .map_err(DurableError::from)?;
94
95        Ok(())
96    }
97
98    /// Spawn a background tokio task that calls `heartbeat()` every `config.heartbeat_interval`.
99    ///
100    /// Returns a `JoinHandle` so the caller can abort it on graceful shutdown.
101    pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
102        let db = self.db.clone();
103        let executor_id = self.executor_id.clone();
104        let interval = config.heartbeat_interval;
105
106        tokio::spawn(async move {
107            let executor = Executor::new(db, executor_id);
108            let mut ticker = tokio::time::interval(interval);
109            loop {
110                ticker.tick().await;
111                if let Err(e) = executor.heartbeat().await {
112                    tracing::warn!("heartbeat write failed: {e}");
113                }
114            }
115        })
116    }
117
118    /// Reset RUNNING tasks from dead workers to PENDING. Catches three cases:
119    ///
120    /// 1. Tasks whose `executor_id` has a stale heartbeat (> staleness_threshold)
121    /// 2. Tasks whose `executor_id` has no heartbeat row at all (executor never registered or was cleaned up)
122    /// 3. Orphaned tasks with NULL `executor_id` that have been RUNNING too long
123    ///
124    /// Excludes own executor_id — own tasks are handled by the current process.
125    /// Returns the recovered tasks.
126    pub async fn recover_stale_tasks(
127        &self,
128        staleness_threshold: Duration,
129    ) -> Result<Vec<RecoveredTask>, DurableError> {
130        let threshold_secs = staleness_threshold.as_secs();
131        let sql = format!(
132            "UPDATE durable.task \
133             SET status = 'PENDING', started_at = NULL, \
134                 executor_id = '{}' \
135             WHERE status = 'RUNNING' \
136               AND ( \
137                   (executor_id IN ( \
138                       SELECT executor_id FROM durable.executor_heartbeat \
139                       WHERE last_seen < now() - interval '{threshold_secs} seconds' \
140                   ) AND executor_id != '{}') \
141                   OR \
142                   (executor_id IS NOT NULL \
143                    AND executor_id != '{}' \
144                    AND executor_id NOT IN ( \
145                        SELECT executor_id FROM durable.executor_heartbeat \
146                    )) \
147                   OR \
148                   (executor_id IS NULL \
149                    AND started_at IS NOT NULL \
150                    AND started_at < now() - interval '{threshold_secs} seconds') \
151               ) \
152             RETURNING id, parent_id, name, kind",
153            self.executor_id, self.executor_id, self.executor_id
154        );
155        let rows = self
156            .db
157            .query_all(Statement::from_string(DbBackend::Postgres, sql))
158            .await?;
159
160        rows.iter().map(parse_recovered_row).collect()
161    }
162
163    /// Scan for stale tasks by timeout and reset them to PENDING.
164    ///
165    /// A task is considered stale if:
166    /// - `status = 'RUNNING'`
167    /// - `timeout_ms IS NOT NULL` or `deadline_epoch_ms IS NOT NULL`
168    /// - `started_at IS NOT NULL`
169    /// - The timeout/deadline has elapsed
170    ///
171    /// Tasks without `timeout_ms` set are never considered stale.
172    pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
173        let sql = format!(
174            "UPDATE durable.task \
175             SET status = 'PENDING', started_at = NULL, deadline_epoch_ms = NULL, \
176                 executor_id = '{}' \
177             WHERE status = 'RUNNING' \
178               AND started_at IS NOT NULL \
179               AND ( \
180                 (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
181                 OR \
182                 (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
183               ) \
184             RETURNING id, parent_id, name, kind",
185            self.executor_id
186        );
187
188        let rows = self
189            .db
190            .query_all(Statement::from_string(DbBackend::Postgres, sql))
191            .await?;
192
193        rows.iter().map(parse_recovered_row).collect()
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_heartbeat_config_default() {
203        let config = HeartbeatConfig::default();
204        assert_eq!(
205            config.heartbeat_interval,
206            Duration::from_secs(60),
207            "default heartbeat_interval should be 60s"
208        );
209        assert_eq!(
210            config.staleness_threshold,
211            Duration::from_secs(180),
212            "default staleness_threshold should be 180s"
213        );
214    }
215
216    #[test]
217    fn test_heartbeat_config_custom() {
218        let config = HeartbeatConfig {
219            heartbeat_interval: Duration::from_secs(30),
220            staleness_threshold: Duration::from_secs(90),
221        };
222        assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
223        assert_eq!(config.staleness_threshold, Duration::from_secs(90));
224    }
225}