use std::collections::HashSet;
use crate::error::{Result, TinyAgentsError};
use crate::graph::{Node, StateGraph};
use crate::language::parser::parse_str;
use crate::language::types::{Blueprint, ChannelSpec, END, EdgeSpec, NodeSpec, Program, Routing};
use crate::registry::CapabilityRegistry;
pub fn compile(program: &Program) -> Result<Vec<Blueprint>> {
program.graphs.iter().map(compile_graph).collect()
}
fn compile_graph(graph: &crate::language::types::GraphDecl) -> Result<Blueprint> {
let compile_err = |msg: String| TinyAgentsError::Compile(msg);
let mut node_names: HashSet<&str> = HashSet::new();
for node in &graph.nodes {
if !node_names.insert(node.name.as_str()) {
return Err(compile_err(format!(
"duplicate node `{}` in graph `{}`",
node.name, graph.name
)));
}
}
let target_ok = |target: &str| target == END || node_names.contains(target);
let start = graph
.start
.clone()
.ok_or_else(|| compile_err(format!("graph `{}` has no `start` node", graph.name)))?;
if !node_names.contains(start.as_str()) {
return Err(compile_err(format!(
"start node `{start}` is not defined in graph `{}`",
graph.name
)));
}
let mut edges = Vec::new();
for edge in &graph.edges {
if !node_names.contains(edge.from.as_str()) {
return Err(compile_err(format!(
"edge source `{}` does not exist in graph `{}`",
edge.from, graph.name
)));
}
if !target_ok(&edge.to) {
return Err(compile_err(format!(
"edge target `{}` does not exist in graph `{}`",
edge.to, graph.name
)));
}
edges.push(EdgeSpec {
from: edge.from.clone(),
to: edge.to.clone(),
});
}
let nodes_with_static_edge: HashSet<&str> =
graph.edges.iter().map(|e| e.from.as_str()).collect();
let mut nodes = Vec::new();
for node in &graph.nodes {
let has_routes = !node.routes.is_empty();
let has_next = node.next.is_some();
let has_static_edge = nodes_with_static_edge.contains(node.name.as_str());
if has_routes && (has_next || has_static_edge) {
return Err(compile_err(format!(
"node `{}` mixes static routing (`next`/edge) with command routing (`routes`); use one or the other",
node.name
)));
}
let mut seen_labels: HashSet<&str> = HashSet::new();
for route in &node.routes {
if !seen_labels.insert(route.label.as_str()) {
return Err(compile_err(format!(
"duplicate route label `{}` on node `{}`",
route.label, node.name
)));
}
if !target_ok(&route.target) {
return Err(compile_err(format!(
"route target `{}` on node `{}` does not exist",
route.target, node.name
)));
}
}
if let Some(next) = &node.next
&& !target_ok(next)
{
return Err(compile_err(format!(
"next target `{next}` on node `{}` does not exist",
node.name
)));
}
let routing = if has_routes {
Routing::Conditional(
node.routes
.iter()
.map(|r| (r.label.clone(), r.target.clone()))
.collect(),
)
} else if let Some(next) = &node.next {
if next == END {
Routing::Terminal
} else {
Routing::Next(next.clone())
}
} else if let Some(edge) = graph.edges.iter().find(|e| e.from == node.name) {
if edge.to == END {
Routing::Terminal
} else {
Routing::Next(edge.to.clone())
}
} else {
Routing::Terminal
};
nodes.push(NodeSpec {
name: node.name.clone(),
kind: node.kind.clone().unwrap_or_else(|| "model".to_string()),
model: node.model.clone(),
prompt: node.prompt.clone(),
tools: node.tools.clone(),
routing,
});
}
let channels = graph
.channels
.iter()
.map(|c| ChannelSpec {
name: c.name.clone(),
reducer: c.reducer.clone(),
})
.collect();
Ok(Blueprint {
graph_id: graph.name.clone(),
start,
channels,
nodes,
edges,
defaults: graph.defaults.clone(),
})
}
pub const DEFAULT_NODE_KINDS: &[&str] = &[
"agent",
"model",
"tool_executor",
"subgraph",
"graph",
"router",
"human",
];
#[derive(Clone, Debug, Default)]
pub struct CapabilityResolver {
models: HashSet<String>,
tools: HashSet<String>,
subgraphs: HashSet<String>,
routers: HashSet<String>,
reducers: HashSet<String>,
node_kinds: HashSet<String>,
}
impl CapabilityResolver {
pub fn new() -> Self {
Self::default()
}
pub fn from_lists<M, T>(models: M, tools: T) -> Self
where
M: IntoIterator<Item = String>,
T: IntoIterator<Item = String>,
{
Self {
models: models.into_iter().collect(),
tools: tools.into_iter().collect(),
..Self::default()
}
}
pub fn from_registry<State: Send + Sync>(registry: &CapabilityRegistry<State>) -> Self {
use crate::registry::ComponentKind;
let collect = |kind| registry.names_including_aliases(kind).into_iter().collect();
Self {
models: collect(ComponentKind::Model),
tools: collect(ComponentKind::Tool),
subgraphs: collect(ComponentKind::Graph),
routers: collect(ComponentKind::Router),
reducers: collect(ComponentKind::Reducer),
node_kinds: DEFAULT_NODE_KINDS.iter().map(|k| (*k).to_owned()).collect(),
}
}
pub fn allow_model(mut self, name: impl Into<String>) -> Self {
self.models.insert(name.into());
self
}
pub fn allow_tool(mut self, name: impl Into<String>) -> Self {
self.tools.insert(name.into());
self
}
pub fn allow_subgraph(mut self, name: impl Into<String>) -> Self {
self.subgraphs.insert(name.into());
self
}
pub fn allow_router(mut self, name: impl Into<String>) -> Self {
self.routers.insert(name.into());
self
}
pub fn allow_reducer(mut self, name: impl Into<String>) -> Self {
self.reducers.insert(name.into());
self
}
pub fn with_node_kinds<I, S>(mut self, kinds: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.node_kinds = kinds.into_iter().map(Into::into).collect();
self
}
pub fn model_allowed(&self, name: &str) -> bool {
self.models.contains(name)
}
pub fn tool_allowed(&self, name: &str) -> bool {
self.tools.contains(name)
}
pub fn subgraph_allowed(&self, name: &str) -> bool {
self.subgraphs.contains(name)
}
pub fn router_allowed(&self, name: &str) -> bool {
self.routers.contains(name)
}
pub fn reducer_allowed(&self, name: &str) -> bool {
self.reducers.contains(name)
}
pub fn node_kind_allowed(&self, kind: &str) -> bool {
self.node_kinds.is_empty() || self.node_kinds.contains(kind)
}
pub fn bind_blueprint(&self, blueprint: &Blueprint) -> Result<()> {
for node in &blueprint.nodes {
if !self.node_kind_allowed(&node.kind) {
return Err(TinyAgentsError::Compile(format!(
"node `{}` has unknown kind `{}`",
node.name, node.kind
)));
}
match node.kind.as_str() {
"subgraph" | "graph" => {
if let Some(target) = &node.model
&& !self.subgraph_allowed(target)
{
return Err(TinyAgentsError::Capability(format!(
"node `{}` references unknown subgraph `{target}`",
node.name
)));
}
}
"router" => {
if let Some(target) = &node.model
&& !self.router_allowed(target)
{
return Err(TinyAgentsError::Capability(format!(
"node `{}` references unknown router `{target}`",
node.name
)));
}
}
_ => {
if let Some(model) = &node.model
&& !self.model_allowed(model)
{
return Err(TinyAgentsError::Capability(format!(
"node `{}` references unknown model `{model}`",
node.name
)));
}
}
}
for tool in &node.tools {
if !self.tool_allowed(tool) {
return Err(TinyAgentsError::Capability(format!(
"node `{}` references unknown tool `{tool}`",
node.name
)));
}
}
}
for channel in &blueprint.channels {
if !self.reducer_allowed(&channel.reducer) {
return Err(TinyAgentsError::Capability(format!(
"channel `{}` references unknown reducer `{}`",
channel.name, channel.reducer
)));
}
}
Ok(())
}
}
pub fn bind_capabilities(blueprint: &Blueprint, allow: &CapabilityResolver) -> Result<()> {
for node in &blueprint.nodes {
if let Some(model) = &node.model
&& !allow.model_allowed(model)
{
return Err(TinyAgentsError::Capability(format!(
"node `{}` references unknown model `{model}`",
node.name
)));
}
for tool in &node.tools {
if !allow.tool_allowed(tool) {
return Err(TinyAgentsError::Capability(format!(
"node `{}` references unknown tool `{tool}`",
node.name
)));
}
}
}
Ok(())
}
pub fn bind_capabilities_with_registry<State: Send + Sync>(
blueprint: &Blueprint,
registry: &CapabilityRegistry<State>,
) -> Result<()> {
CapabilityResolver::from_registry(registry).bind_blueprint(blueprint)
}
pub fn compile_source<State: Send + Sync>(
source: &str,
registry: &CapabilityRegistry<State>,
) -> Result<Vec<Blueprint>> {
let program = parse_str(source)?;
let blueprints = compile(&program)?;
for blueprint in &blueprints {
bind_capabilities_with_registry(blueprint, registry)?;
}
Ok(blueprints)
}
pub trait NodeFactory<State> {
fn make(&self, spec: &NodeSpec) -> Result<Node<State>>;
}
pub fn build_graph<State, F>(blueprint: &Blueprint, factory: &F) -> Result<StateGraph<State>>
where
State: Clone + Send + 'static,
F: NodeFactory<State>,
{
let mut graph = StateGraph::new().set_start(&blueprint.start);
for spec in &blueprint.nodes {
let node = factory.make(spec)?;
graph = graph.add_node(node);
graph = match &spec.routing {
Routing::Next(target) => graph.add_edge(&spec.name, target),
Routing::Conditional(routes) => graph.add_conditional_edges(
&spec.name,
routes
.iter()
.filter(|(_, target)| target != END)
.map(|(label, target)| (label.clone(), target.clone())),
),
Routing::Terminal => graph.add_end(&spec.name),
};
}
Ok(graph)
}