1use crate::core::data::{NodeId, Port, PortData, PortId};
4use crate::core::error::{GraphError, Result};
5use petgraph::algo::toposort;
6use petgraph::graph::{DiGraph, NodeIndex};
7use petgraph::Direction;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12pub type NodeFunction =
14 Arc<dyn Fn(&HashMap<PortId, PortData>) -> Result<HashMap<PortId, PortData>> + Send + Sync>;
15
16#[derive(Clone)]
18pub struct NodeConfig {
19 pub id: NodeId,
21 pub name: String,
23 pub description: Option<String>,
25 pub input_ports: Vec<Port>,
27 pub output_ports: Vec<Port>,
29 pub function: NodeFunction,
31}
32
33impl NodeConfig {
34 pub fn new(
36 id: impl Into<NodeId>,
37 name: impl Into<String>,
38 input_ports: Vec<Port>,
39 output_ports: Vec<Port>,
40 function: NodeFunction,
41 ) -> Self {
42 Self {
43 id: id.into(),
44 name: name.into(),
45 description: None,
46 input_ports,
47 output_ports,
48 function,
49 }
50 }
51
52 pub fn with_description(mut self, description: impl Into<String>) -> Self {
54 self.description = Some(description.into());
55 self
56 }
57}
58
59#[derive(Clone)]
61pub struct Node {
62 pub config: NodeConfig,
64 pub inputs: HashMap<PortId, PortData>,
66 pub outputs: HashMap<PortId, PortData>,
68}
69
70impl Node {
71 pub fn new(config: NodeConfig) -> Self {
73 Self {
74 config,
75 inputs: HashMap::new(),
76 outputs: HashMap::new(),
77 }
78 }
79
80 pub fn set_input(&mut self, port_id: impl Into<PortId>, data: PortData) {
82 self.inputs.insert(port_id.into(), data);
83 }
84
85 pub fn get_output(&self, port_id: &str) -> Option<&PortData> {
87 self.outputs.get(port_id)
88 }
89
90 pub fn execute(&mut self) -> Result<()> {
92 for port in &self.config.input_ports {
94 if port.required && !self.inputs.contains_key(&port.broadcast_name) {
95 return Err(GraphError::MissingInput {
96 node: self.config.id.clone(),
97 port: port.broadcast_name.clone(),
98 });
99 }
100 }
101
102 let mut impl_inputs = HashMap::new();
104 for port in &self.config.input_ports {
105 if let Some(data) = self.inputs.get(&port.broadcast_name) {
106 impl_inputs.insert(port.impl_name.clone(), data.clone());
107 }
108 }
109
110 let impl_outputs = (self.config.function)(&impl_inputs)?;
112
113 self.outputs.clear();
115 for port in &self.config.output_ports {
116 if let Some(data) = impl_outputs.get(&port.impl_name) {
117 self.outputs
118 .insert(port.broadcast_name.clone(), data.clone());
119 }
120 }
121
122 Ok(())
123 }
124
125 pub fn clear_inputs(&mut self) {
127 self.inputs.clear();
128 }
129
130 pub fn clear_outputs(&mut self) {
132 self.outputs.clear();
133 }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct Edge {
139 pub from_node: NodeId,
141 pub from_port: PortId,
143 pub to_node: NodeId,
145 pub to_port: PortId,
147}
148
149impl Edge {
150 pub fn new(
152 from_node: impl Into<NodeId>,
153 from_port: impl Into<PortId>,
154 to_node: impl Into<NodeId>,
155 to_port: impl Into<PortId>,
156 ) -> Self {
157 Self {
158 from_node: from_node.into(),
159 from_port: from_port.into(),
160 to_node: to_node.into(),
161 to_port: to_port.into(),
162 }
163 }
164}
165
166pub type MergeFunction = Arc<dyn Fn(Vec<&PortData>) -> Result<PortData> + Send + Sync>;
168
169pub struct MergeConfig {
171 pub branches: Vec<String>,
173 pub port: String,
175 pub merge_fn: Option<MergeFunction>,
177}
178
179impl MergeConfig {
180 pub fn new(branches: Vec<String>, port: String) -> Self {
182 Self {
183 branches,
184 port,
185 merge_fn: None,
186 }
187 }
188
189 pub fn with_merge_fn(mut self, merge_fn: MergeFunction) -> Self {
191 self.merge_fn = Some(merge_fn);
192 self
193 }
194}
195
196pub type VariantFunction = Arc<dyn Fn(usize) -> PortData + Send + Sync>;
198
199pub struct VariantConfig {
201 pub name_prefix: String,
203 pub count: usize,
205 pub variant_fn: VariantFunction,
207 pub param_name: String,
209 pub parallel: bool,
211}
212
213impl VariantConfig {
214 pub fn new(
216 name_prefix: impl Into<String>,
217 count: usize,
218 param_name: impl Into<String>,
219 variant_fn: VariantFunction,
220 ) -> Self {
221 Self {
222 name_prefix: name_prefix.into(),
223 count,
224 variant_fn,
225 param_name: param_name.into(),
226 parallel: true,
227 }
228 }
229
230 pub fn with_parallel(mut self, parallel: bool) -> Self {
232 self.parallel = parallel;
233 self
234 }
235}
236
237#[derive(Clone)]
239pub struct Graph {
240 graph: DiGraph<Node, Edge>,
242 node_indices: HashMap<NodeId, NodeIndex>,
244 branches: HashMap<String, Graph>,
246 node_order: Vec<NodeId>,
248 strict_edge_mapping: bool,
250}
251
252impl Graph {
253 pub fn new() -> Self {
255 Self {
256 graph: DiGraph::new(),
257 node_indices: HashMap::new(),
258 branches: HashMap::new(),
259 node_order: Vec::new(),
260 strict_edge_mapping: false,
261 }
262 }
263
264 pub fn with_strict_edges() -> Self {
268 Self {
269 graph: DiGraph::new(),
270 node_indices: HashMap::new(),
271 branches: HashMap::new(),
272 node_order: Vec::new(),
273 strict_edge_mapping: true,
274 }
275 }
276
277 pub fn set_strict_edge_mapping(&mut self, strict: bool) {
279 self.strict_edge_mapping = strict;
280 }
281
282 pub fn add(&mut self, node: Node) -> Result<()> {
284 let node_id = node.config.id.clone();
285
286 if self.node_indices.contains_key(&node_id) {
287 return Err(GraphError::InvalidGraph(format!(
288 "Node with ID '{}' already exists",
289 node_id
290 )));
291 }
292
293 let index = self.graph.add_node(node);
294 self.node_indices.insert(node_id.clone(), index);
295
296 if !self.strict_edge_mapping && !self.node_order.is_empty() {
298 self.auto_connect_to_previous(&node_id)?;
299 }
300
301 self.node_order.push(node_id);
302 Ok(())
303 }
304
305 fn auto_connect_to_previous(&mut self, new_node_id: &str) -> Result<()> {
307 let edges_to_add = if let Some(prev_node_id) = self.node_order.last().cloned() {
308 let prev_node = self.get_node(&prev_node_id)?;
309 let new_node = self.get_node(new_node_id)?;
310
311 let mut edges = Vec::new();
312 for out_port in &prev_node.config.output_ports {
314 for in_port in &new_node.config.input_ports {
315 let should_connect = out_port.broadcast_name == in_port.broadcast_name
317 || (prev_node.config.output_ports.len() == 1
318 && new_node.config.input_ports.len() == 1);
319
320 if should_connect {
321 edges.push(Edge::new(
322 &prev_node_id,
323 &out_port.broadcast_name,
324 new_node_id,
325 &in_port.broadcast_name,
326 ));
327 break; }
329 }
330 }
331 edges
332 } else {
333 Vec::new()
334 };
335
336 for edge in edges_to_add {
338 self.add_edge(edge)?;
339 }
340
341 Ok(())
342 }
343
344 #[deprecated(since = "0.2.0", note = "Use `add` instead")]
346 pub fn add_node(&mut self, node: Node) -> Result<()> {
347 self.add(node)
348 }
349
350 pub fn add_edge(&mut self, edge: Edge) -> Result<()> {
352 let from_idx = self
353 .node_indices
354 .get(&edge.from_node)
355 .ok_or_else(|| GraphError::NodeNotFound(edge.from_node.clone()))?;
356 let to_idx = self
357 .node_indices
358 .get(&edge.to_node)
359 .ok_or_else(|| GraphError::NodeNotFound(edge.to_node.clone()))?;
360
361 let from_node = &self.graph[*from_idx];
363 if !from_node
364 .config
365 .output_ports
366 .iter()
367 .any(|p| p.broadcast_name == edge.from_port)
368 {
369 return Err(GraphError::PortError(format!(
370 "Output port '{}' not found on node '{}'",
371 edge.from_port, edge.from_node
372 )));
373 }
374
375 let to_node = &self.graph[*to_idx];
377 if !to_node
378 .config
379 .input_ports
380 .iter()
381 .any(|p| p.broadcast_name == edge.to_port)
382 {
383 return Err(GraphError::PortError(format!(
384 "Input port '{}' not found on node '{}'",
385 edge.to_port, edge.to_node
386 )));
387 }
388
389 self.graph.add_edge(*from_idx, *to_idx, edge);
390 Ok(())
391 }
392
393 pub fn get_node(&self, node_id: &str) -> Result<&Node> {
395 let idx = self
396 .node_indices
397 .get(node_id)
398 .ok_or_else(|| GraphError::NodeNotFound(node_id.to_string()))?;
399 Ok(&self.graph[*idx])
400 }
401
402 pub fn get_node_mut(&mut self, node_id: &str) -> Result<&mut Node> {
404 let idx = self
405 .node_indices
406 .get(node_id)
407 .ok_or_else(|| GraphError::NodeNotFound(node_id.to_string()))?;
408 Ok(&mut self.graph[*idx])
409 }
410
411 pub fn validate(&self) -> Result<()> {
413 match toposort(&self.graph, None) {
414 Ok(_) => Ok(()),
415 Err(cycle) => {
416 let node = &self.graph[cycle.node_id()];
417 Err(GraphError::CycleDetected(node.config.id.clone()))
418 }
419 }
420 }
421
422 pub fn topological_order(&self) -> Result<Vec<NodeId>> {
424 let sorted = toposort(&self.graph, None).map_err(|cycle| {
425 let node = &self.graph[cycle.node_id()];
426 GraphError::CycleDetected(node.config.id.clone())
427 })?;
428
429 Ok(sorted
430 .into_iter()
431 .map(|idx| self.graph[idx].config.id.clone())
432 .collect())
433 }
434
435 pub fn nodes(&self) -> Vec<&Node> {
437 self.graph
438 .node_indices()
439 .map(|idx| &self.graph[idx])
440 .collect()
441 }
442
443 pub fn edges(&self) -> Vec<&Edge> {
445 self.graph
446 .edge_indices()
447 .map(|idx| &self.graph[idx])
448 .collect()
449 }
450
451 pub fn node_count(&self) -> usize {
453 self.graph.node_count()
454 }
455
456 pub fn edge_count(&self) -> usize {
458 self.graph.edge_count()
459 }
460
461 pub fn incoming_edges(&self, node_id: &str) -> Result<Vec<&Edge>> {
463 let idx = self
464 .node_indices
465 .get(node_id)
466 .ok_or_else(|| GraphError::NodeNotFound(node_id.to_string()))?;
467
468 Ok(self
469 .graph
470 .edges_directed(*idx, Direction::Incoming)
471 .map(|e| e.weight())
472 .collect())
473 }
474
475 pub fn outgoing_edges(&self, node_id: &str) -> Result<Vec<&Edge>> {
477 let idx = self
478 .node_indices
479 .get(node_id)
480 .ok_or_else(|| GraphError::NodeNotFound(node_id.to_string()))?;
481
482 Ok(self
483 .graph
484 .edges_directed(*idx, Direction::Outgoing)
485 .map(|e| e.weight())
486 .collect())
487 }
488
489 pub fn auto_connect(&mut self) -> Result<usize> {
500 let mut edges_created = 0;
501 let node_ids: Vec<NodeId> = self.nodes().iter().map(|n| n.config.id.clone()).collect();
502
503 for from_node_id in &node_ids {
504 let from_node = self.get_node(from_node_id)?;
505 let output_ports: Vec<PortId> = from_node
506 .config
507 .output_ports
508 .iter()
509 .map(|p| p.broadcast_name.clone())
510 .collect();
511
512 for to_node_id in &node_ids {
513 if from_node_id == to_node_id {
514 continue;
515 }
516
517 let to_node = self.get_node(to_node_id)?;
518 let input_ports: Vec<PortId> = to_node
519 .config
520 .input_ports
521 .iter()
522 .map(|p| p.broadcast_name.clone())
523 .collect();
524
525 for output_port in &output_ports {
527 for input_port in &input_ports {
528 if output_port == input_port {
529 let edge_exists = self.edges().iter().any(|e| {
531 e.from_node == *from_node_id
532 && e.from_port == *output_port
533 && e.to_node == *to_node_id
534 && e.to_port == *input_port
535 });
536
537 if !edge_exists {
538 let edge = Edge::new(
539 from_node_id.clone(),
540 output_port.clone(),
541 to_node_id.clone(),
542 input_port.clone(),
543 );
544 self.add_edge(edge)?;
545 edges_created += 1;
546 }
547 }
548 }
549 }
550 }
551 }
552
553 Ok(edges_created)
554 }
555
556 pub fn with_auto_connect(mut self) -> Result<Self> {
559 self.auto_connect()?;
560 Ok(self)
561 }
562
563 pub fn create_branch(&mut self, name: impl Into<String>) -> Result<&mut Graph> {
565 let name = name.into();
566 if self.branches.contains_key(&name) {
567 return Err(GraphError::InvalidGraph(format!(
568 "Branch '{}' already exists",
569 name
570 )));
571 }
572 self.branches.insert(name.clone(), Graph::new());
573 Ok(self.branches.get_mut(&name).unwrap())
574 }
575
576 pub fn get_branch(&self, name: &str) -> Result<&Graph> {
578 self.branches
579 .get(name)
580 .ok_or_else(|| GraphError::InvalidGraph(format!("Branch '{}' not found", name)))
581 }
582
583 pub fn get_branch_mut(&mut self, name: &str) -> Result<&mut Graph> {
585 self.branches
586 .get_mut(name)
587 .ok_or_else(|| GraphError::InvalidGraph(format!("Branch '{}' not found", name)))
588 }
589
590 pub fn branch_names(&self) -> Vec<String> {
592 self.branches.keys().cloned().collect()
593 }
594
595 pub fn has_branch(&self, name: &str) -> bool {
597 self.branches.contains_key(name)
598 }
599
600 pub fn merge(&mut self, node_id: impl Into<NodeId>, config: MergeConfig) -> Result<()> {
605 for branch_name in &config.branches {
607 if !self.has_branch(branch_name) {
608 return Err(GraphError::InvalidGraph(format!(
609 "Branch '{}' not found for merge operation",
610 branch_name
611 )));
612 }
613 }
614
615 let branch_names = config.branches.clone();
616
617 let merge_fn = config.merge_fn.unwrap_or_else(|| {
619 Arc::new(|inputs: Vec<&PortData>| -> Result<PortData> {
621 Ok(PortData::List(inputs.iter().map(|&d| d.clone()).collect()))
622 })
623 });
624
625 let input_ports: Vec<Port> = branch_names
627 .iter()
628 .map(|name| Port::new(name.clone(), format!("Input from {}", name)))
629 .collect();
630
631 let node_config = NodeConfig::new(
633 node_id,
634 "Merge Node",
635 input_ports,
636 vec![Port::new("merged", "Merged Output")],
637 Arc::new(move |inputs: &HashMap<PortId, PortData>| {
638 let mut collected_inputs = Vec::new();
640 for branch_name in &branch_names {
641 if let Some(data) = inputs.get(branch_name.as_str()) {
642 collected_inputs.push(data);
643 }
644 }
645
646 let merged = merge_fn(collected_inputs)?;
648
649 let mut outputs = HashMap::new();
650 outputs.insert("merged".to_string(), merged);
651 Ok(outputs)
652 }),
653 );
654
655 self.add(Node::new(node_config))
656 }
657
658 pub fn create_variants(&mut self, config: VariantConfig) -> Result<Vec<String>> {
666 let mut branch_names = Vec::new();
667
668 for i in 0..config.count {
669 let branch_name = format!("{}_{}", config.name_prefix, i);
670
671 if self.has_branch(&branch_name) {
673 return Err(GraphError::InvalidGraph(format!(
674 "Variant branch '{}' already exists",
675 branch_name
676 )));
677 }
678
679 let branch = self.create_branch(&branch_name)?;
681
682 let param_value = (config.variant_fn)(i);
684 let param_name = config.param_name.clone();
685
686 let source_config = NodeConfig::new(
687 format!("{}_source", branch_name),
688 format!("Variant Source {}", i),
689 vec![],
690 vec![Port::new(¶m_name, "Variant Parameter")],
691 Arc::new(move |_: &HashMap<PortId, PortData>| {
695 let mut outputs = HashMap::new();
696 outputs.insert(param_name.clone(), param_value.clone());
697 Ok(outputs)
698 }),
699 );
700
701 branch.add(Node::new(source_config))?;
702 branch_names.push(branch_name);
703 }
704
705 Ok(branch_names)
706 }
707}
708
709impl Default for Graph {
710 fn default() -> Self {
711 Self::new()
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718 use crate::core::data::PortData;
719
720 fn dummy_function(inputs: &HashMap<PortId, PortData>) -> Result<HashMap<PortId, PortData>> {
721 let mut outputs = HashMap::new();
722 if let Some(PortData::Int(val)) = inputs.get("input") {
723 outputs.insert("output".to_string(), PortData::Int(val * 2));
724 }
725 Ok(outputs)
726 }
727
728 #[test]
729 fn test_graph_creation() {
730 let graph = Graph::new();
731 assert_eq!(graph.node_count(), 0);
732 assert_eq!(graph.edge_count(), 0);
733 }
734
735 #[test]
736 fn test_add_node() {
737 let mut graph = Graph::new();
738
739 let config = NodeConfig::new(
740 "node1",
741 "Node 1",
742 vec![Port::new("input", "Input")],
743 vec![Port::new("output", "Output")],
744 Arc::new(dummy_function),
745 );
746
747 let node = Node::new(config);
748 assert!(graph.add(node).is_ok());
749 assert_eq!(graph.node_count(), 1);
750 }
751
752 #[test]
753 fn test_duplicate_node_id() {
754 let mut graph = Graph::new();
755
756 let config1 = NodeConfig::new("node1", "Node 1", vec![], vec![], Arc::new(dummy_function));
757
758 let config2 = NodeConfig::new(
759 "node1",
760 "Node 1 Duplicate",
761 vec![],
762 vec![],
763 Arc::new(dummy_function),
764 );
765
766 assert!(graph.add(Node::new(config1)).is_ok());
767 assert!(graph.add(Node::new(config2)).is_err());
768 }
769
770 #[test]
771 fn test_add_edge() {
772 let mut graph = Graph::with_strict_edges();
773
774 let config1 = NodeConfig::new(
775 "node1",
776 "Node 1",
777 vec![],
778 vec![Port::new("output", "Output")],
779 Arc::new(dummy_function),
780 );
781
782 let config2 = NodeConfig::new(
783 "node2",
784 "Node 2",
785 vec![Port::new("input", "Input")],
786 vec![],
787 Arc::new(dummy_function),
788 );
789
790 graph.add(Node::new(config1)).unwrap();
791 graph.add(Node::new(config2)).unwrap();
792
793 let edge = Edge::new("node1", "output", "node2", "input");
794 assert!(graph.add_edge(edge).is_ok());
795 assert_eq!(graph.edge_count(), 1);
796 }
797
798 #[test]
799 fn test_topological_order() {
800 let mut graph = Graph::new();
801
802 for i in 1..=3 {
804 let outputs = if i < 3 {
805 vec![Port::new("output", "Output")]
806 } else {
807 vec![]
808 };
809 let inputs = if i > 1 {
810 vec![Port::new("input", "Input")]
811 } else {
812 vec![]
813 };
814
815 let config = NodeConfig::new(
816 format!("node{}", i),
817 format!("Node {}", i),
818 inputs,
819 outputs,
820 Arc::new(dummy_function),
821 );
822 graph.add(Node::new(config)).unwrap();
823 }
824
825 graph
826 .add_edge(Edge::new("node1", "output", "node2", "input"))
827 .unwrap();
828 graph
829 .add_edge(Edge::new("node2", "output", "node3", "input"))
830 .unwrap();
831
832 let order = graph.topological_order().unwrap();
833 assert_eq!(order.len(), 3);
834 assert_eq!(order[0], "node1");
835 assert_eq!(order[1], "node2");
836 assert_eq!(order[2], "node3");
837 }
838
839 #[test]
840 fn test_cycle_detection() {
841 let mut graph = Graph::new();
842
843 let config1 = NodeConfig::new(
845 "node1",
846 "Node 1",
847 vec![Port::new("input", "Input")],
848 vec![Port::new("output", "Output")],
849 Arc::new(dummy_function),
850 );
851
852 let config2 = NodeConfig::new(
853 "node2",
854 "Node 2",
855 vec![Port::new("input", "Input")],
856 vec![Port::new("output", "Output")],
857 Arc::new(dummy_function),
858 );
859
860 graph.add(Node::new(config1)).unwrap();
861 graph.add(Node::new(config2)).unwrap();
862
863 graph
864 .add_edge(Edge::new("node1", "output", "node2", "input"))
865 .unwrap();
866 graph
867 .add_edge(Edge::new("node2", "output", "node1", "input"))
868 .unwrap();
869
870 assert!(graph.validate().is_err());
871 }
872
873 #[test]
874 fn test_create_branch() {
875 let mut graph = Graph::new();
876
877 let branch = graph.create_branch("branch_a");
879 assert!(branch.is_ok());
880
881 assert!(graph.has_branch("branch_a"));
883 assert_eq!(graph.branch_names().len(), 1);
884 assert_eq!(graph.branch_names()[0], "branch_a");
885 }
886
887 #[test]
888 fn test_duplicate_branch_name() {
889 let mut graph = Graph::new();
890
891 graph.create_branch("branch_a").unwrap();
892 let result = graph.create_branch("branch_a");
893 assert!(result.is_err());
894 }
895
896 #[test]
897 fn test_branch_isolation() {
898 let mut graph = Graph::new();
899
900 let branch_a = graph.create_branch("branch_a").unwrap();
902 let config_a = NodeConfig::new(
903 "node_a",
904 "Node A",
905 vec![],
906 vec![Port::new("output", "Output")],
907 Arc::new(dummy_function),
908 );
909 branch_a.add(Node::new(config_a)).unwrap();
910
911 let branch_b = graph.create_branch("branch_b").unwrap();
912 let config_b = NodeConfig::new(
913 "node_b",
914 "Node B",
915 vec![],
916 vec![Port::new("output", "Output")],
917 Arc::new(dummy_function),
918 );
919 branch_b.add(Node::new(config_b)).unwrap();
920
921 assert_eq!(graph.get_branch("branch_a").unwrap().node_count(), 1);
923 assert_eq!(graph.get_branch("branch_b").unwrap().node_count(), 1);
924
925 assert!(graph
927 .get_branch("branch_a")
928 .unwrap()
929 .get_node("node_b")
930 .is_err());
931 assert!(graph
932 .get_branch("branch_b")
933 .unwrap()
934 .get_node("node_a")
935 .is_err());
936 }
937
938 #[test]
939 fn test_get_nonexistent_branch() {
940 let graph = Graph::new();
941 assert!(graph.get_branch("nonexistent").is_err());
942 }
943
944 #[test]
945 fn test_merge_basic() {
946 let mut graph = Graph::new();
947
948 graph.create_branch("branch_a").unwrap();
950 graph.create_branch("branch_b").unwrap();
951
952 let merge_config = MergeConfig::new(
954 vec!["branch_a".to_string(), "branch_b".to_string()],
955 "output".to_string(),
956 );
957
958 let result = graph.merge("merge_node", merge_config);
960 assert!(result.is_ok());
961
962 assert_eq!(graph.node_count(), 1);
964 assert!(graph.get_node("merge_node").is_ok());
965 }
966
967 #[test]
968 fn test_merge_with_nonexistent_branch() {
969 let mut graph = Graph::new();
970
971 graph.create_branch("branch_a").unwrap();
972
973 let merge_config = MergeConfig::new(
974 vec!["branch_a".to_string(), "nonexistent".to_string()],
975 "output".to_string(),
976 );
977
978 let result = graph.merge("merge_node", merge_config);
979 assert!(result.is_err());
980 }
981
982 #[test]
983 fn test_merge_with_custom_function() {
984 let mut graph = Graph::new();
985
986 graph.create_branch("branch_a").unwrap();
987 graph.create_branch("branch_b").unwrap();
988
989 let max_merge = Arc::new(|inputs: Vec<&PortData>| -> Result<PortData> {
991 let mut max_val = i64::MIN;
992 for data in inputs {
993 if let PortData::Int(val) = data {
994 max_val = max_val.max(*val);
995 }
996 }
997 Ok(PortData::Int(max_val))
998 });
999
1000 let merge_config = MergeConfig::new(
1001 vec!["branch_a".to_string(), "branch_b".to_string()],
1002 "output".to_string(),
1003 )
1004 .with_merge_fn(max_merge);
1005
1006 let result = graph.merge("merge_node", merge_config);
1007 assert!(result.is_ok());
1008 }
1009
1010 #[test]
1011 fn test_create_variants() {
1012 let mut graph = Graph::new();
1013
1014 let variant_fn = Arc::new(|i: usize| PortData::Int(i as i64 * 10));
1016 let config = VariantConfig::new("test_variant", 3, "param", variant_fn);
1017
1018 let result = graph.create_variants(config);
1019 assert!(result.is_ok());
1020
1021 let branch_names = result.unwrap();
1022 assert_eq!(branch_names.len(), 3);
1023 assert_eq!(branch_names[0], "test_variant_0");
1024 assert_eq!(branch_names[1], "test_variant_1");
1025 assert_eq!(branch_names[2], "test_variant_2");
1026
1027 for branch_name in &branch_names {
1029 assert!(graph.has_branch(branch_name));
1030 let branch = graph.get_branch(branch_name).unwrap();
1031 assert_eq!(branch.node_count(), 1);
1032 }
1033 }
1034
1035 #[test]
1036 fn test_variants_with_parallelization_flag() {
1037 let mut graph = Graph::new();
1038
1039 let variant_fn = Arc::new(|i: usize| PortData::Float(i as f64 * 0.5));
1040 let config =
1041 VariantConfig::new("param_sweep", 5, "learning_rate", variant_fn).with_parallel(false);
1042
1043 let result = graph.create_variants(config);
1044 assert!(result.is_ok());
1045
1046 let branch_names = result.unwrap();
1047 assert_eq!(branch_names.len(), 5);
1048 }
1049
1050 #[test]
1051 fn test_duplicate_variant_branch() {
1052 let mut graph = Graph::new();
1053
1054 let variant_fn = Arc::new(|i: usize| PortData::Int(i as i64));
1056 let config = VariantConfig::new("test", 2, "param", variant_fn.clone());
1057
1058 graph.create_variants(config).unwrap();
1059
1060 let config2 = VariantConfig::new("test", 2, "param", variant_fn);
1062 let result = graph.create_variants(config2);
1063 assert!(result.is_err());
1064 }
1065
1066 #[test]
1067 fn test_implicit_edge_mapping() {
1068 let mut graph = Graph::new();
1070
1071 let config1 = NodeConfig::new(
1072 "source",
1073 "Source",
1074 vec![],
1075 vec![Port::new("output", "Output")],
1076 Arc::new(dummy_function),
1077 );
1078
1079 let config2 = NodeConfig::new(
1080 "processor",
1081 "Processor",
1082 vec![Port::new("output", "Input")], vec![Port::new("result", "Result")],
1084 Arc::new(dummy_function),
1085 );
1086
1087 let config3 = NodeConfig::new(
1088 "sink",
1089 "Sink",
1090 vec![Port::new("result", "Input")], vec![],
1092 Arc::new(dummy_function),
1093 );
1094
1095 graph.add(Node::new(config1)).unwrap();
1097 graph.add(Node::new(config2)).unwrap();
1098 graph.add(Node::new(config3)).unwrap();
1099
1100 assert_eq!(graph.edge_count(), 2);
1102 assert_eq!(graph.node_count(), 3);
1103 }
1104
1105 #[test]
1106 fn test_strict_edge_mapping() {
1107 let mut graph = Graph::with_strict_edges();
1109
1110 let config1 = NodeConfig::new(
1111 "source",
1112 "Source",
1113 vec![],
1114 vec![Port::new("output", "Output")],
1115 Arc::new(dummy_function),
1116 );
1117
1118 let config2 = NodeConfig::new(
1119 "sink",
1120 "Sink",
1121 vec![Port::new("output", "Input")],
1122 vec![],
1123 Arc::new(dummy_function),
1124 );
1125
1126 graph.add(Node::new(config1)).unwrap();
1128 graph.add(Node::new(config2)).unwrap();
1129
1130 assert_eq!(graph.edge_count(), 0);
1132 assert_eq!(graph.node_count(), 2);
1133 }
1134
1135 #[test]
1136 fn test_auto_connect() {
1137 let mut graph = Graph::with_strict_edges();
1138
1139 let config1 = NodeConfig::new(
1141 "source",
1142 "Source",
1143 vec![],
1144 vec![Port::new("data", "Data")],
1145 Arc::new(dummy_function),
1146 );
1147
1148 let config2 = NodeConfig::new(
1149 "processor",
1150 "Processor",
1151 vec![Port::new("data", "Data")], vec![Port::new("result", "Result")],
1153 Arc::new(dummy_function),
1154 );
1155
1156 let config3 = NodeConfig::new(
1157 "sink",
1158 "Sink",
1159 vec![Port::new("result", "Result")], vec![],
1161 Arc::new(dummy_function),
1162 );
1163
1164 graph.add(Node::new(config1)).unwrap();
1165 graph.add(Node::new(config2)).unwrap();
1166 graph.add(Node::new(config3)).unwrap();
1167
1168 assert_eq!(graph.edge_count(), 0);
1170
1171 let edges_created = graph.auto_connect().unwrap();
1173 assert_eq!(edges_created, 2);
1174 assert_eq!(graph.edge_count(), 2);
1175
1176 assert!(graph.validate().is_ok());
1178 }
1179
1180 #[test]
1181 fn test_auto_connect_parallel_branches() {
1182 let mut graph = Graph::with_strict_edges();
1183
1184 let source = NodeConfig::new(
1186 "source",
1187 "Source",
1188 vec![],
1189 vec![Port::new("value", "Value")],
1190 Arc::new(dummy_function),
1191 );
1192
1193 let branch1 = NodeConfig::new(
1195 "branch1",
1196 "Branch 1",
1197 vec![Port::new("value", "Value")],
1198 vec![Port::new("out1", "Output 1")],
1199 Arc::new(dummy_function),
1200 );
1201
1202 let branch2 = NodeConfig::new(
1203 "branch2",
1204 "Branch 2",
1205 vec![Port::new("value", "Value")],
1206 vec![Port::new("out2", "Output 2")],
1207 Arc::new(dummy_function),
1208 );
1209
1210 let merger = NodeConfig::new(
1212 "merger",
1213 "Merger",
1214 vec![Port::new("out1", "Input 1"), Port::new("out2", "Input 2")],
1215 vec![],
1216 Arc::new(dummy_function),
1217 );
1218
1219 graph.add(Node::new(source)).unwrap();
1220 graph.add(Node::new(branch1)).unwrap();
1221 graph.add(Node::new(branch2)).unwrap();
1222 graph.add(Node::new(merger)).unwrap();
1223
1224 let edges_created = graph.auto_connect().unwrap();
1226 assert_eq!(edges_created, 4);
1227 assert_eq!(graph.edge_count(), 4);
1228
1229 assert!(graph.validate().is_ok());
1231 }
1232}