use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::time;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use crate::dal::DAL;
use crate::database::universal_types::UniversalUuid;
use crate::dispatcher::{Dispatcher, TaskReadyEvent};
use crate::error::ValidationError;
use crate::models::task_execution::TaskExecution;
use crate::models::workflow_execution::WorkflowExecutionRecord;
use crate::Runtime;
use super::state_manager::StateManager;
const MAX_BACKOFF: Duration = Duration::from_secs(30);
const CIRCUIT_OPEN_THRESHOLD: u32 = 5;
pub struct SchedulerLoop<'a> {
dal: &'a DAL,
runtime: Arc<Runtime>,
instance_id: Uuid,
poll_interval: Duration,
dispatcher: Option<Arc<dyn Dispatcher>>,
shutdown_rx: Option<tokio::sync::watch::Receiver<bool>>,
consecutive_errors: u32,
}
impl<'a> SchedulerLoop<'a> {
pub fn with_dispatcher(
dal: &'a DAL,
runtime: Arc<Runtime>,
instance_id: Uuid,
poll_interval: Duration,
dispatcher: Option<Arc<dyn Dispatcher>>,
) -> Self {
Self {
dal,
runtime,
instance_id,
poll_interval,
dispatcher,
shutdown_rx: None,
consecutive_errors: 0,
}
}
pub fn with_shutdown(mut self, shutdown_rx: tokio::sync::watch::Receiver<bool>) -> Self {
self.shutdown_rx = Some(shutdown_rx);
self
}
pub async fn run(&mut self) -> Result<(), ValidationError> {
info!(
"Starting task scheduler loop (instance: {}, poll_interval: {:?})",
self.instance_id, self.poll_interval
);
let mut interval = time::interval(self.poll_interval);
loop {
if let Some(ref mut shutdown_rx) = self.shutdown_rx {
tokio::select! {
_ = interval.tick() => {}
_ = shutdown_rx.changed() => {
info!("SchedulerLoop shutting down");
break;
}
}
} else {
interval.tick().await;
}
match self.process_active_executions().await {
Ok(_) => {
if self.consecutive_errors > 0 {
info!(
"Scheduler loop recovered after {} consecutive errors",
self.consecutive_errors
);
self.consecutive_errors = 0;
}
debug!("Scheduling loop completed successfully");
}
Err(e) => {
self.consecutive_errors += 1;
if self.consecutive_errors == CIRCUIT_OPEN_THRESHOLD {
warn!(
"Scheduler loop circuit open: {} consecutive errors — backing off (latest: {})",
self.consecutive_errors, e
);
} else if self.consecutive_errors.is_multiple_of(10) {
warn!(
"Scheduler loop still failing: {} consecutive errors (latest: {})",
self.consecutive_errors, e
);
} else if self.consecutive_errors < CIRCUIT_OPEN_THRESHOLD {
error!("Scheduling loop error: {}", e);
}
let backoff_exp = self.consecutive_errors.min(8);
let backoff = self
.poll_interval
.saturating_mul(1u32 << backoff_exp)
.min(MAX_BACKOFF);
if backoff > self.poll_interval {
time::sleep(backoff - self.poll_interval).await;
}
}
}
}
Ok(())
}
pub async fn process_active_executions(&self) -> Result<(), ValidationError> {
let active_executions = self
.dal
.workflow_execution()
.get_active_executions()
.await?;
metrics::gauge!("cloacina_active_workflows").set(active_executions.len() as f64);
if active_executions.is_empty() {
if self.dispatcher.is_some() {
self.dispatch_ready_tasks().await?;
}
return Ok(());
}
self.process_executions_batch(active_executions).await?;
if self.dispatcher.is_some() {
self.dispatch_ready_tasks().await?;
}
Ok(())
}
async fn process_executions_batch(
&self,
active_executions: Vec<WorkflowExecutionRecord>,
) -> Result<(), ValidationError> {
let execution_ids: Vec<UniversalUuid> = active_executions.iter().map(|e| e.id).collect();
let all_pending_tasks = self
.dal
.task_execution()
.get_pending_tasks_batch(execution_ids)
.await?;
let mut tasks_by_execution: HashMap<UniversalUuid, Vec<TaskExecution>> = HashMap::new();
for task in all_pending_tasks {
tasks_by_execution
.entry(task.workflow_execution_id)
.or_default()
.push(task);
}
let state_manager = StateManager::new(self.dal, self.runtime.clone());
for execution in &active_executions {
if let Some(execution_tasks) = tasks_by_execution.get(&execution.id) {
if let Err(e) = state_manager
.update_workflow_task_readiness(execution.id, execution_tasks)
.await
{
error!(
"Failed to update task readiness for workflow execution {}: {}",
execution.id, e
);
continue;
}
}
if self
.dal
.task_execution()
.check_workflow_completion(execution.id)
.await?
{
self.complete_execution(execution).await?;
}
}
Ok(())
}
async fn dispatch_ready_tasks(&self) -> Result<(), ValidationError> {
let dispatcher = match &self.dispatcher {
Some(d) => d,
None => return Ok(()),
};
let ready_tasks = self.dal.task_execution().get_ready_for_retry().await?;
for task in ready_tasks {
let event = TaskReadyEvent::new(
task.id,
task.workflow_execution_id,
task.task_name.clone(),
task.attempt,
);
if let Err(e) = dispatcher.dispatch(event).await {
warn!(
task_id = %task.id,
task_name = %task.task_name,
error = %e,
"Failed to dispatch ready task"
);
}
}
Ok(())
}
async fn complete_execution(
&self,
execution: &WorkflowExecutionRecord,
) -> Result<(), ValidationError> {
let current = self
.dal
.workflow_execution()
.get_by_id(execution.id)
.await?;
if current.status == "Completed" || current.status == "Failed" {
debug!(
"Workflow execution {} already in status '{}', skipping completion",
execution.id, current.status
);
return Ok(());
}
let all_tasks = self
.dal
.task_execution()
.get_all_tasks_for_workflow(execution.id)
.await?;
let completed_count = all_tasks.iter().filter(|t| t.status == "Completed").count();
let failed_count = all_tasks.iter().filter(|t| t.status == "Failed").count();
let skipped_count = all_tasks.iter().filter(|t| t.status == "Skipped").count();
if let Err(e) = self
.update_execution_final_context(execution.id, &all_tasks)
.await
{
warn!(
"Failed to update final context for workflow execution {}: {}",
execution.id, e
);
}
if failed_count > 0 {
let reason = format!(
"{} task(s) failed, {} completed, {} skipped",
failed_count, completed_count, skipped_count
);
self.dal
.workflow_execution()
.mark_failed(execution.id, &reason)
.await?;
metrics::counter!(
"cloacina_workflows_total",
"status" => "failed",
"reason" => "dependency_failed",
)
.increment(1);
info!(
"Workflow execution failed: {} (name: {}, {})",
execution.id, execution.workflow_name, reason
);
} else {
self.dal
.workflow_execution()
.mark_completed(execution.id)
.await?;
metrics::counter!(
"cloacina_workflows_total",
"status" => "completed",
"reason" => "ok",
)
.increment(1);
info!(
"Workflow execution completed: {} (name: {}, {} completed, {} skipped)",
execution.id, execution.workflow_name, completed_count, skipped_count
);
}
let duration = chrono::Utc::now() - execution.started_at.0;
if let Ok(secs) = duration.to_std() {
metrics::histogram!("cloacina_workflow_duration_seconds").record(secs.as_secs_f64());
}
Ok(())
}
async fn update_execution_final_context(
&self,
workflow_execution_id: UniversalUuid,
all_tasks: &[TaskExecution],
) -> Result<(), ValidationError> {
let mut final_context_id: Option<UniversalUuid> = None;
let mut latest_completion_time: Option<chrono::DateTime<chrono::Utc>> = None;
for task in all_tasks {
if task.status == "Completed" || task.status == "Skipped" {
if let Some(completed_at) = task.completed_at {
let task_namespace = crate::task::TaskNamespace::from_string(&task.task_name)
.map_err(|_| {
crate::error::ValidationError::InvalidTaskName(task.task_name.clone())
})?;
if let Ok(task_metadata) = self
.dal
.task_execution_metadata()
.get_by_workflow_and_task(workflow_execution_id, &task_namespace)
.await
{
if let Some(context_id) = task_metadata.context_id {
if latest_completion_time.is_none()
|| completed_at.0 > latest_completion_time.unwrap()
{
final_context_id = Some(context_id);
latest_completion_time = Some(completed_at.0);
}
}
}
}
}
}
if let Some(context_id) = final_context_id {
debug!(
"Updating workflow execution {} final context to context_id: {}",
workflow_execution_id, context_id
);
self.dal
.workflow_execution()
.update_final_context(workflow_execution_id, context_id)
.await?;
} else {
debug!(
"No final context found for workflow execution {} - keeping initial context",
workflow_execution_id
);
}
Ok(())
}
}