Skip to main content

axonml_jit/
optimize.rs

1//! Graph Optimization
2//!
3//! Provides optimization passes for computation graphs.
4
5use crate::ir::{Graph, NodeId, Op};
6use rustc_hash::{FxHashMap, FxHashSet};
7
8/// Optimization passes available.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum OptimizationPass {
11    /// Fold constant expressions.
12    ConstantFolding,
13    /// Remove dead (unused) code.
14    DeadCodeElimination,
15    /// Fuse consecutive elementwise operations.
16    ElementwiseFusion,
17    /// Common subexpression elimination.
18    CommonSubexpressionElimination,
19    /// Algebraic simplifications (x * 1 = x, x + 0 = x, etc).
20    AlgebraicSimplification,
21    /// Strength reduction (expensive ops -> cheaper ops).
22    StrengthReduction,
23}
24
25/// Graph optimizer.
26pub struct Optimizer {
27    passes: Vec<OptimizationPass>,
28}
29
30impl Optimizer {
31    /// Creates a new optimizer with no passes.
32    pub fn new() -> Self {
33        Self { passes: Vec::new() }
34    }
35
36    /// Creates an optimizer with default passes.
37    pub fn default_passes() -> Self {
38        Self {
39            passes: vec![
40                OptimizationPass::ConstantFolding,
41                OptimizationPass::AlgebraicSimplification,
42                OptimizationPass::DeadCodeElimination,
43                OptimizationPass::CommonSubexpressionElimination,
44            ],
45        }
46    }
47
48    /// Adds an optimization pass.
49    pub fn add_pass(&mut self, pass: OptimizationPass) {
50        self.passes.push(pass);
51    }
52
53    /// Runs all optimization passes on the graph.
54    pub fn optimize(&self, mut graph: Graph) -> Graph {
55        for pass in &self.passes {
56            graph = self.run_pass(graph, *pass);
57        }
58        graph
59    }
60
61    fn run_pass(&self, graph: Graph, pass: OptimizationPass) -> Graph {
62        match pass {
63            OptimizationPass::ConstantFolding => constant_folding(graph),
64            OptimizationPass::DeadCodeElimination => dead_code_elimination(graph),
65            OptimizationPass::ElementwiseFusion => elementwise_fusion(graph),
66            OptimizationPass::CommonSubexpressionElimination => cse(graph),
67            OptimizationPass::AlgebraicSimplification => algebraic_simplification(graph),
68            OptimizationPass::StrengthReduction => strength_reduction(graph),
69        }
70    }
71}
72
73impl Default for Optimizer {
74    fn default() -> Self {
75        Self::default_passes()
76    }
77}
78
79/// Constant folding: evaluate constant expressions at compile time.
80fn constant_folding(graph: Graph) -> Graph {
81    // For now, just identify constant nodes
82    // Full implementation would evaluate constant subgraphs
83    let mut new_graph = Graph::new();
84    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
85    let mut constants: FxHashMap<NodeId, f64> = FxHashMap::default();
86
87    for node in graph.nodes() {
88        // Track constant values
89        if let Op::Constant { value } = &node.op {
90            constants.insert(node.id, *value);
91        }
92
93        // Try to fold binary ops with constants
94        let new_op = match &node.op {
95            Op::MulScalar { input, scalar } if *scalar == 1.0 => {
96                // x * 1 = x
97                let new_input = node_map.get(input).copied().unwrap_or(*input);
98                node_map.insert(node.id, new_input);
99                continue;
100            }
101            Op::MulScalar { input: _, scalar } if *scalar == 0.0 => {
102                // x * 0 = 0
103                Op::Constant { value: 0.0 }
104            }
105            Op::AddScalar { input, scalar } if *scalar == 0.0 => {
106                // x + 0 = x
107                let new_input = node_map.get(input).copied().unwrap_or(*input);
108                node_map.insert(node.id, new_input);
109                continue;
110            }
111            other => remap_op(other, &node_map),
112        };
113
114        let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
115        node_map.insert(node.id, new_id);
116    }
117
118    // Remap inputs and outputs
119    for (name, id) in graph.inputs() {
120        if let Some(&new_id) = node_map.get(id) {
121            new_graph.register_input(name, new_id);
122        }
123    }
124    for (name, id) in graph.outputs() {
125        if let Some(&new_id) = node_map.get(id) {
126            new_graph.register_output(name, new_id);
127        }
128    }
129
130    new_graph
131}
132
133/// Dead code elimination: remove nodes that don't contribute to outputs.
134fn dead_code_elimination(graph: Graph) -> Graph {
135    // Find all nodes reachable from outputs
136    let mut live_nodes: FxHashSet<NodeId> = FxHashSet::default();
137    let mut worklist: Vec<NodeId> = graph.outputs().values().copied().collect();
138
139    while let Some(id) = worklist.pop() {
140        if live_nodes.insert(id) {
141            let node = graph.node(id);
142            for input_id in node.op.inputs() {
143                worklist.push(input_id);
144            }
145        }
146    }
147
148    // Rebuild graph with only live nodes
149    let mut new_graph = Graph::new();
150    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
151
152    for node in graph.nodes() {
153        if !live_nodes.contains(&node.id) {
154            continue;
155        }
156
157        let new_op = remap_op(&node.op, &node_map);
158        let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
159        node_map.insert(node.id, new_id);
160    }
161
162    // Remap inputs and outputs
163    for (name, id) in graph.inputs() {
164        if let Some(&new_id) = node_map.get(id) {
165            new_graph.register_input(name, new_id);
166        }
167    }
168    for (name, id) in graph.outputs() {
169        if let Some(&new_id) = node_map.get(id) {
170            new_graph.register_output(name, new_id);
171        }
172    }
173
174    new_graph
175}
176
177/// Elementwise fusion: combine consecutive elementwise ops into kernels.
178fn elementwise_fusion(graph: Graph) -> Graph {
179    // For now, just return the graph unchanged
180    // Full implementation would identify fusible sequences
181    // and create FusedElementwise nodes
182    graph
183}
184
185/// Common subexpression elimination.
186fn cse(graph: Graph) -> Graph {
187    // Hash-based CSE
188    let mut new_graph = Graph::new();
189    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
190    let mut expr_map: FxHashMap<String, NodeId> = FxHashMap::default();
191
192    for node in graph.nodes() {
193        let remapped_op = remap_op(&node.op, &node_map);
194        let expr_key = format!("{:?}", remapped_op);
195
196        if let Some(&existing_id) = expr_map.get(&expr_key) {
197            // Reuse existing node
198            node_map.insert(node.id, existing_id);
199        } else {
200            let new_id = new_graph.add_node(remapped_op, node.dtype, node.shape.clone());
201            node_map.insert(node.id, new_id);
202            expr_map.insert(expr_key, new_id);
203        }
204    }
205
206    // Remap inputs and outputs
207    for (name, id) in graph.inputs() {
208        if let Some(&new_id) = node_map.get(id) {
209            new_graph.register_input(name, new_id);
210        }
211    }
212    for (name, id) in graph.outputs() {
213        if let Some(&new_id) = node_map.get(id) {
214            new_graph.register_output(name, new_id);
215        }
216    }
217
218    new_graph
219}
220
221/// Algebraic simplifications.
222fn algebraic_simplification(graph: Graph) -> Graph {
223    let mut new_graph = Graph::new();
224    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
225
226    for node in graph.nodes() {
227        let simplified_op = match &node.op {
228            // x * 1 = x
229            Op::MulScalar { input, scalar } if *scalar == 1.0 => {
230                let new_input = node_map.get(input).copied().unwrap_or(*input);
231                node_map.insert(node.id, new_input);
232                continue;
233            }
234            // x + 0 = x
235            Op::AddScalar { input, scalar } if *scalar == 0.0 => {
236                let new_input = node_map.get(input).copied().unwrap_or(*input);
237                node_map.insert(node.id, new_input);
238                continue;
239            }
240            // x - 0 = x (via AddScalar with -0)
241            // x / 1 = x (via MulScalar with 1)
242            // --x = x
243            Op::Neg { input } => {
244                let actual_input = node_map.get(input).copied().unwrap_or(*input);
245                if let Some(input_node) = new_graph.nodes().iter().find(|n| n.id == actual_input) {
246                    if let Op::Neg { input: inner } = &input_node.op {
247                        node_map.insert(node.id, *inner);
248                        continue;
249                    }
250                }
251                Op::Neg {
252                    input: actual_input,
253                }
254            }
255            other => remap_op(other, &node_map),
256        };
257
258        let new_id = new_graph.add_node(simplified_op, node.dtype, node.shape.clone());
259        node_map.insert(node.id, new_id);
260    }
261
262    // Remap inputs and outputs
263    for (name, id) in graph.inputs() {
264        if let Some(&new_id) = node_map.get(id) {
265            new_graph.register_input(name, new_id);
266        }
267    }
268    for (name, id) in graph.outputs() {
269        if let Some(&new_id) = node_map.get(id) {
270            new_graph.register_output(name, new_id);
271        }
272    }
273
274    new_graph
275}
276
277/// Strength reduction: replace expensive ops with cheaper equivalents.
278fn strength_reduction(graph: Graph) -> Graph {
279    let mut new_graph = Graph::new();
280    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
281
282    for node in graph.nodes() {
283        let reduced_op = match &node.op {
284            // x^2 -> x * x
285            Op::Pow { .. } => {
286                // Check if exp is constant 2
287                // For now, just pass through
288                remap_op(&node.op, &node_map)
289            }
290            // x / c -> x * (1/c) for constant c
291            Op::Div { .. } => {
292                // Would need to check if rhs is constant
293                remap_op(&node.op, &node_map)
294            }
295            other => remap_op(other, &node_map),
296        };
297
298        let new_id = new_graph.add_node(reduced_op, node.dtype, node.shape.clone());
299        node_map.insert(node.id, new_id);
300    }
301
302    // Remap inputs and outputs
303    for (name, id) in graph.inputs() {
304        if let Some(&new_id) = node_map.get(id) {
305            new_graph.register_input(name, new_id);
306        }
307    }
308    for (name, id) in graph.outputs() {
309        if let Some(&new_id) = node_map.get(id) {
310            new_graph.register_output(name, new_id);
311        }
312    }
313
314    new_graph
315}
316
317/// Remaps node IDs in an operation using the provided mapping.
318fn remap_op(op: &Op, node_map: &FxHashMap<NodeId, NodeId>) -> Op {
319    let remap = |id: &NodeId| node_map.get(id).copied().unwrap_or(*id);
320
321    match op {
322        Op::Input { name } => Op::Input { name: name.clone() },
323        Op::Output { name, input } => Op::Output {
324            name: name.clone(),
325            input: remap(input),
326        },
327        Op::Constant { value } => Op::Constant { value: *value },
328
329        Op::Add { lhs, rhs } => Op::Add {
330            lhs: remap(lhs),
331            rhs: remap(rhs),
332        },
333        Op::Sub { lhs, rhs } => Op::Sub {
334            lhs: remap(lhs),
335            rhs: remap(rhs),
336        },
337        Op::Mul { lhs, rhs } => Op::Mul {
338            lhs: remap(lhs),
339            rhs: remap(rhs),
340        },
341        Op::Div { lhs, rhs } => Op::Div {
342            lhs: remap(lhs),
343            rhs: remap(rhs),
344        },
345        Op::Pow { base, exp } => Op::Pow {
346            base: remap(base),
347            exp: remap(exp),
348        },
349        Op::Max { lhs, rhs } => Op::Max {
350            lhs: remap(lhs),
351            rhs: remap(rhs),
352        },
353        Op::Min { lhs, rhs } => Op::Min {
354            lhs: remap(lhs),
355            rhs: remap(rhs),
356        },
357
358        Op::Neg { input } => Op::Neg {
359            input: remap(input),
360        },
361        Op::Abs { input } => Op::Abs {
362            input: remap(input),
363        },
364        Op::Sqrt { input } => Op::Sqrt {
365            input: remap(input),
366        },
367        Op::Exp { input } => Op::Exp {
368            input: remap(input),
369        },
370        Op::Log { input } => Op::Log {
371            input: remap(input),
372        },
373        Op::Sin { input } => Op::Sin {
374            input: remap(input),
375        },
376        Op::Cos { input } => Op::Cos {
377            input: remap(input),
378        },
379        Op::Tanh { input } => Op::Tanh {
380            input: remap(input),
381        },
382
383        Op::Relu { input } => Op::Relu {
384            input: remap(input),
385        },
386        Op::Sigmoid { input } => Op::Sigmoid {
387            input: remap(input),
388        },
389        Op::Gelu { input } => Op::Gelu {
390            input: remap(input),
391        },
392        Op::Silu { input } => Op::Silu {
393            input: remap(input),
394        },
395
396        Op::AddScalar { input, scalar } => Op::AddScalar {
397            input: remap(input),
398            scalar: *scalar,
399        },
400        Op::MulScalar { input, scalar } => Op::MulScalar {
401            input: remap(input),
402            scalar: *scalar,
403        },
404
405        Op::Sum { input } => Op::Sum {
406            input: remap(input),
407        },
408        Op::SumAxis {
409            input,
410            axis,
411            keepdim,
412        } => Op::SumAxis {
413            input: remap(input),
414            axis: *axis,
415            keepdim: *keepdim,
416        },
417        Op::Mean { input } => Op::Mean {
418            input: remap(input),
419        },
420        Op::MeanAxis {
421            input,
422            axis,
423            keepdim,
424        } => Op::MeanAxis {
425            input: remap(input),
426            axis: *axis,
427            keepdim: *keepdim,
428        },
429        Op::MaxAxis {
430            input,
431            axis,
432            keepdim,
433        } => Op::MaxAxis {
434            input: remap(input),
435            axis: *axis,
436            keepdim: *keepdim,
437        },
438
439        Op::Reshape { input, shape } => Op::Reshape {
440            input: remap(input),
441            shape: shape.clone(),
442        },
443        Op::Transpose { input, dim0, dim1 } => Op::Transpose {
444            input: remap(input),
445            dim0: *dim0,
446            dim1: *dim1,
447        },
448        Op::Squeeze { input, dim } => Op::Squeeze {
449            input: remap(input),
450            dim: *dim,
451        },
452        Op::Unsqueeze { input, dim } => Op::Unsqueeze {
453            input: remap(input),
454            dim: *dim,
455        },
456        Op::Broadcast { input, shape } => Op::Broadcast {
457            input: remap(input),
458            shape: shape.clone(),
459        },
460
461        Op::MatMul { lhs, rhs } => Op::MatMul {
462            lhs: remap(lhs),
463            rhs: remap(rhs),
464        },
465
466        Op::Gt { lhs, rhs } => Op::Gt {
467            lhs: remap(lhs),
468            rhs: remap(rhs),
469        },
470        Op::Lt { lhs, rhs } => Op::Lt {
471            lhs: remap(lhs),
472            rhs: remap(rhs),
473        },
474        Op::Eq { lhs, rhs } => Op::Eq {
475            lhs: remap(lhs),
476            rhs: remap(rhs),
477        },
478
479        Op::Where { condition, x, y } => Op::Where {
480            condition: remap(condition),
481            x: remap(x),
482            y: remap(y),
483        },
484
485        Op::Cast { input, dtype } => Op::Cast {
486            input: remap(input),
487            dtype: *dtype,
488        },
489        Op::Contiguous { input } => Op::Contiguous {
490            input: remap(input),
491        },
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::trace::trace;
499
500    #[test]
501    fn test_dead_code_elimination() {
502        let graph = trace(|tracer| {
503            let a = tracer.input("a", &[2, 3]);
504            let b = tracer.input("b", &[2, 3]);
505            let _unused = a.mul(&b); // This should be eliminated
506            let c = a.add(&b);
507            tracer.output("result", c)
508        });
509
510        let optimizer = Optimizer::new();
511        let mut opt = optimizer;
512        opt.add_pass(OptimizationPass::DeadCodeElimination);
513        let optimized = opt.optimize(graph);
514
515        // Mul node should be eliminated
516        let has_mul = optimized
517            .nodes()
518            .iter()
519            .any(|n| matches!(n.op, Op::Mul { .. }));
520        assert!(!has_mul);
521    }
522
523    #[test]
524    fn test_algebraic_simplification() {
525        let graph = trace(|tracer| {
526            let x = tracer.input("x", &[2, 3]);
527            let y = x.mul_scalar(1.0); // Should be simplified to x
528            tracer.output("y", y)
529        });
530
531        let mut optimizer = Optimizer::new();
532        optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
533        let optimized = optimizer.optimize(graph);
534
535        // MulScalar(1.0) should be eliminated
536        let has_mul_scalar = optimized
537            .nodes()
538            .iter()
539            .any(|n| matches!(n.op, Op::MulScalar { .. }));
540        assert!(!has_mul_scalar);
541    }
542
543    #[test]
544    fn test_constant_folding() {
545        let graph = trace(|tracer| {
546            let x = tracer.input("x", &[2, 3]);
547            let y = x.mul_scalar(0.0); // Should become constant 0
548            tracer.output("y", y)
549        });
550
551        let mut optimizer = Optimizer::new();
552        optimizer.add_pass(OptimizationPass::ConstantFolding);
553        let optimized = optimizer.optimize(graph);
554
555        // Should have a Constant node
556        let has_constant = optimized
557            .nodes()
558            .iter()
559            .any(|n| matches!(n.op, Op::Constant { .. }));
560        assert!(has_constant);
561    }
562}