1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use durable_db::entity::executor_heartbeat;
9
10pub struct HeartbeatConfig {
15 pub heartbeat_interval: Duration,
17 pub staleness_threshold: Duration,
19}
20
21impl Default for HeartbeatConfig {
22 fn default() -> Self {
23 Self {
24 heartbeat_interval: Duration::from_secs(60),
25 staleness_threshold: Duration::from_secs(180),
26 }
27 }
28}
29
30pub struct RecoveredTask {
32 pub id: Uuid,
33 pub parent_id: Option<Uuid>,
34 pub name: String,
35 pub kind: String,
36}
37
38fn parse_recovered_row(row: &sea_orm::QueryResult) -> Result<RecoveredTask, DurableError> {
40 let id: Uuid = row
41 .try_get_by_index(0)
42 .map_err(|e| DurableError::custom(e.to_string()))?;
43 let parent_id: Option<Uuid> = row.try_get_by_index(1).ok().flatten();
44 let name: String = row
45 .try_get_by_index(2)
46 .map_err(|e| DurableError::custom(e.to_string()))?;
47 let kind: String = row
48 .try_get_by_index(3)
49 .map_err(|e| DurableError::custom(e.to_string()))?;
50 Ok(RecoveredTask {
51 id,
52 parent_id,
53 name,
54 kind,
55 })
56}
57
58pub struct Executor {
60 db: DatabaseConnection,
61 executor_id: String,
62}
63
64impl Executor {
65 pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
66 Self { db, executor_id }
67 }
68
69 pub fn db(&self) -> &DatabaseConnection {
70 &self.db
71 }
72
73 pub fn executor_id(&self) -> &str {
74 &self.executor_id
75 }
76
77 pub async fn heartbeat(&self) -> Result<(), DurableError> {
79 let now = chrono::Utc::now().into();
80 let model = executor_heartbeat::ActiveModel {
81 executor_id: Set(self.executor_id.clone()),
82 last_seen: Set(now),
83 };
84
85 executor_heartbeat::Entity::insert(model)
86 .on_conflict(
87 OnConflict::column(executor_heartbeat::Column::ExecutorId)
88 .update_column(executor_heartbeat::Column::LastSeen)
89 .to_owned(),
90 )
91 .exec(&self.db)
92 .await
93 .map_err(DurableError::from)?;
94
95 Ok(())
96 }
97
98 pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
102 let db = self.db.clone();
103 let executor_id = self.executor_id.clone();
104 let interval = config.heartbeat_interval;
105
106 tokio::spawn(async move {
107 let executor = Executor::new(db, executor_id);
108 let mut ticker = tokio::time::interval(interval);
109 loop {
110 ticker.tick().await;
111 if let Err(e) = executor.heartbeat().await {
112 tracing::warn!("heartbeat write failed: {e}");
113 }
114 }
115 })
116 }
117
118 pub async fn recover_stale_tasks(
131 &self,
132 staleness_threshold: Duration,
133 ) -> Result<Vec<RecoveredTask>, DurableError> {
134 let threshold_secs = staleness_threshold.as_secs();
135 let sql = format!(
136 "WITH claimable AS ( \
137 SELECT id FROM durable.task \
138 WHERE status = 'RUNNING' \
139 AND ( \
140 (executor_id IN ( \
141 SELECT executor_id FROM durable.executor_heartbeat \
142 WHERE last_seen < now() - interval '{threshold_secs} seconds' \
143 ) AND executor_id != '{eid}') \
144 OR \
145 (executor_id IS NOT NULL \
146 AND executor_id != '{eid}' \
147 AND executor_id NOT IN ( \
148 SELECT executor_id FROM durable.executor_heartbeat \
149 )) \
150 OR \
151 (executor_id IS NULL \
152 AND started_at IS NOT NULL \
153 AND started_at < now() - interval '{threshold_secs} seconds') \
154 ) \
155 FOR UPDATE SKIP LOCKED \
156 ) \
157 UPDATE durable.task \
158 SET status = 'RUNNING', started_at = now(), executor_id = '{eid}' \
159 WHERE id IN (SELECT id FROM claimable) \
160 RETURNING id, parent_id, name, kind",
161 eid = self.executor_id
162 );
163 let rows = self
164 .db
165 .query_all(Statement::from_string(DbBackend::Postgres, sql))
166 .await?;
167
168 rows.iter().map(parse_recovered_row).collect()
169 }
170
171 pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
180 let sql = format!(
181 "WITH claimable AS ( \
182 SELECT id FROM durable.task \
183 WHERE status = 'RUNNING' \
184 AND started_at IS NOT NULL \
185 AND ( \
186 (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
187 OR \
188 (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
189 ) \
190 FOR UPDATE SKIP LOCKED \
191 ) \
192 UPDATE durable.task \
193 SET status = 'RUNNING', started_at = now(), deadline_epoch_ms = NULL, \
194 executor_id = '{eid}' \
195 WHERE id IN (SELECT id FROM claimable) \
196 RETURNING id, parent_id, name, kind",
197 eid = self.executor_id
198 );
199
200 let rows = self
201 .db
202 .query_all(Statement::from_string(DbBackend::Postgres, sql))
203 .await?;
204
205 rows.iter().map(parse_recovered_row).collect()
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_heartbeat_config_default() {
215 let config = HeartbeatConfig::default();
216 assert_eq!(
217 config.heartbeat_interval,
218 Duration::from_secs(60),
219 "default heartbeat_interval should be 60s"
220 );
221 assert_eq!(
222 config.staleness_threshold,
223 Duration::from_secs(180),
224 "default staleness_threshold should be 180s"
225 );
226 }
227
228 #[test]
229 fn test_heartbeat_config_custom() {
230 let config = HeartbeatConfig {
231 heartbeat_interval: Duration::from_secs(30),
232 staleness_threshold: Duration::from_secs(90),
233 };
234 assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
235 assert_eq!(config.staleness_threshold, Duration::from_secs(90));
236 }
237}