use std::sync::Arc;
use std::time::Duration;
use sqlx::PgPool;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::event_store::EventStore;
use super::executor::WorkflowExecutor;
use forge_core::Result;
#[derive(Debug, Clone)]
pub struct WorkflowSchedulerConfig {
pub poll_interval: Duration,
pub batch_size: i32,
pub process_events: bool,
}
impl Default for WorkflowSchedulerConfig {
fn default() -> Self {
Self {
poll_interval: Duration::from_secs(1),
batch_size: 100,
process_events: true,
}
}
}
pub struct WorkflowScheduler {
pool: PgPool,
executor: Arc<WorkflowExecutor>,
#[allow(dead_code)]
event_store: Arc<EventStore>,
config: WorkflowSchedulerConfig,
}
impl WorkflowScheduler {
pub fn new(
pool: PgPool,
executor: Arc<WorkflowExecutor>,
event_store: Arc<EventStore>,
config: WorkflowSchedulerConfig,
) -> Self {
Self {
pool,
executor,
event_store,
config,
}
}
pub async fn run(&self, shutdown: CancellationToken) {
let mut interval = tokio::time::interval(self.config.poll_interval);
tracing::info!(
poll_interval = ?self.config.poll_interval,
batch_size = self.config.batch_size,
"Workflow scheduler started"
);
loop {
tokio::select! {
_ = interval.tick() => {
if let Err(e) = self.process_ready_workflows().await {
tracing::error!(error = %e, "Failed to process ready workflows");
}
}
_ = shutdown.cancelled() => {
tracing::info!("Workflow scheduler shutting down");
break;
}
}
}
}
async fn process_ready_workflows(&self) -> Result<()> {
let workflows: Vec<(Uuid, Option<String>)> = sqlx::query_as(
r#"
SELECT id, waiting_for_event FROM forge_workflow_runs
WHERE status = 'waiting' AND (
(wake_at IS NOT NULL AND wake_at <= NOW())
OR (event_timeout_at IS NOT NULL AND event_timeout_at <= NOW())
)
ORDER BY COALESCE(wake_at, event_timeout_at) ASC
LIMIT $1
FOR UPDATE SKIP LOCKED
"#,
)
.bind(self.config.batch_size)
.fetch_all(&self.pool)
.await
.map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
let count = workflows.len();
if count > 0 {
tracing::debug!(count = count, "Processing ready workflows");
}
for (workflow_id, waiting_for_event) in workflows {
if waiting_for_event.is_some() {
self.resume_with_timeout(workflow_id).await;
} else {
self.resume_workflow(workflow_id).await;
}
}
if self.config.process_events {
self.process_event_wakeups().await?;
}
Ok(())
}
async fn process_event_wakeups(&self) -> Result<()> {
let workflows: Vec<(Uuid, String)> = sqlx::query_as(
r#"
SELECT wr.id, wr.waiting_for_event
FROM forge_workflow_runs wr
WHERE wr.status = 'waiting'
AND wr.waiting_for_event IS NOT NULL
AND EXISTS (
SELECT 1 FROM forge_workflow_events we
WHERE we.correlation_id = wr.id::text
AND we.event_name = wr.waiting_for_event
AND we.consumed_at IS NULL
)
LIMIT $1
FOR UPDATE OF wr SKIP LOCKED
"#,
)
.bind(self.config.batch_size)
.fetch_all(&self.pool)
.await
.map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
for (workflow_id, _event_name) in workflows {
self.resume_with_event(workflow_id).await;
}
Ok(())
}
async fn resume_workflow(&self, workflow_run_id: Uuid) {
if let Err(e) = sqlx::query(
r#"
UPDATE forge_workflow_runs
SET wake_at = NULL, suspended_at = NULL, status = 'running'
WHERE id = $1
"#,
)
.bind(workflow_run_id)
.execute(&self.pool)
.await
{
tracing::error!(
workflow_run_id = %workflow_run_id,
error = %e,
"Failed to clear wake state"
);
return;
}
if let Err(e) = self.executor.resume_from_sleep(workflow_run_id).await {
tracing::error!(
workflow_run_id = %workflow_run_id,
error = %e,
"Failed to resume workflow"
);
} else {
tracing::info!(
workflow_run_id = %workflow_run_id,
"Resumed workflow after timer"
);
}
}
async fn resume_with_timeout(&self, workflow_run_id: Uuid) {
if let Err(e) = sqlx::query(
r#"
UPDATE forge_workflow_runs
SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
WHERE id = $1
"#,
)
.bind(workflow_run_id)
.execute(&self.pool)
.await
{
tracing::error!(
workflow_run_id = %workflow_run_id,
error = %e,
"Failed to clear waiting state"
);
return;
}
if let Err(e) = self.executor.resume(workflow_run_id).await {
tracing::error!(
workflow_run_id = %workflow_run_id,
error = %e,
"Failed to resume workflow after timeout"
);
} else {
tracing::info!(
workflow_run_id = %workflow_run_id,
"Resumed workflow after event timeout"
);
}
}
async fn resume_with_event(&self, workflow_run_id: Uuid) {
if let Err(e) = sqlx::query(
r#"
UPDATE forge_workflow_runs
SET waiting_for_event = NULL, event_timeout_at = NULL, suspended_at = NULL, status = 'running'
WHERE id = $1
"#,
)
.bind(workflow_run_id)
.execute(&self.pool)
.await
{
tracing::error!(
workflow_run_id = %workflow_run_id,
error = %e,
"Failed to clear waiting state for event"
);
return;
}
if let Err(e) = self.executor.resume(workflow_run_id).await {
tracing::error!(
workflow_run_id = %workflow_run_id,
error = %e,
"Failed to resume workflow after event"
);
} else {
tracing::info!(
workflow_run_id = %workflow_run_id,
"Resumed workflow after event received"
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scheduler_config_default() {
let config = WorkflowSchedulerConfig::default();
assert_eq!(config.poll_interval, Duration::from_secs(1));
assert_eq!(config.batch_size, 100);
assert!(config.process_events);
}
}