use std::collections::HashSet;
use crate::error::{Result, RustAgentsError};
use crate::graph::{Node, StateGraph};
use crate::language::types::{Blueprint, ChannelSpec, END, EdgeSpec, NodeSpec, Program, Routing};
pub fn compile(program: &Program) -> Result<Vec<Blueprint>> {
program.graphs.iter().map(compile_graph).collect()
}
fn compile_graph(graph: &crate::language::types::GraphDecl) -> Result<Blueprint> {
let compile_err = |msg: String| RustAgentsError::Compile(msg);
let mut node_names: HashSet<&str> = HashSet::new();
for node in &graph.nodes {
if !node_names.insert(node.name.as_str()) {
return Err(compile_err(format!(
"duplicate node `{}` in graph `{}`",
node.name, graph.name
)));
}
}
let target_ok = |target: &str| target == END || node_names.contains(target);
let start = graph
.start
.clone()
.ok_or_else(|| compile_err(format!("graph `{}` has no `start` node", graph.name)))?;
if !node_names.contains(start.as_str()) {
return Err(compile_err(format!(
"start node `{start}` is not defined in graph `{}`",
graph.name
)));
}
let mut edges = Vec::new();
for edge in &graph.edges {
if !node_names.contains(edge.from.as_str()) {
return Err(compile_err(format!(
"edge source `{}` does not exist in graph `{}`",
edge.from, graph.name
)));
}
if !target_ok(&edge.to) {
return Err(compile_err(format!(
"edge target `{}` does not exist in graph `{}`",
edge.to, graph.name
)));
}
edges.push(EdgeSpec {
from: edge.from.clone(),
to: edge.to.clone(),
});
}
let nodes_with_static_edge: HashSet<&str> =
graph.edges.iter().map(|e| e.from.as_str()).collect();
let mut nodes = Vec::new();
for node in &graph.nodes {
let has_routes = !node.routes.is_empty();
let has_next = node.next.is_some();
let has_static_edge = nodes_with_static_edge.contains(node.name.as_str());
if has_routes && (has_next || has_static_edge) {
return Err(compile_err(format!(
"node `{}` mixes static routing (`next`/edge) with command routing (`routes`); use one or the other",
node.name
)));
}
let mut seen_labels: HashSet<&str> = HashSet::new();
for route in &node.routes {
if !seen_labels.insert(route.label.as_str()) {
return Err(compile_err(format!(
"duplicate route label `{}` on node `{}`",
route.label, node.name
)));
}
if !target_ok(&route.target) {
return Err(compile_err(format!(
"route target `{}` on node `{}` does not exist",
route.target, node.name
)));
}
}
if let Some(next) = &node.next
&& !target_ok(next)
{
return Err(compile_err(format!(
"next target `{next}` on node `{}` does not exist",
node.name
)));
}
let routing = if has_routes {
Routing::Conditional(
node.routes
.iter()
.map(|r| (r.label.clone(), r.target.clone()))
.collect(),
)
} else if let Some(next) = &node.next {
if next == END {
Routing::Terminal
} else {
Routing::Next(next.clone())
}
} else if let Some(edge) = graph.edges.iter().find(|e| e.from == node.name) {
if edge.to == END {
Routing::Terminal
} else {
Routing::Next(edge.to.clone())
}
} else {
Routing::Terminal
};
nodes.push(NodeSpec {
name: node.name.clone(),
kind: node.kind.clone().unwrap_or_else(|| "model".to_string()),
model: node.model.clone(),
prompt: node.prompt.clone(),
tools: node.tools.clone(),
routing,
});
}
let channels = graph
.channels
.iter()
.map(|c| ChannelSpec {
name: c.name.clone(),
reducer: c.reducer.clone(),
})
.collect();
Ok(Blueprint {
graph_id: graph.name.clone(),
start,
channels,
nodes,
edges,
defaults: graph.defaults.clone(),
})
}
#[derive(Clone, Debug, Default)]
pub struct CapabilityResolver {
models: HashSet<String>,
tools: HashSet<String>,
}
impl CapabilityResolver {
pub fn new() -> Self {
Self::default()
}
pub fn from_lists<M, T>(models: M, tools: T) -> Self
where
M: IntoIterator<Item = String>,
T: IntoIterator<Item = String>,
{
Self {
models: models.into_iter().collect(),
tools: tools.into_iter().collect(),
}
}
pub fn allow_model(mut self, name: impl Into<String>) -> Self {
self.models.insert(name.into());
self
}
pub fn allow_tool(mut self, name: impl Into<String>) -> Self {
self.tools.insert(name.into());
self
}
pub fn model_allowed(&self, name: &str) -> bool {
self.models.contains(name)
}
pub fn tool_allowed(&self, name: &str) -> bool {
self.tools.contains(name)
}
}
pub fn bind_capabilities(blueprint: &Blueprint, allow: &CapabilityResolver) -> Result<()> {
for node in &blueprint.nodes {
if let Some(model) = &node.model
&& !allow.model_allowed(model)
{
return Err(RustAgentsError::Capability(format!(
"node `{}` references unknown model `{model}`",
node.name
)));
}
for tool in &node.tools {
if !allow.tool_allowed(tool) {
return Err(RustAgentsError::Capability(format!(
"node `{}` references unknown tool `{tool}`",
node.name
)));
}
}
}
Ok(())
}
pub trait NodeFactory<State> {
fn make(&self, spec: &NodeSpec) -> Result<Node<State>>;
}
pub fn build_graph<State, F>(blueprint: &Blueprint, factory: &F) -> Result<StateGraph<State>>
where
State: Clone + Send + 'static,
F: NodeFactory<State>,
{
let mut graph = StateGraph::new().set_start(&blueprint.start);
for spec in &blueprint.nodes {
let node = factory.make(spec)?;
graph = graph.add_node(node);
graph = match &spec.routing {
Routing::Next(target) => graph.add_edge(&spec.name, target),
Routing::Conditional(routes) => graph.add_conditional_edges(
&spec.name,
routes
.iter()
.filter(|(_, target)| target != END)
.map(|(label, target)| (label.clone(), target.clone())),
),
Routing::Terminal => graph.add_end(&spec.name),
};
}
Ok(graph)
}