use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tinyagents::graph::ClosureStateReducer;
use tinyagents::{Command, CompiledGraph, GraphBuilder, NodeContext, NodeResult};
#[derive(Clone, Debug, Default, PartialEq, Eq)]
struct AgentState {
values: Vec<i32>,
forks: Vec<usize>,
workers: Vec<String>,
total: Option<i32>,
}
#[derive(Clone, Debug)]
enum AgentUpdate {
Work {
value: i32,
fork: usize,
worker: String,
},
Total(i32),
}
#[derive(Clone)]
struct Inflight {
current: Arc<AtomicUsize>,
max: Arc<AtomicUsize>,
}
impl Inflight {
fn new() -> Self {
Self {
current: Arc::new(AtomicUsize::new(0)),
max: Arc::new(AtomicUsize::new(0)),
}
}
fn max_observed(&self) -> usize {
self.max.load(Ordering::SeqCst)
}
async fn run<T>(&self, sleep: Duration, value: T) -> T {
let now = self.current.fetch_add(1, Ordering::SeqCst) + 1;
self.max.fetch_max(now, Ordering::SeqCst);
tokio::time::sleep(sleep).await;
self.current.fetch_sub(1, Ordering::SeqCst);
value
}
}
fn fork_index(ctx: &NodeContext) -> usize {
ctx.fork
.as_ref()
.map(|f| f.branch_index)
.unwrap_or(usize::MAX)
}
fn fanout_graph(parallel: bool, inflight: Inflight) -> CompiledGraph<AgentState, AgentUpdate> {
let workers = [
("worker-a", 1, 80u64),
("worker-b", 2, 60),
("worker-c", 4, 40),
("worker-d", 8, 20),
];
let mut builder = GraphBuilder::<AgentState, AgentUpdate>::new()
.with_parallel(parallel)
.set_reducer(ClosureStateReducer::new(
|mut s: AgentState, u: AgentUpdate| {
match u {
AgentUpdate::Work {
value,
fork,
worker,
} => {
s.values.push(value);
s.forks.push(fork);
s.workers.push(worker);
}
AgentUpdate::Total(t) => s.total = Some(t),
}
Ok(s)
},
))
.add_node("dispatch", |_s: AgentState, _c: NodeContext| async move {
Ok(NodeResult::Command(Command::default().with_goto([
"worker-a", "worker-b", "worker-c", "worker-d",
])))
});
for (name, value, sleep_ms) in workers {
let inflight = inflight.clone();
builder = builder.add_node(name, move |_s: AgentState, c: NodeContext| {
let inflight = inflight.clone();
let fork = fork_index(&c);
async move {
let update = AgentUpdate::Work {
value,
fork,
worker: name.to_string(),
};
Ok(NodeResult::Update(
inflight.run(Duration::from_millis(sleep_ms), update).await,
))
}
});
}
builder = builder
.add_node("aggregate", |s: AgentState, _c: NodeContext| async move {
Ok(NodeResult::Update(AgentUpdate::Total(
s.values.iter().sum(),
)))
})
.set_entry("dispatch")
.mark_command_routing("dispatch")
.add_edge("worker-a", "aggregate")
.add_edge("worker-b", "aggregate")
.add_edge("worker-c", "aggregate")
.add_edge("worker-d", "aggregate")
.set_finish("aggregate");
builder.compile().expect("graph compiles")
}
#[tokio::test]
async fn parallel_fanout_merges_all_branches_and_downstream_sees_them() {
let inflight = Inflight::new();
let graph = fanout_graph(true, inflight.clone());
let run = graph.run(AgentState::default()).await.unwrap();
assert_eq!(
inflight.max_observed(),
4,
"parallel mode should run every branch concurrently"
);
assert_eq!(run.state.values, vec![1, 2, 4, 8]);
assert_eq!(
run.state.workers,
vec!["worker-a", "worker-b", "worker-c", "worker-d"],
);
assert_eq!(run.state.forks, vec![0, 1, 2, 3]);
assert_eq!(run.state.total, Some(15));
assert_eq!(run.steps, 3);
for node in [
"dispatch",
"worker-a",
"worker-b",
"worker-c",
"worker-d",
"aggregate",
] {
assert!(
run.visited.iter().any(|n| n.as_str() == node),
"expected `{node}` in visited history {:?}",
run.visited,
);
}
}
#[tokio::test]
async fn sequential_mode_runs_one_branch_at_a_time_without_fork_identity() {
let inflight = Inflight::new();
let graph = fanout_graph(false, inflight.clone());
let run = graph.run(AgentState::default()).await.unwrap();
assert_eq!(
inflight.max_observed(),
1,
"sequential mode must serialize branches"
);
assert_eq!(
run.state.forks,
vec![usize::MAX, usize::MAX, usize::MAX, usize::MAX],
);
}
#[tokio::test]
async fn parallel_and_sequential_reach_the_same_final_state() {
let parallel = fanout_graph(true, Inflight::new())
.run(AgentState::default())
.await
.unwrap();
let sequential = fanout_graph(false, Inflight::new())
.run(AgentState::default())
.await
.unwrap();
assert_eq!(parallel.state.values, sequential.state.values);
assert_eq!(parallel.state.workers, sequential.state.workers);
assert_eq!(parallel.state.total, sequential.state.total);
assert_eq!(parallel.steps, sequential.steps);
assert_ne!(parallel.state.forks, sequential.state.forks);
}
#[tokio::test]
async fn parallel_merge_order_is_reproducible_across_runs() {
for _ in 0..5 {
let run = fanout_graph(true, Inflight::new())
.run(AgentState::default())
.await
.unwrap();
assert_eq!(run.state.values, vec![1, 2, 4, 8]);
assert_eq!(run.state.forks, vec![0, 1, 2, 3]);
assert_eq!(run.state.total, Some(15));
}
}