use std::sync::Arc;
use tinyagents::graph::{END, NodeFuture};
use tinyagents::language::compiler::{
BoxedNode, CapabilityResolver, NodeFactory, bind_capabilities, build_graph, compile,
};
use tinyagents::language::parser::parse_str;
use tinyagents::language::types::{END as LANG_END, NodeSpec, Routing};
use tinyagents::{Command, NodeContext, NodeResult, Result};
const SUPPORT_AGENT: &str = r#"
graph support_agent {
start agent
defaults {
recursion_limit 50
checkpoint inherit
}
channel messages messages
channel tool_calls append
node agent {
kind agent
model "default"
system "Resolve support requests using tools when useful."
tools ["lookup_user", "create_ticket"]
routes {
tool_call -> tools
final -> END
}
}
node tools {
kind tool_executor
next agent
}
}
"#;
#[derive(Clone, Debug, Default)]
struct RagState {
trail: Vec<String>,
agent_turns: u32,
tool_runs: u32,
}
fn resolve_route(routes: &[(String, String)], label: &str) -> String {
match routes
.iter()
.find(|(l, _)| l == label)
.map(|(_, t)| t.as_str())
{
Some(t) if t != LANG_END => t.to_string(),
_ => END.to_string(),
}
}
struct RagFactory;
impl NodeFactory<RagState> for RagFactory {
fn make(&self, spec: &NodeSpec) -> Result<BoxedNode<RagState>> {
let name = spec.name.clone();
let kind = spec.kind.clone();
let routing = spec.routing.clone();
Ok(Arc::new(
move |mut state: RagState, _ctx: NodeContext| -> NodeFuture<RagState> {
let name = name.clone();
let kind = kind.clone();
let routing = routing.clone();
Box::pin(async move {
state.trail.push(name.clone());
let result = match (kind.as_str(), &routing) {
("agent", Routing::Conditional(routes)) => {
state.agent_turns += 1;
let label = if state.agent_turns >= 2 {
"final"
} else {
"tool_call"
};
let target = resolve_route(routes, label);
NodeResult::Command(Command::goto([target]).with_update(state))
}
("tool_executor", Routing::Next(_)) => {
state.tool_runs += 1;
NodeResult::Update(state)
}
(_, Routing::Terminal | Routing::Next(_)) => NodeResult::Update(state),
(_, Routing::Conditional(routes)) => {
let target = resolve_route(routes, "final");
NodeResult::Command(Command::goto([target]).with_update(state))
}
};
Ok(result)
})
},
))
}
}
#[tokio::test]
async fn rag_source_compiles_binds_and_runs_to_end() {
let program = parse_str(SUPPORT_AGENT).expect("source parses");
let blueprint = compile(&program).expect("program compiles").remove(0);
assert_eq!(blueprint.graph_id, "support_agent");
assert_eq!(blueprint.start, "agent");
assert_eq!(blueprint.nodes.len(), 2);
assert_eq!(blueprint.nodes[0].kind, "agent");
assert_eq!(blueprint.nodes[1].kind, "tool_executor");
assert_eq!(
blueprint.nodes[1].routing,
Routing::Next("agent".to_string())
);
let resolver = CapabilityResolver::new()
.allow_model("default")
.allow_tool("lookup_user")
.allow_tool("create_ticket");
bind_capabilities(&blueprint, &resolver).expect("capabilities resolve");
let graph = build_graph(&blueprint, &RagFactory).expect("graph builds");
let run = graph
.run(RagState::default())
.await
.expect("graph runs to END");
let visited: Vec<String> = run.visited.iter().map(ToString::to_string).collect();
assert_eq!(visited, vec!["agent", "tools", "agent"]);
assert_eq!(run.state.trail, vec!["agent", "tools", "agent"]);
assert_eq!(run.state.agent_turns, 2);
assert_eq!(run.state.tool_runs, 1);
}