use async_trait::async_trait;
use chrono::{DateTime, Datelike, Duration as ChronoDuration, Timelike, Utc, Weekday};
use floxide_core::{error::FloxideError, ActionType, DefaultAction, Node, NodeId, NodeOutcome};
use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, warn};
use uuid::Uuid;
#[derive(Debug, Clone)]
pub enum Schedule {
Once(DateTime<Utc>),
Interval(Duration),
Daily(u32, u32),
Weekly(Weekday, u32, u32),
Monthly(u32, u32, u32),
Cron(String),
}
impl Schedule {
pub fn next_execution(&self) -> Result<DateTime<Utc>, FloxideError> {
let now = Utc::now();
match self {
Schedule::Once(time) => {
if time <= &now {
Err(FloxideError::Other(
"Scheduled time has already passed".to_string(),
))
} else {
Ok(*time)
}
}
Schedule::Interval(duration) => Ok(now
+ ChronoDuration::from_std(*duration).map_err(|e| {
FloxideError::Other(format!("Failed to convert duration: {}", e))
})?),
Schedule::Daily(hour, minute) => {
let mut next = now;
next = next
.with_hour(*hour)
.and_then(|dt| dt.with_minute(*minute))
.and_then(|dt| dt.with_second(0))
.and_then(|dt| dt.with_nanosecond(0))
.ok_or_else(|| FloxideError::Other("Invalid hour or minute".to_string()))?;
if next <= now {
next += ChronoDuration::days(1);
}
Ok(next)
}
Schedule::Weekly(weekday, hour, minute) => {
let mut next = now;
next = next
.with_hour(*hour)
.and_then(|dt| dt.with_minute(*minute))
.and_then(|dt| dt.with_second(0))
.and_then(|dt| dt.with_nanosecond(0))
.ok_or_else(|| FloxideError::Other("Invalid hour or minute".to_string()))?;
let days_until_weekday =
(*weekday as i32 - now.weekday().num_days_from_monday() as i32 + 7) % 7;
if days_until_weekday == 0 && next <= now {
next += ChronoDuration::days(7);
} else {
next += ChronoDuration::days(days_until_weekday as i64);
}
Ok(next)
}
Schedule::Monthly(day, hour, minute) => {
let mut next = now;
next = next
.with_hour(*hour)
.and_then(|dt| dt.with_minute(*minute))
.and_then(|dt| dt.with_second(0))
.and_then(|dt| dt.with_nanosecond(0))
.ok_or_else(|| FloxideError::Other("Invalid hour or minute".to_string()))?;
let current_day = now.day();
if *day <= 31 {
match next.with_day(*day) {
Some(date) => next = date,
None => {
return Err(FloxideError::Other(format!(
"Invalid day {} for the current month",
day
)))
}
}
if next <= now || *day < current_day {
next += ChronoDuration::days(32); next = next.with_day(1).ok_or_else(|| {
FloxideError::Other("Failed to set day to 1".to_string())
})?;
next = next.with_day(*day).ok_or_else(|| {
FloxideError::Other(format!("Invalid day {} for next month", day))
})?;
}
} else {
return Err(FloxideError::Other(format!(
"Invalid day of month: {}",
day
)));
}
Ok(next)
}
Schedule::Cron(_expression) => {
Err(FloxideError::Other(
"Cron expressions are not yet implemented".to_string(),
))
}
}
}
pub fn duration_until_next(&self) -> Result<Duration, FloxideError> {
let next = self.next_execution()?;
let now = Utc::now();
let duration = next.signed_duration_since(now);
if duration.num_milliseconds() <= 0 {
return Err(FloxideError::Other(
"Scheduled time is in the past".to_string(),
));
}
Ok(Duration::from_millis(duration.num_milliseconds() as u64))
}
}
#[async_trait]
pub trait TimerNode<Context, Action>: Send + Sync
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default + Debug,
{
fn schedule(&self) -> Schedule;
async fn execute_on_schedule(&self, ctx: &mut Context) -> Result<Action, FloxideError>;
fn id(&self) -> NodeId;
}
pub struct SimpleTimer<F>
where
F: Send + Sync + 'static,
{
id: NodeId,
schedule: Schedule,
action: F,
}
impl<F> SimpleTimer<F>
where
F: Send + Sync + 'static,
{
pub fn new(schedule: Schedule, action: F) -> Self {
Self {
id: Uuid::new_v4().to_string(),
schedule,
action,
}
}
pub fn with_id(id: impl Into<String>, schedule: Schedule, action: F) -> Self {
Self {
id: id.into(),
schedule,
action,
}
}
}
#[async_trait]
impl<Context, Action, F> TimerNode<Context, Action> for SimpleTimer<F>
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default + Debug,
F: Fn(&mut Context) -> Result<Action, FloxideError> + Send + Sync + 'static,
{
fn schedule(&self) -> Schedule {
self.schedule.clone()
}
async fn execute_on_schedule(&self, ctx: &mut Context) -> Result<Action, FloxideError> {
(self.action)(ctx)
}
fn id(&self) -> NodeId {
self.id.clone()
}
}
pub struct TimerWorkflow<Context, Action>
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default + Debug,
{
nodes: HashMap<NodeId, Arc<dyn TimerNode<Context, Action>>>,
routes: HashMap<(NodeId, Action), NodeId>,
initial_node: NodeId,
termination_action: Action,
}
impl<Context, Action> TimerWorkflow<Context, Action>
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default + Debug,
{
pub fn new(
initial_node: Arc<dyn TimerNode<Context, Action>>,
termination_action: Action,
) -> Self {
let id = initial_node.id();
let mut nodes = HashMap::new();
nodes.insert(id.clone(), initial_node);
Self {
nodes,
routes: HashMap::new(),
initial_node: id,
termination_action,
}
}
pub fn add_node(&mut self, node: Arc<dyn TimerNode<Context, Action>>) {
let id = node.id();
self.nodes.insert(id, node);
}
pub fn set_route(&mut self, from_id: &NodeId, action: Action, to_id: &NodeId) {
self.routes.insert((from_id.clone(), action), to_id.clone());
}
pub async fn execute(&self, ctx: &mut Context) -> Result<(), FloxideError> {
let mut current_node_id = self.initial_node.clone();
loop {
let node = self.nodes.get(¤t_node_id).ok_or_else(|| {
FloxideError::Other(format!("Node not found: {}", current_node_id))
})?;
let wait_duration = match node.schedule().duration_until_next() {
Ok(duration) => duration,
Err(e) => {
warn!(
"Failed to calculate next execution time for node {}: {}",
current_node_id, e
);
Duration::from_secs(5)
}
};
debug!(
"Waiting {:?} until next execution of node {}",
wait_duration, current_node_id
);
sleep(wait_duration).await;
let action = match node.execute_on_schedule(ctx).await {
Ok(action) => action,
Err(e) => {
warn!("Error executing node {}: {}", current_node_id, e);
Action::default()
}
};
if action == self.termination_action {
debug!("Workflow terminated by node {}", current_node_id);
break;
}
if let Some(next_node_id) = self.routes.get(&(current_node_id.clone(), action.clone()))
{
debug!(
"Moving from node {} to node {}",
current_node_id, next_node_id
);
current_node_id = next_node_id.clone();
} else {
if let Some(next_node_id) = self
.routes
.get(&(current_node_id.clone(), Action::default()))
{
debug!(
"No route found for action {:?}, using default route to node {}",
action, next_node_id
);
current_node_id = next_node_id.clone();
} else {
warn!(
"No route found for node {} with action {:?} and no default route",
current_node_id, action
);
break;
}
}
}
Ok(())
}
}
pub struct TimerNodeAdapter<Context, Action>
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default + Debug,
{
node: Arc<dyn TimerNode<Context, Action>>,
id: NodeId,
execute_immediately: bool,
}
impl<Context, Action> TimerNodeAdapter<Context, Action>
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default + Debug,
{
pub fn new(node: Arc<dyn TimerNode<Context, Action>>, execute_immediately: bool) -> Self {
let id = node.id();
Self {
node,
id,
execute_immediately,
}
}
pub fn with_id(
node: Arc<dyn TimerNode<Context, Action>>,
id: impl Into<String>,
execute_immediately: bool,
) -> Self {
Self {
node,
id: id.into(),
execute_immediately,
}
}
}
#[async_trait]
impl<Context, Action> Node<Context, Action> for TimerNodeAdapter<Context, Action>
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default + Debug,
{
type Output = ();
fn id(&self) -> NodeId {
self.id.clone()
}
async fn process(
&self,
ctx: &mut Context,
) -> Result<NodeOutcome<Self::Output, Action>, FloxideError> {
if self.execute_immediately {
let action = self.node.execute_on_schedule(ctx).await?;
Ok(NodeOutcome::RouteToAction(action))
} else {
let wait_duration = self.node.schedule().duration_until_next()?;
debug!(
"Waiting {:?} before executing node {}",
wait_duration, self.id
);
sleep(wait_duration).await;
let action = self.node.execute_on_schedule(ctx).await?;
Ok(NodeOutcome::RouteToAction(action))
}
}
}
pub struct NestedTimerWorkflow<Context, Action>
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default + Debug,
{
workflow: Arc<TimerWorkflow<Context, Action>>,
id: NodeId,
complete_action: Action,
_phantom: PhantomData<(Context, Action)>,
}
impl<Context, Action> NestedTimerWorkflow<Context, Action>
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default + Debug,
{
pub fn new(workflow: Arc<TimerWorkflow<Context, Action>>, complete_action: Action) -> Self {
Self {
workflow,
id: Uuid::new_v4().to_string(),
complete_action,
_phantom: PhantomData,
}
}
pub fn with_id(
workflow: Arc<TimerWorkflow<Context, Action>>,
id: impl Into<String>,
complete_action: Action,
) -> Self {
Self {
workflow,
id: id.into(),
complete_action,
_phantom: PhantomData,
}
}
}
#[async_trait]
impl<Context, Action> Node<Context, Action> for NestedTimerWorkflow<Context, Action>
where
Context: Send + Sync + 'static,
Action: ActionType + Send + Sync + 'static + Default + Debug,
{
type Output = ();
fn id(&self) -> NodeId {
self.id.clone()
}
async fn process(
&self,
ctx: &mut Context,
) -> Result<NodeOutcome<Self::Output, Action>, FloxideError> {
let result = self.workflow.execute(ctx).await;
match result {
Ok(_) => Ok(NodeOutcome::RouteToAction(self.complete_action.clone())),
Err(e) => Err(e),
}
}
}
pub trait TimerActionExt: ActionType {
fn complete() -> Self;
fn retry() -> Self;
}
impl TimerActionExt for DefaultAction {
fn complete() -> Self {
DefaultAction::Custom("timer_complete".to_string())
}
fn retry() -> Self {
DefaultAction::Custom("timer_retry".to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_schedule_next_execution() {
let future_time = Utc::now() + ChronoDuration::hours(1);
let once_schedule = Schedule::Once(future_time);
let next = once_schedule.next_execution().unwrap();
assert_eq!(next, future_time);
let interval_schedule = Schedule::Interval(Duration::from_secs(60));
let next = interval_schedule.next_execution().unwrap();
let diff = (next - Utc::now()).num_seconds();
assert!(diff > 0 && diff <= 61);
let now = Utc::now();
let future_hour = (now.hour() + 1) % 24;
let daily_schedule = Schedule::Daily(future_hour, 0);
let next = daily_schedule.next_execution().unwrap();
assert!(next > now);
assert_eq!(next.hour(), future_hour);
assert_eq!(next.minute(), 0);
}
#[tokio::test]
async fn test_simple_timer() {
let mut ctx = "test_context".to_string();
let timer = SimpleTimer::new(
Schedule::Once(Utc::now() + ChronoDuration::milliseconds(100)),
|ctx: &mut String| {
*ctx = format!("{}_executed", ctx);
Ok(DefaultAction::Next)
},
);
let action = timer.execute_on_schedule(&mut ctx).await.unwrap();
assert_eq!(action, DefaultAction::Next);
assert_eq!(ctx, "test_context_executed");
}
#[tokio::test]
async fn test_timer_workflow() {
let mut ctx = 0;
let timer1 = Arc::new(SimpleTimer::with_id(
"timer1",
Schedule::Once(Utc::now() + ChronoDuration::milliseconds(100)),
|ctx: &mut i32| {
*ctx += 1;
Ok(DefaultAction::Next)
},
));
let timer2 = Arc::new(SimpleTimer::with_id(
"timer2",
Schedule::Once(Utc::now() + ChronoDuration::milliseconds(200)),
|ctx: &mut i32| {
*ctx += 2;
Ok(DefaultAction::Custom("terminate".to_string()))
},
));
let mut workflow = TimerWorkflow::new(
timer1.clone(),
DefaultAction::Custom("terminate".to_string()),
);
workflow.add_node(timer2.clone());
workflow.set_route(&timer1.id(), DefaultAction::Next, &timer2.id());
let handle = tokio::spawn(async move {
workflow.execute(&mut ctx).await.unwrap();
ctx
});
let result = tokio::time::timeout(Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
assert_eq!(result, 3); }
}