graph_sp/core/
graph.rs

1//! Graph structure and node definitions for the DAG execution engine.
2
3use crate::core::data::{NodeId, Port, PortData, PortId};
4use crate::core::error::{GraphError, Result};
5use petgraph::algo::toposort;
6use petgraph::graph::{DiGraph, NodeIndex};
7use petgraph::Direction;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Function type for node execution
13pub type NodeFunction =
14    Arc<dyn Fn(&HashMap<PortId, PortData>) -> Result<HashMap<PortId, PortData>> + Send + Sync>;
15
16/// Configuration for a node in the graph
17#[derive(Clone)]
18pub struct NodeConfig {
19    /// Unique identifier for the node
20    pub id: NodeId,
21    /// Human-readable name
22    pub name: String,
23    /// Node description
24    pub description: Option<String>,
25    /// Input ports
26    pub input_ports: Vec<Port>,
27    /// Output ports
28    pub output_ports: Vec<Port>,
29    /// Execution function
30    pub function: NodeFunction,
31}
32
33impl NodeConfig {
34    /// Create a new node configuration
35    pub fn new(
36        id: impl Into<NodeId>,
37        name: impl Into<String>,
38        input_ports: Vec<Port>,
39        output_ports: Vec<Port>,
40        function: NodeFunction,
41    ) -> Self {
42        Self {
43            id: id.into(),
44            name: name.into(),
45            description: None,
46            input_ports,
47            output_ports,
48            function,
49        }
50    }
51
52    /// Set the description for this node
53    pub fn with_description(mut self, description: impl Into<String>) -> Self {
54        self.description = Some(description.into());
55        self
56    }
57}
58
59/// Represents a node in the execution graph
60#[derive(Clone)]
61pub struct Node {
62    /// Node configuration
63    pub config: NodeConfig,
64    /// Current input data
65    pub inputs: HashMap<PortId, PortData>,
66    /// Current output data
67    pub outputs: HashMap<PortId, PortData>,
68}
69
70impl Node {
71    /// Create a new node from a configuration
72    pub fn new(config: NodeConfig) -> Self {
73        Self {
74            config,
75            inputs: HashMap::new(),
76            outputs: HashMap::new(),
77        }
78    }
79
80    /// Set input data for a port
81    pub fn set_input(&mut self, port_id: impl Into<PortId>, data: PortData) {
82        self.inputs.insert(port_id.into(), data);
83    }
84
85    /// Get output data from a port
86    pub fn get_output(&self, port_id: &str) -> Option<&PortData> {
87        self.outputs.get(port_id)
88    }
89
90    /// Execute the node's function
91    pub fn execute(&mut self) -> Result<()> {
92        // Validate required inputs
93        for port in &self.config.input_ports {
94            if port.required && !self.inputs.contains_key(&port.broadcast_name) {
95                return Err(GraphError::MissingInput {
96                    node: self.config.id.clone(),
97                    port: port.broadcast_name.clone(),
98                });
99            }
100        }
101
102        // Map inputs from broadcast_name to impl_name for the function
103        let mut impl_inputs = HashMap::new();
104        for port in &self.config.input_ports {
105            if let Some(data) = self.inputs.get(&port.broadcast_name) {
106                impl_inputs.insert(port.impl_name.clone(), data.clone());
107            }
108        }
109
110        // Execute the function with impl_name keys
111        let impl_outputs = (self.config.function)(&impl_inputs)?;
112
113        // Map outputs from impl_name back to broadcast_name
114        self.outputs.clear();
115        for port in &self.config.output_ports {
116            if let Some(data) = impl_outputs.get(&port.impl_name) {
117                self.outputs
118                    .insert(port.broadcast_name.clone(), data.clone());
119            }
120        }
121
122        Ok(())
123    }
124
125    /// Clear input data
126    pub fn clear_inputs(&mut self) {
127        self.inputs.clear();
128    }
129
130    /// Clear output data
131    pub fn clear_outputs(&mut self) {
132        self.outputs.clear();
133    }
134}
135
136/// Represents an edge connecting two nodes
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct Edge {
139    /// Source node ID
140    pub from_node: NodeId,
141    /// Source port ID
142    pub from_port: PortId,
143    /// Target node ID
144    pub to_node: NodeId,
145    /// Target port ID
146    pub to_port: PortId,
147}
148
149impl Edge {
150    /// Create a new edge
151    pub fn new(
152        from_node: impl Into<NodeId>,
153        from_port: impl Into<PortId>,
154        to_node: impl Into<NodeId>,
155        to_port: impl Into<PortId>,
156    ) -> Self {
157        Self {
158            from_node: from_node.into(),
159            from_port: from_port.into(),
160            to_node: to_node.into(),
161            to_port: to_port.into(),
162        }
163    }
164}
165
166/// Merge function type for combining outputs from multiple branches
167pub type MergeFunction = Arc<dyn Fn(Vec<&PortData>) -> Result<PortData> + Send + Sync>;
168
169/// Configuration for merging branch outputs
170pub struct MergeConfig {
171    /// Branches to merge
172    pub branches: Vec<String>,
173    /// Output port name on each branch to merge
174    pub port: String,
175    /// Custom merge function (default: collect into list)
176    pub merge_fn: Option<MergeFunction>,
177}
178
179impl MergeConfig {
180    /// Create a new merge configuration
181    pub fn new(branches: Vec<String>, port: String) -> Self {
182        Self {
183            branches,
184            port,
185            merge_fn: None,
186        }
187    }
188
189    /// Set a custom merge function
190    pub fn with_merge_fn(mut self, merge_fn: MergeFunction) -> Self {
191        self.merge_fn = Some(merge_fn);
192        self
193    }
194}
195
196/// Variant function type for generating parameter variations
197pub type VariantFunction = Arc<dyn Fn(usize) -> PortData + Send + Sync>;
198
199/// Configuration for creating variants (config sweeps)
200pub struct VariantConfig {
201    /// Name prefix for variant branches
202    pub name_prefix: String,
203    /// Number of variants to create
204    pub count: usize,
205    /// Function to generate variant parameter values
206    pub variant_fn: VariantFunction,
207    /// Parameter name to vary
208    pub param_name: String,
209    /// Whether to enable parallel execution (default: true)
210    pub parallel: bool,
211}
212
213impl VariantConfig {
214    /// Create a new variant configuration
215    pub fn new(
216        name_prefix: impl Into<String>,
217        count: usize,
218        param_name: impl Into<String>,
219        variant_fn: VariantFunction,
220    ) -> Self {
221        Self {
222            name_prefix: name_prefix.into(),
223            count,
224            variant_fn,
225            param_name: param_name.into(),
226            parallel: true,
227        }
228    }
229
230    /// Set parallelization flag
231    pub fn with_parallel(mut self, parallel: bool) -> Self {
232        self.parallel = parallel;
233        self
234    }
235}
236
237/// The main graph structure representing a DAG
238#[derive(Clone)]
239pub struct Graph {
240    /// Internal graph structure
241    graph: DiGraph<Node, Edge>,
242    /// Map from node ID to graph index
243    node_indices: HashMap<NodeId, NodeIndex>,
244    /// Named branches (subgraphs)
245    branches: HashMap<String, Graph>,
246    /// Track node addition order for implicit mapping
247    node_order: Vec<NodeId>,
248    /// Whether to use strict edge mapping (explicit add_edge required)
249    strict_edge_mapping: bool,
250}
251
252impl Graph {
253    /// Create a new empty graph
254    pub fn new() -> Self {
255        Self {
256            graph: DiGraph::new(),
257            node_indices: HashMap::new(),
258            branches: HashMap::new(),
259            node_order: Vec::new(),
260            strict_edge_mapping: false,
261        }
262    }
263
264    /// Create a new graph with strict edge mapping enabled
265    /// When enabled, edges must be explicitly added with add_edge()
266    /// When disabled (default), edges are automatically created based on node order
267    pub fn with_strict_edges() -> Self {
268        Self {
269            graph: DiGraph::new(),
270            node_indices: HashMap::new(),
271            branches: HashMap::new(),
272            node_order: Vec::new(),
273            strict_edge_mapping: true,
274        }
275    }
276
277    /// Set strict edge mapping mode
278    pub fn set_strict_edge_mapping(&mut self, strict: bool) {
279        self.strict_edge_mapping = strict;
280    }
281
282    /// Add a node to the graph
283    pub fn add(&mut self, node: Node) -> Result<()> {
284        let node_id = node.config.id.clone();
285
286        if self.node_indices.contains_key(&node_id) {
287            return Err(GraphError::InvalidGraph(format!(
288                "Node with ID '{}' already exists",
289                node_id
290            )));
291        }
292
293        let index = self.graph.add_node(node);
294        self.node_indices.insert(node_id.clone(), index);
295
296        // Implicit edge mapping: connect to previous node if not in strict mode
297        if !self.strict_edge_mapping && !self.node_order.is_empty() {
298            self.auto_connect_to_previous(&node_id)?;
299        }
300
301        self.node_order.push(node_id);
302        Ok(())
303    }
304
305    /// Automatically connect the new node to the previous node based on port names
306    fn auto_connect_to_previous(&mut self, new_node_id: &str) -> Result<()> {
307        let edges_to_add = if let Some(prev_node_id) = self.node_order.last().cloned() {
308            let prev_node = self.get_node(&prev_node_id)?;
309            let new_node = self.get_node(new_node_id)?;
310
311            let mut edges = Vec::new();
312            // Match output ports from previous node to input ports of new node
313            for out_port in &prev_node.config.output_ports {
314                for in_port in &new_node.config.input_ports {
315                    // Connect if port names match or if they're the only ports
316                    let should_connect = out_port.broadcast_name == in_port.broadcast_name
317                        || (prev_node.config.output_ports.len() == 1
318                            && new_node.config.input_ports.len() == 1);
319
320                    if should_connect {
321                        edges.push(Edge::new(
322                            &prev_node_id,
323                            &out_port.broadcast_name,
324                            new_node_id,
325                            &in_port.broadcast_name,
326                        ));
327                        break; // Only connect first matching port
328                    }
329                }
330            }
331            edges
332        } else {
333            Vec::new()
334        };
335
336        // Add all collected edges
337        for edge in edges_to_add {
338            self.add_edge(edge)?;
339        }
340
341        Ok(())
342    }
343
344    /// Alias for add() for backward compatibility
345    #[deprecated(since = "0.2.0", note = "Use `add` instead")]
346    pub fn add_node(&mut self, node: Node) -> Result<()> {
347        self.add(node)
348    }
349
350    /// Add an edge to the graph
351    pub fn add_edge(&mut self, edge: Edge) -> Result<()> {
352        let from_idx = self
353            .node_indices
354            .get(&edge.from_node)
355            .ok_or_else(|| GraphError::NodeNotFound(edge.from_node.clone()))?;
356        let to_idx = self
357            .node_indices
358            .get(&edge.to_node)
359            .ok_or_else(|| GraphError::NodeNotFound(edge.to_node.clone()))?;
360
361        // Check if the output port exists
362        let from_node = &self.graph[*from_idx];
363        if !from_node
364            .config
365            .output_ports
366            .iter()
367            .any(|p| p.broadcast_name == edge.from_port)
368        {
369            return Err(GraphError::PortError(format!(
370                "Output port '{}' not found on node '{}'",
371                edge.from_port, edge.from_node
372            )));
373        }
374
375        // Check if the input port exists
376        let to_node = &self.graph[*to_idx];
377        if !to_node
378            .config
379            .input_ports
380            .iter()
381            .any(|p| p.broadcast_name == edge.to_port)
382        {
383            return Err(GraphError::PortError(format!(
384                "Input port '{}' not found on node '{}'",
385                edge.to_port, edge.to_node
386            )));
387        }
388
389        self.graph.add_edge(*from_idx, *to_idx, edge);
390        Ok(())
391    }
392
393    /// Get a node by ID
394    pub fn get_node(&self, node_id: &str) -> Result<&Node> {
395        let idx = self
396            .node_indices
397            .get(node_id)
398            .ok_or_else(|| GraphError::NodeNotFound(node_id.to_string()))?;
399        Ok(&self.graph[*idx])
400    }
401
402    /// Get a mutable reference to a node by ID
403    pub fn get_node_mut(&mut self, node_id: &str) -> Result<&mut Node> {
404        let idx = self
405            .node_indices
406            .get(node_id)
407            .ok_or_else(|| GraphError::NodeNotFound(node_id.to_string()))?;
408        Ok(&mut self.graph[*idx])
409    }
410
411    /// Validate the graph (check for cycles)
412    pub fn validate(&self) -> Result<()> {
413        match toposort(&self.graph, None) {
414            Ok(_) => Ok(()),
415            Err(cycle) => {
416                let node = &self.graph[cycle.node_id()];
417                Err(GraphError::CycleDetected(node.config.id.clone()))
418            }
419        }
420    }
421
422    /// Get a topological ordering of the nodes
423    pub fn topological_order(&self) -> Result<Vec<NodeId>> {
424        let sorted = toposort(&self.graph, None).map_err(|cycle| {
425            let node = &self.graph[cycle.node_id()];
426            GraphError::CycleDetected(node.config.id.clone())
427        })?;
428
429        Ok(sorted
430            .into_iter()
431            .map(|idx| self.graph[idx].config.id.clone())
432            .collect())
433    }
434
435    /// Get all nodes in the graph
436    pub fn nodes(&self) -> Vec<&Node> {
437        self.graph
438            .node_indices()
439            .map(|idx| &self.graph[idx])
440            .collect()
441    }
442
443    /// Get all edges in the graph
444    pub fn edges(&self) -> Vec<&Edge> {
445        self.graph
446            .edge_indices()
447            .map(|idx| &self.graph[idx])
448            .collect()
449    }
450
451    /// Get the number of nodes
452    pub fn node_count(&self) -> usize {
453        self.graph.node_count()
454    }
455
456    /// Get the number of edges
457    pub fn edge_count(&self) -> usize {
458        self.graph.edge_count()
459    }
460
461    /// Get incoming edges for a node
462    pub fn incoming_edges(&self, node_id: &str) -> Result<Vec<&Edge>> {
463        let idx = self
464            .node_indices
465            .get(node_id)
466            .ok_or_else(|| GraphError::NodeNotFound(node_id.to_string()))?;
467
468        Ok(self
469            .graph
470            .edges_directed(*idx, Direction::Incoming)
471            .map(|e| e.weight())
472            .collect())
473    }
474
475    /// Get outgoing edges for a node
476    pub fn outgoing_edges(&self, node_id: &str) -> Result<Vec<&Edge>> {
477        let idx = self
478            .node_indices
479            .get(node_id)
480            .ok_or_else(|| GraphError::NodeNotFound(node_id.to_string()))?;
481
482        Ok(self
483            .graph
484            .edges_directed(*idx, Direction::Outgoing)
485            .map(|e| e.weight())
486            .collect())
487    }
488
489    /// Automatically connect nodes based on matching port names
490    /// This enables implicit edge mapping without explicit add_edge() calls
491    ///
492    /// # Matching Strategy
493    /// - Connects output ports to input ports with the same name
494    /// - Only creates edges if the port names match exactly
495    /// - Respects topological ordering to avoid cycles
496    ///
497    /// # Returns
498    /// The number of edges created
499    pub fn auto_connect(&mut self) -> Result<usize> {
500        let mut edges_created = 0;
501        let node_ids: Vec<NodeId> = self.nodes().iter().map(|n| n.config.id.clone()).collect();
502
503        for from_node_id in &node_ids {
504            let from_node = self.get_node(from_node_id)?;
505            let output_ports: Vec<PortId> = from_node
506                .config
507                .output_ports
508                .iter()
509                .map(|p| p.broadcast_name.clone())
510                .collect();
511
512            for to_node_id in &node_ids {
513                if from_node_id == to_node_id {
514                    continue;
515                }
516
517                let to_node = self.get_node(to_node_id)?;
518                let input_ports: Vec<PortId> = to_node
519                    .config
520                    .input_ports
521                    .iter()
522                    .map(|p| p.broadcast_name.clone())
523                    .collect();
524
525                // Find matching port names
526                for output_port in &output_ports {
527                    for input_port in &input_ports {
528                        if output_port == input_port {
529                            // Check if edge already exists
530                            let edge_exists = self.edges().iter().any(|e| {
531                                e.from_node == *from_node_id
532                                    && e.from_port == *output_port
533                                    && e.to_node == *to_node_id
534                                    && e.to_port == *input_port
535                            });
536
537                            if !edge_exists {
538                                let edge = Edge::new(
539                                    from_node_id.clone(),
540                                    output_port.clone(),
541                                    to_node_id.clone(),
542                                    input_port.clone(),
543                                );
544                                self.add_edge(edge)?;
545                                edges_created += 1;
546                            }
547                        }
548                    }
549                }
550            }
551        }
552
553        Ok(edges_created)
554    }
555
556    /// Build a graph with strict mode disabled - uses implicit edge mapping
557    /// This is a convenience method that calls auto_connect() after all nodes are added
558    pub fn with_auto_connect(mut self) -> Result<Self> {
559        self.auto_connect()?;
560        Ok(self)
561    }
562
563    /// Create a new branch (subgraph) with the given name
564    pub fn create_branch(&mut self, name: impl Into<String>) -> Result<&mut Graph> {
565        let name = name.into();
566        if self.branches.contains_key(&name) {
567            return Err(GraphError::InvalidGraph(format!(
568                "Branch '{}' already exists",
569                name
570            )));
571        }
572        self.branches.insert(name.clone(), Graph::new());
573        Ok(self.branches.get_mut(&name).unwrap())
574    }
575
576    /// Get a reference to a branch by name
577    pub fn get_branch(&self, name: &str) -> Result<&Graph> {
578        self.branches
579            .get(name)
580            .ok_or_else(|| GraphError::InvalidGraph(format!("Branch '{}' not found", name)))
581    }
582
583    /// Get a mutable reference to a branch by name
584    pub fn get_branch_mut(&mut self, name: &str) -> Result<&mut Graph> {
585        self.branches
586            .get_mut(name)
587            .ok_or_else(|| GraphError::InvalidGraph(format!("Branch '{}' not found", name)))
588    }
589
590    /// Get all branch names
591    pub fn branch_names(&self) -> Vec<String> {
592        self.branches.keys().cloned().collect()
593    }
594
595    /// Check if a branch exists
596    pub fn has_branch(&self, name: &str) -> bool {
597        self.branches.contains_key(name)
598    }
599
600    /// Create a merge node that combines outputs from multiple branches
601    ///
602    /// The merge node will collect outputs from the specified branches and combine them
603    /// using the provided merge function (or collect into a list by default).
604    pub fn merge(&mut self, node_id: impl Into<NodeId>, config: MergeConfig) -> Result<()> {
605        // Validate that all branches exist
606        for branch_name in &config.branches {
607            if !self.has_branch(branch_name) {
608                return Err(GraphError::InvalidGraph(format!(
609                    "Branch '{}' not found for merge operation",
610                    branch_name
611                )));
612            }
613        }
614
615        let branch_names = config.branches.clone();
616
617        // Create the merge function
618        let merge_fn = config.merge_fn.unwrap_or_else(|| {
619            // Default merge function: collect into a list
620            Arc::new(|inputs: Vec<&PortData>| -> Result<PortData> {
621                Ok(PortData::List(inputs.iter().map(|&d| d.clone()).collect()))
622            })
623        });
624
625        // Create input ports - one for each branch
626        let input_ports: Vec<Port> = branch_names
627            .iter()
628            .map(|name| Port::new(name.clone(), format!("Input from {}", name)))
629            .collect();
630
631        // Create a merge node
632        let node_config = NodeConfig::new(
633            node_id,
634            "Merge Node",
635            input_ports,
636            vec![Port::new("merged", "Merged Output")],
637            Arc::new(move |inputs: &HashMap<PortId, PortData>| {
638                // Collect inputs in branch order
639                let mut collected_inputs = Vec::new();
640                for branch_name in &branch_names {
641                    if let Some(data) = inputs.get(branch_name.as_str()) {
642                        collected_inputs.push(data);
643                    }
644                }
645
646                // Apply merge function
647                let merged = merge_fn(collected_inputs)?;
648
649                let mut outputs = HashMap::new();
650                outputs.insert("merged".to_string(), merged);
651                Ok(outputs)
652            }),
653        );
654
655        self.add(Node::new(node_config))
656    }
657
658    /// Create variant branches for config sweeps
659    ///
660    /// This creates multiple isolated branches, each with a different parameter value.
661    /// Variants can be used for hyperparameter sweeps, A/B testing, or any scenario
662    /// where you want to run the same computation with different inputs.
663    ///
664    /// Returns the names of the created variant branches.
665    pub fn create_variants(&mut self, config: VariantConfig) -> Result<Vec<String>> {
666        let mut branch_names = Vec::new();
667
668        for i in 0..config.count {
669            let branch_name = format!("{}_{}", config.name_prefix, i);
670
671            // Check if branch already exists
672            if self.has_branch(&branch_name) {
673                return Err(GraphError::InvalidGraph(format!(
674                    "Variant branch '{}' already exists",
675                    branch_name
676                )));
677            }
678
679            // Create the branch
680            let branch = self.create_branch(&branch_name)?;
681
682            // Add a source node to the branch with the variant parameter
683            let param_value = (config.variant_fn)(i);
684            let param_name = config.param_name.clone();
685
686            let source_config = NodeConfig::new(
687                format!("{}_source", branch_name),
688                format!("Variant Source {}", i),
689                vec![],
690                vec![Port::new(&param_name, "Variant Parameter")],
691                // Note: param_name and param_value must be cloned into the closure
692                // because the closure is moved into an Arc and needs to own these values
693                // to ensure they remain valid for the lifetime of the node function
694                Arc::new(move |_: &HashMap<PortId, PortData>| {
695                    let mut outputs = HashMap::new();
696                    outputs.insert(param_name.clone(), param_value.clone());
697                    Ok(outputs)
698                }),
699            );
700
701            branch.add(Node::new(source_config))?;
702            branch_names.push(branch_name);
703        }
704
705        Ok(branch_names)
706    }
707}
708
709impl Default for Graph {
710    fn default() -> Self {
711        Self::new()
712    }
713}
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718    use crate::core::data::PortData;
719
720    fn dummy_function(inputs: &HashMap<PortId, PortData>) -> Result<HashMap<PortId, PortData>> {
721        let mut outputs = HashMap::new();
722        if let Some(PortData::Int(val)) = inputs.get("input") {
723            outputs.insert("output".to_string(), PortData::Int(val * 2));
724        }
725        Ok(outputs)
726    }
727
728    #[test]
729    fn test_graph_creation() {
730        let graph = Graph::new();
731        assert_eq!(graph.node_count(), 0);
732        assert_eq!(graph.edge_count(), 0);
733    }
734
735    #[test]
736    fn test_add_node() {
737        let mut graph = Graph::new();
738
739        let config = NodeConfig::new(
740            "node1",
741            "Node 1",
742            vec![Port::new("input", "Input")],
743            vec![Port::new("output", "Output")],
744            Arc::new(dummy_function),
745        );
746
747        let node = Node::new(config);
748        assert!(graph.add(node).is_ok());
749        assert_eq!(graph.node_count(), 1);
750    }
751
752    #[test]
753    fn test_duplicate_node_id() {
754        let mut graph = Graph::new();
755
756        let config1 = NodeConfig::new("node1", "Node 1", vec![], vec![], Arc::new(dummy_function));
757
758        let config2 = NodeConfig::new(
759            "node1",
760            "Node 1 Duplicate",
761            vec![],
762            vec![],
763            Arc::new(dummy_function),
764        );
765
766        assert!(graph.add(Node::new(config1)).is_ok());
767        assert!(graph.add(Node::new(config2)).is_err());
768    }
769
770    #[test]
771    fn test_add_edge() {
772        let mut graph = Graph::with_strict_edges();
773
774        let config1 = NodeConfig::new(
775            "node1",
776            "Node 1",
777            vec![],
778            vec![Port::new("output", "Output")],
779            Arc::new(dummy_function),
780        );
781
782        let config2 = NodeConfig::new(
783            "node2",
784            "Node 2",
785            vec![Port::new("input", "Input")],
786            vec![],
787            Arc::new(dummy_function),
788        );
789
790        graph.add(Node::new(config1)).unwrap();
791        graph.add(Node::new(config2)).unwrap();
792
793        let edge = Edge::new("node1", "output", "node2", "input");
794        assert!(graph.add_edge(edge).is_ok());
795        assert_eq!(graph.edge_count(), 1);
796    }
797
798    #[test]
799    fn test_topological_order() {
800        let mut graph = Graph::new();
801
802        // Create a simple linear graph: node1 -> node2 -> node3
803        for i in 1..=3 {
804            let outputs = if i < 3 {
805                vec![Port::new("output", "Output")]
806            } else {
807                vec![]
808            };
809            let inputs = if i > 1 {
810                vec![Port::new("input", "Input")]
811            } else {
812                vec![]
813            };
814
815            let config = NodeConfig::new(
816                format!("node{}", i),
817                format!("Node {}", i),
818                inputs,
819                outputs,
820                Arc::new(dummy_function),
821            );
822            graph.add(Node::new(config)).unwrap();
823        }
824
825        graph
826            .add_edge(Edge::new("node1", "output", "node2", "input"))
827            .unwrap();
828        graph
829            .add_edge(Edge::new("node2", "output", "node3", "input"))
830            .unwrap();
831
832        let order = graph.topological_order().unwrap();
833        assert_eq!(order.len(), 3);
834        assert_eq!(order[0], "node1");
835        assert_eq!(order[1], "node2");
836        assert_eq!(order[2], "node3");
837    }
838
839    #[test]
840    fn test_cycle_detection() {
841        let mut graph = Graph::new();
842
843        // Create a cycle: node1 -> node2 -> node1
844        let config1 = NodeConfig::new(
845            "node1",
846            "Node 1",
847            vec![Port::new("input", "Input")],
848            vec![Port::new("output", "Output")],
849            Arc::new(dummy_function),
850        );
851
852        let config2 = NodeConfig::new(
853            "node2",
854            "Node 2",
855            vec![Port::new("input", "Input")],
856            vec![Port::new("output", "Output")],
857            Arc::new(dummy_function),
858        );
859
860        graph.add(Node::new(config1)).unwrap();
861        graph.add(Node::new(config2)).unwrap();
862
863        graph
864            .add_edge(Edge::new("node1", "output", "node2", "input"))
865            .unwrap();
866        graph
867            .add_edge(Edge::new("node2", "output", "node1", "input"))
868            .unwrap();
869
870        assert!(graph.validate().is_err());
871    }
872
873    #[test]
874    fn test_create_branch() {
875        let mut graph = Graph::new();
876
877        // Create a branch
878        let branch = graph.create_branch("branch_a");
879        assert!(branch.is_ok());
880
881        // Verify branch exists
882        assert!(graph.has_branch("branch_a"));
883        assert_eq!(graph.branch_names().len(), 1);
884        assert_eq!(graph.branch_names()[0], "branch_a");
885    }
886
887    #[test]
888    fn test_duplicate_branch_name() {
889        let mut graph = Graph::new();
890
891        graph.create_branch("branch_a").unwrap();
892        let result = graph.create_branch("branch_a");
893        assert!(result.is_err());
894    }
895
896    #[test]
897    fn test_branch_isolation() {
898        let mut graph = Graph::new();
899
900        // Create two branches
901        let branch_a = graph.create_branch("branch_a").unwrap();
902        let config_a = NodeConfig::new(
903            "node_a",
904            "Node A",
905            vec![],
906            vec![Port::new("output", "Output")],
907            Arc::new(dummy_function),
908        );
909        branch_a.add(Node::new(config_a)).unwrap();
910
911        let branch_b = graph.create_branch("branch_b").unwrap();
912        let config_b = NodeConfig::new(
913            "node_b",
914            "Node B",
915            vec![],
916            vec![Port::new("output", "Output")],
917            Arc::new(dummy_function),
918        );
919        branch_b.add(Node::new(config_b)).unwrap();
920
921        // Verify each branch has only one node
922        assert_eq!(graph.get_branch("branch_a").unwrap().node_count(), 1);
923        assert_eq!(graph.get_branch("branch_b").unwrap().node_count(), 1);
924
925        // Verify branches don't share nodes
926        assert!(graph
927            .get_branch("branch_a")
928            .unwrap()
929            .get_node("node_b")
930            .is_err());
931        assert!(graph
932            .get_branch("branch_b")
933            .unwrap()
934            .get_node("node_a")
935            .is_err());
936    }
937
938    #[test]
939    fn test_get_nonexistent_branch() {
940        let graph = Graph::new();
941        assert!(graph.get_branch("nonexistent").is_err());
942    }
943
944    #[test]
945    fn test_merge_basic() {
946        let mut graph = Graph::new();
947
948        // Create two branches
949        graph.create_branch("branch_a").unwrap();
950        graph.create_branch("branch_b").unwrap();
951
952        // Create merge configuration
953        let merge_config = MergeConfig::new(
954            vec!["branch_a".to_string(), "branch_b".to_string()],
955            "output".to_string(),
956        );
957
958        // Create merge node
959        let result = graph.merge("merge_node", merge_config);
960        assert!(result.is_ok());
961
962        // Verify merge node was created
963        assert_eq!(graph.node_count(), 1);
964        assert!(graph.get_node("merge_node").is_ok());
965    }
966
967    #[test]
968    fn test_merge_with_nonexistent_branch() {
969        let mut graph = Graph::new();
970
971        graph.create_branch("branch_a").unwrap();
972
973        let merge_config = MergeConfig::new(
974            vec!["branch_a".to_string(), "nonexistent".to_string()],
975            "output".to_string(),
976        );
977
978        let result = graph.merge("merge_node", merge_config);
979        assert!(result.is_err());
980    }
981
982    #[test]
983    fn test_merge_with_custom_function() {
984        let mut graph = Graph::new();
985
986        graph.create_branch("branch_a").unwrap();
987        graph.create_branch("branch_b").unwrap();
988
989        // Custom merge function that finds max
990        let max_merge = Arc::new(|inputs: Vec<&PortData>| -> Result<PortData> {
991            let mut max_val = i64::MIN;
992            for data in inputs {
993                if let PortData::Int(val) = data {
994                    max_val = max_val.max(*val);
995                }
996            }
997            Ok(PortData::Int(max_val))
998        });
999
1000        let merge_config = MergeConfig::new(
1001            vec!["branch_a".to_string(), "branch_b".to_string()],
1002            "output".to_string(),
1003        )
1004        .with_merge_fn(max_merge);
1005
1006        let result = graph.merge("merge_node", merge_config);
1007        assert!(result.is_ok());
1008    }
1009
1010    #[test]
1011    fn test_create_variants() {
1012        let mut graph = Graph::new();
1013
1014        // Create variants with integer values
1015        let variant_fn = Arc::new(|i: usize| PortData::Int(i as i64 * 10));
1016        let config = VariantConfig::new("test_variant", 3, "param", variant_fn);
1017
1018        let result = graph.create_variants(config);
1019        assert!(result.is_ok());
1020
1021        let branch_names = result.unwrap();
1022        assert_eq!(branch_names.len(), 3);
1023        assert_eq!(branch_names[0], "test_variant_0");
1024        assert_eq!(branch_names[1], "test_variant_1");
1025        assert_eq!(branch_names[2], "test_variant_2");
1026
1027        // Verify each branch was created with a source node
1028        for branch_name in &branch_names {
1029            assert!(graph.has_branch(branch_name));
1030            let branch = graph.get_branch(branch_name).unwrap();
1031            assert_eq!(branch.node_count(), 1);
1032        }
1033    }
1034
1035    #[test]
1036    fn test_variants_with_parallelization_flag() {
1037        let mut graph = Graph::new();
1038
1039        let variant_fn = Arc::new(|i: usize| PortData::Float(i as f64 * 0.5));
1040        let config =
1041            VariantConfig::new("param_sweep", 5, "learning_rate", variant_fn).with_parallel(false);
1042
1043        let result = graph.create_variants(config);
1044        assert!(result.is_ok());
1045
1046        let branch_names = result.unwrap();
1047        assert_eq!(branch_names.len(), 5);
1048    }
1049
1050    #[test]
1051    fn test_duplicate_variant_branch() {
1052        let mut graph = Graph::new();
1053
1054        // Create initial variant
1055        let variant_fn = Arc::new(|i: usize| PortData::Int(i as i64));
1056        let config = VariantConfig::new("test", 2, "param", variant_fn.clone());
1057
1058        graph.create_variants(config).unwrap();
1059
1060        // Try to create the same variants again
1061        let config2 = VariantConfig::new("test", 2, "param", variant_fn);
1062        let result = graph.create_variants(config2);
1063        assert!(result.is_err());
1064    }
1065
1066    #[test]
1067    fn test_implicit_edge_mapping() {
1068        // Default mode: implicit edge mapping
1069        let mut graph = Graph::new();
1070
1071        let config1 = NodeConfig::new(
1072            "source",
1073            "Source",
1074            vec![],
1075            vec![Port::new("output", "Output")],
1076            Arc::new(dummy_function),
1077        );
1078
1079        let config2 = NodeConfig::new(
1080            "processor",
1081            "Processor",
1082            vec![Port::new("output", "Input")], // Port name matches prev output
1083            vec![Port::new("result", "Result")],
1084            Arc::new(dummy_function),
1085        );
1086
1087        let config3 = NodeConfig::new(
1088            "sink",
1089            "Sink",
1090            vec![Port::new("result", "Input")], // Port name matches prev output
1091            vec![],
1092            Arc::new(dummy_function),
1093        );
1094
1095        // Add nodes - edges should be created automatically
1096        graph.add(Node::new(config1)).unwrap();
1097        graph.add(Node::new(config2)).unwrap();
1098        graph.add(Node::new(config3)).unwrap();
1099
1100        // Should have 2 edges (source->processor, processor->sink)
1101        assert_eq!(graph.edge_count(), 2);
1102        assert_eq!(graph.node_count(), 3);
1103    }
1104
1105    #[test]
1106    fn test_strict_edge_mapping() {
1107        // Strict mode: explicit edges required
1108        let mut graph = Graph::with_strict_edges();
1109
1110        let config1 = NodeConfig::new(
1111            "source",
1112            "Source",
1113            vec![],
1114            vec![Port::new("output", "Output")],
1115            Arc::new(dummy_function),
1116        );
1117
1118        let config2 = NodeConfig::new(
1119            "sink",
1120            "Sink",
1121            vec![Port::new("output", "Input")],
1122            vec![],
1123            Arc::new(dummy_function),
1124        );
1125
1126        // Add nodes - NO edges should be created automatically
1127        graph.add(Node::new(config1)).unwrap();
1128        graph.add(Node::new(config2)).unwrap();
1129
1130        // Should have 0 edges in strict mode
1131        assert_eq!(graph.edge_count(), 0);
1132        assert_eq!(graph.node_count(), 2);
1133    }
1134
1135    #[test]
1136    fn test_auto_connect() {
1137        let mut graph = Graph::with_strict_edges();
1138
1139        // Create nodes with matching port names
1140        let config1 = NodeConfig::new(
1141            "source",
1142            "Source",
1143            vec![],
1144            vec![Port::new("data", "Data")],
1145            Arc::new(dummy_function),
1146        );
1147
1148        let config2 = NodeConfig::new(
1149            "processor",
1150            "Processor",
1151            vec![Port::new("data", "Data")], // Matches source output!
1152            vec![Port::new("result", "Result")],
1153            Arc::new(dummy_function),
1154        );
1155
1156        let config3 = NodeConfig::new(
1157            "sink",
1158            "Sink",
1159            vec![Port::new("result", "Result")], // Matches processor output!
1160            vec![],
1161            Arc::new(dummy_function),
1162        );
1163
1164        graph.add(Node::new(config1)).unwrap();
1165        graph.add(Node::new(config2)).unwrap();
1166        graph.add(Node::new(config3)).unwrap();
1167
1168        // Initially no edges in strict mode
1169        assert_eq!(graph.edge_count(), 0);
1170
1171        // Auto-connect should create 2 edges
1172        let edges_created = graph.auto_connect().unwrap();
1173        assert_eq!(edges_created, 2);
1174        assert_eq!(graph.edge_count(), 2);
1175
1176        // Graph should be valid
1177        assert!(graph.validate().is_ok());
1178    }
1179
1180    #[test]
1181    fn test_auto_connect_parallel_branches() {
1182        let mut graph = Graph::with_strict_edges();
1183
1184        // Source with output "value"
1185        let source = NodeConfig::new(
1186            "source",
1187            "Source",
1188            vec![],
1189            vec![Port::new("value", "Value")],
1190            Arc::new(dummy_function),
1191        );
1192
1193        // Two branches with same input port name
1194        let branch1 = NodeConfig::new(
1195            "branch1",
1196            "Branch 1",
1197            vec![Port::new("value", "Value")],
1198            vec![Port::new("out1", "Output 1")],
1199            Arc::new(dummy_function),
1200        );
1201
1202        let branch2 = NodeConfig::new(
1203            "branch2",
1204            "Branch 2",
1205            vec![Port::new("value", "Value")],
1206            vec![Port::new("out2", "Output 2")],
1207            Arc::new(dummy_function),
1208        );
1209
1210        // Merger with inputs matching branch outputs
1211        let merger = NodeConfig::new(
1212            "merger",
1213            "Merger",
1214            vec![Port::new("out1", "Input 1"), Port::new("out2", "Input 2")],
1215            vec![],
1216            Arc::new(dummy_function),
1217        );
1218
1219        graph.add(Node::new(source)).unwrap();
1220        graph.add(Node::new(branch1)).unwrap();
1221        graph.add(Node::new(branch2)).unwrap();
1222        graph.add(Node::new(merger)).unwrap();
1223
1224        // Auto-connect should create 4 edges (fan-out + fan-in)
1225        let edges_created = graph.auto_connect().unwrap();
1226        assert_eq!(edges_created, 4);
1227        assert_eq!(graph.edge_count(), 4);
1228
1229        // Graph should be valid
1230        assert!(graph.validate().is_ok());
1231    }
1232}