ipfrs_tensorlogic/
computation_graph.rs

1//! Computation graph storage and execution
2//!
3//! This module provides:
4//! - IPLD schema for computation graphs
5//! - Graph serialization and deserialization
6//! - Graph optimization (CSE, constant folding, fusion)
7//! - Lazy evaluation with memoization
8//! - Parallel execution support
9//! - Streaming execution with backpressure
10//! - Distributed graph execution
11
12use ipfrs_core::Cid;
13use rayon::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet, VecDeque};
16use std::sync::{Arc, Mutex};
17use thiserror::Error;
18
19/// Errors that can occur during graph operations
20#[derive(Debug, Error)]
21pub enum GraphError {
22    #[error("Node not found: {0}")]
23    NodeNotFound(String),
24
25    #[error("Circular dependency detected")]
26    CircularDependency,
27
28    #[error("Invalid graph structure: {0}")]
29    InvalidGraph(String),
30
31    #[error("Type mismatch: expected {expected}, got {actual}")]
32    TypeMismatch { expected: String, actual: String },
33
34    #[error("Shape mismatch: {0}")]
35    ShapeMismatch(String),
36
37    #[error("Missing input: {0}")]
38    MissingInput(String),
39
40    #[error("Execution error: {0}")]
41    ExecutionError(String),
42}
43
44/// Tensor operation types
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub enum TensorOp {
47    /// Input placeholder
48    Input { name: String },
49
50    /// Constant tensor
51    Constant { value_cid: String },
52
53    /// Matrix multiplication
54    MatMul,
55
56    /// Element-wise addition
57    Add,
58
59    /// Element-wise multiplication
60    Mul,
61
62    /// Element-wise subtraction
63    Sub,
64
65    /// Element-wise division
66    Div,
67
68    /// Einsum operation with subscript notation
69    Einsum { subscripts: String },
70
71    /// Reshape operation
72    Reshape { shape: Vec<i64> },
73
74    /// Transpose operation
75    Transpose { axes: Vec<usize> },
76
77    /// Reduce sum along axes
78    ReduceSum { axes: Vec<usize>, keepdims: bool },
79
80    /// Reduce mean along axes
81    ReduceMean { axes: Vec<usize>, keepdims: bool },
82
83    /// Activation: ReLU
84    ReLU,
85
86    /// Activation: Tanh
87    Tanh,
88
89    /// Activation: Sigmoid
90    Sigmoid,
91
92    /// Activation: GELU (Gaussian Error Linear Unit)
93    GELU,
94
95    /// Activation: Softmax along axis
96    Softmax { axis: i64 },
97
98    /// Layer normalization
99    LayerNorm {
100        normalized_shape: Vec<usize>,
101        eps: f64,
102    },
103
104    /// Batch normalization
105    BatchNorm { eps: f64, momentum: f64 },
106
107    /// Dropout (training mode)
108    Dropout { p: f64 },
109
110    /// Element-wise exponential
111    Exp,
112
113    /// Element-wise logarithm
114    Log,
115
116    /// Element-wise power
117    Pow { exponent: f64 },
118
119    /// Element-wise square root
120    Sqrt,
121
122    /// Concatenate tensors along axis
123    Concat { axis: usize },
124
125    /// Split tensor along axis
126    Split { axis: usize, sections: Vec<usize> },
127
128    /// Gather elements along axis
129    Gather { axis: usize },
130
131    /// Scatter elements along axis
132    Scatter { axis: usize },
133
134    /// Slice tensor
135    Slice {
136        start: Vec<i64>,
137        end: Vec<i64>,
138        strides: Vec<i64>,
139    },
140
141    /// Pad tensor
142    Pad {
143        padding: Vec<(usize, usize)>,
144        mode: String,
145    },
146
147    // Fused operations for performance
148    /// Fused MatMul + Add (common in linear layers)
149    FusedLinear,
150
151    /// Fused Add + ReLU
152    FusedAddReLU,
153
154    /// Fused BatchNorm + ReLU
155    FusedBatchNormReLU { eps: f64, momentum: f64 },
156
157    /// Fused LayerNorm + Dropout
158    FusedLayerNormDropout {
159        normalized_shape: Vec<usize>,
160        eps: f64,
161        dropout_p: f64,
162    },
163}
164
165impl TensorOp {
166    /// Get the number of inputs required by this operation
167    pub fn num_inputs(&self) -> usize {
168        match self {
169            TensorOp::Input { .. } | TensorOp::Constant { .. } => 0,
170            TensorOp::ReLU
171            | TensorOp::Tanh
172            | TensorOp::Sigmoid
173            | TensorOp::GELU
174            | TensorOp::Softmax { .. }
175            | TensorOp::LayerNorm { .. }
176            | TensorOp::BatchNorm { .. }
177            | TensorOp::Dropout { .. }
178            | TensorOp::Exp
179            | TensorOp::Log
180            | TensorOp::Pow { .. }
181            | TensorOp::Sqrt
182            | TensorOp::Reshape { .. }
183            | TensorOp::Transpose { .. }
184            | TensorOp::ReduceSum { .. }
185            | TensorOp::ReduceMean { .. }
186            | TensorOp::Slice { .. }
187            | TensorOp::Pad { .. } => 1,
188            TensorOp::MatMul
189            | TensorOp::Add
190            | TensorOp::Mul
191            | TensorOp::Sub
192            | TensorOp::Div
193            | TensorOp::Gather { .. }
194            | TensorOp::Scatter { .. }
195            | TensorOp::FusedAddReLU => 2,
196            TensorOp::Einsum { .. } => 2, // Simplified for now
197            TensorOp::Concat { .. } | TensorOp::Split { .. } => 1, // Variadic, but simplified
198            TensorOp::FusedLinear => 3,   // input, weight, bias
199            TensorOp::FusedBatchNormReLU { .. } => 1,
200            TensorOp::FusedLayerNormDropout { .. } => 1,
201        }
202    }
203
204    /// Check if this is a pure operation (no side effects)
205    pub fn is_pure(&self) -> bool {
206        true // All current ops are pure
207    }
208
209    /// Infer output shape from input shapes
210    pub fn infer_output_shape(
211        &self,
212        input_shapes: &[Vec<usize>],
213    ) -> Result<Vec<usize>, GraphError> {
214        match self {
215            TensorOp::Input { .. } | TensorOp::Constant { .. } => Err(GraphError::InvalidGraph(
216                "Cannot infer shape for input/constant nodes without explicit shape".to_string(),
217            )),
218            // Unary element-wise operations preserve shape
219            TensorOp::ReLU
220            | TensorOp::Tanh
221            | TensorOp::Sigmoid
222            | TensorOp::GELU
223            | TensorOp::Exp
224            | TensorOp::Log
225            | TensorOp::Sqrt
226            | TensorOp::Dropout { .. } => {
227                if input_shapes.is_empty() {
228                    return Err(GraphError::MissingInput(
229                        "No input shapes provided".to_string(),
230                    ));
231                }
232                Ok(input_shapes[0].clone())
233            }
234            // Binary element-wise operations (broadcasting rules apply)
235            TensorOp::Add | TensorOp::Mul | TensorOp::Sub | TensorOp::Div => {
236                if input_shapes.len() < 2 {
237                    return Err(GraphError::MissingInput(
238                        "Binary operation requires 2 inputs".to_string(),
239                    ));
240                }
241                Self::broadcast_shapes(&input_shapes[0], &input_shapes[1])
242            }
243            TensorOp::MatMul => {
244                if input_shapes.len() < 2 {
245                    return Err(GraphError::MissingInput(
246                        "MatMul requires 2 inputs".to_string(),
247                    ));
248                }
249                let a = &input_shapes[0];
250                let b = &input_shapes[1];
251                if a.len() < 2 || b.len() < 2 {
252                    return Err(GraphError::ShapeMismatch(
253                        "MatMul requires at least 2D tensors".to_string(),
254                    ));
255                }
256                let m = a[a.len() - 2];
257                let k1 = a[a.len() - 1];
258                let k2 = b[b.len() - 2];
259                let n = b[b.len() - 1];
260                if k1 != k2 {
261                    return Err(GraphError::ShapeMismatch(format!(
262                        "MatMul dimension mismatch: {} vs {}",
263                        k1, k2
264                    )));
265                }
266                let mut result = a[..a.len() - 2].to_vec();
267                result.push(m);
268                result.push(n);
269                Ok(result)
270            }
271            TensorOp::Reshape { shape } => {
272                let new_shape: Vec<usize> = shape.iter().map(|&s| s as usize).collect();
273                Ok(new_shape)
274            }
275            TensorOp::Transpose { axes } => {
276                if input_shapes.is_empty() {
277                    return Err(GraphError::MissingInput(
278                        "No input shapes provided".to_string(),
279                    ));
280                }
281                let input_shape = &input_shapes[0];
282                if axes.len() != input_shape.len() {
283                    return Err(GraphError::ShapeMismatch(
284                        "Transpose axes must match input dimensions".to_string(),
285                    ));
286                }
287                let mut output_shape = vec![0; input_shape.len()];
288                for (i, &axis) in axes.iter().enumerate() {
289                    output_shape[i] = input_shape[axis];
290                }
291                Ok(output_shape)
292            }
293            TensorOp::ReduceSum { axes, keepdims } | TensorOp::ReduceMean { axes, keepdims } => {
294                if input_shapes.is_empty() {
295                    return Err(GraphError::MissingInput(
296                        "No input shapes provided".to_string(),
297                    ));
298                }
299                let input_shape = &input_shapes[0];
300                if *keepdims {
301                    let mut output_shape = input_shape.clone();
302                    for &axis in axes {
303                        if axis < output_shape.len() {
304                            output_shape[axis] = 1;
305                        }
306                    }
307                    Ok(output_shape)
308                } else {
309                    let output_shape: Vec<usize> = input_shape
310                        .iter()
311                        .enumerate()
312                        .filter(|(i, _)| !axes.contains(i))
313                        .map(|(_, &dim)| dim)
314                        .collect();
315                    Ok(output_shape)
316                }
317            }
318            TensorOp::Softmax { .. } => {
319                if input_shapes.is_empty() {
320                    return Err(GraphError::MissingInput(
321                        "No input shapes provided".to_string(),
322                    ));
323                }
324                Ok(input_shapes[0].clone())
325            }
326            TensorOp::LayerNorm { .. }
327            | TensorOp::BatchNorm { .. }
328            | TensorOp::Pow { .. }
329            | TensorOp::FusedBatchNormReLU { .. }
330            | TensorOp::FusedLayerNormDropout { .. } => {
331                if input_shapes.is_empty() {
332                    return Err(GraphError::MissingInput(
333                        "No input shapes provided".to_string(),
334                    ));
335                }
336                Ok(input_shapes[0].clone())
337            }
338            TensorOp::Concat { axis } => {
339                if input_shapes.is_empty() {
340                    return Err(GraphError::MissingInput(
341                        "Concat requires at least one input".to_string(),
342                    ));
343                }
344                let mut output_shape = input_shapes[0].clone();
345                if *axis >= output_shape.len() {
346                    return Err(GraphError::ShapeMismatch("Invalid concat axis".to_string()));
347                }
348                for shape in &input_shapes[1..] {
349                    if shape.len() != output_shape.len() {
350                        return Err(GraphError::ShapeMismatch(
351                            "Concat inputs must have same rank".to_string(),
352                        ));
353                    }
354                    output_shape[*axis] += shape[*axis];
355                }
356                Ok(output_shape)
357            }
358            TensorOp::Slice { start, end, .. } => {
359                if input_shapes.is_empty() {
360                    return Err(GraphError::MissingInput(
361                        "No input shapes provided".to_string(),
362                    ));
363                }
364                let input_shape = &input_shapes[0];
365                let output_shape: Vec<usize> = start
366                    .iter()
367                    .zip(end.iter())
368                    .map(|(&s, &e)| (e - s).max(0) as usize)
369                    .collect();
370                if output_shape.len() != input_shape.len() {
371                    return Err(GraphError::ShapeMismatch(
372                        "Slice dimensions must match input".to_string(),
373                    ));
374                }
375                Ok(output_shape)
376            }
377            TensorOp::Pad { padding, .. } => {
378                if input_shapes.is_empty() {
379                    return Err(GraphError::MissingInput(
380                        "No input shapes provided".to_string(),
381                    ));
382                }
383                let input_shape = &input_shapes[0];
384                let output_shape: Vec<usize> = input_shape
385                    .iter()
386                    .zip(padding.iter())
387                    .map(|(&dim, &(pad_before, pad_after))| dim + pad_before + pad_after)
388                    .collect();
389                Ok(output_shape)
390            }
391            TensorOp::FusedLinear => {
392                if input_shapes.len() < 3 {
393                    return Err(GraphError::MissingInput(
394                        "FusedLinear requires 3 inputs".to_string(),
395                    ));
396                }
397                // Similar to MatMul + Add
398                let a = &input_shapes[0];
399                let b = &input_shapes[1];
400                if a.len() < 2 || b.len() < 2 {
401                    return Err(GraphError::ShapeMismatch(
402                        "Linear requires at least 2D tensors".to_string(),
403                    ));
404                }
405                let m = a[a.len() - 2];
406                let n = b[b.len() - 1];
407                let mut result = a[..a.len() - 2].to_vec();
408                result.push(m);
409                result.push(n);
410                Ok(result)
411            }
412            TensorOp::FusedAddReLU => {
413                if input_shapes.len() < 2 {
414                    return Err(GraphError::MissingInput(
415                        "FusedAddReLU requires 2 inputs".to_string(),
416                    ));
417                }
418                Self::broadcast_shapes(&input_shapes[0], &input_shapes[1])
419            }
420            _ => {
421                // For operations not yet implemented, preserve first input shape
422                if input_shapes.is_empty() {
423                    return Err(GraphError::MissingInput(
424                        "No input shapes provided".to_string(),
425                    ));
426                }
427                Ok(input_shapes[0].clone())
428            }
429        }
430    }
431
432    /// Broadcast two shapes according to NumPy broadcasting rules
433    fn broadcast_shapes(a: &[usize], b: &[usize]) -> Result<Vec<usize>, GraphError> {
434        let mut result = Vec::new();
435        let max_len = a.len().max(b.len());
436
437        for i in 0..max_len {
438            let dim_a = if i < a.len() { a[a.len() - 1 - i] } else { 1 };
439            let dim_b = if i < b.len() { b[b.len() - 1 - i] } else { 1 };
440
441            if dim_a == dim_b {
442                result.push(dim_a);
443            } else if dim_a == 1 {
444                result.push(dim_b);
445            } else if dim_b == 1 {
446                result.push(dim_a);
447            } else {
448                return Err(GraphError::ShapeMismatch(format!(
449                    "Cannot broadcast shapes: {:?} and {:?}",
450                    a, b
451                )));
452            }
453        }
454
455        result.reverse();
456        Ok(result)
457    }
458}
459
460/// Node in the computation graph
461#[derive(Debug, Clone, Serialize, Deserialize)]
462pub struct GraphNode {
463    /// Unique node ID
464    pub id: String,
465
466    /// Operation type
467    pub op: TensorOp,
468
469    /// Input node IDs
470    pub inputs: Vec<String>,
471
472    /// Output shape (if known)
473    pub output_shape: Option<Vec<usize>>,
474
475    /// Metadata
476    pub metadata: HashMap<String, String>,
477}
478
479impl GraphNode {
480    /// Create a new graph node
481    pub fn new(id: String, op: TensorOp) -> Self {
482        Self {
483            id,
484            op,
485            inputs: Vec::new(),
486            output_shape: None,
487            metadata: HashMap::new(),
488        }
489    }
490
491    /// Add an input node
492    pub fn add_input(mut self, input_id: String) -> Self {
493        self.inputs.push(input_id);
494        self
495    }
496
497    /// Set output shape
498    pub fn with_output_shape(mut self, shape: Vec<usize>) -> Self {
499        self.output_shape = Some(shape);
500        self
501    }
502
503    /// Add metadata
504    pub fn add_metadata(mut self, key: String, value: String) -> Self {
505        self.metadata.insert(key, value);
506        self
507    }
508}
509
510/// Computation graph
511#[derive(Debug, Clone, Serialize, Deserialize)]
512pub struct ComputationGraph {
513    /// Graph nodes (node ID -> node)
514    pub nodes: HashMap<String, GraphNode>,
515
516    /// Input node IDs
517    pub inputs: Vec<String>,
518
519    /// Output node IDs
520    pub outputs: Vec<String>,
521
522    /// Graph metadata
523    pub metadata: HashMap<String, String>,
524
525    /// Graph CID (if stored in IPFS)
526    #[serde(skip_serializing_if = "Option::is_none")]
527    #[serde(serialize_with = "serialize_optional_cid")]
528    #[serde(deserialize_with = "deserialize_optional_cid")]
529    pub cid: Option<Cid>,
530}
531
532impl ComputationGraph {
533    /// Create a new empty computation graph
534    pub fn new() -> Self {
535        Self {
536            nodes: HashMap::new(),
537            inputs: Vec::new(),
538            outputs: Vec::new(),
539            metadata: HashMap::new(),
540            cid: None,
541        }
542    }
543
544    /// Add a node to the graph
545    pub fn add_node(&mut self, node: GraphNode) -> Result<(), GraphError> {
546        let id = node.id.clone();
547
548        // Validate inputs exist
549        for input_id in &node.inputs {
550            if !self.nodes.contains_key(input_id) && !self.inputs.contains(input_id) {
551                return Err(GraphError::NodeNotFound(input_id.clone()));
552            }
553        }
554
555        self.nodes.insert(id, node);
556        Ok(())
557    }
558
559    /// Mark a node as an input
560    pub fn mark_input(&mut self, node_id: String) {
561        if !self.inputs.contains(&node_id) {
562            self.inputs.push(node_id);
563        }
564    }
565
566    /// Mark a node as an output
567    pub fn mark_output(&mut self, node_id: String) {
568        if !self.outputs.contains(&node_id) {
569            self.outputs.push(node_id);
570        }
571    }
572
573    /// Get topological order of nodes
574    pub fn topological_sort(&self) -> Result<Vec<String>, GraphError> {
575        let mut in_degree: HashMap<String, usize> = HashMap::new();
576        let mut adj_list: HashMap<String, Vec<String>> = HashMap::new();
577
578        // Build adjacency list and compute in-degrees
579        for (node_id, node) in &self.nodes {
580            in_degree.entry(node_id.clone()).or_insert(0);
581            adj_list.entry(node_id.clone()).or_default();
582
583            for input_id in &node.inputs {
584                if self.nodes.contains_key(input_id) {
585                    *in_degree.entry(node_id.clone()).or_insert(0) += 1;
586                    adj_list
587                        .entry(input_id.clone())
588                        .or_default()
589                        .push(node_id.clone());
590                }
591            }
592        }
593
594        // Kahn's algorithm
595        let mut queue: VecDeque<String> = in_degree
596            .iter()
597            .filter(|(_, &deg)| deg == 0)
598            .map(|(id, _)| id.clone())
599            .collect();
600
601        let mut result = Vec::new();
602
603        while let Some(node_id) = queue.pop_front() {
604            result.push(node_id.clone());
605
606            if let Some(neighbors) = adj_list.get(&node_id) {
607                for neighbor in neighbors {
608                    if let Some(deg) = in_degree.get_mut(neighbor) {
609                        *deg -= 1;
610                        if *deg == 0 {
611                            queue.push_back(neighbor.clone());
612                        }
613                    }
614                }
615            }
616        }
617
618        if result.len() != self.nodes.len() {
619            return Err(GraphError::CircularDependency);
620        }
621
622        Ok(result)
623    }
624
625    /// Extract a subgraph containing only the specified output nodes
626    pub fn extract_subgraph(&self, output_ids: &[String]) -> Result<ComputationGraph, GraphError> {
627        let mut subgraph = ComputationGraph::new();
628        let mut visited = HashSet::new();
629        let mut queue: VecDeque<String> = output_ids.iter().cloned().collect();
630
631        // Backward DFS to find all dependencies
632        while let Some(node_id) = queue.pop_front() {
633            if visited.contains(&node_id) {
634                continue;
635            }
636
637            visited.insert(node_id.clone());
638
639            if let Some(node) = self.nodes.get(&node_id) {
640                for input_id in &node.inputs {
641                    if !visited.contains(input_id) {
642                        queue.push_back(input_id.clone());
643                    }
644                }
645            }
646        }
647
648        // Set inputs first (before adding nodes that depend on them)
649        for input_id in &self.inputs {
650            if visited.contains(input_id) {
651                subgraph.mark_input(input_id.clone());
652            }
653        }
654
655        // Copy relevant nodes
656        for node_id in &visited {
657            if let Some(node) = self.nodes.get(node_id) {
658                subgraph.nodes.insert(node_id.clone(), node.clone());
659            }
660        }
661
662        // Set outputs
663        for output_id in output_ids {
664            subgraph.mark_output(output_id.clone());
665        }
666
667        Ok(subgraph)
668    }
669
670    /// Optimize the graph using common subexpression elimination
671    pub fn optimize_cse(&mut self) -> usize {
672        let mut optimized_count = 0;
673        let mut expr_map: HashMap<String, String> = HashMap::new();
674
675        if let Ok(sorted) = self.topological_sort() {
676            for node_id in sorted {
677                if let Some(node) = self.nodes.get(&node_id) {
678                    // Create expression signature
679                    let signature = format!("{:?}:{:?}", node.op, node.inputs);
680
681                    if let Some(existing_id) = expr_map.get(&signature) {
682                        // Found duplicate, replace references
683                        for other_node in self.nodes.values_mut() {
684                            for input in &mut other_node.inputs {
685                                if input == &node_id {
686                                    *input = existing_id.clone();
687                                    optimized_count += 1;
688                                }
689                            }
690                        }
691                    } else {
692                        expr_map.insert(signature, node_id.clone());
693                    }
694                }
695            }
696        }
697
698        optimized_count
699    }
700
701    /// Count the number of nodes
702    pub fn node_count(&self) -> usize {
703        self.nodes.len()
704    }
705
706    /// Get the number of inputs
707    pub fn input_count(&self) -> usize {
708        self.inputs.len()
709    }
710
711    /// Get the number of outputs
712    pub fn output_count(&self) -> usize {
713        self.outputs.len()
714    }
715
716    /// Propagate shapes through the graph (shape inference)
717    /// This method performs a topological traversal and infers output shapes for all nodes
718    pub fn propagate_shapes(&mut self) -> Result<(), GraphError> {
719        // Get topological order
720        let topo_order = self.topological_sort()?;
721
722        // Propagate shapes in topological order
723        for node_id in topo_order {
724            if let Some(node) = self.nodes.get(&node_id).cloned() {
725                // Skip if shape is already known
726                if node.output_shape.is_some() {
727                    continue;
728                }
729
730                // Collect input shapes
731                let mut input_shapes = Vec::new();
732                for input_id in &node.inputs {
733                    if let Some(input_node) = self.nodes.get(input_id) {
734                        if let Some(shape) = &input_node.output_shape {
735                            input_shapes.push(shape.clone());
736                        } else {
737                            return Err(GraphError::InvalidGraph(format!(
738                                "Input node {} has no shape information",
739                                input_id
740                            )));
741                        }
742                    } else {
743                        return Err(GraphError::NodeNotFound(input_id.clone()));
744                    }
745                }
746
747                // Infer output shape
748                let output_shape = node.op.infer_output_shape(&input_shapes)?;
749
750                // Update node with inferred shape
751                if let Some(node_mut) = self.nodes.get_mut(&node_id) {
752                    node_mut.output_shape = Some(output_shape);
753                }
754            }
755        }
756
757        Ok(())
758    }
759
760    /// Validate graph structure and shapes
761    pub fn validate(&self) -> Result<(), GraphError> {
762        // Check all inputs exist
763        for input_id in &self.inputs {
764            if !self.nodes.contains_key(input_id) {
765                return Err(GraphError::NodeNotFound(format!(
766                    "Input node {} not found",
767                    input_id
768                )));
769            }
770        }
771
772        // Check all outputs exist
773        for output_id in &self.outputs {
774            if !self.nodes.contains_key(output_id) {
775                return Err(GraphError::NodeNotFound(format!(
776                    "Output node {} not found",
777                    output_id
778                )));
779            }
780        }
781
782        // Check all node inputs exist
783        for (node_id, node) in &self.nodes {
784            for input_id in &node.inputs {
785                if !self.nodes.contains_key(input_id) && !self.inputs.contains(input_id) {
786                    return Err(GraphError::NodeNotFound(format!(
787                        "Node {} references non-existent input {}",
788                        node_id, input_id
789                    )));
790                }
791            }
792
793            // Validate expected number of inputs
794            let expected_inputs = node.op.num_inputs();
795            if node.inputs.len() != expected_inputs && expected_inputs > 0 {
796                return Err(GraphError::InvalidGraph(format!(
797                    "Node {} expects {} inputs but has {}",
798                    node_id,
799                    expected_inputs,
800                    node.inputs.len()
801                )));
802            }
803        }
804
805        // Check for cycles
806        self.topological_sort().map(|_| ())
807    }
808
809    /// Get memory footprint estimate for the graph
810    pub fn estimate_memory(&self) -> usize {
811        let mut total_bytes = 0;
812
813        for node in self.nodes.values() {
814            if let Some(shape) = &node.output_shape {
815                // Assume f32 (4 bytes) for simplicity
816                let elements: usize = shape.iter().product();
817                total_bytes += elements * 4;
818            }
819        }
820
821        total_bytes
822    }
823}
824
825impl Default for ComputationGraph {
826    fn default() -> Self {
827        Self::new()
828    }
829}
830
831/// Graph optimizer for applying optimizations
832pub struct GraphOptimizer;
833
834impl GraphOptimizer {
835    /// Apply constant folding
836    pub fn constant_folding(graph: &mut ComputationGraph) -> Result<usize, GraphError> {
837        let mut folded_count = 0;
838
839        // Simplified constant folding - in a real implementation,
840        // we would evaluate constant sub-expressions
841        let sorted = graph.topological_sort()?;
842
843        for node_id in sorted {
844            if let Some(node) = graph.nodes.get(&node_id) {
845                // Check if all inputs are constants
846                let all_const = node.inputs.iter().all(|input_id| {
847                    graph
848                        .nodes
849                        .get(input_id)
850                        .map(|n| matches!(n.op, TensorOp::Constant { .. }))
851                        .unwrap_or(false)
852                });
853
854                if all_const && node.op.is_pure() {
855                    // In a real implementation, we would evaluate this
856                    // and replace with a constant
857                    folded_count += 1;
858                }
859            }
860        }
861
862        Ok(folded_count)
863    }
864
865    /// Fuse consecutive operations where possible
866    pub fn fusion(graph: &mut ComputationGraph) -> Result<usize, GraphError> {
867        let mut fused_count = 0;
868        let mut nodes_to_remove = HashSet::new();
869        let mut new_nodes: HashMap<String, GraphNode> = HashMap::new();
870
871        // Build a map of node outputs to their consumers
872        let mut consumers: HashMap<String, Vec<String>> = HashMap::new();
873        for (node_id, node) in &graph.nodes {
874            for input in &node.inputs {
875                consumers
876                    .entry(input.clone())
877                    .or_default()
878                    .push(node_id.clone());
879            }
880        }
881
882        // Pattern 1: MatMul + Add -> FusedLinear
883        for (node_id, node) in &graph.nodes {
884            if let TensorOp::Add = node.op {
885                if node.inputs.len() == 2 {
886                    // Check if one of the inputs is a MatMul
887                    for input_id in &node.inputs {
888                        if let Some(input_node) = graph.nodes.get(input_id) {
889                            if matches!(input_node.op, TensorOp::MatMul) {
890                                // Only fuse if the MatMul has a single consumer
891                                if let Some(input_consumers) = consumers.get(input_id) {
892                                    if input_consumers.len() == 1
893                                        && !nodes_to_remove.contains(node_id)
894                                    {
895                                        // Create fused node
896                                        let fused_id = format!("{}_fused", node_id);
897                                        let fused_node = GraphNode {
898                                            id: fused_id.clone(),
899                                            op: TensorOp::FusedLinear,
900                                            inputs: vec![
901                                                input_node.inputs[0].clone(),
902                                                input_node.inputs[1].clone(),
903                                                node.inputs
904                                                    .iter()
905                                                    .find(|&id| id != input_id)
906                                                    .unwrap()
907                                                    .clone(),
908                                            ],
909                                            output_shape: node.output_shape.clone(),
910                                            metadata: HashMap::new(),
911                                        };
912                                        new_nodes.insert(fused_id, fused_node);
913                                        nodes_to_remove.insert(node_id.clone());
914                                        nodes_to_remove.insert(input_id.clone());
915                                        fused_count += 1;
916                                    }
917                                }
918                            }
919                        }
920                    }
921                }
922            }
923        }
924
925        // Pattern 2: Add + ReLU -> FusedAddReLU
926        for (node_id, node) in &graph.nodes {
927            if let TensorOp::ReLU = node.op {
928                if node.inputs.len() == 1 {
929                    let input_id = &node.inputs[0];
930                    if let Some(input_node) = graph.nodes.get(input_id) {
931                        if matches!(input_node.op, TensorOp::Add) {
932                            if let Some(input_consumers) = consumers.get(input_id) {
933                                if input_consumers.len() == 1 && !nodes_to_remove.contains(node_id)
934                                {
935                                    let fused_id = format!("{}_fused", node_id);
936                                    let fused_node = GraphNode {
937                                        id: fused_id.clone(),
938                                        op: TensorOp::FusedAddReLU,
939                                        inputs: input_node.inputs.clone(),
940                                        output_shape: node.output_shape.clone(),
941                                        metadata: HashMap::new(),
942                                    };
943                                    new_nodes.insert(fused_id, fused_node);
944                                    nodes_to_remove.insert(node_id.clone());
945                                    nodes_to_remove.insert(input_id.clone());
946                                    fused_count += 1;
947                                }
948                            }
949                        }
950                    }
951                }
952            }
953        }
954
955        // Pattern 3: BatchNorm + ReLU -> FusedBatchNormReLU
956        for (node_id, node) in &graph.nodes {
957            if let TensorOp::ReLU = node.op {
958                if node.inputs.len() == 1 {
959                    let input_id = &node.inputs[0];
960                    if let Some(input_node) = graph.nodes.get(input_id) {
961                        if let TensorOp::BatchNorm { eps, momentum } = &input_node.op {
962                            if let Some(input_consumers) = consumers.get(input_id) {
963                                if input_consumers.len() == 1 && !nodes_to_remove.contains(node_id)
964                                {
965                                    let fused_id = format!("{}_fused", node_id);
966                                    let fused_node = GraphNode {
967                                        id: fused_id.clone(),
968                                        op: TensorOp::FusedBatchNormReLU {
969                                            eps: *eps,
970                                            momentum: *momentum,
971                                        },
972                                        inputs: input_node.inputs.clone(),
973                                        output_shape: node.output_shape.clone(),
974                                        metadata: HashMap::new(),
975                                    };
976                                    new_nodes.insert(fused_id, fused_node);
977                                    nodes_to_remove.insert(node_id.clone());
978                                    nodes_to_remove.insert(input_id.clone());
979                                    fused_count += 1;
980                                }
981                            }
982                        }
983                    }
984                }
985            }
986        }
987
988        // Pattern 4: LayerNorm + Dropout -> FusedLayerNormDropout
989        for (node_id, node) in &graph.nodes {
990            if let TensorOp::Dropout { p } = &node.op {
991                if node.inputs.len() == 1 {
992                    let input_id = &node.inputs[0];
993                    if let Some(input_node) = graph.nodes.get(input_id) {
994                        if let TensorOp::LayerNorm {
995                            normalized_shape,
996                            eps,
997                        } = &input_node.op
998                        {
999                            if let Some(input_consumers) = consumers.get(input_id) {
1000                                if input_consumers.len() == 1 && !nodes_to_remove.contains(node_id)
1001                                {
1002                                    let fused_id = format!("{}_fused", node_id);
1003                                    let fused_node = GraphNode {
1004                                        id: fused_id.clone(),
1005                                        op: TensorOp::FusedLayerNormDropout {
1006                                            normalized_shape: normalized_shape.clone(),
1007                                            eps: *eps,
1008                                            dropout_p: *p,
1009                                        },
1010                                        inputs: input_node.inputs.clone(),
1011                                        output_shape: node.output_shape.clone(),
1012                                        metadata: HashMap::new(),
1013                                    };
1014                                    new_nodes.insert(fused_id, fused_node);
1015                                    nodes_to_remove.insert(node_id.clone());
1016                                    nodes_to_remove.insert(input_id.clone());
1017                                    fused_count += 1;
1018                                }
1019                            }
1020                        }
1021                    }
1022                }
1023            }
1024        }
1025
1026        // Apply the fusion by removing old nodes and adding new ones
1027        graph.nodes.retain(|id, _| !nodes_to_remove.contains(id));
1028        graph.nodes.extend(new_nodes);
1029
1030        // Update references to removed nodes
1031        // Build a mapping from old node IDs to fused node IDs
1032        let mut replacements: HashMap<String, String> = HashMap::new();
1033        for removed_id in &nodes_to_remove {
1034            let fused_id = format!("{}_fused", removed_id);
1035            if graph.nodes.contains_key(&fused_id) {
1036                replacements.insert(removed_id.clone(), fused_id);
1037            }
1038        }
1039
1040        // Apply replacements
1041        let node_ids: Vec<String> = graph.nodes.keys().cloned().collect();
1042        for node_id in node_ids {
1043            if let Some(node) = graph.nodes.get_mut(&node_id) {
1044                for input in &mut node.inputs {
1045                    if let Some(replacement) = replacements.get(input) {
1046                        *input = replacement.clone();
1047                    }
1048                }
1049            }
1050        }
1051
1052        Ok(fused_count)
1053    }
1054
1055    /// Remove dead nodes (nodes not connected to outputs)
1056    pub fn remove_dead_nodes(graph: &mut ComputationGraph) -> Result<usize, GraphError> {
1057        let subgraph = graph.extract_subgraph(&graph.outputs.clone())?;
1058        let removed = graph.nodes.len() - subgraph.nodes.len();
1059
1060        *graph = subgraph;
1061
1062        Ok(removed)
1063    }
1064
1065    /// Apply all optimizations
1066    pub fn optimize_all(graph: &mut ComputationGraph) -> Result<(), GraphError> {
1067        // Apply optimizations multiple times until convergence
1068        let mut prev_count = graph.node_count();
1069
1070        for _ in 0..10 {
1071            Self::constant_folding(graph)?;
1072            graph.optimize_cse();
1073            Self::fusion(graph)?;
1074            Self::remove_dead_nodes(graph)?;
1075
1076            let curr_count = graph.node_count();
1077            if curr_count == prev_count {
1078                break;
1079            }
1080            prev_count = curr_count;
1081        }
1082
1083        Ok(())
1084    }
1085}
1086
1087/// Lazy evaluation cache
1088#[derive(Debug, Clone)]
1089pub struct LazyCache {
1090    /// Cached results (node ID -> cached value)
1091    cache: HashMap<String, Vec<f32>>,
1092
1093    /// Cache size limit (in number of entries)
1094    max_size: usize,
1095
1096    /// Access order for LRU eviction
1097    access_order: VecDeque<String>,
1098}
1099
1100impl LazyCache {
1101    /// Create a new lazy cache
1102    pub fn new(max_size: usize) -> Self {
1103        Self {
1104            cache: HashMap::new(),
1105            max_size,
1106            access_order: VecDeque::new(),
1107        }
1108    }
1109
1110    /// Get a cached value
1111    pub fn get(&mut self, node_id: &str) -> Option<&Vec<f32>> {
1112        if self.cache.contains_key(node_id) {
1113            // Update access order
1114            self.access_order.retain(|id| id != node_id);
1115            self.access_order.push_back(node_id.to_string());
1116
1117            self.cache.get(node_id)
1118        } else {
1119            None
1120        }
1121    }
1122
1123    /// Insert a value into the cache
1124    pub fn insert(&mut self, node_id: String, value: Vec<f32>) {
1125        // Evict if necessary
1126        while self.cache.len() >= self.max_size && !self.access_order.is_empty() {
1127            if let Some(evict_id) = self.access_order.pop_front() {
1128                self.cache.remove(&evict_id);
1129            }
1130        }
1131
1132        self.cache.insert(node_id.clone(), value);
1133        self.access_order.push_back(node_id);
1134    }
1135
1136    /// Clear the cache
1137    pub fn clear(&mut self) {
1138        self.cache.clear();
1139        self.access_order.clear();
1140    }
1141
1142    /// Get cache size
1143    pub fn size(&self) -> usize {
1144        self.cache.len()
1145    }
1146
1147    /// Get cache hit ratio (if we track statistics)
1148    pub fn hit_ratio(&self) -> f32 {
1149        // Simplified - would need counters for hits/misses
1150        0.0
1151    }
1152}
1153
1154/// Execution batch containing independent nodes that can run in parallel
1155#[derive(Debug, Clone)]
1156pub struct ExecutionBatch {
1157    /// Node IDs in this batch
1158    pub node_ids: Vec<String>,
1159    /// Batch level in the dependency graph
1160    pub level: usize,
1161}
1162
1163impl ExecutionBatch {
1164    /// Create a new execution batch
1165    pub fn new(level: usize) -> Self {
1166        Self {
1167            node_ids: Vec::new(),
1168            level,
1169        }
1170    }
1171
1172    /// Add a node to the batch
1173    pub fn add_node(&mut self, node_id: String) {
1174        self.node_ids.push(node_id);
1175    }
1176
1177    /// Get the number of nodes in the batch
1178    pub fn size(&self) -> usize {
1179        self.node_ids.len()
1180    }
1181}
1182
1183/// Batch scheduler for identifying independent nodes
1184pub struct BatchScheduler;
1185
1186impl BatchScheduler {
1187    /// Create execution batches from a computation graph
1188    /// Returns batches where all nodes in each batch can execute in parallel
1189    pub fn create_batches(graph: &ComputationGraph) -> Result<Vec<ExecutionBatch>, GraphError> {
1190        let sorted = graph.topological_sort()?;
1191        let mut batches: Vec<ExecutionBatch> = Vec::new();
1192        let mut node_to_level: HashMap<String, usize> = HashMap::new();
1193
1194        // Assign levels to each node based on dependencies
1195        for node_id in &sorted {
1196            let max_input_level = if let Some(node) = graph.nodes.get(node_id) {
1197                node.inputs
1198                    .iter()
1199                    .filter_map(|input_id| node_to_level.get(input_id))
1200                    .max()
1201                    .copied()
1202                    .unwrap_or(0)
1203            } else {
1204                0
1205            };
1206
1207            let level = if graph.inputs.contains(node_id) {
1208                0
1209            } else {
1210                max_input_level + 1
1211            };
1212
1213            node_to_level.insert(node_id.clone(), level);
1214
1215            // Add node to appropriate batch
1216            while batches.len() <= level {
1217                batches.push(ExecutionBatch::new(batches.len()));
1218            }
1219            batches[level].add_node(node_id.clone());
1220        }
1221
1222        Ok(batches)
1223    }
1224}
1225
1226/// Parallel executor for computation graphs
1227pub struct ParallelExecutor {
1228    /// Number of threads to use (None = use rayon default)
1229    thread_count: Option<usize>,
1230}
1231
1232impl ParallelExecutor {
1233    /// Create a new parallel executor
1234    pub fn new(thread_count: Option<usize>) -> Self {
1235        Self { thread_count }
1236    }
1237
1238    /// Execute a computation graph in parallel
1239    /// This is a simplified version that tracks execution order
1240    pub fn execute(&self, graph: &ComputationGraph) -> Result<Vec<String>, GraphError> {
1241        let batches = BatchScheduler::create_batches(graph)?;
1242        let mut executed = Vec::new();
1243
1244        // Configure rayon thread pool if needed
1245        if let Some(threads) = self.thread_count {
1246            rayon::ThreadPoolBuilder::new()
1247                .num_threads(threads)
1248                .build()
1249                .map_err(|e| GraphError::ExecutionError(e.to_string()))?;
1250        }
1251
1252        // Execute each batch in parallel
1253        for batch in batches {
1254            let batch_results: Vec<String> = batch
1255                .node_ids
1256                .par_iter()
1257                .map(|node_id| {
1258                    // In a real implementation, this would execute the actual operation
1259                    // For now, we just return the node ID to track execution
1260                    node_id.clone()
1261                })
1262                .collect();
1263
1264            executed.extend(batch_results);
1265        }
1266
1267        Ok(executed)
1268    }
1269
1270    /// Execute a batch of nodes in parallel with a custom function
1271    pub fn execute_batch<F>(
1272        &self,
1273        batch: &ExecutionBatch,
1274        graph: &ComputationGraph,
1275        executor_fn: F,
1276    ) -> Result<Vec<(String, Vec<f32>)>, GraphError>
1277    where
1278        F: Fn(&GraphNode) -> Result<Vec<f32>, GraphError> + Sync + Send,
1279    {
1280        let results: Result<Vec<(String, Vec<f32>)>, GraphError> = batch
1281            .node_ids
1282            .par_iter()
1283            .map(|node_id| {
1284                let node = graph
1285                    .nodes
1286                    .get(node_id)
1287                    .ok_or_else(|| GraphError::NodeNotFound(node_id.clone()))?;
1288                let result = executor_fn(node)?;
1289                Ok((node_id.clone(), result))
1290            })
1291            .collect();
1292
1293        results
1294    }
1295}
1296
1297/// Stream chunk for streaming execution
1298#[derive(Debug, Clone)]
1299pub struct StreamChunk {
1300    /// Chunk data (node ID -> tensor data)
1301    pub data: HashMap<String, Vec<f32>>,
1302    /// Chunk index
1303    pub index: usize,
1304    /// Total number of chunks
1305    pub total_chunks: usize,
1306}
1307
1308impl StreamChunk {
1309    /// Create a new stream chunk
1310    pub fn new(index: usize, total_chunks: usize) -> Self {
1311        Self {
1312            data: HashMap::new(),
1313            index,
1314            total_chunks,
1315        }
1316    }
1317
1318    /// Add data for a node
1319    pub fn add_data(&mut self, node_id: String, data: Vec<f32>) {
1320        self.data.insert(node_id, data);
1321    }
1322
1323    /// Check if this is the last chunk
1324    pub fn is_last(&self) -> bool {
1325        self.index == self.total_chunks - 1
1326    }
1327}
1328
1329/// Streaming executor for processing data in chunks
1330pub struct StreamingExecutor {
1331    /// Chunk size (number of elements per chunk)
1332    chunk_size: usize,
1333    /// Maximum number of chunks to buffer
1334    max_buffer_size: usize,
1335    /// Current buffer
1336    buffer: Arc<Mutex<VecDeque<StreamChunk>>>,
1337}
1338
1339impl StreamingExecutor {
1340    /// Create a new streaming executor
1341    pub fn new(chunk_size: usize, max_buffer_size: usize) -> Self {
1342        Self {
1343            chunk_size,
1344            max_buffer_size,
1345            buffer: Arc::new(Mutex::new(VecDeque::new())),
1346        }
1347    }
1348
1349    /// Split input data into chunks
1350    pub fn create_chunks(&self, data: Vec<f32>, node_id: &str) -> Vec<StreamChunk> {
1351        let total_elements = data.len();
1352        let total_chunks = total_elements.div_ceil(self.chunk_size);
1353        let mut chunks = Vec::new();
1354
1355        for (i, chunk_data) in data.chunks(self.chunk_size).enumerate() {
1356            let mut chunk = StreamChunk::new(i, total_chunks);
1357            chunk.add_data(node_id.to_string(), chunk_data.to_vec());
1358            chunks.push(chunk);
1359        }
1360
1361        chunks
1362    }
1363
1364    /// Execute graph on a stream chunk
1365    pub fn execute_chunk(
1366        &self,
1367        _graph: &ComputationGraph,
1368        chunk: StreamChunk,
1369    ) -> Result<StreamChunk, GraphError> {
1370        // In a real implementation, this would:
1371        // 1. Execute the graph operations on the chunk data
1372        // 2. Apply backpressure if buffer is full
1373        // 3. Return the processed chunk
1374
1375        // For now, return the chunk as-is
1376        Ok(chunk)
1377    }
1378
1379    /// Process a stream of chunks through the graph
1380    pub fn process_stream(
1381        &self,
1382        graph: &ComputationGraph,
1383        chunks: Vec<StreamChunk>,
1384    ) -> Result<Vec<StreamChunk>, GraphError> {
1385        let mut results = Vec::new();
1386
1387        for chunk in chunks {
1388            // Check backpressure
1389            {
1390                let buffer = self.buffer.lock().unwrap();
1391                if buffer.len() >= self.max_buffer_size {
1392                    // In a real implementation, we would wait or apply backpressure
1393                    // For now, we just continue
1394                }
1395            }
1396
1397            // Execute chunk
1398            let result = self.execute_chunk(graph, chunk)?;
1399
1400            // Add to buffer
1401            {
1402                let mut buffer = self.buffer.lock().unwrap();
1403                buffer.push_back(result.clone());
1404
1405                // Keep buffer size in check
1406                while buffer.len() > self.max_buffer_size {
1407                    buffer.pop_front();
1408                }
1409            }
1410
1411            results.push(result);
1412        }
1413
1414        Ok(results)
1415    }
1416
1417    /// Get the current buffer size
1418    pub fn buffer_size(&self) -> usize {
1419        self.buffer.lock().unwrap().len()
1420    }
1421
1422    /// Clear the buffer
1423    pub fn clear_buffer(&self) {
1424        self.buffer.lock().unwrap().clear();
1425    }
1426
1427    /// Get chunk size
1428    pub fn chunk_size(&self) -> usize {
1429        self.chunk_size
1430    }
1431
1432    /// Get max buffer size
1433    pub fn max_buffer_size(&self) -> usize {
1434        self.max_buffer_size
1435    }
1436}
1437
1438/// Distributed graph execution for multi-node computation
1439///
1440/// This module provides infrastructure for distributing computation graphs
1441/// across multiple nodes in an IPFS network.
1442/// Node assignment for distributed execution
1443#[derive(Debug, Clone, Serialize, Deserialize)]
1444pub struct NodeAssignment {
1445    /// Node ID in the computation graph
1446    pub node_id: String,
1447    /// Worker ID (peer ID or node identifier)
1448    pub worker_id: String,
1449    /// Execution priority
1450    pub priority: usize,
1451}
1452
1453/// Graph partition for a single worker
1454#[derive(Debug, Clone, Serialize, Deserialize)]
1455pub struct GraphPartition {
1456    /// Worker ID that will execute this partition
1457    pub worker_id: String,
1458    /// Nodes assigned to this worker
1459    pub nodes: Vec<String>,
1460    /// Input dependencies from other partitions
1461    pub external_inputs: HashMap<String, String>, // node_id -> source_worker_id
1462    /// Output nodes consumed by other partitions
1463    pub external_outputs: Vec<String>,
1464    /// Subgraph for this partition
1465    #[serde(skip)]
1466    pub subgraph: Option<ComputationGraph>,
1467}
1468
1469impl GraphPartition {
1470    /// Create a new graph partition
1471    pub fn new(worker_id: String) -> Self {
1472        Self {
1473            worker_id,
1474            nodes: Vec::new(),
1475            external_inputs: HashMap::new(),
1476            external_outputs: Vec::new(),
1477            subgraph: None,
1478        }
1479    }
1480
1481    /// Add a node to this partition
1482    pub fn add_node(&mut self, node_id: String) {
1483        if !self.nodes.contains(&node_id) {
1484            self.nodes.push(node_id);
1485        }
1486    }
1487
1488    /// Add an external input dependency
1489    pub fn add_external_input(&mut self, node_id: String, source_worker_id: String) {
1490        self.external_inputs.insert(node_id, source_worker_id);
1491    }
1492
1493    /// Mark a node as an external output
1494    pub fn mark_external_output(&mut self, node_id: String) {
1495        if !self.external_outputs.contains(&node_id) {
1496            self.external_outputs.push(node_id);
1497        }
1498    }
1499
1500    /// Get the number of nodes in this partition
1501    pub fn size(&self) -> usize {
1502        self.nodes.len()
1503    }
1504}
1505
1506/// Distributed executor for multi-node graph execution
1507pub struct DistributedExecutor {
1508    /// Worker assignments
1509    assignments: HashMap<String, NodeAssignment>,
1510    /// Graph partitions by worker ID
1511    partitions: HashMap<String, GraphPartition>,
1512    /// Communication timeout (milliseconds)
1513    timeout_ms: u64,
1514}
1515
1516impl DistributedExecutor {
1517    /// Create a new distributed executor
1518    pub fn new() -> Self {
1519        Self {
1520            assignments: HashMap::new(),
1521            partitions: HashMap::new(),
1522            timeout_ms: 30000, // 30 seconds default
1523        }
1524    }
1525
1526    /// Set communication timeout
1527    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
1528        self.timeout_ms = timeout_ms;
1529        self
1530    }
1531
1532    /// Partition a graph across multiple workers
1533    /// Uses a simple round-robin strategy for now
1534    pub fn partition_graph(
1535        &mut self,
1536        graph: &ComputationGraph,
1537        worker_ids: &[String],
1538    ) -> Result<(), GraphError> {
1539        if worker_ids.is_empty() {
1540            return Err(GraphError::InvalidGraph("No workers available".to_string()));
1541        }
1542
1543        // Get topological order
1544        let sorted = graph.topological_sort()?;
1545
1546        // Create partitions for each worker
1547        for worker_id in worker_ids {
1548            self.partitions
1549                .insert(worker_id.clone(), GraphPartition::new(worker_id.clone()));
1550        }
1551
1552        // Assign nodes to workers in round-robin fashion
1553        for (idx, node_id) in sorted.iter().enumerate() {
1554            let worker_id = &worker_ids[idx % worker_ids.len()];
1555            let assignment = NodeAssignment {
1556                node_id: node_id.clone(),
1557                worker_id: worker_id.clone(),
1558                priority: idx,
1559            };
1560
1561            self.assignments.insert(node_id.clone(), assignment);
1562            if let Some(partition) = self.partitions.get_mut(worker_id) {
1563                partition.add_node(node_id.clone());
1564            }
1565        }
1566
1567        // Identify cross-partition dependencies
1568        for (node_id, node) in &graph.nodes {
1569            if let Some(assignment) = self.assignments.get(node_id) {
1570                for input_id in &node.inputs {
1571                    if let Some(input_assignment) = self.assignments.get(input_id) {
1572                        if input_assignment.worker_id != assignment.worker_id {
1573                            // Cross-partition dependency
1574                            if let Some(partition) = self.partitions.get_mut(&assignment.worker_id)
1575                            {
1576                                partition.add_external_input(
1577                                    input_id.clone(),
1578                                    input_assignment.worker_id.clone(),
1579                                );
1580                            }
1581                            if let Some(source_partition) =
1582                                self.partitions.get_mut(&input_assignment.worker_id)
1583                            {
1584                                source_partition.mark_external_output(input_id.clone());
1585                            }
1586                        }
1587                    }
1588                }
1589            }
1590        }
1591
1592        // Build subgraphs for each partition
1593        for partition in self.partitions.values_mut() {
1594            let mut subgraph = ComputationGraph::new();
1595
1596            // Add nodes belonging to this partition
1597            for node_id in &partition.nodes {
1598                if let Some(node) = graph.nodes.get(node_id) {
1599                    subgraph.nodes.insert(node_id.clone(), node.clone());
1600                }
1601            }
1602
1603            // Mark inputs and outputs
1604            for input_id in partition.external_inputs.keys() {
1605                if subgraph.nodes.contains_key(input_id) || graph.inputs.contains(input_id) {
1606                    subgraph.mark_input(input_id.clone());
1607                }
1608            }
1609
1610            for output_id in &partition.external_outputs {
1611                if subgraph.nodes.contains_key(output_id) {
1612                    subgraph.mark_output(output_id.clone());
1613                }
1614            }
1615
1616            // Also include original graph inputs if they're in this partition
1617            for input_id in &graph.inputs {
1618                if partition.nodes.contains(input_id) {
1619                    subgraph.mark_input(input_id.clone());
1620                }
1621            }
1622
1623            // Include original graph outputs if they're in this partition
1624            for output_id in &graph.outputs {
1625                if partition.nodes.contains(output_id) {
1626                    subgraph.mark_output(output_id.clone());
1627                }
1628            }
1629
1630            partition.subgraph = Some(subgraph);
1631        }
1632
1633        Ok(())
1634    }
1635
1636    /// Get partition for a specific worker
1637    pub fn get_partition(&self, worker_id: &str) -> Option<&GraphPartition> {
1638        self.partitions.get(worker_id)
1639    }
1640
1641    /// Get all partitions
1642    pub fn get_partitions(&self) -> &HashMap<String, GraphPartition> {
1643        &self.partitions
1644    }
1645
1646    /// Get node assignment
1647    pub fn get_assignment(&self, node_id: &str) -> Option<&NodeAssignment> {
1648        self.assignments.get(node_id)
1649    }
1650
1651    /// Execute a distributed graph
1652    /// NOTE: This is a stub that will be integrated with ipfrs-network
1653    pub fn execute_distributed(
1654        &self,
1655        _graph: &ComputationGraph,
1656    ) -> Result<HashMap<String, Vec<f32>>, GraphError> {
1657        // This is a placeholder for distributed execution
1658        // When ipfrs-network is integrated, this will:
1659        // 1. Send subgraphs to respective workers
1660        // 2. Coordinate data transfer between workers
1661        // 3. Collect results from workers
1662        // 4. Assemble final output
1663
1664        Err(GraphError::ExecutionError(
1665            "Distributed execution requires ipfrs-network integration".to_string(),
1666        ))
1667    }
1668
1669    /// Estimate communication cost for a partition
1670    pub fn estimate_communication_cost(&self, worker_id: &str) -> usize {
1671        if let Some(partition) = self.partitions.get(worker_id) {
1672            partition.external_inputs.len() + partition.external_outputs.len()
1673        } else {
1674            0
1675        }
1676    }
1677
1678    /// Get total number of workers
1679    pub fn worker_count(&self) -> usize {
1680        self.partitions.len()
1681    }
1682
1683    /// Get timeout in milliseconds
1684    pub fn timeout(&self) -> u64 {
1685        self.timeout_ms
1686    }
1687}
1688
1689impl Default for DistributedExecutor {
1690    fn default() -> Self {
1691        Self::new()
1692    }
1693}
1694
1695// Helper functions for serializing/deserializing Option<Cid>
1696fn serialize_optional_cid<S>(cid: &Option<Cid>, serializer: S) -> Result<S::Ok, S::Error>
1697where
1698    S: serde::Serializer,
1699{
1700    use serde::Serialize;
1701    match cid {
1702        Some(c) => Some(c.to_string()).serialize(serializer),
1703        None => None::<String>.serialize(serializer),
1704    }
1705}
1706
1707fn deserialize_optional_cid<'de, D>(deserializer: D) -> Result<Option<Cid>, D::Error>
1708where
1709    D: serde::Deserializer<'de>,
1710{
1711    use serde::Deserialize;
1712    let opt = Option::<String>::deserialize(deserializer)?;
1713    opt.map(|s| s.parse().map_err(serde::de::Error::custom))
1714        .transpose()
1715}
1716
1717#[cfg(test)]
1718mod tests {
1719    use super::*;
1720
1721    #[test]
1722    fn test_tensor_op() {
1723        let add = TensorOp::Add;
1724        assert_eq!(add.num_inputs(), 2);
1725        assert!(add.is_pure());
1726
1727        let relu = TensorOp::ReLU;
1728        assert_eq!(relu.num_inputs(), 1);
1729    }
1730
1731    #[test]
1732    fn test_graph_node() {
1733        let node = GraphNode::new("node1".to_string(), TensorOp::Add)
1734            .add_input("input1".to_string())
1735            .add_input("input2".to_string())
1736            .with_output_shape(vec![10, 20]);
1737
1738        assert_eq!(node.inputs.len(), 2);
1739        assert_eq!(node.output_shape, Some(vec![10, 20]));
1740    }
1741
1742    #[test]
1743    fn test_computation_graph() {
1744        let mut graph = ComputationGraph::new();
1745
1746        let input1 = GraphNode::new(
1747            "input1".to_string(),
1748            TensorOp::Input {
1749                name: "x".to_string(),
1750            },
1751        );
1752
1753        let input2 = GraphNode::new(
1754            "input2".to_string(),
1755            TensorOp::Input {
1756                name: "y".to_string(),
1757            },
1758        );
1759
1760        graph.add_node(input1).unwrap();
1761        graph.add_node(input2).unwrap();
1762        graph.mark_input("input1".to_string());
1763        graph.mark_input("input2".to_string());
1764
1765        let add = GraphNode::new("add1".to_string(), TensorOp::Add)
1766            .add_input("input1".to_string())
1767            .add_input("input2".to_string());
1768
1769        graph.add_node(add).unwrap();
1770        graph.mark_output("add1".to_string());
1771
1772        assert_eq!(graph.node_count(), 3);
1773        assert_eq!(graph.input_count(), 2);
1774        assert_eq!(graph.output_count(), 1);
1775    }
1776
1777    #[test]
1778    fn test_topological_sort() {
1779        let mut graph = ComputationGraph::new();
1780
1781        let input1 = GraphNode::new(
1782            "a".to_string(),
1783            TensorOp::Input {
1784                name: "x".to_string(),
1785            },
1786        );
1787        graph.add_node(input1).unwrap();
1788
1789        let b = GraphNode::new("b".to_string(), TensorOp::ReLU).add_input("a".to_string());
1790        graph.add_node(b).unwrap();
1791
1792        let c = GraphNode::new("c".to_string(), TensorOp::Tanh).add_input("b".to_string());
1793        graph.add_node(c).unwrap();
1794
1795        let sorted = graph.topological_sort().unwrap();
1796
1797        // Check that 'a' comes before 'b', and 'b' comes before 'c'
1798        let pos_a = sorted.iter().position(|x| x == "a").unwrap();
1799        let pos_b = sorted.iter().position(|x| x == "b").unwrap();
1800        let pos_c = sorted.iter().position(|x| x == "c").unwrap();
1801
1802        assert!(pos_a < pos_b);
1803        assert!(pos_b < pos_c);
1804    }
1805
1806    #[test]
1807    fn test_subgraph_extraction() {
1808        let mut graph = ComputationGraph::new();
1809
1810        let a = GraphNode::new(
1811            "a".to_string(),
1812            TensorOp::Input {
1813                name: "x".to_string(),
1814            },
1815        );
1816
1817        graph.add_node(a).unwrap();
1818        graph.mark_input("a".to_string());
1819
1820        let b = GraphNode::new("b".to_string(), TensorOp::ReLU).add_input("a".to_string());
1821        let c = GraphNode::new("c".to_string(), TensorOp::Tanh).add_input("a".to_string());
1822
1823        graph.add_node(b).unwrap();
1824        graph.add_node(c).unwrap();
1825
1826        let subgraph = graph.extract_subgraph(&["b".to_string()]).unwrap();
1827
1828        assert_eq!(subgraph.node_count(), 2); // Should have 'a' and 'b'
1829        assert!(subgraph.nodes.contains_key("a"));
1830        assert!(subgraph.nodes.contains_key("b"));
1831        assert!(!subgraph.nodes.contains_key("c"));
1832    }
1833
1834    #[test]
1835    fn test_cse_optimization() {
1836        let mut graph = ComputationGraph::new();
1837
1838        let a = GraphNode::new(
1839            "a".to_string(),
1840            TensorOp::Input {
1841                name: "x".to_string(),
1842            },
1843        );
1844        let b = GraphNode::new(
1845            "b".to_string(),
1846            TensorOp::Input {
1847                name: "y".to_string(),
1848            },
1849        );
1850
1851        // Create two identical Add operations
1852        let add1 = GraphNode::new("add1".to_string(), TensorOp::Add)
1853            .add_input("a".to_string())
1854            .add_input("b".to_string());
1855
1856        let add2 = GraphNode::new("add2".to_string(), TensorOp::Add)
1857            .add_input("a".to_string())
1858            .add_input("b".to_string());
1859
1860        graph.add_node(a).unwrap();
1861        graph.add_node(b).unwrap();
1862        graph.add_node(add1).unwrap();
1863        graph.add_node(add2).unwrap();
1864
1865        // CSE should detect these as duplicates
1866        let _optimized = graph.optimize_cse();
1867        // Note: In a more sophisticated implementation, we would verify
1868        // that duplicates are actually eliminated
1869    }
1870
1871    #[test]
1872    fn test_lazy_cache() {
1873        let mut cache = LazyCache::new(2);
1874
1875        cache.insert("node1".to_string(), vec![1.0, 2.0]);
1876        cache.insert("node2".to_string(), vec![3.0, 4.0]);
1877
1878        assert_eq!(cache.size(), 2);
1879        assert!(cache.get("node1").is_some());
1880
1881        // Adding a third item should evict the least recently used
1882        cache.insert("node3".to_string(), vec![5.0, 6.0]);
1883        assert_eq!(cache.size(), 2);
1884    }
1885
1886    #[test]
1887    fn test_graph_optimizer() {
1888        let mut graph = ComputationGraph::new();
1889
1890        let input = GraphNode::new(
1891            "input".to_string(),
1892            TensorOp::Input {
1893                name: "x".to_string(),
1894            },
1895        );
1896
1897        graph.add_node(input).unwrap();
1898        graph.mark_input("input".to_string());
1899
1900        let relu =
1901            GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("input".to_string());
1902
1903        // Add a dead node (not connected to output)
1904        let dead =
1905            GraphNode::new("dead".to_string(), TensorOp::Tanh).add_input("input".to_string());
1906
1907        graph.add_node(relu).unwrap();
1908        graph.add_node(dead).unwrap();
1909        graph.mark_output("relu".to_string());
1910
1911        let removed = GraphOptimizer::remove_dead_nodes(&mut graph).unwrap();
1912
1913        assert_eq!(removed, 1);
1914        assert!(!graph.nodes.contains_key("dead"));
1915    }
1916
1917    #[test]
1918    fn test_batch_scheduler() {
1919        let mut graph = ComputationGraph::new();
1920
1921        // Create a simple graph: a -> b, a -> c, (b,c) -> d
1922        let a = GraphNode::new(
1923            "a".to_string(),
1924            TensorOp::Input {
1925                name: "x".to_string(),
1926            },
1927        );
1928        graph.add_node(a).unwrap();
1929        graph.mark_input("a".to_string());
1930
1931        let b = GraphNode::new("b".to_string(), TensorOp::ReLU).add_input("a".to_string());
1932        let c = GraphNode::new("c".to_string(), TensorOp::Tanh).add_input("a".to_string());
1933
1934        graph.add_node(b).unwrap();
1935        graph.add_node(c).unwrap();
1936
1937        let d = GraphNode::new("d".to_string(), TensorOp::Add)
1938            .add_input("b".to_string())
1939            .add_input("c".to_string());
1940        graph.add_node(d).unwrap();
1941        graph.mark_output("d".to_string());
1942
1943        let batches = BatchScheduler::create_batches(&graph).unwrap();
1944
1945        // Batch 0: a (input)
1946        // Batch 1: b, c (both depend only on a)
1947        // Batch 2: d (depends on b and c)
1948        assert_eq!(batches.len(), 3);
1949        assert_eq!(batches[0].size(), 1); // a
1950        assert_eq!(batches[1].size(), 2); // b, c
1951        assert_eq!(batches[2].size(), 1); // d
1952    }
1953
1954    #[test]
1955    fn test_parallel_executor() {
1956        let mut graph = ComputationGraph::new();
1957
1958        let input1 = GraphNode::new(
1959            "input1".to_string(),
1960            TensorOp::Input {
1961                name: "x".to_string(),
1962            },
1963        );
1964        let input2 = GraphNode::new(
1965            "input2".to_string(),
1966            TensorOp::Input {
1967                name: "y".to_string(),
1968            },
1969        );
1970
1971        graph.add_node(input1).unwrap();
1972        graph.add_node(input2).unwrap();
1973        graph.mark_input("input1".to_string());
1974        graph.mark_input("input2".to_string());
1975
1976        let add = GraphNode::new("add".to_string(), TensorOp::Add)
1977            .add_input("input1".to_string())
1978            .add_input("input2".to_string());
1979
1980        graph.add_node(add).unwrap();
1981        graph.mark_output("add".to_string());
1982
1983        let executor = ParallelExecutor::new(Some(2));
1984        let result = executor.execute(&graph).unwrap();
1985
1986        // All nodes should be executed
1987        assert_eq!(result.len(), 3);
1988    }
1989
1990    #[test]
1991    fn test_execution_batch() {
1992        let mut batch = ExecutionBatch::new(0);
1993        batch.add_node("node1".to_string());
1994        batch.add_node("node2".to_string());
1995
1996        assert_eq!(batch.size(), 2);
1997        assert_eq!(batch.level, 0);
1998        assert!(batch.node_ids.contains(&"node1".to_string()));
1999    }
2000
2001    #[test]
2002    fn test_streaming_executor() {
2003        let executor = StreamingExecutor::new(100, 10);
2004
2005        // Create test data
2006        let data: Vec<f32> = (0..250).map(|i| i as f32).collect();
2007        let chunks = executor.create_chunks(data.clone(), "test_node");
2008
2009        // Should create 3 chunks (100, 100, 50)
2010        assert_eq!(chunks.len(), 3);
2011        assert_eq!(chunks[0].data["test_node"].len(), 100);
2012        assert_eq!(chunks[1].data["test_node"].len(), 100);
2013        assert_eq!(chunks[2].data["test_node"].len(), 50);
2014        assert!(chunks[2].is_last());
2015
2016        assert_eq!(executor.chunk_size(), 100);
2017        assert_eq!(executor.max_buffer_size(), 10);
2018    }
2019
2020    #[test]
2021    fn test_stream_chunk() {
2022        let mut chunk = StreamChunk::new(0, 5);
2023        chunk.add_data("node1".to_string(), vec![1.0, 2.0, 3.0]);
2024        chunk.add_data("node2".to_string(), vec![4.0, 5.0, 6.0]);
2025
2026        assert_eq!(chunk.index, 0);
2027        assert_eq!(chunk.total_chunks, 5);
2028        assert!(!chunk.is_last());
2029        assert_eq!(chunk.data.len(), 2);
2030
2031        let last_chunk = StreamChunk::new(4, 5);
2032        assert!(last_chunk.is_last());
2033    }
2034
2035    #[test]
2036    fn test_streaming_process_stream() {
2037        let graph = ComputationGraph::new();
2038        let executor = StreamingExecutor::new(100, 5);
2039
2040        let data: Vec<f32> = (0..300).map(|i| i as f32).collect();
2041        let chunks = executor.create_chunks(data, "input");
2042
2043        let results = executor.process_stream(&graph, chunks).unwrap();
2044
2045        assert_eq!(results.len(), 3);
2046        assert!(executor.buffer_size() <= executor.max_buffer_size());
2047
2048        executor.clear_buffer();
2049        assert_eq!(executor.buffer_size(), 0);
2050    }
2051
2052    #[test]
2053    fn test_distributed_executor_creation() {
2054        let executor = DistributedExecutor::new();
2055        assert_eq!(executor.worker_count(), 0);
2056        assert_eq!(executor.timeout(), 30000);
2057
2058        let executor_custom = DistributedExecutor::new().with_timeout(60000);
2059        assert_eq!(executor_custom.timeout(), 60000);
2060    }
2061
2062    #[test]
2063    fn test_graph_partitioning() {
2064        let mut graph = ComputationGraph::new();
2065
2066        // Create a simple graph: input -> a -> b -> c -> output
2067        let input = GraphNode::new(
2068            "input".to_string(),
2069            TensorOp::Input {
2070                name: "x".to_string(),
2071            },
2072        );
2073        graph.add_node(input).unwrap();
2074        graph.mark_input("input".to_string());
2075
2076        let a = GraphNode::new("a".to_string(), TensorOp::ReLU).add_input("input".to_string());
2077        let b = GraphNode::new("b".to_string(), TensorOp::Tanh).add_input("a".to_string());
2078        let c = GraphNode::new("c".to_string(), TensorOp::Sigmoid).add_input("b".to_string());
2079
2080        graph.add_node(a).unwrap();
2081        graph.add_node(b).unwrap();
2082        graph.add_node(c).unwrap();
2083        graph.mark_output("c".to_string());
2084
2085        // Partition across 2 workers
2086        let mut executor = DistributedExecutor::new();
2087        let workers = vec!["worker1".to_string(), "worker2".to_string()];
2088        executor.partition_graph(&graph, &workers).unwrap();
2089
2090        assert_eq!(executor.worker_count(), 2);
2091
2092        // Check that partitions were created
2093        let partition1 = executor.get_partition("worker1");
2094        let partition2 = executor.get_partition("worker2");
2095
2096        assert!(partition1.is_some());
2097        assert!(partition2.is_some());
2098
2099        // Each partition should have nodes
2100        let p1 = partition1.unwrap();
2101        let p2 = partition2.unwrap();
2102
2103        assert!(p1.size() > 0);
2104        assert!(p2.size() > 0);
2105
2106        // Total nodes across partitions should match graph
2107        assert_eq!(p1.size() + p2.size(), 4); // input, a, b, c
2108    }
2109
2110    #[test]
2111    fn test_cross_partition_dependencies() {
2112        let mut graph = ComputationGraph::new();
2113
2114        // Create a graph with cross-partition dependencies
2115        let input1 = GraphNode::new(
2116            "input1".to_string(),
2117            TensorOp::Input {
2118                name: "x".to_string(),
2119            },
2120        );
2121        let input2 = GraphNode::new(
2122            "input2".to_string(),
2123            TensorOp::Input {
2124                name: "y".to_string(),
2125            },
2126        );
2127
2128        graph.add_node(input1).unwrap();
2129        graph.add_node(input2).unwrap();
2130        graph.mark_input("input1".to_string());
2131        graph.mark_input("input2".to_string());
2132
2133        let a = GraphNode::new("a".to_string(), TensorOp::ReLU).add_input("input1".to_string());
2134        let b = GraphNode::new("b".to_string(), TensorOp::Tanh).add_input("input2".to_string());
2135        let c = GraphNode::new("c".to_string(), TensorOp::Add)
2136            .add_input("a".to_string())
2137            .add_input("b".to_string());
2138
2139        graph.add_node(a).unwrap();
2140        graph.add_node(b).unwrap();
2141        graph.add_node(c).unwrap();
2142        graph.mark_output("c".to_string());
2143
2144        // Partition across 3 workers
2145        let mut executor = DistributedExecutor::new();
2146        let workers = vec![
2147            "worker1".to_string(),
2148            "worker2".to_string(),
2149            "worker3".to_string(),
2150        ];
2151        executor.partition_graph(&graph, &workers).unwrap();
2152
2153        // Check communication costs
2154        let cost1 = executor.estimate_communication_cost("worker1");
2155        let cost2 = executor.estimate_communication_cost("worker2");
2156        let cost3 = executor.estimate_communication_cost("worker3");
2157
2158        // At least one partition should have external dependencies
2159        assert!(cost1 > 0 || cost2 > 0 || cost3 > 0);
2160    }
2161
2162    #[test]
2163    fn test_graph_partition_struct() {
2164        let mut partition = GraphPartition::new("worker1".to_string());
2165
2166        partition.add_node("node1".to_string());
2167        partition.add_node("node2".to_string());
2168        partition.add_node("node1".to_string()); // Duplicate should be ignored
2169
2170        assert_eq!(partition.size(), 2);
2171
2172        partition.add_external_input("input1".to_string(), "worker2".to_string());
2173        partition.mark_external_output("output1".to_string());
2174
2175        assert_eq!(partition.external_inputs.len(), 1);
2176        assert_eq!(partition.external_outputs.len(), 1);
2177    }
2178
2179    #[test]
2180    fn test_node_assignment() {
2181        let assignment = NodeAssignment {
2182            node_id: "node1".to_string(),
2183            worker_id: "worker1".to_string(),
2184            priority: 5,
2185        };
2186
2187        assert_eq!(assignment.node_id, "node1");
2188        assert_eq!(assignment.worker_id, "worker1");
2189        assert_eq!(assignment.priority, 5);
2190    }
2191
2192    #[test]
2193    fn test_distributed_partition_no_workers() {
2194        let graph = ComputationGraph::new();
2195        let mut executor = DistributedExecutor::new();
2196        let workers: Vec<String> = vec![];
2197
2198        let result = executor.partition_graph(&graph, &workers);
2199        assert!(result.is_err());
2200    }
2201
2202    #[test]
2203    fn test_shape_inference_matmul() {
2204        let op = TensorOp::MatMul;
2205        let input_shapes = vec![vec![2, 3, 4], vec![2, 4, 5]];
2206        let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2207        assert_eq!(output_shape, vec![2, 3, 5]);
2208    }
2209
2210    #[test]
2211    fn test_shape_inference_add_broadcast() {
2212        let op = TensorOp::Add;
2213        let input_shapes = vec![vec![3, 1, 4], vec![3, 2, 4]];
2214        let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2215        assert_eq!(output_shape, vec![3, 2, 4]);
2216    }
2217
2218    #[test]
2219    fn test_shape_inference_reduce_sum() {
2220        let op = TensorOp::ReduceSum {
2221            axes: vec![1],
2222            keepdims: false,
2223        };
2224        let input_shapes = vec![vec![2, 3, 4]];
2225        let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2226        assert_eq!(output_shape, vec![2, 4]);
2227    }
2228
2229    #[test]
2230    fn test_shape_inference_reduce_sum_keepdims() {
2231        let op = TensorOp::ReduceSum {
2232            axes: vec![1],
2233            keepdims: true,
2234        };
2235        let input_shapes = vec![vec![2, 3, 4]];
2236        let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2237        assert_eq!(output_shape, vec![2, 1, 4]);
2238    }
2239
2240    #[test]
2241    fn test_shape_inference_transpose() {
2242        let op = TensorOp::Transpose {
2243            axes: vec![0, 2, 1],
2244        };
2245        let input_shapes = vec![vec![2, 3, 4]];
2246        let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2247        assert_eq!(output_shape, vec![2, 4, 3]);
2248    }
2249
2250    #[test]
2251    fn test_shape_inference_concat() {
2252        let op = TensorOp::Concat { axis: 1 };
2253        let input_shapes = vec![vec![2, 3, 4], vec![2, 5, 4]];
2254        let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2255        assert_eq!(output_shape, vec![2, 8, 4]);
2256    }
2257
2258    #[test]
2259    fn test_shape_inference_reshape() {
2260        let op = TensorOp::Reshape { shape: vec![6, 4] };
2261        let input_shapes = vec![vec![2, 3, 4]];
2262        let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2263        assert_eq!(output_shape, vec![6, 4]);
2264    }
2265
2266    #[test]
2267    fn test_graph_shape_propagation() {
2268        let mut graph = ComputationGraph::new();
2269
2270        // Input: [2, 3]
2271        let mut input = GraphNode::new(
2272            "input".to_string(),
2273            TensorOp::Input {
2274                name: "x".to_string(),
2275            },
2276        );
2277        input.output_shape = Some(vec![2, 3]);
2278        graph.add_node(input).unwrap();
2279        graph.mark_input("input".to_string());
2280
2281        // Weight: [3, 4]
2282        let mut weight = GraphNode::new(
2283            "weight".to_string(),
2284            TensorOp::Constant {
2285                value_cid: "cid1".to_string(),
2286            },
2287        );
2288        weight.output_shape = Some(vec![3, 4]);
2289        graph.add_node(weight).unwrap();
2290
2291        // MatMul: should be [2, 4]
2292        let matmul = GraphNode::new("matmul".to_string(), TensorOp::MatMul)
2293            .add_input("input".to_string())
2294            .add_input("weight".to_string());
2295        graph.add_node(matmul).unwrap();
2296
2297        // ReLU: should be [2, 4]
2298        let relu =
2299            GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("matmul".to_string());
2300        graph.add_node(relu).unwrap();
2301        graph.mark_output("relu".to_string());
2302
2303        // Propagate shapes
2304        graph.propagate_shapes().unwrap();
2305
2306        // Check inferred shapes
2307        assert_eq!(
2308            graph.nodes.get("matmul").unwrap().output_shape,
2309            Some(vec![2, 4])
2310        );
2311        assert_eq!(
2312            graph.nodes.get("relu").unwrap().output_shape,
2313            Some(vec![2, 4])
2314        );
2315    }
2316
2317    #[test]
2318    fn test_graph_validation() {
2319        let mut graph = ComputationGraph::new();
2320
2321        let input = GraphNode::new(
2322            "input".to_string(),
2323            TensorOp::Input {
2324                name: "x".to_string(),
2325            },
2326        )
2327        .with_output_shape(vec![2, 3]);
2328        graph.add_node(input).unwrap();
2329        graph.mark_input("input".to_string());
2330
2331        let relu =
2332            GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("input".to_string());
2333        graph.add_node(relu).unwrap();
2334        graph.mark_output("relu".to_string());
2335
2336        // Should validate successfully
2337        assert!(graph.validate().is_ok());
2338    }
2339
2340    #[test]
2341    fn test_graph_validation_missing_input() {
2342        let mut graph = ComputationGraph::new();
2343
2344        let relu =
2345            GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("nonexistent".to_string());
2346
2347        // Should fail because input doesn't exist
2348        assert!(graph.add_node(relu).is_err());
2349    }
2350
2351    #[test]
2352    fn test_estimate_memory() {
2353        let mut graph = ComputationGraph::new();
2354
2355        let mut input = GraphNode::new(
2356            "input".to_string(),
2357            TensorOp::Input {
2358                name: "x".to_string(),
2359            },
2360        );
2361        input.output_shape = Some(vec![10, 20]); // 200 elements * 4 bytes = 800 bytes
2362        graph.add_node(input).unwrap();
2363
2364        let mut weight = GraphNode::new(
2365            "weight".to_string(),
2366            TensorOp::Constant {
2367                value_cid: "cid1".to_string(),
2368            },
2369        );
2370        weight.output_shape = Some(vec![20, 30]); // 600 elements * 4 bytes = 2400 bytes
2371        graph.add_node(weight).unwrap();
2372
2373        let memory = graph.estimate_memory();
2374        assert_eq!(memory, 800 + 2400); // 3200 bytes total
2375    }
2376
2377    #[test]
2378    fn test_broadcast_shapes_same() {
2379        let result = TensorOp::broadcast_shapes(&[2, 3, 4], &[2, 3, 4]).unwrap();
2380        assert_eq!(result, vec![2, 3, 4]);
2381    }
2382
2383    #[test]
2384    fn test_broadcast_shapes_scalar() {
2385        let result = TensorOp::broadcast_shapes(&[2, 3, 4], &[1]).unwrap();
2386        assert_eq!(result, vec![2, 3, 4]);
2387    }
2388
2389    #[test]
2390    fn test_broadcast_shapes_incompatible() {
2391        let result = TensorOp::broadcast_shapes(&[2, 3, 4], &[2, 5, 4]);
2392        assert!(result.is_err());
2393    }
2394}