forge_runtime/workflow/
scheduler.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15 pub poll_interval: Duration,
17 pub batch_size: i32,
19 pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24 fn default() -> Self {
25 Self {
26 poll_interval: Duration::from_secs(1),
27 batch_size: 100,
28 process_events: true,
29 }
30 }
31}
32
33pub struct WorkflowScheduler {
40 pool: PgPool,
41 executor: Arc<WorkflowExecutor>,
42 #[allow(dead_code)]
43 event_store: Arc<EventStore>,
44 config: WorkflowSchedulerConfig,
45}
46
47impl WorkflowScheduler {
48 pub fn new(
50 pool: PgPool,
51 executor: Arc<WorkflowExecutor>,
52 event_store: Arc<EventStore>,
53 config: WorkflowSchedulerConfig,
54 ) -> Self {
55 Self {
56 pool,
57 executor,
58 event_store,
59 config,
60 }
61 }
62
63 pub async fn run(&self, shutdown: CancellationToken) {
71 let fallback_interval = self.config.poll_interval * 10;
72 let mut interval = tokio::time::interval(fallback_interval);
73
74 let mut listener = match sqlx::postgres::PgListener::connect_with(&self.pool).await {
76 Ok(mut l) => {
77 if let Err(e) = l.listen("forge_workflow_wakeup").await {
78 tracing::warn!(error = %e, "Failed to listen on workflow wakeup channel, using poll-only mode");
79 }
80 Some(l)
81 }
82 Err(e) => {
83 tracing::warn!(error = %e, "Failed to create workflow wakeup listener, using poll-only mode");
84 None
85 }
86 };
87
88 tracing::debug!(
89 poll_interval = ?fallback_interval,
90 batch_size = self.config.batch_size,
91 notify_enabled = listener.is_some(),
92 "Workflow scheduler started"
93 );
94
95 loop {
96 tokio::select! {
97 _ = interval.tick() => {
98 if let Err(e) = self.process_ready_workflows().await {
99 tracing::warn!(error = %e, "Failed to process ready workflows");
100 }
101 }
102 notification = async {
103 match listener.as_mut() {
104 Some(l) => l.recv().await,
105 None => std::future::pending().await,
106 }
107 } => {
108 match notification {
109 Ok(_) => {
110 if let Err(e) = self.process_ready_workflows().await {
111 tracing::warn!(error = %e, "Failed to process workflows after wakeup");
112 }
113 }
114 Err(e) => {
115 tracing::debug!(error = %e, "Workflow wakeup listener error, will retry on next poll");
116 }
117 }
118 }
119 _ = shutdown.cancelled() => {
120 tracing::debug!("Workflow scheduler shutting down");
121 break;
122 }
123 }
124 }
125 }
126
127 async fn process_ready_workflows(&self) -> Result<()> {
129 let workflows: Vec<(Uuid, Option<String>)> = sqlx::query_as(
131 r#"
132 SELECT id, waiting_for_event FROM forge_workflow_runs
133 WHERE status = 'waiting' AND (
134 (wake_at IS NOT NULL AND wake_at <= NOW())
135 OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
136 )
137 ORDER BY COALESCE(wake_at, event_timeout_at) ASC
138 LIMIT $1
139 FOR UPDATE SKIP LOCKED
140 "#,
141 )
142 .bind(self.config.batch_size)
143 .fetch_all(&self.pool)
144 .await
145 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
146
147 let count = workflows.len();
148 if count > 0 {
149 tracing::trace!(count, "Processing ready workflows");
150 }
151
152 for (workflow_id, waiting_for_event) in workflows {
153 if waiting_for_event.is_some() {
154 self.resume_with_timeout(workflow_id).await;
156 } else {
157 self.resume_workflow(workflow_id).await;
159 }
160 }
161
162 if self.config.process_events {
164 self.process_event_wakeups().await?;
165 }
166
167 Ok(())
168 }
169
170 async fn process_event_wakeups(&self) -> Result<()> {
172 let workflows: Vec<(Uuid, String)> = sqlx::query_as(
175 r#"
176 SELECT wr.id, wr.waiting_for_event
177 FROM forge_workflow_runs wr
178 WHERE wr.status = 'waiting'
179 AND wr.waiting_for_event IS NOT NULL
180 AND EXISTS (
181 SELECT 1 FROM forge_workflow_events we
182 WHERE we.correlation_id = wr.id::text
183 AND we.event_name = wr.waiting_for_event
184 AND we.consumed_at IS NULL
185 )
186 LIMIT $1
187 FOR UPDATE OF wr SKIP LOCKED
188 "#,
189 )
190 .bind(self.config.batch_size)
191 .fetch_all(&self.pool)
192 .await
193 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
194
195 for (workflow_id, _event_name) in workflows {
196 self.resume_with_event(workflow_id).await;
197 }
198
199 Ok(())
200 }
201
202 async fn resume_workflow(&self, workflow_run_id: Uuid) {
204 if let Err(e) = sqlx::query(
206 r#"
207 UPDATE forge_workflow_runs
208 SET wake_at = NULL, suspended_at = NULL, status = 'running'
209 WHERE id = $1
210 "#,
211 )
212 .bind(workflow_run_id)
213 .execute(&self.pool)
214 .await
215 {
216 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear wake state");
217 return;
218 }
219
220 if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
222 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow");
223 } else {
224 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after timer");
225 }
226 }
227
228 async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
230 if let Err(e) = sqlx::query(
232 r#"
233 UPDATE forge_workflow_runs
234 SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
235 WHERE id = $1
236 "#,
237 )
238 .bind(workflow_run_id)
239 .execute(&self.pool)
240 .await
241 {
242 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state");
243 return;
244 }
245
246 if let Err(e) = self.executor.resume(workflow_run_id).await {
248 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after timeout");
249 } else {
250 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event timeout");
251 }
252 }
253
254 async fn resume_with_event(&self, workflow_run_id: Uuid) {
256 if let Err(e) = sqlx::query(
258 r#"
259 UPDATE forge_workflow_runs
260 SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
261 WHERE id = $1
262 "#,
263 )
264 .bind(workflow_run_id)
265 .execute(&self.pool)
266 .await
267 {
268 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state for event");
269 return;
270 }
271
272 if let Err(e) = self.executor.resume(workflow_run_id).await {
274 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after event");
275 } else {
276 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event");
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_scheduler_config_default() {
287 let config = WorkflowSchedulerConfig::default();
288 assert_eq!(config.poll_interval, Duration::from_secs(1));
289 assert_eq!(config.batch_size, 100);
290 assert!(config.process_events);
291 }
292}