use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;
use chrono::Utc;
use cron::Schedule as CronSchedule;
use tokio::sync::{Notify, RwLock, mpsc, watch};
use tokio::task::JoinHandle;
use tokio::time::{Duration, sleep};
use crate::error::{CanoError, CanoResult};
use crate::workflow::Workflow;
#[cfg(feature = "tracing")]
use tracing::Instrument;
use super::{BackoffPolicy, FlowData, FlowInfo, SchedulerCommand, Status};
pub(super) async fn spawn_every_loop<TState, TResourceKey>(
workflow: Arc<Workflow<TState, TResourceKey>>,
initial_state: TState,
info: Arc<RwLock<FlowInfo>>,
policy: Arc<BackoffPolicy>,
running: Arc<RwLock<bool>>,
stop_notify: Arc<Notify>,
interval: Duration,
) where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
if !*running.read().await {
return;
}
if dispatchable_now(&info).await {
execute_flow(
Arc::clone(&workflow),
initial_state.clone(),
Arc::clone(&info),
&policy,
)
.await;
}
loop {
if !*running.read().await {
break;
}
let wait = wait_until_eligible(&info, interval).await;
tokio::select! {
_ = sleep(wait) => {}
_ = stop_notify.notified() => break,
}
if !*running.read().await {
break;
}
if !dispatchable_now(&info).await {
continue;
}
execute_flow(
Arc::clone(&workflow),
initial_state.clone(),
Arc::clone(&info),
&policy,
)
.await;
}
}
pub(super) async fn spawn_cron_loop<TState, TResourceKey>(
workflow: Arc<Workflow<TState, TResourceKey>>,
initial_state: TState,
info: Arc<RwLock<FlowInfo>>,
policy: Arc<BackoffPolicy>,
running: Arc<RwLock<bool>>,
stop_notify: Arc<Notify>,
schedule: Box<CronSchedule>,
) where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
loop {
if !*running.read().await {
break;
}
let now = Utc::now();
let Some(next) = schedule.after(&now).next() else {
break;
};
let wait_duration = (next - now).to_std().unwrap_or(Duration::from_secs(0));
tokio::select! {
_ = sleep(wait_duration) => {}
_ = stop_notify.notified() => break,
}
if !*running.read().await {
break;
}
let info_snapshot = info.read().await;
if let Some(eligible) = info_snapshot.next_eligible
&& Utc::now() < eligible
{
#[cfg(feature = "tracing")]
tracing::debug!(
flow_id = %info_snapshot.id,
next_eligible = %eligible,
"cron tick suppressed by backoff window"
);
drop(info_snapshot);
continue;
}
drop(info_snapshot);
if !dispatchable_now(&info).await {
continue;
}
execute_flow(
Arc::clone(&workflow),
initial_state.clone(),
Arc::clone(&info),
&policy,
)
.await;
}
}
pub(super) async fn driver_task<TState, TResourceKey>(
mut rx: mpsc::Receiver<SchedulerCommand>,
workflows: HashMap<String, FlowData<TState, TResourceKey>>,
flow_order: Vec<String>,
running: Arc<RwLock<bool>>,
stop_notify: Arc<Notify>,
scheduler_tasks: Arc<RwLock<Vec<JoinHandle<()>>>>,
result_tx: watch::Sender<Option<CanoResult<()>>>,
) where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
while let Some(cmd) = rx.recv().await {
match cmd {
SchedulerCommand::Stop => {
rx.close();
break;
}
SchedulerCommand::Trigger { id, response } => {
let outcome = if let Some(flow) = workflows.get(&id) {
match reserve_flow(Arc::clone(&flow.info)).await {
ReserveOutcome::Reserved => {
let workflow = Arc::clone(&flow.workflow);
let initial_state = flow.initial_state.clone();
let info = Arc::clone(&flow.info);
let policy = Arc::clone(&flow.policy);
let handle = tokio::spawn(async move {
execute_reserved_flow(workflow, initial_state, info, &policy).await;
});
let mut tasks = scheduler_tasks.write().await;
tasks.retain(|h| !h.is_finished());
tasks.push(handle);
Ok(())
}
ReserveOutcome::AlreadyRunning => Err(CanoError::Workflow(format!(
"Flow '{id}' is already running"
))),
ReserveOutcome::Tripped => Err(CanoError::Workflow(format!(
"Flow '{id}' is tripped — call reset_flow before triggering"
))),
}
} else {
Err(CanoError::Workflow(format!(
"No workflow registered with id '{id}'"
)))
};
let _ = response.send(outcome);
}
SchedulerCommand::Reset { id, response } => {
let outcome = if let Some(flow) = workflows.get(&id) {
let mut info_guard = flow.info.write().await;
info_guard.failure_streak = 0;
info_guard.next_eligible = None;
if !matches!(info_guard.status, Status::Running) {
info_guard.status = Status::Idle;
}
Ok(())
} else {
Err(CanoError::Workflow(format!(
"No workflow registered with id '{id}'"
)))
};
let _ = response.send(outcome);
}
}
}
*running.write().await = false;
stop_notify.notify_waiters();
{
let mut tasks = scheduler_tasks.write().await;
while let Some(handle) = tasks.pop() {
let _ = handle.await;
}
}
let timeout = Duration::from_secs(30);
let start_time = tokio::time::Instant::now();
let mut result: CanoResult<()> = Ok(());
'wait: loop {
let mut any_running = false;
for fd in workflows.values() {
if fd.info.read().await.status == Status::Running {
any_running = true;
break;
}
}
if !any_running {
break 'wait;
}
if start_time.elapsed() >= timeout {
result = Err(CanoError::Workflow(
"Timeout waiting for workflows to complete".to_string(),
));
break 'wait;
}
sleep(Duration::from_millis(100)).await;
}
for id in flow_order.iter().rev() {
if let Some(flow) = workflows.get(id) {
let len = flow.workflow.resources.lifecycle_len();
flow.workflow.resources.teardown_range(0..len).await;
}
}
let _ = result_tx.send(Some(result));
}
async fn execute_flow<TState, TResourceKey>(
workflow: Arc<Workflow<TState, TResourceKey>>,
initial_state: TState,
info: Arc<RwLock<FlowInfo>>,
policy: &BackoffPolicy,
) where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
if !matches!(
reserve_flow(Arc::clone(&info)).await,
ReserveOutcome::Reserved
) {
return;
}
execute_reserved_flow(workflow, initial_state, info, policy).await;
}
enum ReserveOutcome {
Reserved,
AlreadyRunning,
Tripped,
}
async fn reserve_flow(info: Arc<RwLock<FlowInfo>>) -> ReserveOutcome {
let mut info_guard = info.write().await;
match info_guard.status {
Status::Running => return ReserveOutcome::AlreadyRunning,
Status::Tripped { .. } => return ReserveOutcome::Tripped,
_ => {}
}
info_guard.status = Status::Running;
info_guard.last_run = Some(Utc::now());
info_guard.run_count += 1;
ReserveOutcome::Reserved
}
async fn execute_reserved_flow<TState, TResourceKey>(
workflow: Arc<Workflow<TState, TResourceKey>>,
initial_state: TState,
info: Arc<RwLock<FlowInfo>>,
policy: &BackoffPolicy,
) where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
#[cfg(feature = "metrics")]
let _active = crate::metrics::SchedulerFlowActiveGuard::new();
#[cfg(feature = "metrics")]
let _flow_id = info.read().await.id.clone();
#[cfg(feature = "metrics")]
let _started = std::time::Instant::now();
#[cfg(feature = "tracing")]
let result = workflow
.execute_workflow(initial_state)
.instrument(tracing::info_span!("execute_flow"))
.await;
#[cfg(not(feature = "tracing"))]
let result = workflow.execute_workflow(initial_state).await;
#[cfg(feature = "metrics")]
crate::metrics::scheduler_flow_run(&_flow_id, result.is_ok(), _started.elapsed());
apply_outcome(&info, result.map(|_| ()), policy).await;
}
async fn apply_outcome(
info: &Arc<RwLock<FlowInfo>>,
result: Result<(), CanoError>,
policy: &BackoffPolicy,
) {
let mut info_guard = info.write().await;
match result {
Ok(_) => {
info_guard.status = Status::Completed;
info_guard.failure_streak = 0;
info_guard.next_eligible = None;
}
Err(e) => {
let err_str = e.to_string();
let new_streak = info_guard.failure_streak.saturating_add(1);
info_guard.failure_streak = new_streak;
if policy.is_tripped(new_streak) {
info_guard.next_eligible = None;
info_guard.status = Status::Tripped {
streak: new_streak,
last_error: err_str,
};
#[cfg(feature = "metrics")]
crate::metrics::scheduler_flow_tripped(&info_guard.id);
} else {
let delay = policy.compute_delay(new_streak);
let until = Utc::now()
+ chrono::Duration::from_std(delay).unwrap_or(chrono::Duration::zero());
info_guard.next_eligible = Some(until);
info_guard.status = Status::Backoff {
until,
streak: new_streak,
last_error: err_str,
};
#[cfg(feature = "metrics")]
crate::metrics::scheduler_flow_backoff(&info_guard.id);
}
}
}
}
async fn dispatchable_now(info: &Arc<RwLock<FlowInfo>>) -> bool {
let guard = info.read().await;
!matches!(guard.status, Status::Running | Status::Tripped { .. })
}
async fn wait_until_eligible(info: &Arc<RwLock<FlowInfo>>, interval: Duration) -> Duration {
let snapshot = info.read().await;
if let Some(eligible) = snapshot.next_eligible {
let now = Utc::now();
if eligible > now {
let extra = (eligible - now).to_std().unwrap_or(Duration::from_secs(0));
return interval.max(extra);
}
}
interval
}