forge_runtime/workflow/
scheduler.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12/// Configuration for the workflow scheduler.
13#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15    /// How often to poll for ready workflows.
16    pub poll_interval: Duration,
17    /// Maximum workflows to process per poll.
18    pub batch_size: i32,
19    /// Whether to process event-based wakeups.
20    pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24    fn default() -> Self {
25        Self {
26            poll_interval: Duration::from_secs(1),
27            batch_size: 100,
28            process_events: true,
29        }
30    }
31}
32
33/// Scheduler for durable workflows.
34///
35/// Polls the database for suspended workflows that are ready to resume
36/// (timer expired or event received) and triggers their execution.
37pub struct WorkflowScheduler {
38    pool: PgPool,
39    executor: Arc<WorkflowExecutor>,
40    #[allow(dead_code)]
41    event_store: Arc<EventStore>,
42    config: WorkflowSchedulerConfig,
43}
44
45impl WorkflowScheduler {
46    /// Create a new workflow scheduler.
47    pub fn new(
48        pool: PgPool,
49        executor: Arc<WorkflowExecutor>,
50        event_store: Arc<EventStore>,
51        config: WorkflowSchedulerConfig,
52    ) -> Self {
53        Self {
54            pool,
55            executor,
56            event_store,
57            config,
58        }
59    }
60
61    /// Run the scheduler until shutdown.
62    pub async fn run(&self, shutdown: CancellationToken) {
63        let mut interval = tokio::time::interval(self.config.poll_interval);
64
65        tracing::info!(
66            poll_interval = ?self.config.poll_interval,
67            batch_size = self.config.batch_size,
68            "Workflow scheduler started"
69        );
70
71        loop {
72            tokio::select! {
73                _ = interval.tick() => {
74                    if let Err(e) = self.process_ready_workflows().await {
75                        tracing::error!(error = %e, "Failed to process ready workflows");
76                    }
77                }
78                _ = shutdown.cancelled() => {
79                    tracing::info!("Workflow scheduler shutting down");
80                    break;
81                }
82            }
83        }
84    }
85
86    /// Process workflows that are ready to resume.
87    async fn process_ready_workflows(&self) -> Result<()> {
88        // Query for workflows ready to wake (timer or event timeout)
89        let workflows: Vec<(Uuid, Option<String>)> = sqlx::query_as(
90            r#"
91            SELECT id, waiting_for_event FROM forge_workflow_runs
92            WHERE status = 'waiting' AND (
93                (wake_at IS NOT NULL AND wake_at <= NOW())
94                OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
95            )
96            ORDER BY COALESCE(wake_at, event_timeout_at) ASC
97            LIMIT $1
98            FOR UPDATE SKIP LOCKED
99            "#,
100        )
101        .bind(self.config.batch_size)
102        .fetch_all(&self.pool)
103        .await
104        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
105
106        let count = workflows.len();
107        if count > 0 {
108            tracing::debug!(count = count, "Processing ready workflows");
109        }
110
111        for (workflow_id, waiting_for_event) in workflows {
112            if waiting_for_event.is_some() {
113                // Event timeout - resume with timeout error
114                self.resume_with_timeout(workflow_id).await;
115            } else {
116                // Timer expired - normal resume
117                self.resume_workflow(workflow_id).await;
118            }
119        }
120
121        // Also check for workflows waiting for events that now have events
122        if self.config.process_events {
123            self.process_event_wakeups().await?;
124        }
125
126        Ok(())
127    }
128
129    /// Process workflows that have pending events.
130    async fn process_event_wakeups(&self) -> Result<()> {
131        // Find workflows waiting for events that have matching events
132        // Use a subquery to avoid DISTINCT with FOR UPDATE
133        let workflows: Vec<(Uuid, String)> = sqlx::query_as(
134            r#"
135            SELECT wr.id, wr.waiting_for_event
136            FROM forge_workflow_runs wr
137            WHERE wr.status = 'waiting'
138                AND wr.waiting_for_event IS NOT NULL
139                AND EXISTS (
140                    SELECT 1 FROM forge_workflow_events we
141                    WHERE we.correlation_id = wr.id::text
142                    AND we.event_name = wr.waiting_for_event
143                    AND we.consumed_at IS NULL
144                )
145            LIMIT $1
146            FOR UPDATE OF wr SKIP LOCKED
147            "#,
148        )
149        .bind(self.config.batch_size)
150        .fetch_all(&self.pool)
151        .await
152        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
153
154        for (workflow_id, _event_name) in workflows {
155            self.resume_with_event(workflow_id).await;
156        }
157
158        Ok(())
159    }
160
161    /// Resume a workflow after timer expiry.
162    async fn resume_workflow(&self, workflow_run_id: Uuid) {
163        // Clear wake state
164        if let Err(e) = sqlx::query(
165            r#"
166            UPDATE forge_workflow_runs
167            SET wake_at = NULL, suspended_at = NULL, status = 'running'
168            WHERE id = $1
169            "#,
170        )
171        .bind(workflow_run_id)
172        .execute(&self.pool)
173        .await
174        {
175            tracing::error!(
176                workflow_run_id = %workflow_run_id,
177                error = %e,
178                "Failed to clear wake state"
179            );
180            return;
181        }
182
183        // Resume execution - use resume_from_sleep so ctx.sleep() returns immediately
184        if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
185            tracing::error!(
186                workflow_run_id = %workflow_run_id,
187                error = %e,
188                "Failed to resume workflow"
189            );
190        } else {
191            tracing::info!(
192                workflow_run_id = %workflow_run_id,
193                "Resumed workflow after timer"
194            );
195        }
196    }
197
198    /// Resume a workflow after event timeout.
199    async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
200        // Clear waiting state
201        if let Err(e) = sqlx::query(
202            r#"
203            UPDATE forge_workflow_runs
204            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
205            WHERE id = $1
206            "#,
207        )
208        .bind(workflow_run_id)
209        .execute(&self.pool)
210        .await
211        {
212            tracing::error!(
213                workflow_run_id = %workflow_run_id,
214                error = %e,
215                "Failed to clear waiting state"
216            );
217            return;
218        }
219
220        // Resume execution - the workflow will get a timeout error
221        if let Err(e) = self.executor.resume(workflow_run_id).await {
222            tracing::error!(
223                workflow_run_id = %workflow_run_id,
224                error = %e,
225                "Failed to resume workflow after timeout"
226            );
227        } else {
228            tracing::info!(
229                workflow_run_id = %workflow_run_id,
230                "Resumed workflow after event timeout"
231            );
232        }
233    }
234
235    /// Resume a workflow that received an event.
236    async fn resume_with_event(&self, workflow_run_id: Uuid) {
237        // Clear waiting state
238        if let Err(e) = sqlx::query(
239            r#"
240            UPDATE forge_workflow_runs
241            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
242            WHERE id = $1
243            "#,
244        )
245        .bind(workflow_run_id)
246        .execute(&self.pool)
247        .await
248        {
249            tracing::error!(
250                workflow_run_id = %workflow_run_id,
251                error = %e,
252                "Failed to clear waiting state for event"
253            );
254            return;
255        }
256
257        // Resume execution
258        if let Err(e) = self.executor.resume(workflow_run_id).await {
259            tracing::error!(
260                workflow_run_id = %workflow_run_id,
261                error = %e,
262                "Failed to resume workflow after event"
263            );
264        } else {
265            tracing::info!(
266                workflow_run_id = %workflow_run_id,
267                "Resumed workflow after event received"
268            );
269        }
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_scheduler_config_default() {
279        let config = WorkflowSchedulerConfig::default();
280        assert_eq!(config.poll_interval, Duration::from_secs(1));
281        assert_eq!(config.batch_size, 100);
282        assert!(config.process_events);
283    }
284}