cubecl_opt/
lib.rs

1//! # CubeCL Optimizer
2//!
3//! A library that parses CubeCL IR into a
4//! [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph), transforms it to
5//! [static single-assignment form](https://en.wikipedia.org/wiki/Static_single-assignment_form)
6//! and runs various optimizations on it.
7//! The order of operations is as follows:
8//!
9//! 1. Parse root scope recursively into a [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph)
10//! 2. Run optimizations that must be done before SSA transformation
11//! 3. Analyze variable liveness
12//! 4. Transform the graph to [pruned SSA](https://en.wikipedia.org/wiki/Static_single-assignment_form#Pruned_SSA) form
13//! 5. Run post-SSA optimizations and analyses in a loop until no more improvements are found
14//! 6. Speed
15//!
16//! The output is represented as a [`petgraph`] graph of [`BasicBlock`]s terminated by [`ControlFlow`].
17//! This can then be compiled into actual executable code by walking the graph and generating all
18//! phi nodes, instructions and branches.
19//!
20//! # Representing [`PhiInstruction`] in non-SSA languages
21//!
22//! Phi instructions can be simulated by generating a mutable variable for each phi, then assigning
23//! `value` to it in each relevant `block`.
24//!
25
26use std::{
27    collections::{HashMap, VecDeque},
28    ops::{Deref, DerefMut},
29    rc::Rc,
30    sync::atomic::{AtomicUsize, Ordering},
31};
32
33use analyses::{dominance::DomFrontiers, liveness::Liveness, writes::Writes, AnalysisCache};
34use cubecl_core::{
35    ir::{self as core, Allocator, Branch, Id, Operation, Operator, Variable, VariableKind},
36    CubeDim,
37};
38use cubecl_core::{
39    ir::{Item, Scope},
40    ExecutionMode,
41};
42use gvn::GvnPass;
43use passes::{
44    CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
45    EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
46    EmptyBranchToSelect, InBoundsToUnchecked, InlineAssignments, MergeBlocks, MergeSameExpressions,
47    OptimizerPass, ReduceStrength, RemoveIndexScalar,
48};
49use petgraph::{prelude::StableDiGraph, visit::EdgeRef, Direction};
50
51mod analyses;
52mod block;
53mod control_flow;
54mod debug;
55mod gvn;
56mod instructions;
57mod passes;
58mod phi_frontiers;
59mod version;
60
61pub use block::*;
62pub use control_flow::*;
63pub use petgraph::graph::{EdgeIndex, NodeIndex};
64pub use version::PhiInstruction;
65
66/// An atomic counter with a simplified interface.
67#[derive(Clone, Debug, Default)]
68pub struct AtomicCounter {
69    inner: Rc<AtomicUsize>,
70}
71
72impl AtomicCounter {
73    /// Creates a new counter with `val` as its initial value.
74    pub fn new(val: usize) -> Self {
75        Self {
76            inner: Rc::new(AtomicUsize::new(val)),
77        }
78    }
79
80    /// Increments the counter and returns the last count.
81    pub fn inc(&self) -> usize {
82        self.inner.fetch_add(1, Ordering::AcqRel)
83    }
84
85    /// Gets the value of the counter without incrementing it.
86    pub fn get(&self) -> usize {
87        self.inner.load(Ordering::Acquire)
88    }
89}
90
91#[derive(Debug, Clone)]
92pub struct ConstArray {
93    pub id: Id,
94    pub length: u32,
95    pub item: Item,
96    pub values: Vec<core::Variable>,
97}
98
99#[derive(Default, Debug, Clone)]
100struct Program {
101    pub const_arrays: Vec<ConstArray>,
102    pub variables: HashMap<Id, Item>,
103    pub graph: StableDiGraph<BasicBlock, ()>,
104    root: NodeIndex,
105}
106
107impl Deref for Program {
108    type Target = StableDiGraph<BasicBlock, ()>;
109
110    fn deref(&self) -> &Self::Target {
111        &self.graph
112    }
113}
114
115impl DerefMut for Program {
116    fn deref_mut(&mut self) -> &mut Self::Target {
117        &mut self.graph
118    }
119}
120
121type VarId = (Id, u16);
122
123/// An optimizer that applies various analyses and optimization passes to the IR.
124#[derive(Debug, Clone)]
125pub struct Optimizer {
126    /// The overall program state
127    program: Program,
128    /// Allocator for kernel
129    pub allocator: Allocator,
130    /// Analyses with persistent state
131    analysis_cache: Rc<AnalysisCache>,
132    /// The current block while parsing
133    current_block: Option<NodeIndex>,
134    /// The current loop's break target
135    loop_break: VecDeque<NodeIndex>,
136    /// The single return block
137    pub ret: NodeIndex,
138    /// Root scope to allocate variables on
139    root_scope: Scope,
140    /// The `CubeDim` used for range analysis
141    pub(crate) cube_dim: CubeDim,
142    /// The execution mode, `Unchecked` skips bounds check optimizations.
143    pub(crate) mode: ExecutionMode,
144}
145
146impl Default for Optimizer {
147    fn default() -> Self {
148        Self {
149            program: Default::default(),
150            allocator: Default::default(),
151            current_block: Default::default(),
152            loop_break: Default::default(),
153            ret: Default::default(),
154            root_scope: Scope::root(),
155            cube_dim: Default::default(),
156            mode: Default::default(),
157            analysis_cache: Default::default(),
158        }
159    }
160}
161
162impl Optimizer {
163    /// Create a new optimizer with the scope, `CubeDim` and execution mode passed into the compiler.
164    /// Parses the scope and runs several optimization and analysis loops.
165    pub fn new(expand: Scope, cube_dim: CubeDim, mode: ExecutionMode) -> Self {
166        let mut opt = Self {
167            root_scope: expand.clone(),
168            cube_dim,
169            mode,
170            allocator: expand.allocator.clone(),
171            ..Default::default()
172        };
173        opt.run_opt(expand);
174
175        opt
176    }
177
178    /// Run all optimizations
179    fn run_opt(&mut self, expand: Scope) {
180        self.parse_graph(expand);
181        self.split_critical_edges();
182        self.apply_pre_ssa_passes();
183        self.exempt_index_assign_locals();
184        self.ssa_transform();
185        self.apply_post_ssa_passes();
186
187        // Special expensive passes that should only run once.
188        // Need more optimization rounds in between.
189
190        let arrays_prop = AtomicCounter::new(0);
191        CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
192        if arrays_prop.get() > 0 {
193            self.invalidate_analysis::<Liveness>();
194            self.ssa_transform();
195            self.apply_post_ssa_passes();
196        }
197
198        let gvn_count = AtomicCounter::new(0);
199        GvnPass.apply_post_ssa(self, gvn_count.clone());
200        ReduceStrength.apply_post_ssa(self, gvn_count.clone());
201        CopyTransform.apply_post_ssa(self, gvn_count.clone());
202
203        if gvn_count.get() > 0 {
204            self.apply_post_ssa_passes();
205        }
206
207        MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
208    }
209
210    /// The entry block of the program
211    pub fn entry(&self) -> NodeIndex {
212        self.program.root
213    }
214
215    fn parse_graph(&mut self, scope: Scope) {
216        let entry = self.program.add_node(BasicBlock::default());
217        self.program.root = entry;
218        self.current_block = Some(entry);
219        self.ret = self.program.add_node(BasicBlock::default());
220        *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
221        self.parse_scope(scope);
222        if let Some(current_block) = self.current_block {
223            self.program.add_edge(current_block, self.ret, ());
224        }
225        // Analyses shouldn't have run at this point, but just in case they have, invalidate
226        // all analyses that depend on the graph
227        self.invalidate_structure();
228    }
229
230    fn apply_pre_ssa_passes(&mut self) {
231        // Currently only one pre-ssa pass, but might add more
232        let mut passes = vec![CompositeMerge];
233        loop {
234            let counter = AtomicCounter::default();
235
236            for pass in &mut passes {
237                pass.apply_pre_ssa(self, counter.clone());
238            }
239
240            if counter.get() == 0 {
241                break;
242            }
243        }
244    }
245
246    fn apply_post_ssa_passes(&mut self) {
247        // Passes that run regardless of execution mode
248        let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
249            Box::new(InlineAssignments),
250            Box::new(EliminateUnusedVariables),
251            Box::new(ConstOperandSimplify),
252            Box::new(MergeSameExpressions),
253            Box::new(ConstEval),
254            Box::new(RemoveIndexScalar),
255            Box::new(EliminateConstBranches),
256            Box::new(EmptyBranchToSelect),
257            Box::new(EliminateDeadBlocks),
258            Box::new(EliminateDeadPhi),
259        ];
260
261        loop {
262            let counter = AtomicCounter::default();
263            for pass in &mut passes {
264                pass.apply_post_ssa(self, counter.clone());
265            }
266
267            if counter.get() == 0 {
268                break;
269            }
270        }
271
272        // Only replace indexing when checked, since all indexes are unchecked anyways in unchecked
273        if matches!(self.mode, ExecutionMode::Checked) {
274            InBoundsToUnchecked.apply_post_ssa(self, AtomicCounter::new(0));
275        }
276    }
277
278    /// Remove non-constant index vectors from SSA transformation because they currently must be
279    /// mutated
280    fn exempt_index_assign_locals(&mut self) {
281        for node in self.node_ids() {
282            let ops = self.program[node].ops.clone();
283            for op in ops.borrow().values() {
284                if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation {
285                    if let VariableKind::LocalMut { id } = &op.out().kind {
286                        self.program.variables.remove(id);
287                    }
288                }
289            }
290        }
291    }
292
293    /// A set of node indices for all blocks in the program
294    pub fn node_ids(&self) -> Vec<NodeIndex> {
295        self.program.node_indices().collect()
296    }
297
298    fn ssa_transform(&mut self) {
299        self.place_phi_nodes();
300        self.version_program();
301        self.program.variables.clear();
302        self.invalidate_analysis::<Writes>();
303        self.invalidate_analysis::<DomFrontiers>();
304    }
305
306    /// Mutable reference to the current basic block
307    pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
308        &mut self.program[self.current_block.unwrap()]
309    }
310
311    /// List of predecessor IDs of the `block`
312    pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
313        self.program
314            .edges_directed(block, Direction::Incoming)
315            .map(|it| it.source())
316            .collect()
317    }
318
319    /// List of successor IDs of the `block`
320    pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
321        self.program
322            .edges_directed(block, Direction::Outgoing)
323            .map(|it| it.target())
324            .collect()
325    }
326
327    /// Reference to the [`BasicBlock`] with ID `block`
328    #[track_caller]
329    pub fn block(&self, block: NodeIndex) -> &BasicBlock {
330        &self.program[block]
331    }
332
333    /// Reference to the [`BasicBlock`] with ID `block`
334    #[track_caller]
335    pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
336        &mut self.program[block]
337    }
338
339    /// Recursively parse a scope into the graph
340    pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
341        let processed = scope.process();
342
343        for var in processed.variables {
344            if let VariableKind::LocalMut { id } = var.kind {
345                self.program.variables.insert(id, var.item);
346            }
347        }
348
349        for (var, values) in scope.const_arrays {
350            let VariableKind::ConstantArray { id, length } = var.kind else {
351                unreachable!()
352            };
353            self.program.const_arrays.push(ConstArray {
354                id,
355                length,
356                item: var.item,
357                values,
358            });
359        }
360
361        let is_break = processed.operations.contains(&Branch::Break.into());
362
363        for mut instruction in processed.operations {
364            match &mut instruction.operation {
365                Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
366                _ => {
367                    self.current_block_mut().ops.borrow_mut().push(instruction);
368                }
369            }
370        }
371
372        is_break
373    }
374
375    /// Gets the `id` and `depth` of the variable if it's a `Local` and not atomic, `None` otherwise.
376    pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
377        match variable.kind {
378            core::VariableKind::LocalMut { id } if !variable.item.elem.is_atomic() => Some(id),
379            _ => None,
380        }
381    }
382
383    pub(crate) fn ret(&mut self) -> NodeIndex {
384        if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
385            let new_ret = self.program.add_node(BasicBlock::default());
386            self.program.add_edge(new_ret, self.ret, ());
387            self.ret = new_ret;
388            self.invalidate_structure();
389            new_ret
390        } else {
391            self.ret
392        }
393    }
394
395    pub fn const_arrays(&self) -> Vec<ConstArray> {
396        self.program.const_arrays.clone()
397    }
398}
399
400/// A visitor that does nothing.
401pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
402
403#[cfg(test)]
404mod test {
405    use cubecl::prelude::*;
406    use cubecl_core::{
407        self as cubecl,
408        ir::{Elem, Item, UIntKind, Variable, VariableKind},
409        prelude::{Array, CubeContext, ExpandElement},
410    };
411    use cubecl_core::{cube, CubeDim, ExecutionMode};
412
413    use crate::Optimizer;
414
415    #[allow(unused)]
416    #[cube(launch)]
417    fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
418        let mut y = 0;
419        let mut z = 0;
420        if cond == 0 {
421            y = x + 4;
422        }
423        z = x + 4;
424        out[0] = y;
425        out[1] = z;
426    }
427
428    #[test]
429    #[ignore = "no good way to assert opt is applied"]
430    fn test_pre() {
431        let mut ctx = CubeContext::root();
432        let x = ExpandElement::Plain(Variable::new(
433            VariableKind::GlobalScalar(0),
434            Item::new(Elem::UInt(UIntKind::U32)),
435        ));
436        let cond = ExpandElement::Plain(Variable::new(
437            VariableKind::GlobalScalar(1),
438            Item::new(Elem::UInt(UIntKind::U32)),
439        ));
440        let arr = ExpandElement::Plain(Variable::new(
441            VariableKind::GlobalOutputArray(0),
442            Item::new(Elem::UInt(UIntKind::U32)),
443        ));
444
445        pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
446        let scope = ctx.into_scope();
447        let opt = Optimizer::new(scope, CubeDim::default(), ExecutionMode::Checked);
448        println!("{opt}")
449    }
450}