use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::time;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use crate::dal::DAL;
use crate::database::universal_types::UniversalUuid;
use crate::dispatcher::{Dispatcher, TaskReadyEvent};
use crate::error::ValidationError;
use crate::models::pipeline_execution::PipelineExecution;
use crate::models::task_execution::TaskExecution;
use super::state_manager::StateManager;
pub struct SchedulerLoop<'a> {
dal: &'a DAL,
instance_id: Uuid,
poll_interval: Duration,
dispatcher: Option<Arc<dyn Dispatcher>>,
}
impl<'a> SchedulerLoop<'a> {
#[allow(dead_code)]
pub fn new(dal: &'a DAL, instance_id: Uuid, poll_interval: Duration) -> Self {
Self {
dal,
instance_id,
poll_interval,
dispatcher: None,
}
}
pub fn with_dispatcher(
dal: &'a DAL,
instance_id: Uuid,
poll_interval: Duration,
dispatcher: Option<Arc<dyn Dispatcher>>,
) -> Self {
Self {
dal,
instance_id,
poll_interval,
dispatcher,
}
}
pub async fn run(&self) -> Result<(), ValidationError> {
info!(
"Starting task scheduler loop (instance: {}, poll_interval: {:?})",
self.instance_id, self.poll_interval
);
let mut interval = time::interval(self.poll_interval);
loop {
interval.tick().await;
match self.process_active_pipelines().await {
Ok(_) => debug!("Scheduling loop completed successfully"),
Err(e) => error!("Scheduling loop error: {}", e),
}
}
}
pub async fn process_active_pipelines(&self) -> Result<(), ValidationError> {
let active_executions = self
.dal
.pipeline_execution()
.get_active_executions()
.await?;
if active_executions.is_empty() {
if self.dispatcher.is_some() {
self.dispatch_ready_tasks().await?;
}
return Ok(());
}
self.process_pipelines_batch(active_executions).await?;
if self.dispatcher.is_some() {
self.dispatch_ready_tasks().await?;
}
Ok(())
}
async fn process_pipelines_batch(
&self,
active_executions: Vec<PipelineExecution>,
) -> Result<(), ValidationError> {
let pipeline_ids: Vec<UniversalUuid> = active_executions.iter().map(|e| e.id).collect();
let all_pending_tasks = self
.dal
.task_execution()
.get_pending_tasks_batch(pipeline_ids)
.await?;
let mut tasks_by_pipeline: HashMap<UniversalUuid, Vec<TaskExecution>> = HashMap::new();
for task in all_pending_tasks {
tasks_by_pipeline
.entry(task.pipeline_execution_id)
.or_default()
.push(task);
}
let state_manager = StateManager::new(self.dal);
for execution in &active_executions {
if let Some(pipeline_tasks) = tasks_by_pipeline.get(&execution.id) {
if let Err(e) = state_manager
.update_pipeline_task_readiness(execution.id, pipeline_tasks)
.await
{
error!(
"Failed to update task readiness for pipeline {}: {}",
execution.id, e
);
continue;
}
}
if self
.dal
.task_execution()
.check_pipeline_completion(execution.id)
.await?
{
self.complete_pipeline(execution).await?;
}
}
Ok(())
}
async fn dispatch_ready_tasks(&self) -> Result<(), ValidationError> {
let dispatcher = match &self.dispatcher {
Some(d) => d,
None => return Ok(()),
};
let ready_tasks = self.dal.task_execution().get_ready_for_retry().await?;
for task in ready_tasks {
let event = TaskReadyEvent::new(
task.id,
task.pipeline_execution_id,
task.task_name.clone(),
task.attempt,
);
if let Err(e) = dispatcher.dispatch(event).await {
warn!(
task_id = %task.id,
task_name = %task.task_name,
error = %e,
"Failed to dispatch ready task"
);
}
}
Ok(())
}
async fn complete_pipeline(
&self,
execution: &PipelineExecution,
) -> Result<(), ValidationError> {
let all_tasks = self
.dal
.task_execution()
.get_all_tasks_for_pipeline(execution.id)
.await?;
let completed_count = all_tasks.iter().filter(|t| t.status == "Completed").count();
let failed_count = all_tasks.iter().filter(|t| t.status == "Failed").count();
let skipped_count = all_tasks.iter().filter(|t| t.status == "Skipped").count();
if let Err(e) = self
.update_pipeline_final_context(execution.id, &all_tasks)
.await
{
warn!(
"Failed to update final context for pipeline {}: {}",
execution.id, e
);
}
self.dal
.pipeline_execution()
.mark_completed(execution.id)
.await?;
info!(
"Pipeline completed: {} (name: {}, {} completed, {} failed, {} skipped)",
execution.id, execution.pipeline_name, completed_count, failed_count, skipped_count
);
Ok(())
}
async fn update_pipeline_final_context(
&self,
pipeline_execution_id: UniversalUuid,
all_tasks: &[TaskExecution],
) -> Result<(), ValidationError> {
let mut final_context_id: Option<UniversalUuid> = None;
let mut latest_completion_time: Option<chrono::DateTime<chrono::Utc>> = None;
for task in all_tasks {
if task.status == "Completed" || task.status == "Skipped" {
if let Some(completed_at) = task.completed_at {
let task_namespace = crate::task::TaskNamespace::from_string(&task.task_name)
.map_err(|_| {
crate::error::ValidationError::InvalidTaskName(task.task_name.clone())
})?;
if let Ok(task_metadata) = self
.dal
.task_execution_metadata()
.get_by_pipeline_and_task(pipeline_execution_id, &task_namespace)
.await
{
if let Some(context_id) = task_metadata.context_id {
if latest_completion_time.is_none()
|| completed_at.0 > latest_completion_time.unwrap()
{
final_context_id = Some(context_id);
latest_completion_time = Some(completed_at.0);
}
}
}
}
}
}
if let Some(context_id) = final_context_id {
debug!(
"Updating pipeline {} final context to context_id: {}",
pipeline_execution_id, context_id
);
self.dal
.pipeline_execution()
.update_final_context(pipeline_execution_id, context_id)
.await?;
} else {
debug!(
"No final context found for pipeline {} - keeping initial context",
pipeline_execution_id
);
}
Ok(())
}
}