Skip to main content

forge_runtime/workflow/
scheduler.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12/// Configuration for the workflow scheduler.
13#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15    /// How often to poll for ready workflows.
16    pub poll_interval: Duration,
17    /// Maximum workflows to process per poll.
18    pub batch_size: i32,
19    /// Whether to process event-based wakeups.
20    pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24    fn default() -> Self {
25        Self {
26            poll_interval: Duration::from_secs(1),
27            batch_size: 100,
28            process_events: true,
29        }
30    }
31}
32
33/// Scheduler for durable workflows.
34///
35/// Polls the database for suspended workflows that are ready to resume
36/// (timer expired or event received) and triggers their execution.
37/// Also listens for NOTIFY events on the `forge_workflow_wakeup` channel
38/// for immediate wakeup when a workflow event is inserted.
39pub struct WorkflowScheduler {
40    pool: PgPool,
41    executor: Arc<WorkflowExecutor>,
42    event_store: Arc<EventStore>,
43    config: WorkflowSchedulerConfig,
44}
45
46impl WorkflowScheduler {
47    /// Create a new workflow scheduler.
48    pub fn new(
49        pool: PgPool,
50        executor: Arc<WorkflowExecutor>,
51        event_store: Arc<EventStore>,
52        config: WorkflowSchedulerConfig,
53    ) -> Self {
54        Self {
55            pool,
56            executor,
57            event_store,
58            config,
59        }
60    }
61
62    /// Run the scheduler until shutdown.
63    ///
64    /// Combines polling with NOTIFY-driven wakeup. When a workflow event is
65    /// inserted, the `forge_workflow_event_notify` trigger fires a NOTIFY on
66    /// the `forge_workflow_wakeup` channel, and we process immediately instead
67    /// of waiting for the next poll cycle. Polling remains as a fallback at a
68    /// longer interval (10x the base) to catch anything missed.
69    pub async fn run(&self, shutdown: CancellationToken) {
70        let fallback_interval = self.config.poll_interval * 10;
71        let mut interval = tokio::time::interval(fallback_interval);
72        let mut cleanup_interval = tokio::time::interval(Duration::from_secs(3600));
73
74        // Set up NOTIFY listener for immediate wakeup
75        let mut listener = match sqlx::postgres::PgListener::connect_with(&self.pool).await {
76            Ok(mut l) => {
77                if let Err(e) = l.listen("forge_workflow_wakeup").await {
78                    tracing::warn!(error = %e, "Failed to listen on workflow wakeup channel, using poll-only mode");
79                }
80                Some(l)
81            }
82            Err(e) => {
83                tracing::warn!(error = %e, "Failed to create workflow wakeup listener, using poll-only mode");
84                None
85            }
86        };
87
88        tracing::debug!(
89            poll_interval = ?fallback_interval,
90            batch_size = self.config.batch_size,
91            notify_enabled = listener.is_some(),
92            "Workflow scheduler started"
93        );
94
95        loop {
96            tokio::select! {
97                _ = interval.tick() => {
98                    if let Err(e) = self.process_ready_workflows().await {
99                        tracing::warn!(error = %e, "Failed to process ready workflows");
100                    }
101                }
102                notification = async {
103                    match listener.as_mut() {
104                        Some(l) => l.recv().await,
105                        None => std::future::pending().await,
106                    }
107                } => {
108                    match notification {
109                        Ok(_) => {
110                            if let Err(e) = self.process_ready_workflows().await {
111                                tracing::warn!(error = %e, "Failed to process workflows after wakeup");
112                            }
113                        }
114                        Err(e) => {
115                            tracing::debug!(error = %e, "Workflow wakeup listener error, will retry on next poll");
116                        }
117                    }
118                }
119                _ = cleanup_interval.tick() => {
120                    // Periodically clean up consumed events older than 24 hours
121                    let cutoff = chrono::Utc::now() - chrono::Duration::hours(24);
122                    match self.event_store.cleanup_consumed_events(cutoff).await {
123                        Ok(count) if count > 0 => {
124                            tracing::debug!(count, "Cleaned up consumed workflow events");
125                        }
126                        Err(e) => {
127                            tracing::debug!(error = %e, "Failed to clean up consumed events");
128                        }
129                        _ => {}
130                    }
131                }
132                _ = shutdown.cancelled() => {
133                    tracing::debug!("Workflow scheduler shutting down");
134                    break;
135                }
136            }
137        }
138    }
139
140    /// Process workflows that are ready to resume.
141    async fn process_ready_workflows(&self) -> Result<()> {
142        // Query for workflows ready to wake (timer or event timeout)
143        let workflows = sqlx::query!(
144            r#"
145            SELECT id, waiting_for_event FROM forge_workflow_runs
146            WHERE status = 'waiting' AND (
147                (wake_at IS NOT NULL AND wake_at <= NOW())
148                OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
149            )
150            ORDER BY COALESCE(wake_at, event_timeout_at) ASC
151            LIMIT $1
152            FOR UPDATE SKIP LOCKED
153            "#,
154            self.config.batch_size as i64
155        )
156        .fetch_all(&self.pool)
157        .await
158        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
159
160        let count = workflows.len();
161        if count > 0 {
162            tracing::trace!(count, "Processing ready workflows");
163        }
164
165        for workflow in workflows {
166            if workflow.waiting_for_event.is_some() {
167                // Event timeout - resume with timeout error
168                self.resume_with_timeout(workflow.id).await;
169            } else {
170                // Timer expired - normal resume
171                self.resume_workflow(workflow.id).await;
172            }
173        }
174
175        // Also check for workflows waiting for events that now have events
176        if self.config.process_events {
177            self.process_event_wakeups().await?;
178        }
179
180        Ok(())
181    }
182
183    /// Process workflows that have pending events.
184    async fn process_event_wakeups(&self) -> Result<()> {
185        // Find workflows waiting for events that have matching events
186        // Use a subquery to avoid DISTINCT with FOR UPDATE
187        let workflows = sqlx::query!(
188            r#"
189            SELECT wr.id, wr.waiting_for_event
190            FROM forge_workflow_runs wr
191            WHERE wr.status = 'waiting'
192                AND wr.waiting_for_event IS NOT NULL
193                AND EXISTS (
194                    SELECT 1 FROM forge_workflow_events we
195                    WHERE we.correlation_id = wr.id::text
196                    AND we.event_name = wr.waiting_for_event
197                    AND we.consumed_at IS NULL
198                )
199            LIMIT $1
200            FOR UPDATE OF wr SKIP LOCKED
201            "#,
202            self.config.batch_size as i64
203        )
204        .fetch_all(&self.pool)
205        .await
206        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
207
208        for workflow in workflows {
209            let workflow_id = workflow.id;
210            let Some(event_name) = workflow.waiting_for_event else {
211                continue;
212            };
213            // Consume the event via event_store so it's marked as processed
214            match self
215                .event_store
216                .consume_event(&event_name, &workflow_id.to_string(), workflow_id)
217                .await
218            {
219                Ok(Some(_event)) => {
220                    self.resume_with_event(workflow_id).await;
221                }
222                Ok(None) => {
223                    tracing::debug!(
224                        workflow_run_id = %workflow_id,
225                        event_name = %event_name,
226                        "Event already consumed, skipping wakeup"
227                    );
228                }
229                Err(e) => {
230                    tracing::warn!(
231                        workflow_run_id = %workflow_id,
232                        error = %e,
233                        "Failed to consume workflow event"
234                    );
235                }
236            }
237        }
238
239        Ok(())
240    }
241
242    /// Resume a workflow after timer expiry.
243    async fn resume_workflow(&self, workflow_run_id: Uuid) {
244        // Clear wake state
245        if let Err(e) = sqlx::query(
246            r#"
247            UPDATE forge_workflow_runs
248            SET wake_at = NULL, suspended_at = NULL, status = 'running'
249            WHERE id = $1
250            "#,
251        )
252        .bind(workflow_run_id)
253        .execute(&self.pool)
254        .await
255        {
256            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear wake state");
257            return;
258        }
259
260        // Resume execution - use resume_from_sleep so ctx.sleep() returns immediately
261        if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
262            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow");
263        } else {
264            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after timer");
265        }
266    }
267
268    /// Resume a workflow after event timeout.
269    async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
270        // Clear waiting state
271        if let Err(e) = sqlx::query(
272            r#"
273            UPDATE forge_workflow_runs
274            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
275            WHERE id = $1
276            "#,
277        )
278        .bind(workflow_run_id)
279        .execute(&self.pool)
280        .await
281        {
282            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state");
283            return;
284        }
285
286        // Resume execution - the workflow will get a timeout error
287        if let Err(e) = self.executor.resume(workflow_run_id).await {
288            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after timeout");
289        } else {
290            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event timeout");
291        }
292    }
293
294    /// Resume a workflow that received an event.
295    async fn resume_with_event(&self, workflow_run_id: Uuid) {
296        // Clear waiting state
297        if let Err(e) = sqlx::query(
298            r#"
299            UPDATE forge_workflow_runs
300            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
301            WHERE id = $1
302            "#,
303        )
304        .bind(workflow_run_id)
305        .execute(&self.pool)
306        .await
307        {
308            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state for event");
309            return;
310        }
311
312        // Resume execution
313        if let Err(e) = self.executor.resume(workflow_run_id).await {
314            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after event");
315        } else {
316            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event");
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_scheduler_config_default() {
327        let config = WorkflowSchedulerConfig::default();
328        assert_eq!(config.poll_interval, Duration::from_secs(1));
329        assert_eq!(config.batch_size, 100);
330        assert!(config.process_events);
331    }
332}