Skip to main content

forge_runtime/workflow/
scheduler.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12/// Configuration for the workflow scheduler.
13#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15    /// How often to poll for ready workflows.
16    pub poll_interval: Duration,
17    /// Maximum workflows to process per poll.
18    pub batch_size: i32,
19    /// Whether to process event-based wakeups.
20    pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24    fn default() -> Self {
25        Self {
26            poll_interval: Duration::from_secs(1),
27            batch_size: 100,
28            process_events: true,
29        }
30    }
31}
32
33/// Scheduler for durable workflows.
34///
35/// Polls the database for suspended workflows that are ready to resume
36/// (timer expired or event received) and triggers their execution.
37/// Also listens for NOTIFY events on the `forge_workflow_wakeup` channel
38/// for immediate wakeup when a workflow event is inserted.
39pub struct WorkflowScheduler {
40    pool: PgPool,
41    executor: Arc<WorkflowExecutor>,
42    event_store: Arc<EventStore>,
43    config: WorkflowSchedulerConfig,
44}
45
46impl WorkflowScheduler {
47    /// Create a new workflow scheduler.
48    pub fn new(
49        pool: PgPool,
50        executor: Arc<WorkflowExecutor>,
51        event_store: Arc<EventStore>,
52        config: WorkflowSchedulerConfig,
53    ) -> Self {
54        Self {
55            pool,
56            executor,
57            event_store,
58            config,
59        }
60    }
61
62    /// Run the scheduler until shutdown.
63    ///
64    /// Combines polling with NOTIFY-driven wakeup. When a workflow event is
65    /// inserted, the `forge_workflow_event_notify` trigger fires a NOTIFY on
66    /// the `forge_workflow_wakeup` channel, and we process immediately instead
67    /// of waiting for the next poll cycle. Polling remains as a fallback at a
68    /// longer interval (10x the base) to catch anything missed.
69    pub async fn run(&self, shutdown: CancellationToken) {
70        let fallback_interval = self.config.poll_interval * 10;
71        let mut interval = tokio::time::interval(fallback_interval);
72        let mut cleanup_interval = tokio::time::interval(Duration::from_secs(3600));
73
74        // Set up NOTIFY listener for immediate wakeup
75        let mut listener = match sqlx::postgres::PgListener::connect_with(&self.pool).await {
76            Ok(mut l) => {
77                if let Err(e) = l.listen("forge_workflow_wakeup").await {
78                    tracing::warn!(error = %e, "Failed to listen on workflow wakeup channel, using poll-only mode");
79                }
80                Some(l)
81            }
82            Err(e) => {
83                tracing::warn!(error = %e, "Failed to create workflow wakeup listener, using poll-only mode");
84                None
85            }
86        };
87
88        tracing::debug!(
89            poll_interval = ?fallback_interval,
90            batch_size = self.config.batch_size,
91            notify_enabled = listener.is_some(),
92            "Workflow scheduler started"
93        );
94
95        loop {
96            tokio::select! {
97                _ = interval.tick() => {
98                    if let Err(e) = self.process_ready_workflows().await {
99                        tracing::warn!(error = %e, "Failed to process ready workflows");
100                    }
101                }
102                notification = async {
103                    match listener.as_mut() {
104                        Some(l) => l.recv().await,
105                        None => std::future::pending().await,
106                    }
107                } => {
108                    match notification {
109                        Ok(_) => {
110                            if let Err(e) = self.process_ready_workflows().await {
111                                tracing::warn!(error = %e, "Failed to process workflows after wakeup");
112                            }
113                        }
114                        Err(e) => {
115                            tracing::debug!(error = %e, "Workflow wakeup listener error, will retry on next poll");
116                        }
117                    }
118                }
119                _ = cleanup_interval.tick() => {
120                    // Periodically clean up consumed events older than 24 hours
121                    let cutoff = chrono::Utc::now() - chrono::Duration::hours(24);
122                    match self.event_store.cleanup_consumed_events(cutoff).await {
123                        Ok(count) if count > 0 => {
124                            tracing::debug!(count, "Cleaned up consumed workflow events");
125                        }
126                        Err(e) => {
127                            tracing::debug!(error = %e, "Failed to clean up consumed events");
128                        }
129                        _ => {}
130                    }
131                }
132                _ = shutdown.cancelled() => {
133                    tracing::debug!("Workflow scheduler shutting down");
134                    break;
135                }
136            }
137        }
138    }
139
140    /// Process workflows that are ready to resume.
141    async fn process_ready_workflows(&self) -> Result<()> {
142        // Query for workflows ready to wake (timer or event timeout)
143        let workflows = sqlx::query!(
144            r#"
145            SELECT id, workflow_name, workflow_version, workflow_signature, waiting_for_event
146            FROM forge_workflow_runs
147            WHERE status = 'waiting' AND (
148                (wake_at IS NOT NULL AND wake_at <= NOW())
149                OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
150            )
151            ORDER BY COALESCE(wake_at, event_timeout_at) ASC
152            LIMIT $1
153            FOR UPDATE SKIP LOCKED
154            "#,
155            self.config.batch_size as i64
156        )
157        .fetch_all(&self.pool)
158        .await
159        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
160
161        let count = workflows.len();
162        if count > 0 {
163            tracing::trace!(count, "Processing ready workflows");
164        }
165
166        for workflow in workflows {
167            if workflow.waiting_for_event.is_some() {
168                // Event timeout - resume with timeout error
169                self.resume_with_timeout(workflow.id).await;
170            } else {
171                // Timer expired - normal resume
172                self.resume_workflow(workflow.id).await;
173            }
174        }
175
176        // Also check for workflows waiting for events that now have events
177        if self.config.process_events {
178            self.process_event_wakeups().await?;
179        }
180
181        Ok(())
182    }
183
184    /// Process workflows that have pending events.
185    async fn process_event_wakeups(&self) -> Result<()> {
186        // Find workflows waiting for events that have matching events
187        // Use a subquery to avoid DISTINCT with FOR UPDATE
188        let workflows = sqlx::query!(
189            r#"
190            SELECT wr.id, wr.waiting_for_event
191            FROM forge_workflow_runs wr
192            WHERE wr.status = 'waiting'
193                AND wr.waiting_for_event IS NOT NULL
194                AND EXISTS (
195                    SELECT 1 FROM forge_workflow_events we
196                    WHERE we.correlation_id = wr.id::text
197                    AND we.event_name = wr.waiting_for_event
198                    AND we.consumed_at IS NULL
199                )
200            LIMIT $1
201            FOR UPDATE OF wr SKIP LOCKED
202            "#,
203            self.config.batch_size as i64
204        )
205        .fetch_all(&self.pool)
206        .await
207        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
208
209        for workflow in workflows {
210            let workflow_id = workflow.id;
211            let Some(event_name) = workflow.waiting_for_event else {
212                continue;
213            };
214            // Consume the event via event_store so it's marked as processed
215            match self
216                .event_store
217                .consume_event(&event_name, &workflow_id.to_string(), workflow_id)
218                .await
219            {
220                Ok(Some(_event)) => {
221                    self.resume_with_event(workflow_id).await;
222                }
223                Ok(None) => {
224                    tracing::debug!(
225                        workflow_run_id = %workflow_id,
226                        event_name = %event_name,
227                        "Event already consumed, skipping wakeup"
228                    );
229                }
230                Err(e) => {
231                    tracing::warn!(
232                        workflow_run_id = %workflow_id,
233                        error = %e,
234                        "Failed to consume workflow event"
235                    );
236                }
237            }
238        }
239
240        Ok(())
241    }
242
243    /// Resume a workflow after timer expiry.
244    async fn resume_workflow(&self, workflow_run_id: Uuid) {
245        // Clear wake state
246        if let Err(e) = sqlx::query!(
247            r#"
248            UPDATE forge_workflow_runs
249            SET wake_at = NULL, suspended_at = NULL, status = 'running'
250            WHERE id = $1
251            "#,
252            workflow_run_id,
253        )
254        .execute(&self.pool)
255        .await
256        {
257            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear wake state");
258            return;
259        }
260
261        // Resume execution - use resume_from_sleep so ctx.sleep() returns immediately
262        if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
263            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow");
264        } else {
265            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after timer");
266        }
267    }
268
269    /// Resume a workflow after event timeout.
270    async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
271        // Clear waiting state
272        if let Err(e) = sqlx::query!(
273            r#"
274            UPDATE forge_workflow_runs
275            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
276            WHERE id = $1
277            "#,
278            workflow_run_id,
279        )
280        .execute(&self.pool)
281        .await
282        {
283            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state");
284            return;
285        }
286
287        // Resume execution - the workflow will get a timeout error
288        if let Err(e) = self.executor.resume(workflow_run_id).await {
289            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after timeout");
290        } else {
291            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event timeout");
292        }
293    }
294
295    /// Resume a workflow that received an event.
296    async fn resume_with_event(&self, workflow_run_id: Uuid) {
297        // Clear waiting state
298        if let Err(e) = sqlx::query!(
299            r#"
300            UPDATE forge_workflow_runs
301            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
302            WHERE id = $1
303            "#,
304            workflow_run_id,
305        )
306        .execute(&self.pool)
307        .await
308        {
309            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state for event");
310            return;
311        }
312
313        // Resume execution
314        if let Err(e) = self.executor.resume(workflow_run_id).await {
315            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after event");
316        } else {
317            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event");
318        }
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_scheduler_config_default() {
328        let config = WorkflowSchedulerConfig::default();
329        assert_eq!(config.poll_interval, Duration::from_secs(1));
330        assert_eq!(config.batch_size, 100);
331        assert!(config.process_events);
332    }
333}