Skip to main content

forge_runtime/workflow/
scheduler.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12/// Configuration for the workflow scheduler.
13#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15    /// How often to poll for ready workflows.
16    pub poll_interval: Duration,
17    /// Maximum workflows to process per poll.
18    pub batch_size: i32,
19    /// Whether to process event-based wakeups.
20    pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24    fn default() -> Self {
25        Self {
26            poll_interval: Duration::from_secs(1),
27            batch_size: 100,
28            process_events: true,
29        }
30    }
31}
32
33/// Scheduler for durable workflows.
34///
35/// Polls the database for suspended workflows that are ready to resume
36/// (timer expired or event received) and triggers their execution.
37/// Also listens for NOTIFY events on the `forge_workflow_wakeup` channel
38/// for immediate wakeup when a workflow event is inserted.
39pub struct WorkflowScheduler {
40    pool: PgPool,
41    executor: Arc<WorkflowExecutor>,
42    #[allow(dead_code)]
43    event_store: Arc<EventStore>,
44    config: WorkflowSchedulerConfig,
45}
46
47impl WorkflowScheduler {
48    /// Create a new workflow scheduler.
49    pub fn new(
50        pool: PgPool,
51        executor: Arc<WorkflowExecutor>,
52        event_store: Arc<EventStore>,
53        config: WorkflowSchedulerConfig,
54    ) -> Self {
55        Self {
56            pool,
57            executor,
58            event_store,
59            config,
60        }
61    }
62
63    /// Run the scheduler until shutdown.
64    ///
65    /// Combines polling with NOTIFY-driven wakeup. When a workflow event is
66    /// inserted, the `forge_workflow_event_notify` trigger fires a NOTIFY on
67    /// the `forge_workflow_wakeup` channel, and we process immediately instead
68    /// of waiting for the next poll cycle. Polling remains as a fallback at a
69    /// longer interval (10x the base) to catch anything missed.
70    pub async fn run(&self, shutdown: CancellationToken) {
71        let fallback_interval = self.config.poll_interval * 10;
72        let mut interval = tokio::time::interval(fallback_interval);
73
74        // Set up NOTIFY listener for immediate wakeup
75        let mut listener = match sqlx::postgres::PgListener::connect_with(&self.pool).await {
76            Ok(mut l) => {
77                if let Err(e) = l.listen("forge_workflow_wakeup").await {
78                    tracing::warn!(error = %e, "Failed to listen on workflow wakeup channel, using poll-only mode");
79                }
80                Some(l)
81            }
82            Err(e) => {
83                tracing::warn!(error = %e, "Failed to create workflow wakeup listener, using poll-only mode");
84                None
85            }
86        };
87
88        tracing::debug!(
89            poll_interval = ?fallback_interval,
90            batch_size = self.config.batch_size,
91            notify_enabled = listener.is_some(),
92            "Workflow scheduler started"
93        );
94
95        loop {
96            tokio::select! {
97                _ = interval.tick() => {
98                    if let Err(e) = self.process_ready_workflows().await {
99                        tracing::warn!(error = %e, "Failed to process ready workflows");
100                    }
101                }
102                notification = async {
103                    match listener.as_mut() {
104                        Some(l) => l.recv().await,
105                        None => std::future::pending().await,
106                    }
107                } => {
108                    match notification {
109                        Ok(_) => {
110                            if let Err(e) = self.process_ready_workflows().await {
111                                tracing::warn!(error = %e, "Failed to process workflows after wakeup");
112                            }
113                        }
114                        Err(e) => {
115                            tracing::debug!(error = %e, "Workflow wakeup listener error, will retry on next poll");
116                        }
117                    }
118                }
119                _ = shutdown.cancelled() => {
120                    tracing::debug!("Workflow scheduler shutting down");
121                    break;
122                }
123            }
124        }
125    }
126
127    /// Process workflows that are ready to resume.
128    async fn process_ready_workflows(&self) -> Result<()> {
129        // Query for workflows ready to wake (timer or event timeout)
130        let workflows: Vec<(Uuid, Option<String>)> = sqlx::query_as(
131            r#"
132            SELECT id, waiting_for_event FROM forge_workflow_runs
133            WHERE status = 'waiting' AND (
134                (wake_at IS NOT NULL AND wake_at <= NOW())
135                OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
136            )
137            ORDER BY COALESCE(wake_at, event_timeout_at) ASC
138            LIMIT $1
139            FOR UPDATE SKIP LOCKED
140            "#,
141        )
142        .bind(self.config.batch_size)
143        .fetch_all(&self.pool)
144        .await
145        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
146
147        let count = workflows.len();
148        if count > 0 {
149            tracing::trace!(count, "Processing ready workflows");
150        }
151
152        for (workflow_id, waiting_for_event) in workflows {
153            if waiting_for_event.is_some() {
154                // Event timeout - resume with timeout error
155                self.resume_with_timeout(workflow_id).await;
156            } else {
157                // Timer expired - normal resume
158                self.resume_workflow(workflow_id).await;
159            }
160        }
161
162        // Also check for workflows waiting for events that now have events
163        if self.config.process_events {
164            self.process_event_wakeups().await?;
165        }
166
167        Ok(())
168    }
169
170    /// Process workflows that have pending events.
171    async fn process_event_wakeups(&self) -> Result<()> {
172        // Find workflows waiting for events that have matching events
173        // Use a subquery to avoid DISTINCT with FOR UPDATE
174        let workflows: Vec<(Uuid, String)> = sqlx::query_as(
175            r#"
176            SELECT wr.id, wr.waiting_for_event
177            FROM forge_workflow_runs wr
178            WHERE wr.status = 'waiting'
179                AND wr.waiting_for_event IS NOT NULL
180                AND EXISTS (
181                    SELECT 1 FROM forge_workflow_events we
182                    WHERE we.correlation_id = wr.id::text
183                    AND we.event_name = wr.waiting_for_event
184                    AND we.consumed_at IS NULL
185                )
186            LIMIT $1
187            FOR UPDATE OF wr SKIP LOCKED
188            "#,
189        )
190        .bind(self.config.batch_size)
191        .fetch_all(&self.pool)
192        .await
193        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
194
195        for (workflow_id, _event_name) in workflows {
196            self.resume_with_event(workflow_id).await;
197        }
198
199        Ok(())
200    }
201
202    /// Resume a workflow after timer expiry.
203    async fn resume_workflow(&self, workflow_run_id: Uuid) {
204        // Clear wake state
205        if let Err(e) = sqlx::query(
206            r#"
207            UPDATE forge_workflow_runs
208            SET wake_at = NULL, suspended_at = NULL, status = 'running'
209            WHERE id = $1
210            "#,
211        )
212        .bind(workflow_run_id)
213        .execute(&self.pool)
214        .await
215        {
216            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear wake state");
217            return;
218        }
219
220        // Resume execution - use resume_from_sleep so ctx.sleep() returns immediately
221        if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
222            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow");
223        } else {
224            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after timer");
225        }
226    }
227
228    /// Resume a workflow after event timeout.
229    async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
230        // Clear waiting state
231        if let Err(e) = sqlx::query(
232            r#"
233            UPDATE forge_workflow_runs
234            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
235            WHERE id = $1
236            "#,
237        )
238        .bind(workflow_run_id)
239        .execute(&self.pool)
240        .await
241        {
242            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state");
243            return;
244        }
245
246        // Resume execution - the workflow will get a timeout error
247        if let Err(e) = self.executor.resume(workflow_run_id).await {
248            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after timeout");
249        } else {
250            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event timeout");
251        }
252    }
253
254    /// Resume a workflow that received an event.
255    async fn resume_with_event(&self, workflow_run_id: Uuid) {
256        // Clear waiting state
257        if let Err(e) = sqlx::query(
258            r#"
259            UPDATE forge_workflow_runs
260            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
261            WHERE id = $1
262            "#,
263        )
264        .bind(workflow_run_id)
265        .execute(&self.pool)
266        .await
267        {
268            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state for event");
269            return;
270        }
271
272        // Resume execution
273        if let Err(e) = self.executor.resume(workflow_run_id).await {
274            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after event");
275        } else {
276            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event");
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_scheduler_config_default() {
287        let config = WorkflowSchedulerConfig::default();
288        assert_eq!(config.poll_interval, Duration::from_secs(1));
289        assert_eq!(config.batch_size, 100);
290        assert!(config.process_events);
291    }
292}