Skip to main content

axonml_jit/
optimize.rs

1//! Graph Optimization
2//!
3//! # File
4//! `crates/axonml-jit/src/optimize.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::ir::{Graph, NodeId, Op};
18use rustc_hash::{FxHashMap, FxHashSet};
19
20/// Optimization passes available.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum OptimizationPass {
23    /// Fold constant expressions.
24    ConstantFolding,
25    /// Remove dead (unused) code.
26    DeadCodeElimination,
27    /// Fuse consecutive elementwise operations.
28    ElementwiseFusion,
29    /// Common subexpression elimination.
30    CommonSubexpressionElimination,
31    /// Algebraic simplifications (x * 1 = x, x + 0 = x, etc).
32    AlgebraicSimplification,
33    /// Strength reduction (expensive ops -> cheaper ops).
34    StrengthReduction,
35}
36
37/// Graph optimizer.
38pub struct Optimizer {
39    passes: Vec<OptimizationPass>,
40}
41
42impl Optimizer {
43    /// Creates a new optimizer with no passes.
44    pub fn new() -> Self {
45        Self { passes: Vec::new() }
46    }
47
48    /// Creates an optimizer with default passes.
49    pub fn default_passes() -> Self {
50        Self {
51            passes: vec![
52                OptimizationPass::ConstantFolding,
53                OptimizationPass::AlgebraicSimplification,
54                OptimizationPass::DeadCodeElimination,
55                OptimizationPass::CommonSubexpressionElimination,
56            ],
57        }
58    }
59
60    /// Adds an optimization pass.
61    pub fn add_pass(&mut self, pass: OptimizationPass) {
62        self.passes.push(pass);
63    }
64
65    /// Runs all optimization passes on the graph.
66    pub fn optimize(&self, mut graph: Graph) -> Graph {
67        for pass in &self.passes {
68            graph = self.run_pass(graph, *pass);
69        }
70        graph
71    }
72
73    fn run_pass(&self, graph: Graph, pass: OptimizationPass) -> Graph {
74        match pass {
75            OptimizationPass::ConstantFolding => constant_folding(graph),
76            OptimizationPass::DeadCodeElimination => dead_code_elimination(graph),
77            OptimizationPass::ElementwiseFusion => elementwise_fusion(graph),
78            OptimizationPass::CommonSubexpressionElimination => cse(graph),
79            OptimizationPass::AlgebraicSimplification => algebraic_simplification(graph),
80            OptimizationPass::StrengthReduction => strength_reduction(graph),
81        }
82    }
83}
84
85impl Default for Optimizer {
86    fn default() -> Self {
87        Self::default_passes()
88    }
89}
90
91/// Constant folding: evaluate constant expressions at compile time.
92fn constant_folding(graph: Graph) -> Graph {
93    // For now, just identify constant nodes
94    // Full implementation would evaluate constant subgraphs
95    let mut new_graph = Graph::new();
96    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
97    let mut constants: FxHashMap<NodeId, f64> = FxHashMap::default();
98
99    for node in graph.nodes() {
100        // Track constant values
101        if let Op::Constant { value } = &node.op {
102            constants.insert(node.id, *value);
103        }
104
105        // Try to fold binary ops with constants
106        let new_op = match &node.op {
107            Op::MulScalar { input, scalar } if *scalar == 1.0 => {
108                // x * 1 = x
109                let new_input = node_map.get(input).copied().unwrap_or(*input);
110                node_map.insert(node.id, new_input);
111                continue;
112            }
113            Op::MulScalar { input: _, scalar } if *scalar == 0.0 => {
114                // x * 0 = 0
115                Op::Constant { value: 0.0 }
116            }
117            Op::AddScalar { input, scalar } if *scalar == 0.0 => {
118                // x + 0 = x
119                let new_input = node_map.get(input).copied().unwrap_or(*input);
120                node_map.insert(node.id, new_input);
121                continue;
122            }
123            other => remap_op(other, &node_map),
124        };
125
126        let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
127        node_map.insert(node.id, new_id);
128    }
129
130    // Remap inputs and outputs
131    for (name, id) in graph.inputs() {
132        if let Some(&new_id) = node_map.get(id) {
133            new_graph.register_input(name, new_id);
134        }
135    }
136    for (name, id) in graph.outputs() {
137        if let Some(&new_id) = node_map.get(id) {
138            new_graph.register_output(name, new_id);
139        }
140    }
141
142    new_graph
143}
144
145/// Dead code elimination: remove nodes that don't contribute to outputs.
146fn dead_code_elimination(graph: Graph) -> Graph {
147    // Find all nodes reachable from outputs
148    let mut live_nodes: FxHashSet<NodeId> = FxHashSet::default();
149    let mut worklist: Vec<NodeId> = graph.outputs().values().copied().collect();
150
151    while let Some(id) = worklist.pop() {
152        if live_nodes.insert(id) {
153            let node = graph.node(id);
154            for input_id in node.op.inputs() {
155                worklist.push(input_id);
156            }
157        }
158    }
159
160    // Rebuild graph with only live nodes
161    let mut new_graph = Graph::new();
162    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
163
164    for node in graph.nodes() {
165        if !live_nodes.contains(&node.id) {
166            continue;
167        }
168
169        let new_op = remap_op(&node.op, &node_map);
170        let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
171        node_map.insert(node.id, new_id);
172    }
173
174    // Remap inputs and outputs
175    for (name, id) in graph.inputs() {
176        if let Some(&new_id) = node_map.get(id) {
177            new_graph.register_input(name, new_id);
178        }
179    }
180    for (name, id) in graph.outputs() {
181        if let Some(&new_id) = node_map.get(id) {
182            new_graph.register_output(name, new_id);
183        }
184    }
185
186    new_graph
187}
188
189/// Elementwise fusion: combine consecutive elementwise ops into kernels.
190fn elementwise_fusion(graph: Graph) -> Graph {
191    // For now, just return the graph unchanged
192    // Full implementation would identify fusible sequences
193    // and create FusedElementwise nodes
194    graph
195}
196
197/// Common subexpression elimination.
198fn cse(graph: Graph) -> Graph {
199    // Hash-based CSE
200    let mut new_graph = Graph::new();
201    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
202    let mut expr_map: FxHashMap<String, NodeId> = FxHashMap::default();
203
204    for node in graph.nodes() {
205        let remapped_op = remap_op(&node.op, &node_map);
206        let expr_key = format!("{:?}", remapped_op);
207
208        if let Some(&existing_id) = expr_map.get(&expr_key) {
209            // Reuse existing node
210            node_map.insert(node.id, existing_id);
211        } else {
212            let new_id = new_graph.add_node(remapped_op, node.dtype, node.shape.clone());
213            node_map.insert(node.id, new_id);
214            expr_map.insert(expr_key, new_id);
215        }
216    }
217
218    // Remap inputs and outputs
219    for (name, id) in graph.inputs() {
220        if let Some(&new_id) = node_map.get(id) {
221            new_graph.register_input(name, new_id);
222        }
223    }
224    for (name, id) in graph.outputs() {
225        if let Some(&new_id) = node_map.get(id) {
226            new_graph.register_output(name, new_id);
227        }
228    }
229
230    new_graph
231}
232
233/// Algebraic simplifications.
234fn algebraic_simplification(graph: Graph) -> Graph {
235    let mut new_graph = Graph::new();
236    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
237
238    for node in graph.nodes() {
239        let simplified_op = match &node.op {
240            // x * 1 = x
241            Op::MulScalar { input, scalar } if *scalar == 1.0 => {
242                let new_input = node_map.get(input).copied().unwrap_or(*input);
243                node_map.insert(node.id, new_input);
244                continue;
245            }
246            // x + 0 = x
247            Op::AddScalar { input, scalar } if *scalar == 0.0 => {
248                let new_input = node_map.get(input).copied().unwrap_or(*input);
249                node_map.insert(node.id, new_input);
250                continue;
251            }
252            // x - 0 = x (via AddScalar with -0)
253            // x / 1 = x (via MulScalar with 1)
254            // --x = x
255            Op::Neg { input } => {
256                let actual_input = node_map.get(input).copied().unwrap_or(*input);
257                if let Some(input_node) = new_graph.nodes().iter().find(|n| n.id == actual_input) {
258                    if let Op::Neg { input: inner } = &input_node.op {
259                        node_map.insert(node.id, *inner);
260                        continue;
261                    }
262                }
263                Op::Neg {
264                    input: actual_input,
265                }
266            }
267            other => remap_op(other, &node_map),
268        };
269
270        let new_id = new_graph.add_node(simplified_op, node.dtype, node.shape.clone());
271        node_map.insert(node.id, new_id);
272    }
273
274    // Remap inputs and outputs
275    for (name, id) in graph.inputs() {
276        if let Some(&new_id) = node_map.get(id) {
277            new_graph.register_input(name, new_id);
278        }
279    }
280    for (name, id) in graph.outputs() {
281        if let Some(&new_id) = node_map.get(id) {
282            new_graph.register_output(name, new_id);
283        }
284    }
285
286    new_graph
287}
288
289/// Strength reduction: replace expensive ops with cheaper equivalents.
290fn strength_reduction(graph: Graph) -> Graph {
291    let mut new_graph = Graph::new();
292    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
293
294    for node in graph.nodes() {
295        let reduced_op = match &node.op {
296            // x^2 -> x * x
297            Op::Pow { .. } => {
298                // Check if exp is constant 2
299                // For now, just pass through
300                remap_op(&node.op, &node_map)
301            }
302            // x / c -> x * (1/c) for constant c
303            Op::Div { .. } => {
304                // Would need to check if rhs is constant
305                remap_op(&node.op, &node_map)
306            }
307            other => remap_op(other, &node_map),
308        };
309
310        let new_id = new_graph.add_node(reduced_op, node.dtype, node.shape.clone());
311        node_map.insert(node.id, new_id);
312    }
313
314    // Remap inputs and outputs
315    for (name, id) in graph.inputs() {
316        if let Some(&new_id) = node_map.get(id) {
317            new_graph.register_input(name, new_id);
318        }
319    }
320    for (name, id) in graph.outputs() {
321        if let Some(&new_id) = node_map.get(id) {
322            new_graph.register_output(name, new_id);
323        }
324    }
325
326    new_graph
327}
328
329/// Remaps node IDs in an operation using the provided mapping.
330fn remap_op(op: &Op, node_map: &FxHashMap<NodeId, NodeId>) -> Op {
331    let remap = |id: &NodeId| node_map.get(id).copied().unwrap_or(*id);
332
333    match op {
334        Op::Input { name } => Op::Input { name: name.clone() },
335        Op::Output { name, input } => Op::Output {
336            name: name.clone(),
337            input: remap(input),
338        },
339        Op::Constant { value } => Op::Constant { value: *value },
340
341        Op::Add { lhs, rhs } => Op::Add {
342            lhs: remap(lhs),
343            rhs: remap(rhs),
344        },
345        Op::Sub { lhs, rhs } => Op::Sub {
346            lhs: remap(lhs),
347            rhs: remap(rhs),
348        },
349        Op::Mul { lhs, rhs } => Op::Mul {
350            lhs: remap(lhs),
351            rhs: remap(rhs),
352        },
353        Op::Div { lhs, rhs } => Op::Div {
354            lhs: remap(lhs),
355            rhs: remap(rhs),
356        },
357        Op::Pow { base, exp } => Op::Pow {
358            base: remap(base),
359            exp: remap(exp),
360        },
361        Op::Max { lhs, rhs } => Op::Max {
362            lhs: remap(lhs),
363            rhs: remap(rhs),
364        },
365        Op::Min { lhs, rhs } => Op::Min {
366            lhs: remap(lhs),
367            rhs: remap(rhs),
368        },
369
370        Op::Neg { input } => Op::Neg {
371            input: remap(input),
372        },
373        Op::Abs { input } => Op::Abs {
374            input: remap(input),
375        },
376        Op::Sqrt { input } => Op::Sqrt {
377            input: remap(input),
378        },
379        Op::Exp { input } => Op::Exp {
380            input: remap(input),
381        },
382        Op::Log { input } => Op::Log {
383            input: remap(input),
384        },
385        Op::Sin { input } => Op::Sin {
386            input: remap(input),
387        },
388        Op::Cos { input } => Op::Cos {
389            input: remap(input),
390        },
391        Op::Tanh { input } => Op::Tanh {
392            input: remap(input),
393        },
394
395        Op::Relu { input } => Op::Relu {
396            input: remap(input),
397        },
398        Op::Sigmoid { input } => Op::Sigmoid {
399            input: remap(input),
400        },
401        Op::Gelu { input } => Op::Gelu {
402            input: remap(input),
403        },
404        Op::Silu { input } => Op::Silu {
405            input: remap(input),
406        },
407
408        Op::AddScalar { input, scalar } => Op::AddScalar {
409            input: remap(input),
410            scalar: *scalar,
411        },
412        Op::MulScalar { input, scalar } => Op::MulScalar {
413            input: remap(input),
414            scalar: *scalar,
415        },
416
417        Op::Sum { input } => Op::Sum {
418            input: remap(input),
419        },
420        Op::SumAxis {
421            input,
422            axis,
423            keepdim,
424        } => Op::SumAxis {
425            input: remap(input),
426            axis: *axis,
427            keepdim: *keepdim,
428        },
429        Op::Mean { input } => Op::Mean {
430            input: remap(input),
431        },
432        Op::MeanAxis {
433            input,
434            axis,
435            keepdim,
436        } => Op::MeanAxis {
437            input: remap(input),
438            axis: *axis,
439            keepdim: *keepdim,
440        },
441        Op::MaxAxis {
442            input,
443            axis,
444            keepdim,
445        } => Op::MaxAxis {
446            input: remap(input),
447            axis: *axis,
448            keepdim: *keepdim,
449        },
450
451        Op::Reshape { input, shape } => Op::Reshape {
452            input: remap(input),
453            shape: shape.clone(),
454        },
455        Op::Transpose { input, dim0, dim1 } => Op::Transpose {
456            input: remap(input),
457            dim0: *dim0,
458            dim1: *dim1,
459        },
460        Op::Squeeze { input, dim } => Op::Squeeze {
461            input: remap(input),
462            dim: *dim,
463        },
464        Op::Unsqueeze { input, dim } => Op::Unsqueeze {
465            input: remap(input),
466            dim: *dim,
467        },
468        Op::Broadcast { input, shape } => Op::Broadcast {
469            input: remap(input),
470            shape: shape.clone(),
471        },
472
473        Op::MatMul { lhs, rhs } => Op::MatMul {
474            lhs: remap(lhs),
475            rhs: remap(rhs),
476        },
477
478        Op::Gt { lhs, rhs } => Op::Gt {
479            lhs: remap(lhs),
480            rhs: remap(rhs),
481        },
482        Op::Lt { lhs, rhs } => Op::Lt {
483            lhs: remap(lhs),
484            rhs: remap(rhs),
485        },
486        Op::Eq { lhs, rhs } => Op::Eq {
487            lhs: remap(lhs),
488            rhs: remap(rhs),
489        },
490
491        Op::Where { condition, x, y } => Op::Where {
492            condition: remap(condition),
493            x: remap(x),
494            y: remap(y),
495        },
496
497        Op::Cast { input, dtype } => Op::Cast {
498            input: remap(input),
499            dtype: *dtype,
500        },
501        Op::Contiguous { input } => Op::Contiguous {
502            input: remap(input),
503        },
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510    use crate::trace::trace;
511
512    #[test]
513    fn test_dead_code_elimination() {
514        let graph = trace(|tracer| {
515            let a = tracer.input("a", &[2, 3]);
516            let b = tracer.input("b", &[2, 3]);
517            let _unused = a.mul(&b); // This should be eliminated
518            let c = a.add(&b);
519            tracer.output("result", c)
520        });
521
522        let optimizer = Optimizer::new();
523        let mut opt = optimizer;
524        opt.add_pass(OptimizationPass::DeadCodeElimination);
525        let optimized = opt.optimize(graph);
526
527        // Mul node should be eliminated
528        let has_mul = optimized
529            .nodes()
530            .iter()
531            .any(|n| matches!(n.op, Op::Mul { .. }));
532        assert!(!has_mul);
533    }
534
535    #[test]
536    fn test_algebraic_simplification() {
537        let graph = trace(|tracer| {
538            let x = tracer.input("x", &[2, 3]);
539            let y = x.mul_scalar(1.0); // Should be simplified to x
540            tracer.output("y", y)
541        });
542
543        let mut optimizer = Optimizer::new();
544        optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
545        let optimized = optimizer.optimize(graph);
546
547        // MulScalar(1.0) should be eliminated
548        let has_mul_scalar = optimized
549            .nodes()
550            .iter()
551            .any(|n| matches!(n.op, Op::MulScalar { .. }));
552        assert!(!has_mul_scalar);
553    }
554
555    #[test]
556    fn test_constant_folding() {
557        let graph = trace(|tracer| {
558            let x = tracer.input("x", &[2, 3]);
559            let y = x.mul_scalar(0.0); // Should become constant 0
560            tracer.output("y", y)
561        });
562
563        let mut optimizer = Optimizer::new();
564        optimizer.add_pass(OptimizationPass::ConstantFolding);
565        let optimized = optimizer.optimize(graph);
566
567        // Should have a Constant node
568        let has_constant = optimized
569            .nodes()
570            .iter()
571            .any(|n| matches!(n.op, Op::Constant { .. }));
572        assert!(has_constant);
573    }
574}