cubecl_opt/
lib.rs

1//! # CubeCL Optimizer
2//!
3//! A library that parses CubeCL IR into a
4//! [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph), transforms it to
5//! [static single-assignment form](https://en.wikipedia.org/wiki/Static_single-assignment_form)
6//! and runs various optimizations on it.
7//! The order of operations is as follows:
8//!
9//! 1. Parse root scope recursively into a [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph)
10//! 2. Run optimizations that must be done before SSA transformation
11//! 3. Analyze variable liveness
12//! 4. Transform the graph to [pruned SSA](https://en.wikipedia.org/wiki/Static_single-assignment_form#Pruned_SSA) form
13//! 5. Run post-SSA optimizations and analyses in a loop until no more improvements are found
14//! 6. Speed
15//!
16//! The output is represented as a [`petgraph`] graph of [`BasicBlock`]s terminated by [`ControlFlow`].
17//! This can then be compiled into actual executable code by walking the graph and generating all
18//! phi nodes, instructions and branches.
19//!
20//! # Representing [`PhiInstruction`] in non-SSA languages
21//!
22//! Phi instructions can be simulated by generating a mutable variable for each phi, then assigning
23//! `value` to it in each relevant `block`.
24//!
25
26#![allow(unknown_lints, unnecessary_transmutes)]
27
28use std::{
29    collections::{HashMap, VecDeque},
30    ops::{Deref, DerefMut},
31    rc::Rc,
32    sync::atomic::{AtomicUsize, Ordering},
33};
34
35use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
36use cubecl_common::CubeDim;
37use cubecl_ir::{
38    self as core, Allocator, Branch, Id, Operation, Operator, Processor, Scope, Type, Variable,
39    VariableKind,
40};
41use gvn::GvnPass;
42use passes::{
43    CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
44    EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
45    EmptyBranchToSelect, InlineAssignments, MergeBlocks, MergeSameExpressions, OptimizerPass,
46    ReduceStrength, RemoveIndexScalar,
47};
48use petgraph::{
49    Direction,
50    dot::{Config, Dot},
51    prelude::StableDiGraph,
52    visit::EdgeRef,
53};
54
55mod analyses;
56mod block;
57mod control_flow;
58mod debug;
59mod gvn;
60mod instructions;
61mod passes;
62mod phi_frontiers;
63mod transformers;
64mod version;
65
66pub use analyses::uniformity::Uniformity;
67pub use block::*;
68pub use control_flow::*;
69pub use petgraph::graph::{EdgeIndex, NodeIndex};
70pub use transformers::*;
71pub use version::PhiInstruction;
72
73pub use crate::analyses::liveness::shared::SharedLiveness;
74
75/// An atomic counter with a simplified interface.
76#[derive(Clone, Debug, Default)]
77pub struct AtomicCounter {
78    inner: Rc<AtomicUsize>,
79}
80
81impl AtomicCounter {
82    /// Creates a new counter with `val` as its initial value.
83    pub fn new(val: usize) -> Self {
84        Self {
85            inner: Rc::new(AtomicUsize::new(val)),
86        }
87    }
88
89    /// Increments the counter and returns the last count.
90    pub fn inc(&self) -> usize {
91        self.inner.fetch_add(1, Ordering::AcqRel)
92    }
93
94    /// Gets the value of the counter without incrementing it.
95    pub fn get(&self) -> usize {
96        self.inner.load(Ordering::Acquire)
97    }
98}
99
100#[derive(Debug, Clone)]
101pub struct ConstArray {
102    pub id: Id,
103    pub length: u32,
104    pub item: Type,
105    pub values: Vec<core::Variable>,
106}
107
108#[derive(Default, Debug, Clone)]
109pub struct Program {
110    pub const_arrays: Vec<ConstArray>,
111    pub variables: HashMap<Id, Type>,
112    pub graph: StableDiGraph<BasicBlock, u32>,
113    root: NodeIndex,
114}
115
116impl Deref for Program {
117    type Target = StableDiGraph<BasicBlock, u32>;
118
119    fn deref(&self) -> &Self::Target {
120        &self.graph
121    }
122}
123
124impl DerefMut for Program {
125    fn deref_mut(&mut self) -> &mut Self::Target {
126        &mut self.graph
127    }
128}
129
130type VarId = (Id, u16);
131
132/// An optimizer that applies various analyses and optimization passes to the IR.
133#[derive(Debug, Clone)]
134pub struct Optimizer {
135    /// The overall program state
136    pub program: Program,
137    /// Allocator for kernel
138    pub allocator: Allocator,
139    /// Analyses with persistent state
140    analysis_cache: Rc<AnalysisCache>,
141    /// The current block while parsing
142    current_block: Option<NodeIndex>,
143    /// The current loop's break target
144    loop_break: VecDeque<NodeIndex>,
145    /// The single return block
146    pub ret: NodeIndex,
147    /// Root scope to allocate variables on
148    pub root_scope: Scope,
149    /// The `CubeDim` used for range analysis
150    pub(crate) cube_dim: CubeDim,
151    pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
152    pub(crate) processors: Rc<Vec<Box<dyn Processor>>>,
153}
154
155impl Default for Optimizer {
156    fn default() -> Self {
157        Self {
158            program: Default::default(),
159            allocator: Default::default(),
160            current_block: Default::default(),
161            loop_break: Default::default(),
162            ret: Default::default(),
163            root_scope: Scope::root(false),
164            cube_dim: Default::default(),
165            analysis_cache: Default::default(),
166            transformers: Default::default(),
167            processors: Default::default(),
168        }
169    }
170}
171
172impl Optimizer {
173    /// Create a new optimizer with the scope, `CubeDim` and execution mode passed into the compiler.
174    /// Parses the scope and runs several optimization and analysis loops.
175    pub fn new(
176        expand: Scope,
177        cube_dim: CubeDim,
178        transformers: Vec<Rc<dyn IrTransformer>>,
179        processors: Vec<Box<dyn Processor>>,
180    ) -> Self {
181        let mut opt = Self {
182            root_scope: expand.clone(),
183            cube_dim,
184            allocator: expand.allocator.clone(),
185            transformers,
186            processors: Rc::new(processors),
187            ..Default::default()
188        };
189        opt.run_opt();
190
191        opt
192    }
193
194    /// Create a new optimizer with the scope, `CubeDim` and execution mode passed into the compiler.
195    /// Parses the scope and runs several optimization and analysis loops.
196    pub fn shared_only(expand: Scope, cube_dim: CubeDim) -> Self {
197        let mut opt = Self {
198            root_scope: expand.clone(),
199            cube_dim,
200            allocator: expand.allocator.clone(),
201            transformers: Vec::new(),
202            processors: Rc::new(Vec::new()),
203            ..Default::default()
204        };
205        opt.run_shared_only();
206
207        opt
208    }
209
210    /// Run all optimizations
211    fn run_opt(&mut self) {
212        self.parse_graph(self.root_scope.clone());
213        self.split_critical_edges();
214        self.transform_ssa_and_merge_composites();
215        self.apply_post_ssa_passes();
216
217        // Special expensive passes that should only run once.
218        // Need more optimization rounds in between.
219
220        let arrays_prop = AtomicCounter::new(0);
221        CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
222        if arrays_prop.get() > 0 {
223            self.invalidate_analysis::<Liveness>();
224            self.ssa_transform();
225            self.apply_post_ssa_passes();
226        }
227
228        let gvn_count = AtomicCounter::new(0);
229        GvnPass.apply_post_ssa(self, gvn_count.clone());
230        ReduceStrength.apply_post_ssa(self, gvn_count.clone());
231        CopyTransform.apply_post_ssa(self, gvn_count.clone());
232
233        if gvn_count.get() > 0 {
234            self.apply_post_ssa_passes();
235        }
236
237        self.split_free();
238        self.analysis::<SharedLiveness>();
239
240        MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
241    }
242
243    /// Run only the shared memory analysis
244    fn run_shared_only(&mut self) {
245        self.parse_graph(self.root_scope.clone());
246        self.split_critical_edges();
247        self.transform_ssa_and_merge_composites();
248        self.split_free();
249        self.analysis::<SharedLiveness>();
250    }
251
252    /// The entry block of the program
253    pub fn entry(&self) -> NodeIndex {
254        self.program.root
255    }
256
257    fn parse_graph(&mut self, scope: Scope) {
258        let entry = self.program.add_node(BasicBlock::default());
259        self.program.root = entry;
260        self.current_block = Some(entry);
261        self.ret = self.program.add_node(BasicBlock::default());
262        *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
263        self.parse_scope(scope);
264        if let Some(current_block) = self.current_block {
265            self.program.add_edge(current_block, self.ret, 0);
266        }
267        // Analyses shouldn't have run at this point, but just in case they have, invalidate
268        // all analyses that depend on the graph
269        self.invalidate_structure();
270    }
271
272    fn apply_post_ssa_passes(&mut self) {
273        // Passes that run regardless of execution mode
274        let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
275            Box::new(InlineAssignments),
276            Box::new(EliminateUnusedVariables),
277            Box::new(ConstOperandSimplify),
278            Box::new(MergeSameExpressions),
279            Box::new(ConstEval),
280            Box::new(RemoveIndexScalar),
281            Box::new(EliminateConstBranches),
282            Box::new(EmptyBranchToSelect),
283            Box::new(EliminateDeadBlocks),
284            Box::new(EliminateDeadPhi),
285        ];
286
287        loop {
288            let counter = AtomicCounter::default();
289            for pass in &mut passes {
290                pass.apply_post_ssa(self, counter.clone());
291            }
292
293            if counter.get() == 0 {
294                break;
295            }
296        }
297    }
298
299    /// Remove non-constant index vectors from SSA transformation because they currently must be
300    /// mutated
301    fn exempt_index_assign_locals(&mut self) {
302        for node in self.node_ids() {
303            let ops = self.program[node].ops.clone();
304            for op in ops.borrow().values() {
305                if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation
306                    && let VariableKind::LocalMut { id } = &op.out().kind
307                {
308                    self.program.variables.remove(id);
309                }
310            }
311        }
312    }
313
314    /// A set of node indices for all blocks in the program
315    pub fn node_ids(&self) -> Vec<NodeIndex> {
316        self.program.node_indices().collect()
317    }
318
319    fn transform_ssa_and_merge_composites(&mut self) {
320        self.exempt_index_assign_locals();
321        self.ssa_transform();
322
323        let mut done = false;
324        while !done {
325            let changes = AtomicCounter::new(0);
326            CompositeMerge.apply_post_ssa(self, changes.clone());
327            if changes.get() > 0 {
328                self.exempt_index_assign_locals();
329                self.ssa_transform();
330            } else {
331                done = true;
332            }
333        }
334    }
335
336    fn ssa_transform(&mut self) {
337        self.place_phi_nodes();
338        self.version_program();
339        self.program.variables.clear();
340        self.invalidate_analysis::<Writes>();
341        self.invalidate_analysis::<DomFrontiers>();
342    }
343
344    /// Mutable reference to the current basic block
345    pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
346        &mut self.program[self.current_block.unwrap()]
347    }
348
349    /// List of predecessor IDs of the `block`
350    pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
351        self.program
352            .edges_directed(block, Direction::Incoming)
353            .map(|it| it.source())
354            .collect()
355    }
356
357    /// List of successor IDs of the `block`
358    pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
359        self.program
360            .edges_directed(block, Direction::Outgoing)
361            .map(|it| it.target())
362            .collect()
363    }
364
365    /// Reference to the [`BasicBlock`] with ID `block`
366    #[track_caller]
367    pub fn block(&self, block: NodeIndex) -> &BasicBlock {
368        &self.program[block]
369    }
370
371    /// Reference to the [`BasicBlock`] with ID `block`
372    #[track_caller]
373    pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
374        &mut self.program[block]
375    }
376
377    /// Recursively parse a scope into the graph
378    pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
379        let processed = scope.process(self.processors.iter().map(|it| &**it));
380
381        for var in processed.variables {
382            if let VariableKind::LocalMut { id } = var.kind {
383                self.program.variables.insert(id, var.ty);
384            }
385        }
386
387        for (var, values) in scope.const_arrays.clone() {
388            let VariableKind::ConstantArray {
389                id,
390                length,
391                unroll_factor,
392            } = var.kind
393            else {
394                unreachable!()
395            };
396            self.program.const_arrays.push(ConstArray {
397                id,
398                length: length * unroll_factor,
399                item: var.ty,
400                values,
401            });
402        }
403
404        let is_break = processed.instructions.contains(&Branch::Break.into());
405
406        for mut instruction in processed.instructions {
407            let mut removed = false;
408            for transform in self.transformers.iter() {
409                match transform.maybe_transform(&mut scope, &instruction) {
410                    TransformAction::Ignore => {}
411                    TransformAction::Replace(replacement) => {
412                        self.current_block_mut()
413                            .ops
414                            .borrow_mut()
415                            .extend(replacement);
416                        removed = true;
417                        break;
418                    }
419                    TransformAction::Remove => {
420                        removed = true;
421                        break;
422                    }
423                }
424            }
425            if removed {
426                continue;
427            }
428            match &mut instruction.operation {
429                Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
430                _ => {
431                    self.current_block_mut().ops.borrow_mut().push(instruction);
432                }
433            }
434        }
435
436        is_break
437    }
438
439    /// Gets the `id` and `depth` of the variable if it's a `Local` and not atomic, `None` otherwise.
440    pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
441        match variable.kind {
442            core::VariableKind::LocalMut { id } if !variable.ty.is_atomic() => Some(id),
443            _ => None,
444        }
445    }
446
447    pub(crate) fn ret(&mut self) -> NodeIndex {
448        if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
449            let new_ret = self.program.add_node(BasicBlock::default());
450            self.program.add_edge(new_ret, self.ret, 0);
451            self.ret = new_ret;
452            self.invalidate_structure();
453            new_ret
454        } else {
455            self.ret
456        }
457    }
458
459    pub fn const_arrays(&self) -> Vec<ConstArray> {
460        self.program.const_arrays.clone()
461    }
462
463    pub fn dot_viz(&self) -> Dot<'_, &StableDiGraph<BasicBlock, u32>> {
464        Dot::with_config(&self.program, &[Config::EdgeNoLabel])
465    }
466}
467
468/// A visitor that does nothing.
469pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
470
471#[cfg(test)]
472mod test {
473    use cubecl_core as cubecl;
474    use cubecl_core::cube;
475    use cubecl_core::prelude::*;
476    use cubecl_ir::{ElemType, ExpandElement, Type, UIntKind, Variable, VariableKind};
477
478    use crate::Optimizer;
479
480    #[allow(unused)]
481    #[cube(launch)]
482    fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
483        let mut y = 0;
484        let mut z = 0;
485        if cond == 0 {
486            y = x + 4;
487        }
488        z = x + 4;
489        out[0] = y;
490        out[1] = z;
491    }
492
493    #[test]
494    #[ignore = "no good way to assert opt is applied"]
495    fn test_pre() {
496        let mut ctx = Scope::root(false);
497        let x = ExpandElement::Plain(Variable::new(
498            VariableKind::GlobalScalar(0),
499            Type::scalar(ElemType::UInt(UIntKind::U32)),
500        ));
501        let cond = ExpandElement::Plain(Variable::new(
502            VariableKind::GlobalScalar(1),
503            Type::scalar(ElemType::UInt(UIntKind::U32)),
504        ));
505        let arr = ExpandElement::Plain(Variable::new(
506            VariableKind::GlobalOutputArray(0),
507            Type::scalar(ElemType::UInt(UIntKind::U32)),
508        ));
509
510        pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
511        let opt = Optimizer::new(ctx, CubeDim::default(), vec![], vec![]);
512        println!("{opt}")
513    }
514}