Skip to main content

axonml_jit/
optimize.rs

1//! Graph Optimization
2//!
3//! Provides optimization passes for computation graphs.
4
5use crate::ir::{Graph, NodeId, Op};
6use rustc_hash::{FxHashMap, FxHashSet};
7
8/// Optimization passes available.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum OptimizationPass {
11    /// Fold constant expressions.
12    ConstantFolding,
13    /// Remove dead (unused) code.
14    DeadCodeElimination,
15    /// Fuse consecutive elementwise operations.
16    ElementwiseFusion,
17    /// Common subexpression elimination.
18    CommonSubexpressionElimination,
19    /// Algebraic simplifications (x * 1 = x, x + 0 = x, etc).
20    AlgebraicSimplification,
21    /// Strength reduction (expensive ops -> cheaper ops).
22    StrengthReduction,
23}
24
25/// Graph optimizer.
26pub struct Optimizer {
27    passes: Vec<OptimizationPass>,
28}
29
30impl Optimizer {
31    /// Creates a new optimizer with no passes.
32    pub fn new() -> Self {
33        Self { passes: Vec::new() }
34    }
35
36    /// Creates an optimizer with default passes.
37    pub fn default_passes() -> Self {
38        Self {
39            passes: vec![
40                OptimizationPass::ConstantFolding,
41                OptimizationPass::AlgebraicSimplification,
42                OptimizationPass::DeadCodeElimination,
43                OptimizationPass::CommonSubexpressionElimination,
44            ],
45        }
46    }
47
48    /// Adds an optimization pass.
49    pub fn add_pass(&mut self, pass: OptimizationPass) {
50        self.passes.push(pass);
51    }
52
53    /// Runs all optimization passes on the graph.
54    pub fn optimize(&self, mut graph: Graph) -> Graph {
55        for pass in &self.passes {
56            graph = self.run_pass(graph, *pass);
57        }
58        graph
59    }
60
61    fn run_pass(&self, graph: Graph, pass: OptimizationPass) -> Graph {
62        match pass {
63            OptimizationPass::ConstantFolding => constant_folding(graph),
64            OptimizationPass::DeadCodeElimination => dead_code_elimination(graph),
65            OptimizationPass::ElementwiseFusion => elementwise_fusion(graph),
66            OptimizationPass::CommonSubexpressionElimination => cse(graph),
67            OptimizationPass::AlgebraicSimplification => algebraic_simplification(graph),
68            OptimizationPass::StrengthReduction => strength_reduction(graph),
69        }
70    }
71}
72
73impl Default for Optimizer {
74    fn default() -> Self {
75        Self::default_passes()
76    }
77}
78
79/// Constant folding: evaluate constant expressions at compile time.
80fn constant_folding(graph: Graph) -> Graph {
81    // For now, just identify constant nodes
82    // Full implementation would evaluate constant subgraphs
83    let mut new_graph = Graph::new();
84    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
85    let mut constants: FxHashMap<NodeId, f64> = FxHashMap::default();
86
87    for node in graph.nodes() {
88        // Track constant values
89        if let Op::Constant { value } = &node.op {
90            constants.insert(node.id, *value);
91        }
92
93        // Try to fold binary ops with constants
94        let new_op = match &node.op {
95            Op::MulScalar { input, scalar } if *scalar == 1.0 => {
96                // x * 1 = x
97                let new_input = node_map.get(input).copied().unwrap_or(*input);
98                node_map.insert(node.id, new_input);
99                continue;
100            }
101            Op::MulScalar { input: _, scalar } if *scalar == 0.0 => {
102                // x * 0 = 0
103                Op::Constant { value: 0.0 }
104            }
105            Op::AddScalar { input, scalar } if *scalar == 0.0 => {
106                // x + 0 = x
107                let new_input = node_map.get(input).copied().unwrap_or(*input);
108                node_map.insert(node.id, new_input);
109                continue;
110            }
111            other => remap_op(other, &node_map),
112        };
113
114        let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
115        node_map.insert(node.id, new_id);
116    }
117
118    // Remap inputs and outputs
119    for (name, id) in graph.inputs() {
120        if let Some(&new_id) = node_map.get(id) {
121            new_graph.register_input(name, new_id);
122        }
123    }
124    for (name, id) in graph.outputs() {
125        if let Some(&new_id) = node_map.get(id) {
126            new_graph.register_output(name, new_id);
127        }
128    }
129
130    new_graph
131}
132
133/// Dead code elimination: remove nodes that don't contribute to outputs.
134fn dead_code_elimination(graph: Graph) -> Graph {
135    // Find all nodes reachable from outputs
136    let mut live_nodes: FxHashSet<NodeId> = FxHashSet::default();
137    let mut worklist: Vec<NodeId> = graph.outputs().values().copied().collect();
138
139    while let Some(id) = worklist.pop() {
140        if live_nodes.insert(id) {
141            let node = graph.node(id);
142            for input_id in node.op.inputs() {
143                worklist.push(input_id);
144            }
145        }
146    }
147
148    // Rebuild graph with only live nodes
149    let mut new_graph = Graph::new();
150    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
151
152    for node in graph.nodes() {
153        if !live_nodes.contains(&node.id) {
154            continue;
155        }
156
157        let new_op = remap_op(&node.op, &node_map);
158        let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
159        node_map.insert(node.id, new_id);
160    }
161
162    // Remap inputs and outputs
163    for (name, id) in graph.inputs() {
164        if let Some(&new_id) = node_map.get(id) {
165            new_graph.register_input(name, new_id);
166        }
167    }
168    for (name, id) in graph.outputs() {
169        if let Some(&new_id) = node_map.get(id) {
170            new_graph.register_output(name, new_id);
171        }
172    }
173
174    new_graph
175}
176
177/// Elementwise fusion: combine consecutive elementwise ops into kernels.
178fn elementwise_fusion(graph: Graph) -> Graph {
179    // For now, just return the graph unchanged
180    // Full implementation would identify fusible sequences
181    // and create FusedElementwise nodes
182    graph
183}
184
185/// Common subexpression elimination.
186fn cse(graph: Graph) -> Graph {
187    // Hash-based CSE
188    let mut new_graph = Graph::new();
189    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
190    let mut expr_map: FxHashMap<String, NodeId> = FxHashMap::default();
191
192    for node in graph.nodes() {
193        let remapped_op = remap_op(&node.op, &node_map);
194        let expr_key = format!("{:?}", remapped_op);
195
196        if let Some(&existing_id) = expr_map.get(&expr_key) {
197            // Reuse existing node
198            node_map.insert(node.id, existing_id);
199        } else {
200            let new_id = new_graph.add_node(remapped_op, node.dtype, node.shape.clone());
201            node_map.insert(node.id, new_id);
202            expr_map.insert(expr_key, new_id);
203        }
204    }
205
206    // Remap inputs and outputs
207    for (name, id) in graph.inputs() {
208        if let Some(&new_id) = node_map.get(id) {
209            new_graph.register_input(name, new_id);
210        }
211    }
212    for (name, id) in graph.outputs() {
213        if let Some(&new_id) = node_map.get(id) {
214            new_graph.register_output(name, new_id);
215        }
216    }
217
218    new_graph
219}
220
221/// Algebraic simplifications.
222fn algebraic_simplification(graph: Graph) -> Graph {
223    let mut new_graph = Graph::new();
224    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
225
226    for node in graph.nodes() {
227        let simplified_op = match &node.op {
228            // x * 1 = x
229            Op::MulScalar { input, scalar } if *scalar == 1.0 => {
230                let new_input = node_map.get(input).copied().unwrap_or(*input);
231                node_map.insert(node.id, new_input);
232                continue;
233            }
234            // x + 0 = x
235            Op::AddScalar { input, scalar } if *scalar == 0.0 => {
236                let new_input = node_map.get(input).copied().unwrap_or(*input);
237                node_map.insert(node.id, new_input);
238                continue;
239            }
240            // x - 0 = x (via AddScalar with -0)
241            // x / 1 = x (via MulScalar with 1)
242            // --x = x
243            Op::Neg { input } => {
244                let actual_input = node_map.get(input).copied().unwrap_or(*input);
245                if let Some(input_node) = new_graph.nodes().iter().find(|n| n.id == actual_input) {
246                    if let Op::Neg { input: inner } = &input_node.op {
247                        node_map.insert(node.id, *inner);
248                        continue;
249                    }
250                }
251                Op::Neg { input: actual_input }
252            }
253            other => remap_op(other, &node_map),
254        };
255
256        let new_id = new_graph.add_node(simplified_op, node.dtype, node.shape.clone());
257        node_map.insert(node.id, new_id);
258    }
259
260    // Remap inputs and outputs
261    for (name, id) in graph.inputs() {
262        if let Some(&new_id) = node_map.get(id) {
263            new_graph.register_input(name, new_id);
264        }
265    }
266    for (name, id) in graph.outputs() {
267        if let Some(&new_id) = node_map.get(id) {
268            new_graph.register_output(name, new_id);
269        }
270    }
271
272    new_graph
273}
274
275/// Strength reduction: replace expensive ops with cheaper equivalents.
276fn strength_reduction(graph: Graph) -> Graph {
277    let mut new_graph = Graph::new();
278    let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
279
280    for node in graph.nodes() {
281        let reduced_op = match &node.op {
282            // x^2 -> x * x
283            Op::Pow { .. } => {
284                // Check if exp is constant 2
285                // For now, just pass through
286                remap_op(&node.op, &node_map)
287            }
288            // x / c -> x * (1/c) for constant c
289            Op::Div { .. } => {
290                // Would need to check if rhs is constant
291                remap_op(&node.op, &node_map)
292            }
293            other => remap_op(other, &node_map),
294        };
295
296        let new_id = new_graph.add_node(reduced_op, node.dtype, node.shape.clone());
297        node_map.insert(node.id, new_id);
298    }
299
300    // Remap inputs and outputs
301    for (name, id) in graph.inputs() {
302        if let Some(&new_id) = node_map.get(id) {
303            new_graph.register_input(name, new_id);
304        }
305    }
306    for (name, id) in graph.outputs() {
307        if let Some(&new_id) = node_map.get(id) {
308            new_graph.register_output(name, new_id);
309        }
310    }
311
312    new_graph
313}
314
315/// Remaps node IDs in an operation using the provided mapping.
316fn remap_op(op: &Op, node_map: &FxHashMap<NodeId, NodeId>) -> Op {
317    let remap = |id: &NodeId| node_map.get(id).copied().unwrap_or(*id);
318
319    match op {
320        Op::Input { name } => Op::Input { name: name.clone() },
321        Op::Output { name, input } => Op::Output { name: name.clone(), input: remap(input) },
322        Op::Constant { value } => Op::Constant { value: *value },
323
324        Op::Add { lhs, rhs } => Op::Add { lhs: remap(lhs), rhs: remap(rhs) },
325        Op::Sub { lhs, rhs } => Op::Sub { lhs: remap(lhs), rhs: remap(rhs) },
326        Op::Mul { lhs, rhs } => Op::Mul { lhs: remap(lhs), rhs: remap(rhs) },
327        Op::Div { lhs, rhs } => Op::Div { lhs: remap(lhs), rhs: remap(rhs) },
328        Op::Pow { base, exp } => Op::Pow { base: remap(base), exp: remap(exp) },
329        Op::Max { lhs, rhs } => Op::Max { lhs: remap(lhs), rhs: remap(rhs) },
330        Op::Min { lhs, rhs } => Op::Min { lhs: remap(lhs), rhs: remap(rhs) },
331
332        Op::Neg { input } => Op::Neg { input: remap(input) },
333        Op::Abs { input } => Op::Abs { input: remap(input) },
334        Op::Sqrt { input } => Op::Sqrt { input: remap(input) },
335        Op::Exp { input } => Op::Exp { input: remap(input) },
336        Op::Log { input } => Op::Log { input: remap(input) },
337        Op::Sin { input } => Op::Sin { input: remap(input) },
338        Op::Cos { input } => Op::Cos { input: remap(input) },
339        Op::Tanh { input } => Op::Tanh { input: remap(input) },
340
341        Op::Relu { input } => Op::Relu { input: remap(input) },
342        Op::Sigmoid { input } => Op::Sigmoid { input: remap(input) },
343        Op::Gelu { input } => Op::Gelu { input: remap(input) },
344        Op::Silu { input } => Op::Silu { input: remap(input) },
345
346        Op::AddScalar { input, scalar } => Op::AddScalar { input: remap(input), scalar: *scalar },
347        Op::MulScalar { input, scalar } => Op::MulScalar { input: remap(input), scalar: *scalar },
348
349        Op::Sum { input } => Op::Sum { input: remap(input) },
350        Op::SumAxis { input, axis, keepdim } => Op::SumAxis { input: remap(input), axis: *axis, keepdim: *keepdim },
351        Op::Mean { input } => Op::Mean { input: remap(input) },
352        Op::MeanAxis { input, axis, keepdim } => Op::MeanAxis { input: remap(input), axis: *axis, keepdim: *keepdim },
353        Op::MaxAxis { input, axis, keepdim } => Op::MaxAxis { input: remap(input), axis: *axis, keepdim: *keepdim },
354
355        Op::Reshape { input, shape } => Op::Reshape { input: remap(input), shape: shape.clone() },
356        Op::Transpose { input, dim0, dim1 } => Op::Transpose { input: remap(input), dim0: *dim0, dim1: *dim1 },
357        Op::Squeeze { input, dim } => Op::Squeeze { input: remap(input), dim: *dim },
358        Op::Unsqueeze { input, dim } => Op::Unsqueeze { input: remap(input), dim: *dim },
359        Op::Broadcast { input, shape } => Op::Broadcast { input: remap(input), shape: shape.clone() },
360
361        Op::MatMul { lhs, rhs } => Op::MatMul { lhs: remap(lhs), rhs: remap(rhs) },
362
363        Op::Gt { lhs, rhs } => Op::Gt { lhs: remap(lhs), rhs: remap(rhs) },
364        Op::Lt { lhs, rhs } => Op::Lt { lhs: remap(lhs), rhs: remap(rhs) },
365        Op::Eq { lhs, rhs } => Op::Eq { lhs: remap(lhs), rhs: remap(rhs) },
366
367        Op::Where { condition, x, y } => Op::Where { condition: remap(condition), x: remap(x), y: remap(y) },
368
369        Op::Cast { input, dtype } => Op::Cast { input: remap(input), dtype: *dtype },
370        Op::Contiguous { input } => Op::Contiguous { input: remap(input) },
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::trace::trace;
378
379    #[test]
380    fn test_dead_code_elimination() {
381        let graph = trace(|tracer| {
382            let a = tracer.input("a", &[2, 3]);
383            let b = tracer.input("b", &[2, 3]);
384            let _unused = a.mul(&b); // This should be eliminated
385            let c = a.add(&b);
386            tracer.output("result", c)
387        });
388
389        let optimizer = Optimizer::new();
390        let mut opt = optimizer;
391        opt.add_pass(OptimizationPass::DeadCodeElimination);
392        let optimized = opt.optimize(graph);
393
394        // Mul node should be eliminated
395        let has_mul = optimized.nodes().iter().any(|n| matches!(n.op, Op::Mul { .. }));
396        assert!(!has_mul);
397    }
398
399    #[test]
400    fn test_algebraic_simplification() {
401        let graph = trace(|tracer| {
402            let x = tracer.input("x", &[2, 3]);
403            let y = x.mul_scalar(1.0); // Should be simplified to x
404            tracer.output("y", y)
405        });
406
407        let mut optimizer = Optimizer::new();
408        optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
409        let optimized = optimizer.optimize(graph);
410
411        // MulScalar(1.0) should be eliminated
412        let has_mul_scalar = optimized.nodes().iter().any(|n| matches!(n.op, Op::MulScalar { .. }));
413        assert!(!has_mul_scalar);
414    }
415
416    #[test]
417    fn test_constant_folding() {
418        let graph = trace(|tracer| {
419            let x = tracer.input("x", &[2, 3]);
420            let y = x.mul_scalar(0.0); // Should become constant 0
421            tracer.output("y", y)
422        });
423
424        let mut optimizer = Optimizer::new();
425        optimizer.add_pass(OptimizationPass::ConstantFolding);
426        let optimized = optimizer.optimize(graph);
427
428        // Should have a Constant node
429        let has_constant = optimized.nodes().iter().any(|n| matches!(n.op, Op::Constant { .. }));
430        assert!(has_constant);
431    }
432}