use std::borrow::Cow;
use std::collections::HashMap;
use std::hash::Hash;
use std::str::FromStr;
use std::sync::Arc;
use cron::Schedule as CronSchedule;
use tokio::sync::{Notify, RwLock, mpsc, watch};
use tokio::task::JoinHandle;
use tokio::time::Duration;
use crate::error::{CanoError, CanoResult};
use crate::workflow::Workflow;
use super::loops::{driver_task, spawn_cron_loop, spawn_every_loop};
use super::{
BackoffPolicy, FlowData, FlowInfo, ParsedSchedule, RunningScheduler, SchedulerCommand, Status,
};
pub struct Scheduler<TState, TResourceKey = Cow<'static, str>>
where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
workflows: HashMap<String, FlowData<TState, TResourceKey>>,
flow_order: Vec<String>,
}
impl<TState, TResourceKey> Scheduler<TState, TResourceKey>
where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
workflows: HashMap::new(),
flow_order: Vec::new(),
}
}
pub fn every(
&mut self,
id: &str,
workflow: Workflow<TState, TResourceKey>,
initial_state: TState,
interval: Duration,
) -> CanoResult<()> {
self.add_flow_internal(id, workflow, initial_state, ParsedSchedule::Every(interval))
}
pub fn every_seconds(
&mut self,
id: &str,
workflow: Workflow<TState, TResourceKey>,
initial_state: TState,
seconds: u64,
) -> CanoResult<()> {
self.every(id, workflow, initial_state, Duration::from_secs(seconds))
}
pub fn every_minutes(
&mut self,
id: &str,
workflow: Workflow<TState, TResourceKey>,
initial_state: TState,
minutes: u64,
) -> CanoResult<()> {
self.every(
id,
workflow,
initial_state,
Duration::from_secs(minutes * 60),
)
}
pub fn every_hours(
&mut self,
id: &str,
workflow: Workflow<TState, TResourceKey>,
initial_state: TState,
hours: u64,
) -> CanoResult<()> {
self.every(
id,
workflow,
initial_state,
Duration::from_secs(hours * 3600),
)
}
pub fn cron(
&mut self,
id: &str,
workflow: Workflow<TState, TResourceKey>,
initial_state: TState,
expr: &str,
) -> CanoResult<()> {
let parsed = CronSchedule::from_str(expr)
.map_err(|e| CanoError::Configuration(format!("Invalid cron expression: {e}")))?;
self.add_flow_internal(
id,
workflow,
initial_state,
ParsedSchedule::Cron(Box::new(parsed)),
)
}
pub fn manual(
&mut self,
id: &str,
workflow: Workflow<TState, TResourceKey>,
initial_state: TState,
) -> CanoResult<()> {
self.add_flow_internal(id, workflow, initial_state, ParsedSchedule::Manual)
}
fn add_flow_internal(
&mut self,
id: &str,
workflow: Workflow<TState, TResourceKey>,
initial_state: TState,
schedule: ParsedSchedule,
) -> CanoResult<()> {
if self.workflows.contains_key(id) {
return Err(CanoError::Configuration(format!(
"Flow '{}' already exists",
id
)));
}
let info = Arc::new(RwLock::new(FlowInfo {
id: id.to_string(),
status: Status::Idle,
run_count: 0,
last_run: None,
failure_streak: 0,
next_eligible: None,
}));
self.workflows.insert(
id.to_string(),
FlowData {
workflow: Arc::new(workflow),
initial_state,
schedule,
info,
policy: Arc::new(BackoffPolicy::default()),
},
);
self.flow_order.push(id.to_string());
Ok(())
}
pub fn len(&self) -> usize {
self.flow_order.len()
}
pub fn is_empty(&self) -> bool {
self.flow_order.is_empty()
}
pub fn contains(&self, id: &str) -> bool {
self.workflows.contains_key(id)
}
pub fn set_backoff(&mut self, id: &str, policy: BackoffPolicy) -> CanoResult<()> {
let flow = self.workflows.get_mut(id).ok_or_else(|| {
CanoError::Configuration(format!("Flow '{id}' not found — cannot set backoff"))
})?;
flow.policy = Arc::new(policy);
Ok(())
}
pub async fn start(self) -> CanoResult<RunningScheduler<TState, TResourceKey>> {
let Self {
workflows,
flow_order,
} = self;
let ordered: Vec<&FlowData<TState, TResourceKey>> = flow_order
.iter()
.filter_map(|id| workflows.get(id))
.collect();
for flow in ordered.iter() {
flow.workflow
.validate()
.and_then(|_| flow.workflow.validate_initial_state(&flow.initial_state))?;
}
for (idx, flow) in ordered.iter().enumerate() {
if let Err(e) = flow.workflow.resources.setup_all().await {
for prior in ordered[..idx].iter().rev() {
let len = prior.workflow.resources.lifecycle_len();
prior.workflow.resources.teardown_range(0..len).await;
}
return Err(e);
}
}
drop(ordered);
let (command_tx, command_rx) = mpsc::channel::<SchedulerCommand>(64);
let stop_notify = Arc::new(Notify::new());
let running = Arc::new(RwLock::new(true));
let scheduler_tasks: Arc<RwLock<Vec<JoinHandle<()>>>> = Arc::new(RwLock::new(Vec::new()));
let mut flows_view: HashMap<String, Arc<RwLock<FlowInfo>>> =
HashMap::with_capacity(workflows.len());
for (id, fd) in &workflows {
flows_view.insert(id.clone(), Arc::clone(&fd.info));
}
let flows_view = Arc::new(flows_view);
let flow_order_view = Arc::new(flow_order.clone());
{
let mut tasks = scheduler_tasks.write().await;
for id in flow_order.iter() {
let Some(fd) = workflows.get(id) else {
continue;
};
let workflow = Arc::clone(&fd.workflow);
let initial_state = fd.initial_state.clone();
let info = Arc::clone(&fd.info);
let policy = fd.policy.clone();
let running_clone = Arc::clone(&running);
let notify_clone = Arc::clone(&stop_notify);
match &fd.schedule {
ParsedSchedule::Every(interval) => {
let interval = *interval;
let handle = tokio::spawn(spawn_every_loop(
workflow,
initial_state,
info,
policy,
running_clone,
notify_clone,
interval,
));
tasks.push(handle);
}
ParsedSchedule::Cron(cron_schedule) => {
let cron_schedule = cron_schedule.clone();
let handle = tokio::spawn(spawn_cron_loop(
workflow,
initial_state,
info,
policy,
running_clone,
notify_clone,
cron_schedule,
));
tasks.push(handle);
}
ParsedSchedule::Manual => {
}
}
}
}
let (result_tx, result_rx) = watch::channel::<Option<CanoResult<()>>>(None);
let driver_handle = tokio::spawn(driver_task(
command_rx,
workflows,
flow_order,
Arc::clone(&running),
Arc::clone(&stop_notify),
Arc::clone(&scheduler_tasks),
result_tx,
));
Ok(RunningScheduler {
command_tx,
flows: flows_view,
flow_order: flow_order_view,
result_rx,
scheduler_tasks,
driver_handle: Arc::new(driver_handle),
liveness: Arc::new(()),
_marker: std::marker::PhantomData,
})
}
}
impl<TState, TResourceKey> Default for Scheduler<TState, TResourceKey>
where
TState: Clone + Send + Sync + 'static + std::fmt::Debug + std::hash::Hash + Eq,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scheduler::test_support::*;
#[tokio::test(flavor = "multi_thread")]
async fn test_scheduler_creation() {
let scheduler: Scheduler<TestState> = Scheduler::new();
assert!(scheduler.is_empty());
assert_eq!(scheduler.len(), 0);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_add_workflow_every_seconds() {
let mut scheduler = Scheduler::<TestState>::new();
let workflow = create_test_workflow();
scheduler
.every_seconds("test_task", workflow, TestState::Start, 5)
.unwrap();
assert_eq!(scheduler.len(), 1);
assert!(scheduler.contains("test_task"));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_add_workflow_every_minutes() {
let mut scheduler: Scheduler<TestState> = Scheduler::<TestState>::new();
let workflow = create_test_workflow();
scheduler
.every_minutes("test_task", workflow, TestState::Start, 2)
.unwrap();
assert!(scheduler.contains("test_task"));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_add_workflow_every_hours() {
let mut scheduler: Scheduler<TestState> = Scheduler::<TestState>::new();
let workflow = create_test_workflow();
scheduler
.every_hours("test_task", workflow, TestState::Start, 1)
.unwrap();
assert!(scheduler.contains("test_task"));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_add_workflow_every_duration() {
let mut scheduler: Scheduler<TestState> = Scheduler::<TestState>::new();
let workflow = create_test_workflow();
scheduler
.every(
"test_task",
workflow,
TestState::Start,
Duration::from_millis(100),
)
.unwrap();
assert!(scheduler.contains("test_task"));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_add_workflow_cron() {
let mut scheduler: Scheduler<TestState> = Scheduler::<TestState>::new();
let workflow = create_test_workflow();
scheduler
.cron("test_task", workflow, TestState::Start, "0 */5 * * * *")
.unwrap();
assert!(scheduler.contains("test_task"));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_add_workflow_cron_invalid() {
let mut scheduler: Scheduler<TestState> = Scheduler::<TestState>::new();
let workflow = create_test_workflow();
let err = scheduler
.cron("test_task", workflow, TestState::Start, "invalid cron")
.unwrap_err();
assert!(matches!(err, CanoError::Configuration(_)));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_add_workflow_manual() {
let mut scheduler: Scheduler<TestState> = Scheduler::<TestState>::new();
let workflow = create_test_workflow();
scheduler
.manual("test_task", workflow, TestState::Start)
.unwrap();
assert!(scheduler.contains("test_task"));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_duplicate_workflow_id() {
let mut scheduler: Scheduler<TestState> = Scheduler::<TestState>::new();
let workflow1 = create_test_workflow();
let workflow2 = create_test_workflow();
scheduler
.every_seconds("test_task", workflow1, TestState::Start, 5)
.unwrap();
let err = scheduler
.every_seconds("test_task", workflow2, TestState::Start, 10)
.unwrap_err();
assert!(matches!(err, CanoError::Configuration(_)));
}
}