1use sea_orm::sea_query::OnConflict;
2use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, Set, Statement};
3use std::time::Duration;
4use tokio::task::JoinHandle;
5use uuid::Uuid;
6
7use crate::error::DurableError;
8use durable_db::entity::executor_heartbeat;
9
10pub struct HeartbeatConfig {
15 pub heartbeat_interval: Duration,
17 pub staleness_threshold: Duration,
19}
20
21impl Default for HeartbeatConfig {
22 fn default() -> Self {
23 Self {
24 heartbeat_interval: Duration::from_secs(60),
25 staleness_threshold: Duration::from_secs(180),
26 }
27 }
28}
29
30pub struct RecoveredTask {
32 pub id: Uuid,
33 pub name: String,
34 pub kind: String,
35}
36
37pub struct Executor {
39 db: DatabaseConnection,
40 executor_id: String,
41}
42
43impl Executor {
44 pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
45 Self { db, executor_id }
46 }
47
48 pub fn db(&self) -> &DatabaseConnection {
49 &self.db
50 }
51
52 pub fn executor_id(&self) -> &str {
53 &self.executor_id
54 }
55
56 pub async fn heartbeat(&self) -> Result<(), DurableError> {
58 let now = chrono::Utc::now().into();
59 let model = executor_heartbeat::ActiveModel {
60 executor_id: Set(self.executor_id.clone()),
61 last_seen: Set(now),
62 };
63
64 executor_heartbeat::Entity::insert(model)
65 .on_conflict(
66 OnConflict::column(executor_heartbeat::Column::ExecutorId)
67 .update_column(executor_heartbeat::Column::LastSeen)
68 .to_owned(),
69 )
70 .exec(&self.db)
71 .await
72 .map_err(DurableError::from)?;
73
74 Ok(())
75 }
76
77 pub fn start_heartbeat(&self, config: HeartbeatConfig) -> JoinHandle<()> {
81 let db = self.db.clone();
82 let executor_id = self.executor_id.clone();
83 let interval = config.heartbeat_interval;
84
85 tokio::spawn(async move {
86 let executor = Executor::new(db, executor_id);
87 let mut ticker = tokio::time::interval(interval);
88 loop {
89 ticker.tick().await;
90 if let Err(e) = executor.heartbeat().await {
91 tracing::warn!("heartbeat write failed: {e}");
92 }
93 }
94 })
95 }
96
97 pub async fn recover_stale_tasks(
102 &self,
103 staleness_threshold: Duration,
104 ) -> Result<u64, DurableError> {
105 let threshold_secs = staleness_threshold.as_secs();
106 let sql = format!(
107 "UPDATE durable.task SET status = 'PENDING', started_at = NULL \
108 WHERE status = 'RUNNING' \
109 AND executor_id IN ( \
110 SELECT executor_id FROM durable.executor_heartbeat \
111 WHERE last_seen < now() - interval '{threshold_secs} seconds' \
112 ) \
113 AND executor_id != '{}'",
114 self.executor_id
115 );
116 let result = self
117 .db
118 .execute(Statement::from_string(DbBackend::Postgres, sql))
119 .await
120 .map_err(DurableError::from)?;
121 Ok(result.rows_affected())
122 }
123
124 pub fn start_recovery_loop(&self, config: HeartbeatConfig) -> JoinHandle<()> {
127 let db = self.db.clone();
128 let executor_id = self.executor_id.clone();
129 let threshold = config.staleness_threshold;
130
131 tokio::spawn(async move {
132 let executor = Executor::new(db, executor_id);
133 let mut ticker = tokio::time::interval(threshold);
134 loop {
135 ticker.tick().await;
136 match executor.recover_stale_tasks(threshold).await {
137 Ok(n) if n > 0 => {
138 tracing::info!("recovered {n} stale tasks from dead workers");
139 }
140 Ok(_) => {}
141 Err(e) => {
142 tracing::warn!("recover_stale_tasks failed: {e}");
143 }
144 }
145 }
146 })
147 }
148
149 pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
159 let sql = "
160 UPDATE durable.task
161 SET status = 'PENDING', started_at = NULL, deadline_epoch_ms = NULL
162 WHERE status = 'RUNNING'
163 AND started_at IS NOT NULL
164 AND (
165 (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0))
166 OR
167 (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000)
168 )
169 RETURNING id, name, kind
170 ";
171
172 let rows = self
173 .db
174 .query_all(Statement::from_string(DbBackend::Postgres, sql.to_string()))
175 .await?;
176
177 let mut recovered = Vec::with_capacity(rows.len());
178 for row in rows {
179 let id: Uuid = row
180 .try_get_by_index(0)
181 .map_err(|e| DurableError::custom(e.to_string()))?;
182 let name: String = row
183 .try_get_by_index(1)
184 .map_err(|e| DurableError::custom(e.to_string()))?;
185 let kind: String = row
186 .try_get_by_index(2)
187 .map_err(|e| DurableError::custom(e.to_string()))?;
188 recovered.push(RecoveredTask { id, name, kind });
189 }
190
191 Ok(recovered)
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_heartbeat_config_default() {
201 let config = HeartbeatConfig::default();
202 assert_eq!(
203 config.heartbeat_interval,
204 Duration::from_secs(60),
205 "default heartbeat_interval should be 60s"
206 );
207 assert_eq!(
208 config.staleness_threshold,
209 Duration::from_secs(180),
210 "default staleness_threshold should be 180s"
211 );
212 }
213
214 #[test]
215 fn test_heartbeat_config_custom() {
216 let config = HeartbeatConfig {
217 heartbeat_interval: Duration::from_secs(30),
218 staleness_threshold: Duration::from_secs(90),
219 };
220 assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
221 assert_eq!(config.staleness_threshold, Duration::from_secs(90));
222 }
223}