1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use durable_db::entity::executor_heartbeat;
9
10pub struct HeartbeatConfig {
15 pub heartbeat_interval: Duration,
17 pub staleness_threshold: Duration,
19}
20
21impl Default for HeartbeatConfig {
22 fn default() -> Self {
23 Self {
24 heartbeat_interval: Duration::from_secs(60),
25 staleness_threshold: Duration::from_secs(180),
26 }
27 }
28}
29
30pub struct RecoveredTask {
32 pub id: Uuid,
33 pub parent_id: Option<Uuid>,
34 pub name: String,
35 pub kind: String,
36}
37
38fn parse_recovered_row(row: &sea_orm::QueryResult) -> Result<RecoveredTask, DurableError> {
40 let id: Uuid = row
41 .try_get_by_index(0)
42 .map_err(|e| DurableError::custom(e.to_string()))?;
43 let parent_id: Option<Uuid> = row.try_get_by_index(1).ok().flatten();
44 let name: String = row
45 .try_get_by_index(2)
46 .map_err(|e| DurableError::custom(e.to_string()))?;
47 let kind: String = row
48 .try_get_by_index(3)
49 .map_err(|e| DurableError::custom(e.to_string()))?;
50 Ok(RecoveredTask {
51 id,
52 parent_id,
53 name,
54 kind,
55 })
56}
57
58pub struct Executor {
60 db: DatabaseConnection,
61 executor_id: String,
62}
63
64impl Executor {
65 pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
66 Self { db, executor_id }
67 }
68
69 pub fn db(&self) -> &DatabaseConnection {
70 &self.db
71 }
72
73 pub fn executor_id(&self) -> &str {
74 &self.executor_id
75 }
76
77 pub async fn heartbeat(&self) -> Result<(), DurableError> {
79 let now = chrono::Utc::now().into();
80 let model = executor_heartbeat::ActiveModel {
81 executor_id: Set(self.executor_id.clone()),
82 last_seen: Set(now),
83 };
84
85 executor_heartbeat::Entity::insert(model)
86 .on_conflict(
87 OnConflict::column(executor_heartbeat::Column::ExecutorId)
88 .update_column(executor_heartbeat::Column::LastSeen)
89 .to_owned(),
90 )
91 .exec(&self.db)
92 .await
93 .map_err(DurableError::from)?;
94
95 Ok(())
96 }
97
98 pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
102 let db = self.db.clone();
103 let executor_id = self.executor_id.clone();
104 let interval = config.heartbeat_interval;
105
106 tokio::spawn(async move {
107 let executor = Executor::new(db, executor_id);
108 let mut ticker = tokio::time::interval(interval);
109 loop {
110 ticker.tick().await;
111 if let Err(e) = executor.heartbeat().await {
112 tracing::warn!("heartbeat write failed: {e}");
113 }
114 }
115 })
116 }
117
118 pub async fn recover_stale_tasks(
127 &self,
128 staleness_threshold: Duration,
129 ) -> Result<Vec<RecoveredTask>, DurableError> {
130 let threshold_secs = staleness_threshold.as_secs();
131 let sql = format!(
132 "UPDATE durable.task \
133 SET status = 'PENDING', started_at = NULL, \
134 executor_id = '{}' \
135 WHERE status = 'RUNNING' \
136 AND ( \
137 (executor_id IN ( \
138 SELECT executor_id FROM durable.executor_heartbeat \
139 WHERE last_seen < now() - interval '{threshold_secs} seconds' \
140 ) AND executor_id != '{}') \
141 OR \
142 (executor_id IS NOT NULL \
143 AND executor_id != '{}' \
144 AND executor_id NOT IN ( \
145 SELECT executor_id FROM durable.executor_heartbeat \
146 )) \
147 OR \
148 (executor_id IS NULL \
149 AND started_at IS NOT NULL \
150 AND started_at < now() - interval '{threshold_secs} seconds') \
151 ) \
152 RETURNING id, parent_id, name, kind",
153 self.executor_id, self.executor_id, self.executor_id
154 );
155 let rows = self
156 .db
157 .query_all(Statement::from_string(DbBackend::Postgres, sql))
158 .await?;
159
160 rows.iter().map(parse_recovered_row).collect()
161 }
162
163 pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
173 let sql = format!(
174 "UPDATE durable.task \
175 SET status = 'PENDING', started_at = NULL, deadline_epoch_ms = NULL, \
176 executor_id = '{}' \
177 WHERE status = 'RUNNING' \
178 AND started_at IS NOT NULL \
179 AND ( \
180 (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
181 OR \
182 (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
183 ) \
184 RETURNING id, parent_id, name, kind",
185 self.executor_id
186 );
187
188 let rows = self
189 .db
190 .query_all(Statement::from_string(DbBackend::Postgres, sql))
191 .await?;
192
193 rows.iter().map(parse_recovered_row).collect()
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_heartbeat_config_default() {
203 let config = HeartbeatConfig::default();
204 assert_eq!(
205 config.heartbeat_interval,
206 Duration::from_secs(60),
207 "default heartbeat_interval should be 60s"
208 );
209 assert_eq!(
210 config.staleness_threshold,
211 Duration::from_secs(180),
212 "default staleness_threshold should be 180s"
213 );
214 }
215
216 #[test]
217 fn test_heartbeat_config_custom() {
218 let config = HeartbeatConfig {
219 heartbeat_interval: Duration::from_secs(30),
220 staleness_threshold: Duration::from_secs(90),
221 };
222 assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
223 assert_eq!(config.staleness_threshold, Duration::from_secs(90));
224 }
225}