Skip to main content

durable/
executor.rs

1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use durable_db::entity::executor_heartbeat;
9
10/// Configuration for the executor heartbeat loop and recovery.
11///
12/// The staleness_threshold should be strictly greater than heartbeat_interval
13/// (recommended: 3x). Default is 60s interval and 180s threshold.
14pub struct HeartbeatConfig {
15    /// How often the executor writes a heartbeat. Default: 60s.
16    pub heartbeat_interval: Duration,
17    /// Tasks from executors whose last heartbeat is older than this are reset to PENDING. Default: 180s.
18    pub staleness_threshold: Duration,
19}
20
21impl Default for HeartbeatConfig {
22    fn default() -> Self {
23        Self {
24            heartbeat_interval: Duration::from_secs(60),
25            staleness_threshold: Duration::from_secs(180),
26        }
27    }
28}
29
30/// Information about a task that was recovered from a stale RUNNING state.
31pub struct RecoveredTask {
32    pub id: Uuid,
33    pub name: String,
34    pub kind: String,
35}
36
37/// Executor configuration for a durable worker.
38pub struct Executor {
39    db: DatabaseConnection,
40    executor_id: String,
41}
42
43impl Executor {
44    pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
45        Self { db, executor_id }
46    }
47
48    pub fn db(&self) -> &DatabaseConnection {
49        &self.db
50    }
51
52    pub fn executor_id(&self) -> &str {
53        &self.executor_id
54    }
55
56    /// Write (or upsert) a heartbeat row for this executor.
57    pub async fn heartbeat(&self) -> Result<(), DurableError> {
58        let now = chrono::Utc::now().into();
59        let model = executor_heartbeat::ActiveModel {
60            executor_id: Set(self.executor_id.clone()),
61            last_seen: Set(now),
62        };
63
64        executor_heartbeat::Entity::insert(model)
65            .on_conflict(
66                OnConflict::column(executor_heartbeat::Column::ExecutorId)
67                    .update_column(executor_heartbeat::Column::LastSeen)
68                    .to_owned(),
69            )
70            .exec(&self.db)
71            .await
72            .map_err(DurableError::from)?;
73
74        Ok(())
75    }
76
77    /// Spawn a background tokio task that calls `heartbeat()` every `config.heartbeat_interval`.
78    ///
79    /// Returns a `JoinHandle` so the caller can abort it on graceful shutdown.
80    pub fn start_heartbeat(&self, config: HeartbeatConfig) -> JoinHandle<()> {
81        let db = self.db.clone();
82        let executor_id = self.executor_id.clone();
83        let interval = config.heartbeat_interval;
84
85        tokio::spawn(async move {
86            let executor = Executor::new(db, executor_id);
87            let mut ticker = tokio::time::interval(interval);
88            loop {
89                ticker.tick().await;
90                if let Err(e) = executor.heartbeat().await {
91                    tracing::warn!("heartbeat write failed: {e}");
92                }
93            }
94        })
95    }
96
97    /// Reset RUNNING tasks from dead workers (stale heartbeat > staleness_threshold) to PENDING.
98    ///
99    /// Excludes own executor_id — own tasks are handled by startup reset.
100    /// Returns the number of rows affected.
101    pub async fn recover_stale_tasks(
102        &self,
103        staleness_threshold: Duration,
104    ) -> Result<u64, DurableError> {
105        let threshold_secs = staleness_threshold.as_secs();
106        let sql = format!(
107            "UPDATE durable.task SET status = 'PENDING', started_at = NULL \
108             WHERE status = 'RUNNING' \
109               AND executor_id IN ( \
110                   SELECT executor_id FROM durable.executor_heartbeat \
111                   WHERE last_seen < now() - interval '{threshold_secs} seconds' \
112               ) \
113               AND executor_id != '{}'",
114            self.executor_id
115        );
116        let result = self
117            .db
118            .execute(Statement::from_string(DbBackend::Postgres, sql))
119            .await
120            .map_err(DurableError::from)?;
121        Ok(result.rows_affected())
122    }
123
124    /// Spawn a background tokio task that calls `recover_stale_tasks()` every
125    /// `config.staleness_threshold` interval.
126    pub fn start_recovery_loop(&self, config: HeartbeatConfig) -> JoinHandle<()> {
127        let db = self.db.clone();
128        let executor_id = self.executor_id.clone();
129        let threshold = config.staleness_threshold;
130
131        tokio::spawn(async move {
132            let executor = Executor::new(db, executor_id);
133            let mut ticker = tokio::time::interval(threshold);
134            loop {
135                ticker.tick().await;
136                match executor.recover_stale_tasks(threshold).await {
137                    Ok(n) if n > 0 => {
138                        tracing::info!("recovered {n} stale tasks from dead workers");
139                    }
140                    Ok(_) => {}
141                    Err(e) => {
142                        tracing::warn!("recover_stale_tasks failed: {e}");
143                    }
144                }
145            }
146        })
147    }
148
149    /// Scan for stale tasks by timeout and reset them to PENDING.
150    ///
151    /// A task is considered stale if:
152    /// - `status = 'RUNNING'`
153    /// - `timeout_ms IS NOT NULL` or `deadline_epoch_ms IS NOT NULL`
154    /// - `started_at IS NOT NULL`
155    /// - The timeout/deadline has elapsed
156    ///
157    /// Tasks without `timeout_ms` set are never considered stale.
158    pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
159        let sql = "
160            UPDATE durable.task
161            SET status = 'PENDING', started_at = NULL, deadline_epoch_ms = NULL
162            WHERE status = 'RUNNING'
163              AND started_at IS NOT NULL
164              AND (
165                (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0))
166                OR
167                (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000)
168              )
169            RETURNING id, name, kind
170        ";
171
172        let rows = self
173            .db
174            .query_all(Statement::from_string(DbBackend::Postgres, sql.to_string()))
175            .await?;
176
177        let mut recovered = Vec::with_capacity(rows.len());
178        for row in rows {
179            let id: Uuid = row
180                .try_get_by_index(0)
181                .map_err(|e| DurableError::custom(e.to_string()))?;
182            let name: String = row
183                .try_get_by_index(1)
184                .map_err(|e| DurableError::custom(e.to_string()))?;
185            let kind: String = row
186                .try_get_by_index(2)
187                .map_err(|e| DurableError::custom(e.to_string()))?;
188            recovered.push(RecoveredTask { id, name, kind });
189        }
190
191        Ok(recovered)
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_heartbeat_config_default() {
201        let config = HeartbeatConfig::default();
202        assert_eq!(
203            config.heartbeat_interval,
204            Duration::from_secs(60),
205            "default heartbeat_interval should be 60s"
206        );
207        assert_eq!(
208            config.staleness_threshold,
209            Duration::from_secs(180),
210            "default staleness_threshold should be 180s"
211        );
212    }
213
214    #[test]
215    fn test_heartbeat_config_custom() {
216        let config = HeartbeatConfig {
217            heartbeat_interval: Duration::from_secs(30),
218            staleness_threshold: Duration::from_secs(90),
219        };
220        assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
221        assert_eq!(config.staleness_threshold, Duration::from_secs(90));
222    }
223}