use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use tracing::{debug, error, info, warn};
use crate::action::{ActionType, DefaultAction};
use crate::error::FloxideError;
use crate::node::{Node, NodeId, NodeOutcome};
#[derive(Debug, thiserror::Error)]
pub enum WorkflowError {
#[error("Initial node not found: {0}")]
InitialNodeNotFound(NodeId),
#[error("Node not found: {0}")]
NodeNotFound(NodeId),
#[error("Action not handled: {0}")]
ActionNotHandled(String),
#[error("Node execution error: {0}")]
NodeExecution(#[from] FloxideError),
}
pub struct Workflow<Context, A = DefaultAction, Output = ()>
where
A: ActionType,
{
start_node: NodeId,
pub(crate) nodes: HashMap<NodeId, Arc<dyn Node<Context, A, Output = Output>>>,
edges: HashMap<(NodeId, A), NodeId>,
default_routes: HashMap<NodeId, NodeId>,
allow_cycles: bool,
cycle_limit: usize,
}
impl<Context, A, Output> Workflow<Context, A, Output>
where
Context: Send + Sync + 'static,
A: ActionType + Debug + Default + Clone + Send + Sync + 'static,
Output: Send + Sync + 'static + std::fmt::Debug,
{
pub fn new<N>(start_node: N) -> Self
where
N: Node<Context, A, Output = Output> + 'static,
{
let id = start_node.id();
let mut nodes = HashMap::new();
nodes.insert(
id.clone(),
Arc::new(start_node) as Arc<dyn Node<Context, A, Output = Output>>,
);
Self {
start_node: id,
nodes,
edges: HashMap::new(),
default_routes: HashMap::new(),
allow_cycles: false,
cycle_limit: 0,
}
}
pub fn add_node<N>(&mut self, node: N) -> &mut Self
where
N: Node<Context, A, Output = Output> + 'static,
{
let id = node.id();
self.nodes.insert(
id,
Arc::new(node) as Arc<dyn Node<Context, A, Output = Output>>,
);
self
}
pub fn connect(&mut self, from: &NodeId, action: A, to: &NodeId) -> &mut Self {
self.edges.insert((from.clone(), action), to.clone());
self
}
pub fn set_default_route(&mut self, from: &NodeId, to: &NodeId) -> &mut Self {
self.default_routes.insert(from.clone(), to.clone());
self
}
pub fn get_node(&self, id: NodeId) -> Option<&dyn Node<Context, A, Output = Output>> {
self.nodes.get(&id).map(|node| node.as_ref())
}
pub fn allow_cycles(&mut self, allow: bool) -> &mut Self {
self.allow_cycles = allow;
self
}
pub fn set_cycle_limit(&mut self, limit: usize) -> &mut Self {
self.cycle_limit = limit;
self
}
pub async fn execute(&self, ctx: &mut Context) -> Result<Output, WorkflowError> {
let mut current_node_id = self.start_node.clone();
let mut visit_counts = HashMap::new();
info!(start_node = %current_node_id, "Starting workflow execution");
debug!(node = %current_node_id, "Starting workflow execution from node");
debug!("Node connections:");
for ((from, action), to) in &self.edges {
debug!(from = %from, action = ?action, to = %to, "Connection");
}
debug!("Default routes:");
for (from, to) in &self.default_routes {
debug!(from = %from, to = %to, "Default route");
}
loop {
let visit_count = visit_counts.entry(current_node_id.clone()).or_insert(0);
*visit_count += 1;
if !self.allow_cycles && *visit_count > 1 {
error!(
node_id = %current_node_id,
"Cycle detected in workflow execution"
);
return Err(WorkflowError::NodeExecution(
FloxideError::WorkflowCycleDetected,
));
}
if self.cycle_limit > 0 && *visit_count > self.cycle_limit {
error!(
node_id = %current_node_id,
visit_count = %visit_count,
limit = %self.cycle_limit,
"Cycle limit exceeded in workflow execution"
);
return Err(WorkflowError::NodeExecution(
FloxideError::WorkflowCycleDetected,
));
}
let node = self.nodes.get(¤t_node_id).ok_or_else(|| {
error!(node_id = %current_node_id, "Node not found in workflow");
WorkflowError::NodeNotFound(current_node_id.clone())
})?;
debug!(node_id = %current_node_id, visit_count = %visit_count, "Executing node");
let outcome = node
.process(ctx)
.await
.map_err(WorkflowError::NodeExecution)?;
match &outcome {
NodeOutcome::Success(_) => {
info!(node_id = %current_node_id, "Node completed successfully with Success outcome");
}
NodeOutcome::Skipped => {
info!(node_id = %current_node_id, "Node completed with Skipped outcome");
}
NodeOutcome::RouteToAction(action) => {
info!(node_id = %current_node_id, action = %action.name(), action_debug = ?action, "Node completed with RouteToAction outcome");
}
}
match outcome {
NodeOutcome::Success(output) => {
info!(node_id = %current_node_id, "Node completed successfully");
if let Some(next) = self.default_routes.get(¤t_node_id) {
debug!(
node_id = %current_node_id,
next_node = %next,
"Following default route"
);
current_node_id = next.clone();
} else {
debug!(node_id = %current_node_id, "Workflow execution completed");
return Ok(output);
}
}
NodeOutcome::Skipped => {
warn!(node_id = %current_node_id, "Node was skipped");
if let Some(next) = self.default_routes.get(¤t_node_id) {
debug!(
node_id = %current_node_id,
next_node = %next,
"Following default route after skip"
);
current_node_id = next.clone();
} else {
warn!(node_id = %current_node_id, "Node was skipped but no default route exists");
return Err(WorkflowError::ActionNotHandled(
"Skipped node without default route".into(),
));
}
}
NodeOutcome::RouteToAction(action) => {
debug!(
node_id = %current_node_id,
action = ?action,
"Node routed to action"
);
if let Some(next) = self.edges.get(&(current_node_id.clone(), action.clone())) {
debug!(
node_id = %current_node_id,
action = ?action,
next_node = %next,
"Following edge for action"
);
current_node_id = next.clone();
}
else if action != A::default() {
if let Some(next) = self.edges.get(&(current_node_id.clone(), A::default()))
{
debug!(
node_id = %current_node_id,
next_node = %next,
"No edge for action, following default action"
);
current_node_id = next.clone();
} else if let Some(next) = self.default_routes.get(¤t_node_id) {
debug!(
node_id = %current_node_id,
next_node = %next,
"No edge for action or default action, following default route"
);
current_node_id = next.clone();
} else {
error!(
node_id = %current_node_id,
action = ?action,
"No edge found for action and no default route"
);
error!(
"Available edges: {:?}",
self.edges
.iter()
.map(|((from, act), to)| format!(
"{} -[{:?}]-> {}",
from, act, to
))
.collect::<Vec<_>>()
);
error!(
"Default routes: {:?}",
self.default_routes
.iter()
.map(|(from, to)| format!("{} -> {}", from, to))
.collect::<Vec<_>>()
);
return Err(WorkflowError::ActionNotHandled(format!("{:?}", action)));
}
} else if let Some(next) = self.default_routes.get(¤t_node_id) {
debug!(
node_id = %current_node_id,
next_node = %next,
"No edge for default action, following default route"
);
current_node_id = next.clone();
} else {
error!(
node_id = %current_node_id,
action = ?action,
"No edge found for default action and no default route"
);
error!(
"Available edges: {:?}",
self.edges
.iter()
.map(|((from, act), to)| format!("{} -[{:?}]-> {}", from, act, to))
.collect::<Vec<_>>()
);
error!(
"Default routes: {:?}",
self.default_routes
.iter()
.map(|(from, to)| format!("{} -> {}", from, to))
.collect::<Vec<_>>()
);
return Err(WorkflowError::ActionNotHandled(
"Default action not handled".into(),
));
}
}
}
}
}
}
impl<Context, A, Output> Clone for Workflow<Context, A, Output>
where
Context: Send + Sync + 'static,
A: ActionType + Clone + Send + Sync + 'static,
Output: Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
start_node: self.start_node.clone(),
nodes: self.nodes.clone(), edges: self.edges.clone(),
default_routes: self.default_routes.clone(),
allow_cycles: self.allow_cycles,
cycle_limit: self.cycle_limit,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::node::closure;
#[derive(Debug, Clone)]
struct TestContext {
value: i32,
visited: Vec<String>,
}
#[tokio::test]
async fn test_simple_linear_workflow() {
let start_node = closure::node(|mut ctx: TestContext| async move {
ctx.value += 1;
ctx.visited.push("start".to_string());
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
});
let middle_node = closure::node(|mut ctx: TestContext| async move {
ctx.value *= 2;
ctx.visited.push("middle".to_string());
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
});
let end_node = closure::node(|mut ctx: TestContext| async move {
ctx.value -= 3;
ctx.visited.push("end".to_string());
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
});
let mut workflow = Workflow::new(start_node);
let start_id = workflow.start_node.clone();
let middle_id = middle_node.id();
let end_id = end_node.id();
workflow
.add_node(middle_node)
.add_node(end_node)
.set_default_route(&start_id, &middle_id)
.set_default_route(&middle_id, &end_id);
let mut ctx = TestContext {
value: 10,
visited: vec![],
};
let result = workflow.execute(&mut ctx).await;
assert!(result.is_ok());
assert_eq!(ctx.value, 19); assert_eq!(ctx.visited, vec!["start", "middle", "end"]);
}
#[tokio::test]
async fn test_workflow_with_routing() {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum TestAction {
Default,
Route1,
Route2,
}
impl Default for TestAction {
fn default() -> Self {
Self::Default
}
}
impl ActionType for TestAction {
fn name(&self) -> &str {
match self {
Self::Default => "default",
Self::Route1 => "route1",
Self::Route2 => "route2",
}
}
}
let start_node = closure::node(|mut ctx: TestContext| async move {
ctx.visited.push("start".to_string());
if ctx.value > 5 {
Ok((
ctx,
NodeOutcome::<(), TestAction>::RouteToAction(TestAction::Route1),
))
} else {
Ok((
ctx,
NodeOutcome::<(), TestAction>::RouteToAction(TestAction::Route2),
))
}
});
let path1_node = closure::node(|mut ctx: TestContext| async move {
ctx.value += 100;
ctx.visited.push("path1".to_string());
Ok((ctx, NodeOutcome::<(), TestAction>::Success(())))
});
let path2_node = closure::node(|mut ctx: TestContext| async move {
ctx.value *= 10;
ctx.visited.push("path2".to_string());
Ok((ctx, NodeOutcome::<(), TestAction>::Success(())))
});
let mut workflow = Workflow::<_, TestAction, _>::new(start_node);
let start_id = workflow.start_node.clone();
let path1_id = path1_node.id();
let path2_id = path2_node.id();
workflow
.add_node(path1_node)
.add_node(path2_node)
.connect(&start_id, TestAction::Route1, &path1_id)
.connect(&start_id, TestAction::Route2, &path2_id);
let mut ctx1 = TestContext {
value: 10,
visited: vec![],
};
let result1 = workflow.execute(&mut ctx1).await;
assert!(result1.is_ok());
assert_eq!(ctx1.value, 110); assert_eq!(ctx1.visited, vec!["start", "path1"]);
let mut ctx2 = TestContext {
value: 3,
visited: vec![],
};
let result2 = workflow.execute(&mut ctx2).await;
assert!(result2.is_ok());
assert_eq!(ctx2.value, 30); assert_eq!(ctx2.visited, vec!["start", "path2"]);
}
#[tokio::test]
async fn test_workflow_with_skipped_node() {
let start_node = closure::node(|mut ctx: TestContext| async move {
ctx.visited.push("start".to_string());
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
});
let skip_node = closure::node(|mut ctx: TestContext| async move {
ctx.visited.push("skip_check".to_string());
if ctx.value > 5 {
Ok((ctx, NodeOutcome::<(), DefaultAction>::Skipped))
} else {
ctx.value *= 2;
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
}
});
let end_node = closure::node(|mut ctx: TestContext| async move {
ctx.visited.push("end".to_string());
ctx.value += 5;
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
});
let mut workflow = Workflow::new(start_node);
let start_id = workflow.start_node.clone();
let skip_id = skip_node.id();
let end_id = end_node.id();
workflow
.add_node(skip_node)
.add_node(end_node)
.set_default_route(&start_id, &skip_id)
.set_default_route(&skip_id, &end_id);
let mut ctx1 = TestContext {
value: 10,
visited: vec![],
};
let result1 = workflow.execute(&mut ctx1).await;
assert!(result1.is_ok());
assert_eq!(ctx1.value, 15); assert_eq!(ctx1.visited, vec!["start", "skip_check", "end"]);
let mut ctx2 = TestContext {
value: 3,
visited: vec![],
};
let result2 = workflow.execute(&mut ctx2).await;
assert!(result2.is_ok());
assert_eq!(ctx2.value, 11); assert_eq!(ctx2.visited, vec!["start", "skip_check", "end"]);
}
#[tokio::test]
async fn test_cyclic_workflow() {
let start_node = closure::node(|mut ctx: TestContext| async move {
ctx.value += 1;
ctx.visited.push("start".to_string());
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
});
let loop_node = closure::node(|mut ctx: TestContext| async move {
ctx.value *= 2;
ctx.visited.push("loop".to_string());
if ctx.value <= 100 {
Ok((
ctx,
NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Next),
))
} else {
Ok((
ctx,
NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Error),
))
}
});
let end_node = closure::node(|mut ctx: TestContext| async move {
ctx.value -= 10;
ctx.visited.push("end".to_string());
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
});
let mut workflow = Workflow::new(start_node);
let start_id = workflow.start_node.clone();
let loop_id = loop_node.id();
let end_id = end_node.id();
workflow
.add_node(loop_node)
.add_node(end_node)
.set_default_route(&start_id, &loop_id)
.connect(&loop_id, DefaultAction::Next, &loop_id) .connect(&loop_id, DefaultAction::Error, &end_id)
.allow_cycles(true) .set_cycle_limit(10);
let mut ctx = TestContext {
value: 3,
visited: vec![],
};
let result = workflow.execute(&mut ctx).await;
assert!(result.is_ok());
assert_eq!(ctx.value, 118);
assert_eq!(ctx.visited.len(), 7);
assert_eq!(ctx.visited[0], "start");
assert_eq!(ctx.visited[1], "loop");
assert_eq!(ctx.visited[2], "loop");
assert_eq!(ctx.visited[3], "loop");
assert_eq!(ctx.visited[4], "loop");
assert_eq!(ctx.visited[5], "loop");
assert_eq!(ctx.visited[6], "end");
let start_node2 = closure::node(|mut ctx: TestContext| async move {
ctx.value += 1;
ctx.visited.push("start".to_string());
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
});
let loop_node2 = closure::node(|mut ctx: TestContext| async move {
ctx.value *= 2;
ctx.visited.push("loop".to_string());
if ctx.value <= 100 {
Ok((
ctx,
NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Next),
))
} else {
Ok((
ctx,
NodeOutcome::<(), DefaultAction>::RouteToAction(DefaultAction::Error),
))
}
});
let end_node2 = closure::node(|mut ctx: TestContext| async move {
ctx.value -= 10;
ctx.visited.push("end".to_string());
Ok((ctx, NodeOutcome::<(), DefaultAction>::Success(())))
});
let mut workflow2 = Workflow::new(start_node2);
let start_id2 = workflow2.start_node.clone();
let loop_id2 = loop_node2.id();
let end_id2 = end_node2.id();
workflow2
.add_node(loop_node2)
.add_node(end_node2)
.set_default_route(&start_id2, &loop_id2)
.connect(&loop_id2, DefaultAction::Next, &loop_id2) .connect(&loop_id2, DefaultAction::Error, &end_id2)
.allow_cycles(false);
let mut ctx2 = TestContext {
value: 3,
visited: vec![],
};
let result2 = workflow2.execute(&mut ctx2).await;
assert!(result2.is_err());
match result2 {
Err(WorkflowError::NodeExecution(FloxideError::WorkflowCycleDetected)) => {
}
_ => panic!("Expected WorkflowCycleDetected error, got {:?}", result2),
}
assert_eq!(ctx2.visited.len(), 2);
assert_eq!(ctx2.visited[0], "start");
assert_eq!(ctx2.visited[1], "loop");
}
}