use crate::edge::Edge;
use crate::graph::Graph;
use crate::graph::GraphExecutionError;
use crate::node::Node;
use std::any::Any;
use std::sync::Arc;
type InputBinding = (String, String, String, Option<Arc<dyn Any + Send + Sync>>);
pub struct GraphBuilder {
name: String,
nodes: Vec<(String, Box<dyn Node>)>,
edges: Vec<Edge>,
input_bindings: Vec<InputBinding>,
output_bindings: Vec<(String, String, String)>,
}
impl GraphBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
nodes: Vec::new(),
edges: Vec::new(),
input_bindings: Vec::new(),
output_bindings: Vec::new(),
}
}
pub fn add_node(mut self, name: impl Into<String>, node: Box<dyn Node>) -> Self {
let name_str = name.into();
if name_str == "graph" {
panic!(
"Node name 'graph' is reserved for graph I/O namespace. Use 'graph.input_name' for graph inputs and 'graph.output_name' for graph outputs."
);
}
self.nodes.push((name_str, node));
self
}
pub fn connect(
mut self,
source: &str,
source_port: &str,
target: &str,
target_port: &str,
) -> Self {
if self
.edges
.iter()
.any(|e| e.source_node() == source && e.source_port() == source_port)
{
panic!(
"Fan-out not supported: output port '{}.{}' is already connected. Each output port can only connect to one input port.",
source, source_port
);
}
self.edges.push(Edge {
source_node: source.to_string(),
source_port: source_port.to_string(),
target_node: target.to_string(),
target_port: target_port.to_string(),
});
self
}
pub fn input<T: Send + Sync + 'static>(
mut self,
external_name: impl Into<String>,
node: &str,
port: &str,
value: Option<T>,
) -> Self {
let value_arc = value.map(|v| Arc::new(v) as Arc<dyn Any + Send + Sync>);
self.input_bindings.push((
external_name.into(),
node.to_string(),
port.to_string(),
value_arc,
));
self
}
pub fn output(mut self, external_name: impl Into<String>, node: &str, port: &str) -> Self {
self
.output_bindings
.push((external_name.into(), node.to_string(), port.to_string()));
self
}
pub fn build(self) -> Result<Graph, GraphExecutionError> {
let mut graph = Graph::new(self.name);
for (name, node) in self.nodes {
graph
.add_node(name.clone(), node)
.map_err(|e: String| -> GraphExecutionError { Box::new(std::io::Error::other(e)) })?;
}
for edge in self.edges {
graph
.add_edge(edge)
.map_err(|e: String| -> GraphExecutionError { Box::new(std::io::Error::other(e)) })?;
}
for (external, node, port, value) in self.input_bindings {
graph
.expose_input_port(&node, &port, &external)
.map_err(|e: String| -> GraphExecutionError { Box::new(std::io::Error::other(e)) })?;
if let Some(val) = value {
let (tx, rx) = tokio::sync::mpsc::channel(1);
graph
.connect_input_channel(&external, rx)
.map_err(|e: String| -> GraphExecutionError { Box::new(std::io::Error::other(e)) })?;
tokio::spawn(async move {
let _ = tx.send(val).await;
});
}
}
for (external, node, port) in self.output_bindings {
graph
.expose_output_port(&node, &port, &external)
.map_err(|e: String| -> GraphExecutionError { Box::new(std::io::Error::other(e)) })?;
}
Ok(graph)
}
}