use crate::graph;
use crate::graph::Graph;
use crate::node::{InputStreams, Node, NodeExecutionError, OutputStreams};
use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::{Stream, StreamExt};
struct MockProducerNode {
name: String,
data: Vec<i32>,
output_port_names: Vec<String>,
}
impl MockProducerNode {
fn new(name: String, data: Vec<i32>) -> Self {
Self {
name,
data,
output_port_names: vec!["out".to_string()],
}
}
}
#[async_trait]
impl Node for MockProducerNode {
fn name(&self) -> &str {
&self.name
}
fn set_name(&mut self, name: &str) {
self.name = name.to_string();
}
fn input_port_names(&self) -> &[String] {
&[]
}
fn output_port_names(&self) -> &[String] {
&self.output_port_names
}
fn has_input_port(&self, _name: &str) -> bool {
false
}
fn has_output_port(&self, name: &str) -> bool {
name == "out"
}
fn execute(
&self,
_inputs: InputStreams,
) -> Pin<
Box<dyn std::future::Future<Output = Result<OutputStreams, NodeExecutionError>> + Send + '_>,
> {
let data = self.data.clone();
Box::pin(async move {
let (tx, rx) = mpsc::channel(10);
tokio::spawn(async move {
for item in data {
let _ = tx.send(Arc::new(item) as Arc<dyn Any + Send + Sync>).await;
}
});
let mut outputs = HashMap::new();
outputs.insert(
"out".to_string(),
Box::pin(ReceiverStream::new(rx))
as Pin<Box<dyn Stream<Item = Arc<dyn Any + Send + Sync>> + Send>>,
);
Ok(outputs)
})
}
}
struct MockTransformNode {
name: String,
input_port_names: Vec<String>,
output_port_names: Vec<String>,
}
impl MockTransformNode {
fn new(name: String) -> Self {
Self {
name,
input_port_names: vec!["in".to_string()],
output_port_names: vec!["out".to_string()],
}
}
}
#[async_trait]
impl Node for MockTransformNode {
fn name(&self) -> &str {
&self.name
}
fn set_name(&mut self, name: &str) {
self.name = name.to_string();
}
fn input_port_names(&self) -> &[String] {
&self.input_port_names
}
fn output_port_names(&self) -> &[String] {
&self.output_port_names
}
fn has_input_port(&self, name: &str) -> bool {
name == "in"
}
fn has_output_port(&self, name: &str) -> bool {
name == "out"
}
fn execute(
&self,
mut inputs: InputStreams,
) -> Pin<
Box<dyn std::future::Future<Output = Result<OutputStreams, NodeExecutionError>> + Send + '_>,
> {
Box::pin(async move {
let input_stream = inputs.remove("in").ok_or("Missing 'in' input")?;
let mut outputs = HashMap::new();
outputs.insert(
"out".to_string(),
Box::pin(async_stream::stream! {
let mut input = input_stream;
while let Some(item) = input.next().await {
yield item;
}
}) as Pin<Box<dyn Stream<Item = Arc<dyn Any + Send + Sync>> + Send>>,
);
Ok(outputs)
})
}
}
struct MockSinkNode {
name: String,
input_port_names: Vec<String>,
}
impl MockSinkNode {
fn new(name: String) -> Self {
Self {
name,
input_port_names: vec!["in".to_string()],
}
}
}
#[async_trait]
impl Node for MockSinkNode {
fn name(&self) -> &str {
&self.name
}
fn set_name(&mut self, name: &str) {
self.name = name.to_string();
}
fn input_port_names(&self) -> &[String] {
&self.input_port_names
}
fn output_port_names(&self) -> &[String] {
&[]
}
fn has_input_port(&self, name: &str) -> bool {
name == "in"
}
fn has_output_port(&self, _name: &str) -> bool {
false
}
fn execute(
&self,
mut inputs: InputStreams,
) -> Pin<
Box<dyn std::future::Future<Output = Result<OutputStreams, NodeExecutionError>> + Send + '_>,
> {
Box::pin(async move {
let _input_stream = inputs.remove("in").ok_or("Missing 'in' input")?;
Ok(HashMap::new())
})
}
}
#[tokio::test]
async fn test_graph_macro_simple_linear_pipeline() {
let mut graph: Graph = graph! {
producer: MockProducerNode::new("producer".to_string(), vec![1, 2, 3]),
transform: MockTransformNode::new("transform".to_string()),
sink: MockSinkNode::new("sink".to_string()),
producer.out => transform.in,
transform.out => sink.in
};
assert_eq!(graph.node_count(), 3);
assert_eq!(graph.edge_count(), 2);
assert!(graph.has_node("producer"));
assert!(graph.has_node("transform"));
assert!(graph.has_node("sink"));
assert!(
graph
.find_edge_by_nodes_and_ports("producer", "out", "transform", "in")
.is_some()
);
assert!(
graph
.find_edge_by_nodes_and_ports("transform", "out", "sink", "in")
.is_some()
);
assert!(Graph::execute(&mut graph).await.is_ok());
assert!(graph.wait_for_completion().await.is_ok());
}
#[test]
#[should_panic(expected = "Fan-out not supported")]
fn test_graph_macro_fan_out_rejected() {
use crate::graph_builder::GraphBuilder;
let builder = GraphBuilder::new("test")
.add_node(
"source",
Box::new(MockProducerNode::new("source".to_string(), vec![1, 2, 3])),
)
.add_node(
"filter1",
Box::new(MockTransformNode::new("filter1".to_string())),
)
.add_node(
"filter2",
Box::new(MockTransformNode::new("filter2".to_string())),
)
.connect("source", "out", "filter1", "in")
.connect("source", "out", "filter2", "in");
let _graph = builder.build().unwrap();
}
#[tokio::test]
async fn test_graph_macro_graph_io_with_values() {
let mut graph: Graph = graph! {
transform: MockTransformNode::new("transform".to_string()),
sink: MockSinkNode::new("sink".to_string()),
graph.config: 42i32 => transform.in
};
graph
.add_edge(crate::edge::Edge {
source_node: "transform".to_string(),
source_port: "out".to_string(),
target_node: "sink".to_string(),
target_port: "in".to_string(),
})
.unwrap();
graph
.expose_output_port("transform", "out", "result")
.unwrap();
assert_eq!(graph.node_count(), 2);
assert!(graph.has_node("transform"));
assert!(graph.has_node("sink"));
assert!(graph.has_input_port("config"));
assert!(graph.has_output_port("result"));
}
#[tokio::test]
async fn test_graph_macro_graph_io_without_values() {
let mut graph: Graph = graph! {
transform: MockTransformNode::new("transform".to_string()),
graph.config => transform.in
};
graph
.expose_output_port("transform", "out", "result")
.unwrap();
assert_eq!(graph.node_count(), 1);
assert!(graph.has_node("transform"));
assert!(graph.has_input_port("config"));
assert!(graph.has_output_port("result"));
}
#[test]
#[should_panic(expected = "Node name 'graph' is reserved")]
fn test_graph_macro_invalid_graph_node_name() {
let _graph: Graph = graph! {
graph: MockProducerNode::new("graph".to_string(), vec![1, 2, 3])
};
}
#[tokio::test]
async fn test_graph_macro_combined_patterns() {
let mut graph: Graph = graph! {
producer: MockProducerNode::new("producer".to_string(), vec![1, 2, 3]),
transform: MockTransformNode::new("transform".to_string()),
sink: MockSinkNode::new("sink".to_string()),
producer.out => transform.in,
transform.out => sink.in
};
graph
.expose_input_port("transform", "in", "config")
.unwrap();
graph
.expose_output_port("transform", "out", "output")
.unwrap();
assert_eq!(graph.node_count(), 3);
assert!(graph.has_node("producer"));
assert!(graph.has_node("transform"));
assert!(graph.has_node("sink"));
assert!(
graph
.find_edge_by_nodes_and_ports("producer", "out", "transform", "in")
.is_some()
);
assert!(
graph
.find_edge_by_nodes_and_ports("transform", "out", "sink", "in")
.is_some()
);
assert!(graph.has_input_port("config"));
assert!(graph.has_output_port("output"));
assert!(Graph::execute(&mut graph).await.is_ok());
assert!(graph.wait_for_completion().await.is_ok());
}