cubecl_opt/
lib.rs

1//! # CubeCL Optimizer
2//!
3//! A library that parses CubeCL IR into a
4//! [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph), transforms it to
5//! [static single-assignment form](https://en.wikipedia.org/wiki/Static_single-assignment_form)
6//! and runs various optimizations on it.
7//! The order of operations is as follows:
8//!
9//! 1. Parse root scope recursively into a [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph)
10//! 2. Run optimizations that must be done before SSA transformation
11//! 3. Analyze variable liveness
12//! 4. Transform the graph to [pruned SSA](https://en.wikipedia.org/wiki/Static_single-assignment_form#Pruned_SSA) form
13//! 5. Run post-SSA optimizations and analyses in a loop until no more improvements are found
14//! 6. Speed
15//!
16//! The output is represented as a [`petgraph`] graph of [`BasicBlock`]s terminated by [`ControlFlow`].
17//! This can then be compiled into actual executable code by walking the graph and generating all
18//! phi nodes, instructions and branches.
19//!
20//! # Representing [`PhiInstruction`] in non-SSA languages
21//!
22//! Phi instructions can be simulated by generating a mutable variable for each phi, then assigning
23//! `value` to it in each relevant `block`.
24//!
25
26#![allow(unknown_lints, unnecessary_transmutes)]
27
28use std::{
29    collections::{HashMap, VecDeque},
30    ops::{Deref, DerefMut},
31    rc::Rc,
32    sync::atomic::{AtomicUsize, Ordering},
33};
34
35use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
36use cubecl_common::{CubeDim, ExecutionMode};
37use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
38use cubecl_ir::{
39    self as core, Allocator, Branch, Id, Item, Operation, Operator, Processor, Scope, Variable,
40    VariableKind,
41};
42use gvn::GvnPass;
43use passes::{
44    CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
45    EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
46    EmptyBranchToSelect, InlineAssignments, MergeBlocks, MergeSameExpressions, OptimizerPass,
47    ReduceStrength, RemoveIndexScalar,
48};
49use petgraph::{
50    Direction,
51    dot::{Config, Dot},
52    prelude::StableDiGraph,
53    visit::EdgeRef,
54};
55
56mod analyses;
57mod block;
58mod control_flow;
59mod debug;
60mod gvn;
61mod instructions;
62mod passes;
63mod phi_frontiers;
64mod transformers;
65mod version;
66
67pub use analyses::uniformity::Uniformity;
68pub use block::*;
69pub use control_flow::*;
70pub use petgraph::graph::{EdgeIndex, NodeIndex};
71pub use transformers::*;
72pub use version::PhiInstruction;
73
74/// An atomic counter with a simplified interface.
75#[derive(Clone, Debug, Default)]
76pub struct AtomicCounter {
77    inner: Rc<AtomicUsize>,
78}
79
80impl AtomicCounter {
81    /// Creates a new counter with `val` as its initial value.
82    pub fn new(val: usize) -> Self {
83        Self {
84            inner: Rc::new(AtomicUsize::new(val)),
85        }
86    }
87
88    /// Increments the counter and returns the last count.
89    pub fn inc(&self) -> usize {
90        self.inner.fetch_add(1, Ordering::AcqRel)
91    }
92
93    /// Gets the value of the counter without incrementing it.
94    pub fn get(&self) -> usize {
95        self.inner.load(Ordering::Acquire)
96    }
97}
98
99#[derive(Debug, Clone)]
100pub struct ConstArray {
101    pub id: Id,
102    pub length: u32,
103    pub item: Item,
104    pub values: Vec<core::Variable>,
105}
106
107#[derive(Default, Debug, Clone)]
108struct Program {
109    pub const_arrays: Vec<ConstArray>,
110    pub variables: HashMap<Id, Item>,
111    pub graph: StableDiGraph<BasicBlock, u32>,
112    root: NodeIndex,
113}
114
115impl Deref for Program {
116    type Target = StableDiGraph<BasicBlock, u32>;
117
118    fn deref(&self) -> &Self::Target {
119        &self.graph
120    }
121}
122
123impl DerefMut for Program {
124    fn deref_mut(&mut self) -> &mut Self::Target {
125        &mut self.graph
126    }
127}
128
129type VarId = (Id, u16);
130
131/// An optimizer that applies various analyses and optimization passes to the IR.
132#[derive(Debug, Clone)]
133pub struct Optimizer {
134    /// The overall program state
135    program: Program,
136    /// Allocator for kernel
137    pub allocator: Allocator,
138    /// Analyses with persistent state
139    analysis_cache: Rc<AnalysisCache>,
140    /// The current block while parsing
141    current_block: Option<NodeIndex>,
142    /// The current loop's break target
143    loop_break: VecDeque<NodeIndex>,
144    /// The single return block
145    pub ret: NodeIndex,
146    /// Root scope to allocate variables on
147    pub root_scope: Scope,
148    /// The `CubeDim` used for range analysis
149    pub(crate) cube_dim: CubeDim,
150    /// The execution mode, `Unchecked` skips bounds check optimizations.
151    pub(crate) mode: ExecutionMode,
152    pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
153}
154
155impl Default for Optimizer {
156    fn default() -> Self {
157        Self {
158            program: Default::default(),
159            allocator: Default::default(),
160            current_block: Default::default(),
161            loop_break: Default::default(),
162            ret: Default::default(),
163            root_scope: Scope::root(false),
164            cube_dim: Default::default(),
165            mode: Default::default(),
166            analysis_cache: Default::default(),
167            transformers: Default::default(),
168        }
169    }
170}
171
172impl Optimizer {
173    /// Create a new optimizer with the scope, `CubeDim` and execution mode passed into the compiler.
174    /// Parses the scope and runs several optimization and analysis loops.
175    pub fn new(
176        expand: Scope,
177        cube_dim: CubeDim,
178        mode: ExecutionMode,
179        transformers: Vec<Rc<dyn IrTransformer>>,
180    ) -> Self {
181        let mut opt = Self {
182            root_scope: expand.clone(),
183            cube_dim,
184            mode,
185            allocator: expand.allocator.clone(),
186            transformers,
187            ..Default::default()
188        };
189        opt.run_opt();
190
191        opt
192    }
193
194    /// Run all optimizations
195    fn run_opt(&mut self) {
196        self.parse_graph(self.root_scope.clone());
197        self.split_critical_edges();
198        self.apply_pre_ssa_passes();
199        self.exempt_index_assign_locals();
200        self.ssa_transform();
201        self.apply_post_ssa_passes();
202
203        // Special expensive passes that should only run once.
204        // Need more optimization rounds in between.
205
206        let arrays_prop = AtomicCounter::new(0);
207        CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
208        if arrays_prop.get() > 0 {
209            self.invalidate_analysis::<Liveness>();
210            self.ssa_transform();
211            self.apply_post_ssa_passes();
212        }
213
214        let gvn_count = AtomicCounter::new(0);
215        GvnPass.apply_post_ssa(self, gvn_count.clone());
216        ReduceStrength.apply_post_ssa(self, gvn_count.clone());
217        CopyTransform.apply_post_ssa(self, gvn_count.clone());
218
219        if gvn_count.get() > 0 {
220            self.apply_post_ssa_passes();
221        }
222
223        MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
224    }
225
226    /// The entry block of the program
227    pub fn entry(&self) -> NodeIndex {
228        self.program.root
229    }
230
231    fn parse_graph(&mut self, scope: Scope) {
232        let entry = self.program.add_node(BasicBlock::default());
233        self.program.root = entry;
234        self.current_block = Some(entry);
235        self.ret = self.program.add_node(BasicBlock::default());
236        *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
237        self.parse_scope(scope);
238        if let Some(current_block) = self.current_block {
239            self.program.add_edge(current_block, self.ret, 0);
240        }
241        // Analyses shouldn't have run at this point, but just in case they have, invalidate
242        // all analyses that depend on the graph
243        self.invalidate_structure();
244    }
245
246    fn apply_pre_ssa_passes(&mut self) {
247        // Currently only one pre-ssa pass, but might add more
248        let mut passes = vec![CompositeMerge];
249        loop {
250            let counter = AtomicCounter::default();
251
252            for pass in &mut passes {
253                pass.apply_pre_ssa(self, counter.clone());
254            }
255
256            if counter.get() == 0 {
257                break;
258            }
259        }
260    }
261
262    fn apply_post_ssa_passes(&mut self) {
263        // Passes that run regardless of execution mode
264        let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
265            Box::new(InlineAssignments),
266            Box::new(EliminateUnusedVariables),
267            Box::new(ConstOperandSimplify),
268            Box::new(MergeSameExpressions),
269            Box::new(ConstEval),
270            Box::new(RemoveIndexScalar),
271            Box::new(EliminateConstBranches),
272            Box::new(EmptyBranchToSelect),
273            Box::new(EliminateDeadBlocks),
274            Box::new(EliminateDeadPhi),
275        ];
276
277        loop {
278            let counter = AtomicCounter::default();
279            for pass in &mut passes {
280                pass.apply_post_ssa(self, counter.clone());
281            }
282
283            if counter.get() == 0 {
284                break;
285            }
286        }
287    }
288
289    /// Remove non-constant index vectors from SSA transformation because they currently must be
290    /// mutated
291    fn exempt_index_assign_locals(&mut self) {
292        for node in self.node_ids() {
293            let ops = self.program[node].ops.clone();
294            for op in ops.borrow().values() {
295                if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation {
296                    if let VariableKind::LocalMut { id } = &op.out().kind {
297                        self.program.variables.remove(id);
298                    }
299                }
300            }
301        }
302    }
303
304    /// A set of node indices for all blocks in the program
305    pub fn node_ids(&self) -> Vec<NodeIndex> {
306        self.program.node_indices().collect()
307    }
308
309    fn ssa_transform(&mut self) {
310        self.place_phi_nodes();
311        self.version_program();
312        self.program.variables.clear();
313        self.invalidate_analysis::<Writes>();
314        self.invalidate_analysis::<DomFrontiers>();
315    }
316
317    /// Mutable reference to the current basic block
318    pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
319        &mut self.program[self.current_block.unwrap()]
320    }
321
322    /// List of predecessor IDs of the `block`
323    pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
324        self.program
325            .edges_directed(block, Direction::Incoming)
326            .map(|it| it.source())
327            .collect()
328    }
329
330    /// List of successor IDs of the `block`
331    pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
332        self.program
333            .edges_directed(block, Direction::Outgoing)
334            .map(|it| it.target())
335            .collect()
336    }
337
338    /// Reference to the [`BasicBlock`] with ID `block`
339    #[track_caller]
340    pub fn block(&self, block: NodeIndex) -> &BasicBlock {
341        &self.program[block]
342    }
343
344    /// Reference to the [`BasicBlock`] with ID `block`
345    #[track_caller]
346    pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
347        &mut self.program[block]
348    }
349
350    /// Recursively parse a scope into the graph
351    pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
352        let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.mode));
353        let processed = scope.process([checked_io]);
354
355        for var in processed.variables {
356            if let VariableKind::LocalMut { id } = var.kind {
357                self.program.variables.insert(id, var.item);
358            }
359        }
360
361        for (var, values) in scope.const_arrays.clone() {
362            let VariableKind::ConstantArray { id, length } = var.kind else {
363                unreachable!()
364            };
365            self.program.const_arrays.push(ConstArray {
366                id,
367                length,
368                item: var.item,
369                values,
370            });
371        }
372
373        let is_break = processed.instructions.contains(&Branch::Break.into());
374
375        for mut instruction in processed.instructions {
376            let mut removed = false;
377            for transform in self.transformers.iter() {
378                match transform.maybe_transform(&mut scope, &instruction) {
379                    TransformAction::Ignore => {}
380                    TransformAction::Replace(replacement) => {
381                        self.current_block_mut()
382                            .ops
383                            .borrow_mut()
384                            .extend(replacement);
385                        removed = true;
386                        break;
387                    }
388                    TransformAction::Remove => {
389                        removed = true;
390                        break;
391                    }
392                }
393            }
394            if removed {
395                continue;
396            }
397            match &mut instruction.operation {
398                Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
399                _ => {
400                    self.current_block_mut().ops.borrow_mut().push(instruction);
401                }
402            }
403        }
404
405        is_break
406    }
407
408    /// Gets the `id` and `depth` of the variable if it's a `Local` and not atomic, `None` otherwise.
409    pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
410        match variable.kind {
411            core::VariableKind::LocalMut { id } if !variable.item.elem.is_atomic() => Some(id),
412            _ => None,
413        }
414    }
415
416    pub(crate) fn ret(&mut self) -> NodeIndex {
417        if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
418            let new_ret = self.program.add_node(BasicBlock::default());
419            self.program.add_edge(new_ret, self.ret, 0);
420            self.ret = new_ret;
421            self.invalidate_structure();
422            new_ret
423        } else {
424            self.ret
425        }
426    }
427
428    pub fn const_arrays(&self) -> Vec<ConstArray> {
429        self.program.const_arrays.clone()
430    }
431
432    pub fn dot_viz(&self) -> Dot<'_, &StableDiGraph<BasicBlock, u32>> {
433        Dot::with_config(&self.program, &[Config::EdgeNoLabel])
434    }
435}
436
437/// A visitor that does nothing.
438pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
439
440#[cfg(test)]
441mod test {
442    use cubecl_core as cubecl;
443    use cubecl_core::cube;
444    use cubecl_core::prelude::*;
445    use cubecl_ir::{Elem, ExpandElement, Item, UIntKind, Variable, VariableKind};
446
447    use crate::Optimizer;
448
449    #[allow(unused)]
450    #[cube(launch)]
451    fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
452        let mut y = 0;
453        let mut z = 0;
454        if cond == 0 {
455            y = x + 4;
456        }
457        z = x + 4;
458        out[0] = y;
459        out[1] = z;
460    }
461
462    #[test]
463    #[ignore = "no good way to assert opt is applied"]
464    fn test_pre() {
465        let mut ctx = Scope::root(false);
466        let x = ExpandElement::Plain(Variable::new(
467            VariableKind::GlobalScalar(0),
468            Item::new(Elem::UInt(UIntKind::U32)),
469        ));
470        let cond = ExpandElement::Plain(Variable::new(
471            VariableKind::GlobalScalar(1),
472            Item::new(Elem::UInt(UIntKind::U32)),
473        ));
474        let arr = ExpandElement::Plain(Variable::new(
475            VariableKind::GlobalOutputArray(0),
476            Item::new(Elem::UInt(UIntKind::U32)),
477        ));
478
479        pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
480        let opt = Optimizer::new(ctx, CubeDim::default(), ExecutionMode::Checked, vec![]);
481        println!("{opt}")
482    }
483}