use rustc_hash::FxHashMap;
use std::sync::Arc;
use super::edges::{ConditionalEdge, EdgePredicate};
use crate::node::Node;
use crate::reducers::{Reducer, ReducerRegistry};
use crate::runtimes::{EventBusConfig, RuntimeConfig};
use crate::types::{ChannelType, NodeKind};
type GraphParts = (
FxHashMap<NodeKind, Arc<dyn Node>>,
FxHashMap<NodeKind, Vec<NodeKind>>,
Vec<ConditionalEdge>,
RuntimeConfig,
ReducerRegistry,
);
pub struct GraphBuilder {
nodes: FxHashMap<NodeKind, Arc<dyn Node>>,
edges: FxHashMap<NodeKind, Vec<NodeKind>>,
conditional_edges: Vec<ConditionalEdge>,
runtime_config: RuntimeConfig,
reducer_registry: ReducerRegistry,
}
impl Default for GraphBuilder {
fn default() -> Self {
Self::new()
}
}
impl GraphBuilder {
#[must_use]
pub fn new() -> Self {
Self {
nodes: FxHashMap::default(),
edges: FxHashMap::default(),
conditional_edges: Vec::new(),
runtime_config: RuntimeConfig::default(),
reducer_registry: ReducerRegistry::default(),
}
}
#[must_use]
pub fn add_conditional_edge(mut self, from: NodeKind, predicate: EdgePredicate) -> Self {
self.conditional_edges
.push(ConditionalEdge::new(from, predicate));
self
}
#[must_use]
pub fn add_node(mut self, id: NodeKind, node: impl Node + 'static) -> Self {
match id {
NodeKind::Start | NodeKind::End => {
tracing::warn!(
?id,
"Ignoring registration of virtual node kind (Start/End are virtual)"
);
}
_ => {
self.nodes.insert(id, Arc::new(node));
}
}
self
}
#[must_use]
pub fn add_edge(mut self, from: NodeKind, to: NodeKind) -> Self {
self.edges.entry(from).or_default().push(to);
self
}
#[must_use]
pub fn with_runtime_config(mut self, runtime_config: RuntimeConfig) -> Self {
self.runtime_config = runtime_config;
self
}
#[must_use]
pub fn with_event_bus_config(mut self, config: EventBusConfig) -> Self {
let mut runtime_config = self.runtime_config.clone();
runtime_config.event_bus = config;
self.runtime_config = runtime_config;
self
}
#[must_use]
pub fn with_reducer(mut self, channel: ChannelType, reducer: Arc<dyn Reducer>) -> Self {
self.reducer_registry.register(channel, reducer);
self
}
#[must_use]
pub fn with_reducer_registry(mut self, registry: ReducerRegistry) -> Self {
self.reducer_registry = registry;
self
}
pub fn nodes(&self) -> super::iteration::NodesIter<'_> {
super::iteration::NodesIter::new(self.nodes.keys())
}
pub fn edges(&self) -> super::iteration::EdgesIter<'_> {
super::iteration::EdgesIter::new(&self.edges)
}
#[must_use]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
#[must_use]
pub fn edge_count(&self) -> usize {
self.edges.values().map(|v| v.len()).sum()
}
#[must_use]
pub fn topological_sort(&self) -> Vec<crate::types::NodeKind> {
super::iteration::topological_sort(&self.edges)
}
#[cfg(feature = "petgraph-compat")]
#[cfg_attr(docsrs, doc(cfg(feature = "petgraph-compat")))]
#[must_use]
pub fn to_petgraph(&self) -> super::petgraph_compat::PetgraphConversion {
super::petgraph_compat::to_petgraph(&self.edges)
}
#[cfg(feature = "petgraph-compat")]
#[cfg_attr(docsrs, doc(cfg(feature = "petgraph-compat")))]
#[must_use]
pub fn to_dot(&self) -> String {
super::petgraph_compat::to_dot(&self.edges)
}
#[cfg(feature = "petgraph-compat")]
#[cfg_attr(docsrs, doc(cfg(feature = "petgraph-compat")))]
#[must_use]
pub fn is_cyclic_petgraph(&self) -> bool {
super::petgraph_compat::is_cyclic(&self.edges)
}
pub(super) fn into_parts(self) -> GraphParts {
(
self.nodes,
self.edges,
self.conditional_edges,
self.runtime_config,
self.reducer_registry,
)
}
pub(super) fn nodes_ref(&self) -> &FxHashMap<NodeKind, Arc<dyn Node>> {
&self.nodes
}
pub(super) fn edges_ref(&self) -> &FxHashMap<NodeKind, Vec<NodeKind>> {
&self.edges
}
pub(super) fn conditional_edges_ref(&self) -> &Vec<ConditionalEdge> {
&self.conditional_edges
}
}