cubecl_opt/
lib.rs

1//! # CubeCL Optimizer
2//!
3//! A library that parses CubeCL IR into a
4//! [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph), transforms it to
5//! [static single-assignment form](https://en.wikipedia.org/wiki/Static_single-assignment_form)
6//! and runs various optimizations on it.
7//! The order of operations is as follows:
8//!
9//! 1. Parse root scope recursively into a [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph)
10//! 2. Run optimizations that must be done before SSA transformation
11//! 3. Analyze variable liveness
12//! 4. Transform the graph to [pruned SSA](https://en.wikipedia.org/wiki/Static_single-assignment_form#Pruned_SSA) form
13//! 5. Run post-SSA optimizations and analyses in a loop until no more improvements are found
14//! 6. Speed
15//!
16//! The output is represented as a [`petgraph`] graph of [`BasicBlock`]s terminated by [`ControlFlow`].
17//! This can then be compiled into actual executable code by walking the graph and generating all
18//! phi nodes, instructions and branches.
19//!
20//! # Representing [`PhiInstruction`] in non-SSA languages
21//!
22//! Phi instructions can be simulated by generating a mutable variable for each phi, then assigning
23//! `value` to it in each relevant `block`.
24//!
25
26#![allow(unknown_lints, unnecessary_transmutes)]
27
28use std::{
29    collections::{HashMap, VecDeque},
30    ops::{Deref, DerefMut},
31    rc::Rc,
32    sync::atomic::{AtomicUsize, Ordering},
33};
34
35use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
36use cubecl_common::CubeDim;
37use cubecl_ir::{
38    self as core, Allocator, Branch, Id, Operation, Operator, Processor, Scope, Type, Variable,
39    VariableKind,
40};
41use gvn::GvnPass;
42use passes::{
43    CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
44    EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
45    EmptyBranchToSelect, InlineAssignments, MergeBlocks, MergeSameExpressions, OptimizerPass,
46    ReduceStrength, RemoveIndexScalar,
47};
48use petgraph::{
49    Direction,
50    dot::{Config, Dot},
51    prelude::StableDiGraph,
52    visit::EdgeRef,
53};
54
55mod analyses;
56mod block;
57mod control_flow;
58mod debug;
59mod gvn;
60mod instructions;
61mod passes;
62mod phi_frontiers;
63mod transformers;
64mod version;
65
66pub use analyses::uniformity::Uniformity;
67pub use block::*;
68pub use control_flow::*;
69pub use petgraph::graph::{EdgeIndex, NodeIndex};
70pub use transformers::*;
71pub use version::PhiInstruction;
72
73pub use crate::analyses::liveness::shared::SharedLiveness;
74
75/// An atomic counter with a simplified interface.
76#[derive(Clone, Debug, Default)]
77pub struct AtomicCounter {
78    inner: Rc<AtomicUsize>,
79}
80
81impl AtomicCounter {
82    /// Creates a new counter with `val` as its initial value.
83    pub fn new(val: usize) -> Self {
84        Self {
85            inner: Rc::new(AtomicUsize::new(val)),
86        }
87    }
88
89    /// Increments the counter and returns the last count.
90    pub fn inc(&self) -> usize {
91        self.inner.fetch_add(1, Ordering::AcqRel)
92    }
93
94    /// Gets the value of the counter without incrementing it.
95    pub fn get(&self) -> usize {
96        self.inner.load(Ordering::Acquire)
97    }
98}
99
100#[derive(Debug, Clone)]
101pub struct ConstArray {
102    pub id: Id,
103    pub length: u32,
104    pub item: Type,
105    pub values: Vec<core::Variable>,
106}
107
108#[derive(Default, Debug, Clone)]
109pub struct Program {
110    pub const_arrays: Vec<ConstArray>,
111    pub variables: HashMap<Id, Type>,
112    pub graph: StableDiGraph<BasicBlock, u32>,
113    root: NodeIndex,
114}
115
116impl Deref for Program {
117    type Target = StableDiGraph<BasicBlock, u32>;
118
119    fn deref(&self) -> &Self::Target {
120        &self.graph
121    }
122}
123
124impl DerefMut for Program {
125    fn deref_mut(&mut self) -> &mut Self::Target {
126        &mut self.graph
127    }
128}
129
130type VarId = (Id, u16);
131
132/// An optimizer that applies various analyses and optimization passes to the IR.
133#[derive(Debug, Clone)]
134pub struct Optimizer {
135    /// The overall program state
136    pub program: Program,
137    /// Allocator for kernel
138    pub allocator: Allocator,
139    /// Analyses with persistent state
140    analysis_cache: Rc<AnalysisCache>,
141    /// The current block while parsing
142    current_block: Option<NodeIndex>,
143    /// The current loop's break target
144    loop_break: VecDeque<NodeIndex>,
145    /// The single return block
146    pub ret: NodeIndex,
147    /// Root scope to allocate variables on
148    pub root_scope: Scope,
149    /// The `CubeDim` used for range analysis
150    pub(crate) cube_dim: CubeDim,
151    pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
152    pub(crate) processors: Rc<Vec<Box<dyn Processor>>>,
153}
154
155impl Default for Optimizer {
156    fn default() -> Self {
157        Self {
158            program: Default::default(),
159            allocator: Default::default(),
160            current_block: Default::default(),
161            loop_break: Default::default(),
162            ret: Default::default(),
163            root_scope: Scope::root(false),
164            cube_dim: Default::default(),
165            analysis_cache: Default::default(),
166            transformers: Default::default(),
167            processors: Default::default(),
168        }
169    }
170}
171
172impl Optimizer {
173    /// Create a new optimizer with the scope, `CubeDim` and execution mode passed into the compiler.
174    /// Parses the scope and runs several optimization and analysis loops.
175    pub fn new(
176        expand: Scope,
177        cube_dim: CubeDim,
178        transformers: Vec<Rc<dyn IrTransformer>>,
179        processors: Vec<Box<dyn Processor>>,
180    ) -> Self {
181        let mut opt = Self {
182            root_scope: expand.clone(),
183            cube_dim,
184            allocator: expand.allocator.clone(),
185            transformers,
186            processors: Rc::new(processors),
187            ..Default::default()
188        };
189        opt.run_opt();
190
191        opt
192    }
193
194    /// Create a new optimizer with the scope, `CubeDim` and execution mode passed into the compiler.
195    /// Parses the scope and runs several optimization and analysis loops.
196    pub fn shared_only(expand: Scope, cube_dim: CubeDim) -> Self {
197        let mut opt = Self {
198            root_scope: expand.clone(),
199            cube_dim,
200            allocator: expand.allocator.clone(),
201            transformers: Vec::new(),
202            processors: Rc::new(Vec::new()),
203            ..Default::default()
204        };
205        opt.run_shared_only();
206
207        opt
208    }
209
210    /// Run all optimizations
211    fn run_opt(&mut self) {
212        self.parse_graph(self.root_scope.clone());
213        self.split_critical_edges();
214        self.apply_pre_ssa_passes();
215        self.exempt_index_assign_locals();
216        self.ssa_transform();
217        self.apply_post_ssa_passes();
218
219        // Special expensive passes that should only run once.
220        // Need more optimization rounds in between.
221
222        let arrays_prop = AtomicCounter::new(0);
223        CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
224        if arrays_prop.get() > 0 {
225            self.invalidate_analysis::<Liveness>();
226            self.ssa_transform();
227            self.apply_post_ssa_passes();
228        }
229
230        let gvn_count = AtomicCounter::new(0);
231        GvnPass.apply_post_ssa(self, gvn_count.clone());
232        ReduceStrength.apply_post_ssa(self, gvn_count.clone());
233        CopyTransform.apply_post_ssa(self, gvn_count.clone());
234
235        if gvn_count.get() > 0 {
236            self.apply_post_ssa_passes();
237        }
238
239        self.split_free();
240        self.analysis::<SharedLiveness>();
241
242        MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
243    }
244
245    /// Run only the shared memory analysis
246    fn run_shared_only(&mut self) {
247        self.parse_graph(self.root_scope.clone());
248        self.split_critical_edges();
249        self.exempt_index_assign_locals();
250        self.ssa_transform();
251        self.split_free();
252        self.analysis::<SharedLiveness>();
253    }
254
255    /// The entry block of the program
256    pub fn entry(&self) -> NodeIndex {
257        self.program.root
258    }
259
260    fn parse_graph(&mut self, scope: Scope) {
261        let entry = self.program.add_node(BasicBlock::default());
262        self.program.root = entry;
263        self.current_block = Some(entry);
264        self.ret = self.program.add_node(BasicBlock::default());
265        *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
266        self.parse_scope(scope);
267        if let Some(current_block) = self.current_block {
268            self.program.add_edge(current_block, self.ret, 0);
269        }
270        // Analyses shouldn't have run at this point, but just in case they have, invalidate
271        // all analyses that depend on the graph
272        self.invalidate_structure();
273    }
274
275    fn apply_pre_ssa_passes(&mut self) {
276        // Currently only one pre-ssa pass, but might add more
277        let mut passes = vec![CompositeMerge];
278        loop {
279            let counter = AtomicCounter::default();
280
281            for pass in &mut passes {
282                pass.apply_pre_ssa(self, counter.clone());
283            }
284
285            if counter.get() == 0 {
286                break;
287            }
288        }
289    }
290
291    fn apply_post_ssa_passes(&mut self) {
292        // Passes that run regardless of execution mode
293        let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
294            Box::new(InlineAssignments),
295            Box::new(EliminateUnusedVariables),
296            Box::new(ConstOperandSimplify),
297            Box::new(MergeSameExpressions),
298            Box::new(ConstEval),
299            Box::new(RemoveIndexScalar),
300            Box::new(EliminateConstBranches),
301            Box::new(EmptyBranchToSelect),
302            Box::new(EliminateDeadBlocks),
303            Box::new(EliminateDeadPhi),
304        ];
305
306        loop {
307            let counter = AtomicCounter::default();
308            for pass in &mut passes {
309                pass.apply_post_ssa(self, counter.clone());
310            }
311
312            if counter.get() == 0 {
313                break;
314            }
315        }
316    }
317
318    /// Remove non-constant index vectors from SSA transformation because they currently must be
319    /// mutated
320    fn exempt_index_assign_locals(&mut self) {
321        for node in self.node_ids() {
322            let ops = self.program[node].ops.clone();
323            for op in ops.borrow().values() {
324                if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation
325                    && let VariableKind::LocalMut { id } = &op.out().kind
326                {
327                    self.program.variables.remove(id);
328                }
329            }
330        }
331    }
332
333    /// A set of node indices for all blocks in the program
334    pub fn node_ids(&self) -> Vec<NodeIndex> {
335        self.program.node_indices().collect()
336    }
337
338    fn ssa_transform(&mut self) {
339        self.place_phi_nodes();
340        self.version_program();
341        self.program.variables.clear();
342        self.invalidate_analysis::<Writes>();
343        self.invalidate_analysis::<DomFrontiers>();
344    }
345
346    /// Mutable reference to the current basic block
347    pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
348        &mut self.program[self.current_block.unwrap()]
349    }
350
351    /// List of predecessor IDs of the `block`
352    pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
353        self.program
354            .edges_directed(block, Direction::Incoming)
355            .map(|it| it.source())
356            .collect()
357    }
358
359    /// List of successor IDs of the `block`
360    pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
361        self.program
362            .edges_directed(block, Direction::Outgoing)
363            .map(|it| it.target())
364            .collect()
365    }
366
367    /// Reference to the [`BasicBlock`] with ID `block`
368    #[track_caller]
369    pub fn block(&self, block: NodeIndex) -> &BasicBlock {
370        &self.program[block]
371    }
372
373    /// Reference to the [`BasicBlock`] with ID `block`
374    #[track_caller]
375    pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
376        &mut self.program[block]
377    }
378
379    /// Recursively parse a scope into the graph
380    pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
381        let processed = scope.process(self.processors.iter().map(|it| &**it));
382
383        for var in processed.variables {
384            if let VariableKind::LocalMut { id } = var.kind {
385                self.program.variables.insert(id, var.ty);
386            }
387        }
388
389        for (var, values) in scope.const_arrays.clone() {
390            let VariableKind::ConstantArray {
391                id,
392                length,
393                unroll_factor,
394            } = var.kind
395            else {
396                unreachable!()
397            };
398            self.program.const_arrays.push(ConstArray {
399                id,
400                length: length * unroll_factor,
401                item: var.ty,
402                values,
403            });
404        }
405
406        let is_break = processed.instructions.contains(&Branch::Break.into());
407
408        for mut instruction in processed.instructions {
409            let mut removed = false;
410            for transform in self.transformers.iter() {
411                match transform.maybe_transform(&mut scope, &instruction) {
412                    TransformAction::Ignore => {}
413                    TransformAction::Replace(replacement) => {
414                        self.current_block_mut()
415                            .ops
416                            .borrow_mut()
417                            .extend(replacement);
418                        removed = true;
419                        break;
420                    }
421                    TransformAction::Remove => {
422                        removed = true;
423                        break;
424                    }
425                }
426            }
427            if removed {
428                continue;
429            }
430            match &mut instruction.operation {
431                Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
432                _ => {
433                    self.current_block_mut().ops.borrow_mut().push(instruction);
434                }
435            }
436        }
437
438        is_break
439    }
440
441    /// Gets the `id` and `depth` of the variable if it's a `Local` and not atomic, `None` otherwise.
442    pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
443        match variable.kind {
444            core::VariableKind::LocalMut { id } if !variable.ty.is_atomic() => Some(id),
445            _ => None,
446        }
447    }
448
449    pub(crate) fn ret(&mut self) -> NodeIndex {
450        if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
451            let new_ret = self.program.add_node(BasicBlock::default());
452            self.program.add_edge(new_ret, self.ret, 0);
453            self.ret = new_ret;
454            self.invalidate_structure();
455            new_ret
456        } else {
457            self.ret
458        }
459    }
460
461    pub fn const_arrays(&self) -> Vec<ConstArray> {
462        self.program.const_arrays.clone()
463    }
464
465    pub fn dot_viz(&self) -> Dot<'_, &StableDiGraph<BasicBlock, u32>> {
466        Dot::with_config(&self.program, &[Config::EdgeNoLabel])
467    }
468}
469
470/// A visitor that does nothing.
471pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
472
473#[cfg(test)]
474mod test {
475    use cubecl_core as cubecl;
476    use cubecl_core::cube;
477    use cubecl_core::prelude::*;
478    use cubecl_ir::{ElemType, ExpandElement, Type, UIntKind, Variable, VariableKind};
479
480    use crate::Optimizer;
481
482    #[allow(unused)]
483    #[cube(launch)]
484    fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
485        let mut y = 0;
486        let mut z = 0;
487        if cond == 0 {
488            y = x + 4;
489        }
490        z = x + 4;
491        out[0] = y;
492        out[1] = z;
493    }
494
495    #[test]
496    #[ignore = "no good way to assert opt is applied"]
497    fn test_pre() {
498        let mut ctx = Scope::root(false);
499        let x = ExpandElement::Plain(Variable::new(
500            VariableKind::GlobalScalar(0),
501            Type::scalar(ElemType::UInt(UIntKind::U32)),
502        ));
503        let cond = ExpandElement::Plain(Variable::new(
504            VariableKind::GlobalScalar(1),
505            Type::scalar(ElemType::UInt(UIntKind::U32)),
506        ));
507        let arr = ExpandElement::Plain(Variable::new(
508            VariableKind::GlobalOutputArray(0),
509            Type::scalar(ElemType::UInt(UIntKind::U32)),
510        ));
511
512        pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
513        let opt = Optimizer::new(ctx, CubeDim::default(), vec![], vec![]);
514        println!("{opt}")
515    }
516}