use crate::edge::Edge;
use crate::node::{InputStreams, Node, NodeExecutionError, OutputStreams};
use async_trait::async_trait;
use std::any::Any;
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::{AtomicU8, Ordering};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio_stream::StreamExt;
const DATAFLOW_CHANNEL_CAPACITY: usize = 64;
pub type GraphExecutionError = Box<dyn std::error::Error + Send + Sync>;
type ExecutionHandleVec = Arc<Mutex<Vec<JoinHandle<Result<(), GraphExecutionError>>>>>;
type NodesRestoredMap = Arc<Mutex<HashMap<String, Box<dyn Node>>>>;
#[derive(Clone, Debug)]
struct PortMapping {
node: String,
port: String,
}
pub struct Graph {
name: String,
nodes: Arc<StdMutex<HashMap<String, Box<dyn Node>>>>,
edges: Vec<Edge>,
execution_handles: ExecutionHandleVec,
stop_signal: Arc<tokio::sync::Notify>,
pause_signal: Arc<tokio::sync::Notify>,
execution_state: Arc<AtomicU8>,
input_port_mapping: HashMap<String, PortMapping>,
output_port_mapping: HashMap<String, PortMapping>,
connected_input_channels:
HashMap<String, Option<tokio::sync::mpsc::Receiver<Arc<dyn Any + Send + Sync>>>>,
connected_output_channels: HashMap<String, tokio::sync::mpsc::Sender<Arc<dyn Any + Send + Sync>>>,
input_port_names: Vec<String>,
output_port_names: Vec<String>,
nodes_restored_after_run: Option<NodesRestoredMap>,
}
impl Graph {
pub fn new(name: String) -> Self {
Self {
name,
nodes: Arc::new(StdMutex::new(HashMap::new())),
edges: Vec::new(),
execution_handles: Arc::new(Mutex::new(Vec::new())),
stop_signal: Arc::new(tokio::sync::Notify::new()),
pause_signal: Arc::new(tokio::sync::Notify::new()),
execution_state: Arc::new(AtomicU8::new(0)), input_port_mapping: HashMap::new(),
output_port_mapping: HashMap::new(),
connected_input_channels: HashMap::new(),
connected_output_channels: HashMap::new(),
input_port_names: Vec::new(),
output_port_names: Vec::new(),
nodes_restored_after_run: None,
}
}
pub fn expose_input_port(
&mut self,
internal_node: &str,
internal_port: &str,
external_name: &str,
) -> Result<(), String> {
let guard = self.nodes.lock().unwrap();
let node = guard
.get(internal_node)
.ok_or_else(|| format!("Internal node '{}' does not exist", internal_node))?;
if !node.has_input_port(internal_port) {
return Err(format!(
"Internal node '{}' does not have input port '{}'",
internal_node, internal_port
));
}
self.input_port_mapping.insert(
external_name.to_string(),
PortMapping {
node: internal_node.to_string(),
port: internal_port.to_string(),
},
);
if !self.input_port_names.contains(&external_name.to_string()) {
self.input_port_names.push(external_name.to_string());
}
Ok(())
}
pub fn expose_output_port(
&mut self,
internal_node: &str,
internal_port: &str,
external_name: &str,
) -> Result<(), String> {
let guard = self.nodes.lock().unwrap();
let node = guard
.get(internal_node)
.ok_or_else(|| format!("Internal node '{}' does not exist", internal_node))?;
if !node.has_output_port(internal_port) {
return Err(format!(
"Internal node '{}' does not have output port '{}'",
internal_node, internal_port
));
}
self.output_port_mapping.insert(
external_name.to_string(),
PortMapping {
node: internal_node.to_string(),
port: internal_port.to_string(),
},
);
if !self.output_port_names.contains(&external_name.to_string()) {
self.output_port_names.push(external_name.to_string());
}
Ok(())
}
pub fn connect_input_channel(
&mut self,
external_port: &str,
receiver: tokio::sync::mpsc::Receiver<Arc<dyn Any + Send + Sync>>,
) -> Result<(), String> {
if !self.input_port_mapping.contains_key(external_port) {
return Err(format!(
"External input port '{}' is not exposed",
external_port
));
}
self
.connected_input_channels
.insert(external_port.to_string(), Some(receiver));
Ok(())
}
pub fn connect_output_channel(
&mut self,
external_port: &str,
sender: tokio::sync::mpsc::Sender<Arc<dyn Any + Send + Sync>>,
) -> Result<(), String> {
if !self.output_port_mapping.contains_key(external_port) {
return Err(format!(
"External output port '{}' is not exposed",
external_port
));
}
self
.connected_output_channels
.insert(external_port.to_string(), sender);
Ok(())
}
pub fn name(&self) -> &str {
&self.name
}
pub fn set_name(&mut self, name: &str) {
self.name = name.to_string();
}
pub fn get_nodes(&self) -> std::sync::MutexGuard<'_, HashMap<String, Box<dyn Node>>> {
self.nodes.lock().unwrap()
}
pub fn find_node_by_name(
&self,
name: &str,
) -> Option<std::sync::MutexGuard<'_, HashMap<String, Box<dyn Node>>>> {
let guard = self.nodes.lock().unwrap();
if guard.contains_key(name) {
Some(guard)
} else {
None
}
}
pub fn node_count(&self) -> usize {
self.nodes.lock().unwrap().len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn has_node(&self, name: &str) -> bool {
self.nodes.lock().unwrap().contains_key(name)
}
pub fn has_edge(&self, source_node: &str, target_node: &str) -> bool {
self
.edges
.iter()
.any(|e| e.source_node() == source_node && e.target_node() == target_node)
}
pub fn add_node(&mut self, name: String, node: Box<dyn Node>) -> Result<(), String> {
let mut g = self.nodes.lock().unwrap();
if g.contains_key(&name) {
return Err(format!("Node with name '{}' already exists", name));
}
g.insert(name, node);
Ok(())
}
pub fn remove_node(&mut self, name: &str) -> Result<(), String> {
let mut g = self.nodes.lock().unwrap();
if !g.contains_key(name) {
return Err(format!("Node with name '{}' does not exist", name));
}
let has_edges = self
.edges
.iter()
.any(|e| e.source_node() == name || e.target_node() == name);
if has_edges {
return Err(format!(
"Cannot remove node '{}': it has connected edges",
name
));
}
g.remove(name);
Ok(())
}
pub fn get_edges(&self) -> Vec<&Edge> {
self.edges.iter().collect()
}
pub fn find_edge_by_nodes_and_ports(
&self,
source_node: &str,
source_port: &str,
target_node: &str,
target_port: &str,
) -> Option<&Edge> {
self.edges.iter().find(|e| {
e.source_node() == source_node
&& e.source_port() == source_port
&& e.target_node() == target_node
&& e.target_port() == target_port
})
}
pub fn add_edge(&mut self, edge: Edge) -> Result<(), String> {
let g = self.nodes.lock().unwrap();
if !g.contains_key(edge.source_node()) {
return Err(format!(
"Source node '{}' does not exist",
edge.source_node()
));
}
if !g.contains_key(edge.target_node()) {
return Err(format!(
"Target node '{}' does not exist",
edge.target_node()
));
}
let source_node = g.get(edge.source_node()).unwrap();
if !source_node.has_output_port(edge.source_port()) {
return Err(format!(
"Source node '{}' does not have output port '{}'",
edge.source_node(),
edge.source_port()
));
}
let target_node = g.get(edge.target_node()).unwrap();
if !target_node.has_input_port(edge.target_port()) {
return Err(format!(
"Target node '{}' does not have input port '{}'",
edge.target_node(),
edge.target_port()
));
}
drop(g);
if self
.find_edge_by_nodes_and_ports(
edge.source_node(),
edge.source_port(),
edge.target_node(),
edge.target_port(),
)
.is_some()
{
return Err("Edge already exists".to_string());
}
self.edges.push(edge);
Ok(())
}
pub fn remove_edge(
&mut self,
source_node: &str,
source_port: &str,
target_node: &str,
target_port: &str,
) -> Result<(), String> {
let index = self
.edges
.iter()
.position(|e| {
e.source_node() == source_node
&& e.source_port() == source_port
&& e.target_node() == target_node
&& e.target_port() == target_port
})
.ok_or_else(|| "Edge not found".to_string())?;
self.edges.remove(index);
Ok(())
}
pub async fn execute(&mut self) -> Result<(), GraphExecutionError> {
let mut external_inputs = HashMap::new();
for (port_name, receiver_option) in &mut self.connected_input_channels {
if let Some(receiver) = receiver_option.take() {
let stream = tokio_stream::wrappers::ReceiverStream::new(receiver);
let pinned_stream = Box::pin(stream) as crate::node::InputStream;
external_inputs.insert(port_name.clone(), pinned_stream);
}
}
let external_outputs = &self.connected_output_channels;
let nodes = {
let mut g = self.nodes.lock().unwrap();
std::mem::take(&mut *g)
};
let (handles, nodes_restored) = self
.run_dataflow(nodes, Some(external_inputs), Some(external_outputs))
.await?;
self.execution_handles.lock().await.clear();
self.execution_handles.lock().await.extend(handles);
self.nodes_restored_after_run = Some(nodes_restored);
Ok(())
}
async fn run_dataflow(
&self,
mut nodes: HashMap<String, Box<dyn Node>>,
external_inputs: Option<HashMap<String, crate::node::InputStream>>,
external_outputs: Option<
&HashMap<String, tokio::sync::mpsc::Sender<Arc<dyn Any + Send + Sync>>>,
>,
) -> Result<
(
Vec<JoinHandle<Result<(), GraphExecutionError>>>,
NodesRestoredMap,
),
GraphExecutionError,
> {
let edges = self.get_edges();
type Payload = Arc<dyn Any + Send + Sync>;
let mut input_rx: HashMap<(String, String), tokio::sync::mpsc::Receiver<Payload>> =
HashMap::new();
let mut output_txs: HashMap<(String, String), Vec<tokio::sync::mpsc::Sender<Payload>>> =
HashMap::new();
for edge in &edges {
let (tx, rx) = tokio::sync::mpsc::channel(DATAFLOW_CHANNEL_CAPACITY);
input_rx.insert(
(
edge.target_node().to_string(),
edge.target_port().to_string(),
),
rx,
);
output_txs
.entry((
edge.source_node().to_string(),
edge.source_port().to_string(),
))
.or_default()
.push(tx);
}
let nodes_restored = Arc::new(Mutex::new(HashMap::new()));
let mut all_handles = Vec::new();
if let Some(mut external_inputs) = external_inputs {
for (external_port, stream) in external_inputs.drain() {
if let Some(mapping) = self.input_port_mapping.get(&external_port) {
let (tx, rx) = tokio::sync::mpsc::channel(DATAFLOW_CHANNEL_CAPACITY);
input_rx.insert((mapping.node.clone(), mapping.port.clone()), rx);
let stop_signal = Arc::clone(&self.stop_signal);
let handle = tokio::spawn(async move {
let mut stream = stream;
loop {
tokio::select! {
_ = stop_signal.notified() => break,
item = stream.next() => {
match item {
Some(item) => {
if tx.send(item).await.is_err() {
break;
}
}
None => break,
}
}
}
}
drop(tx);
Ok(()) as Result<(), GraphExecutionError>
});
all_handles.push(handle);
}
}
}
let stop_signal = Arc::clone(&self.stop_signal);
let output_port_mapping = self.output_port_mapping.clone();
for (node_name, node) in nodes.drain() {
let node_input_keys: Vec<(String, String)> = input_rx
.keys()
.filter(|(n, _)| n == &node_name)
.cloned()
.collect();
let mut input_streams: InputStreams = HashMap::new();
for (_, port) in node_input_keys {
if let Some(rx) = input_rx.remove(&(node_name.clone(), port.clone())) {
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
input_streams.insert(port, Box::pin(stream) as crate::node::InputStream);
}
}
let output_txs_for_node: HashMap<String, Vec<tokio::sync::mpsc::Sender<Payload>>> =
output_txs
.iter()
.filter(|((n, _), _)| n == &node_name)
.map(|(k, v)| (k.1.clone(), v.clone()))
.collect();
let nodes_restored = Arc::clone(&nodes_restored);
let stop_signal = Arc::clone(&stop_signal);
let output_port_mapping = output_port_mapping.clone();
let external_outputs = external_outputs.cloned();
let handle = tokio::spawn(async move {
let node_outputs = match node.execute(input_streams).await {
Ok(o) => o,
Err(e) => {
let _ = nodes_restored.lock().await.insert(node_name.clone(), node);
return Err(format!("Node '{}' execution error: {}", node_name, e).into());
}
};
nodes_restored.lock().await.insert(node_name.clone(), node);
let mut forwarder_handles = Vec::new();
for (port_name, stream) in node_outputs {
let mut senders = output_txs_for_node
.get(&port_name)
.cloned()
.unwrap_or_default();
let is_exposed = output_port_mapping
.values()
.any(|mapping| mapping.node == node_name && mapping.port == port_name);
if is_exposed
&& let Some(ref outputs) = external_outputs
&& let Some((external_port, _)) = output_port_mapping
.iter()
.find(|(_, m)| m.node == node_name && m.port == port_name)
&& let Some(tx) = outputs.get(external_port)
{
senders.push(tx.clone());
}
let stop_signal = Arc::clone(&stop_signal);
let handle = tokio::spawn(async move {
let mut stream = stream;
loop {
tokio::select! {
_ = stop_signal.notified() => break,
item = stream.next() => {
match item {
Some(item) => {
for tx in &senders {
if tx.send(item.clone()).await.is_err() {
return Ok(());
}
}
}
None => break,
}
}
}
}
Ok(()) as Result<(), GraphExecutionError>
});
forwarder_handles.push(handle);
}
for h in forwarder_handles {
let _ = h.await;
}
Ok(())
});
all_handles.push(handle);
}
Ok((all_handles, nodes_restored))
}
pub fn start(&self) {
self.execution_state.store(1, Ordering::Release); self.pause_signal.notify_waiters(); }
pub fn pause(&self) {
self.execution_state.store(2, Ordering::Release); }
pub fn resume(&self) {
self.execution_state.store(1, Ordering::Release); self.pause_signal.notify_waiters();
}
pub async fn stop(&self) -> Result<(), GraphExecutionError> {
self.stop_signal.notify_waiters();
let handles = {
let mut handles_guard = self.execution_handles.lock().await;
std::mem::take(&mut *handles_guard)
};
for handle in handles {
let _ = handle.await;
}
self.execution_state.store(0, Ordering::Release);
self.execution_handles.lock().await.clear();
Ok(())
}
async fn execute_internal(
&self,
external_inputs: Option<InputStreams>,
) -> Result<OutputStreams, GraphExecutionError> {
type Payload = Arc<dyn Any + Send + Sync>;
let nodes = {
let mut g = self.nodes.lock().unwrap();
std::mem::take(&mut *g)
};
let mut external_output_txs: HashMap<String, tokio::sync::mpsc::Sender<Payload>> =
HashMap::new();
let mut output_rxs: HashMap<String, tokio::sync::mpsc::Receiver<Payload>> = HashMap::new();
for external_port in self.output_port_mapping.keys() {
let (tx, rx) = tokio::sync::mpsc::channel(DATAFLOW_CHANNEL_CAPACITY);
external_output_txs.insert(external_port.clone(), tx);
output_rxs.insert(external_port.clone(), rx);
}
let (handles, nodes_restored) = self
.run_dataflow(nodes, external_inputs, Some(&external_output_txs))
.await?;
for handle in handles {
handle.await??;
}
{
let mut restored = nodes_restored.lock().await;
*self.nodes.lock().unwrap() = std::mem::take(&mut *restored);
}
let mut external_outputs: OutputStreams = HashMap::new();
for (external_port, rx) in output_rxs {
let stream =
Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)) as crate::node::OutputStream;
external_outputs.insert(external_port, stream);
}
Ok(external_outputs)
}
pub async fn wait_for_completion(&mut self) -> Result<(), GraphExecutionError> {
let handles = {
let mut handles_guard = self.execution_handles.lock().await;
std::mem::take(&mut *handles_guard)
};
for handle in handles {
handle.await??;
}
if let Some(restore) = self.nodes_restored_after_run.take() {
let mut map = restore.lock().await;
*self.nodes.lock().unwrap() = std::mem::take(&mut *map);
}
Ok(())
}
}
pub fn topological_sort(
nodes: &[&dyn Node],
edges: &[&Edge],
) -> Result<Vec<String>, GraphExecutionError> {
let mut in_degree: HashMap<String, usize> = HashMap::new();
let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
for node in nodes {
in_degree.insert(node.name().to_string(), 0);
adjacency.insert(node.name().to_string(), Vec::new());
}
for edge in edges {
let source = edge.source_node().to_string();
let target = edge.target_node().to_string();
adjacency.get_mut(&source).unwrap().push(target.clone());
*in_degree.get_mut(&target).unwrap() += 1;
}
let mut queue: VecDeque<String> = VecDeque::new();
for (node_name, °ree) in &in_degree {
if degree == 0 {
queue.push_back(node_name.clone());
}
}
let mut result = Vec::new();
while let Some(node_name) = queue.pop_front() {
result.push(node_name.clone());
if let Some(neighbors) = adjacency.get(&node_name) {
for neighbor in neighbors {
let degree = in_degree.get_mut(neighbor).unwrap();
*degree -= 1;
if *degree == 0 {
queue.push_back(neighbor.clone());
}
}
}
}
if result.len() != nodes.len() {
return Err("Graph contains cycles".into());
}
Ok(result)
}
#[async_trait]
impl Node for Graph {
fn name(&self) -> &str {
&self.name
}
fn set_name(&mut self, name: &str) {
self.name = name.to_string();
}
fn input_port_names(&self) -> &[String] {
&self.input_port_names
}
fn output_port_names(&self) -> &[String] {
&self.output_port_names
}
fn has_input_port(&self, name: &str) -> bool {
self.input_port_names.contains(&name.to_string())
}
fn has_output_port(&self, name: &str) -> bool {
self.output_port_names.contains(&name.to_string())
}
fn execute(
&self,
inputs: InputStreams,
) -> Pin<Box<dyn Future<Output = Result<OutputStreams, NodeExecutionError>> + Send + '_>> {
Box::pin(async move {
self
.execute_internal(Some(inputs))
.await
.map_err(|e| format!("Graph execution error: {}", e).into())
})
}
}