mod types;
pub(crate) use types::{Branch, BuilderNode, NodeMeta};
pub use types::{
END, ForkId, GraphBuilder, GraphDefaults, NodeContext, NodeFuture, NodeHandler, Route,
RouterFn, START,
};
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::sync::Arc;
use crate::graph::command::NodeResult;
use crate::graph::compiled::CompiledGraph;
use crate::graph::reducer::{OverwriteStateReducer, StateReducer};
use crate::harness::ids::{GraphId, NodeId};
use crate::{Result, TinyAgentsError};
impl<State, Update> Default for GraphBuilder<State, Update>
where
State: Clone + Send + Sync + 'static,
Update: Send + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<State, Update> GraphBuilder<State, Update>
where
State: Clone + Send + Sync + 'static,
Update: Send + 'static,
{
pub fn new() -> Self {
Self {
graph_id: GraphId::new(format!("graph-{}", crate::graph::compiled::next_seq())),
name: None,
nodes: HashMap::new(),
edges: HashMap::new(),
branches: HashMap::new(),
command_nodes: HashSet::new(),
waiting: HashMap::new(),
reducer: None,
recursion_limit: 50,
parallel: false,
max_concurrency: None,
node_timeout: None,
node_meta: HashMap::new(),
}
}
pub fn set_defaults(mut self, defaults: GraphDefaults) -> Self {
if let Some(limit) = defaults.recursion_limit {
self.recursion_limit = limit;
}
if let Some(parallel) = defaults.parallel {
self.parallel = parallel;
}
if let Some(max) = defaults.max_concurrency {
self.max_concurrency = Some(max);
}
if let Some(timeout) = defaults.node_timeout {
self.node_timeout = Some(timeout);
}
self
}
pub fn with_max_concurrency(mut self, n: usize) -> Self {
self.max_concurrency = (n > 0).then_some(n);
self
}
pub fn with_node_timeout(mut self, timeout: std::time::Duration) -> Self {
self.node_timeout = Some(timeout);
self
}
pub fn with_parallel(mut self, parallel: bool) -> Self {
self.parallel = parallel;
self
}
pub fn with_graph_id(mut self, id: impl Into<GraphId>) -> Self {
self.graph_id = id.into();
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_recursion_limit(mut self, limit: usize) -> Self {
self.recursion_limit = limit;
self
}
pub fn set_reducer<R>(mut self, reducer: R) -> Self
where
R: StateReducer<State, Update> + 'static,
{
self.reducer = Some(Arc::new(reducer));
self
}
pub fn add_node<F, Fut>(mut self, id: impl Into<NodeId>, handler: F) -> Self
where
F: Fn(State, NodeContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<NodeResult<Update>>> + Send + 'static,
{
let id = id.into();
self.nodes.insert(
id.clone(),
BuilderNode {
id,
handler: Arc::new(move |state, ctx| Box::pin(handler(state, ctx))),
},
);
self
}
pub fn add_edge(mut self, from: impl Into<NodeId>, to: impl Into<NodeId>) -> Self {
self.edges.insert(from.into(), to.into());
self
}
pub fn add_sequence<I, N>(mut self, nodes: I) -> Self
where
I: IntoIterator<Item = N>,
N: Into<NodeId>,
{
let nodes: Vec<NodeId> = nodes.into_iter().map(Into::into).collect();
for pair in nodes.windows(2) {
self.edges.insert(pair[0].clone(), pair[1].clone());
}
self
}
pub fn add_waiting_edge(mut self, from: impl Into<NodeId>, to: impl Into<NodeId>) -> Self {
let from = from.into();
let to = to.into();
self.edges.insert(from.clone(), to.clone());
self.waiting.entry(to).or_default().insert(from);
self
}
pub fn set_entry(self, node: impl Into<NodeId>) -> Self {
self.add_edge(START, node)
}
pub fn set_finish(self, node: impl Into<NodeId>) -> Self {
self.add_edge(node, END)
}
pub fn add_conditional_edges<F, R, I, K, V>(
mut self,
from: impl Into<NodeId>,
router: F,
routes: I,
) -> Self
where
F: Fn(&State) -> R + Send + Sync + 'static,
R: ToString,
I: IntoIterator<Item = (K, V)>,
K: ToString,
V: Into<NodeId>,
{
let routes = routes
.into_iter()
.map(|(k, v)| (k.to_string(), v.into()))
.collect();
self.branches.insert(
from.into(),
Branch {
router: Arc::new(move |state| router(state).to_string()),
routes,
},
);
self
}
pub fn mark_command_routing(mut self, node: impl Into<NodeId>) -> Self {
self.command_nodes.insert(node.into());
self
}
pub fn with_command_destinations<I, N>(
mut self,
node: impl Into<NodeId>,
destinations: I,
) -> Self
where
I: IntoIterator<Item = N>,
N: Into<NodeId>,
{
let node = node.into();
self.command_nodes.insert(node.clone());
let dests = destinations.into_iter().map(Into::into).collect();
self.node_meta.entry(node).or_default().command_destinations = dests;
self
}
pub fn with_node_kind(mut self, node: impl Into<NodeId>, kind: impl Into<String>) -> Self {
self.node_meta.entry(node.into()).or_default().kind = Some(kind.into());
self
}
pub fn with_node_metadata(
mut self,
node: impl Into<NodeId>,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.node_meta
.entry(node.into())
.or_default()
.metadata
.insert(key.into(), value.into());
self
}
pub fn mark_subgraph(mut self, node: impl Into<NodeId>) -> Self {
self.node_meta.entry(node.into()).or_default().subgraph = true;
self
}
pub fn mark_interrupt(mut self, node: impl Into<NodeId>) -> Self {
self.node_meta.entry(node.into()).or_default().interrupt = true;
self
}
pub fn mark_deferred(mut self, node: impl Into<NodeId>) -> Self {
self.node_meta.entry(node.into()).or_default().deferred = true;
self
}
pub fn compile(self) -> Result<CompiledGraph<State, Update>> {
if self.reducer.is_none() {
return Err(TinyAgentsError::Validation(
"no state reducer set; call set_reducer (or GraphBuilder::overwrite)".to_string(),
));
}
let entry = self
.edges
.get(&NodeId::from(START))
.cloned()
.ok_or(TinyAgentsError::MissingStart)?;
if entry.as_str() == END {
return Err(TinyAgentsError::Validation(
"START cannot route directly to END".to_string(),
));
}
self.require_node(&entry)?;
for (from, to) in &self.edges {
if from.as_str() != START {
self.require_node(from)?;
}
if to.as_str() != END {
self.require_node(to)?;
}
if to.as_str() == START {
return Err(TinyAgentsError::Validation(
"START cannot be an edge target".to_string(),
));
}
if from.as_str() == END {
return Err(TinyAgentsError::Validation(
"END cannot be an edge source".to_string(),
));
}
}
for (from, branch) in &self.branches {
self.require_node(from)?;
if self.edges.contains_key(from) {
return Err(TinyAgentsError::Validation(format!(
"node `{from}` has both a static edge and conditional edges"
)));
}
for target in branch.routes.values() {
if target.as_str() != END {
self.require_node(target)?;
}
}
}
for (to, froms) in &self.waiting {
self.require_node(to)?;
for from in froms {
self.require_node(from)?;
}
}
for node in &self.command_nodes {
self.require_node(node)?;
if self.edges.contains_key(node) || self.branches.contains_key(node) {
return Err(TinyAgentsError::Validation(format!(
"node `{node}` declares command routing but also has static/conditional edges"
)));
}
}
let Self {
graph_id,
name,
nodes,
edges,
branches,
command_nodes,
waiting,
reducer,
recursion_limit,
parallel,
max_concurrency,
node_timeout,
node_meta,
} = self;
Ok(CompiledGraph::from_parts(
graph_id,
name,
nodes,
edges,
branches,
command_nodes,
waiting,
entry,
reducer.expect("reducer presence checked above"),
recursion_limit,
parallel,
max_concurrency,
node_timeout,
node_meta,
))
}
fn require_node(&self, id: &NodeId) -> Result<()> {
if self.nodes.contains_key(id) {
Ok(())
} else {
Err(TinyAgentsError::MissingNode(id.to_string()))
}
}
}
impl<State> GraphBuilder<State, State>
where
State: Clone + Send + Sync + 'static,
{
pub fn overwrite() -> Self {
Self::new().set_reducer(OverwriteStateReducer)
}
}
#[cfg(test)]
mod test;