Skip to main content

axonml_jit/
optimize.rs

1//! Graph Optimization — Six-Pass IR Transformation Pipeline
2//!
3//! Implements `Optimizer`, the driver that runs an ordered list of
4//! `OptimizationPass` transformations over a `Graph`. `OptimizationPass`
5//! enumerates ConstantFolding, DeadCodeElimination, ElementwiseFusion,
6//! CommonSubexpressionElimination, AlgebraicSimplification, and
7//! StrengthReduction. `default_passes` seeds the pipeline with folding, algebra,
8//! DCE, and CSE. Each pass is a standalone `fn(Graph) -> Graph`:
9//! `constant_folding` rewrites `x * 1 -> x`, `x * 0 -> Constant(0)`, and
10//! `x + 0 -> x` while tracking known-constant node values;
11//! `dead_code_elimination` walks backward from the registered outputs to build
12//! a `FxHashSet` of live nodes and rebuilds the graph keeping only those;
13//! `elementwise_fusion` is a no-op stub; `cse` uses a debug-formatted op string
14//! as the hash key in an `FxHashMap<String, NodeId>` to deduplicate structurally
15//! identical subexpressions; `algebraic_simplification` collapses `x * 1`,
16//! `x + 0`, and double negation (`--x -> x`); `strength_reduction` is a
17//! pass-through placeholder for Pow/Div cheapening. All passes share `remap_op`,
18//! a dense helper that rewrites every `Op` variant's `NodeId` references through
19//! a `FxHashMap` produced during graph rebuilding. Tests verify DCE removes an
20//! unused Mul, algebraic simplification eliminates `MulScalar(1.0)`, and
21//! constant folding materializes a Constant for `MulScalar(0.0)`.
22//!
23//! # File
24//! `crates/axonml-jit/src/optimize.rs`
25//!
26//! # Author
27//! Andrew Jewell Sr. — AutomataNexus LLC
28//! ORCID: 0009-0005-2158-7060
29//!
30//! # Updated
31//! April 16, 2026 11:15 PM EST
32//!
33//! # Disclaimer
34//! Use at own risk. This software is provided "as is", without warranty of any
35//! kind, express or implied. The author and AutomataNexus shall not be held
36//! liable for any damages arising from the use of this software.
37
38// =============================================================================
39// Imports
40// =============================================================================
41
42use crate::ir::{Graph, NodeId, Op};
43use rustc_hash::{FxHashMap, FxHashSet};
44
45// =============================================================================
46// OptimizationPass
47// =============================================================================
48
49/// Optimization passes available.
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum OptimizationPass {
52    /// Fold constant expressions.
53    ConstantFolding,
54    /// Remove dead (unused) code.
55    DeadCodeElimination,
56    /// Fuse consecutive elementwise operations.
57    ElementwiseFusion,
58    /// Common subexpression elimination.
59    CommonSubexpressionElimination,
60    /// Algebraic simplifications (x * 1 = x, x + 0 = x, etc).
61    AlgebraicSimplification,
62    /// Strength reduction (expensive ops -> cheaper ops).
63    StrengthReduction,
64}
65
66// =============================================================================
67// Optimizer Driver
68// =============================================================================
69
70/// Graph optimizer.
71pub struct Optimizer {
72    passes: Vec<OptimizationPass>,
73}
74
75impl Optimizer {
76    /// Creates a new optimizer with no passes.
77    pub fn new() -> Self {
78        Self { passes: Vec::new() }
79    }
80
81    /// Creates an optimizer with default passes.
82    pub fn default_passes() -> Self {
83        Self {
84            passes: vec![
85                OptimizationPass::ConstantFolding,
86                OptimizationPass::AlgebraicSimplification,
87                OptimizationPass::DeadCodeElimination,
88                OptimizationPass::CommonSubexpressionElimination,
89            ],
90        }
91    }
92
93    /// Adds an optimization pass.
94    pub fn add_pass(&mut self, pass: OptimizationPass) {
95        self.passes.push(pass);
96    }
97
98    /// Runs all optimization passes on the graph.
99    pub fn optimize(&self, mut graph: Graph) -> Graph {
100        for pass in &self.passes {
101            graph = self.run_pass(graph, *pass);
102        }
103        graph
104    }
105
106    fn run_pass(&self, graph: Graph, pass: OptimizationPass) -> Graph {
107        match pass {
108            OptimizationPass::ConstantFolding => constant_folding(graph),
109            OptimizationPass::DeadCodeElimination => dead_code_elimination(graph),
110            OptimizationPass::ElementwiseFusion => elementwise_fusion(graph),
111            OptimizationPass::CommonSubexpressionElimination => cse(graph),
112            OptimizationPass::AlgebraicSimplification => algebraic_simplification(graph),
113            OptimizationPass::StrengthReduction => strength_reduction(graph),
114        }
115    }
116}
117
118impl Default for Optimizer {
119    fn default() -> Self {
120        Self::default_passes()
121    }
122}
123
124// =============================================================================
125// Optimizer Passes
126// =============================================================================
127
128// -----------------------------------------------------------------------------
129// Constant Folding
130// -----------------------------------------------------------------------------
131
132/// Constant folding: evaluate constant expressions at compile time.
133fn constant_folding(graph: Graph) -> Graph {
134    // For now, just identify constant nodes
135    // Full implementation would evaluate constant subgraphs
136    let mut new_graph = Graph::new();
137    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
138    let mut constants: FxHashMap<NodeId, f64> = FxHashMap::default();
139
140    for node in graph.nodes() {
141        // Track constant values
142        if let Op::Constant { value } = &node.op {
143            constants.insert(node.id, *value);
144        }
145
146        // Try to fold binary ops with constants
147        let new_op = match &node.op {
148            Op::MulScalar { input, scalar } if *scalar == 1.0 => {
149                // x * 1 = x
150                let new_input = node_map.get(input).copied().unwrap_or(*input);
151                node_map.insert(node.id, new_input);
152                continue;
153            }
154            Op::MulScalar { input: _, scalar } if *scalar == 0.0 => {
155                // x * 0 = 0
156                Op::Constant { value: 0.0 }
157            }
158            Op::AddScalar { input, scalar } if *scalar == 0.0 => {
159                // x + 0 = x
160                let new_input = node_map.get(input).copied().unwrap_or(*input);
161                node_map.insert(node.id, new_input);
162                continue;
163            }
164            other => remap_op(other, &node_map),
165        };
166
167        let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
168        node_map.insert(node.id, new_id);
169    }
170
171    // Remap inputs and outputs
172    for (name, id) in graph.inputs() {
173        if let Some(&new_id) = node_map.get(id) {
174            new_graph.register_input(name, new_id);
175        }
176    }
177    for (name, id) in graph.outputs() {
178        if let Some(&new_id) = node_map.get(id) {
179            new_graph.register_output(name, new_id);
180        }
181    }
182
183    new_graph
184}
185
186// -----------------------------------------------------------------------------
187// Dead Code Elimination
188// -----------------------------------------------------------------------------
189
190/// Dead code elimination: remove nodes that don't contribute to outputs.
191fn dead_code_elimination(graph: Graph) -> Graph {
192    // Find all nodes reachable from outputs
193    let mut live_nodes: FxHashSet<NodeId> = FxHashSet::default();
194    let mut worklist: Vec<NodeId> = graph.outputs().values().copied().collect();
195
196    while let Some(id) = worklist.pop() {
197        if live_nodes.insert(id) {
198            let node = graph.node(id);
199            for input_id in node.op.inputs() {
200                worklist.push(input_id);
201            }
202        }
203    }
204
205    // Rebuild graph with only live nodes
206    let mut new_graph = Graph::new();
207    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
208
209    for node in graph.nodes() {
210        if !live_nodes.contains(&node.id) {
211            continue;
212        }
213
214        let new_op = remap_op(&node.op, &node_map);
215        let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
216        node_map.insert(node.id, new_id);
217    }
218
219    // Remap inputs and outputs
220    for (name, id) in graph.inputs() {
221        if let Some(&new_id) = node_map.get(id) {
222            new_graph.register_input(name, new_id);
223        }
224    }
225    for (name, id) in graph.outputs() {
226        if let Some(&new_id) = node_map.get(id) {
227            new_graph.register_output(name, new_id);
228        }
229    }
230
231    new_graph
232}
233
234// -----------------------------------------------------------------------------
235// Elementwise Fusion
236// -----------------------------------------------------------------------------
237
238/// Elementwise fusion: combine consecutive elementwise ops into kernels.
239fn elementwise_fusion(graph: Graph) -> Graph {
240    // For now, just return the graph unchanged
241    // Full implementation would identify fusible sequences
242    // and create FusedElementwise nodes
243    graph
244}
245
246// -----------------------------------------------------------------------------
247// Common Subexpression Elimination
248// -----------------------------------------------------------------------------
249
250/// Common subexpression elimination.
251fn cse(graph: Graph) -> Graph {
252    // Hash-based CSE
253    let mut new_graph = Graph::new();
254    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
255    let mut expr_map: FxHashMap<String, NodeId> = FxHashMap::default();
256
257    for node in graph.nodes() {
258        let remapped_op = remap_op(&node.op, &node_map);
259        let expr_key = format!("{:?}", remapped_op);
260
261        if let Some(&existing_id) = expr_map.get(&expr_key) {
262            // Reuse existing node
263            node_map.insert(node.id, existing_id);
264        } else {
265            let new_id = new_graph.add_node(remapped_op, node.dtype, node.shape.clone());
266            node_map.insert(node.id, new_id);
267            expr_map.insert(expr_key, new_id);
268        }
269    }
270
271    // Remap inputs and outputs
272    for (name, id) in graph.inputs() {
273        if let Some(&new_id) = node_map.get(id) {
274            new_graph.register_input(name, new_id);
275        }
276    }
277    for (name, id) in graph.outputs() {
278        if let Some(&new_id) = node_map.get(id) {
279            new_graph.register_output(name, new_id);
280        }
281    }
282
283    new_graph
284}
285
286// -----------------------------------------------------------------------------
287// Algebraic Simplification
288// -----------------------------------------------------------------------------
289
290/// Algebraic simplifications.
291fn algebraic_simplification(graph: Graph) -> Graph {
292    let mut new_graph = Graph::new();
293    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
294
295    for node in graph.nodes() {
296        let simplified_op = match &node.op {
297            // x * 1 = x
298            Op::MulScalar { input, scalar } if *scalar == 1.0 => {
299                let new_input = node_map.get(input).copied().unwrap_or(*input);
300                node_map.insert(node.id, new_input);
301                continue;
302            }
303            // x + 0 = x
304            Op::AddScalar { input, scalar } if *scalar == 0.0 => {
305                let new_input = node_map.get(input).copied().unwrap_or(*input);
306                node_map.insert(node.id, new_input);
307                continue;
308            }
309            // x - 0 = x (via AddScalar with -0)
310            // x / 1 = x (via MulScalar with 1)
311            // --x = x
312            Op::Neg { input } => {
313                let actual_input = node_map.get(input).copied().unwrap_or(*input);
314                if let Some(input_node) = new_graph.nodes().iter().find(|n| n.id == actual_input) {
315                    if let Op::Neg { input: inner } = &input_node.op {
316                        node_map.insert(node.id, *inner);
317                        continue;
318                    }
319                }
320                Op::Neg {
321                    input: actual_input,
322                }
323            }
324            other => remap_op(other, &node_map),
325        };
326
327        let new_id = new_graph.add_node(simplified_op, node.dtype, node.shape.clone());
328        node_map.insert(node.id, new_id);
329    }
330
331    // Remap inputs and outputs
332    for (name, id) in graph.inputs() {
333        if let Some(&new_id) = node_map.get(id) {
334            new_graph.register_input(name, new_id);
335        }
336    }
337    for (name, id) in graph.outputs() {
338        if let Some(&new_id) = node_map.get(id) {
339            new_graph.register_output(name, new_id);
340        }
341    }
342
343    new_graph
344}
345
346// -----------------------------------------------------------------------------
347// Strength Reduction
348// -----------------------------------------------------------------------------
349
350/// Strength reduction: replace expensive ops with cheaper equivalents.
351fn strength_reduction(graph: Graph) -> Graph {
352    let mut new_graph = Graph::new();
353    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
354
355    for node in graph.nodes() {
356        let reduced_op = match &node.op {
357            // x^2 -> x * x
358            Op::Pow { .. } => {
359                // Check if exp is constant 2
360                // For now, just pass through
361                remap_op(&node.op, &node_map)
362            }
363            // x / c -> x * (1/c) for constant c
364            Op::Div { .. } => {
365                // Would need to check if rhs is constant
366                remap_op(&node.op, &node_map)
367            }
368            other => remap_op(other, &node_map),
369        };
370
371        let new_id = new_graph.add_node(reduced_op, node.dtype, node.shape.clone());
372        node_map.insert(node.id, new_id);
373    }
374
375    // Remap inputs and outputs
376    for (name, id) in graph.inputs() {
377        if let Some(&new_id) = node_map.get(id) {
378            new_graph.register_input(name, new_id);
379        }
380    }
381    for (name, id) in graph.outputs() {
382        if let Some(&new_id) = node_map.get(id) {
383            new_graph.register_output(name, new_id);
384        }
385    }
386
387    new_graph
388}
389
390// =============================================================================
391// Helpers
392// =============================================================================
393
394/// Remaps node IDs in an operation using the provided mapping.
395fn remap_op(op: &Op, node_map: &FxHashMap<NodeId, NodeId>) -> Op {
396    let remap = |id: &NodeId| node_map.get(id).copied().unwrap_or(*id);
397
398    match op {
399        Op::Input { name } => Op::Input { name: name.clone() },
400        Op::Output { name, input } => Op::Output {
401            name: name.clone(),
402            input: remap(input),
403        },
404        Op::Constant { value } => Op::Constant { value: *value },
405
406        Op::Add { lhs, rhs } => Op::Add {
407            lhs: remap(lhs),
408            rhs: remap(rhs),
409        },
410        Op::Sub { lhs, rhs } => Op::Sub {
411            lhs: remap(lhs),
412            rhs: remap(rhs),
413        },
414        Op::Mul { lhs, rhs } => Op::Mul {
415            lhs: remap(lhs),
416            rhs: remap(rhs),
417        },
418        Op::Div { lhs, rhs } => Op::Div {
419            lhs: remap(lhs),
420            rhs: remap(rhs),
421        },
422        Op::Pow { base, exp } => Op::Pow {
423            base: remap(base),
424            exp: remap(exp),
425        },
426        Op::Max { lhs, rhs } => Op::Max {
427            lhs: remap(lhs),
428            rhs: remap(rhs),
429        },
430        Op::Min { lhs, rhs } => Op::Min {
431            lhs: remap(lhs),
432            rhs: remap(rhs),
433        },
434
435        Op::Neg { input } => Op::Neg {
436            input: remap(input),
437        },
438        Op::Abs { input } => Op::Abs {
439            input: remap(input),
440        },
441        Op::Sqrt { input } => Op::Sqrt {
442            input: remap(input),
443        },
444        Op::Exp { input } => Op::Exp {
445            input: remap(input),
446        },
447        Op::Log { input } => Op::Log {
448            input: remap(input),
449        },
450        Op::Sin { input } => Op::Sin {
451            input: remap(input),
452        },
453        Op::Cos { input } => Op::Cos {
454            input: remap(input),
455        },
456        Op::Tanh { input } => Op::Tanh {
457            input: remap(input),
458        },
459
460        Op::Relu { input } => Op::Relu {
461            input: remap(input),
462        },
463        Op::Sigmoid { input } => Op::Sigmoid {
464            input: remap(input),
465        },
466        Op::Gelu { input } => Op::Gelu {
467            input: remap(input),
468        },
469        Op::Silu { input } => Op::Silu {
470            input: remap(input),
471        },
472
473        Op::AddScalar { input, scalar } => Op::AddScalar {
474            input: remap(input),
475            scalar: *scalar,
476        },
477        Op::MulScalar { input, scalar } => Op::MulScalar {
478            input: remap(input),
479            scalar: *scalar,
480        },
481
482        Op::Sum { input } => Op::Sum {
483            input: remap(input),
484        },
485        Op::SumAxis {
486            input,
487            axis,
488            keepdim,
489        } => Op::SumAxis {
490            input: remap(input),
491            axis: *axis,
492            keepdim: *keepdim,
493        },
494        Op::Mean { input } => Op::Mean {
495            input: remap(input),
496        },
497        Op::MeanAxis {
498            input,
499            axis,
500            keepdim,
501        } => Op::MeanAxis {
502            input: remap(input),
503            axis: *axis,
504            keepdim: *keepdim,
505        },
506        Op::MaxAxis {
507            input,
508            axis,
509            keepdim,
510        } => Op::MaxAxis {
511            input: remap(input),
512            axis: *axis,
513            keepdim: *keepdim,
514        },
515
516        Op::Reshape { input, shape } => Op::Reshape {
517            input: remap(input),
518            shape: shape.clone(),
519        },
520        Op::Transpose { input, dim0, dim1 } => Op::Transpose {
521            input: remap(input),
522            dim0: *dim0,
523            dim1: *dim1,
524        },
525        Op::Squeeze { input, dim } => Op::Squeeze {
526            input: remap(input),
527            dim: *dim,
528        },
529        Op::Unsqueeze { input, dim } => Op::Unsqueeze {
530            input: remap(input),
531            dim: *dim,
532        },
533        Op::Broadcast { input, shape } => Op::Broadcast {
534            input: remap(input),
535            shape: shape.clone(),
536        },
537
538        Op::MatMul { lhs, rhs } => Op::MatMul {
539            lhs: remap(lhs),
540            rhs: remap(rhs),
541        },
542
543        Op::Gt { lhs, rhs } => Op::Gt {
544            lhs: remap(lhs),
545            rhs: remap(rhs),
546        },
547        Op::Lt { lhs, rhs } => Op::Lt {
548            lhs: remap(lhs),
549            rhs: remap(rhs),
550        },
551        Op::Eq { lhs, rhs } => Op::Eq {
552            lhs: remap(lhs),
553            rhs: remap(rhs),
554        },
555
556        Op::Where { condition, x, y } => Op::Where {
557            condition: remap(condition),
558            x: remap(x),
559            y: remap(y),
560        },
561
562        Op::Cast { input, dtype } => Op::Cast {
563            input: remap(input),
564            dtype: *dtype,
565        },
566        Op::Contiguous { input } => Op::Contiguous {
567            input: remap(input),
568        },
569    }
570}
571
572// =============================================================================
573// Tests
574// =============================================================================
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use crate::trace::trace;
580
581    #[test]
582    fn test_dead_code_elimination() {
583        let graph = trace(|tracer| {
584            let a = tracer.input("a", &[2, 3]);
585            let b = tracer.input("b", &[2, 3]);
586            let _unused = a.mul(&b); // This should be eliminated
587            let c = a.add(&b);
588            tracer.output("result", c)
589        });
590
591        let optimizer = Optimizer::new();
592        let mut opt = optimizer;
593        opt.add_pass(OptimizationPass::DeadCodeElimination);
594        let optimized = opt.optimize(graph);
595
596        // Mul node should be eliminated
597        let has_mul = optimized
598            .nodes()
599            .iter()
600            .any(|n| matches!(n.op, Op::Mul { .. }));
601        assert!(!has_mul);
602    }
603
604    #[test]
605    fn test_algebraic_simplification() {
606        let graph = trace(|tracer| {
607            let x = tracer.input("x", &[2, 3]);
608            let y = x.mul_scalar(1.0); // Should be simplified to x
609            tracer.output("y", y)
610        });
611
612        let mut optimizer = Optimizer::new();
613        optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
614        let optimized = optimizer.optimize(graph);
615
616        // MulScalar(1.0) should be eliminated
617        let has_mul_scalar = optimized
618            .nodes()
619            .iter()
620            .any(|n| matches!(n.op, Op::MulScalar { .. }));
621        assert!(!has_mul_scalar);
622    }
623
624    #[test]
625    fn test_constant_folding() {
626        let graph = trace(|tracer| {
627            let x = tracer.input("x", &[2, 3]);
628            let y = x.mul_scalar(0.0); // Should become constant 0
629            tracer.output("y", y)
630        });
631
632        let mut optimizer = Optimizer::new();
633        optimizer.add_pass(OptimizationPass::ConstantFolding);
634        let optimized = optimizer.optimize(graph);
635
636        // Should have a Constant node
637        let has_constant = optimized
638            .nodes()
639            .iter()
640            .any(|n| matches!(n.op, Op::Constant { .. }));
641        assert!(has_constant);
642    }
643}