use std::collections::HashSet;
use crate::error::{Result, TinyAgentsError};
use crate::language::ast::{ChannelDecl, GraphDecl, NodeDecl, Program};
use crate::language::compiler::{CapabilityResolver, DEFAULT_NODE_KINDS, compile};
use crate::language::diagnostic::Diagnostic;
use crate::language::parser::parse_str;
use crate::language::source::SourceFile;
use crate::language::span::Span;
use crate::language::types::Blueprint;
use crate::registry::CapabilityRegistry;
const CODE_UNKNOWN_MODEL: &str = "E-rag-unknown-model";
const CODE_UNKNOWN_TOOL: &str = "E-rag-unknown-tool";
const CODE_UNKNOWN_SUBGRAPH: &str = "E-rag-unknown-subgraph";
const CODE_UNKNOWN_ROUTER: &str = "E-rag-unknown-router";
const CODE_UNKNOWN_AGENT: &str = "E-rag-unknown-agent";
const CODE_UNKNOWN_REDUCER: &str = "E-rag-unknown-reducer";
const CODE_INVALID_NODE_KIND: &str = "E-rag-invalid-node-kind";
#[derive(Clone, Debug)]
pub struct Resolver {
caps: CapabilityResolver,
agents: HashSet<String>,
}
impl Resolver {
pub fn from_registry<State: Send + Sync>(registry: &CapabilityRegistry<State>) -> Self {
use crate::registry::ComponentKind;
let agents = registry
.names_including_aliases(ComponentKind::Agent)
.into_iter()
.collect();
Self {
caps: CapabilityResolver::from_registry(registry),
agents,
}
}
pub fn from_capabilities(caps: CapabilityResolver) -> Self {
Self {
caps,
agents: HashSet::new(),
}
}
pub fn capabilities(&self) -> &CapabilityResolver {
&self.caps
}
pub fn allow_agent(mut self, name: impl Into<String>) -> Self {
self.agents.insert(name.into());
self
}
pub fn agent_allowed(&self, name: &str) -> bool {
self.agents.contains(name)
}
pub fn resolve_program(&self, program: &Program) -> Vec<Diagnostic> {
let mut diagnostics = Vec::new();
for graph in &program.graphs {
self.resolve_graph(graph, &mut diagnostics);
}
diagnostics
}
fn resolve_graph(&self, graph: &GraphDecl, out: &mut Vec<Diagnostic>) {
for node in &graph.nodes {
self.resolve_node(node, out);
}
for channel in &graph.channels {
self.resolve_channel(channel, out);
}
}
fn resolve_node(&self, node: &NodeDecl, out: &mut Vec<Diagnostic>) {
let kind = node.kind.as_deref().unwrap_or("model");
if !self.caps.node_kind_allowed(kind) {
out.push(
Diagnostic::error(
format!("node `{}` has unknown kind `{kind}`", node.name),
node.span,
)
.with_code(CODE_INVALID_NODE_KIND)
.with_primary_label("not an allowed node kind")
.with_help(format!("allowed kinds: {}", DEFAULT_NODE_KINDS.join(", "))),
);
}
match kind {
"subgraph" | "graph" => {
if let Some(target) = node.graph.as_deref().or(node.model.as_deref()) {
self.check_ref(
self.caps.subgraph_allowed(target),
&node.name,
"subgraph",
target,
node.span,
CODE_UNKNOWN_SUBGRAPH,
out,
);
}
}
"router" => {
if let Some(target) = node.model.as_deref() {
self.check_ref(
self.caps.router_allowed(target),
&node.name,
"router",
target,
node.span,
CODE_UNKNOWN_ROUTER,
out,
);
}
}
"subagent" => {
if let Some(target) = node.agent.as_deref() {
self.check_ref(
self.agent_allowed(target),
&node.name,
"agent",
target,
node.span,
CODE_UNKNOWN_AGENT,
out,
);
}
}
_ => {
if let Some(target) = node.model.as_deref() {
self.check_ref(
self.caps.model_allowed(target),
&node.name,
"model",
target,
node.span,
CODE_UNKNOWN_MODEL,
out,
);
}
}
}
for tool in &node.tools {
self.check_ref(
self.caps.tool_allowed(tool),
&node.name,
"tool",
tool,
node.span,
CODE_UNKNOWN_TOOL,
out,
);
}
}
fn resolve_channel(&self, channel: &ChannelDecl, out: &mut Vec<Diagnostic>) {
if !self.caps.reducer_allowed(&channel.reducer) {
out.push(
Diagnostic::error(
format!(
"channel `{}` references unknown reducer `{}`",
channel.name, channel.reducer
),
channel.span,
)
.with_code(CODE_UNKNOWN_REDUCER)
.with_primary_label("reducer not registered or not allowed")
.with_help("register the reducer before referencing it from `.rag`"),
);
}
}
#[allow(clippy::too_many_arguments)]
fn check_ref(
&self,
allowed: bool,
node: &str,
what: &str,
target: &str,
span: Span,
code: &str,
out: &mut Vec<Diagnostic>,
) {
if allowed {
return;
}
out.push(
Diagnostic::error(
format!("node `{node}` references unknown {what} `{target}`"),
span,
)
.with_code(code)
.with_primary_label(format!("{what} not registered or not allowed"))
.with_help(format!(
"register `{target}` as a {what} before referencing it from `.rag`"
)),
);
}
pub fn check_program(&self, program: &Program, source: Option<&SourceFile>) -> Result<()> {
match self.resolve_program(program).into_iter().next() {
Some(diagnostic) => Err(fold_diagnostic(diagnostic, source)),
None => Ok(()),
}
}
pub fn resolve_blueprint(&self, blueprint: &Blueprint) -> Result<()> {
for node in &blueprint.nodes {
if !self.caps.node_kind_allowed(&node.kind) {
return Err(TinyAgentsError::Compile(format!(
"node `{}` has unknown kind `{}`",
node.name, node.kind
)));
}
match node.kind.as_str() {
"subgraph" | "graph" => {
if let Some(target) = node.subgraph.as_deref().or(node.model.as_deref())
&& !self.caps.subgraph_allowed(target)
{
return Err(unregistered("subgraph", &node.name, target));
}
}
"router" => {
if let Some(target) = node.model.as_deref()
&& !self.caps.router_allowed(target)
{
return Err(unregistered("router", &node.name, target));
}
}
"subagent" => {
if let Some(target) = node.agent.as_deref()
&& !self.agent_allowed(target)
{
return Err(unregistered("agent", &node.name, target));
}
}
_ => {
if let Some(target) = node.model.as_deref()
&& !self.caps.model_allowed(target)
{
return Err(unregistered("model", &node.name, target));
}
}
}
for tool in &node.tools {
if !self.caps.tool_allowed(tool) {
return Err(unregistered("tool", &node.name, tool));
}
}
}
for channel in &blueprint.channels {
if !self.caps.reducer_allowed(&channel.reducer) {
return Err(TinyAgentsError::Capability(format!(
"channel `{}` references unknown reducer `{}`",
channel.name, channel.reducer
)));
}
}
Ok(())
}
}
fn fold_diagnostic(diagnostic: Diagnostic, source: Option<&SourceFile>) -> TinyAgentsError {
let is_kind = diagnostic.code.as_deref() == Some(CODE_INVALID_NODE_KIND);
let rendered = match source {
Some(file) => diagnostic.render(file),
None => diagnostic.render_plain(),
};
if is_kind {
TinyAgentsError::Compile(rendered)
} else {
TinyAgentsError::Capability(rendered)
}
}
fn unregistered(what: &str, node: &str, target: &str) -> TinyAgentsError {
TinyAgentsError::Capability(format!(
"node `{node}` references unknown {what} `{target}`"
))
}
pub fn resolve_source<State: Send + Sync>(
source: &str,
registry: &CapabilityRegistry<State>,
) -> Result<Vec<Blueprint>> {
let program = parse_str(source)?;
let file = SourceFile::anonymous(source);
Resolver::from_registry(registry).check_program(&program, Some(&file))?;
compile(&program)
}