1use std::{
27 collections::{HashMap, VecDeque},
28 ops::{Deref, DerefMut},
29 rc::Rc,
30 sync::atomic::{AtomicUsize, Ordering},
31};
32
33use analyses::{dominance::DomFrontiers, liveness::Liveness, writes::Writes, AnalysisCache};
34use cubecl_core::{
35 ir::{self as core, Allocator, Branch, Id, Operation, Operator, Variable, VariableKind},
36 CubeDim,
37};
38use cubecl_core::{
39 ir::{Item, Scope},
40 ExecutionMode,
41};
42use gvn::GvnPass;
43use passes::{
44 CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
45 EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
46 EmptyBranchToSelect, InBoundsToUnchecked, InlineAssignments, MergeBlocks, MergeSameExpressions,
47 OptimizerPass, ReduceStrength, RemoveIndexScalar,
48};
49use petgraph::{prelude::StableDiGraph, visit::EdgeRef, Direction};
50
51mod analyses;
52mod block;
53mod control_flow;
54mod debug;
55mod gvn;
56mod instructions;
57mod passes;
58mod phi_frontiers;
59mod version;
60
61pub use block::*;
62pub use control_flow::*;
63pub use petgraph::graph::{EdgeIndex, NodeIndex};
64pub use version::PhiInstruction;
65
66#[derive(Clone, Debug, Default)]
68pub struct AtomicCounter {
69 inner: Rc<AtomicUsize>,
70}
71
72impl AtomicCounter {
73 pub fn new(val: usize) -> Self {
75 Self {
76 inner: Rc::new(AtomicUsize::new(val)),
77 }
78 }
79
80 pub fn inc(&self) -> usize {
82 self.inner.fetch_add(1, Ordering::AcqRel)
83 }
84
85 pub fn get(&self) -> usize {
87 self.inner.load(Ordering::Acquire)
88 }
89}
90
91#[derive(Debug, Clone)]
92pub struct ConstArray {
93 pub id: Id,
94 pub length: u32,
95 pub item: Item,
96 pub values: Vec<core::Variable>,
97}
98
99#[derive(Default, Debug, Clone)]
100struct Program {
101 pub const_arrays: Vec<ConstArray>,
102 pub variables: HashMap<Id, Item>,
103 pub graph: StableDiGraph<BasicBlock, ()>,
104 root: NodeIndex,
105}
106
107impl Deref for Program {
108 type Target = StableDiGraph<BasicBlock, ()>;
109
110 fn deref(&self) -> &Self::Target {
111 &self.graph
112 }
113}
114
115impl DerefMut for Program {
116 fn deref_mut(&mut self) -> &mut Self::Target {
117 &mut self.graph
118 }
119}
120
121type VarId = (Id, u16);
122
123#[derive(Debug, Clone)]
125pub struct Optimizer {
126 program: Program,
128 pub allocator: Allocator,
130 analysis_cache: Rc<AnalysisCache>,
132 current_block: Option<NodeIndex>,
134 loop_break: VecDeque<NodeIndex>,
136 pub ret: NodeIndex,
138 root_scope: Scope,
140 pub(crate) cube_dim: CubeDim,
142 pub(crate) mode: ExecutionMode,
144}
145
146impl Default for Optimizer {
147 fn default() -> Self {
148 Self {
149 program: Default::default(),
150 allocator: Default::default(),
151 current_block: Default::default(),
152 loop_break: Default::default(),
153 ret: Default::default(),
154 root_scope: Scope::root(),
155 cube_dim: Default::default(),
156 mode: Default::default(),
157 analysis_cache: Default::default(),
158 }
159 }
160}
161
162impl Optimizer {
163 pub fn new(expand: Scope, cube_dim: CubeDim, mode: ExecutionMode) -> Self {
166 let mut opt = Self {
167 root_scope: expand.clone(),
168 cube_dim,
169 mode,
170 allocator: expand.allocator.clone(),
171 ..Default::default()
172 };
173 opt.run_opt(expand);
174
175 opt
176 }
177
178 fn run_opt(&mut self, expand: Scope) {
180 self.parse_graph(expand);
181 self.split_critical_edges();
182 self.apply_pre_ssa_passes();
183 self.exempt_index_assign_locals();
184 self.ssa_transform();
185 self.apply_post_ssa_passes();
186
187 let arrays_prop = AtomicCounter::new(0);
191 CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
192 if arrays_prop.get() > 0 {
193 self.invalidate_analysis::<Liveness>();
194 self.ssa_transform();
195 self.apply_post_ssa_passes();
196 }
197
198 let gvn_count = AtomicCounter::new(0);
199 GvnPass.apply_post_ssa(self, gvn_count.clone());
200 ReduceStrength.apply_post_ssa(self, gvn_count.clone());
201 CopyTransform.apply_post_ssa(self, gvn_count.clone());
202
203 if gvn_count.get() > 0 {
204 self.apply_post_ssa_passes();
205 }
206
207 MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
208 }
209
210 pub fn entry(&self) -> NodeIndex {
212 self.program.root
213 }
214
215 fn parse_graph(&mut self, scope: Scope) {
216 let entry = self.program.add_node(BasicBlock::default());
217 self.program.root = entry;
218 self.current_block = Some(entry);
219 self.ret = self.program.add_node(BasicBlock::default());
220 *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
221 self.parse_scope(scope);
222 if let Some(current_block) = self.current_block {
223 self.program.add_edge(current_block, self.ret, ());
224 }
225 self.invalidate_structure();
228 }
229
230 fn apply_pre_ssa_passes(&mut self) {
231 let mut passes = vec![CompositeMerge];
233 loop {
234 let counter = AtomicCounter::default();
235
236 for pass in &mut passes {
237 pass.apply_pre_ssa(self, counter.clone());
238 }
239
240 if counter.get() == 0 {
241 break;
242 }
243 }
244 }
245
246 fn apply_post_ssa_passes(&mut self) {
247 let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
249 Box::new(InlineAssignments),
250 Box::new(EliminateUnusedVariables),
251 Box::new(ConstOperandSimplify),
252 Box::new(MergeSameExpressions),
253 Box::new(ConstEval),
254 Box::new(RemoveIndexScalar),
255 Box::new(EliminateConstBranches),
256 Box::new(EmptyBranchToSelect),
257 Box::new(EliminateDeadBlocks),
258 Box::new(EliminateDeadPhi),
259 ];
260
261 loop {
262 let counter = AtomicCounter::default();
263 for pass in &mut passes {
264 pass.apply_post_ssa(self, counter.clone());
265 }
266
267 if counter.get() == 0 {
268 break;
269 }
270 }
271
272 if matches!(self.mode, ExecutionMode::Checked) {
274 InBoundsToUnchecked.apply_post_ssa(self, AtomicCounter::new(0));
275 }
276 }
277
278 fn exempt_index_assign_locals(&mut self) {
281 for node in self.node_ids() {
282 let ops = self.program[node].ops.clone();
283 for op in ops.borrow().values() {
284 if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation {
285 if let VariableKind::LocalMut { id } = &op.out().kind {
286 self.program.variables.remove(id);
287 }
288 }
289 }
290 }
291 }
292
293 pub fn node_ids(&self) -> Vec<NodeIndex> {
295 self.program.node_indices().collect()
296 }
297
298 fn ssa_transform(&mut self) {
299 self.place_phi_nodes();
300 self.version_program();
301 self.program.variables.clear();
302 self.invalidate_analysis::<Writes>();
303 self.invalidate_analysis::<DomFrontiers>();
304 }
305
306 pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
308 &mut self.program[self.current_block.unwrap()]
309 }
310
311 pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
313 self.program
314 .edges_directed(block, Direction::Incoming)
315 .map(|it| it.source())
316 .collect()
317 }
318
319 pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
321 self.program
322 .edges_directed(block, Direction::Outgoing)
323 .map(|it| it.target())
324 .collect()
325 }
326
327 #[track_caller]
329 pub fn block(&self, block: NodeIndex) -> &BasicBlock {
330 &self.program[block]
331 }
332
333 #[track_caller]
335 pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
336 &mut self.program[block]
337 }
338
339 pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
341 let processed = scope.process();
342
343 for var in processed.variables {
344 if let VariableKind::LocalMut { id } = var.kind {
345 self.program.variables.insert(id, var.item);
346 }
347 }
348
349 for (var, values) in scope.const_arrays {
350 let VariableKind::ConstantArray { id, length } = var.kind else {
351 unreachable!()
352 };
353 self.program.const_arrays.push(ConstArray {
354 id,
355 length,
356 item: var.item,
357 values,
358 });
359 }
360
361 let is_break = processed.operations.contains(&Branch::Break.into());
362
363 for mut instruction in processed.operations {
364 match &mut instruction.operation {
365 Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
366 _ => {
367 self.current_block_mut().ops.borrow_mut().push(instruction);
368 }
369 }
370 }
371
372 is_break
373 }
374
375 pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
377 match variable.kind {
378 core::VariableKind::LocalMut { id } if !variable.item.elem.is_atomic() => Some(id),
379 _ => None,
380 }
381 }
382
383 pub(crate) fn ret(&mut self) -> NodeIndex {
384 if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
385 let new_ret = self.program.add_node(BasicBlock::default());
386 self.program.add_edge(new_ret, self.ret, ());
387 self.ret = new_ret;
388 self.invalidate_structure();
389 new_ret
390 } else {
391 self.ret
392 }
393 }
394
395 pub fn const_arrays(&self) -> Vec<ConstArray> {
396 self.program.const_arrays.clone()
397 }
398}
399
400pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
402
403#[cfg(test)]
404mod test {
405 use cubecl::prelude::*;
406 use cubecl_core::{
407 self as cubecl,
408 ir::{Elem, Item, UIntKind, Variable, VariableKind},
409 prelude::{Array, CubeContext, ExpandElement},
410 };
411 use cubecl_core::{cube, CubeDim, ExecutionMode};
412
413 use crate::Optimizer;
414
415 #[allow(unused)]
416 #[cube(launch)]
417 fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
418 let mut y = 0;
419 let mut z = 0;
420 if cond == 0 {
421 y = x + 4;
422 }
423 z = x + 4;
424 out[0] = y;
425 out[1] = z;
426 }
427
428 #[test]
429 #[ignore = "no good way to assert opt is applied"]
430 fn test_pre() {
431 let mut ctx = CubeContext::root();
432 let x = ExpandElement::Plain(Variable::new(
433 VariableKind::GlobalScalar(0),
434 Item::new(Elem::UInt(UIntKind::U32)),
435 ));
436 let cond = ExpandElement::Plain(Variable::new(
437 VariableKind::GlobalScalar(1),
438 Item::new(Elem::UInt(UIntKind::U32)),
439 ));
440 let arr = ExpandElement::Plain(Variable::new(
441 VariableKind::GlobalOutputArray(0),
442 Item::new(Elem::UInt(UIntKind::U32)),
443 ));
444
445 pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
446 let scope = ctx.into_scope();
447 let opt = Optimizer::new(scope, CubeDim::default(), ExecutionMode::Checked);
448 println!("{opt}")
449 }
450}