use rustagents::language::compiler::{
CapabilityResolver, NodeFactory, bind_capabilities, build_graph, compile,
};
use rustagents::language::parser::parse_str;
use rustagents::language::types::{NodeSpec, Routing};
use rustagents::{Node, NodeOutput, Result, RustAgentsError};
const SUPPORT_AGENT: &str = r#"
// A support workflow with a tool loop.
graph support_agent {
start agent
defaults {
recursion_limit 50
backoff "exponential"
checkpoint inherit
}
channel messages messages
channel tool_calls append
node agent {
kind agent
model "default"
system "Resolve support requests using tools when useful."
tools ["lookup_user", "create_ticket"]
routes {
tool_call -> tools
final -> END
}
}
node tools {
kind tool_executor
next agent
}
}
"#;
#[test]
fn compiles_support_agent_blueprint_structure() {
let program = parse_str(SUPPORT_AGENT).expect("source parses");
assert_eq!(program.graphs.len(), 1);
let blueprint = compile(&program).expect("program compiles").remove(0);
assert_eq!(blueprint.graph_id, "support_agent");
assert_eq!(blueprint.start, "agent");
assert_eq!(blueprint.nodes.len(), 2);
let agent = &blueprint.nodes[0];
assert_eq!(agent.name, "agent");
assert_eq!(agent.kind, "agent");
assert_eq!(agent.model.as_deref(), Some("default"));
assert_eq!(agent.tools, vec!["lookup_user", "create_ticket"]);
match &agent.routing {
Routing::Conditional(routes) => {
assert_eq!(routes.len(), 2);
assert!(routes.contains(&("tool_call".to_string(), "tools".to_string())));
}
other => panic!("expected conditional routing on `agent`, got {other:?}"),
}
let tools = &blueprint.nodes[1];
assert_eq!(tools.name, "tools");
assert_eq!(tools.routing, Routing::Next("agent".to_string()));
let resolver = CapabilityResolver::new()
.allow_model("default")
.allow_tool("lookup_user")
.allow_tool("create_ticket");
bind_capabilities(&blueprint, &resolver).expect("capabilities resolve");
}
#[test]
fn bind_capabilities_rejects_unknown_tool() {
let blueprint = compile(&parse_str(SUPPORT_AGENT).unwrap())
.unwrap()
.remove(0);
let resolver = CapabilityResolver::new()
.allow_model("default")
.allow_tool("lookup_user");
let err = bind_capabilities(&blueprint, &resolver).expect_err("create_ticket is not allowed");
match err {
RustAgentsError::Capability(msg) => assert!(msg.contains("create_ticket"), "{msg}"),
other => panic!("expected Capability error, got {other:?}"),
}
}
#[test]
fn missing_start_is_a_compile_error() {
let program = parse_str("graph no_start { node a { kind model } }").expect("parses");
let err = compile(&program).expect_err("a graph without `start` cannot compile");
assert!(matches!(err, RustAgentsError::Compile(_)), "got {err:?}");
}
#[test]
fn duplicate_node_is_a_compile_error() {
let src = "graph dupes { start a node a { kind model } node a { kind model } }";
let program = parse_str(src).expect("parses");
let err = compile(&program).expect_err("duplicate node names cannot compile");
match err {
RustAgentsError::Compile(msg) => assert!(msg.contains("duplicate"), "{msg}"),
other => panic!("expected Compile error, got {other:?}"),
}
}
#[test]
fn unknown_route_target_is_a_compile_error() {
let src = "graph bad_route { start a node a { routes { go -> ghost } } }";
let program = parse_str(src).expect("parses");
let err = compile(&program).expect_err("routing to a missing node cannot compile");
assert!(matches!(err, RustAgentsError::Compile(_)), "got {err:?}");
}
#[derive(Clone, Debug, Default)]
struct TraceState {
trail: Vec<String>,
agent_visits: u32,
}
struct TraceFactory;
impl NodeFactory<TraceState> for TraceFactory {
fn make(&self, spec: &NodeSpec) -> Result<Node<TraceState>> {
let name = spec.name.clone();
let routing = spec.routing.clone();
Ok(Node::new(name.clone(), move |mut state: TraceState| {
let name = name.clone();
let routing = routing.clone();
async move {
state.trail.push(name.clone());
let output = match &routing {
Routing::Terminal => NodeOutput::end(state),
Routing::Next(_) => NodeOutput::continue_with(state),
Routing::Conditional(_) => {
state.agent_visits += 1;
if state.agent_visits >= 2 {
NodeOutput::end(state)
} else {
NodeOutput::route(state, "tool_call")
}
}
};
Ok(output)
}
}))
}
}
#[tokio::test]
async fn build_graph_runs_to_end() {
let blueprint = compile(&parse_str(SUPPORT_AGENT).unwrap())
.unwrap()
.remove(0);
let graph = build_graph(&blueprint, &TraceFactory).expect("graph builds");
let run = graph.run(TraceState::default()).await.expect("graph runs");
assert_eq!(run.visited, vec!["agent", "tools", "agent"]);
assert_eq!(run.state.trail, vec!["agent", "tools", "agent"]);
assert_eq!(run.state.agent_visits, 2);
}