forge_runtime/workflow/
scheduler.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15 pub poll_interval: Duration,
17 pub batch_size: i32,
19 pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24 fn default() -> Self {
25 Self {
26 poll_interval: Duration::from_secs(1),
27 batch_size: 100,
28 process_events: true,
29 }
30 }
31}
32
33pub struct WorkflowScheduler {
38 pool: PgPool,
39 executor: Arc<WorkflowExecutor>,
40 #[allow(dead_code)]
41 event_store: Arc<EventStore>,
42 config: WorkflowSchedulerConfig,
43}
44
45impl WorkflowScheduler {
46 pub fn new(
48 pool: PgPool,
49 executor: Arc<WorkflowExecutor>,
50 event_store: Arc<EventStore>,
51 config: WorkflowSchedulerConfig,
52 ) -> Self {
53 Self {
54 pool,
55 executor,
56 event_store,
57 config,
58 }
59 }
60
61 pub async fn run(&self, shutdown: CancellationToken) {
63 let mut interval = tokio::time::interval(self.config.poll_interval);
64
65 tracing::info!(
66 poll_interval = ?self.config.poll_interval,
67 batch_size = self.config.batch_size,
68 "Workflow scheduler started"
69 );
70
71 loop {
72 tokio::select! {
73 _ = interval.tick() => {
74 if let Err(e) = self.process_ready_workflows().await {
75 tracing::error!(error = %e, "Failed to process ready workflows");
76 }
77 }
78 _ = shutdown.cancelled() => {
79 tracing::info!("Workflow scheduler shutting down");
80 break;
81 }
82 }
83 }
84 }
85
86 async fn process_ready_workflows(&self) -> Result<()> {
88 let workflows: Vec<(Uuid, Option<String>)> = sqlx::query_as(
90 r#"
91 SELECT id, waiting_for_event FROM forge_workflow_runs
92 WHERE status = 'waiting' AND (
93 (wake_at IS NOT NULL AND wake_at <= NOW())
94 OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
95 )
96 ORDER BY COALESCE(wake_at, event_timeout_at) ASC
97 LIMIT $1
98 FOR UPDATE SKIP LOCKED
99 "#,
100 )
101 .bind(self.config.batch_size)
102 .fetch_all(&self.pool)
103 .await
104 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
105
106 let count = workflows.len();
107 if count > 0 {
108 tracing::debug!(count = count, "Processing ready workflows");
109 }
110
111 for (workflow_id, waiting_for_event) in workflows {
112 if waiting_for_event.is_some() {
113 self.resume_with_timeout(workflow_id).await;
115 } else {
116 self.resume_workflow(workflow_id).await;
118 }
119 }
120
121 if self.config.process_events {
123 self.process_event_wakeups().await?;
124 }
125
126 Ok(())
127 }
128
129 async fn process_event_wakeups(&self) -> Result<()> {
131 let workflows: Vec<(Uuid, String)> = sqlx::query_as(
134 r#"
135 SELECT wr.id, wr.waiting_for_event
136 FROM forge_workflow_runs wr
137 WHERE wr.status = 'waiting'
138 AND wr.waiting_for_event IS NOT NULL
139 AND EXISTS (
140 SELECT 1 FROM forge_workflow_events we
141 WHERE we.correlation_id = wr.id::text
142 AND we.event_name = wr.waiting_for_event
143 AND we.consumed_at IS NULL
144 )
145 LIMIT $1
146 FOR UPDATE OF wr SKIP LOCKED
147 "#,
148 )
149 .bind(self.config.batch_size)
150 .fetch_all(&self.pool)
151 .await
152 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
153
154 for (workflow_id, _event_name) in workflows {
155 self.resume_with_event(workflow_id).await;
156 }
157
158 Ok(())
159 }
160
161 async fn resume_workflow(&self, workflow_run_id: Uuid) {
163 if let Err(e) = sqlx::query(
165 r#"
166 UPDATE forge_workflow_runs
167 SET wake_at = NULL, suspended_at = NULL, status = 'running'
168 WHERE id = $1
169 "#,
170 )
171 .bind(workflow_run_id)
172 .execute(&self.pool)
173 .await
174 {
175 tracing::error!(
176 workflow_run_id = %workflow_run_id,
177 error = %e,
178 "Failed to clear wake state"
179 );
180 return;
181 }
182
183 if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
185 tracing::error!(
186 workflow_run_id = %workflow_run_id,
187 error = %e,
188 "Failed to resume workflow"
189 );
190 } else {
191 tracing::info!(
192 workflow_run_id = %workflow_run_id,
193 "Resumed workflow after timer"
194 );
195 }
196 }
197
198 async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
200 if let Err(e) = sqlx::query(
202 r#"
203 UPDATE forge_workflow_runs
204 SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
205 WHERE id = $1
206 "#,
207 )
208 .bind(workflow_run_id)
209 .execute(&self.pool)
210 .await
211 {
212 tracing::error!(
213 workflow_run_id = %workflow_run_id,
214 error = %e,
215 "Failed to clear waiting state"
216 );
217 return;
218 }
219
220 if let Err(e) = self.executor.resume(workflow_run_id).await {
222 tracing::error!(
223 workflow_run_id = %workflow_run_id,
224 error = %e,
225 "Failed to resume workflow after timeout"
226 );
227 } else {
228 tracing::info!(
229 workflow_run_id = %workflow_run_id,
230 "Resumed workflow after event timeout"
231 );
232 }
233 }
234
235 async fn resume_with_event(&self, workflow_run_id: Uuid) {
237 if let Err(e) = sqlx::query(
239 r#"
240 UPDATE forge_workflow_runs
241 SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
242 WHERE id = $1
243 "#,
244 )
245 .bind(workflow_run_id)
246 .execute(&self.pool)
247 .await
248 {
249 tracing::error!(
250 workflow_run_id = %workflow_run_id,
251 error = %e,
252 "Failed to clear waiting state for event"
253 );
254 return;
255 }
256
257 if let Err(e) = self.executor.resume(workflow_run_id).await {
259 tracing::error!(
260 workflow_run_id = %workflow_run_id,
261 error = %e,
262 "Failed to resume workflow after event"
263 );
264 } else {
265 tracing::info!(
266 workflow_run_id = %workflow_run_id,
267 "Resumed workflow after event received"
268 );
269 }
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_scheduler_config_default() {
279 let config = WorkflowSchedulerConfig::default();
280 assert_eq!(config.poll_interval, Duration::from_secs(1));
281 assert_eq!(config.batch_size, 100);
282 assert!(config.process_events);
283 }
284}