forge_runtime/workflow/
scheduler.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15 pub poll_interval: Duration,
17 pub batch_size: i32,
19 pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24 fn default() -> Self {
25 Self {
26 poll_interval: Duration::from_secs(1),
27 batch_size: 100,
28 process_events: true,
29 }
30 }
31}
32
33pub struct WorkflowScheduler {
40 pool: PgPool,
41 executor: Arc<WorkflowExecutor>,
42 event_store: Arc<EventStore>,
43 config: WorkflowSchedulerConfig,
44}
45
46impl WorkflowScheduler {
47 pub fn new(
49 pool: PgPool,
50 executor: Arc<WorkflowExecutor>,
51 event_store: Arc<EventStore>,
52 config: WorkflowSchedulerConfig,
53 ) -> Self {
54 Self {
55 pool,
56 executor,
57 event_store,
58 config,
59 }
60 }
61
62 pub async fn run(&self, shutdown: CancellationToken) {
70 let fallback_interval = self.config.poll_interval * 10;
71 let mut interval = tokio::time::interval(fallback_interval);
72 let mut cleanup_interval = tokio::time::interval(Duration::from_secs(3600));
73
74 let mut listener = match sqlx::postgres::PgListener::connect_with(&self.pool).await {
76 Ok(mut l) => {
77 if let Err(e) = l.listen("forge_workflow_wakeup").await {
78 tracing::warn!(error = %e, "Failed to listen on workflow wakeup channel, using poll-only mode");
79 }
80 Some(l)
81 }
82 Err(e) => {
83 tracing::warn!(error = %e, "Failed to create workflow wakeup listener, using poll-only mode");
84 None
85 }
86 };
87
88 tracing::debug!(
89 poll_interval = ?fallback_interval,
90 batch_size = self.config.batch_size,
91 notify_enabled = listener.is_some(),
92 "Workflow scheduler started"
93 );
94
95 loop {
96 tokio::select! {
97 _ = interval.tick() => {
98 if let Err(e) = self.process_ready_workflows().await {
99 tracing::warn!(error = %e, "Failed to process ready workflows");
100 }
101 }
102 notification = async {
103 match listener.as_mut() {
104 Some(l) => l.recv().await,
105 None => std::future::pending().await,
106 }
107 } => {
108 match notification {
109 Ok(_) => {
110 if let Err(e) = self.process_ready_workflows().await {
111 tracing::warn!(error = %e, "Failed to process workflows after wakeup");
112 }
113 }
114 Err(e) => {
115 tracing::debug!(error = %e, "Workflow wakeup listener error, will retry on next poll");
116 }
117 }
118 }
119 _ = cleanup_interval.tick() => {
120 let cutoff = chrono::Utc::now() - chrono::Duration::hours(24);
122 match self.event_store.cleanup_consumed_events(cutoff).await {
123 Ok(count) if count > 0 => {
124 tracing::debug!(count, "Cleaned up consumed workflow events");
125 }
126 Err(e) => {
127 tracing::debug!(error = %e, "Failed to clean up consumed events");
128 }
129 _ => {}
130 }
131 }
132 _ = shutdown.cancelled() => {
133 tracing::debug!("Workflow scheduler shutting down");
134 break;
135 }
136 }
137 }
138 }
139
140 async fn process_ready_workflows(&self) -> Result<()> {
142 let workflows = sqlx::query!(
144 r#"
145 SELECT id, waiting_for_event FROM forge_workflow_runs
146 WHERE status = 'waiting' AND (
147 (wake_at IS NOT NULL AND wake_at <= NOW())
148 OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
149 )
150 ORDER BY COALESCE(wake_at, event_timeout_at) ASC
151 LIMIT $1
152 FOR UPDATE SKIP LOCKED
153 "#,
154 self.config.batch_size as i64
155 )
156 .fetch_all(&self.pool)
157 .await
158 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
159
160 let count = workflows.len();
161 if count > 0 {
162 tracing::trace!(count, "Processing ready workflows");
163 }
164
165 for workflow in workflows {
166 if workflow.waiting_for_event.is_some() {
167 self.resume_with_timeout(workflow.id).await;
169 } else {
170 self.resume_workflow(workflow.id).await;
172 }
173 }
174
175 if self.config.process_events {
177 self.process_event_wakeups().await?;
178 }
179
180 Ok(())
181 }
182
183 async fn process_event_wakeups(&self) -> Result<()> {
185 let workflows = sqlx::query!(
188 r#"
189 SELECT wr.id, wr.waiting_for_event
190 FROM forge_workflow_runs wr
191 WHERE wr.status = 'waiting'
192 AND wr.waiting_for_event IS NOT NULL
193 AND EXISTS (
194 SELECT 1 FROM forge_workflow_events we
195 WHERE we.correlation_id = wr.id::text
196 AND we.event_name = wr.waiting_for_event
197 AND we.consumed_at IS NULL
198 )
199 LIMIT $1
200 FOR UPDATE OF wr SKIP LOCKED
201 "#,
202 self.config.batch_size as i64
203 )
204 .fetch_all(&self.pool)
205 .await
206 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
207
208 for workflow in workflows {
209 let workflow_id = workflow.id;
210 let Some(event_name) = workflow.waiting_for_event else {
211 continue;
212 };
213 match self
215 .event_store
216 .consume_event(&event_name, &workflow_id.to_string(), workflow_id)
217 .await
218 {
219 Ok(Some(_event)) => {
220 self.resume_with_event(workflow_id).await;
221 }
222 Ok(None) => {
223 tracing::debug!(
224 workflow_run_id = %workflow_id,
225 event_name = %event_name,
226 "Event already consumed, skipping wakeup"
227 );
228 }
229 Err(e) => {
230 tracing::warn!(
231 workflow_run_id = %workflow_id,
232 error = %e,
233 "Failed to consume workflow event"
234 );
235 }
236 }
237 }
238
239 Ok(())
240 }
241
242 async fn resume_workflow(&self, workflow_run_id: Uuid) {
244 if let Err(e) = sqlx::query(
246 r#"
247 UPDATE forge_workflow_runs
248 SET wake_at = NULL, suspended_at = NULL, status = 'running'
249 WHERE id = $1
250 "#,
251 )
252 .bind(workflow_run_id)
253 .execute(&self.pool)
254 .await
255 {
256 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear wake state");
257 return;
258 }
259
260 if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
262 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow");
263 } else {
264 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after timer");
265 }
266 }
267
268 async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
270 if let Err(e) = sqlx::query(
272 r#"
273 UPDATE forge_workflow_runs
274 SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
275 WHERE id = $1
276 "#,
277 )
278 .bind(workflow_run_id)
279 .execute(&self.pool)
280 .await
281 {
282 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state");
283 return;
284 }
285
286 if let Err(e) = self.executor.resume(workflow_run_id).await {
288 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after timeout");
289 } else {
290 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event timeout");
291 }
292 }
293
294 async fn resume_with_event(&self, workflow_run_id: Uuid) {
296 if let Err(e) = sqlx::query(
298 r#"
299 UPDATE forge_workflow_runs
300 SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
301 WHERE id = $1
302 "#,
303 )
304 .bind(workflow_run_id)
305 .execute(&self.pool)
306 .await
307 {
308 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state for event");
309 return;
310 }
311
312 if let Err(e) = self.executor.resume(workflow_run_id).await {
314 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after event");
315 } else {
316 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event");
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_scheduler_config_default() {
327 let config = WorkflowSchedulerConfig::default();
328 assert_eq!(config.poll_interval, Duration::from_secs(1));
329 assert_eq!(config.batch_size, 100);
330 assert!(config.process_events);
331 }
332}