Skip to main content

durable/
executor.rs

1use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, Statement};
2use std::time::Duration;
3use tokio::task::JoinHandle;
4use uuid::Uuid;
5
6use crate::error::DurableError;
7
8/// Configuration for the executor heartbeat loop and recovery.
9///
10/// The staleness_threshold should be strictly greater than heartbeat_interval
11/// (recommended: 3x). Default is 60s interval and 180s threshold.
12pub struct HeartbeatConfig {
13    /// How often the executor writes a heartbeat. Default: 60s.
14    pub heartbeat_interval: Duration,
15    /// Tasks from executors whose last heartbeat is older than this are reset to PENDING. Default: 180s.
16    pub staleness_threshold: Duration,
17}
18
19impl Default for HeartbeatConfig {
20    fn default() -> Self {
21        Self {
22            heartbeat_interval: Duration::from_secs(60),
23            staleness_threshold: Duration::from_secs(180),
24        }
25    }
26}
27
28/// Information about a task that was recovered from a stale RUNNING state.
29pub struct RecoveredTask {
30    pub id: Uuid,
31    pub name: String,
32    pub kind: String,
33}
34
35/// Executor configuration for a durable worker.
36pub struct Executor {
37    db: DatabaseConnection,
38    executor_id: String,
39}
40
41impl Executor {
42    pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
43        Self { db, executor_id }
44    }
45
46    pub fn db(&self) -> &DatabaseConnection {
47        &self.db
48    }
49
50    pub fn executor_id(&self) -> &str {
51        &self.executor_id
52    }
53
54    /// Write (or upsert) a heartbeat row for this executor.
55    pub async fn heartbeat(&self) -> Result<(), DurableError> {
56        let sql = format!(
57            "INSERT INTO durable.executor_heartbeat (executor_id, last_seen) \
58             VALUES ('{}', now()) \
59             ON CONFLICT (executor_id) DO UPDATE SET last_seen = now()",
60            self.executor_id
61        );
62        self.db
63            .execute(Statement::from_string(DbBackend::Postgres, sql))
64            .await
65            .map_err(DurableError::from)?;
66        Ok(())
67    }
68
69    /// Spawn a background tokio task that calls `heartbeat()` every `config.heartbeat_interval`.
70    ///
71    /// Returns a `JoinHandle` so the caller can abort it on graceful shutdown.
72    pub fn start_heartbeat(&self, config: HeartbeatConfig) -> JoinHandle<()> {
73        let db = self.db.clone();
74        let executor_id = self.executor_id.clone();
75        let interval = config.heartbeat_interval;
76
77        tokio::spawn(async move {
78            let executor = Executor::new(db, executor_id);
79            let mut ticker = tokio::time::interval(interval);
80            loop {
81                ticker.tick().await;
82                if let Err(e) = executor.heartbeat().await {
83                    tracing::warn!("heartbeat write failed: {e}");
84                }
85            }
86        })
87    }
88
89    /// Reset RUNNING tasks from dead workers (stale heartbeat > staleness_threshold) to PENDING.
90    ///
91    /// Excludes own executor_id — own tasks are handled by startup reset.
92    /// Returns the number of rows affected.
93    pub async fn recover_stale_tasks(
94        &self,
95        staleness_threshold: Duration,
96    ) -> Result<u64, DurableError> {
97        let threshold_secs = staleness_threshold.as_secs();
98        let sql = format!(
99            "UPDATE durable.task SET status = 'PENDING', started_at = NULL \
100             WHERE status = 'RUNNING' \
101               AND executor_id IN ( \
102                   SELECT executor_id FROM durable.executor_heartbeat \
103                   WHERE last_seen < now() - interval '{threshold_secs} seconds' \
104               ) \
105               AND executor_id != '{}'",
106            self.executor_id
107        );
108        let result = self
109            .db
110            .execute(Statement::from_string(DbBackend::Postgres, sql))
111            .await
112            .map_err(DurableError::from)?;
113        Ok(result.rows_affected())
114    }
115
116    /// Spawn a background tokio task that calls `recover_stale_tasks()` every
117    /// `config.staleness_threshold` interval.
118    pub fn start_recovery_loop(&self, config: HeartbeatConfig) -> JoinHandle<()> {
119        let db = self.db.clone();
120        let executor_id = self.executor_id.clone();
121        let threshold = config.staleness_threshold;
122
123        tokio::spawn(async move {
124            let executor = Executor::new(db, executor_id);
125            let mut ticker = tokio::time::interval(threshold);
126            loop {
127                ticker.tick().await;
128                match executor.recover_stale_tasks(threshold).await {
129                    Ok(n) if n > 0 => {
130                        tracing::info!("recovered {n} stale tasks from dead workers");
131                    }
132                    Ok(_) => {}
133                    Err(e) => {
134                        tracing::warn!("recover_stale_tasks failed: {e}");
135                    }
136                }
137            }
138        })
139    }
140
141    /// Scan for stale tasks by timeout and reset them to PENDING.
142    ///
143    /// A task is considered stale if:
144    /// - `status = 'RUNNING'`
145    /// - `timeout_ms IS NOT NULL` or `deadline_epoch_ms IS NOT NULL`
146    /// - `started_at IS NOT NULL`
147    /// - The timeout/deadline has elapsed
148    ///
149    /// Tasks without `timeout_ms` set are never considered stale.
150    pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
151        let sql = "
152            UPDATE durable.task
153            SET status = 'PENDING', started_at = NULL, deadline_epoch_ms = NULL
154            WHERE status = 'RUNNING'
155              AND started_at IS NOT NULL
156              AND (
157                (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0))
158                OR
159                (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000)
160              )
161            RETURNING id, name, kind
162        ";
163
164        let rows = self
165            .db
166            .query_all(Statement::from_string(DbBackend::Postgres, sql.to_string()))
167            .await?;
168
169        let mut recovered = Vec::with_capacity(rows.len());
170        for row in rows {
171            let id: Uuid = row
172                .try_get_by_index(0)
173                .map_err(|e| DurableError::custom(e.to_string()))?;
174            let name: String = row
175                .try_get_by_index(1)
176                .map_err(|e| DurableError::custom(e.to_string()))?;
177            let kind: String = row
178                .try_get_by_index(2)
179                .map_err(|e| DurableError::custom(e.to_string()))?;
180            recovered.push(RecoveredTask { id, name, kind });
181        }
182
183        Ok(recovered)
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_heartbeat_config_default() {
193        let config = HeartbeatConfig::default();
194        assert_eq!(
195            config.heartbeat_interval,
196            Duration::from_secs(60),
197            "default heartbeat_interval should be 60s"
198        );
199        assert_eq!(
200            config.staleness_threshold,
201            Duration::from_secs(180),
202            "default staleness_threshold should be 180s"
203        );
204    }
205
206    #[test]
207    fn test_heartbeat_config_custom() {
208        let config = HeartbeatConfig {
209            heartbeat_interval: Duration::from_secs(30),
210            staleness_threshold: Duration::from_secs(90),
211        };
212        assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
213        assert_eq!(config.staleness_threshold, Duration::from_secs(90));
214    }
215}