1use chrono::{DateTime, Utc};
2use croner::Cron;
3use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, Statement};
4use std::str::FromStr;
5use std::time::Duration;
6use tokio::task::JoinHandle;
7use uuid::Uuid;
8
9use crate::ctx::TS;
10use crate::error::DurableError;
11
12pub struct SchedulerConfig {
14 pub poll_interval: Duration,
16}
17
18impl Default for SchedulerConfig {
19 fn default() -> Self {
20 Self {
21 poll_interval: Duration::from_secs(15),
22 }
23 }
24}
25
26#[derive(Debug)]
28pub struct DueTask {
29 pub id: Uuid,
30 pub name: String,
31 pub cron: String,
32 pub handler: Option<String>,
33 pub input: Option<serde_json::Value>,
34}
35
36pub fn next_run(cron_expr: &str, after: DateTime<Utc>) -> Option<DateTime<Utc>> {
40 let cron = Cron::from_str(cron_expr).ok()?;
41 cron.find_next_occurrence(&after, false).ok()
42}
43
44async fn find_due_tasks(db: &DatabaseConnection) -> Result<Vec<DueTask>, DurableError> {
49 let sql = format!(
50 "SELECT id, name, cron, handler, input::text FROM durable.task \
51 WHERE cron IS NOT NULL \
52 AND next_run_at <= now() \
53 AND status = 'PENDING'{TS} \
54 FOR UPDATE SKIP LOCKED"
55 );
56
57 let rows = db
58 .query_all(Statement::from_string(DbBackend::Postgres, sql))
59 .await?;
60
61 let mut tasks = Vec::with_capacity(rows.len());
62 for row in &rows {
63 let id: Uuid = row
64 .try_get_by_index(0)
65 .map_err(|e| DurableError::custom(e.to_string()))?;
66 let name: String = row
67 .try_get_by_index(1)
68 .map_err(|e| DurableError::custom(e.to_string()))?;
69 let cron: String = row
70 .try_get_by_index(2)
71 .map_err(|e| DurableError::custom(e.to_string()))?;
72 let handler: Option<String> = row.try_get_by_index(3).ok().flatten();
73 let input_str: Option<String> = row.try_get_by_index(4).ok().flatten();
74 let input = input_str.and_then(|s| serde_json::from_str(&s).ok());
75
76 tasks.push(DueTask {
77 id,
78 name,
79 cron,
80 handler,
81 input,
82 });
83 }
84 Ok(tasks)
85}
86
87async fn trigger_scheduled_task(
97 db: &DatabaseConnection,
98 task: &DueTask,
99 executor_id: &str,
100) -> Result<Uuid, DurableError> {
101 let child_id = Uuid::new_v4();
102 let now = Utc::now();
103
104 let escaped_name = task.name.replace('\'', "''");
106 let escaped_executor = executor_id.replace('\'', "''");
107
108 let input_literal = match &task.input {
111 Some(v) => {
112 let s = serde_json::to_string(v)?;
113 format!("'{}'", s.replace('\'', "''"))
114 }
115 None => "NULL".to_string(),
116 };
117
118 let (handler_col, handler_val) = match &task.handler {
120 Some(h) => (", handler", format!(", '{}'", h.replace('\'', "''"))),
121 None => ("", String::new()),
122 };
123
124 let next_run_sql = match next_run(&task.cron, now) {
126 Some(next_time) => format!("'{}'::timestamptz", next_time.to_rfc3339()),
127 None => "NULL".to_string(),
128 };
129
130 let sql = format!(
132 "WITH next_seq AS ( \
133 SELECT COALESCE(MAX(sequence), -1) + 1 AS seq \
134 FROM durable.task WHERE parent_id = '{parent_id}' \
135 ), \
136 child AS ( \
137 INSERT INTO durable.task \
138 (id, parent_id, sequence, name, kind, status, input, executor_id, started_at{handler_col}) \
139 SELECT '{child_id}', '{parent_id}', seq, \
140 '{escaped_name}_run_' || seq, 'WORKFLOW', \
141 'RUNNING'{TS}, {input_literal}, '{escaped_executor}', now(){handler_val} \
142 FROM next_seq \
143 RETURNING sequence \
144 ) \
145 UPDATE durable.task SET next_run_at = {next_run_sql} \
146 WHERE id = '{parent_id}'",
147 parent_id = task.id,
148 );
149 db.execute(Statement::from_string(DbBackend::Postgres, sql))
150 .await?;
151
152 if next_run_sql == "NULL" {
153 tracing::warn!(
154 id = %task.id,
155 cron = %task.cron,
156 "cron expression has no future occurrences, cleared next_run_at"
157 );
158 }
159
160 tracing::info!(
161 scheduled_id = %task.id,
162 child_id = %child_id,
163 name = %task.name,
164 "triggered scheduled workflow instance"
165 );
166
167 Ok(child_id)
168}
169
170pub async fn tick(db: &DatabaseConnection, executor_id: &str) -> Result<Vec<Uuid>, DurableError> {
172 let due = find_due_tasks(db).await?;
173 if due.is_empty() {
174 return Ok(vec![]);
175 }
176
177 tracing::info!("scheduler found {} due task(s)", due.len());
178
179 let mut child_ids = Vec::with_capacity(due.len());
180 for task in &due {
181 match trigger_scheduled_task(db, task, executor_id).await {
182 Ok(child_id) => {
183 let lookup_key = task.handler.as_deref().unwrap_or(&task.name);
185 if let Some(reg) = crate::find_workflow(lookup_key) {
186 let db_inner = db.clone();
187 let resume = reg.resume_fn;
188 tokio::spawn(async move {
189 crate::run_workflow_with_recovery(db_inner, child_id, resume).await;
190 });
191 }
192 child_ids.push(child_id);
193 }
194 Err(e) => {
195 tracing::error!(
196 id = %task.id,
197 error = %e,
198 "failed to trigger scheduled task"
199 );
200 }
201 }
202 }
203
204 Ok(child_ids)
205}
206
207pub fn start_scheduler_loop(
211 db: DatabaseConnection,
212 executor_id: String,
213 config: &SchedulerConfig,
214) -> JoinHandle<()> {
215 let interval = config.poll_interval;
216 tokio::spawn(async move {
217 let mut ticker = tokio::time::interval(interval);
218 loop {
219 ticker.tick().await;
220 if let Err(e) = tick(&db, &executor_id).await {
221 tracing::warn!("scheduler tick failed: {e}");
222 }
223 }
224 })
225}