use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;
use chrono::Utc;
use cron::Schedule as CronSchedule;
use tokio::sync::{Notify, RwLock, mpsc, watch};
use tokio::task::{AbortHandle, JoinHandle};
use tokio::time::{Duration, sleep};
use crate::error::{CanoError, CanoResult};
use crate::workflow::Workflow;
#[cfg(feature = "tracing")]
use tracing::Instrument;
use super::{BackoffPolicy, FlowData, FlowInfo, SchedulerCommand, Status};
const SHUTDOWN_POLL_CHUNK: Duration = Duration::from_millis(250);
async fn sleep_unless_stopped(
wait: Duration,
running: &Arc<RwLock<bool>>,
stop_notify: &Arc<Notify>,
) -> bool {
if !*running.read().await {
return false;
}
let mut remaining = wait;
while !remaining.is_zero() {
if !*running.read().await {
return false;
}
let chunk = remaining.min(SHUTDOWN_POLL_CHUNK);
tokio::select! {
_ = sleep(chunk) => {}
_ = stop_notify.notified() => return false,
}
remaining = remaining.saturating_sub(chunk);
}
true
}
pub(super) async fn spawn_every_loop<TState, TResourceKey>(
workflow: Arc<Workflow<TState, TResourceKey>>,
initial_state: TState,
info: Arc<RwLock<FlowInfo>>,
policy: Arc<BackoffPolicy>,
running: Arc<RwLock<bool>>,
stop_notify: Arc<Notify>,
interval: Duration,
) where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
if !*running.read().await {
return;
}
if dispatchable_now(&info).await {
execute_flow(
Arc::clone(&workflow),
initial_state.clone(),
Arc::clone(&info),
&policy,
)
.await;
}
loop {
if !*running.read().await {
break;
}
let wait = wait_until_eligible(&info, interval).await;
if !sleep_unless_stopped(wait, &running, &stop_notify).await {
break;
}
if !dispatchable_now(&info).await {
continue;
}
execute_flow(
Arc::clone(&workflow),
initial_state.clone(),
Arc::clone(&info),
&policy,
)
.await;
}
}
pub(super) async fn spawn_cron_loop<TState, TResourceKey>(
workflow: Arc<Workflow<TState, TResourceKey>>,
initial_state: TState,
info: Arc<RwLock<FlowInfo>>,
policy: Arc<BackoffPolicy>,
running: Arc<RwLock<bool>>,
stop_notify: Arc<Notify>,
schedule: Box<CronSchedule>,
) where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
loop {
if !*running.read().await {
break;
}
let now = Utc::now();
let Some(next) = schedule.after(&now).next() else {
break;
};
let wait_duration = (next - now).to_std().unwrap_or(Duration::from_secs(0));
if !sleep_unless_stopped(wait_duration, &running, &stop_notify).await {
break;
}
let now2 = Utc::now();
if now2 < next {
let extra = (next - now2).to_std().unwrap_or(Duration::from_secs(0));
if !sleep_unless_stopped(extra, &running, &stop_notify).await {
break;
}
}
let info_snapshot = info.read().await;
if let Some(eligible) = info_snapshot.next_eligible
&& Utc::now() < eligible
{
#[cfg(feature = "tracing")]
tracing::debug!(
flow_id = %info_snapshot.id,
next_eligible = %eligible,
"cron tick suppressed by backoff window"
);
drop(info_snapshot);
continue;
}
drop(info_snapshot);
if !dispatchable_now(&info).await {
continue;
}
execute_flow(
Arc::clone(&workflow),
initial_state.clone(),
Arc::clone(&info),
&policy,
)
.await;
}
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn driver_task<TState, TResourceKey>(
mut rx: mpsc::Receiver<SchedulerCommand>,
workflows: HashMap<Arc<str>, FlowData<TState, TResourceKey>>,
flow_order: Vec<Arc<str>>,
running: Arc<RwLock<bool>>,
stop_notify: Arc<Notify>,
scheduler_tasks: Arc<RwLock<Vec<JoinHandle<()>>>>,
in_flight_drain: Arc<RwLock<Option<AbortHandle>>>,
result_tx: watch::Sender<Option<CanoResult<()>>>,
) where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
while let Some(cmd) = rx.recv().await {
match cmd {
SchedulerCommand::Stop => {
rx.close();
break;
}
SchedulerCommand::Trigger { id, response } => {
let outcome = if let Some(flow) = workflows.get(&id) {
match reserve_flow(Arc::clone(&flow.info)).await {
ReserveOutcome::Reserved => {
let workflow = Arc::clone(&flow.workflow);
let initial_state = flow.initial_state.clone();
let info = Arc::clone(&flow.info);
let policy = Arc::clone(&flow.policy);
let handle = tokio::spawn(async move {
execute_reserved_flow(workflow, initial_state, info, &policy).await;
});
let mut tasks = scheduler_tasks.write().await;
tasks.retain(|h| !h.is_finished());
tasks.push(handle);
Ok(())
}
ReserveOutcome::AlreadyRunning => Err(CanoError::Workflow(format!(
"Flow '{id}' is already running"
))),
ReserveOutcome::Tripped => Err(CanoError::Workflow(format!(
"Flow '{id}' is tripped — call reset_flow before triggering"
))),
}
} else {
Err(CanoError::Workflow(format!(
"No workflow registered with id '{id}'"
)))
};
let _ = response.send(outcome);
}
SchedulerCommand::Reset { id, response } => {
let outcome = if let Some(flow) = workflows.get(&id) {
let mut info_guard = flow.info.write().await;
info_guard.failure_streak = 0;
info_guard.next_eligible = None;
if !matches!(info_guard.status, Status::Running) {
info_guard.status = Status::Idle;
}
Ok(())
} else {
Err(CanoError::Workflow(format!(
"No workflow registered with id '{id}'"
)))
};
let _ = response.send(outcome);
}
}
}
*running.write().await = false;
stop_notify.notify_waiters();
loop {
let handle = {
let mut tasks = scheduler_tasks.write().await;
tasks.pop()
};
match handle {
Some(h) => {
let abort = h.abort_handle();
*in_flight_drain.write().await = Some(abort);
let _ = h.await;
*in_flight_drain.write().await = None;
}
None => break,
}
}
let timeout = Duration::from_secs(30);
let start_time = tokio::time::Instant::now();
let mut result: CanoResult<()> = Ok(());
'wait: loop {
let mut any_running = false;
for fd in workflows.values() {
if fd.info.read().await.status == Status::Running {
any_running = true;
break;
}
}
if !any_running {
break 'wait;
}
if start_time.elapsed() >= timeout {
result = Err(CanoError::Workflow(
"Timeout waiting for workflows to complete".to_string(),
));
break 'wait;
}
sleep(Duration::from_millis(100)).await;
}
for id in flow_order.iter().rev() {
if let Some(flow) = workflows.get(id) {
let len = flow.workflow.resources.lifecycle_len();
flow.workflow.resources.teardown_range(0..len).await;
}
}
let _ = result_tx.send(Some(result));
}
async fn execute_flow<TState, TResourceKey>(
workflow: Arc<Workflow<TState, TResourceKey>>,
initial_state: TState,
info: Arc<RwLock<FlowInfo>>,
policy: &BackoffPolicy,
) where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
if !matches!(
reserve_flow(Arc::clone(&info)).await,
ReserveOutcome::Reserved
) {
return;
}
execute_reserved_flow(workflow, initial_state, info, policy).await;
}
enum ReserveOutcome {
Reserved,
AlreadyRunning,
Tripped,
}
async fn reserve_flow(info: Arc<RwLock<FlowInfo>>) -> ReserveOutcome {
let mut info_guard = info.write().await;
match info_guard.status {
Status::Running => return ReserveOutcome::AlreadyRunning,
Status::Tripped { .. } => return ReserveOutcome::Tripped,
_ => {}
}
info_guard.status = Status::Running;
info_guard.last_run = Some(Utc::now());
info_guard.run_count += 1;
ReserveOutcome::Reserved
}
async fn execute_reserved_flow<TState, TResourceKey>(
workflow: Arc<Workflow<TState, TResourceKey>>,
initial_state: TState,
info: Arc<RwLock<FlowInfo>>,
policy: &BackoffPolicy,
) where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
use futures_util::FutureExt;
use std::panic::AssertUnwindSafe;
#[cfg(feature = "metrics")]
let _active = crate::metrics::SchedulerFlowActiveGuard::new();
#[cfg(feature = "metrics")]
let _flow_id = info.read().await.id.clone();
#[cfg(feature = "metrics")]
let _started = std::time::Instant::now();
let total_budget = workflow
.total_timeout
.map(|d| (std::time::Instant::now(), d));
#[cfg(feature = "tracing")]
let workflow_fut = workflow
.execute_workflow(initial_state, total_budget)
.instrument(tracing::info_span!("execute_flow"));
#[cfg(not(feature = "tracing"))]
let workflow_fut = workflow.execute_workflow(initial_state, total_budget);
let result = match AssertUnwindSafe(workflow_fut).catch_unwind().await {
Ok(inner) => inner,
Err(payload) => {
let msg = crate::workflow::panic_payload_message(&*payload);
#[cfg(feature = "tracing")]
tracing::error!(panic = %msg, "scheduled flow panicked");
Err(CanoError::task_execution(format!(
"scheduled flow panicked: {msg}"
)))
}
};
#[cfg(feature = "metrics")]
crate::metrics::scheduler_flow_run(&_flow_id, result.is_ok(), _started.elapsed());
apply_outcome(&info, result.map(|_| ()), policy).await;
}
async fn apply_outcome(
info: &Arc<RwLock<FlowInfo>>,
result: Result<(), CanoError>,
policy: &BackoffPolicy,
) {
let mut info_guard = info.write().await;
match result {
Ok(_) => {
info_guard.status = Status::Completed;
info_guard.failure_streak = 0;
info_guard.next_eligible = None;
}
Err(e) => {
let err_str: Arc<str> = Arc::from(e.to_string());
let new_streak = info_guard.failure_streak.saturating_add(1);
info_guard.failure_streak = new_streak;
if policy.is_tripped(new_streak) {
info_guard.next_eligible = None;
info_guard.status = Status::Tripped {
streak: new_streak,
last_error: err_str,
};
#[cfg(feature = "metrics")]
crate::metrics::scheduler_flow_tripped(&info_guard.id);
} else {
let delay = policy.compute_delay(new_streak);
let until = Utc::now()
+ chrono::Duration::from_std(delay).unwrap_or(chrono::Duration::zero());
info_guard.next_eligible = Some(until);
info_guard.status = Status::Backoff {
until,
streak: new_streak,
last_error: err_str,
};
#[cfg(feature = "metrics")]
crate::metrics::scheduler_flow_backoff(&info_guard.id);
}
}
}
}
async fn dispatchable_now(info: &Arc<RwLock<FlowInfo>>) -> bool {
let guard = info.read().await;
!matches!(guard.status, Status::Running | Status::Tripped { .. })
}
async fn wait_until_eligible(info: &Arc<RwLock<FlowInfo>>, interval: Duration) -> Duration {
let snapshot = info.read().await;
if let Some(eligible) = snapshot.next_eligible {
let now = Utc::now();
if eligible > now {
let extra = (eligible - now).to_std().unwrap_or(Duration::from_secs(0));
return interval.max(extra);
}
}
interval
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn sleep_unless_stopped_returns_early_when_running_flips() {
let running = Arc::new(RwLock::new(true));
let stop = Arc::new(Notify::new());
let running_clone = Arc::clone(&running);
let stop_clone = Arc::clone(&stop);
let start = tokio::time::Instant::now();
let task = tokio::spawn(async move {
sleep_unless_stopped(Duration::from_secs(10), &running_clone, &stop_clone).await
});
tokio::time::sleep(Duration::from_millis(50)).await;
*running.write().await = false;
let returned_full = task.await.unwrap();
let elapsed = start.elapsed();
assert!(!returned_full, "helper must report early-exit");
assert!(
elapsed < SHUTDOWN_POLL_CHUNK + Duration::from_millis(150),
"helper must observe `running=false` within ~1 chunk, got {elapsed:?}"
);
}
#[tokio::test]
async fn sleep_unless_stopped_returns_false_on_zero_when_already_stopped() {
let running = Arc::new(RwLock::new(false));
let stop = Arc::new(Notify::new());
let returned_full = sleep_unless_stopped(Duration::ZERO, &running, &stop).await;
assert!(
!returned_full,
"zero-duration sleep must surface running=false instead of short-circuiting to true"
);
}
#[tokio::test]
async fn sleep_unless_stopped_observes_notify() {
let running = Arc::new(RwLock::new(true));
let stop = Arc::new(Notify::new());
let r = Arc::clone(&running);
let s = Arc::clone(&stop);
let task =
tokio::spawn(
async move { sleep_unless_stopped(Duration::from_secs(10), &r, &s).await },
);
tokio::time::sleep(Duration::from_millis(20)).await;
stop.notify_waiters();
let returned_full = task.await.unwrap();
assert!(!returned_full, "notify must trigger early-exit");
}
}