runnx 0.3.0

A minimal, verifiable ONNX runtime implementation in Rust
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
//! Computational graph representation
//!
//! This module defines the graph structure for ONNX models, including
//! nodes, edges, and the overall graph representation.

use crate::{
    error::{OnnxError, Result},
    operators::OperatorType,
    tensor::Tensor,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// A node in the computational graph
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Node {
    /// Unique identifier for the node
    pub name: String,
    /// Type of operation this node performs
    pub op_type: String,
    /// Input tensor names
    pub inputs: Vec<String>,
    /// Output tensor names
    pub outputs: Vec<String>,
    /// Node attributes (parameters)
    pub attributes: HashMap<String, String>,
}

impl Node {
    /// Create a new node
    pub fn new(name: String, op_type: String, inputs: Vec<String>, outputs: Vec<String>) -> Self {
        Self {
            name,
            op_type,
            inputs,
            outputs,
            attributes: HashMap::new(),
        }
    }

    /// Add an attribute to the node
    pub fn add_attribute<K: Into<String>, V: Into<String>>(&mut self, key: K, value: V) {
        self.attributes.insert(key.into(), value.into());
    }

    /// Get the operator type as enum
    pub fn get_operator_type(&self) -> Result<OperatorType> {
        self.op_type.parse()
    }
}

/// Represents the computational graph of an ONNX model
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Graph {
    /// Graph name
    pub name: String,
    /// List of nodes in execution order
    pub nodes: Vec<Node>,
    /// Input tensor specifications
    pub inputs: Vec<TensorSpec>,
    /// Output tensor specifications
    pub outputs: Vec<TensorSpec>,
    /// Initial values for parameters/constants
    pub initializers: HashMap<String, Tensor>,
}

/// Tensor specification with name and shape information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorSpec {
    /// Name of the tensor
    pub name: String,
    /// Shape of the tensor (None for dynamic dimensions)
    pub dimensions: Vec<Option<usize>>,
    /// Data type (simplified to f32 for this implementation)
    pub dtype: String,
}

impl TensorSpec {
    /// Create a new tensor specification
    pub fn new(name: String, dimensions: Vec<Option<usize>>) -> Self {
        Self {
            name,
            dimensions,
            dtype: "float32".to_string(),
        }
    }

    /// Check if the tensor spec matches a given tensor
    pub fn matches_tensor(&self, tensor: &Tensor) -> bool {
        let tensor_shape = tensor.shape();

        if self.dimensions.len() != tensor_shape.len() {
            return false;
        }

        for (spec_dim, &tensor_dim) in self.dimensions.iter().zip(tensor_shape.iter()) {
            match spec_dim {
                Some(expected) => {
                    if *expected != tensor_dim {
                        return false;
                    }
                }
                None => {
                    // Dynamic dimension, any size is acceptable
                    continue;
                }
            }
        }

        true
    }
}

impl Graph {
    /// Create a new empty graph
    pub fn new(name: String) -> Self {
        Self {
            name,
            nodes: Vec::new(),
            inputs: Vec::new(),
            outputs: Vec::new(),
            initializers: HashMap::new(),
        }
    }

    /// Add a node to the graph
    pub fn add_node(&mut self, node: Node) {
        self.nodes.push(node);
    }

    /// Add an input specification
    pub fn add_input(&mut self, input_spec: TensorSpec) {
        self.inputs.push(input_spec);
    }

    /// Add an output specification
    pub fn add_output(&mut self, output_spec: TensorSpec) {
        self.outputs.push(output_spec);
    }

    /// Add an initializer (constant tensor)
    pub fn add_initializer(&mut self, name: String, tensor: Tensor) {
        self.initializers.insert(name, tensor);
    }

    /// Get input tensor names
    pub fn input_names(&self) -> Vec<&str> {
        self.inputs.iter().map(|spec| spec.name.as_str()).collect()
    }

    /// Get output tensor names
    pub fn output_names(&self) -> Vec<&str> {
        self.outputs.iter().map(|spec| spec.name.as_str()).collect()
    }

    /// Validate the graph structure
    ///
    /// Checks for:
    /// - Duplicate node names
    /// - References to tensors not produced by any node or input/initializer
    /// - Invalid operator types
    /// - Graph outputs that are never produced
    ///
    /// Node ordering does not matter here; cycle detection is handled separately
    /// by [`Graph::topological_sort`].
    pub fn validate(&self) -> Result<()> {
        // Check for duplicate node names
        let mut node_names = std::collections::HashSet::new();
        for node in &self.nodes {
            if !node_names.insert(&node.name) {
                return Err(OnnxError::graph_validation_error(format!(
                    "Duplicate node name: {}",
                    node.name
                )));
            }
        }

        // Build the full set of tensors that are available anywhere in the graph:
        // graph inputs, initializers, and every node output.  We do this in one
        // pass so that validation is independent of the node listing order.
        let mut available_tensors: std::collections::HashSet<&str> =
            std::collections::HashSet::new();

        for input in &self.inputs {
            available_tensors.insert(&input.name);
        }
        for name in self.initializers.keys() {
            available_tensors.insert(name);
        }
        for node in &self.nodes {
            for output_name in &node.outputs {
                available_tensors.insert(output_name);
            }
        }

        // Now check each node's inputs and operator type
        for node in &self.nodes {
            for input_name in &node.inputs {
                if !available_tensors.contains(input_name.as_str()) {
                    return Err(OnnxError::graph_validation_error(format!(
                        "Node '{}' references unknown input tensor '{}'",
                        node.name, input_name
                    )));
                }
            }

            node.get_operator_type().map_err(|e| {
                OnnxError::graph_validation_error(format!(
                    "Node '{}' has invalid operator type '{}': {}",
                    node.name, node.op_type, e
                ))
            })?;
        }

        // Check that all declared graph outputs are reachable
        for output in &self.outputs {
            if !available_tensors.contains(output.name.as_str()) {
                return Err(OnnxError::graph_validation_error(format!(
                    "Graph output '{}' is not produced by any node",
                    output.name
                )));
            }
        }

        Ok(())
    }

    /// Perform topological sort to get execution order
    pub fn topological_sort(&self) -> Result<Vec<usize>> {
        let n = self.nodes.len();
        let mut in_degree = vec![0; n];
        let mut adjacency_list: Vec<Vec<usize>> = vec![vec![]; n];

        // Build a map from tensor name -> indices of nodes that consume it (O(n))
        let mut consumers: HashMap<&str, Vec<usize>> = HashMap::new();
        for (j, node) in self.nodes.iter().enumerate() {
            for input in &node.inputs {
                consumers.entry(input.as_str()).or_default().push(j);
            }
        }

        // Build adjacency list and in-degree count using the consumer map (O(n))
        for (i, node) in self.nodes.iter().enumerate() {
            for output in &node.outputs {
                if let Some(deps) = consumers.get(output.as_str()) {
                    for &j in deps {
                        adjacency_list[i].push(j);
                        in_degree[j] += 1;
                    }
                }
            }
        }

        // Kahn's algorithm
        let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
        let mut result = Vec::new();

        while let Some(current) = queue.pop() {
            result.push(current);

            for &neighbor in &adjacency_list[current] {
                in_degree[neighbor] -= 1;
                if in_degree[neighbor] == 0 {
                    queue.push(neighbor);
                }
            }
        }

        if result.len() != n {
            return Err(OnnxError::graph_validation_error(
                "Graph contains cycles".to_string(),
            ));
        }

        Ok(result)
    }

    /// Group nodes into parallel execution waves.
    ///
    /// Returns a list of levels where every node in a level is independent of
    /// every other node in that level (no data edges between them).  Nodes in
    /// the same level can be executed concurrently; levels must be executed in
    /// order.
    pub fn topological_levels(&self) -> Result<Vec<Vec<usize>>> {
        let n = self.nodes.len();
        if n == 0 {
            return Ok(vec![]);
        }

        // tensor_level[t] = the wave after which tensor t is available.
        // Graph inputs and initializers are available before wave 0 → level 0.
        let mut tensor_level: HashMap<&str, usize> = HashMap::new();
        for input in &self.inputs {
            tensor_level.insert(input.name.as_str(), 0);
        }
        for name in self.initializers.keys() {
            tensor_level.insert(name.as_str(), 0);
        }

        // Process nodes in topological order so dependencies are resolved first.
        let topo_order = self.topological_sort()?;
        let mut node_level = vec![0usize; n];

        for &idx in &topo_order {
            let node = &self.nodes[idx];
            // A node's wave = max wave of all its input tensors.
            let level = node
                .inputs
                .iter()
                .filter_map(|name| tensor_level.get(name.as_str()).copied())
                .max()
                .unwrap_or(0);
            node_level[idx] = level;
            // Outputs produced by this node become available at level + 1.
            for output in &node.outputs {
                tensor_level.insert(output.as_str(), level + 1);
            }
        }

        let max_level = node_level.iter().copied().max().unwrap_or(0);
        let mut levels: Vec<Vec<usize>> = vec![vec![]; max_level + 1];
        for (idx, &lvl) in node_level.iter().enumerate() {
            levels[lvl].push(idx);
        }

        Ok(levels)
    }

    /// Print the graph structure in a visual ASCII format
    pub fn print_graph(&self) {
        // Calculate the width needed for the graph name
        let title = format!("GRAPH: {}", self.name);
        let min_width = title.len() + 4; // 2 spaces on each side
        let box_width = std::cmp::max(min_width, 40); // Minimum width of 40 characters

        // Create the top border
        let top_border = format!("{}", "".repeat(box_width));

        // Create the title line with proper centering
        let padding = (box_width - title.len()) / 2;
        let left_padding = " ".repeat(padding);
        let right_padding = " ".repeat(box_width - title.len() - padding);
        let title_line = format!("{left_padding}{title}{right_padding}");

        // Create the bottom border
        let bottom_border = format!("{}", "".repeat(box_width));

        println!("\n{top_border}");
        println!("{title_line}");
        println!("{bottom_border}");

        // Print inputs
        if !self.inputs.is_empty() {
            println!("\n📥 INPUTS:");
            for input in &self.inputs {
                let shape_str = input
                    .dimensions
                    .iter()
                    .map(|d| d.map_or("?".to_string(), |v| v.to_string()))
                    .collect::<Vec<_>>()
                    .join(" × ");
                println!("   ┌─ {} [{}] ({})", input.name, shape_str, input.dtype);
            }
        }

        // Print initializers
        if !self.initializers.is_empty() {
            println!("\n⚙️  INITIALIZERS:");
            for (name, tensor) in &self.initializers {
                let shape_str = tensor
                    .shape()
                    .iter()
                    .map(|&d| d.to_string())
                    .collect::<Vec<_>>()
                    .join(" × ");
                println!("   ┌─ {name} [{shape_str}]");
            }
        }

        // Print computation flow
        if !self.nodes.is_empty() {
            println!("\n🔄 COMPUTATION FLOW:");

            // Try to get execution order, fall back to original order if there are cycles
            let execution_order = self.topological_sort().unwrap_or_else(|_| {
                println!("   ⚠️  Warning: Graph contains cycles, showing original order");
                (0..self.nodes.len()).collect()
            });

            for (step, &node_idx) in execution_order.iter().enumerate() {
                let node = &self.nodes[node_idx];

                // Print step number
                println!("");
                println!("   ├─ Step {}: {}", step + 1, node.name);

                // Print operation type
                println!("   │  ┌─ Operation: {}", node.op_type);

                // Print inputs
                if !node.inputs.is_empty() {
                    println!("   │  ├─ Inputs:");
                    for input in &node.inputs {
                        println!("   │  │  └─ {input}");
                    }
                }

                // Print outputs
                if !node.outputs.is_empty() {
                    println!("   │  ├─ Outputs:");
                    for output in &node.outputs {
                        println!("   │  │  └─ {output}");
                    }
                }

                // Print attributes if any
                if !node.attributes.is_empty() {
                    println!("   │  └─ Attributes:");
                    for (key, value) in &node.attributes {
                        println!("   │     └─ {key}: {value}");
                    }
                } else {
                    println!("   │  └─ (no attributes)");
                }
            }
        }

        // Print outputs
        if !self.outputs.is_empty() {
            println!("");
            println!("📤 OUTPUTS:");
            for output in &self.outputs {
                let shape_str = output
                    .dimensions
                    .iter()
                    .map(|d| d.map_or("?".to_string(), |v| v.to_string()))
                    .collect::<Vec<_>>()
                    .join(" × ");
                println!("   └─ {} [{}] ({})", output.name, shape_str, output.dtype);
            }
        }

        println!("\n📊 STATISTICS:");
        println!("   ├─ Total nodes: {}", self.nodes.len());
        println!("   ├─ Input tensors: {}", self.inputs.len());
        println!("   ├─ Output tensors: {}", self.outputs.len());
        println!("   └─ Initializers: {}", self.initializers.len());

        // Print operation summary
        if !self.nodes.is_empty() {
            let mut op_counts: std::collections::BTreeMap<String, usize> =
                std::collections::BTreeMap::new();
            for node in &self.nodes {
                *op_counts.entry(node.op_type.clone()).or_insert(0) += 1;
            }

            println!("\n🎯 OPERATION SUMMARY:");
            for (op_type, count) in op_counts {
                println!("   ├─ {op_type}: {count}");
            }
        }

        println!();
    }

    /// Generate a simplified DOT format for graph visualization tools
    pub fn to_dot(&self) -> String {
        let mut dot = String::new();

        dot.push_str("digraph G {\n");
        dot.push_str("  rankdir=TB;\n");
        dot.push_str("  node [shape=box, style=rounded];\n\n");

        // Add input nodes
        for input in &self.inputs {
            dot.push_str(&format!(
                "  \"{}\" [shape=ellipse, color=green, label=\"{}\"];\n",
                input.name, input.name
            ));
        }

        // Add initializer nodes
        for name in self.initializers.keys() {
            dot.push_str(&format!(
                "  \"{name}\" [shape=diamond, color=blue, label=\"{name}\"];\n"
            ));
        }

        // Add operation nodes
        for node in &self.nodes {
            dot.push_str(&format!(
                "  \"{}\" [label=\"{}\\n({})\"];\n",
                node.name, node.name, node.op_type
            ));
        }

        // Add output nodes
        for output in &self.outputs {
            dot.push_str(&format!(
                "  \"{}\" [shape=ellipse, color=red, label=\"{}\"];\n",
                output.name, output.name
            ));
        }

        dot.push('\n');

        // Add edges
        for node in &self.nodes {
            for input in &node.inputs {
                dot.push_str(&format!("  \"{}\" -> \"{}\";\n", input, node.name));
            }
            for output in &node.outputs {
                dot.push_str(&format!("  \"{}\" -> \"{}\";\n", node.name, output));
            }
        }

        dot.push_str("}\n");
        dot
    }

    /// Create a simple linear graph for testing
    pub fn create_simple_linear() -> Self {
        let mut graph = Graph::new("simple_linear".to_string());

        // Add inputs
        graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));

        // Add outputs
        graph.add_output(TensorSpec::new(
            "output".to_string(),
            vec![Some(1), Some(2)],
        ));

        // Add weight initializer
        let weights = Tensor::from_shape_vec(&[3, 2], vec![0.5, 0.3, 0.2, 0.4, 0.1, 0.6]).unwrap();
        let bias = Tensor::from_shape_vec(&[1, 2], vec![0.1, 0.2]).unwrap();

        graph.add_initializer("weights".to_string(), weights);
        graph.add_initializer("bias".to_string(), bias);

        // Add MatMul node
        let matmul_node = Node::new(
            "matmul".to_string(),
            "MatMul".to_string(),
            vec!["input".to_string(), "weights".to_string()],
            vec!["matmul_output".to_string()],
        );
        graph.add_node(matmul_node);

        // Add Add node (bias)
        let add_node = Node::new(
            "add_bias".to_string(),
            "Add".to_string(),
            vec!["matmul_output".to_string(), "bias".to_string()],
            vec!["output".to_string()],
        );
        graph.add_node(add_node);

        graph
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_node_creation() {
        let mut node = Node::new(
            "test_node".to_string(),
            "Add".to_string(),
            vec!["input1".to_string(), "input2".to_string()],
            vec!["output".to_string()],
        );

        assert_eq!(node.name, "test_node");
        assert_eq!(node.op_type, "Add");
        assert_eq!(node.inputs.len(), 2);
        assert_eq!(node.outputs.len(), 1);

        node.add_attribute("axis", "1");
        assert_eq!(node.attributes.get("axis"), Some(&"1".to_string()));
    }

    #[test]
    fn test_tensor_spec() {
        let spec = TensorSpec::new("test_tensor".to_string(), vec![Some(2), Some(3), None]);

        let matching_tensor = Tensor::zeros(&[2, 3, 5]); // 5 is dynamic
        let non_matching_tensor = Tensor::zeros(&[2, 4, 5]); // Wrong second dimension

        assert!(spec.matches_tensor(&matching_tensor));
        assert!(!spec.matches_tensor(&non_matching_tensor));
    }

    #[test]
    fn test_graph_creation() {
        let mut graph = Graph::new("test_graph".to_string());

        graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
        graph.add_output(TensorSpec::new(
            "output".to_string(),
            vec![Some(1), Some(1)],
        ));

        let node = Node::new(
            "relu".to_string(),
            "Relu".to_string(),
            vec!["input".to_string()],
            vec!["output".to_string()],
        );
        graph.add_node(node);

        assert_eq!(graph.nodes.len(), 1);
        assert_eq!(graph.inputs.len(), 1);
        assert_eq!(graph.outputs.len(), 1);
        assert_eq!(graph.input_names(), vec!["input"]);
        assert_eq!(graph.output_names(), vec!["output"]);
    }

    #[test]
    fn test_graph_validation_success() {
        let mut graph = Graph::new("valid_graph".to_string());

        graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
        graph.add_output(TensorSpec::new(
            "output".to_string(),
            vec![Some(1), Some(3)],
        ));

        let node = Node::new(
            "relu".to_string(),
            "Relu".to_string(),
            vec!["input".to_string()],
            vec!["output".to_string()],
        );
        graph.add_node(node);

        assert!(graph.validate().is_ok());
    }

    #[test]
    fn test_graph_validation_failure() {
        let mut graph = Graph::new("invalid_graph".to_string());

        // Missing input declaration
        graph.add_output(TensorSpec::new(
            "output".to_string(),
            vec![Some(1), Some(3)],
        ));

        let node = Node::new(
            "relu".to_string(),
            "Relu".to_string(),
            vec!["missing_input".to_string()], // References unknown input
            vec!["output".to_string()],
        );
        graph.add_node(node);

        assert!(graph.validate().is_err());
    }

    #[test]
    fn test_simple_linear_graph() {
        let graph = Graph::create_simple_linear();

        assert!(graph.validate().is_ok());
        assert_eq!(graph.nodes.len(), 2);
        assert_eq!(graph.inputs.len(), 1);
        assert_eq!(graph.outputs.len(), 1);
        assert_eq!(graph.initializers.len(), 2);

        // Test topological sort
        let order = graph.topological_sort().unwrap();
        assert_eq!(order.len(), 2);
        // MatMul should come before Add
        let matmul_pos = order
            .iter()
            .position(|&i| graph.nodes[i].op_type == "MatMul")
            .unwrap();
        let add_pos = order
            .iter()
            .position(|&i| graph.nodes[i].op_type == "Add")
            .unwrap();
        assert!(matmul_pos < add_pos);
    }

    #[test]
    fn test_graph_print_functions() {
        let graph = Graph::create_simple_linear();

        // Test that print_graph doesn't panic
        graph.print_graph();

        // Test DOT format generation
        let dot_content = graph.to_dot();
        assert!(dot_content.contains("digraph G {"));
        assert!(dot_content.contains("input"));
        assert!(dot_content.contains("output"));
        assert!(dot_content.contains("MatMul"));
        assert!(dot_content.contains("Add"));
        assert!(dot_content.contains("->"));
        assert!(dot_content.ends_with("}\n"));
    }

    #[test]
    fn test_topological_sort() {
        let mut graph = Graph::new("test_topo".to_string());

        // Create a simple chain: input -> relu -> sigmoid -> output
        graph.add_input(TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]));
        graph.add_output(TensorSpec::new(
            "output".to_string(),
            vec![Some(1), Some(3)],
        ));

        let relu_node = Node::new(
            "relu".to_string(),
            "Relu".to_string(),
            vec!["input".to_string()],
            vec!["relu_out".to_string()],
        );
        graph.add_node(relu_node);

        let sigmoid_node = Node::new(
            "sigmoid".to_string(),
            "Sigmoid".to_string(),
            vec!["relu_out".to_string()],
            vec!["output".to_string()],
        );
        graph.add_node(sigmoid_node);

        let order = graph.topological_sort().unwrap();
        assert_eq!(order, vec![0, 1]); // relu first, then sigmoid
    }
}