forge_runtime/workflow/
scheduler.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15 pub poll_interval: Duration,
17 pub batch_size: i32,
19 pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24 fn default() -> Self {
25 Self {
26 poll_interval: Duration::from_secs(1),
27 batch_size: 100,
28 process_events: true,
29 }
30 }
31}
32
33pub struct WorkflowScheduler {
38 pool: PgPool,
39 executor: Arc<WorkflowExecutor>,
40 #[allow(dead_code)]
41 event_store: Arc<EventStore>,
42 config: WorkflowSchedulerConfig,
43}
44
45impl WorkflowScheduler {
46 pub fn new(
48 pool: PgPool,
49 executor: Arc<WorkflowExecutor>,
50 event_store: Arc<EventStore>,
51 config: WorkflowSchedulerConfig,
52 ) -> Self {
53 Self {
54 pool,
55 executor,
56 event_store,
57 config,
58 }
59 }
60
61 pub async fn run(&self, shutdown: CancellationToken) {
63 let mut interval = tokio::time::interval(self.config.poll_interval);
64
65 tracing::debug!(
66 poll_interval = ?self.config.poll_interval,
67 batch_size = self.config.batch_size,
68 "Workflow scheduler started"
69 );
70
71 loop {
72 tokio::select! {
73 _ = interval.tick() => {
74 if let Err(e) = self.process_ready_workflows().await {
75 tracing::warn!(error = %e, "Failed to process ready workflows");
76 }
77 }
78 _ = shutdown.cancelled() => {
79 tracing::debug!("Workflow scheduler shutting down");
80 break;
81 }
82 }
83 }
84 }
85
86 async fn process_ready_workflows(&self) -> Result<()> {
88 let workflows: Vec<(Uuid, Option<String>)> = sqlx::query_as(
90 r#"
91 SELECT id, waiting_for_event FROM forge_workflow_runs
92 WHERE status = 'waiting' AND (
93 (wake_at IS NOT NULL AND wake_at <= NOW())
94 OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
95 )
96 ORDER BY COALESCE(wake_at, event_timeout_at) ASC
97 LIMIT $1
98 FOR UPDATE SKIP LOCKED
99 "#,
100 )
101 .bind(self.config.batch_size)
102 .fetch_all(&self.pool)
103 .await
104 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
105
106 let count = workflows.len();
107 if count > 0 {
108 tracing::trace!(count, "Processing ready workflows");
109 }
110
111 for (workflow_id, waiting_for_event) in workflows {
112 if waiting_for_event.is_some() {
113 self.resume_with_timeout(workflow_id).await;
115 } else {
116 self.resume_workflow(workflow_id).await;
118 }
119 }
120
121 if self.config.process_events {
123 self.process_event_wakeups().await?;
124 }
125
126 Ok(())
127 }
128
129 async fn process_event_wakeups(&self) -> Result<()> {
131 let workflows: Vec<(Uuid, String)> = sqlx::query_as(
134 r#"
135 SELECT wr.id, wr.waiting_for_event
136 FROM forge_workflow_runs wr
137 WHERE wr.status = 'waiting'
138 AND wr.waiting_for_event IS NOT NULL
139 AND EXISTS (
140 SELECT 1 FROM forge_workflow_events we
141 WHERE we.correlation_id = wr.id::text
142 AND we.event_name = wr.waiting_for_event
143 AND we.consumed_at IS NULL
144 )
145 LIMIT $1
146 FOR UPDATE OF wr SKIP LOCKED
147 "#,
148 )
149 .bind(self.config.batch_size)
150 .fetch_all(&self.pool)
151 .await
152 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
153
154 for (workflow_id, _event_name) in workflows {
155 self.resume_with_event(workflow_id).await;
156 }
157
158 Ok(())
159 }
160
161 async fn resume_workflow(&self, workflow_run_id: Uuid) {
163 if let Err(e) = sqlx::query(
165 r#"
166 UPDATE forge_workflow_runs
167 SET wake_at = NULL, suspended_at = NULL, status = 'running'
168 WHERE id = $1
169 "#,
170 )
171 .bind(workflow_run_id)
172 .execute(&self.pool)
173 .await
174 {
175 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear wake state");
176 return;
177 }
178
179 if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
181 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow");
182 } else {
183 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after timer");
184 }
185 }
186
187 async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
189 if let Err(e) = sqlx::query(
191 r#"
192 UPDATE forge_workflow_runs
193 SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
194 WHERE id = $1
195 "#,
196 )
197 .bind(workflow_run_id)
198 .execute(&self.pool)
199 .await
200 {
201 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state");
202 return;
203 }
204
205 if let Err(e) = self.executor.resume(workflow_run_id).await {
207 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after timeout");
208 } else {
209 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event timeout");
210 }
211 }
212
213 async fn resume_with_event(&self, workflow_run_id: Uuid) {
215 if let Err(e) = sqlx::query(
217 r#"
218 UPDATE forge_workflow_runs
219 SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
220 WHERE id = $1
221 "#,
222 )
223 .bind(workflow_run_id)
224 .execute(&self.pool)
225 .await
226 {
227 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state for event");
228 return;
229 }
230
231 if let Err(e) = self.executor.resume(workflow_run_id).await {
233 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after event");
234 } else {
235 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event");
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_scheduler_config_default() {
246 let config = WorkflowSchedulerConfig::default();
247 assert_eq!(config.poll_interval, Duration::from_secs(1));
248 assert_eq!(config.batch_size, 100);
249 assert!(config.process_events);
250 }
251}