mod context_manager;
mod scheduler_loop;
pub mod stale_claim_sweeper;
mod state_manager;
mod trigger_rules;
pub use trigger_rules::{TriggerCondition, TriggerRule, ValueOperator};
use std::sync::Arc;
use std::time::Duration;
use diesel::prelude::*;
use diesel::Connection;
use tracing::info;
use uuid::Uuid;
use crate::dal::unified::models::{NewUnifiedTaskExecution, NewUnifiedWorkflowExecution};
use crate::dal::DAL;
use crate::database::schema::unified::{task_executions, workflow_executions};
use crate::database::universal_types::{UniversalTimestamp, UniversalUuid};
use crate::dispatcher::Dispatcher;
use crate::error::ValidationError;
use crate::task::TaskNamespace;
use crate::Runtime;
use crate::{Context, Database, Workflow};
use scheduler_loop::SchedulerLoop;
pub struct TaskScheduler {
dal: DAL,
runtime: Arc<Runtime>,
instance_id: Uuid,
poll_interval: Duration,
dispatcher: Option<Arc<dyn Dispatcher>>,
shutdown_rx: Option<tokio::sync::watch::Receiver<bool>>,
}
impl TaskScheduler {
pub async fn new(database: Database) -> Result<Self, ValidationError> {
let scheduler = Self::with_poll_interval(database, Duration::from_millis(100)).await?;
Ok(scheduler)
}
pub async fn with_poll_interval(
database: Database,
poll_interval: Duration,
) -> Result<Self, ValidationError> {
Ok(Self::with_poll_interval_sync(database, poll_interval))
}
pub(crate) fn with_poll_interval_sync(database: Database, poll_interval: Duration) -> Self {
let dal = DAL::new(database.clone());
Self {
dal,
runtime: Arc::new(Runtime::new()),
instance_id: Uuid::new_v4(),
poll_interval,
dispatcher: None,
shutdown_rx: None,
}
}
pub fn with_runtime(mut self, runtime: Arc<Runtime>) -> Self {
self.runtime = runtime;
self
}
pub fn runtime(&self) -> &Arc<Runtime> {
&self.runtime
}
pub fn with_shutdown(mut self, shutdown_rx: tokio::sync::watch::Receiver<bool>) -> Self {
self.shutdown_rx = Some(shutdown_rx);
self
}
pub fn with_dispatcher(mut self, dispatcher: Arc<dyn Dispatcher>) -> Self {
self.dispatcher = Some(dispatcher);
self
}
pub fn dispatcher(&self) -> Option<&Arc<dyn Dispatcher>> {
self.dispatcher.as_ref()
}
pub async fn schedule_workflow_execution(
&self,
workflow_name: &str,
input_context: Context<serde_json::Value>,
) -> Result<Uuid, ValidationError> {
info!("Scheduling workflow execution: {}", workflow_name);
let workflow = match self.runtime.get_workflow(workflow_name) {
Some(wf) => wf,
None => return Err(ValidationError::WorkflowNotFound(workflow_name.to_string())),
};
let current_version = workflow.metadata().version.clone();
let last_version = self
.dal
.workflow_execution()
.get_last_version(workflow_name)
.await?;
if last_version.as_ref() != Some(¤t_version) {
info!(
"Workflow '{}' version changed: {} -> {}",
workflow_name,
last_version.unwrap_or_else(|| "none".to_string()),
current_version
);
}
let stored_context = self.dal.context().create(&input_context).await?;
let task_ids = workflow.topological_sort()?;
let mut task_data: Vec<(String, String, String, i32)> = Vec::with_capacity(task_ids.len());
for task_id in &task_ids {
let trigger_rules = self.get_task_trigger_rules(&workflow, task_id);
let task_config = self.get_task_configuration(&workflow, task_id);
let max_attempts = workflow
.get_task(task_id)
.map(|t| t.retry_policy().max_attempts)
.unwrap_or(3);
task_data.push((
task_id.to_string(),
trigger_rules.to_string(),
task_config.to_string(),
max_attempts,
));
}
let workflow_execution_id = UniversalUuid::new_v4();
let now = UniversalTimestamp::now();
let wf_name = workflow_name.to_string();
let wf_version = current_version.clone();
crate::dispatch_backend!(
self.dal.backend(),
self.create_workflow_execution_postgres(
workflow_execution_id,
now,
wf_name,
wf_version,
stored_context,
task_data,
)
.await?,
self.create_workflow_execution_sqlite(
workflow_execution_id,
now,
wf_name,
wf_version,
stored_context,
task_data,
)
.await?
);
info!("Workflow execution scheduled: {}", workflow_execution_id);
Ok(workflow_execution_id.into())
}
#[cfg(feature = "postgres")]
async fn create_workflow_execution_postgres(
&self,
workflow_execution_id: UniversalUuid,
now: UniversalTimestamp,
workflow_name: String,
workflow_version: String,
stored_context: Option<UniversalUuid>,
task_data: Vec<(String, String, String, i32)>,
) -> Result<(), ValidationError> {
let conn = self
.dal
.database()
.get_postgres_connection()
.await
.map_err(|e| ValidationError::ConnectionPool(e.to_string()))?;
conn.interact(move |conn| {
conn.transaction(|conn| {
diesel::insert_into(workflow_executions::table)
.values(&NewUnifiedWorkflowExecution {
id: workflow_execution_id,
workflow_name,
workflow_version,
status: "Pending".to_string(),
context_id: stored_context,
started_at: now,
created_at: now,
updated_at: now,
})
.execute(conn)?;
for (task_name, trigger_rules, task_config, max_attempts) in task_data {
diesel::insert_into(task_executions::table)
.values(&NewUnifiedTaskExecution {
id: UniversalUuid::new_v4(),
workflow_execution_id,
task_name,
status: "NotStarted".to_string(),
attempt: 1,
max_attempts,
trigger_rules,
task_configuration: task_config,
created_at: now,
updated_at: now,
})
.execute(conn)?;
}
Ok::<_, diesel::result::Error>(())
})
})
.await
.map_err(|e| ValidationError::ConnectionPool(e.to_string()))??;
Ok(())
}
#[cfg(feature = "sqlite")]
async fn create_workflow_execution_sqlite(
&self,
workflow_execution_id: UniversalUuid,
now: UniversalTimestamp,
workflow_name: String,
workflow_version: String,
stored_context: Option<UniversalUuid>,
task_data: Vec<(String, String, String, i32)>,
) -> Result<(), ValidationError> {
let conn = self
.dal
.database()
.get_sqlite_connection()
.await
.map_err(|e| ValidationError::ConnectionPool(e.to_string()))?;
conn.interact(move |conn| {
conn.transaction(|conn| {
diesel::insert_into(workflow_executions::table)
.values(&NewUnifiedWorkflowExecution {
id: workflow_execution_id,
workflow_name,
workflow_version,
status: "Pending".to_string(),
context_id: stored_context,
started_at: now,
created_at: now,
updated_at: now,
})
.execute(conn)?;
for (task_name, trigger_rules, task_config, max_attempts) in task_data {
diesel::insert_into(task_executions::table)
.values(&NewUnifiedTaskExecution {
id: UniversalUuid::new_v4(),
workflow_execution_id,
task_name,
status: "NotStarted".to_string(),
attempt: 1,
max_attempts,
trigger_rules,
task_configuration: task_config,
created_at: now,
updated_at: now,
})
.execute(conn)?;
}
Ok::<_, diesel::result::Error>(())
})
})
.await
.map_err(|e| ValidationError::ConnectionPool(e.to_string()))??;
Ok(())
}
pub async fn run_scheduling_loop(&self) -> Result<(), ValidationError> {
let mut scheduler_loop = SchedulerLoop::with_dispatcher(
&self.dal,
self.runtime.clone(),
self.instance_id,
self.poll_interval,
self.dispatcher.clone(),
);
if let Some(ref shutdown_rx) = self.shutdown_rx {
scheduler_loop = scheduler_loop.with_shutdown(shutdown_rx.clone());
}
scheduler_loop.run().await
}
pub async fn process_active_executions(&self) -> Result<(), ValidationError> {
let scheduler_loop = SchedulerLoop::with_dispatcher(
&self.dal,
self.runtime.clone(),
self.instance_id,
self.poll_interval,
self.dispatcher.clone(),
);
scheduler_loop.process_active_executions().await
}
fn get_task_trigger_rules(
&self,
workflow: &Workflow,
task_namespace: &TaskNamespace,
) -> serde_json::Value {
workflow
.get_task(task_namespace)
.map(|task| task.trigger_rules())
.unwrap_or_else(|_| serde_json::json!({"type": "Always"}))
}
fn get_task_configuration(
&self,
_workflow: &Workflow,
_task_namespace: &TaskNamespace,
) -> serde_json::Value {
serde_json::json!({})
}
}