use std::collections::HashMap;
use std::sync::Arc;
use futures::StreamExt;
use streamweave::node::Node;
use tokio_stream::wrappers::ReceiverStream;
use super::FindStartNode;
use super::find_start::find_start_id;
use super::find_start::process_find_start_item;
#[tokio::test]
async fn node_execute_skips_wrong_type() {
let node = FindStartNode::new("find");
let (tx, rx) = tokio::sync::mpsc::channel(4);
tx.send(Arc::new(1_u8) as Arc<dyn std::any::Any + Send + Sync>)
.await
.unwrap();
drop(tx);
let mut inputs: streamweave::node::InputStreams = HashMap::new();
inputs.insert(
"in".to_string(),
Box::pin(ReceiverStream::new(rx))
as std::pin::Pin<
Box<dyn futures::Stream<Item = Arc<dyn std::any::Any + Send + Sync>> + Send>,
>,
);
let mut outputs = node.execute(inputs).await.unwrap();
let mut out = outputs.remove("out").unwrap();
let item: Option<Arc<dyn std::any::Any + Send + Sync>> = out.next().await;
assert!(item.is_none());
}
#[tokio::test]
async fn node_execute_emits_start_id() {
let dot = r#"digraph G { start [shape=Mdiamond] exit [shape=Msquare] start -> exit }"#;
let graph = crate::dot_parser::parse_dot(dot).unwrap();
let node = FindStartNode::new("find");
let (tx, rx) = tokio::sync::mpsc::channel(4);
tx.send(Arc::new(graph) as Arc<dyn std::any::Any + Send + Sync>)
.await
.unwrap();
drop(tx);
let mut inputs: streamweave::node::InputStreams = HashMap::new();
inputs.insert(
"in".to_string(),
Box::pin(ReceiverStream::new(rx))
as std::pin::Pin<
Box<dyn futures::Stream<Item = Arc<dyn std::any::Any + Send + Sync>> + Send>,
>,
);
let mut outputs = node.execute(inputs).await.unwrap();
let mut out = outputs.remove("out").unwrap();
let item: Option<Arc<dyn std::any::Any + Send + Sync>> = out.next().await;
assert!(item.is_some());
let id = item.unwrap().downcast::<String>().unwrap();
assert_eq!(*id, "start");
}
#[test]
fn node_trait_methods() {
let mut node = FindStartNode::new("find");
assert_eq!(node.name(), "find");
node.set_name("start");
assert_eq!(node.name(), "start");
assert!(node.has_input_port("in"));
assert!(node.has_output_port("out"));
}
#[test]
fn new_creates_node() {
let n = FindStartNode::new("find");
assert_eq!(n.name(), "find");
assert!(n.has_input_port("in"));
assert!(n.has_output_port("out"));
}
#[test]
fn find_start_id_returns_start_node_id() {
let dot = r#"digraph G { start [shape=Mdiamond] exit [shape=Msquare] start -> exit }"#;
let graph = crate::dot_parser::parse_dot(dot).unwrap();
let id = find_start_id(&graph);
assert_eq!(id, Some("start".to_string()));
}
#[test]
fn find_start_id_returns_none_when_no_start() {
let dot = r#"digraph G { a -> b }"#;
let graph = crate::dot_parser::parse_dot(dot).unwrap();
let id = find_start_id(&graph);
assert!(id.is_none());
}
#[test]
fn process_find_start_item_returns_some_for_graph_with_start() {
let dot = r#"digraph G { start [shape=Mdiamond] exit [shape=Msquare] start -> exit }"#;
let graph = crate::dot_parser::parse_dot(dot).unwrap();
let item = Arc::new(graph) as Arc<dyn std::any::Any + Send + Sync>;
let id = process_find_start_item(item);
assert_eq!(id, Some("start".to_string()));
}
#[test]
fn process_find_start_item_returns_none_for_wrong_type() {
let item = Arc::new(42_u32) as Arc<dyn std::any::Any + Send + Sync>;
let id = process_find_start_item(item);
assert!(id.is_none());
}