Skip to main content

durable/
scheduler.rs

1use chrono::{DateTime, Utc};
2use croner::Cron;
3use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, Statement};
4use std::str::FromStr;
5use std::time::Duration;
6use tokio::task::JoinHandle;
7use uuid::Uuid;
8
9use crate::ctx::TS;
10use crate::error::DurableError;
11
12/// Configuration for the cron scheduler loop.
13pub struct SchedulerConfig {
14    /// How often the scheduler polls for due tasks. Default: 15 s.
15    pub poll_interval: Duration,
16}
17
18impl Default for SchedulerConfig {
19    fn default() -> Self {
20        Self {
21            poll_interval: Duration::from_secs(15),
22        }
23    }
24}
25
26/// A scheduled task that is due for execution.
27#[derive(Debug)]
28pub struct DueTask {
29    pub id: Uuid,
30    pub name: String,
31    pub cron: String,
32    pub handler: Option<String>,
33    pub input: Option<serde_json::Value>,
34}
35
36/// Compute the next run time for a cron expression after `after`.
37///
38/// Returns `None` if the cron expression is invalid or has no future occurrence.
39pub fn next_run(cron_expr: &str, after: DateTime<Utc>) -> Option<DateTime<Utc>> {
40    let cron = Cron::from_str(cron_expr).ok()?;
41    cron.find_next_occurrence(&after, false).ok()
42}
43
44/// Find all scheduled tasks whose `next_run_at` is due (i.e. <= now).
45///
46/// Only considers tasks with `cron IS NOT NULL` and `status = 'PENDING'`.
47/// Uses `FOR UPDATE SKIP LOCKED` so multiple schedulers can run concurrently.
48async fn find_due_tasks(db: &DatabaseConnection) -> Result<Vec<DueTask>, DurableError> {
49    let sql = format!(
50        "SELECT id, name, cron, handler, input::text FROM durable.task \
51         WHERE cron IS NOT NULL \
52           AND next_run_at <= now() \
53           AND status = 'PENDING'{TS} \
54         FOR UPDATE SKIP LOCKED"
55    );
56
57    let rows = db
58        .query_all(Statement::from_string(DbBackend::Postgres, sql))
59        .await?;
60
61    let mut tasks = Vec::with_capacity(rows.len());
62    for row in &rows {
63        let id: Uuid = row
64            .try_get_by_index(0)
65            .map_err(|e| DurableError::custom(e.to_string()))?;
66        let name: String = row
67            .try_get_by_index(1)
68            .map_err(|e| DurableError::custom(e.to_string()))?;
69        let cron: String = row
70            .try_get_by_index(2)
71            .map_err(|e| DurableError::custom(e.to_string()))?;
72        let handler: Option<String> = row.try_get_by_index(3).ok().flatten();
73        let input_str: Option<String> = row.try_get_by_index(4).ok().flatten();
74        let input = input_str.and_then(|s| serde_json::from_str(&s).ok());
75
76        tasks.push(DueTask {
77            id,
78            name,
79            cron,
80            handler,
81            input,
82        });
83    }
84    Ok(tasks)
85}
86
87/// Process a single due scheduled task:
88/// 1. Spawn a child workflow instance under the scheduled parent
89/// 2. Advance `next_run_at` to the next cron occurrence
90///
91/// The entire operation runs in a single transaction so that sequence
92/// computation, child insert, and parent update are atomic. This prevents
93/// race conditions when multiple scheduler instances run concurrently.
94///
95/// The child workflow is dispatched via the registered handler (if any).
96async fn trigger_scheduled_task(
97    db: &DatabaseConnection,
98    task: &DueTask,
99    executor_id: &str,
100) -> Result<Uuid, DurableError> {
101    let child_id = Uuid::new_v4();
102    let now = Utc::now();
103
104    // Escape single quotes in user-supplied strings to prevent SQL injection
105    let escaped_name = task.name.replace('\'', "''");
106    let escaped_executor = executor_id.replace('\'', "''");
107
108    // Build the child name: <parent_name>_run_<seq> (seq computed inside the CTE)
109    // Serialize input (inherit from parent)
110    let input_literal = match &task.input {
111        Some(v) => {
112            let s = serde_json::to_string(v)?;
113            format!("'{}'", s.replace('\'', "''"))
114        }
115        None => "NULL".to_string(),
116    };
117
118    // Handler column (inherit from parent)
119    let (handler_col, handler_val) = match &task.handler {
120        Some(h) => (", handler", format!(", '{}'", h.replace('\'', "''"))),
121        None => ("", String::new()),
122    };
123
124    // Compute next_run_at advancement
125    let next_run_sql = match next_run(&task.cron, now) {
126        Some(next_time) => format!("'{}'::timestamptz", next_time.to_rfc3339()),
127        None => "NULL".to_string(),
128    };
129
130    // Single atomic statement: compute sequence, insert child, advance parent
131    let sql = format!(
132        "WITH next_seq AS ( \
133             SELECT COALESCE(MAX(sequence), -1) + 1 AS seq \
134             FROM durable.task WHERE parent_id = '{parent_id}' \
135         ), \
136         child AS ( \
137             INSERT INTO durable.task \
138             (id, parent_id, sequence, name, kind, status, input, executor_id, started_at{handler_col}) \
139             SELECT '{child_id}', '{parent_id}', seq, \
140                    '{escaped_name}_run_' || seq, 'WORKFLOW', \
141                    'RUNNING'{TS}, {input_literal}, '{escaped_executor}', now(){handler_val} \
142             FROM next_seq \
143             RETURNING sequence \
144         ) \
145         UPDATE durable.task SET next_run_at = {next_run_sql} \
146         WHERE id = '{parent_id}'",
147        parent_id = task.id,
148    );
149    db.execute(Statement::from_string(DbBackend::Postgres, sql))
150        .await?;
151
152    if next_run_sql == "NULL" {
153        tracing::warn!(
154            id = %task.id,
155            cron = %task.cron,
156            "cron expression has no future occurrences, cleared next_run_at"
157        );
158    }
159
160    tracing::info!(
161        scheduled_id = %task.id,
162        child_id = %child_id,
163        name = %task.name,
164        "triggered scheduled workflow instance"
165    );
166
167    Ok(child_id)
168}
169
170/// Run one tick of the scheduler: find due tasks, trigger each, dispatch handlers.
171pub async fn tick(db: &DatabaseConnection, executor_id: &str) -> Result<Vec<Uuid>, DurableError> {
172    let due = find_due_tasks(db).await?;
173    if due.is_empty() {
174        return Ok(vec![]);
175    }
176
177    tracing::info!("scheduler found {} due task(s)", due.len());
178
179    let mut child_ids = Vec::with_capacity(due.len());
180    for task in &due {
181        match trigger_scheduled_task(db, task, executor_id).await {
182            Ok(child_id) => {
183                // Auto-dispatch if handler is registered
184                let lookup_key = task.handler.as_deref().unwrap_or(&task.name);
185                if let Some(reg) = crate::find_workflow(lookup_key) {
186                    let db_inner = db.clone();
187                    let resume = reg.resume_fn;
188                    tokio::spawn(async move {
189                        crate::run_workflow_with_recovery(db_inner, child_id, resume).await;
190                    });
191                }
192                child_ids.push(child_id);
193            }
194            Err(e) => {
195                tracing::error!(
196                    id = %task.id,
197                    error = %e,
198                    "failed to trigger scheduled task"
199                );
200            }
201        }
202    }
203
204    Ok(child_ids)
205}
206
207/// Spawn a background scheduler loop that polls for due cron tasks.
208///
209/// Returns a `JoinHandle` so the caller can abort it on shutdown.
210pub fn start_scheduler_loop(
211    db: DatabaseConnection,
212    executor_id: String,
213    config: &SchedulerConfig,
214) -> JoinHandle<()> {
215    let interval = config.poll_interval;
216    tokio::spawn(async move {
217        let mut ticker = tokio::time::interval(interval);
218        loop {
219            ticker.tick().await;
220            if let Err(e) = tick(&db, &executor_id).await {
221                tracing::warn!("scheduler tick failed: {e}");
222            }
223        }
224    })
225}