mod types;
pub use types::{
GraphAssertions, GraphEventRecorder, GraphRun, RetryCountingNode, StreamCollector,
};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use serde_json::Value;
use crate::graph::builder::{NodeContext, NodeFuture};
use crate::graph::command::{Command, Interrupt, NodeResult, Send as SendPacket};
use crate::graph::compiled::{CompiledGraph, GraphExecution, StateSnapshot};
use crate::graph::recursion::ChildRun;
use crate::graph::stream::{CollectingSink, GraphEvent, GraphEventSink};
use crate::harness::ids::{GraphId, NodeId, RunId};
use crate::harness::usage::UsageTotals;
use crate::{Result, TinyAgentsError};
pub fn noop_node<State, Update>()
-> impl Fn(State, NodeContext) -> NodeFuture<Update> + Send + Sync + 'static
where
State: Send + 'static,
Update: Send + 'static,
{
move |_state, _ctx| -> NodeFuture<Update> {
Box::pin(async move { Ok(NodeResult::Command(Command::new())) })
}
}
pub fn scripted_update_node<State, Update>(
updates: impl IntoIterator<Item = Update>,
) -> impl Fn(State, NodeContext) -> NodeFuture<Update> + Send + Sync + 'static
where
State: Send + 'static,
Update: Clone + Send + Sync + 'static,
{
let updates: Arc<Vec<Update>> = Arc::new(updates.into_iter().collect());
let idx = Arc::new(AtomicUsize::new(0));
move |_state, _ctx| -> NodeFuture<Update> {
let updates = updates.clone();
let idx = idx.clone();
Box::pin(async move {
if updates.is_empty() {
return Err(TinyAgentsError::Graph(
"scripted_update_node has no scripted updates".to_string(),
));
}
let i = idx.fetch_add(1, Ordering::Relaxed).min(updates.len() - 1);
Ok(NodeResult::Update(updates[i].clone()))
})
}
}
pub fn scripted_route_node<State, Update, I, R, N>(
routes: I,
) -> impl Fn(State, NodeContext) -> NodeFuture<Update> + Send + Sync + 'static
where
State: Send + 'static,
Update: Send + 'static,
I: IntoIterator<Item = R>,
R: IntoIterator<Item = N>,
N: Into<NodeId>,
{
let routes: Arc<Vec<Vec<NodeId>>> = Arc::new(
routes
.into_iter()
.map(|r| r.into_iter().map(Into::into).collect())
.collect(),
);
let idx = Arc::new(AtomicUsize::new(0));
move |_state, _ctx| -> NodeFuture<Update> {
let routes = routes.clone();
let idx = idx.clone();
Box::pin(async move {
if routes.is_empty() {
return Err(TinyAgentsError::Graph(
"scripted_route_node has no scripted routes".to_string(),
));
}
let i = idx.fetch_add(1, Ordering::Relaxed).min(routes.len() - 1);
Ok(NodeResult::Command(Command::goto(routes[i].clone())))
})
}
}
pub fn fanout_node<State, Update>(
target: impl Into<NodeId>,
args: impl IntoIterator<Item = Value>,
) -> impl Fn(State, NodeContext) -> NodeFuture<Update> + Send + Sync + 'static
where
State: Send + 'static,
Update: Send + 'static,
{
let target = target.into();
let args: Arc<Vec<Value>> = Arc::new(args.into_iter().collect());
move |_state, _ctx| -> NodeFuture<Update> {
let target = target.clone();
let args = args.clone();
Box::pin(async move {
let sends: Vec<SendPacket> = args
.iter()
.map(|a| SendPacket::new(target.clone(), a.clone()))
.collect();
Ok(NodeResult::Command(Command::send(sends)))
})
}
}
pub fn failing_node<State, Update>(
message: impl Into<String>,
) -> impl Fn(State, NodeContext) -> NodeFuture<Update> + Send + Sync + 'static
where
State: Send + 'static,
Update: Send + 'static,
{
let message = message.into();
move |_state, _ctx| -> NodeFuture<Update> {
let message = message.clone();
Box::pin(async move { Err(TinyAgentsError::Graph(message)) })
}
}
pub fn interrupting_node<State, Update>(
payload: Value,
on_resume: Update,
) -> impl Fn(State, NodeContext) -> NodeFuture<Update> + Send + Sync + 'static
where
State: Send + 'static,
Update: Clone + Send + Sync + 'static,
{
move |_state, ctx: NodeContext| -> NodeFuture<Update> {
let payload = payload.clone();
let on_resume = on_resume.clone();
Box::pin(async move {
match ctx.resume {
Some(_) => Ok(NodeResult::Update(on_resume)),
None => Ok(NodeResult::Interrupt(Interrupt::new(
ctx.node_id.clone(),
payload,
))),
}
})
}
}
pub fn subgraph_test_node<State>(
child: CompiledGraph<State, State>,
) -> Box<dyn Fn(State, NodeContext) -> NodeFuture<State> + Send + Sync>
where
State: Clone + Send + Sync + 'static,
{
crate::graph::subgraph::shared_subgraph_node(child)
}
pub fn subagent_fake_node<State, Update>(
agent: impl Into<String>,
update: Update,
usage: UsageTotals,
) -> impl Fn(State, NodeContext) -> NodeFuture<Update> + Send + Sync + 'static
where
State: Send + 'static,
Update: Clone + Send + Sync + 'static,
{
let agent = agent.into();
move |_state, ctx: NodeContext| -> NodeFuture<Update> {
let agent = agent.clone();
let update = update.clone();
Box::pin(async move {
if let Some(sink) = &ctx.child_runs {
let root_run_id = ctx
.root_run_id
.clone()
.unwrap_or_else(|| ctx.run_id.clone());
sink.record(ChildRun {
node: ctx.node_id.clone(),
graph_id: GraphId::new(format!("agent:{agent}")),
run_id: RunId::new(format!(
"subagent-fake-{}",
crate::graph::compiled::next_seq()
)),
root_run_id,
usage,
});
}
Ok(NodeResult::Update(update))
})
}
}
impl RetryCountingNode {
pub fn new(fail_times: usize) -> Self {
Self {
attempts: Arc::new(AtomicUsize::new(0)),
fail_times,
}
}
pub fn attempts(&self) -> usize {
self.attempts.load(Ordering::Relaxed)
}
pub fn handler<State, Update>(
&self,
success: Update,
) -> impl Fn(State, NodeContext) -> NodeFuture<Update> + Send + Sync + 'static
where
State: Send + 'static,
Update: Clone + Send + Sync + 'static,
{
let attempts = self.attempts.clone();
let fail_times = self.fail_times;
move |_state, _ctx| -> NodeFuture<Update> {
let attempts = attempts.clone();
let success = success.clone();
Box::pin(async move {
let n = attempts.fetch_add(1, Ordering::Relaxed) + 1;
if n <= fail_times {
Err(TinyAgentsError::Graph(format!(
"retry_counting_node: attempt {n} of {fail_times} failing"
)))
} else {
Ok(NodeResult::Update(success))
}
})
}
}
}
impl GraphEventRecorder {
pub fn new() -> Self {
Self {
sink: CollectingSink::new(),
}
}
pub fn sink(&self) -> Arc<dyn GraphEventSink> {
Arc::new(self.sink.clone())
}
pub fn events(&self) -> Vec<GraphEvent> {
self.sink.events()
}
pub fn kinds(&self) -> Vec<String> {
self.sink
.events()
.iter()
.map(|e| e.kind().to_string())
.collect()
}
pub fn collector(&self) -> StreamCollector {
StreamCollector::new(self.sink.events())
}
}
impl StreamCollector {
pub fn new(events: Vec<GraphEvent>) -> Self {
Self { events }
}
pub fn events(&self) -> &[GraphEvent] {
&self.events
}
pub fn node_order(&self) -> Vec<NodeId> {
self.events
.iter()
.filter_map(|e| match e {
GraphEvent::NodeCompleted { node, .. } => Some(node.clone()),
_ => None,
})
.collect()
}
pub fn updates(&self) -> Vec<NodeId> {
self.events
.iter()
.filter_map(|e| match e {
GraphEvent::StateUpdated { node, .. } => Some(node.clone()),
_ => None,
})
.collect()
}
pub fn routes(&self) -> Vec<(NodeId, NodeId)> {
self.events
.iter()
.filter_map(|e| match e {
GraphEvent::RouteSelected { node, target } => Some((node.clone(), target.clone())),
_ => None,
})
.collect()
}
pub fn interrupts(&self) -> Vec<Interrupt> {
self.events
.iter()
.filter_map(|e| match e {
GraphEvent::InterruptEmitted { interrupt } => Some(interrupt.clone()),
_ => None,
})
.collect()
}
pub fn checkpoint_count(&self) -> usize {
self.events
.iter()
.filter(|e| matches!(e, GraphEvent::CheckpointSaved { .. }))
.count()
}
pub fn custom(&self) -> Vec<(String, Value)> {
self.events
.iter()
.filter_map(|e| match e {
GraphEvent::Custom { name, data } => Some((name.clone(), data.clone())),
_ => None,
})
.collect()
}
}
impl<State> GraphRun<State> {
pub fn new(execution: GraphExecution<State>) -> Self {
Self {
execution,
events: Vec::new(),
history: Vec::new(),
}
}
pub fn with_events(mut self, events: Vec<GraphEvent>) -> Self {
self.events = events;
self
}
pub fn with_history(mut self, history: Vec<StateSnapshot<State>>) -> Self {
self.history = history;
self
}
pub fn collector(&self) -> StreamCollector {
StreamCollector::new(self.events.clone())
}
}
pub async fn run_recorded<State, Update>(
graph: &CompiledGraph<State, Update>,
thread: Option<&str>,
state: State,
) -> Result<GraphRun<State>>
where
State: Clone + Send + Sync + 'static,
Update: Send + 'static,
{
let recorder = GraphEventRecorder::new();
let graph = graph.clone().with_event_sink(recorder.sink());
let execution = match thread {
Some(thread) => graph.run_with_thread(thread, state).await?,
None => graph.run(state).await?,
};
let history = match thread {
Some(thread) => graph
.get_state_history(thread, None)
.await
.unwrap_or_default(),
None => Vec::new(),
};
Ok(GraphRun {
execution,
events: recorder.events(),
history,
})
}
pub fn assert_graph<State>(run: &GraphRun<State>) -> GraphAssertions<'_, State> {
GraphAssertions { run }
}
impl<State> GraphAssertions<'_, State> {
pub fn visited<I, N>(&self, expected: I) -> &Self
where
I: IntoIterator<Item = N>,
N: Into<NodeId>,
{
let expected: Vec<NodeId> = expected.into_iter().map(Into::into).collect();
assert_eq!(
self.run.execution.visited, expected,
"assert_graph: expected visited {expected:?} but run visited {:?}",
self.run.execution.visited
);
self
}
pub fn routed(&self, from: impl Into<NodeId>, to: impl Into<NodeId>) -> &Self {
let from = from.into();
let to = to.into();
let by_event = self.run.events.iter().any(|e| {
matches!(
e,
GraphEvent::RouteSelected { node, target }
if *node == from && *target == to
)
});
let by_visited = || {
self.run
.execution
.visited
.windows(2)
.any(|w| w[0] == from && w[1] == to)
};
assert!(
by_event || (self.run.events.is_empty() && by_visited()),
"assert_graph: expected a route from `{from}` to `{to}` but none was found"
);
self
}
pub fn checkpoint_count(&self, n: usize) -> &Self {
let count = if self.run.events.is_empty() {
self.run.history.len()
} else {
self.run.collector().checkpoint_count()
};
assert_eq!(
count, n,
"assert_graph: expected {n} checkpoint(s) but found {count}"
);
self
}
pub fn state_history(&self, f: impl FnOnce(&[StateSnapshot<State>])) -> &Self {
f(&self.run.history);
self
}
pub fn checkpoint(&self, f: impl FnOnce(&StateSnapshot<State>)) -> &Self {
let latest = self
.run
.history
.first()
.expect("assert_graph: expected a checkpoint but the run history is empty");
f(latest);
self
}
pub fn completed(&self) -> &Self {
assert!(
!self.run.execution.is_interrupted(),
"assert_graph: expected the run to complete but it was interrupted: {:?}",
self.run.execution.interrupts
);
assert_eq!(
self.run.execution.status.status,
crate::harness::ids::ExecutionStatus::Completed,
"assert_graph: expected a Completed status but found {:?}",
self.run.execution.status.status
);
self
}
pub fn interrupted(&self) -> &Self {
assert!(
self.run.execution.is_interrupted(),
"assert_graph: expected the run to be interrupted but it completed"
);
self
}
}
#[cfg(test)]
mod test;