1use std::time::Duration;
2
3use sea_orm::sea_query::OnConflict;
4use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
5use tokio::task::JoinHandle;
6use uuid::Uuid;
7
8use crate::ctx::TS;
9use crate::error::DurableError;
10use durable_db::entity::executor_heartbeat;
11
12pub struct HeartbeatConfig {
17 pub heartbeat_interval: Duration,
19 pub staleness_threshold: Duration,
21}
22
23impl Default for HeartbeatConfig {
24 fn default() -> Self {
25 Self {
26 heartbeat_interval: Duration::from_secs(60),
27 staleness_threshold: Duration::from_secs(180),
28 }
29 }
30}
31
32pub struct RecoveredTask {
34 pub id: Uuid,
35 pub parent_id: Option<Uuid>,
36 pub name: String,
37 pub kind: String,
38 pub handler: Option<String>,
42}
43
44fn parse_recovered_row(row: &sea_orm::QueryResult) -> Result<RecoveredTask, DurableError> {
46 let id: Uuid = row
47 .try_get_by_index(0)
48 .map_err(|e| DurableError::custom(e.to_string()))?;
49 let parent_id: Option<Uuid> = row.try_get_by_index(1).ok().flatten();
50 let name: String = row
51 .try_get_by_index(2)
52 .map_err(|e| DurableError::custom(e.to_string()))?;
53 let kind: String = row
54 .try_get_by_index(3)
55 .map_err(|e| DurableError::custom(e.to_string()))?;
56 let handler: Option<String> = row.try_get_by_index(4).ok().flatten();
57 Ok(RecoveredTask {
58 id,
59 parent_id,
60 name,
61 kind,
62 handler,
63 })
64}
65
66pub struct Executor {
68 db: DatabaseConnection,
69 executor_id: String,
70}
71
72impl Executor {
73 pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
74 Self { db, executor_id }
75 }
76
77 pub fn db(&self) -> &DatabaseConnection {
78 &self.db
79 }
80
81 pub fn executor_id(&self) -> &str {
82 &self.executor_id
83 }
84
85 pub async fn heartbeat(&self) -> Result<(), DurableError> {
87 let now = chrono::Utc::now().into();
88 let model = executor_heartbeat::ActiveModel {
89 executor_id: Set(self.executor_id.clone()),
90 last_seen: Set(now),
91 };
92
93 executor_heartbeat::Entity::insert(model)
94 .on_conflict(
95 OnConflict::column(executor_heartbeat::Column::ExecutorId)
96 .update_column(executor_heartbeat::Column::LastSeen)
97 .to_owned(),
98 )
99 .exec(&self.db)
100 .await
101 .map_err(DurableError::from)?;
102
103 Ok(())
104 }
105
106 pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
114 let db = self.db.clone();
115 let executor_id = self.executor_id.clone();
116 let interval = config.heartbeat_interval;
117
118 tokio::spawn(async move {
119 let executor = Executor::new(db, executor_id);
120 let mut ticker = tokio::time::interval(interval);
121 ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
122 loop {
123 ticker.tick().await;
124 let hb_db = executor.db.clone();
130 let hb_eid = executor.executor_id.clone();
131 let child =
132 tokio::spawn(async move { Executor::new(hb_db, hb_eid).heartbeat().await });
133 match child.await {
134 Ok(Ok(())) => {}
135 Ok(Err(e)) => {
136 tracing::warn!("heartbeat write failed: {e}");
137 }
138 Err(join_err) => {
139 tracing::error!(
141 "heartbeat task panicked, will retry on next tick: {join_err}"
142 );
143 }
144 }
145 }
146 })
147 }
148
149 pub async fn recover_stale_tasks(
162 &self,
163 staleness_threshold: Duration,
164 ) -> Result<Vec<RecoveredTask>, DurableError> {
165 let threshold_secs = staleness_threshold.as_secs();
166 let sql = format!(
167 "WITH claimable AS ( \
168 SELECT id FROM durable.task \
169 WHERE status = 'RUNNING'{TS} \
170 AND ( \
171 (executor_id IN ( \
172 SELECT executor_id FROM durable.executor_heartbeat \
173 WHERE last_seen < now() - interval '{threshold_secs} seconds' \
174 ) AND executor_id != '{eid}') \
175 OR \
176 (executor_id IS NOT NULL \
177 AND executor_id != '{eid}' \
178 AND executor_id NOT IN ( \
179 SELECT executor_id FROM durable.executor_heartbeat \
180 )) \
181 OR \
182 (executor_id IS NULL \
183 AND started_at IS NOT NULL \
184 AND started_at < now() - interval '{threshold_secs} seconds') \
185 ) \
186 FOR UPDATE SKIP LOCKED \
187 ) \
188 UPDATE durable.task \
189 SET status = 'RUNNING'{TS}, started_at = now(), executor_id = '{eid}' \
190 WHERE id IN (SELECT id FROM claimable) \
191 RETURNING id, parent_id, name, kind, handler",
192 eid = self.executor_id
193 );
194 let rows = self
195 .db
196 .query_all(Statement::from_string(DbBackend::Postgres, sql))
197 .await?;
198
199 rows.iter().map(parse_recovered_row).collect()
200 }
201
202 pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
211 let sql = format!(
212 "WITH claimable AS ( \
213 SELECT id FROM durable.task \
214 WHERE status = 'RUNNING'{TS} \
215 AND started_at IS NOT NULL \
216 AND ( \
217 (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
218 OR \
219 (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
220 ) \
221 FOR UPDATE SKIP LOCKED \
222 ) \
223 UPDATE durable.task \
224 SET status = 'RUNNING'{TS}, started_at = now(), deadline_epoch_ms = NULL, \
225 executor_id = '{eid}' \
226 WHERE id IN (SELECT id FROM claimable) \
227 RETURNING id, parent_id, name, kind, handler",
228 eid = self.executor_id
229 );
230
231 let rows = self
232 .db
233 .query_all(Statement::from_string(DbBackend::Postgres, sql))
234 .await?;
235
236 rows.iter().map(parse_recovered_row).collect()
237 }
238
239 pub async fn reset_orphaned_steps(&self, recovered_ids: &[Uuid]) -> Result<u64, DurableError> {
247 if recovered_ids.is_empty() {
248 return Ok(0);
249 }
250 let id_list: String = recovered_ids
251 .iter()
252 .map(|id| format!("'{id}'"))
253 .collect::<Vec<_>>()
254 .join(",");
255 let sql = format!(
256 "UPDATE durable.task \
257 SET status = 'PENDING'{TS}, started_at = NULL \
258 WHERE parent_id IN ({id_list}) \
259 AND status = 'RUNNING'{TS} \
260 AND kind = 'STEP'"
261 );
262 let result = self
263 .db
264 .execute(Statement::from_string(DbBackend::Postgres, sql))
265 .await?;
266 Ok(result.rows_affected())
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_heartbeat_config_default() {
276 let config = HeartbeatConfig::default();
277 assert_eq!(
278 config.heartbeat_interval,
279 Duration::from_secs(60),
280 "default heartbeat_interval should be 60s"
281 );
282 assert_eq!(
283 config.staleness_threshold,
284 Duration::from_secs(180),
285 "default staleness_threshold should be 180s"
286 );
287 }
288
289 #[test]
290 fn test_heartbeat_config_custom() {
291 let config = HeartbeatConfig {
292 heartbeat_interval: Duration::from_secs(30),
293 staleness_threshold: Duration::from_secs(90),
294 };
295 assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
296 assert_eq!(config.staleness_threshold, Duration::from_secs(90));
297 }
298}