Skip to main content

entrenar/autograd/
graph_opt.rs

1//! Graph-level optimizations for computation graphs
2//!
3//! Implements constant folding, dead code elimination, and common subexpression
4//! elimination (CSE) during graph construction time. Inspired by JAX's dispatch
5//! constant folding (`jax/_src/dispatch.py:637-646`) and LLVM optimization passes.
6//!
7//! ## How It Works
8//!
9//! 1. Tensors are wrapped in `TracedValue` which tracks whether they are compile-time
10//!    constants or runtime-dynamic values
11//! 2. Binary operations check if both inputs are constant and fold immediately
12//! 3. Identity element optimizations eliminate no-op operations (x+0, x*1, x*0)
13//! 4. Shape tracking enables early validation of dimension mismatches
14//! 5. Graph optimization passes (DCE, CSE) run at construction time
15
16use ndarray::Array1;
17use std::collections::{HashMap, HashSet};
18
19/// Unique identifier for a node in the computation graph
20pub type NodeId = usize;
21
22/// A value that may be constant (known at graph construction time) or dynamic
23/// (computed at execution time).
24#[derive(Debug, Clone)]
25pub enum TracedValue {
26    /// Compile-time constant — can be folded
27    Constant(Array1<f32>),
28    /// Runtime-computed value (symbolic reference to a graph node)
29    Dynamic(NodeId),
30}
31
32impl TracedValue {
33    /// Returns true if this value is a compile-time constant
34    pub fn is_constant(&self) -> bool {
35        matches!(self, TracedValue::Constant(_))
36    }
37
38    /// Returns the constant value if this is a constant
39    pub fn as_constant(&self) -> Option<&Array1<f32>> {
40        match self {
41            TracedValue::Constant(v) => Some(v),
42            TracedValue::Dynamic(_) => None,
43        }
44    }
45
46    /// Returns the node id if this is a dynamic value
47    pub fn node_id(&self) -> Option<NodeId> {
48        match self {
49            TracedValue::Constant(_) => None,
50            TracedValue::Dynamic(id) => Some(*id),
51        }
52    }
53}
54
55/// Tensor with constant tracking for graph construction
56#[derive(Debug, Clone)]
57pub struct TracedTensor {
58    /// The actual or symbolic value
59    value: TracedValue,
60    /// Shape (always known, even for dynamic values)
61    shape: Vec<usize>,
62}
63
64impl TracedTensor {
65    /// Create a constant tensor (known at graph construction time)
66    pub fn constant(data: Array1<f32>) -> Self {
67        let shape = vec![data.len()];
68        Self { value: TracedValue::Constant(data), shape }
69    }
70
71    /// Create a dynamic (placeholder) tensor
72    pub fn placeholder(shape: Vec<usize>, node_id: NodeId) -> Self {
73        Self { value: TracedValue::Dynamic(node_id), shape }
74    }
75
76    /// Check if this tensor is a compile-time constant
77    pub fn is_constant(&self) -> bool {
78        self.value.is_constant()
79    }
80
81    /// Get the value
82    pub fn value(&self) -> &TracedValue {
83        &self.value
84    }
85
86    /// Get the shape
87    pub fn shape(&self) -> &[usize] {
88        &self.shape
89    }
90}
91
92/// Type of operation in the computation graph
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
94pub enum OpType {
95    Add,
96    Mul,
97    Scale,
98    Sum,
99    Matmul,
100    Relu,
101    Gelu,
102    Softmax,
103    LayerNorm,
104    Attention,
105    Constant,
106}
107
108/// A node in the computation graph
109#[derive(Debug, Clone)]
110pub struct GraphNode {
111    /// Unique node identifier
112    pub id: NodeId,
113    /// Operation type
114    pub op_type: OpType,
115    /// Input node IDs
116    pub input_ids: Vec<NodeId>,
117    /// Output shape
118    pub output_shape: Vec<usize>,
119    /// Constant value (if this node is a constant)
120    pub constant_value: Option<Array1<f32>>,
121    /// Whether this node has been eliminated
122    removed: bool,
123}
124
125impl GraphNode {
126    /// Check if this node holds a constant value
127    pub fn is_constant(&self) -> bool {
128        self.constant_value.is_some()
129    }
130
131    /// Check if this node has been removed by an optimization pass
132    pub fn is_removed(&self) -> bool {
133        self.removed
134    }
135
136    /// Mark this node as removed
137    pub fn mark_removed(&mut self) {
138        self.removed = true;
139    }
140}
141
142/// Computation graph with optimization support
143pub struct ComputeGraph {
144    /// All nodes in the graph
145    nodes: Vec<GraphNode>,
146    /// Output node IDs (roots of the graph)
147    output_ids: Vec<NodeId>,
148}
149
150impl ComputeGraph {
151    /// Create a new empty computation graph
152    pub fn new() -> Self {
153        Self { nodes: Vec::new(), output_ids: Vec::new() }
154    }
155
156    /// Add a constant node to the graph
157    pub fn add_constant(&mut self, data: Array1<f32>) -> NodeId {
158        let id = self.nodes.len();
159        let shape = vec![data.len()];
160        self.nodes.push(GraphNode {
161            id,
162            op_type: OpType::Constant,
163            input_ids: Vec::new(),
164            output_shape: shape,
165            constant_value: Some(data),
166            removed: false,
167        });
168        id
169    }
170
171    /// Add an operation node to the graph
172    pub fn add_op(
173        &mut self,
174        op_type: OpType,
175        input_ids: Vec<NodeId>,
176        output_shape: Vec<usize>,
177    ) -> NodeId {
178        let id = self.nodes.len();
179        self.nodes.push(GraphNode {
180            id,
181            op_type,
182            input_ids,
183            output_shape,
184            constant_value: None,
185            removed: false,
186        });
187        id
188    }
189
190    /// Mark a node as an output of the graph
191    pub fn mark_output(&mut self, node_id: NodeId) {
192        self.output_ids.push(node_id);
193    }
194
195    /// Get a node by ID
196    pub fn node(&self, id: NodeId) -> &GraphNode {
197        &self.nodes[id]
198    }
199
200    /// Get a mutable reference to a node by ID
201    pub fn node_mut(&mut self, id: NodeId) -> &mut GraphNode {
202        &mut self.nodes[id]
203    }
204
205    /// Get the number of nodes
206    pub fn len(&self) -> usize {
207        self.nodes.len()
208    }
209
210    /// Check if the graph is empty
211    pub fn is_empty(&self) -> bool {
212        self.nodes.is_empty()
213    }
214
215    /// Count non-removed nodes
216    pub fn active_node_count(&self) -> usize {
217        self.nodes.iter().filter(|n| !n.is_removed()).count()
218    }
219
220    /// Get output node IDs
221    pub fn output_ids(&self) -> &[NodeId] {
222        &self.output_ids
223    }
224
225    /// Compute topological order of non-removed nodes
226    pub fn topological_order(&self) -> Vec<NodeId> {
227        let (in_degree, adjacency) = self.build_graph_maps();
228        Self::kahns_algorithm(in_degree, &adjacency)
229    }
230
231    /// Build in-degree counts and adjacency lists for non-removed nodes
232    fn build_graph_maps(&self) -> (HashMap<NodeId, usize>, HashMap<NodeId, Vec<NodeId>>) {
233        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
234        let mut adjacency: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
235
236        for node in &self.nodes {
237            if node.is_removed() {
238                continue;
239            }
240            in_degree.entry(node.id).or_insert(0);
241            for &input_id in &node.input_ids {
242                if !self.nodes[input_id].is_removed() {
243                    adjacency.entry(input_id).or_default().push(node.id);
244                    *in_degree.entry(node.id).or_insert(0) += 1;
245                }
246            }
247        }
248
249        (in_degree, adjacency)
250    }
251
252    /// Run Kahn's algorithm to produce a topological ordering
253    fn kahns_algorithm(
254        mut in_degree: HashMap<NodeId, usize>,
255        adjacency: &HashMap<NodeId, Vec<NodeId>>,
256    ) -> Vec<NodeId> {
257        let mut queue: Vec<NodeId> =
258            in_degree.iter().filter(|(_, &deg)| deg == 0).map(|(&id, _)| id).collect();
259        queue.sort_unstable_by(|a, b| b.cmp(a)); // Descending so pop() yields smallest first
260
261        let mut order = Vec::new();
262        let empty = Vec::new();
263        while let Some(id) = queue.pop() {
264            order.push(id);
265            for &neighbor in adjacency.get(&id).unwrap_or(&empty) {
266                let Some(deg) = in_degree.get_mut(&neighbor) else {
267                    continue;
268                };
269                *deg -= 1;
270                if *deg == 0 {
271                    queue.push(neighbor);
272                    queue.sort_unstable_by(|a, b| b.cmp(a));
273                }
274            }
275        }
276
277        order
278    }
279
280    /// Replace all uses of `old_id` with `new_id` in the graph
281    pub fn replace_uses(&mut self, old_id: NodeId, new_id: NodeId) {
282        for node in &mut self.nodes {
283            for input_id in &mut node.input_ids {
284                if *input_id == old_id {
285                    *input_id = new_id;
286                }
287            }
288        }
289        for output_id in &mut self.output_ids {
290            if *output_id == old_id {
291                *output_id = new_id;
292            }
293        }
294    }
295}
296
297impl Default for ComputeGraph {
298    fn default() -> Self {
299        Self::new()
300    }
301}
302
303/// Get or create a graph node for a traced value.
304/// Dynamic values already have a node; constants are materialized into the graph.
305fn ensure_graph_node(value: &TracedValue, graph: &mut ComputeGraph) -> NodeId {
306    match value {
307        TracedValue::Dynamic(id) => *id,
308        TracedValue::Constant(data) => graph.add_constant(data.clone()),
309    }
310}
311
312/// Perform a traced binary operation with constant folding
313///
314/// If both inputs are constants, the operation is evaluated immediately.
315/// Otherwise, a graph node is created for deferred execution.
316pub fn traced_binary_op<F>(
317    a: &TracedTensor,
318    b: &TracedTensor,
319    op: F,
320    op_type: OpType,
321    graph: &mut ComputeGraph,
322) -> TracedTensor
323where
324    F: Fn(&Array1<f32>, &Array1<f32>) -> Array1<f32>,
325{
326    // Both constant: fold immediately
327    if let (Some(a_const), Some(b_const)) = (a.value.as_constant(), b.value.as_constant()) {
328        let result = op(a_const, b_const);
329        return TracedTensor::constant(result);
330    }
331
332    // Try identity element optimizations
333    if let Some(folded) = try_identity_fold(a, b, op_type) {
334        return folded;
335    }
336
337    // At least one is dynamic: create graph node
338    let a_node = ensure_graph_node(&a.value, graph);
339    let b_node = ensure_graph_node(&b.value, graph);
340
341    let output_shape = a.shape.clone(); // Assumes same shape for now
342    let node_id = graph.add_op(op_type, vec![a_node, b_node], output_shape.clone());
343
344    TracedTensor::placeholder(output_shape, node_id)
345}
346
347/// Try to fold operations with identity elements
348///
349/// - `x + 0 = x`
350/// - `0 + x = x`
351/// - `x * 1 = x`
352/// - `1 * x = x`
353/// - `x * 0 = 0`
354/// - `0 * x = 0`
355fn try_identity_fold(a: &TracedTensor, b: &TracedTensor, op_type: OpType) -> Option<TracedTensor> {
356    match op_type {
357        OpType::Add => try_additive_identity(a, b),
358        OpType::Mul => try_multiplicative_identity(a, b),
359        _ => None,
360    }
361}
362
363/// Additive identity: x + 0 = x, 0 + x = x
364fn try_additive_identity(a: &TracedTensor, b: &TracedTensor) -> Option<TracedTensor> {
365    if b.value.as_constant().is_some_and(is_zeros) {
366        return Some(a.clone());
367    }
368    if a.value.as_constant().is_some_and(is_zeros) {
369        return Some(b.clone());
370    }
371    None
372}
373
374/// Multiplicative identity/annihilator: x*1=x, 1*x=x, x*0=0, 0*x=0
375fn try_multiplicative_identity(a: &TracedTensor, b: &TracedTensor) -> Option<TracedTensor> {
376    // Check b as the constant operand (x * 1, x * 0)
377    if let Some(result) = try_mul_const(b, a) {
378        return Some(result);
379    }
380    // Check a as the constant operand (1 * x, 0 * x)
381    try_mul_const(a, b)
382}
383
384/// If `maybe_const` is a multiplicative identity (1) or annihilator (0),
385/// return the folded result. `other` is the non-constant operand.
386fn try_mul_const(maybe_const: &TracedTensor, other: &TracedTensor) -> Option<TracedTensor> {
387    let c = maybe_const.value.as_constant()?;
388    if is_ones(c) {
389        return Some(other.clone());
390    }
391    if is_zeros(c) {
392        return Some(TracedTensor::constant(Array1::zeros(other.shape[0])));
393    }
394    None
395}
396
397/// Check if all elements are zero
398fn is_zeros(arr: &Array1<f32>) -> bool {
399    arr.iter().all(|&x| x == 0.0)
400}
401
402/// Check if all elements are one
403fn is_ones(arr: &Array1<f32>) -> bool {
404    arr.iter().all(|&x| (x - 1.0).abs() < f32::EPSILON)
405}
406
407/// Shape tracker for early validation of dimension mismatches
408pub struct ShapeTracker {
409    shapes: HashMap<NodeId, Vec<usize>>,
410}
411
412/// Error type for shape validation
413#[derive(Debug, Clone, PartialEq, Eq)]
414pub enum ShapeError {
415    /// Input shape not found
416    UnknownInput(NodeId),
417    /// Dimension mismatch between operands
418    DimMismatch { expected: usize, got: usize },
419    /// Insufficient dimensions for operation
420    InsufficientDims { required: usize, got: usize },
421}
422
423impl std::fmt::Display for ShapeError {
424    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425        match self {
426            ShapeError::UnknownInput(id) => write!(f, "unknown input node {id}"),
427            ShapeError::DimMismatch { expected, got } => {
428                write!(f, "dimension mismatch: expected {expected}, got {got}")
429            }
430            ShapeError::InsufficientDims { required, got } => {
431                write!(f, "insufficient dims: need {required}, have {got}")
432            }
433        }
434    }
435}
436
437impl std::error::Error for ShapeError {}
438
439impl ShapeTracker {
440    /// Create a new shape tracker
441    pub fn new() -> Self {
442        Self { shapes: HashMap::new() }
443    }
444
445    /// Register a known shape for a node
446    pub fn register(&mut self, node_id: NodeId, shape: Vec<usize>) {
447        self.shapes.insert(node_id, shape);
448    }
449
450    /// Get the shape for a node
451    pub fn get(&self, node_id: NodeId) -> Option<&[usize]> {
452        self.shapes.get(&node_id).map(Vec::as_slice)
453    }
454
455    /// Look up a node's shape, returning an error if not registered
456    fn require_shape(&self, node_id: NodeId) -> Result<Vec<usize>, ShapeError> {
457        self.shapes.get(&node_id).cloned().ok_or(ShapeError::UnknownInput(node_id))
458    }
459
460    /// Validate that a shape has at least `min` dimensions
461    fn require_min_dims(shape: &[usize], min: usize) -> Result<(), ShapeError> {
462        if shape.len() < min {
463            return Err(ShapeError::InsufficientDims { required: min, got: shape.len() });
464        }
465        Ok(())
466    }
467
468    /// Store an output shape and return a clone
469    fn store_output(&mut self, output_id: NodeId, shape: Vec<usize>) -> Vec<usize> {
470        self.shapes.insert(output_id, shape.clone());
471        shape
472    }
473
474    /// Infer output shape for an element-wise binary operation
475    pub fn infer_elementwise(
476        &mut self,
477        output_id: NodeId,
478        a_id: NodeId,
479        b_id: NodeId,
480    ) -> Result<Vec<usize>, ShapeError> {
481        let a_shape = self.require_shape(a_id)?;
482        let b_shape = self.require_shape(b_id)?;
483
484        if a_shape != b_shape {
485            return Err(ShapeError::DimMismatch {
486                expected: a_shape.iter().product(),
487                got: b_shape.iter().product(),
488            });
489        }
490
491        Ok(self.store_output(output_id, a_shape))
492    }
493
494    /// Infer output shape for a matmul operation
495    pub fn infer_matmul(
496        &mut self,
497        output_id: NodeId,
498        a_id: NodeId,
499        b_id: NodeId,
500    ) -> Result<Vec<usize>, ShapeError> {
501        let a_shape = self.require_shape(a_id)?;
502        let b_shape = self.require_shape(b_id)?;
503
504        Self::require_min_dims(&a_shape, 2)?;
505        Self::require_min_dims(&b_shape, 2)?;
506
507        let k1 = a_shape[a_shape.len() - 1];
508        let k2 = b_shape[b_shape.len() - 2];
509
510        if k1 != k2 {
511            return Err(ShapeError::DimMismatch { expected: k1, got: k2 });
512        }
513
514        let m = a_shape[a_shape.len() - 2];
515        let n = b_shape[b_shape.len() - 1];
516        Ok(self.store_output(output_id, vec![m, n]))
517    }
518
519    /// Infer output shape for a sum (reduction) operation
520    pub fn infer_sum(
521        &mut self,
522        output_id: NodeId,
523        input_id: NodeId,
524    ) -> Result<Vec<usize>, ShapeError> {
525        self.require_shape(input_id)?;
526        Ok(self.store_output(output_id, vec![1]))
527    }
528
529    /// Get the number of tracked shapes
530    pub fn len(&self) -> usize {
531        self.shapes.len()
532    }
533
534    /// Check if no shapes are tracked
535    pub fn is_empty(&self) -> bool {
536        self.shapes.is_empty()
537    }
538}
539
540impl Default for ShapeTracker {
541    fn default() -> Self {
542        Self::new()
543    }
544}
545
546/// Trait for graph optimization passes
547pub trait OptimizationPass {
548    /// Name of the optimization pass
549    fn name(&self) -> &'static str;
550
551    /// Run the pass on the graph, returning the number of changes made
552    fn run(&self, graph: &mut ComputeGraph) -> usize;
553}
554
555/// Constant folding pass — evaluates operations with all-constant inputs at
556/// graph construction time.
557pub struct ConstantFolding;
558
559/// Try to evaluate a foldable operation with all-constant inputs.
560/// Returns `None` if the operation cannot be folded.
561fn try_eval_constant_op(op_type: OpType, inputs: &[&Array1<f32>]) -> Option<Array1<f32>> {
562    match (op_type, inputs) {
563        (OpType::Add, [a, b]) => Some(*a + *b),
564        (OpType::Mul, [a, b]) => Some(*a * *b),
565        (OpType::Sum, [a]) => Some(Array1::from(vec![a.sum()])),
566        (OpType::Scale, [a, b]) if b.len() == 1 => Some(*a * b[0]),
567        _ => None,
568    }
569}
570
571impl ConstantFolding {
572    /// Attempt to fold a single node to a constant value.
573    /// Returns `Some(result)` if the node can be folded, `None` otherwise.
574    fn try_fold_node(graph: &ComputeGraph, node_id: NodeId) -> Option<Array1<f32>> {
575        let node = &graph.nodes[node_id];
576        if node.is_removed() || node.is_constant() {
577            return None;
578        }
579
580        let all_const = node.input_ids.iter().all(|&id| graph.nodes[id].is_constant());
581        if !all_const {
582            return None;
583        }
584
585        let inputs: Vec<&Array1<f32>> = node
586            .input_ids
587            .iter()
588            .map(|&id| {
589                graph.nodes[id]
590                    .constant_value
591                    .as_ref()
592                    .expect("all inputs verified as constants above")
593            })
594            .collect();
595
596        try_eval_constant_op(node.op_type, &inputs)
597    }
598}
599
600impl OptimizationPass for ConstantFolding {
601    fn name(&self) -> &'static str {
602        "constant_folding"
603    }
604
605    fn run(&self, graph: &mut ComputeGraph) -> usize {
606        let mut changes = 0;
607        let order = graph.topological_order();
608
609        for node_id in order {
610            if let Some(result) = Self::try_fold_node(graph, node_id) {
611                let node_mut = &mut graph.nodes[node_id];
612                node_mut.constant_value = Some(result);
613                node_mut.op_type = OpType::Constant;
614                node_mut.input_ids.clear();
615                changes += 1;
616            }
617        }
618
619        changes
620    }
621}
622
623/// Dead code elimination pass — removes nodes not reachable from outputs.
624pub struct DeadCodeElimination;
625
626impl DeadCodeElimination {
627    /// Find all nodes reachable from outputs via DFS
628    fn find_reachable(graph: &ComputeGraph) -> HashSet<NodeId> {
629        let mut reachable = HashSet::new();
630        let mut stack: Vec<NodeId> = graph.output_ids.clone();
631
632        while let Some(id) = stack.pop() {
633            if !reachable.insert(id) {
634                continue;
635            }
636            if !graph.nodes[id].is_removed() {
637                stack.extend_from_slice(&graph.nodes[id].input_ids);
638            }
639        }
640
641        reachable
642    }
643}
644
645impl OptimizationPass for DeadCodeElimination {
646    fn name(&self) -> &'static str {
647        "dce"
648    }
649
650    fn run(&self, graph: &mut ComputeGraph) -> usize {
651        let reachable = Self::find_reachable(graph);
652        let mut changes = 0;
653
654        for id in 0..graph.nodes.len() {
655            if !reachable.contains(&id) && !graph.nodes[id].is_removed() {
656                graph.nodes[id].mark_removed();
657                changes += 1;
658            }
659        }
660
661        changes
662    }
663}
664
665/// Key for identifying structurally equivalent expressions (for CSE)
666#[derive(Debug, Clone, PartialEq, Eq, Hash)]
667struct ExprKey {
668    op_type: OpType,
669    input_ids: Vec<NodeId>,
670}
671
672impl ExprKey {
673    fn from_node(node: &GraphNode) -> Self {
674        Self { op_type: node.op_type, input_ids: node.input_ids.clone() }
675    }
676}
677
678/// Common subexpression elimination pass — deduplicates identical computations.
679pub struct CommonSubexprElimination;
680
681impl OptimizationPass for CommonSubexprElimination {
682    fn name(&self) -> &'static str {
683        "cse"
684    }
685
686    fn run(&self, graph: &mut ComputeGraph) -> usize {
687        let mut changes = 0;
688        let mut expr_to_node: HashMap<ExprKey, NodeId> = HashMap::new();
689
690        let order = graph.topological_order();
691        for node_id in order {
692            let node = &graph.nodes[node_id];
693            if node.is_removed() || node.op_type == OpType::Constant {
694                continue;
695            }
696
697            let key = ExprKey::from_node(node);
698
699            if let Some(&existing_id) = expr_to_node.get(&key) {
700                // Found duplicate: replace uses and remove
701                graph.replace_uses(node_id, existing_id);
702                graph.nodes[node_id].mark_removed();
703                changes += 1;
704            } else {
705                expr_to_node.insert(key, node_id);
706            }
707        }
708
709        changes
710    }
711}
712
713/// Graph optimizer that runs multiple passes until a fixpoint is reached
714pub struct GraphOptimizer {
715    passes: Vec<Box<dyn OptimizationPass>>,
716    max_iterations: usize,
717}
718
719impl GraphOptimizer {
720    /// Create a new optimizer with the default set of passes
721    pub fn new() -> Self {
722        let mut opt = Self { passes: Vec::new(), max_iterations: 10 };
723        opt.passes.push(Box::new(ConstantFolding));
724        opt.passes.push(Box::new(DeadCodeElimination));
725        opt.passes.push(Box::new(CommonSubexprElimination));
726        opt
727    }
728
729    /// Set the maximum number of optimization iterations
730    pub fn with_max_iterations(mut self, max: usize) -> Self {
731        self.max_iterations = max;
732        self
733    }
734
735    /// Run all passes until fixpoint or max iterations
736    pub fn optimize(&self, graph: &mut ComputeGraph) -> OptimizationReport {
737        let mut report = OptimizationReport {
738            iterations: 0,
739            total_changes: 0,
740            pass_changes: HashMap::new(),
741            initial_nodes: graph.active_node_count(),
742            final_nodes: 0,
743        };
744
745        for _ in 0..self.max_iterations {
746            let mut iter_changes = 0;
747            for pass in &self.passes {
748                let changes = pass.run(graph);
749                if changes > 0 {
750                    *report.pass_changes.entry(pass.name()).or_insert(0) += changes;
751                }
752                iter_changes += changes;
753            }
754
755            report.iterations += 1;
756            report.total_changes += iter_changes;
757
758            if iter_changes == 0 {
759                break; // Fixpoint reached
760            }
761        }
762
763        report.final_nodes = graph.active_node_count();
764        report
765    }
766}
767
768impl Default for GraphOptimizer {
769    fn default() -> Self {
770        Self::new()
771    }
772}
773
774/// Report of optimization results
775#[derive(Debug, Clone)]
776pub struct OptimizationReport {
777    /// Number of optimization iterations run
778    pub iterations: usize,
779    /// Total changes across all passes and iterations
780    pub total_changes: usize,
781    /// Changes per pass name
782    pub pass_changes: HashMap<&'static str, usize>,
783    /// Number of active nodes before optimization
784    pub initial_nodes: usize,
785    /// Number of active nodes after optimization
786    pub final_nodes: usize,
787}
788
789impl OptimizationReport {
790    /// Node reduction ratio (0.0 = no reduction, 1.0 = all removed)
791    pub fn reduction_ratio(&self) -> f64 {
792        if self.initial_nodes == 0 {
793            return 0.0;
794        }
795        1.0 - (self.final_nodes as f64 / self.initial_nodes as f64)
796    }
797}
798
799#[cfg(test)]
800mod tests {
801    use super::*;
802
803    // --- TracedValue tests ---
804
805    #[test]
806    fn test_traced_value_constant() {
807        let val = TracedValue::Constant(Array1::from(vec![1.0, 2.0]));
808        assert!(val.is_constant());
809        assert_eq!(val.as_constant().expect("operation should succeed").len(), 2);
810        assert_eq!(val.node_id(), None);
811    }
812
813    #[test]
814    fn test_traced_value_dynamic() {
815        let val = TracedValue::Dynamic(42);
816        assert!(!val.is_constant());
817        assert!(val.as_constant().is_none());
818        assert_eq!(val.node_id(), Some(42));
819    }
820
821    // --- TracedTensor tests ---
822
823    #[test]
824    fn test_traced_tensor_constant() {
825        let t = TracedTensor::constant(Array1::from(vec![1.0, 2.0, 3.0]));
826        assert!(t.is_constant());
827        assert_eq!(t.shape(), &[3]);
828    }
829
830    #[test]
831    fn test_traced_tensor_placeholder() {
832        let t = TracedTensor::placeholder(vec![4, 4], 7);
833        assert!(!t.is_constant());
834        assert_eq!(t.shape(), &[4, 4]);
835        assert_eq!(t.value().node_id(), Some(7));
836    }
837
838    // --- Identity folding tests ---
839
840    #[test]
841    fn test_add_with_zero_folds() {
842        let x = TracedTensor::placeholder(vec![3], 0);
843        let zero = TracedTensor::constant(Array1::zeros(3));
844
845        // x + 0 = x
846        let result = try_identity_fold(&x, &zero, OpType::Add);
847        assert!(result.is_some());
848        assert!(!result.expect("operation should succeed").is_constant()); // Should return x (dynamic)
849
850        // 0 + x = x
851        let result = try_identity_fold(&zero, &x, OpType::Add);
852        assert!(result.is_some());
853        assert!(!result.expect("operation should succeed").is_constant()); // Should return x (dynamic)
854    }
855
856    #[test]
857    fn test_mul_with_one_folds() {
858        let x = TracedTensor::placeholder(vec![3], 0);
859        let one = TracedTensor::constant(Array1::ones(3));
860
861        // x * 1 = x
862        let result = try_identity_fold(&x, &one, OpType::Mul);
863        assert!(result.is_some());
864        assert!(!result.expect("operation should succeed").is_constant());
865
866        // 1 * x = x
867        let result = try_identity_fold(&one, &x, OpType::Mul);
868        assert!(result.is_some());
869        assert!(!result.expect("operation should succeed").is_constant());
870    }
871
872    #[test]
873    fn test_mul_with_zero_annihilates() {
874        let x = TracedTensor::placeholder(vec![3], 0);
875        let zero = TracedTensor::constant(Array1::zeros(3));
876
877        // x * 0 = 0
878        let result = try_identity_fold(&x, &zero, OpType::Mul);
879        assert!(result.is_some());
880        let t = result.expect("operation should succeed");
881        assert!(t.is_constant());
882        assert!(is_zeros(t.value().as_constant().expect("operation should succeed")));
883
884        // 0 * x = 0
885        let result = try_identity_fold(&zero, &x, OpType::Mul);
886        assert!(result.is_some());
887        let t = result.expect("operation should succeed");
888        assert!(t.is_constant());
889        assert!(is_zeros(t.value().as_constant().expect("operation should succeed")));
890    }
891
892    #[test]
893    fn test_no_identity_fold_for_nonidentity() {
894        let a = TracedTensor::constant(Array1::from(vec![2.0, 3.0]));
895        let b = TracedTensor::placeholder(vec![2], 0);
896
897        assert!(try_identity_fold(&a, &b, OpType::Add).is_none());
898        assert!(try_identity_fold(&a, &b, OpType::Mul).is_none());
899    }
900
901    // --- Traced binary op tests ---
902
903    #[test]
904    fn test_traced_binary_op_both_constant() {
905        let mut graph = ComputeGraph::new();
906        let a = TracedTensor::constant(Array1::from(vec![1.0, 2.0, 3.0]));
907        let b = TracedTensor::constant(Array1::from(vec![4.0, 5.0, 6.0]));
908
909        let result = traced_binary_op(&a, &b, |x, y| x + y, OpType::Add, &mut graph);
910        assert!(result.is_constant());
911        let data = result.value().as_constant().expect("operation should succeed");
912        assert_eq!(data.as_slice().expect("operation should succeed"), &[5.0, 7.0, 9.0]);
913        // No graph nodes created
914        assert_eq!(graph.len(), 0);
915    }
916
917    #[test]
918    fn test_traced_binary_op_one_dynamic() {
919        let mut graph = ComputeGraph::new();
920        let a = TracedTensor::placeholder(vec![3], graph.add_constant(Array1::from(vec![0.0; 3])));
921        let b = TracedTensor::constant(Array1::from(vec![4.0, 5.0, 6.0]));
922
923        let result = traced_binary_op(&a, &b, |x, y| x + y, OpType::Add, &mut graph);
924        // b is lifted to a constant node, then an add node is created
925        assert!(!result.is_constant());
926    }
927
928    #[test]
929    fn test_traced_binary_op_identity_fold() {
930        let mut graph = ComputeGraph::new();
931        let x_id = graph.add_constant(Array1::from(vec![1.0, 2.0]));
932        let x = TracedTensor::placeholder(vec![2], x_id);
933        let zero = TracedTensor::constant(Array1::zeros(2));
934
935        let result = traced_binary_op(&x, &zero, |a, b| a + b, OpType::Add, &mut graph);
936        // Should fold: x + 0 = x, no new node
937        assert!(!result.is_constant());
938        assert_eq!(result.value().node_id(), Some(x_id));
939    }
940
941    // --- ComputeGraph tests ---
942
943    #[test]
944    fn test_compute_graph_empty() {
945        let graph = ComputeGraph::new();
946        assert!(graph.is_empty());
947        assert_eq!(graph.len(), 0);
948        assert_eq!(graph.active_node_count(), 0);
949    }
950
951    #[test]
952    fn test_compute_graph_add_nodes() {
953        let mut graph = ComputeGraph::new();
954        let c1 = graph.add_constant(Array1::from(vec![1.0]));
955        let c2 = graph.add_constant(Array1::from(vec![2.0]));
956        let add = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
957
958        assert_eq!(graph.len(), 3);
959        assert_eq!(graph.active_node_count(), 3);
960        assert!(graph.node(c1).is_constant());
961        assert!(!graph.node(add).is_constant());
962    }
963
964    #[test]
965    fn test_compute_graph_topological_order() {
966        let mut graph = ComputeGraph::new();
967        let c1 = graph.add_constant(Array1::from(vec![1.0]));
968        let c2 = graph.add_constant(Array1::from(vec![2.0]));
969        let add = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
970        graph.mark_output(add);
971
972        let order = graph.topological_order();
973        // c1 and c2 should come before add
974        let add_pos = order.iter().position(|&x| x == add).expect("operation should succeed");
975        let c1_pos = order.iter().position(|&x| x == c1).expect("operation should succeed");
976        let c2_pos = order.iter().position(|&x| x == c2).expect("operation should succeed");
977        assert!(c1_pos < add_pos);
978        assert!(c2_pos < add_pos);
979    }
980
981    #[test]
982    fn test_compute_graph_replace_uses() {
983        let mut graph = ComputeGraph::new();
984        let c1 = graph.add_constant(Array1::from(vec![1.0]));
985        let c2 = graph.add_constant(Array1::from(vec![2.0]));
986        let add = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
987        graph.mark_output(add);
988
989        // Replace c1 with c2
990        let c3 = graph.add_constant(Array1::from(vec![3.0]));
991        graph.replace_uses(c1, c3);
992
993        assert_eq!(graph.node(add).input_ids, vec![c3, c2]);
994    }
995
996    // --- Constant folding pass tests ---
997
998    #[test]
999    fn test_constant_folding_add() {
1000        let mut graph = ComputeGraph::new();
1001        let c1 = graph.add_constant(Array1::from(vec![1.0, 2.0]));
1002        let c2 = graph.add_constant(Array1::from(vec![3.0, 4.0]));
1003        let add = graph.add_op(OpType::Add, vec![c1, c2], vec![2]);
1004        graph.mark_output(add);
1005
1006        let pass = ConstantFolding;
1007        let changes = pass.run(&mut graph);
1008
1009        assert_eq!(changes, 1);
1010        assert!(graph.node(add).is_constant());
1011        let result = graph.node(add).constant_value.as_ref().expect("operation should succeed");
1012        assert_eq!(result.as_slice().expect("operation should succeed"), &[4.0, 6.0]);
1013    }
1014
1015    #[test]
1016    fn test_constant_folding_mul() {
1017        let mut graph = ComputeGraph::new();
1018        let c1 = graph.add_constant(Array1::from(vec![2.0, 3.0]));
1019        let c2 = graph.add_constant(Array1::from(vec![4.0, 5.0]));
1020        let mul = graph.add_op(OpType::Mul, vec![c1, c2], vec![2]);
1021        graph.mark_output(mul);
1022
1023        let pass = ConstantFolding;
1024        let changes = pass.run(&mut graph);
1025
1026        assert_eq!(changes, 1);
1027        let result = graph.node(mul).constant_value.as_ref().expect("operation should succeed");
1028        assert_eq!(result.as_slice().expect("operation should succeed"), &[8.0, 15.0]);
1029    }
1030
1031    #[test]
1032    fn test_constant_folding_sum() {
1033        let mut graph = ComputeGraph::new();
1034        let c1 = graph.add_constant(Array1::from(vec![1.0, 2.0, 3.0]));
1035        let sum = graph.add_op(OpType::Sum, vec![c1], vec![1]);
1036        graph.mark_output(sum);
1037
1038        let pass = ConstantFolding;
1039        let changes = pass.run(&mut graph);
1040
1041        assert_eq!(changes, 1);
1042        let result = graph.node(sum).constant_value.as_ref().expect("operation should succeed");
1043        assert_eq!(result.as_slice().expect("operation should succeed"), &[6.0]);
1044    }
1045
1046    #[test]
1047    fn test_constant_folding_chain() {
1048        let mut graph = ComputeGraph::new();
1049        let c1 = graph.add_constant(Array1::from(vec![1.0, 2.0]));
1050        let c2 = graph.add_constant(Array1::from(vec![3.0, 4.0]));
1051        let add = graph.add_op(OpType::Add, vec![c1, c2], vec![2]);
1052        let c3 = graph.add_constant(Array1::from(vec![2.0, 2.0]));
1053        let mul = graph.add_op(OpType::Mul, vec![add, c3], vec![2]);
1054        graph.mark_output(mul);
1055
1056        let optimizer = GraphOptimizer::new();
1057        let report = optimizer.optimize(&mut graph);
1058
1059        // Both add and mul should be folded
1060        assert!(report.total_changes >= 2);
1061        assert!(graph.node(mul).is_constant());
1062        let result = graph.node(mul).constant_value.as_ref().expect("operation should succeed");
1063        assert_eq!(result.as_slice().expect("operation should succeed"), &[8.0, 12.0]);
1064    }
1065
1066    #[test]
1067    fn test_constant_folding_skips_dynamic() {
1068        let mut graph = ComputeGraph::new();
1069        let c1 = graph.add_constant(Array1::from(vec![1.0]));
1070        // Node 1 is "dynamic" (no constant value)
1071        let dyn_node = graph.add_op(OpType::Relu, vec![c1], vec![1]);
1072        let c2 = graph.add_constant(Array1::from(vec![2.0]));
1073        let add = graph.add_op(OpType::Add, vec![dyn_node, c2], vec![1]);
1074        graph.mark_output(add);
1075
1076        let pass = ConstantFolding;
1077        let changes = pass.run(&mut graph);
1078
1079        // ReLU is not foldable, so add can't be folded either
1080        assert_eq!(changes, 0);
1081    }
1082
1083    // --- Dead code elimination tests ---
1084
1085    #[test]
1086    fn test_dce_removes_unreachable() {
1087        let mut graph = ComputeGraph::new();
1088        let c1 = graph.add_constant(Array1::from(vec![1.0]));
1089        let c2 = graph.add_constant(Array1::from(vec![2.0]));
1090        let _dead = graph.add_op(OpType::Add, vec![c1, c2], vec![1]); // Dead
1091        let c3 = graph.add_constant(Array1::from(vec![3.0]));
1092        graph.mark_output(c3);
1093
1094        let pass = DeadCodeElimination;
1095        let changes = pass.run(&mut graph);
1096
1097        assert_eq!(changes, 3); // c1, c2, and _dead are unreachable
1098        assert!(graph.node(c1).is_removed());
1099        assert!(graph.node(c2).is_removed());
1100        assert!(!graph.node(c3).is_removed());
1101    }
1102
1103    #[test]
1104    fn test_dce_preserves_reachable() {
1105        let mut graph = ComputeGraph::new();
1106        let c1 = graph.add_constant(Array1::from(vec![1.0]));
1107        let c2 = graph.add_constant(Array1::from(vec![2.0]));
1108        let add = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
1109        graph.mark_output(add);
1110
1111        let pass = DeadCodeElimination;
1112        let changes = pass.run(&mut graph);
1113
1114        assert_eq!(changes, 0); // Everything is reachable
1115    }
1116
1117    // --- CSE tests ---
1118
1119    #[test]
1120    fn test_cse_deduplicates() {
1121        let mut graph = ComputeGraph::new();
1122        let c1 = graph.add_constant(Array1::from(vec![1.0]));
1123        let c2 = graph.add_constant(Array1::from(vec![2.0]));
1124        let add1 = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
1125        let add2 = graph.add_op(OpType::Add, vec![c1, c2], vec![1]); // Duplicate
1126        let mul = graph.add_op(OpType::Mul, vec![add1, add2], vec![1]);
1127        graph.mark_output(mul);
1128
1129        let pass = CommonSubexprElimination;
1130        let changes = pass.run(&mut graph);
1131
1132        assert_eq!(changes, 1); // add2 eliminated
1133        assert!(graph.node(add2).is_removed());
1134        // mul should now reference add1 for both inputs
1135        assert_eq!(graph.node(mul).input_ids, vec![add1, add1]);
1136    }
1137
1138    #[test]
1139    fn test_cse_no_false_positive() {
1140        let mut graph = ComputeGraph::new();
1141        let c1 = graph.add_constant(Array1::from(vec![1.0]));
1142        let c2 = graph.add_constant(Array1::from(vec![2.0]));
1143        let c3 = graph.add_constant(Array1::from(vec![3.0]));
1144        let add1 = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
1145        let add2 = graph.add_op(OpType::Add, vec![c1, c3], vec![1]); // Different inputs
1146        let mul = graph.add_op(OpType::Mul, vec![add1, add2], vec![1]);
1147        graph.mark_output(mul);
1148
1149        let pass = CommonSubexprElimination;
1150        let changes = pass.run(&mut graph);
1151
1152        assert_eq!(changes, 0); // Different expressions, no dedup
1153    }
1154
1155    // --- GraphOptimizer tests ---
1156
1157    #[test]
1158    fn test_optimizer_full_pipeline() {
1159        let mut graph = ComputeGraph::new();
1160
1161        // Build: (a + b) * (a + b) where a, b are constants
1162        // Should fold to a single constant
1163        let a = graph.add_constant(Array1::from(vec![1.0, 2.0]));
1164        let b = graph.add_constant(Array1::from(vec![3.0, 4.0]));
1165        let add1 = graph.add_op(OpType::Add, vec![a, b], vec![2]);
1166        let add2 = graph.add_op(OpType::Add, vec![a, b], vec![2]); // Duplicate
1167        let mul = graph.add_op(OpType::Mul, vec![add1, add2], vec![2]);
1168        graph.mark_output(mul);
1169
1170        let optimizer = GraphOptimizer::new();
1171        let report = optimizer.optimize(&mut graph);
1172
1173        assert!(report.total_changes > 0);
1174        assert!(report.final_nodes < report.initial_nodes);
1175    }
1176
1177    #[test]
1178    fn test_optimizer_report_reduction_ratio() {
1179        let report = OptimizationReport {
1180            iterations: 1,
1181            total_changes: 5,
1182            pass_changes: HashMap::new(),
1183            initial_nodes: 10,
1184            final_nodes: 5,
1185        };
1186        assert!((report.reduction_ratio() - 0.5).abs() < f64::EPSILON);
1187    }
1188
1189    #[test]
1190    fn test_optimizer_report_empty_graph() {
1191        let report = OptimizationReport {
1192            iterations: 0,
1193            total_changes: 0,
1194            pass_changes: HashMap::new(),
1195            initial_nodes: 0,
1196            final_nodes: 0,
1197        };
1198        assert!((report.reduction_ratio() - 0.0).abs() < f64::EPSILON);
1199    }
1200
1201    #[test]
1202    fn test_optimizer_max_iterations() {
1203        let optimizer = GraphOptimizer::new().with_max_iterations(1);
1204        let mut graph = ComputeGraph::new();
1205        let c1 = graph.add_constant(Array1::from(vec![1.0]));
1206        graph.mark_output(c1);
1207
1208        let report = optimizer.optimize(&mut graph);
1209        assert!(report.iterations <= 1);
1210    }
1211
1212    // --- ShapeTracker tests ---
1213
1214    #[test]
1215    fn test_shape_tracker_register_and_get() {
1216        let mut tracker = ShapeTracker::new();
1217        tracker.register(0, vec![3, 4]);
1218        assert_eq!(tracker.get(0), Some(&[3, 4][..]));
1219        assert_eq!(tracker.get(1), None);
1220    }
1221
1222    #[test]
1223    fn test_shape_tracker_elementwise() {
1224        let mut tracker = ShapeTracker::new();
1225        tracker.register(0, vec![5]);
1226        tracker.register(1, vec![5]);
1227
1228        let result = tracker.infer_elementwise(2, 0, 1);
1229        assert!(result.is_ok());
1230        assert_eq!(result.expect("operation should succeed"), vec![5]);
1231        assert_eq!(tracker.get(2), Some(&[5][..]));
1232    }
1233
1234    #[test]
1235    fn test_shape_tracker_elementwise_mismatch() {
1236        let mut tracker = ShapeTracker::new();
1237        tracker.register(0, vec![3]);
1238        tracker.register(1, vec![5]);
1239
1240        let result = tracker.infer_elementwise(2, 0, 1);
1241        assert!(result.is_err());
1242        match result.unwrap_err() {
1243            ShapeError::DimMismatch { .. } => {}
1244            other => panic!("expected DimMismatch, got {other:?}"),
1245        }
1246    }
1247
1248    #[test]
1249    fn test_shape_tracker_matmul() {
1250        let mut tracker = ShapeTracker::new();
1251        tracker.register(0, vec![3, 4]);
1252        tracker.register(1, vec![4, 5]);
1253
1254        let result = tracker.infer_matmul(2, 0, 1);
1255        assert!(result.is_ok());
1256        assert_eq!(result.expect("operation should succeed"), vec![3, 5]);
1257    }
1258
1259    #[test]
1260    fn test_shape_tracker_matmul_mismatch() {
1261        let mut tracker = ShapeTracker::new();
1262        tracker.register(0, vec![3, 4]);
1263        tracker.register(1, vec![5, 6]);
1264
1265        let result = tracker.infer_matmul(2, 0, 1);
1266        assert!(result.is_err());
1267    }
1268
1269    #[test]
1270    fn test_shape_tracker_matmul_insufficient_dims() {
1271        let mut tracker = ShapeTracker::new();
1272        tracker.register(0, vec![4]);
1273        tracker.register(1, vec![4, 5]);
1274
1275        let result = tracker.infer_matmul(2, 0, 1);
1276        assert!(result.is_err());
1277        match result.unwrap_err() {
1278            ShapeError::InsufficientDims { required: 2, got: 1 } => {}
1279            other => panic!("expected InsufficientDims, got {other:?}"),
1280        }
1281    }
1282
1283    #[test]
1284    fn test_shape_tracker_sum() {
1285        let mut tracker = ShapeTracker::new();
1286        tracker.register(0, vec![10]);
1287
1288        let result = tracker.infer_sum(1, 0);
1289        assert!(result.is_ok());
1290        assert_eq!(result.expect("operation should succeed"), vec![1]);
1291    }
1292
1293    #[test]
1294    fn test_shape_tracker_unknown_input() {
1295        let mut tracker = ShapeTracker::new();
1296        let result = tracker.infer_sum(1, 99);
1297        assert!(result.is_err());
1298        match result.unwrap_err() {
1299            ShapeError::UnknownInput(99) => {}
1300            other => panic!("expected UnknownInput(99), got {other:?}"),
1301        }
1302    }
1303
1304    #[test]
1305    fn test_shape_tracker_len() {
1306        let mut tracker = ShapeTracker::new();
1307        assert!(tracker.is_empty());
1308        assert_eq!(tracker.len(), 0);
1309
1310        tracker.register(0, vec![3]);
1311        assert!(!tracker.is_empty());
1312        assert_eq!(tracker.len(), 1);
1313    }
1314
1315    // --- Helper function tests ---
1316
1317    #[test]
1318    fn test_is_zeros() {
1319        assert!(is_zeros(&Array1::zeros(5)));
1320        assert!(!is_zeros(&Array1::ones(5)));
1321        assert!(!is_zeros(&Array1::from(vec![0.0, 0.0, 1.0])));
1322        assert!(is_zeros(&Array1::from(vec![])));
1323    }
1324
1325    #[test]
1326    fn test_is_ones() {
1327        assert!(is_ones(&Array1::ones(5)));
1328        assert!(!is_ones(&Array1::zeros(5)));
1329        assert!(!is_ones(&Array1::from(vec![1.0, 1.0, 2.0])));
1330        assert!(is_ones(&Array1::from(vec![])));
1331    }
1332
1333    // --- ShapeError Display tests ---
1334
1335    #[test]
1336    fn test_shape_error_display() {
1337        let err = ShapeError::UnknownInput(42);
1338        assert_eq!(format!("{err}"), "unknown input node 42");
1339
1340        let err = ShapeError::DimMismatch { expected: 3, got: 5 };
1341        assert_eq!(format!("{err}"), "dimension mismatch: expected 3, got 5");
1342
1343        let err = ShapeError::InsufficientDims { required: 2, got: 1 };
1344        assert_eq!(format!("{err}"), "insufficient dims: need 2, have 1");
1345    }
1346
1347    // --- GraphNode tests ---
1348
1349    #[test]
1350    fn test_graph_node_mark_removed() {
1351        let mut node = GraphNode {
1352            id: 0,
1353            op_type: OpType::Add,
1354            input_ids: vec![],
1355            output_shape: vec![1],
1356            constant_value: None,
1357            removed: false,
1358        };
1359        assert!(!node.is_removed());
1360        node.mark_removed();
1361        assert!(node.is_removed());
1362    }
1363
1364    // --- OpType tests for match arm coverage ---
1365
1366    #[test]
1367    fn test_op_type_variants() {
1368        let ops = [
1369            OpType::Add,
1370            OpType::Mul,
1371            OpType::Scale,
1372            OpType::Sum,
1373            OpType::Matmul,
1374            OpType::Relu,
1375            OpType::Gelu,
1376            OpType::Softmax,
1377            OpType::LayerNorm,
1378            OpType::Attention,
1379            OpType::Constant,
1380        ];
1381
1382        for op in &ops {
1383            match op {
1384                OpType::Add => assert_eq!(*op, OpType::Add),
1385                OpType::Mul => assert_eq!(*op, OpType::Mul),
1386                OpType::Scale => assert_eq!(*op, OpType::Scale),
1387                OpType::Sum => assert_eq!(*op, OpType::Sum),
1388                OpType::Matmul => assert_eq!(*op, OpType::Matmul),
1389                OpType::Relu => assert_eq!(*op, OpType::Relu),
1390                OpType::Gelu => assert_eq!(*op, OpType::Gelu),
1391                OpType::Softmax => assert_eq!(*op, OpType::Softmax),
1392                OpType::LayerNorm => assert_eq!(*op, OpType::LayerNorm),
1393                OpType::Attention => assert_eq!(*op, OpType::Attention),
1394                OpType::Constant => assert_eq!(*op, OpType::Constant),
1395            }
1396        }
1397    }
1398
1399    // --- Integration: realistic graph optimization scenario ---
1400
1401    #[test]
1402    fn test_mlp_init_with_zero_bias() {
1403        // Simulating: output = (input * weights) + bias where bias = 0
1404        // The bias addition should be eliminated by constant folding + identity fold
1405        let mut graph = ComputeGraph::new();
1406
1407        // Input (dynamic) and weights (dynamic) — represented as placeholders
1408        let input = graph.add_op(OpType::Relu, vec![], vec![4]); // Simulating dynamic input
1409        let weights = graph.add_constant(Array1::from(vec![0.5; 4]));
1410        let matmul = graph.add_op(OpType::Mul, vec![input, weights], vec![4]);
1411
1412        // Bias = 0 (constant)
1413        let bias = graph.add_constant(Array1::zeros(4));
1414        let output = graph.add_op(OpType::Add, vec![matmul, bias], vec![4]);
1415        graph.mark_output(output);
1416
1417        let initial_active = graph.active_node_count();
1418
1419        let optimizer = GraphOptimizer::new();
1420        let report = optimizer.optimize(&mut graph);
1421
1422        // The bias addition can't be fully eliminated since matmul is dynamic,
1423        // but DCE should handle any dead nodes
1424        assert!(report.iterations > 0);
1425        assert!(graph.active_node_count() <= initial_active);
1426    }
1427
1428    #[test]
1429    fn test_repeated_subexpression_elimination() {
1430        // Build: z = (a + b) * (a + b) with CSE
1431        let mut graph = ComputeGraph::new();
1432        let a = graph.add_constant(Array1::from(vec![1.0, 2.0]));
1433        let b = graph.add_constant(Array1::from(vec![3.0, 4.0]));
1434        let add1 = graph.add_op(OpType::Add, vec![a, b], vec![2]);
1435        let add2 = graph.add_op(OpType::Add, vec![a, b], vec![2]);
1436        let mul = graph.add_op(OpType::Mul, vec![add1, add2], vec![2]);
1437        let sum = graph.add_op(OpType::Sum, vec![mul], vec![1]);
1438        graph.mark_output(sum);
1439
1440        let optimizer = GraphOptimizer::new();
1441        let report = optimizer.optimize(&mut graph);
1442
1443        // CSE should eliminate add2, constant folding should fold everything
1444        assert!(report.total_changes > 0);
1445    }
1446
1447    // --- Default impls ---
1448
1449    #[test]
1450    fn test_compute_graph_default() {
1451        let graph = ComputeGraph::default();
1452        assert!(graph.is_empty());
1453    }
1454
1455    #[test]
1456    fn test_shape_tracker_default() {
1457        let tracker = ShapeTracker::default();
1458        assert!(tracker.is_empty());
1459    }
1460
1461    #[test]
1462    fn test_graph_optimizer_default() {
1463        let optimizer = GraphOptimizer::default();
1464        let mut graph = ComputeGraph::new();
1465        let c = graph.add_constant(Array1::from(vec![1.0]));
1466        graph.mark_output(c);
1467        let report = optimizer.optimize(&mut graph);
1468        assert_eq!(report.iterations, 1); // One pass, no changes, done
1469    }
1470}