Skip to main content

durable/
executor.rs

1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use durable_db::entity::executor_heartbeat;
9
10/// Configuration for the executor heartbeat loop and recovery.
11///
12/// The staleness_threshold should be strictly greater than heartbeat_interval
13/// (recommended: 3x). Default is 60s interval and 180s threshold.
14pub struct HeartbeatConfig {
15    /// How often the executor writes a heartbeat. Default: 60s.
16    pub heartbeat_interval: Duration,
17    /// Tasks from executors whose last heartbeat is older than this are reset to PENDING. Default: 180s.
18    pub staleness_threshold: Duration,
19}
20
21impl Default for HeartbeatConfig {
22    fn default() -> Self {
23        Self {
24            heartbeat_interval: Duration::from_secs(60),
25            staleness_threshold: Duration::from_secs(180),
26        }
27    }
28}
29
30/// Information about a task that was recovered from a stale RUNNING state.
31pub struct RecoveredTask {
32    pub id: Uuid,
33    pub name: String,
34    pub kind: String,
35}
36
37/// Executor configuration for a durable worker.
38pub struct Executor {
39    db: DatabaseConnection,
40    executor_id: String,
41}
42
43impl Executor {
44    pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
45        Self { db, executor_id }
46    }
47
48    pub fn db(&self) -> &DatabaseConnection {
49        &self.db
50    }
51
52    pub fn executor_id(&self) -> &str {
53        &self.executor_id
54    }
55
56    /// Write (or upsert) a heartbeat row for this executor.
57    pub async fn heartbeat(&self) -> Result<(), DurableError> {
58        let now = chrono::Utc::now().into();
59        let model = executor_heartbeat::ActiveModel {
60            executor_id: Set(self.executor_id.clone()),
61            last_seen: Set(now),
62        };
63
64        executor_heartbeat::Entity::insert(model)
65            .on_conflict(
66                OnConflict::column(executor_heartbeat::Column::ExecutorId)
67                    .update_column(executor_heartbeat::Column::LastSeen)
68                    .to_owned(),
69            )
70            .exec(&self.db)
71            .await
72            .map_err(DurableError::from)?;
73
74        Ok(())
75    }
76
77    /// Spawn a background tokio task that calls `heartbeat()` every `config.heartbeat_interval`.
78    ///
79    /// Returns a `JoinHandle` so the caller can abort it on graceful shutdown.
80    pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
81        let db = self.db.clone();
82        let executor_id = self.executor_id.clone();
83        let interval = config.heartbeat_interval;
84
85        tokio::spawn(async move {
86            let executor = Executor::new(db, executor_id);
87            let mut ticker = tokio::time::interval(interval);
88            loop {
89                ticker.tick().await;
90                if let Err(e) = executor.heartbeat().await {
91                    tracing::warn!("heartbeat write failed: {e}");
92                }
93            }
94        })
95    }
96
97    /// Reset RUNNING tasks from dead workers to PENDING. Catches three cases:
98    ///
99    /// 1. Tasks whose `executor_id` has a stale heartbeat (> staleness_threshold)
100    /// 2. Tasks whose `executor_id` has no heartbeat row at all (executor never registered or was cleaned up)
101    /// 3. Orphaned tasks with NULL `executor_id` that have been RUNNING too long
102    ///
103    /// Excludes own executor_id — own tasks are handled by the current process.
104    /// Returns the recovered tasks.
105    pub async fn recover_stale_tasks(
106        &self,
107        staleness_threshold: Duration,
108    ) -> Result<Vec<RecoveredTask>, DurableError> {
109        let threshold_secs = staleness_threshold.as_secs();
110        let sql = format!(
111            "UPDATE durable.task \
112             SET status = 'PENDING', started_at = NULL, \
113                 executor_id = '{}' \
114             WHERE status = 'RUNNING' \
115               AND ( \
116                   (executor_id IN ( \
117                       SELECT executor_id FROM durable.executor_heartbeat \
118                       WHERE last_seen < now() - interval '{threshold_secs} seconds' \
119                   ) AND executor_id != '{}') \
120                   OR \
121                   (executor_id IS NOT NULL \
122                    AND executor_id != '{}' \
123                    AND executor_id NOT IN ( \
124                        SELECT executor_id FROM durable.executor_heartbeat \
125                    )) \
126                   OR \
127                   (executor_id IS NULL \
128                    AND started_at IS NOT NULL \
129                    AND started_at < now() - interval '{threshold_secs} seconds') \
130               ) \
131             RETURNING id, name, kind",
132            self.executor_id, self.executor_id, self.executor_id
133        );
134        let rows = self
135            .db
136            .query_all(Statement::from_string(DbBackend::Postgres, sql))
137            .await?;
138
139        let mut recovered = Vec::with_capacity(rows.len());
140        for row in rows {
141            let id: Uuid = row
142                .try_get_by_index(0)
143                .map_err(|e| DurableError::custom(e.to_string()))?;
144            let name: String = row
145                .try_get_by_index(1)
146                .map_err(|e| DurableError::custom(e.to_string()))?;
147            let kind: String = row
148                .try_get_by_index(2)
149                .map_err(|e| DurableError::custom(e.to_string()))?;
150            recovered.push(RecoveredTask { id, name, kind });
151        }
152
153        Ok(recovered)
154    }
155
156    /// Spawn a background tokio task that calls `recover_stale_tasks()` every
157    /// `config.staleness_threshold` interval.
158    pub fn start_recovery_loop(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
159        let db = self.db.clone();
160        let executor_id = self.executor_id.clone();
161        let threshold = config.staleness_threshold;
162
163        tokio::spawn(async move {
164            let executor = Executor::new(db, executor_id);
165            let mut ticker = tokio::time::interval(threshold);
166            loop {
167                ticker.tick().await;
168                match executor.recover_stale_tasks(threshold).await {
169                    Ok(ref recovered) if !recovered.is_empty() => {
170                        tracing::info!(
171                            "recovered {} stale tasks from dead workers",
172                            recovered.len()
173                        );
174                    }
175                    Ok(_) => {}
176                    Err(e) => {
177                        tracing::warn!("recover_stale_tasks failed: {e}");
178                    }
179                }
180            }
181        })
182    }
183
184    /// Scan for stale tasks by timeout and reset them to PENDING.
185    ///
186    /// A task is considered stale if:
187    /// - `status = 'RUNNING'`
188    /// - `timeout_ms IS NOT NULL` or `deadline_epoch_ms IS NOT NULL`
189    /// - `started_at IS NOT NULL`
190    /// - The timeout/deadline has elapsed
191    ///
192    /// Tasks without `timeout_ms` set are never considered stale.
193    pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
194        let sql = format!(
195            "UPDATE durable.task \
196             SET status = 'PENDING', started_at = NULL, deadline_epoch_ms = NULL, \
197                 executor_id = '{}' \
198             WHERE status = 'RUNNING' \
199               AND started_at IS NOT NULL \
200               AND ( \
201                 (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
202                 OR \
203                 (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
204               ) \
205             RETURNING id, name, kind",
206            self.executor_id
207        );
208
209        let rows = self
210            .db
211            .query_all(Statement::from_string(DbBackend::Postgres, sql.to_string()))
212            .await?;
213
214        let mut recovered = Vec::with_capacity(rows.len());
215        for row in rows {
216            let id: Uuid = row
217                .try_get_by_index(0)
218                .map_err(|e| DurableError::custom(e.to_string()))?;
219            let name: String = row
220                .try_get_by_index(1)
221                .map_err(|e| DurableError::custom(e.to_string()))?;
222            let kind: String = row
223                .try_get_by_index(2)
224                .map_err(|e| DurableError::custom(e.to_string()))?;
225            recovered.push(RecoveredTask { id, name, kind });
226        }
227
228        Ok(recovered)
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_heartbeat_config_default() {
238        let config = HeartbeatConfig::default();
239        assert_eq!(
240            config.heartbeat_interval,
241            Duration::from_secs(60),
242            "default heartbeat_interval should be 60s"
243        );
244        assert_eq!(
245            config.staleness_threshold,
246            Duration::from_secs(180),
247            "default staleness_threshold should be 180s"
248        );
249    }
250
251    #[test]
252    fn test_heartbeat_config_custom() {
253        let config = HeartbeatConfig {
254            heartbeat_interval: Duration::from_secs(30),
255            staleness_threshold: Duration::from_secs(90),
256        };
257        assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
258        assert_eq!(config.staleness_threshold, Duration::from_secs(90));
259    }
260}