cubecl_opt/
lib.rs

1//! # CubeCL Optimizer
2//!
3//! A library that parses CubeCL IR into a
4//! [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph), transforms it to
5//! [static single-assignment form](https://en.wikipedia.org/wiki/Static_single-assignment_form)
6//! and runs various optimizations on it.
7//! The order of operations is as follows:
8//!
9//! 1. Parse root scope recursively into a [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph)
10//! 2. Run optimizations that must be done before SSA transformation
11//! 3. Analyze variable liveness
12//! 4. Transform the graph to [pruned SSA](https://en.wikipedia.org/wiki/Static_single-assignment_form#Pruned_SSA) form
13//! 5. Run post-SSA optimizations and analyses in a loop until no more improvements are found
14//! 6. Speed
15//!
16//! The output is represented as a [`petgraph`] graph of [`BasicBlock`]s terminated by [`ControlFlow`].
17//! This can then be compiled into actual executable code by walking the graph and generating all
18//! phi nodes, instructions and branches.
19//!
20//! # Representing [`PhiInstruction`] in non-SSA languages
21//!
22//! Phi instructions can be simulated by generating a mutable variable for each phi, then assigning
23//! `value` to it in each relevant `block`.
24//!
25
26use std::{
27    collections::{HashMap, VecDeque},
28    ops::{Deref, DerefMut},
29    rc::Rc,
30    sync::atomic::{AtomicUsize, Ordering},
31};
32
33use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
34use cubecl_common::{CubeDim, ExecutionMode};
35use cubecl_ir::{
36    self as core, Allocator, Branch, Id, Item, Operation, Operator, Scope, Variable, VariableKind,
37};
38use gvn::GvnPass;
39use passes::{
40    CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
41    EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
42    EmptyBranchToSelect, InBoundsToUnchecked, InlineAssignments, MergeBlocks, MergeSameExpressions,
43    OptimizerPass, ReduceStrength, RemoveIndexScalar,
44};
45use petgraph::{Direction, prelude::StableDiGraph, visit::EdgeRef};
46
47mod analyses;
48mod block;
49mod control_flow;
50mod debug;
51mod gvn;
52mod instructions;
53mod passes;
54mod phi_frontiers;
55mod transformers;
56mod version;
57
58pub use analyses::uniformity::Uniformity;
59pub use block::*;
60pub use control_flow::*;
61pub use petgraph::graph::{EdgeIndex, NodeIndex};
62pub use transformers::*;
63pub use version::PhiInstruction;
64
65/// An atomic counter with a simplified interface.
66#[derive(Clone, Debug, Default)]
67pub struct AtomicCounter {
68    inner: Rc<AtomicUsize>,
69}
70
71impl AtomicCounter {
72    /// Creates a new counter with `val` as its initial value.
73    pub fn new(val: usize) -> Self {
74        Self {
75            inner: Rc::new(AtomicUsize::new(val)),
76        }
77    }
78
79    /// Increments the counter and returns the last count.
80    pub fn inc(&self) -> usize {
81        self.inner.fetch_add(1, Ordering::AcqRel)
82    }
83
84    /// Gets the value of the counter without incrementing it.
85    pub fn get(&self) -> usize {
86        self.inner.load(Ordering::Acquire)
87    }
88}
89
90#[derive(Debug, Clone)]
91pub struct ConstArray {
92    pub id: Id,
93    pub length: u32,
94    pub item: Item,
95    pub values: Vec<core::Variable>,
96}
97
98#[derive(Default, Debug, Clone)]
99struct Program {
100    pub const_arrays: Vec<ConstArray>,
101    pub variables: HashMap<Id, Item>,
102    pub graph: StableDiGraph<BasicBlock, ()>,
103    root: NodeIndex,
104}
105
106impl Deref for Program {
107    type Target = StableDiGraph<BasicBlock, ()>;
108
109    fn deref(&self) -> &Self::Target {
110        &self.graph
111    }
112}
113
114impl DerefMut for Program {
115    fn deref_mut(&mut self) -> &mut Self::Target {
116        &mut self.graph
117    }
118}
119
120type VarId = (Id, u16);
121
122/// An optimizer that applies various analyses and optimization passes to the IR.
123#[derive(Debug, Clone)]
124pub struct Optimizer {
125    /// The overall program state
126    program: Program,
127    /// Allocator for kernel
128    pub allocator: Allocator,
129    /// Analyses with persistent state
130    analysis_cache: Rc<AnalysisCache>,
131    /// The current block while parsing
132    current_block: Option<NodeIndex>,
133    /// The current loop's break target
134    loop_break: VecDeque<NodeIndex>,
135    /// The single return block
136    pub ret: NodeIndex,
137    /// Root scope to allocate variables on
138    pub root_scope: Scope,
139    /// The `CubeDim` used for range analysis
140    pub(crate) cube_dim: CubeDim,
141    /// The execution mode, `Unchecked` skips bounds check optimizations.
142    pub(crate) mode: ExecutionMode,
143    pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
144}
145
146impl Default for Optimizer {
147    fn default() -> Self {
148        Self {
149            program: Default::default(),
150            allocator: Default::default(),
151            current_block: Default::default(),
152            loop_break: Default::default(),
153            ret: Default::default(),
154            root_scope: Scope::root(false),
155            cube_dim: Default::default(),
156            mode: Default::default(),
157            analysis_cache: Default::default(),
158            transformers: Default::default(),
159        }
160    }
161}
162
163impl Optimizer {
164    /// Create a new optimizer with the scope, `CubeDim` and execution mode passed into the compiler.
165    /// Parses the scope and runs several optimization and analysis loops.
166    pub fn new(
167        expand: Scope,
168        cube_dim: CubeDim,
169        mode: ExecutionMode,
170        transformers: Vec<Rc<dyn IrTransformer>>,
171    ) -> Self {
172        let mut opt = Self {
173            root_scope: expand.clone(),
174            cube_dim,
175            mode,
176            allocator: expand.allocator.clone(),
177            transformers,
178            ..Default::default()
179        };
180        opt.run_opt();
181
182        opt
183    }
184
185    /// Run all optimizations
186    fn run_opt(&mut self) {
187        self.parse_graph(self.root_scope.clone());
188        self.split_critical_edges();
189        self.apply_pre_ssa_passes();
190        self.exempt_index_assign_locals();
191        self.ssa_transform();
192        self.apply_post_ssa_passes();
193
194        // Special expensive passes that should only run once.
195        // Need more optimization rounds in between.
196
197        let arrays_prop = AtomicCounter::new(0);
198        CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
199        if arrays_prop.get() > 0 {
200            self.invalidate_analysis::<Liveness>();
201            self.ssa_transform();
202            self.apply_post_ssa_passes();
203        }
204
205        let gvn_count = AtomicCounter::new(0);
206        GvnPass.apply_post_ssa(self, gvn_count.clone());
207        ReduceStrength.apply_post_ssa(self, gvn_count.clone());
208        CopyTransform.apply_post_ssa(self, gvn_count.clone());
209
210        if gvn_count.get() > 0 {
211            self.apply_post_ssa_passes();
212        }
213
214        MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
215    }
216
217    /// The entry block of the program
218    pub fn entry(&self) -> NodeIndex {
219        self.program.root
220    }
221
222    fn parse_graph(&mut self, scope: Scope) {
223        let entry = self.program.add_node(BasicBlock::default());
224        self.program.root = entry;
225        self.current_block = Some(entry);
226        self.ret = self.program.add_node(BasicBlock::default());
227        *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
228        self.parse_scope(scope);
229        if let Some(current_block) = self.current_block {
230            self.program.add_edge(current_block, self.ret, ());
231        }
232        // Analyses shouldn't have run at this point, but just in case they have, invalidate
233        // all analyses that depend on the graph
234        self.invalidate_structure();
235    }
236
237    fn apply_pre_ssa_passes(&mut self) {
238        // Currently only one pre-ssa pass, but might add more
239        let mut passes = vec![CompositeMerge];
240        loop {
241            let counter = AtomicCounter::default();
242
243            for pass in &mut passes {
244                pass.apply_pre_ssa(self, counter.clone());
245            }
246
247            if counter.get() == 0 {
248                break;
249            }
250        }
251    }
252
253    fn apply_post_ssa_passes(&mut self) {
254        // Passes that run regardless of execution mode
255        let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
256            Box::new(InlineAssignments),
257            Box::new(EliminateUnusedVariables),
258            Box::new(ConstOperandSimplify),
259            Box::new(MergeSameExpressions),
260            Box::new(ConstEval),
261            Box::new(RemoveIndexScalar),
262            Box::new(EliminateConstBranches),
263            Box::new(EmptyBranchToSelect),
264            Box::new(EliminateDeadBlocks),
265            Box::new(EliminateDeadPhi),
266        ];
267
268        loop {
269            let counter = AtomicCounter::default();
270            for pass in &mut passes {
271                pass.apply_post_ssa(self, counter.clone());
272            }
273
274            if counter.get() == 0 {
275                break;
276            }
277        }
278
279        // Only replace indexing when checked, since all indexes are unchecked anyways in unchecked
280        if matches!(self.mode, ExecutionMode::Checked) {
281            InBoundsToUnchecked.apply_post_ssa(self, AtomicCounter::new(0));
282        }
283    }
284
285    /// Remove non-constant index vectors from SSA transformation because they currently must be
286    /// mutated
287    fn exempt_index_assign_locals(&mut self) {
288        for node in self.node_ids() {
289            let ops = self.program[node].ops.clone();
290            for op in ops.borrow().values() {
291                if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation {
292                    if let VariableKind::LocalMut { id } = &op.out().kind {
293                        self.program.variables.remove(id);
294                    }
295                }
296            }
297        }
298    }
299
300    /// A set of node indices for all blocks in the program
301    pub fn node_ids(&self) -> Vec<NodeIndex> {
302        self.program.node_indices().collect()
303    }
304
305    fn ssa_transform(&mut self) {
306        self.place_phi_nodes();
307        self.version_program();
308        self.program.variables.clear();
309        self.invalidate_analysis::<Writes>();
310        self.invalidate_analysis::<DomFrontiers>();
311    }
312
313    /// Mutable reference to the current basic block
314    pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
315        &mut self.program[self.current_block.unwrap()]
316    }
317
318    /// List of predecessor IDs of the `block`
319    pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
320        self.program
321            .edges_directed(block, Direction::Incoming)
322            .map(|it| it.source())
323            .collect()
324    }
325
326    /// List of successor IDs of the `block`
327    pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
328        self.program
329            .edges_directed(block, Direction::Outgoing)
330            .map(|it| it.target())
331            .collect()
332    }
333
334    /// Reference to the [`BasicBlock`] with ID `block`
335    #[track_caller]
336    pub fn block(&self, block: NodeIndex) -> &BasicBlock {
337        &self.program[block]
338    }
339
340    /// Reference to the [`BasicBlock`] with ID `block`
341    #[track_caller]
342    pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
343        &mut self.program[block]
344    }
345
346    /// Recursively parse a scope into the graph
347    pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
348        let processed = scope.process();
349
350        for var in processed.variables {
351            if let VariableKind::LocalMut { id } = var.kind {
352                self.program.variables.insert(id, var.item);
353            }
354        }
355
356        for (var, values) in scope.const_arrays.clone() {
357            let VariableKind::ConstantArray { id, length } = var.kind else {
358                unreachable!()
359            };
360            self.program.const_arrays.push(ConstArray {
361                id,
362                length,
363                item: var.item,
364                values,
365            });
366        }
367
368        let is_break = processed.instructions.contains(&Branch::Break.into());
369
370        for mut instruction in processed.instructions {
371            let mut removed = false;
372            for transform in self.transformers.iter() {
373                match transform.maybe_transform(&mut scope, &instruction) {
374                    TransformAction::Ignore => {}
375                    TransformAction::Replace(replacement) => {
376                        self.current_block_mut()
377                            .ops
378                            .borrow_mut()
379                            .extend(replacement);
380                        removed = true;
381                        break;
382                    }
383                    TransformAction::Remove => {
384                        removed = true;
385                        break;
386                    }
387                }
388            }
389            if removed {
390                continue;
391            }
392            match &mut instruction.operation {
393                Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
394                _ => {
395                    self.current_block_mut().ops.borrow_mut().push(instruction);
396                }
397            }
398        }
399
400        is_break
401    }
402
403    /// Gets the `id` and `depth` of the variable if it's a `Local` and not atomic, `None` otherwise.
404    pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
405        match variable.kind {
406            core::VariableKind::LocalMut { id } if !variable.item.elem.is_atomic() => Some(id),
407            _ => None,
408        }
409    }
410
411    pub(crate) fn ret(&mut self) -> NodeIndex {
412        if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
413            let new_ret = self.program.add_node(BasicBlock::default());
414            self.program.add_edge(new_ret, self.ret, ());
415            self.ret = new_ret;
416            self.invalidate_structure();
417            new_ret
418        } else {
419            self.ret
420        }
421    }
422
423    pub fn const_arrays(&self) -> Vec<ConstArray> {
424        self.program.const_arrays.clone()
425    }
426}
427
428/// A visitor that does nothing.
429pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
430
431#[cfg(test)]
432mod test {
433    use cubecl_core as cubecl;
434    use cubecl_core::cube;
435    use cubecl_core::prelude::*;
436    use cubecl_ir::{Elem, ExpandElement, Item, UIntKind, Variable, VariableKind};
437
438    use crate::Optimizer;
439
440    #[allow(unused)]
441    #[cube(launch)]
442    fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
443        let mut y = 0;
444        let mut z = 0;
445        if cond == 0 {
446            y = x + 4;
447        }
448        z = x + 4;
449        out[0] = y;
450        out[1] = z;
451    }
452
453    #[test]
454    #[ignore = "no good way to assert opt is applied"]
455    fn test_pre() {
456        let mut ctx = Scope::root(false);
457        let x = ExpandElement::Plain(Variable::new(
458            VariableKind::GlobalScalar(0),
459            Item::new(Elem::UInt(UIntKind::U32)),
460        ));
461        let cond = ExpandElement::Plain(Variable::new(
462            VariableKind::GlobalScalar(1),
463            Item::new(Elem::UInt(UIntKind::U32)),
464        ));
465        let arr = ExpandElement::Plain(Variable::new(
466            VariableKind::GlobalOutputArray(0),
467            Item::new(Elem::UInt(UIntKind::U32)),
468        ));
469
470        pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
471        let opt = Optimizer::new(ctx, CubeDim::default(), ExecutionMode::Checked, vec![]);
472        println!("{opt}")
473    }
474}