use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use async_trait::async_trait;
use tracing::warn;
pub trait State: Clone + Send + Sync + 'static {}
impl<T: Clone + Send + Sync + 'static> State for T {}
#[async_trait]
pub trait Node<S: State>: Send + Sync {
fn name(&self) -> &str;
async fn execute(&self, state: S) -> anyhow::Result<S>;
}
type AsyncNodeFn<S> = Box<dyn Fn(S) -> Pin<Box<dyn Future<Output = anyhow::Result<S>> + Send>> + Send + Sync>;
pub struct FnNode<S: State> {
name: String,
func: AsyncNodeFn<S>,
}
impl<S: State> FnNode<S> {
pub fn new(name: &str, func: impl Fn(S) -> Pin<Box<dyn Future<Output = anyhow::Result<S>> + Send>> + Send + Sync + 'static) -> Self {
Self {
name: name.to_string(),
func: Box::new(func),
}
}
}
#[async_trait]
impl<S: State> Node<S> for FnNode<S> {
fn name(&self) -> &str {
&self.name
}
async fn execute(&self, state: S) -> anyhow::Result<S> {
(self.func)(state).await
}
}
pub enum EdgeTarget<S: State> {
Node(String),
Conditional(Box<dyn Fn(&S) -> String + Send + Sync>),
End,
}
const END_SENTINEL: &str = "__end__";
pub struct Workflow<S: State> {
nodes: HashMap<String, Box<dyn Node<S>>>,
edges: HashMap<String, EdgeTarget<S>>,
entry: String,
max_steps: usize,
}
impl<S: State> Workflow<S> {
pub async fn run(&self, initial_state: S) -> anyhow::Result<S> {
let mut state = initial_state;
let mut current = self.entry.clone();
let mut visited: HashSet<String> = HashSet::new();
for step in 0..self.max_steps {
let node = self
.nodes
.get(¤t)
.ok_or_else(|| anyhow::anyhow!("node '{current}' not found in workflow"))?;
if !visited.insert(current.clone()) {
warn!(node = %current, step, "workflow revisited node (possible loop)");
}
state = node.execute(state).await?;
let Some(edge) = self.edges.get(¤t) else {
return Ok(state);
};
match edge {
EdgeTarget::End => return Ok(state),
EdgeTarget::Node(next) => {
current = next.clone();
}
EdgeTarget::Conditional(router) => {
let target = router(&state);
if target == END_SENTINEL {
return Ok(state);
}
current = target;
}
}
}
warn!(max_steps = self.max_steps, "workflow reached max steps limit");
Ok(state)
}
}
pub struct WorkflowBuilder<S: State> {
nodes: HashMap<String, Box<dyn Node<S>>>,
edges: HashMap<String, EdgeTarget<S>>,
entry: Option<String>,
max_steps: usize,
}
impl<S: State> Default for WorkflowBuilder<S> {
fn default() -> Self {
Self::new()
}
}
impl<S: State> WorkflowBuilder<S> {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: HashMap::new(),
entry: None,
max_steps: 100,
}
}
pub fn add_node(mut self, node: impl Node<S> + 'static) -> Self {
let name = node.name().to_string();
self.nodes.insert(name, Box::new(node));
self
}
pub fn add_edge(mut self, from: &str, to: &str) -> Self {
self.edges.insert(from.to_string(), EdgeTarget::Node(to.to_string()));
self
}
pub fn add_conditional_edge(mut self, from: &str, router: impl Fn(&S) -> String + Send + Sync + 'static) -> Self {
self.edges.insert(from.to_string(), EdgeTarget::Conditional(Box::new(router)));
self
}
pub fn set_entry(mut self, name: &str) -> Self {
self.entry = Some(name.to_string());
self
}
pub fn set_end(mut self, from: &str) -> Self {
self.edges.insert(from.to_string(), EdgeTarget::End);
self
}
pub fn max_steps(mut self, max: usize) -> Self {
self.max_steps = max;
self
}
pub fn build(self) -> anyhow::Result<Workflow<S>> {
let entry = self.entry.ok_or_else(|| anyhow::anyhow!("workflow has no entry node — call set_entry()"))?;
if !self.nodes.contains_key(&entry) {
anyhow::bail!("entry node '{entry}' not found in registered nodes");
}
Ok(Workflow {
nodes: self.nodes,
edges: self.edges,
entry,
max_steps: self.max_steps,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tracking_node(name: &str) -> FnNode<Vec<String>> {
let n = name.to_string();
FnNode::new(name, move |mut state: Vec<String>| {
let n = n.clone();
Box::pin(async move {
state.push(n);
Ok(state)
})
})
}
#[tokio::test]
async fn test_linear_workflow() {
let wf = WorkflowBuilder::new()
.add_node(tracking_node("a"))
.add_node(tracking_node("b"))
.add_node(tracking_node("c"))
.add_edge("a", "b")
.add_edge("b", "c")
.set_entry("a")
.set_end("c")
.build()
.unwrap();
let result = wf.run(vec![]).await.unwrap();
assert_eq!(result, vec!["a", "b", "c"]);
}
#[tokio::test]
async fn test_conditional_edge() {
let wf = WorkflowBuilder::new()
.add_node(tracking_node("start"))
.add_node(tracking_node("left"))
.add_node(tracking_node("right"))
.add_conditional_edge("start", |state: &Vec<String>| {
if state.len() == 1 {
"right".to_string()
} else {
"left".to_string()
}
})
.set_entry("start")
.set_end("left")
.set_end("right")
.build()
.unwrap();
let result = wf.run(vec![]).await.unwrap();
assert_eq!(result, vec!["start", "right"]);
}
#[tokio::test]
async fn test_terminates_at_end() {
let wf = WorkflowBuilder::new()
.add_node(tracking_node("only"))
.set_entry("only")
.set_end("only")
.build()
.unwrap();
let result = wf.run(vec![]).await.unwrap();
assert_eq!(result, vec!["only"]);
}
#[tokio::test]
async fn test_missing_entry_node_error() {
let res = WorkflowBuilder::<Vec<String>>::new().build();
let err = res.err().expect("expected build error");
assert!(err.to_string().contains("no entry node"));
}
#[tokio::test]
async fn test_max_steps_prevents_infinite_loop() {
let wf = WorkflowBuilder::new()
.add_node(tracking_node("a"))
.add_node(tracking_node("b"))
.add_edge("a", "b")
.add_edge("b", "a")
.set_entry("a")
.max_steps(6)
.build()
.unwrap();
let result = wf.run(vec![]).await.unwrap();
assert_eq!(result, vec!["a", "b", "a", "b", "a", "b"]);
}
#[tokio::test]
async fn test_state_mutation() {
let add_ten = FnNode::new("add_ten", |state: i64| Box::pin(async move { Ok(state + 10) }));
let double = FnNode::new("double", |state: i64| Box::pin(async move { Ok(state * 2) }));
let wf = WorkflowBuilder::new()
.add_node(add_ten)
.add_node(double)
.add_edge("add_ten", "double")
.set_entry("add_ten")
.set_end("double")
.build()
.unwrap();
let result = wf.run(5).await.unwrap();
assert_eq!(result, 30); }
#[tokio::test]
async fn test_node_failure_propagates() {
let fail_node = FnNode::new("fail", |_state: Vec<String>| Box::pin(async { anyhow::bail!("something went wrong") }));
let wf = WorkflowBuilder::new().add_node(fail_node).set_entry("fail").set_end("fail").build().unwrap();
let err = wf.run(vec![]).await.unwrap_err();
assert!(err.to_string().contains("something went wrong"));
}
#[tokio::test]
async fn test_fn_node_closure() {
let prefix = "hello_".to_string();
let node = FnNode::new("greet", move |state: String| {
let prefix = prefix.clone();
Box::pin(async move { Ok(format!("{prefix}{state}")) })
});
let wf = WorkflowBuilder::new().add_node(node).set_entry("greet").set_end("greet").build().unwrap();
let result = wf.run("world".to_string()).await.unwrap();
assert_eq!(result, "hello_world");
}
}