Skip to main content

cubecl_opt/
lib.rs

1//! # `CubeCL` Optimizer
2//!
3//! A library that parses `CubeCL` IR into a
4//! [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph), transforms it to
5//! [static single-assignment form](https://en.wikipedia.org/wiki/Static_single-assignment_form)
6//! and runs various optimizations on it.
7//! The order of operations is as follows:
8//!
9//! 1. Parse root scope recursively into a [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph)
10//! 2. Run optimizations that must be done before SSA transformation
11//! 3. Analyze variable liveness
12//! 4. Transform the graph to [pruned SSA](https://en.wikipedia.org/wiki/Static_single-assignment_form#Pruned_SSA) form
13//! 5. Run post-SSA optimizations and analyses in a loop until no more improvements are found
14//! 6. Speed
15//!
16//! The output is represented as a [`petgraph`] graph of [`BasicBlock`]s terminated by [`ControlFlow`].
17//! This can then be compiled into actual executable code by walking the graph and generating all
18//! phi nodes, instructions and branches.
19//!
20//! # Representing [`PhiInstruction`] in non-SSA languages
21//!
22//! Phi instructions can be simulated by generating a mutable variable for each phi, then assigning
23//! `value` to it in each relevant `block`.
24//!
25
26#![allow(unknown_lints, unnecessary_transmutes)]
27
28use std::{
29    collections::{HashMap, VecDeque},
30    ops::{Deref, DerefMut},
31    rc::Rc,
32    sync::atomic::{AtomicUsize, Ordering},
33};
34
35use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
36use cubecl_core::CubeDim;
37use cubecl_ir::{
38    self as core, Allocator, Branch, Id, Operation, Operator, Processor, Scope, Type, Variable,
39    VariableKind,
40};
41use gvn::GvnPass;
42use passes::{
43    CompositeMerge, ConstEval, ConstOperandSimplify, CopyTransform, DisaggregateArray,
44    EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
45    EmptyBranchToSelect, InlineAssignments, MergeBlocks, MergeSameExpressions, OptimizerPass,
46    ReduceStrength, RemoveIndexScalar,
47};
48use petgraph::{
49    Direction,
50    dot::{Config, Dot},
51    prelude::StableDiGraph,
52    visit::EdgeRef,
53};
54
55mod analyses;
56mod block;
57mod control_flow;
58mod debug;
59mod gvn;
60mod instructions;
61mod passes;
62mod phi_frontiers;
63mod transformers;
64mod version;
65
66pub use analyses::uniformity::Uniformity;
67pub use block::*;
68pub use control_flow::*;
69pub use petgraph::graph::{EdgeIndex, NodeIndex};
70pub use transformers::*;
71pub use version::PhiInstruction;
72
73pub use crate::analyses::liveness::shared::{SharedLiveness, SharedMemory};
74
75/// An atomic counter with a simplified interface.
76#[derive(Clone, Debug, Default)]
77pub struct AtomicCounter {
78    inner: Rc<AtomicUsize>,
79}
80
81impl AtomicCounter {
82    /// Creates a new counter with `val` as its initial value.
83    pub fn new(val: usize) -> Self {
84        Self {
85            inner: Rc::new(AtomicUsize::new(val)),
86        }
87    }
88
89    /// Increments the counter and returns the last count.
90    pub fn inc(&self) -> usize {
91        self.inner.fetch_add(1, Ordering::AcqRel)
92    }
93
94    /// Gets the value of the counter without incrementing it.
95    pub fn get(&self) -> usize {
96        self.inner.load(Ordering::Acquire)
97    }
98}
99
100#[derive(Debug, Clone)]
101pub struct ConstArray {
102    pub id: Id,
103    pub length: usize,
104    pub item: Type,
105    pub values: Vec<core::Variable>,
106}
107
108#[derive(Default, Debug, Clone)]
109pub struct Program {
110    pub const_arrays: Vec<ConstArray>,
111    pub variables: HashMap<Id, Type>,
112    pub graph: StableDiGraph<BasicBlock, u32>,
113    root: NodeIndex,
114}
115
116impl Deref for Program {
117    type Target = StableDiGraph<BasicBlock, u32>;
118
119    fn deref(&self) -> &Self::Target {
120        &self.graph
121    }
122}
123
124impl DerefMut for Program {
125    fn deref_mut(&mut self) -> &mut Self::Target {
126        &mut self.graph
127    }
128}
129
130type VarId = (Id, u16);
131
132/// An optimizer that applies various analyses and optimization passes to the IR.
133#[derive(Debug, Clone)]
134pub struct Optimizer {
135    /// The overall program state
136    pub program: Program,
137    /// Allocator for kernel
138    pub allocator: Allocator,
139    /// Analyses with persistent state
140    analysis_cache: Rc<AnalysisCache>,
141    /// The current block while parsing
142    current_block: Option<NodeIndex>,
143    /// The current loop's break target
144    loop_break: VecDeque<NodeIndex>,
145    /// The single return block
146    pub ret: NodeIndex,
147    /// Root scope to allocate variables on
148    pub root_scope: Scope,
149    /// The `CubeDim` used for range analysis
150    pub(crate) cube_dim: CubeDim,
151    pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
152    pub(crate) processors: Rc<Vec<Box<dyn Processor>>>,
153}
154
155// Needed for WGPU server
156unsafe impl Send for Optimizer {}
157unsafe impl Sync for Optimizer {}
158
159impl Default for Optimizer {
160    fn default() -> Self {
161        Self {
162            program: Default::default(),
163            allocator: Default::default(),
164            current_block: Default::default(),
165            loop_break: Default::default(),
166            ret: Default::default(),
167            root_scope: Scope::root(false),
168            cube_dim: CubeDim::new_1d(1),
169            analysis_cache: Default::default(),
170            transformers: Default::default(),
171            processors: Default::default(),
172        }
173    }
174}
175
176impl Optimizer {
177    /// Create a new optimizer with the scope, `CubeDim` and execution mode passed into the compiler.
178    /// Parses the scope and runs several optimization and analysis loops.
179    pub fn new(
180        expand: Scope,
181        cube_dim: CubeDim,
182        transformers: Vec<Rc<dyn IrTransformer>>,
183        processors: Vec<Box<dyn Processor>>,
184    ) -> Self {
185        let mut opt = Self {
186            root_scope: expand.clone(),
187            cube_dim,
188            allocator: expand.allocator.clone(),
189            transformers,
190            processors: Rc::new(processors),
191            ..Default::default()
192        };
193        opt.run_opt();
194
195        opt
196    }
197
198    /// Create a new optimizer with the scope, `CubeDim` and execution mode passed into the compiler.
199    /// Parses the scope and runs several optimization and analysis loops.
200    pub fn shared_only(expand: Scope, cube_dim: CubeDim) -> Self {
201        let mut opt = Self {
202            root_scope: expand.clone(),
203            cube_dim,
204            allocator: expand.allocator.clone(),
205            transformers: Vec::new(),
206            processors: Rc::new(Vec::new()),
207            ..Default::default()
208        };
209        opt.run_shared_only();
210
211        opt
212    }
213
214    /// Run all optimizations
215    fn run_opt(&mut self) {
216        self.parse_graph(self.root_scope.clone());
217        self.split_critical_edges();
218        self.transform_ssa_and_merge_composites();
219        self.apply_post_ssa_passes();
220
221        // Special expensive passes that should only run once.
222        // Need more optimization rounds in between.
223
224        let arrays_prop = AtomicCounter::new(0);
225        log::debug!("Applying {}", DisaggregateArray.name());
226        DisaggregateArray.apply_post_ssa(self, arrays_prop.clone());
227        if arrays_prop.get() > 0 {
228            self.invalidate_analysis::<Liveness>();
229            self.ssa_transform();
230            self.apply_post_ssa_passes();
231        }
232
233        let gvn_count = AtomicCounter::new(0);
234        log::debug!("Applying {}", GvnPass.name());
235        GvnPass.apply_post_ssa(self, gvn_count.clone());
236        log::debug!("Applying {}", ReduceStrength.name());
237        ReduceStrength.apply_post_ssa(self, gvn_count.clone());
238        log::debug!("Applying {}", CopyTransform.name());
239        CopyTransform.apply_post_ssa(self, gvn_count.clone());
240
241        if gvn_count.get() > 0 {
242            self.apply_post_ssa_passes();
243        }
244
245        self.split_free();
246        self.analysis::<SharedLiveness>();
247
248        log::debug!("Applying {}", MergeBlocks.name());
249        MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
250    }
251
252    /// Run only the shared memory analysis
253    fn run_shared_only(&mut self) {
254        self.parse_graph(self.root_scope.clone());
255        self.split_critical_edges();
256        self.transform_ssa_and_merge_composites();
257        self.split_free();
258        self.analysis::<SharedLiveness>();
259    }
260
261    /// The entry block of the program
262    pub fn entry(&self) -> NodeIndex {
263        self.program.root
264    }
265
266    fn parse_graph(&mut self, scope: Scope) {
267        let entry = self.program.add_node(BasicBlock::default());
268        self.program.root = entry;
269        self.current_block = Some(entry);
270        self.ret = self.program.add_node(BasicBlock::default());
271        *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
272        self.parse_scope(scope);
273        if let Some(current_block) = self.current_block {
274            self.program.add_edge(current_block, self.ret, 0);
275        }
276        // Analyses shouldn't have run at this point, but just in case they have, invalidate
277        // all analyses that depend on the graph
278        self.invalidate_structure();
279    }
280
281    fn apply_post_ssa_passes(&mut self) {
282        // Passes that run regardless of execution mode
283        let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
284            Box::new(InlineAssignments),
285            Box::new(EliminateUnusedVariables),
286            Box::new(ConstOperandSimplify),
287            Box::new(MergeSameExpressions),
288            Box::new(ConstEval),
289            Box::new(RemoveIndexScalar),
290            Box::new(EliminateConstBranches),
291            Box::new(EmptyBranchToSelect),
292            Box::new(EliminateDeadBlocks),
293            Box::new(EliminateDeadPhi),
294        ];
295
296        log::debug!("Applying post-SSA passes");
297        loop {
298            let counter = AtomicCounter::default();
299            for pass in &mut passes {
300                log::debug!("Applying {}", pass.name());
301                pass.apply_post_ssa(self, counter.clone());
302            }
303
304            if counter.get() == 0 {
305                break;
306            }
307        }
308    }
309
310    /// Remove non-constant index vectors from SSA transformation because they currently must be
311    /// mutated
312    fn exempt_index_assign_locals(&mut self) {
313        for node in self.node_ids() {
314            let ops = self.program[node].ops.clone();
315            for op in ops.borrow().values() {
316                if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation
317                    && let VariableKind::LocalMut { id } = &op.out().kind
318                {
319                    self.program.variables.remove(id);
320                }
321            }
322        }
323    }
324
325    /// A set of node indices for all blocks in the program
326    pub fn node_ids(&self) -> Vec<NodeIndex> {
327        self.program.node_indices().collect()
328    }
329
330    fn transform_ssa_and_merge_composites(&mut self) {
331        self.exempt_index_assign_locals();
332        self.ssa_transform();
333
334        let mut done = false;
335        while !done {
336            let changes = AtomicCounter::new(0);
337            CompositeMerge.apply_post_ssa(self, changes.clone());
338            if changes.get() > 0 {
339                self.exempt_index_assign_locals();
340                self.ssa_transform();
341            } else {
342                done = true;
343            }
344        }
345    }
346
347    fn ssa_transform(&mut self) {
348        self.place_phi_nodes();
349        self.version_program();
350        self.program.variables.clear();
351        self.invalidate_analysis::<Writes>();
352        self.invalidate_analysis::<DomFrontiers>();
353    }
354
355    /// Mutable reference to the current basic block
356    pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
357        &mut self.program[self.current_block.unwrap()]
358    }
359
360    /// List of predecessor IDs of the `block`
361    pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
362        self.program
363            .edges_directed(block, Direction::Incoming)
364            .map(|it| it.source())
365            .collect()
366    }
367
368    /// List of successor IDs of the `block`
369    pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
370        self.program
371            .edges_directed(block, Direction::Outgoing)
372            .map(|it| it.target())
373            .collect()
374    }
375
376    /// Reference to the [`BasicBlock`] with ID `block`
377    #[track_caller]
378    pub fn block(&self, block: NodeIndex) -> &BasicBlock {
379        &self.program[block]
380    }
381
382    /// Reference to the [`BasicBlock`] with ID `block`
383    #[track_caller]
384    pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
385        &mut self.program[block]
386    }
387
388    /// Recursively parse a scope into the graph
389    pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
390        let processed = scope.process(self.processors.iter().map(|it| &**it));
391
392        for var in processed.variables {
393            if let VariableKind::LocalMut { id } = var.kind {
394                self.program.variables.insert(id, var.ty);
395            }
396        }
397
398        for (var, values) in scope.const_arrays.clone() {
399            let VariableKind::ConstantArray {
400                id,
401                length,
402                unroll_factor,
403            } = var.kind
404            else {
405                unreachable!()
406            };
407            self.program.const_arrays.push(ConstArray {
408                id,
409                length: length * unroll_factor,
410                item: var.ty,
411                values,
412            });
413        }
414
415        let is_break = processed.instructions.contains(&Branch::Break.into());
416
417        for mut instruction in processed.instructions {
418            let mut removed = false;
419            for transform in self.transformers.iter() {
420                match transform.maybe_transform(&mut scope, &instruction) {
421                    TransformAction::Ignore => {}
422                    TransformAction::Replace(replacement) => {
423                        self.current_block_mut()
424                            .ops
425                            .borrow_mut()
426                            .extend(replacement);
427                        removed = true;
428                        break;
429                    }
430                    TransformAction::Remove => {
431                        removed = true;
432                        break;
433                    }
434                }
435            }
436            if removed {
437                continue;
438            }
439            match &mut instruction.operation {
440                Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
441                _ => {
442                    self.current_block_mut().ops.borrow_mut().push(instruction);
443                }
444            }
445        }
446
447        is_break
448    }
449
450    /// Gets the `id` and `depth` of the variable if it's a `Local` and not atomic, `None` otherwise.
451    pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
452        match variable.kind {
453            core::VariableKind::LocalMut { id } if !variable.ty.is_atomic() => Some(id),
454            _ => None,
455        }
456    }
457
458    pub(crate) fn ret(&mut self) -> NodeIndex {
459        if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
460            let new_ret = self.program.add_node(BasicBlock::default());
461            self.program.add_edge(new_ret, self.ret, 0);
462            self.ret = new_ret;
463            self.invalidate_structure();
464            new_ret
465        } else {
466            self.ret
467        }
468    }
469
470    pub fn const_arrays(&self) -> Vec<ConstArray> {
471        self.program.const_arrays.clone()
472    }
473
474    pub fn dot_viz(&self) -> Dot<'_, &StableDiGraph<BasicBlock, u32>> {
475        Dot::with_config(&self.program, &[Config::EdgeNoLabel])
476    }
477}
478
479/// A visitor that does nothing.
480pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
481
482#[cfg(test)]
483mod test {
484    use cubecl_core as cubecl;
485    use cubecl_core::cube;
486    use cubecl_core::prelude::*;
487    use cubecl_ir::{ElemType, ExpandElement, Type, UIntKind, Variable, VariableKind};
488
489    use crate::Optimizer;
490
491    #[allow(unused)]
492    #[cube(launch)]
493    fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
494        let mut y = 0;
495        let mut z = 0;
496        if cond == 0 {
497            y = x + 4;
498        }
499        z = x + 4;
500        out[0] = y;
501        out[1] = z;
502    }
503
504    #[test_log::test]
505    #[ignore = "no good way to assert opt is applied"]
506    fn test_pre() {
507        let mut ctx = Scope::root(false);
508        let x = ExpandElement::Plain(Variable::new(
509            VariableKind::GlobalScalar(0),
510            Type::scalar(ElemType::UInt(UIntKind::U32)),
511        ));
512        let cond = ExpandElement::Plain(Variable::new(
513            VariableKind::GlobalScalar(1),
514            Type::scalar(ElemType::UInt(UIntKind::U32)),
515        ));
516        let arr = ExpandElement::Plain(Variable::new(
517            VariableKind::GlobalOutputArray(0),
518            Type::scalar(ElemType::UInt(UIntKind::U32)),
519        ));
520
521        pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
522        let opt = Optimizer::new(ctx, CubeDim::new_1d(1), vec![], vec![]);
523        println!("{opt}")
524    }
525}