forge_runtime/workflow/
scheduler.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::PgPool;
5use tokio_util::sync::CancellationToken;
6use uuid::Uuid;
7
8use super::event_store::EventStore;
9use super::executor::WorkflowExecutor;
10use forge_core::Result;
11
12#[derive(Debug, Clone)]
14pub struct WorkflowSchedulerConfig {
15 pub poll_interval: Duration,
17 pub batch_size: i32,
19 pub process_events: bool,
21}
22
23impl Default for WorkflowSchedulerConfig {
24 fn default() -> Self {
25 Self {
26 poll_interval: Duration::from_secs(1),
27 batch_size: 100,
28 process_events: true,
29 }
30 }
31}
32
33pub struct WorkflowScheduler {
40 pool: PgPool,
41 executor: Arc<WorkflowExecutor>,
42 event_store: Arc<EventStore>,
43 config: WorkflowSchedulerConfig,
44}
45
46impl WorkflowScheduler {
47 pub fn new(
49 pool: PgPool,
50 executor: Arc<WorkflowExecutor>,
51 event_store: Arc<EventStore>,
52 config: WorkflowSchedulerConfig,
53 ) -> Self {
54 Self {
55 pool,
56 executor,
57 event_store,
58 config,
59 }
60 }
61
62 pub async fn run(&self, shutdown: CancellationToken) {
70 let fallback_interval = self.config.poll_interval * 10;
71 let mut interval = tokio::time::interval(fallback_interval);
72 let mut cleanup_interval = tokio::time::interval(Duration::from_secs(3600));
73
74 let mut listener = match sqlx::postgres::PgListener::connect_with(&self.pool).await {
76 Ok(mut l) => {
77 if let Err(e) = l.listen("forge_workflow_wakeup").await {
78 tracing::warn!(error = %e, "Failed to listen on workflow wakeup channel, using poll-only mode");
79 }
80 Some(l)
81 }
82 Err(e) => {
83 tracing::warn!(error = %e, "Failed to create workflow wakeup listener, using poll-only mode");
84 None
85 }
86 };
87
88 tracing::debug!(
89 poll_interval = ?fallback_interval,
90 batch_size = self.config.batch_size,
91 notify_enabled = listener.is_some(),
92 "Workflow scheduler started"
93 );
94
95 loop {
96 tokio::select! {
97 _ = interval.tick() => {
98 if let Err(e) = self.process_ready_workflows().await {
99 tracing::warn!(error = %e, "Failed to process ready workflows");
100 }
101 }
102 notification = async {
103 match listener.as_mut() {
104 Some(l) => l.recv().await,
105 None => std::future::pending().await,
106 }
107 } => {
108 match notification {
109 Ok(_) => {
110 if let Err(e) = self.process_ready_workflows().await {
111 tracing::warn!(error = %e, "Failed to process workflows after wakeup");
112 }
113 }
114 Err(e) => {
115 tracing::debug!(error = %e, "Workflow wakeup listener error, will retry on next poll");
116 }
117 }
118 }
119 _ = cleanup_interval.tick() => {
120 let cutoff = chrono::Utc::now() - chrono::Duration::hours(24);
122 match self.event_store.cleanup_consumed_events(cutoff).await {
123 Ok(count) if count > 0 => {
124 tracing::debug!(count, "Cleaned up consumed workflow events");
125 }
126 Err(e) => {
127 tracing::debug!(error = %e, "Failed to clean up consumed events");
128 }
129 _ => {}
130 }
131 }
132 _ = shutdown.cancelled() => {
133 tracing::debug!("Workflow scheduler shutting down");
134 break;
135 }
136 }
137 }
138 }
139
140 async fn process_ready_workflows(&self) -> Result<()> {
142 let workflows: Vec<(Uuid, Option<String>)> = sqlx::query_as(
144 r#"
145 SELECT id, waiting_for_event FROM forge_workflow_runs
146 WHERE status = 'waiting' AND (
147 (wake_at IS NOT NULL AND wake_at <= NOW())
148 OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
149 )
150 ORDER BY COALESCE(wake_at, event_timeout_at) ASC
151 LIMIT $1
152 FOR UPDATE SKIP LOCKED
153 "#,
154 )
155 .bind(self.config.batch_size)
156 .fetch_all(&self.pool)
157 .await
158 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
159
160 let count = workflows.len();
161 if count > 0 {
162 tracing::trace!(count, "Processing ready workflows");
163 }
164
165 for (workflow_id, waiting_for_event) in workflows {
166 if waiting_for_event.is_some() {
167 self.resume_with_timeout(workflow_id).await;
169 } else {
170 self.resume_workflow(workflow_id).await;
172 }
173 }
174
175 if self.config.process_events {
177 self.process_event_wakeups().await?;
178 }
179
180 Ok(())
181 }
182
183 async fn process_event_wakeups(&self) -> Result<()> {
185 let workflows: Vec<(Uuid, String)> = sqlx::query_as(
188 r#"
189 SELECT wr.id, wr.waiting_for_event
190 FROM forge_workflow_runs wr
191 WHERE wr.status = 'waiting'
192 AND wr.waiting_for_event IS NOT NULL
193 AND EXISTS (
194 SELECT 1 FROM forge_workflow_events we
195 WHERE we.correlation_id = wr.id::text
196 AND we.event_name = wr.waiting_for_event
197 AND we.consumed_at IS NULL
198 )
199 LIMIT $1
200 FOR UPDATE OF wr SKIP LOCKED
201 "#,
202 )
203 .bind(self.config.batch_size)
204 .fetch_all(&self.pool)
205 .await
206 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
207
208 for (workflow_id, event_name) in workflows {
209 match self
211 .event_store
212 .consume_event(&event_name, &workflow_id.to_string(), workflow_id)
213 .await
214 {
215 Ok(Some(_event)) => {
216 self.resume_with_event(workflow_id).await;
217 }
218 Ok(None) => {
219 tracing::debug!(
220 workflow_run_id = %workflow_id,
221 event_name = %event_name,
222 "Event already consumed, skipping wakeup"
223 );
224 }
225 Err(e) => {
226 tracing::warn!(
227 workflow_run_id = %workflow_id,
228 error = %e,
229 "Failed to consume workflow event"
230 );
231 }
232 }
233 }
234
235 Ok(())
236 }
237
238 async fn resume_workflow(&self, workflow_run_id: Uuid) {
240 if let Err(e) = sqlx::query(
242 r#"
243 UPDATE forge_workflow_runs
244 SET wake_at = NULL, suspended_at = NULL, status = 'running'
245 WHERE id = $1
246 "#,
247 )
248 .bind(workflow_run_id)
249 .execute(&self.pool)
250 .await
251 {
252 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear wake state");
253 return;
254 }
255
256 if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
258 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow");
259 } else {
260 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after timer");
261 }
262 }
263
264 async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
266 if let Err(e) = sqlx::query(
268 r#"
269 UPDATE forge_workflow_runs
270 SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
271 WHERE id = $1
272 "#,
273 )
274 .bind(workflow_run_id)
275 .execute(&self.pool)
276 .await
277 {
278 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state");
279 return;
280 }
281
282 if let Err(e) = self.executor.resume(workflow_run_id).await {
284 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after timeout");
285 } else {
286 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event timeout");
287 }
288 }
289
290 async fn resume_with_event(&self, workflow_run_id: Uuid) {
292 if let Err(e) = sqlx::query(
294 r#"
295 UPDATE forge_workflow_runs
296 SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
297 WHERE id = $1
298 "#,
299 )
300 .bind(workflow_run_id)
301 .execute(&self.pool)
302 .await
303 {
304 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to clear waiting state for event");
305 return;
306 }
307
308 if let Err(e) = self.executor.resume(workflow_run_id).await {
310 tracing::warn!(workflow_run_id = %workflow_run_id, error = %e, "Failed to resume workflow after event");
311 } else {
312 tracing::debug!(workflow_run_id = %workflow_run_id, "Workflow resumed after event");
313 }
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_scheduler_config_default() {
323 let config = WorkflowSchedulerConfig::default();
324 assert_eq!(config.poll_interval, Duration::from_secs(1));
325 assert_eq!(config.batch_size, 100);
326 assert!(config.process_events);
327 }
328}