mod types;
pub(crate) use types::{Branch, BuilderNode};
pub use types::{END, ForkId, GraphBuilder, NodeContext, NodeFuture, NodeHandler, RouterFn, START};
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::sync::Arc;
use crate::graph::command::NodeResult;
use crate::graph::compiled::CompiledGraph;
use crate::graph::reducer::{OverwriteStateReducer, StateReducer};
use crate::harness::ids::{GraphId, NodeId};
use crate::{Result, RustAgentsError};
impl<State, Update> Default for GraphBuilder<State, Update>
where
State: Clone + Send + Sync + 'static,
Update: Send + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<State, Update> GraphBuilder<State, Update>
where
State: Clone + Send + Sync + 'static,
Update: Send + 'static,
{
pub fn new() -> Self {
Self {
graph_id: GraphId::new(format!("graph-{}", crate::graph::compiled::next_seq())),
nodes: HashMap::new(),
edges: HashMap::new(),
branches: HashMap::new(),
command_nodes: HashSet::new(),
reducer: None,
recursion_limit: 50,
parallel: false,
}
}
pub fn with_parallel(mut self, parallel: bool) -> Self {
self.parallel = parallel;
self
}
pub fn with_graph_id(mut self, id: impl Into<GraphId>) -> Self {
self.graph_id = id.into();
self
}
pub fn with_recursion_limit(mut self, limit: usize) -> Self {
self.recursion_limit = limit;
self
}
pub fn set_reducer<R>(mut self, reducer: R) -> Self
where
R: StateReducer<State, Update> + 'static,
{
self.reducer = Some(Arc::new(reducer));
self
}
pub fn add_node<F, Fut>(mut self, id: impl Into<NodeId>, handler: F) -> Self
where
F: Fn(State, NodeContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<NodeResult<Update>>> + Send + 'static,
{
let id = id.into();
self.nodes.insert(
id.clone(),
BuilderNode {
id,
handler: Arc::new(move |state, ctx| Box::pin(handler(state, ctx))),
},
);
self
}
pub fn add_edge(mut self, from: impl Into<NodeId>, to: impl Into<NodeId>) -> Self {
self.edges.insert(from.into(), to.into());
self
}
pub fn set_entry(self, node: impl Into<NodeId>) -> Self {
self.add_edge(START, node)
}
pub fn set_finish(self, node: impl Into<NodeId>) -> Self {
self.add_edge(node, END)
}
pub fn add_conditional_edges<F, I, K, V>(
mut self,
from: impl Into<NodeId>,
router: F,
routes: I,
) -> Self
where
F: Fn(&State) -> String + Send + Sync + 'static,
I: IntoIterator<Item = (K, V)>,
K: Into<String>,
V: Into<NodeId>,
{
let routes = routes
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect();
self.branches.insert(
from.into(),
Branch {
router: Arc::new(router),
routes,
},
);
self
}
pub fn mark_command_routing(mut self, node: impl Into<NodeId>) -> Self {
self.command_nodes.insert(node.into());
self
}
pub fn compile(self) -> Result<CompiledGraph<State, Update>> {
if self.reducer.is_none() {
return Err(RustAgentsError::Validation(
"no state reducer set; call set_reducer (or GraphBuilder::overwrite)".to_string(),
));
}
let entry = self
.edges
.get(&NodeId::from(START))
.cloned()
.ok_or(RustAgentsError::MissingStart)?;
if entry.as_str() == END {
return Err(RustAgentsError::Validation(
"START cannot route directly to END".to_string(),
));
}
self.require_node(&entry)?;
for (from, to) in &self.edges {
if from.as_str() != START {
self.require_node(from)?;
}
if to.as_str() != END {
self.require_node(to)?;
}
if to.as_str() == START {
return Err(RustAgentsError::Validation(
"START cannot be an edge target".to_string(),
));
}
if from.as_str() == END {
return Err(RustAgentsError::Validation(
"END cannot be an edge source".to_string(),
));
}
}
for (from, branch) in &self.branches {
self.require_node(from)?;
if self.edges.contains_key(from) {
return Err(RustAgentsError::Validation(format!(
"node `{from}` has both a static edge and conditional edges"
)));
}
for target in branch.routes.values() {
if target.as_str() != END {
self.require_node(target)?;
}
}
}
for node in &self.command_nodes {
self.require_node(node)?;
if self.edges.contains_key(node) || self.branches.contains_key(node) {
return Err(RustAgentsError::Validation(format!(
"node `{node}` declares command routing but also has static/conditional edges"
)));
}
}
let Self {
graph_id,
nodes,
edges,
branches,
command_nodes,
reducer,
recursion_limit,
parallel,
} = self;
Ok(CompiledGraph::from_parts(
graph_id,
nodes,
edges,
branches,
command_nodes,
entry,
reducer.expect("reducer presence checked above"),
recursion_limit,
parallel,
))
}
fn require_node(&self, id: &NodeId) -> Result<()> {
if self.nodes.contains_key(id) {
Ok(())
} else {
Err(RustAgentsError::MissingNode(id.to_string()))
}
}
}
impl<State> GraphBuilder<State, State>
where
State: Clone + Send + Sync + 'static,
{
pub fn overwrite() -> Self {
Self::new().set_reducer(OverwriteStateReducer)
}
}
#[cfg(test)]
mod test;