1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use durable_db::entity::executor_heartbeat;
9
10pub struct HeartbeatConfig {
15 pub heartbeat_interval: Duration,
17 pub staleness_threshold: Duration,
19}
20
21impl Default for HeartbeatConfig {
22 fn default() -> Self {
23 Self {
24 heartbeat_interval: Duration::from_secs(60),
25 staleness_threshold: Duration::from_secs(180),
26 }
27 }
28}
29
30pub struct RecoveredTask {
32 pub id: Uuid,
33 pub name: String,
34 pub kind: String,
35}
36
37pub struct Executor {
39 db: DatabaseConnection,
40 executor_id: String,
41}
42
43impl Executor {
44 pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
45 Self { db, executor_id }
46 }
47
48 pub fn db(&self) -> &DatabaseConnection {
49 &self.db
50 }
51
52 pub fn executor_id(&self) -> &str {
53 &self.executor_id
54 }
55
56 pub async fn heartbeat(&self) -> Result<(), DurableError> {
58 let now = chrono::Utc::now().into();
59 let model = executor_heartbeat::ActiveModel {
60 executor_id: Set(self.executor_id.clone()),
61 last_seen: Set(now),
62 };
63
64 executor_heartbeat::Entity::insert(model)
65 .on_conflict(
66 OnConflict::column(executor_heartbeat::Column::ExecutorId)
67 .update_column(executor_heartbeat::Column::LastSeen)
68 .to_owned(),
69 )
70 .exec(&self.db)
71 .await
72 .map_err(DurableError::from)?;
73
74 Ok(())
75 }
76
77 pub fn start_heartbeat(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
81 let db = self.db.clone();
82 let executor_id = self.executor_id.clone();
83 let interval = config.heartbeat_interval;
84
85 tokio::spawn(async move {
86 let executor = Executor::new(db, executor_id);
87 let mut ticker = tokio::time::interval(interval);
88 loop {
89 ticker.tick().await;
90 if let Err(e) = executor.heartbeat().await {
91 tracing::warn!("heartbeat write failed: {e}");
92 }
93 }
94 })
95 }
96
97 pub async fn recover_stale_tasks(
106 &self,
107 staleness_threshold: Duration,
108 ) -> Result<Vec<RecoveredTask>, DurableError> {
109 let threshold_secs = staleness_threshold.as_secs();
110 let sql = format!(
111 "UPDATE durable.task \
112 SET status = 'PENDING', started_at = NULL, \
113 executor_id = '{}' \
114 WHERE status = 'RUNNING' \
115 AND ( \
116 (executor_id IN ( \
117 SELECT executor_id FROM durable.executor_heartbeat \
118 WHERE last_seen < now() - interval '{threshold_secs} seconds' \
119 ) AND executor_id != '{}') \
120 OR \
121 (executor_id IS NOT NULL \
122 AND executor_id != '{}' \
123 AND executor_id NOT IN ( \
124 SELECT executor_id FROM durable.executor_heartbeat \
125 )) \
126 OR \
127 (executor_id IS NULL \
128 AND started_at IS NOT NULL \
129 AND started_at < now() - interval '{threshold_secs} seconds') \
130 ) \
131 RETURNING id, name, kind",
132 self.executor_id, self.executor_id, self.executor_id
133 );
134 let rows = self
135 .db
136 .query_all(Statement::from_string(DbBackend::Postgres, sql))
137 .await?;
138
139 let mut recovered = Vec::with_capacity(rows.len());
140 for row in rows {
141 let id: Uuid = row
142 .try_get_by_index(0)
143 .map_err(|e| DurableError::custom(e.to_string()))?;
144 let name: String = row
145 .try_get_by_index(1)
146 .map_err(|e| DurableError::custom(e.to_string()))?;
147 let kind: String = row
148 .try_get_by_index(2)
149 .map_err(|e| DurableError::custom(e.to_string()))?;
150 recovered.push(RecoveredTask { id, name, kind });
151 }
152
153 Ok(recovered)
154 }
155
156 pub fn start_recovery_loop(&self, config: &HeartbeatConfig) -> JoinHandle<()> {
159 let db = self.db.clone();
160 let executor_id = self.executor_id.clone();
161 let threshold = config.staleness_threshold;
162
163 tokio::spawn(async move {
164 let executor = Executor::new(db, executor_id);
165 let mut ticker = tokio::time::interval(threshold);
166 loop {
167 ticker.tick().await;
168 match executor.recover_stale_tasks(threshold).await {
169 Ok(ref recovered) if !recovered.is_empty() => {
170 tracing::info!(
171 "recovered {} stale tasks from dead workers",
172 recovered.len()
173 );
174 }
175 Ok(_) => {}
176 Err(e) => {
177 tracing::warn!("recover_stale_tasks failed: {e}");
178 }
179 }
180 }
181 })
182 }
183
184 pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
194 let sql = format!(
195 "UPDATE durable.task \
196 SET status = 'PENDING', started_at = NULL, deadline_epoch_ms = NULL, \
197 executor_id = '{}' \
198 WHERE status = 'RUNNING' \
199 AND started_at IS NOT NULL \
200 AND ( \
201 (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0)) \
202 OR \
203 (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000) \
204 ) \
205 RETURNING id, name, kind",
206 self.executor_id
207 );
208
209 let rows = self
210 .db
211 .query_all(Statement::from_string(DbBackend::Postgres, sql.to_string()))
212 .await?;
213
214 let mut recovered = Vec::with_capacity(rows.len());
215 for row in rows {
216 let id: Uuid = row
217 .try_get_by_index(0)
218 .map_err(|e| DurableError::custom(e.to_string()))?;
219 let name: String = row
220 .try_get_by_index(1)
221 .map_err(|e| DurableError::custom(e.to_string()))?;
222 let kind: String = row
223 .try_get_by_index(2)
224 .map_err(|e| DurableError::custom(e.to_string()))?;
225 recovered.push(RecoveredTask { id, name, kind });
226 }
227
228 Ok(recovered)
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_heartbeat_config_default() {
238 let config = HeartbeatConfig::default();
239 assert_eq!(
240 config.heartbeat_interval,
241 Duration::from_secs(60),
242 "default heartbeat_interval should be 60s"
243 );
244 assert_eq!(
245 config.staleness_threshold,
246 Duration::from_secs(180),
247 "default staleness_threshold should be 180s"
248 );
249 }
250
251 #[test]
252 fn test_heartbeat_config_custom() {
253 let config = HeartbeatConfig {
254 heartbeat_interval: Duration::from_secs(30),
255 staleness_threshold: Duration::from_secs(90),
256 };
257 assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
258 assert_eq!(config.staleness_threshold, Duration::from_secs(90));
259 }
260}