1use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, Statement};
2use std::time::Duration;
3use tokio::task::JoinHandle;
4use uuid::Uuid;
5
6use crate::error::DurableError;
7
8pub struct HeartbeatConfig {
13 pub heartbeat_interval: Duration,
15 pub staleness_threshold: Duration,
17}
18
19impl Default for HeartbeatConfig {
20 fn default() -> Self {
21 Self {
22 heartbeat_interval: Duration::from_secs(60),
23 staleness_threshold: Duration::from_secs(180),
24 }
25 }
26}
27
28pub struct RecoveredTask {
30 pub id: Uuid,
31 pub name: String,
32 pub kind: String,
33}
34
35pub struct Executor {
37 db: DatabaseConnection,
38 executor_id: String,
39}
40
41impl Executor {
42 pub fn new(db: DatabaseConnection, executor_id: String) -> Self {
43 Self { db, executor_id }
44 }
45
46 pub fn db(&self) -> &DatabaseConnection {
47 &self.db
48 }
49
50 pub fn executor_id(&self) -> &str {
51 &self.executor_id
52 }
53
54 pub async fn heartbeat(&self) -> Result<(), DurableError> {
56 let sql = format!(
57 "INSERT INTO durable.executor_heartbeat (executor_id, last_seen) \
58 VALUES ('{}', now()) \
59 ON CONFLICT (executor_id) DO UPDATE SET last_seen = now()",
60 self.executor_id
61 );
62 self.db
63 .execute(Statement::from_string(DbBackend::Postgres, sql))
64 .await
65 .map_err(DurableError::from)?;
66 Ok(())
67 }
68
69 pub fn start_heartbeat(&self, config: HeartbeatConfig) -> JoinHandle<()> {
73 let db = self.db.clone();
74 let executor_id = self.executor_id.clone();
75 let interval = config.heartbeat_interval;
76
77 tokio::spawn(async move {
78 let executor = Executor::new(db, executor_id);
79 let mut ticker = tokio::time::interval(interval);
80 loop {
81 ticker.tick().await;
82 if let Err(e) = executor.heartbeat().await {
83 tracing::warn!("heartbeat write failed: {e}");
84 }
85 }
86 })
87 }
88
89 pub async fn recover_stale_tasks(
94 &self,
95 staleness_threshold: Duration,
96 ) -> Result<u64, DurableError> {
97 let threshold_secs = staleness_threshold.as_secs();
98 let sql = format!(
99 "UPDATE durable.task SET status = 'PENDING', started_at = NULL \
100 WHERE status = 'RUNNING' \
101 AND executor_id IN ( \
102 SELECT executor_id FROM durable.executor_heartbeat \
103 WHERE last_seen < now() - interval '{threshold_secs} seconds' \
104 ) \
105 AND executor_id != '{}'",
106 self.executor_id
107 );
108 let result = self
109 .db
110 .execute(Statement::from_string(DbBackend::Postgres, sql))
111 .await
112 .map_err(DurableError::from)?;
113 Ok(result.rows_affected())
114 }
115
116 pub fn start_recovery_loop(&self, config: HeartbeatConfig) -> JoinHandle<()> {
119 let db = self.db.clone();
120 let executor_id = self.executor_id.clone();
121 let threshold = config.staleness_threshold;
122
123 tokio::spawn(async move {
124 let executor = Executor::new(db, executor_id);
125 let mut ticker = tokio::time::interval(threshold);
126 loop {
127 ticker.tick().await;
128 match executor.recover_stale_tasks(threshold).await {
129 Ok(n) if n > 0 => {
130 tracing::info!("recovered {n} stale tasks from dead workers");
131 }
132 Ok(_) => {}
133 Err(e) => {
134 tracing::warn!("recover_stale_tasks failed: {e}");
135 }
136 }
137 }
138 })
139 }
140
141 pub async fn recover(&self) -> Result<Vec<RecoveredTask>, DurableError> {
151 let sql = "
152 UPDATE durable.task
153 SET status = 'PENDING', started_at = NULL, deadline_epoch_ms = NULL
154 WHERE status = 'RUNNING'
155 AND started_at IS NOT NULL
156 AND (
157 (timeout_ms IS NOT NULL AND started_at < now() - make_interval(secs => timeout_ms::double precision / 1000.0))
158 OR
159 (deadline_epoch_ms IS NOT NULL AND deadline_epoch_ms < EXTRACT(EPOCH FROM now()) * 1000)
160 )
161 RETURNING id, name, kind
162 ";
163
164 let rows = self
165 .db
166 .query_all(Statement::from_string(DbBackend::Postgres, sql.to_string()))
167 .await?;
168
169 let mut recovered = Vec::with_capacity(rows.len());
170 for row in rows {
171 let id: Uuid = row
172 .try_get_by_index(0)
173 .map_err(|e| DurableError::custom(e.to_string()))?;
174 let name: String = row
175 .try_get_by_index(1)
176 .map_err(|e| DurableError::custom(e.to_string()))?;
177 let kind: String = row
178 .try_get_by_index(2)
179 .map_err(|e| DurableError::custom(e.to_string()))?;
180 recovered.push(RecoveredTask { id, name, kind });
181 }
182
183 Ok(recovered)
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_heartbeat_config_default() {
193 let config = HeartbeatConfig::default();
194 assert_eq!(
195 config.heartbeat_interval,
196 Duration::from_secs(60),
197 "default heartbeat_interval should be 60s"
198 );
199 assert_eq!(
200 config.staleness_threshold,
201 Duration::from_secs(180),
202 "default staleness_threshold should be 180s"
203 );
204 }
205
206 #[test]
207 fn test_heartbeat_config_custom() {
208 let config = HeartbeatConfig {
209 heartbeat_interval: Duration::from_secs(30),
210 staleness_threshold: Duration::from_secs(90),
211 };
212 assert_eq!(config.heartbeat_interval, Duration::from_secs(30));
213 assert_eq!(config.staleness_threshold, Duration::from_secs(90));
214 }
215}