Skip to main content

forge_runtime/workflow/
scheduler.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12/// Configuration for the workflow scheduler.
13#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15    /// How often to poll for ready workflows.
16    pub poll_interval: Duration,
17    /// Maximum workflows to process per poll.
18    pub batch_size: i32,
19    /// Whether to process event-based wakeups.
20    pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24    fn default() -> Self {
25        Self {
26            poll_interval: Duration::from_secs(1),
27            batch_size: 100,
28            process_events: true,
29        }
30    }
31}
32
33/// Scheduler for durable workflows.
34///
35/// Polls the database for suspended workflows that are ready to resume
36/// (timer expired or event received) and triggers their execution.
37pub struct WorkflowScheduler {
38    pool: PgPool,
39    executor: Arc<WorkflowExecutor>,
40    #[allow(dead_code)]
41    event_store: Arc<EventStore>,
42    config: WorkflowSchedulerConfig,
43}
44
45impl WorkflowScheduler {
46    /// Create a new workflow scheduler.
47    pub fn new(
48        pool: PgPool,
49        executor: Arc<WorkflowExecutor>,
50        event_store: Arc<EventStore>,
51        config: WorkflowSchedulerConfig,
52    ) -> Self {
53        Self {
54            pool,
55            executor,
56            event_store,
57            config,
58        }
59    }
60
61    /// Run the scheduler until shutdown.
62    pub async fn run(&self, shutdown: CancellationToken) {
63        let mut interval = tokio::time::interval(self.config.poll_interval);
64
65        tracing::debug!(
66            poll_interval = ?self.config.poll_interval,
67            batch_size = self.config.batch_size,
68            "Workflow scheduler started"
69        );
70
71        loop {
72            tokio::select! {
73                _ = interval.tick() => {
74                    if let Err(e) = self.process_ready_workflows().await {
75                        tracing::warn!(error = %e, "Failed to process ready workflows");
76                    }
77                }
78                _ = shutdown.cancelled() => {
79                    tracing::debug!("Workflow scheduler shutting down");
80                    break;
81                }
82            }
83        }
84    }
85
86    /// Process workflows that are ready to resume.
87    async fn process_ready_workflows(&self) -> Result<()> {
88        // Query for workflows ready to wake (timer or event timeout)
89        let workflows: Vec<(Uuid, Option<String>)> = sqlx::query_as(
90            r#"
91            SELECT id, waiting_for_event FROM forge_workflow_runs
92            WHERE status = 'waiting' AND (
93                (wake_at IS NOT NULL AND wake_at <= NOW())
94                OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
95            )
96            ORDER BY COALESCE(wake_at, event_timeout_at) ASC
97            LIMIT $1
98            FOR UPDATE SKIP LOCKED
99            "#,
100        )
101        .bind(self.config.batch_size)
102        .fetch_all(&self.pool)
103        .await
104        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
105
106        let count = workflows.len();
107        if count > 0 {
108            tracing::trace!(count, "Processing ready workflows");
109        }
110
111        for (workflow_id, waiting_for_event) in workflows {
112            if waiting_for_event.is_some() {
113                // Event timeout - resume with timeout error
114                self.resume_with_timeout(workflow_id).await;
115            } else {
116                // Timer expired - normal resume
117                self.resume_workflow(workflow_id).await;
118            }
119        }
120
121        // Also check for workflows waiting for events that now have events
122        if self.config.process_events {
123            self.process_event_wakeups().await?;
124        }
125
126        Ok(())
127    }
128
129    /// Process workflows that have pending events.
130    async fn process_event_wakeups(&self) -> Result<()> {
131        // Find workflows waiting for events that have matching events
132        // Use a subquery to avoid DISTINCT with FOR UPDATE
133        let workflows: Vec<(Uuid, String)> = sqlx::query_as(
134            r#"
135            SELECT wr.id, wr.waiting_for_event
136            FROM forge_workflow_runs wr
137            WHERE wr.status = 'waiting'
138                AND wr.waiting_for_event IS NOT NULL
139                AND EXISTS (
140                    SELECT 1 FROM forge_workflow_events we
141                    WHERE we.correlation_id = wr.id::text
142                    AND we.event_name = wr.waiting_for_event
143                    AND we.consumed_at IS NULL
144                )
145            LIMIT $1
146            FOR UPDATE OF wr SKIP LOCKED
147            "#,
148        )
149        .bind(self.config.batch_size)
150        .fetch_all(&self.pool)
151        .await
152        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
153
154        for (workflow_id, _event_name) in workflows {
155            self.resume_with_event(workflow_id).await;
156        }
157
158        Ok(())
159    }
160
161    /// Resume a workflow after timer expiry.
162    async fn resume_workflow(&self, workflow_run_id: Uuid) {
163        // Clear wake state
164        if let Err(e) = sqlx::query(
165            r#"
166            UPDATE forge_workflow_runs
167            SET wake_at = NULL, suspended_at = NULL, status = 'running'
168            WHERE id = $1
169            "#,
170        )
171        .bind(workflow_run_id)
172        .execute(&self.pool)
173        .await
174        {
175            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear wake state");
176            return;
177        }
178
179        // Resume execution - use resume_from_sleep so ctx.sleep() returns immediately
180        if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
181            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow");
182        } else {
183            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after timer");
184        }
185    }
186
187    /// Resume a workflow after event timeout.
188    async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
189        // Clear waiting state
190        if let Err(e) = sqlx::query(
191            r#"
192            UPDATE forge_workflow_runs
193            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
194            WHERE id = $1
195            "#,
196        )
197        .bind(workflow_run_id)
198        .execute(&self.pool)
199        .await
200        {
201            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state");
202            return;
203        }
204
205        // Resume execution - the workflow will get a timeout error
206        if let Err(e) = self.executor.resume(workflow_run_id).await {
207            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after timeout");
208        } else {
209            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event timeout");
210        }
211    }
212
213    /// Resume a workflow that received an event.
214    async fn resume_with_event(&self, workflow_run_id: Uuid) {
215        // Clear waiting state
216        if let Err(e) = sqlx::query(
217            r#"
218            UPDATE forge_workflow_runs
219            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
220            WHERE id = $1
221            "#,
222        )
223        .bind(workflow_run_id)
224        .execute(&self.pool)
225        .await
226        {
227            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state for event");
228            return;
229        }
230
231        // Resume execution
232        if let Err(e) = self.executor.resume(workflow_run_id).await {
233            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after event");
234        } else {
235            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event");
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_scheduler_config_default() {
246        let config = WorkflowSchedulerConfig::default();
247        assert_eq!(config.poll_interval, Duration::from_secs(1));
248        assert_eq!(config.batch_size, 100);
249        assert!(config.process_events);
250    }
251}