1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use durable_db::entity::executor_heartbeat;
9
10pub struct HeartbeatConfig {
15 pub heartbeat_interval: Duration,
17 pub staleness_threshold: Duration,
19}
20
21impl Default for HeartbeatConfig {
22 fn default() -> Self {
23 Self {
24 heartbeat_interval: Duration::from_secs(60),
25 staleness_threshold: Duration::from_secs(180),
26 }
27 }
28}
29
30pub struct RecoveredTask {
32 pub id: Uuid,
33 pub parent_id: Option<Uuid>,
34 pub name: String,
35 pub kind: String,
36 pub handler: Option<String>,
40}
41
42fn parse_recovered_row(row: &sea_orm::QueryResult) -> Result<RecoveredTask, DurableError> {
44 let id: Uuid = row
45 .try_get_by_index(0)
46 .map_err(|e| DurableError::custom(e.to_string()))?;
47 let parent_id: Option<Uuid> = row.try_get_by_index(1).ok().flatten();
48 let name: String = row
49 .try_get_by_index(2)
50 .map_err(|e| DurableError::custom(e.to_string()))?;
51 let kind: String = row
52 .try_get_by_index(3)
53 .map_err(|e| DurableError::custom(e.to_string()))?;
54 let handler: Option<String> = row.try_get_by_index(4).ok().flatten();
55 Ok(RecoveredTask {
56 id,
57 parent_id,
58 name,
59 kind,
60 handler,
61 })
62}
63
64pub struct Executor {
66 db: DatabaseConnection,
67 executor_id: String,
68}
69
70impl Executor {
71 pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
72 Self { db, executor_id }
73 }
74
75 pub fn db(&self) -> &DatabaseConnection {
76 &self.db
77 }
78
79 pub fn executor_id(&self) -> &str {
80 &self.executor_id
81 }
82
83 pub async fn heartbeat(&self) -> Result<(), DurableError> {
85 let now = chrono::Utc::now().into();
86 let model = executor_heartbeat::ActiveModel {
87 executor_id: Set(self.executor_id.clone()),
88 last_seen: Set(now),
89 };
90
91 executor_heartbeat::Entity::insert(model)
92 .on_conflict(
93 OnConflict::column(executor_heartbeat::Column::ExecutorId)
94 .update_column(executor_heartbeat::Column::LastSeen)
95 .to_owned(),
96 )
97 .exec(&self.db)
98 .await
99 .map_err(DurableError::from)?;
100
101 Ok(())
102 }
103
104 pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
108 let db = self.db.clone();
109 let executor_id = self.executor_id.clone();
110 let interval = config.heartbeat_interval;
111
112 tokio::spawn(async move {
113 let executor = Executor::new(db, executor_id);
114 let mut ticker = tokio::time::interval(interval);
115 loop {
116 ticker.tick().await;
117 if let Err(e) = executor.heartbeat().await {
118 tracing::warn!("heartbeat write failed: {e}");
119 }
120 }
121 })
122 }
123
124 pub async fn recover_stale_tasks(
137 &self,
138 staleness_threshold: Duration,
139 ) -> Result<Vec<RecoveredTask>, DurableError> {
140 let threshold_secs = staleness_threshold.as_secs();
141 let sql = format!(
142 "WITH claimable AS ( \
143 SELECT id FROM durable.task \
144 WHERE status = 'RUNNING' \
145 AND ( \
146 (executor_id IN ( \
147 SELECT executor_id FROM durable.executor_heartbeat \
148 WHERE last_seen < now() - interval '{threshold_secs} seconds' \
149 ) AND executor_id != '{eid}') \
150 OR \
151 (executor_id IS NOT NULL \
152 AND executor_id != '{eid}' \
153 AND executor_id NOT IN ( \
154 SELECT executor_id FROM durable.executor_heartbeat \
155 )) \
156 OR \
157 (executor_id IS NULL \
158 AND started_at IS NOT NULL \
159 AND started_at < now() - interval '{threshold_secs} seconds') \
160 ) \
161 FOR UPDATE SKIP LOCKED \
162 ) \
163 UPDATE durable.task \
164 SET status = 'RUNNING', started_at = now(), executor_id = '{eid}' \
165 WHERE id IN (SELECT id FROM claimable) \
166 RETURNING id, parent_id, name, kind, handler",
167 eid = self.executor_id
168 );
169 let rows = self
170 .db
171 .query_all(Statement::from_string(DbBackend::Postgres, sql))
172 .await?;
173
174 rows.iter().map(parse_recovered_row).collect()
175 }
176
177 pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
186 let sql = format!(
187 "WITH claimable AS ( \
188 SELECT id FROM durable.task \
189 WHERE status = 'RUNNING' \
190 AND started_at IS NOT NULL \
191 AND ( \
192 (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
193 OR \
194 (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
195 ) \
196 FOR UPDATE SKIP LOCKED \
197 ) \
198 UPDATE durable.task \
199 SET status = 'RUNNING', started_at = now(), deadline_epoch_ms = NULL, \
200 executor_id = '{eid}' \
201 WHERE id IN (SELECT id FROM claimable) \
202 RETURNING id, parent_id, name, kind, handler",
203 eid = self.executor_id
204 );
205
206 let rows = self
207 .db
208 .query_all(Statement::from_string(DbBackend::Postgres, sql))
209 .await?;
210
211 rows.iter().map(parse_recovered_row).collect()
212 }
213
214 pub async fn reset_orphaned_steps(
222 &self,
223 recovered_ids: &[Uuid],
224 ) -> Result<u64, DurableError> {
225 if recovered_ids.is_empty() {
226 return Ok(0);
227 }
228 let id_list: String = recovered_ids
229 .iter()
230 .map(|id| format!("'{id}'"))
231 .collect::<Vec<_>>()
232 .join(",");
233 let sql = format!(
234 "UPDATE durable.task \
235 SET status = 'PENDING', started_at = NULL \
236 WHERE parent_id IN ({id_list}) \
237 AND status = 'RUNNING' \
238 AND kind = 'STEP'"
239 );
240 let result = self
241 .db
242 .execute(Statement::from_string(DbBackend::Postgres, sql))
243 .await?;
244 Ok(result.rows_affected())
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_heartbeat_config_default() {
254 let config = HeartbeatConfig::default();
255 assert_eq!(
256 config.heartbeat_interval,
257 Duration::from_secs(60),
258 "default heartbeat_interval should be 60s"
259 );
260 assert_eq!(
261 config.staleness_threshold,
262 Duration::from_secs(180),
263 "default staleness_threshold should be 180s"
264 );
265 }
266
267 #[test]
268 fn test_heartbeat_config_custom() {
269 let config = HeartbeatConfig {
270 heartbeat_interval: Duration::from_secs(30),
271 staleness_threshold: Duration::from_secs(90),
272 };
273 assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
274 assert_eq!(config.staleness_threshold, Duration::from_secs(90));
275 }
276}