1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use crate::ctx::TS;
9use durable_db::entity::executor_heartbeat;
10
11pub struct HeartbeatConfig {
16 pub heartbeat_interval: Duration,
18 pub staleness_threshold: Duration,
20}
21
22impl Default for HeartbeatConfig {
23 fn default() -> Self {
24 Self {
25 heartbeat_interval: Duration::from_secs(60),
26 staleness_threshold: Duration::from_secs(180),
27 }
28 }
29}
30
31pub struct RecoveredTask {
33 pub id: Uuid,
34 pub parent_id: Option<Uuid>,
35 pub name: String,
36 pub kind: String,
37 pub handler: Option<String>,
41}
42
43fn parse_recovered_row(row: &sea_orm::QueryResult) -> Result<RecoveredTask, DurableError> {
45 let id: Uuid = row
46 .try_get_by_index(0)
47 .map_err(|e| DurableError::custom(e.to_string()))?;
48 let parent_id: Option<Uuid> = row.try_get_by_index(1).ok().flatten();
49 let name: String = row
50 .try_get_by_index(2)
51 .map_err(|e| DurableError::custom(e.to_string()))?;
52 let kind: String = row
53 .try_get_by_index(3)
54 .map_err(|e| DurableError::custom(e.to_string()))?;
55 let handler: Option<String> = row.try_get_by_index(4).ok().flatten();
56 Ok(RecoveredTask {
57 id,
58 parent_id,
59 name,
60 kind,
61 handler,
62 })
63}
64
65pub struct Executor {
67 db: DatabaseConnection,
68 executor_id: String,
69}
70
71impl Executor {
72 pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
73 Self { db, executor_id }
74 }
75
76 pub fn db(&self) -> &DatabaseConnection {
77 &self.db
78 }
79
80 pub fn executor_id(&self) -> &str {
81 &self.executor_id
82 }
83
84 pub async fn heartbeat(&self) -> Result<(), DurableError> {
86 let now = chrono::Utc::now().into();
87 let model = executor_heartbeat::ActiveModel {
88 executor_id: Set(self.executor_id.clone()),
89 last_seen: Set(now),
90 };
91
92 executor_heartbeat::Entity::insert(model)
93 .on_conflict(
94 OnConflict::column(executor_heartbeat::Column::ExecutorId)
95 .update_column(executor_heartbeat::Column::LastSeen)
96 .to_owned(),
97 )
98 .exec(&self.db)
99 .await
100 .map_err(DurableError::from)?;
101
102 Ok(())
103 }
104
105 pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
109 let db = self.db.clone();
110 let executor_id = self.executor_id.clone();
111 let interval = config.heartbeat_interval;
112
113 tokio::spawn(async move {
114 let executor = Executor::new(db, executor_id);
115 let mut ticker = tokio::time::interval(interval);
116 loop {
117 ticker.tick().await;
118 if let Err(e) = executor.heartbeat().await {
119 tracing::warn!("heartbeat write failed: {e}");
120 }
121 }
122 })
123 }
124
125 pub async fn recover_stale_tasks(
138 &self,
139 staleness_threshold: Duration,
140 ) -> Result<Vec<RecoveredTask>, DurableError> {
141 let threshold_secs = staleness_threshold.as_secs();
142 let sql = format!(
143 "WITH claimable AS ( \
144 SELECT id FROM durable.task \
145 WHERE status = 'RUNNING'{TS} \
146 AND ( \
147 (executor_id IN ( \
148 SELECT executor_id FROM durable.executor_heartbeat \
149 WHERE last_seen < now() - interval '{threshold_secs} seconds' \
150 ) AND executor_id != '{eid}') \
151 OR \
152 (executor_id IS NOT NULL \
153 AND executor_id != '{eid}' \
154 AND executor_id NOT IN ( \
155 SELECT executor_id FROM durable.executor_heartbeat \
156 )) \
157 OR \
158 (executor_id IS NULL \
159 AND started_at IS NOT NULL \
160 AND started_at < now() - interval '{threshold_secs} seconds') \
161 ) \
162 FOR UPDATE SKIP LOCKED \
163 ) \
164 UPDATE durable.task \
165 SET status = 'RUNNING'{TS}, started_at = now(), executor_id = '{eid}' \
166 WHERE id IN (SELECT id FROM claimable) \
167 RETURNING id, parent_id, name, kind, handler",
168 eid = self.executor_id
169 );
170 let rows = self
171 .db
172 .query_all(Statement::from_string(DbBackend::Postgres, sql))
173 .await?;
174
175 rows.iter().map(parse_recovered_row).collect()
176 }
177
178 pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
187 let sql = format!(
188 "WITH claimable AS ( \
189 SELECT id FROM durable.task \
190 WHERE status = 'RUNNING'{TS} \
191 AND started_at IS NOT NULL \
192 AND ( \
193 (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
194 OR \
195 (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
196 ) \
197 FOR UPDATE SKIP LOCKED \
198 ) \
199 UPDATE durable.task \
200 SET status = 'RUNNING'{TS}, started_at = now(), deadline_epoch_ms = NULL, \
201 executor_id = '{eid}' \
202 WHERE id IN (SELECT id FROM claimable) \
203 RETURNING id, parent_id, name, kind, handler",
204 eid = self.executor_id
205 );
206
207 let rows = self
208 .db
209 .query_all(Statement::from_string(DbBackend::Postgres, sql))
210 .await?;
211
212 rows.iter().map(parse_recovered_row).collect()
213 }
214
215 pub async fn reset_orphaned_steps(
223 &self,
224 recovered_ids: &[Uuid],
225 ) -> Result<u64, DurableError> {
226 if recovered_ids.is_empty() {
227 return Ok(0);
228 }
229 let id_list: String = recovered_ids
230 .iter()
231 .map(|id| format!("'{id}'"))
232 .collect::<Vec<_>>()
233 .join(",");
234 let sql = format!(
235 "UPDATE durable.task \
236 SET status = 'PENDING'{TS}, started_at = NULL \
237 WHERE parent_id IN ({id_list}) \
238 AND status = 'RUNNING'{TS} \
239 AND kind = 'STEP'"
240 );
241 let result = self
242 .db
243 .execute(Statement::from_string(DbBackend::Postgres, sql))
244 .await?;
245 Ok(result.rows_affected())
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_heartbeat_config_default() {
255 let config = HeartbeatConfig::default();
256 assert_eq!(
257 config.heartbeat_interval,
258 Duration::from_secs(60),
259 "default heartbeat_interval should be 60s"
260 );
261 assert_eq!(
262 config.staleness_threshold,
263 Duration::from_secs(180),
264 "default staleness_threshold should be 180s"
265 );
266 }
267
268 #[test]
269 fn test_heartbeat_config_custom() {
270 let config = HeartbeatConfig {
271 heartbeat_interval: Duration::from_secs(30),
272 staleness_threshold: Duration::from_secs(90),
273 };
274 assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
275 assert_eq!(config.staleness_threshold, Duration::from_secs(90));
276 }
277}