Skip to main content

forge_runtime/workflow/
scheduler.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12/// Configuration for the workflow scheduler.
13#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15    /// How often to poll for ready workflows.
16    pub poll_interval: Duration,
17    /// Maximum workflows to process per poll.
18    pub batch_size: i32,
19    /// Whether to process event-based wakeups.
20    pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24    fn default() -> Self {
25        Self {
26            poll_interval: Duration::from_secs(1),
27            batch_size: 100,
28            process_events: true,
29        }
30    }
31}
32
33/// Scheduler for durable workflows.
34///
35/// Polls the database for suspended workflows that are ready to resume
36/// (timer expired or event received) and triggers their execution.
37/// Also listens for NOTIFY events on the `forge_workflow_wakeup` channel
38/// for immediate wakeup when a workflow event is inserted.
39pub struct WorkflowScheduler {
40    pool: PgPool,
41    executor: Arc<WorkflowExecutor>,
42    event_store: Arc<EventStore>,
43    config: WorkflowSchedulerConfig,
44}
45
46impl WorkflowScheduler {
47    /// Create a new workflow scheduler.
48    pub fn new(
49        pool: PgPool,
50        executor: Arc<WorkflowExecutor>,
51        event_store: Arc<EventStore>,
52        config: WorkflowSchedulerConfig,
53    ) -> Self {
54        Self {
55            pool,
56            executor,
57            event_store,
58            config,
59        }
60    }
61
62    /// Run the scheduler until shutdown.
63    ///
64    /// Combines polling with NOTIFY-driven wakeup. When a workflow event is
65    /// inserted, the `forge_workflow_event_notify` trigger fires a NOTIFY on
66    /// the `forge_workflow_wakeup` channel, and we process immediately instead
67    /// of waiting for the next poll cycle. Polling remains as a fallback at a
68    /// longer interval (10x the base) to catch anything missed.
69    pub async fn run(&self, shutdown: CancellationToken) {
70        let fallback_interval = self.config.poll_interval * 10;
71        let mut interval = tokio::time::interval(fallback_interval);
72        let mut cleanup_interval = tokio::time::interval(Duration::from_secs(3600));
73
74        // Set up NOTIFY listener for immediate wakeup
75        let mut listener = match sqlx::postgres::PgListener::connect_with(&self.pool).await {
76            Ok(mut l) => {
77                if let Err(e) = l.listen("forge_workflow_wakeup").await {
78                    tracing::warn!(error = %e, "Failed to listen on workflow wakeup channel, using poll-only mode");
79                }
80                Some(l)
81            }
82            Err(e) => {
83                tracing::warn!(error = %e, "Failed to create workflow wakeup listener, using poll-only mode");
84                None
85            }
86        };
87
88        tracing::debug!(
89            poll_interval = ?fallback_interval,
90            batch_size = self.config.batch_size,
91            notify_enabled = listener.is_some(),
92            "Workflow scheduler started"
93        );
94
95        loop {
96            tokio::select! {
97                _ = interval.tick() => {
98                    if let Err(e) = self.process_ready_workflows().await {
99                        tracing::warn!(error = %e, "Failed to process ready workflows");
100                    }
101                }
102                notification = async {
103                    match listener.as_mut() {
104                        Some(l) => l.recv().await,
105                        None => std::future::pending().await,
106                    }
107                } => {
108                    match notification {
109                        Ok(_) => {
110                            if let Err(e) = self.process_ready_workflows().await {
111                                tracing::warn!(error = %e, "Failed to process workflows after wakeup");
112                            }
113                        }
114                        Err(e) => {
115                            tracing::debug!(error = %e, "Workflow wakeup listener error, will retry on next poll");
116                        }
117                    }
118                }
119                _ = cleanup_interval.tick() => {
120                    // Periodically clean up consumed events older than 24 hours
121                    let cutoff = chrono::Utc::now() - chrono::Duration::hours(24);
122                    match self.event_store.cleanup_consumed_events(cutoff).await {
123                        Ok(count) if count > 0 => {
124                            tracing::debug!(count, "Cleaned up consumed workflow events");
125                        }
126                        Err(e) => {
127                            tracing::debug!(error = %e, "Failed to clean up consumed events");
128                        }
129                        _ => {}
130                    }
131                }
132                _ = shutdown.cancelled() => {
133                    tracing::debug!("Workflow scheduler shutting down");
134                    break;
135                }
136            }
137        }
138    }
139
140    /// Process workflows that are ready to resume.
141    async fn process_ready_workflows(&self) -> Result<()> {
142        // Query for workflows ready to wake (timer or event timeout)
143        let workflows: Vec<(Uuid, Option<String>)> = sqlx::query_as(
144            r#"
145            SELECT id, waiting_for_event FROM forge_workflow_runs
146            WHERE status = 'waiting' AND (
147                (wake_at IS NOT NULL AND wake_at <= NOW())
148                OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
149            )
150            ORDER BY COALESCE(wake_at, event_timeout_at) ASC
151            LIMIT $1
152            FOR UPDATE SKIP LOCKED
153            "#,
154        )
155        .bind(self.config.batch_size)
156        .fetch_all(&self.pool)
157        .await
158        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
159
160        let count = workflows.len();
161        if count > 0 {
162            tracing::trace!(count, "Processing ready workflows");
163        }
164
165        for (workflow_id, waiting_for_event) in workflows {
166            if waiting_for_event.is_some() {
167                // Event timeout - resume with timeout error
168                self.resume_with_timeout(workflow_id).await;
169            } else {
170                // Timer expired - normal resume
171                self.resume_workflow(workflow_id).await;
172            }
173        }
174
175        // Also check for workflows waiting for events that now have events
176        if self.config.process_events {
177            self.process_event_wakeups().await?;
178        }
179
180        Ok(())
181    }
182
183    /// Process workflows that have pending events.
184    async fn process_event_wakeups(&self) -> Result<()> {
185        // Find workflows waiting for events that have matching events
186        // Use a subquery to avoid DISTINCT with FOR UPDATE
187        let workflows: Vec<(Uuid, String)> = sqlx::query_as(
188            r#"
189            SELECT wr.id, wr.waiting_for_event
190            FROM forge_workflow_runs wr
191            WHERE wr.status = 'waiting'
192                AND wr.waiting_for_event IS NOT NULL
193                AND EXISTS (
194                    SELECT 1 FROM forge_workflow_events we
195                    WHERE we.correlation_id = wr.id::text
196                    AND we.event_name = wr.waiting_for_event
197                    AND we.consumed_at IS NULL
198                )
199            LIMIT $1
200            FOR UPDATE OF wr SKIP LOCKED
201            "#,
202        )
203        .bind(self.config.batch_size)
204        .fetch_all(&self.pool)
205        .await
206        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
207
208        for (workflow_id, event_name) in workflows {
209            // Consume the event via event_store so it's marked as processed
210            match self
211                .event_store
212                .consume_event(&event_name, &workflow_id.to_string(), workflow_id)
213                .await
214            {
215                Ok(Some(_event)) => {
216                    self.resume_with_event(workflow_id).await;
217                }
218                Ok(None) => {
219                    tracing::debug!(
220                        workflow_run_id = %workflow_id,
221                        event_name = %event_name,
222                        "Event already consumed, skipping wakeup"
223                    );
224                }
225                Err(e) => {
226                    tracing::warn!(
227                        workflow_run_id = %workflow_id,
228                        error = %e,
229                        "Failed to consume workflow event"
230                    );
231                }
232            }
233        }
234
235        Ok(())
236    }
237
238    /// Resume a workflow after timer expiry.
239    async fn resume_workflow(&self, workflow_run_id: Uuid) {
240        // Clear wake state
241        if let Err(e) = sqlx::query(
242            r#"
243            UPDATE forge_workflow_runs
244            SET wake_at = NULL, suspended_at = NULL, status = 'running'
245            WHERE id = $1
246            "#,
247        )
248        .bind(workflow_run_id)
249        .execute(&self.pool)
250        .await
251        {
252            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear wake state");
253            return;
254        }
255
256        // Resume execution - use resume_from_sleep so ctx.sleep() returns immediately
257        if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
258            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow");
259        } else {
260            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after timer");
261        }
262    }
263
264    /// Resume a workflow after event timeout.
265    async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
266        // Clear waiting state
267        if let Err(e) = sqlx::query(
268            r#"
269            UPDATE forge_workflow_runs
270            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
271            WHERE id = $1
272            "#,
273        )
274        .bind(workflow_run_id)
275        .execute(&self.pool)
276        .await
277        {
278            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state");
279            return;
280        }
281
282        // Resume execution - the workflow will get a timeout error
283        if let Err(e) = self.executor.resume(workflow_run_id).await {
284            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after timeout");
285        } else {
286            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event timeout");
287        }
288    }
289
290    /// Resume a workflow that received an event.
291    async fn resume_with_event(&self, workflow_run_id: Uuid) {
292        // Clear waiting state
293        if let Err(e) = sqlx::query(
294            r#"
295            UPDATE forge_workflow_runs
296            SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
297            WHERE id = $1
298            "#,
299        )
300        .bind(workflow_run_id)
301        .execute(&self.pool)
302        .await
303        {
304            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state for event");
305            return;
306        }
307
308        // Resume execution
309        if let Err(e) = self.executor.resume(workflow_run_id).await {
310            tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after event");
311        } else {
312            tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event");
313        }
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_scheduler_config_default() {
323        let config = WorkflowSchedulerConfig::default();
324        assert_eq!(config.poll_interval, Duration::from_secs(1));
325        assert_eq!(config.batch_size, 100);
326        assert!(config.process_events);
327    }
328}