1#![allow(unknown_lints, unnecessary_transmutes)]
27
28use std::{
29 collections::{HashMap, VecDeque},
30 ops::{Deref, DerefMut},
31 rc::Rc,
32 sync::atomic::{AtomicUsize, Ordering},
33};
34
35use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
36use cubecl_common::{CubeDim, ExecutionMode};
37use cubecl_core::post_processing::checked_io::CheckedIoProcessor;
38use cubecl_ir::{
39 self as core, Allocator, Branch, Id, Item, Operation, Operator, Processor, Scope, Variable,
40 VariableKind,
41};
42use gvn::GvnPass;
43use passes::{
44 CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
45 EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
46 EmptyBranchToSelect, InlineAssignments, MergeBlocks, MergeSameExpressions, OptimizerPass,
47 ReduceStrength, RemoveIndexScalar,
48};
49use petgraph::{
50 Direction,
51 dot::{Config, Dot},
52 prelude::StableDiGraph,
53 visit::EdgeRef,
54};
55
56mod analyses;
57mod block;
58mod control_flow;
59mod debug;
60mod gvn;
61mod instructions;
62mod passes;
63mod phi_frontiers;
64mod transformers;
65mod version;
66
67pub use analyses::uniformity::Uniformity;
68pub use block::*;
69pub use control_flow::*;
70pub use petgraph::graph::{EdgeIndex, NodeIndex};
71pub use transformers::*;
72pub use version::PhiInstruction;
73
74#[derive(Clone, Debug, Default)]
76pub struct AtomicCounter {
77 inner: Rc<AtomicUsize>,
78}
79
80impl AtomicCounter {
81 pub fn new(val: usize) -> Self {
83 Self {
84 inner: Rc::new(AtomicUsize::new(val)),
85 }
86 }
87
88 pub fn inc(&self) -> usize {
90 self.inner.fetch_add(1, Ordering::AcqRel)
91 }
92
93 pub fn get(&self) -> usize {
95 self.inner.load(Ordering::Acquire)
96 }
97}
98
99#[derive(Debug, Clone)]
100pub struct ConstArray {
101 pub id: Id,
102 pub length: u32,
103 pub item: Item,
104 pub values: Vec<core::Variable>,
105}
106
107#[derive(Default, Debug, Clone)]
108struct Program {
109 pub const_arrays: Vec<ConstArray>,
110 pub variables: HashMap<Id, Item>,
111 pub graph: StableDiGraph<BasicBlock, u32>,
112 root: NodeIndex,
113}
114
115impl Deref for Program {
116 type Target = StableDiGraph<BasicBlock, u32>;
117
118 fn deref(&self) -> &Self::Target {
119 &self.graph
120 }
121}
122
123impl DerefMut for Program {
124 fn deref_mut(&mut self) -> &mut Self::Target {
125 &mut self.graph
126 }
127}
128
129type VarId = (Id, u16);
130
131#[derive(Debug, Clone)]
133pub struct Optimizer {
134 program: Program,
136 pub allocator: Allocator,
138 analysis_cache: Rc<AnalysisCache>,
140 current_block: Option<NodeIndex>,
142 loop_break: VecDeque<NodeIndex>,
144 pub ret: NodeIndex,
146 pub root_scope: Scope,
148 pub(crate) cube_dim: CubeDim,
150 pub(crate) mode: ExecutionMode,
152 pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
153}
154
155impl Default for Optimizer {
156 fn default() -> Self {
157 Self {
158 program: Default::default(),
159 allocator: Default::default(),
160 current_block: Default::default(),
161 loop_break: Default::default(),
162 ret: Default::default(),
163 root_scope: Scope::root(false),
164 cube_dim: Default::default(),
165 mode: Default::default(),
166 analysis_cache: Default::default(),
167 transformers: Default::default(),
168 }
169 }
170}
171
172impl Optimizer {
173 pub fn new(
176 expand: Scope,
177 cube_dim: CubeDim,
178 mode: ExecutionMode,
179 transformers: Vec<Rc<dyn IrTransformer>>,
180 ) -> Self {
181 let mut opt = Self {
182 root_scope: expand.clone(),
183 cube_dim,
184 mode,
185 allocator: expand.allocator.clone(),
186 transformers,
187 ..Default::default()
188 };
189 opt.run_opt();
190
191 opt
192 }
193
194 fn run_opt(&mut self) {
196 self.parse_graph(self.root_scope.clone());
197 self.split_critical_edges();
198 self.apply_pre_ssa_passes();
199 self.exempt_index_assign_locals();
200 self.ssa_transform();
201 self.apply_post_ssa_passes();
202
203 let arrays_prop = AtomicCounter::new(0);
207 CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
208 if arrays_prop.get() > 0 {
209 self.invalidate_analysis::<Liveness>();
210 self.ssa_transform();
211 self.apply_post_ssa_passes();
212 }
213
214 let gvn_count = AtomicCounter::new(0);
215 GvnPass.apply_post_ssa(self, gvn_count.clone());
216 ReduceStrength.apply_post_ssa(self, gvn_count.clone());
217 CopyTransform.apply_post_ssa(self, gvn_count.clone());
218
219 if gvn_count.get() > 0 {
220 self.apply_post_ssa_passes();
221 }
222
223 MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
224 }
225
226 pub fn entry(&self) -> NodeIndex {
228 self.program.root
229 }
230
231 fn parse_graph(&mut self, scope: Scope) {
232 let entry = self.program.add_node(BasicBlock::default());
233 self.program.root = entry;
234 self.current_block = Some(entry);
235 self.ret = self.program.add_node(BasicBlock::default());
236 *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
237 self.parse_scope(scope);
238 if let Some(current_block) = self.current_block {
239 self.program.add_edge(current_block, self.ret, 0);
240 }
241 self.invalidate_structure();
244 }
245
246 fn apply_pre_ssa_passes(&mut self) {
247 let mut passes = vec![CompositeMerge];
249 loop {
250 let counter = AtomicCounter::default();
251
252 for pass in &mut passes {
253 pass.apply_pre_ssa(self, counter.clone());
254 }
255
256 if counter.get() == 0 {
257 break;
258 }
259 }
260 }
261
262 fn apply_post_ssa_passes(&mut self) {
263 let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
265 Box::new(InlineAssignments),
266 Box::new(EliminateUnusedVariables),
267 Box::new(ConstOperandSimplify),
268 Box::new(MergeSameExpressions),
269 Box::new(ConstEval),
270 Box::new(RemoveIndexScalar),
271 Box::new(EliminateConstBranches),
272 Box::new(EmptyBranchToSelect),
273 Box::new(EliminateDeadBlocks),
274 Box::new(EliminateDeadPhi),
275 ];
276
277 loop {
278 let counter = AtomicCounter::default();
279 for pass in &mut passes {
280 pass.apply_post_ssa(self, counter.clone());
281 }
282
283 if counter.get() == 0 {
284 break;
285 }
286 }
287 }
288
289 fn exempt_index_assign_locals(&mut self) {
292 for node in self.node_ids() {
293 let ops = self.program[node].ops.clone();
294 for op in ops.borrow().values() {
295 if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation {
296 if let VariableKind::LocalMut { id } = &op.out().kind {
297 self.program.variables.remove(id);
298 }
299 }
300 }
301 }
302 }
303
304 pub fn node_ids(&self) -> Vec<NodeIndex> {
306 self.program.node_indices().collect()
307 }
308
309 fn ssa_transform(&mut self) {
310 self.place_phi_nodes();
311 self.version_program();
312 self.program.variables.clear();
313 self.invalidate_analysis::<Writes>();
314 self.invalidate_analysis::<DomFrontiers>();
315 }
316
317 pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
319 &mut self.program[self.current_block.unwrap()]
320 }
321
322 pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
324 self.program
325 .edges_directed(block, Direction::Incoming)
326 .map(|it| it.source())
327 .collect()
328 }
329
330 pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
332 self.program
333 .edges_directed(block, Direction::Outgoing)
334 .map(|it| it.target())
335 .collect()
336 }
337
338 #[track_caller]
340 pub fn block(&self, block: NodeIndex) -> &BasicBlock {
341 &self.program[block]
342 }
343
344 #[track_caller]
346 pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
347 &mut self.program[block]
348 }
349
350 pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
352 let checked_io: Box<dyn Processor> = Box::new(CheckedIoProcessor::new(self.mode));
353 let processed = scope.process([checked_io]);
354
355 for var in processed.variables {
356 if let VariableKind::LocalMut { id } = var.kind {
357 self.program.variables.insert(id, var.item);
358 }
359 }
360
361 for (var, values) in scope.const_arrays.clone() {
362 let VariableKind::ConstantArray { id, length } = var.kind else {
363 unreachable!()
364 };
365 self.program.const_arrays.push(ConstArray {
366 id,
367 length,
368 item: var.item,
369 values,
370 });
371 }
372
373 let is_break = processed.instructions.contains(&Branch::Break.into());
374
375 for mut instruction in processed.instructions {
376 let mut removed = false;
377 for transform in self.transformers.iter() {
378 match transform.maybe_transform(&mut scope, &instruction) {
379 TransformAction::Ignore => {}
380 TransformAction::Replace(replacement) => {
381 self.current_block_mut()
382 .ops
383 .borrow_mut()
384 .extend(replacement);
385 removed = true;
386 break;
387 }
388 TransformAction::Remove => {
389 removed = true;
390 break;
391 }
392 }
393 }
394 if removed {
395 continue;
396 }
397 match &mut instruction.operation {
398 Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
399 _ => {
400 self.current_block_mut().ops.borrow_mut().push(instruction);
401 }
402 }
403 }
404
405 is_break
406 }
407
408 pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
410 match variable.kind {
411 core::VariableKind::LocalMut { id } if !variable.item.elem.is_atomic() => Some(id),
412 _ => None,
413 }
414 }
415
416 pub(crate) fn ret(&mut self) -> NodeIndex {
417 if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
418 let new_ret = self.program.add_node(BasicBlock::default());
419 self.program.add_edge(new_ret, self.ret, 0);
420 self.ret = new_ret;
421 self.invalidate_structure();
422 new_ret
423 } else {
424 self.ret
425 }
426 }
427
428 pub fn const_arrays(&self) -> Vec<ConstArray> {
429 self.program.const_arrays.clone()
430 }
431
432 pub fn dot_viz(&self) -> Dot<'_, &StableDiGraph<BasicBlock, u32>> {
433 Dot::with_config(&self.program, &[Config::EdgeNoLabel])
434 }
435}
436
437pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
439
440#[cfg(test)]
441mod test {
442 use cubecl_core as cubecl;
443 use cubecl_core::cube;
444 use cubecl_core::prelude::*;
445 use cubecl_ir::{Elem, ExpandElement, Item, UIntKind, Variable, VariableKind};
446
447 use crate::Optimizer;
448
449 #[allow(unused)]
450 #[cube(launch)]
451 fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
452 let mut y = 0;
453 let mut z = 0;
454 if cond == 0 {
455 y = x + 4;
456 }
457 z = x + 4;
458 out[0] = y;
459 out[1] = z;
460 }
461
462 #[test]
463 #[ignore = "no good way to assert opt is applied"]
464 fn test_pre() {
465 let mut ctx = Scope::root(false);
466 let x = ExpandElement::Plain(Variable::new(
467 VariableKind::GlobalScalar(0),
468 Item::new(Elem::UInt(UIntKind::U32)),
469 ));
470 let cond = ExpandElement::Plain(Variable::new(
471 VariableKind::GlobalScalar(1),
472 Item::new(Elem::UInt(UIntKind::U32)),
473 ));
474 let arr = ExpandElement::Plain(Variable::new(
475 VariableKind::GlobalOutputArray(0),
476 Item::new(Elem::UInt(UIntKind::U32)),
477 ));
478
479 pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
480 let opt = Optimizer::new(ctx, CubeDim::default(), ExecutionMode::Checked, vec![]);
481 println!("{opt}")
482 }
483}