1use std::{
27 collections::{HashMap, VecDeque},
28 ops::{Deref, DerefMut},
29 rc::Rc,
30 sync::atomic::{AtomicUsize, Ordering},
31};
32
33use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
34use cubecl_common::{CubeDim, ExecutionMode};
35use cubecl_ir::{
36 self as core, Allocator, Branch, Id, Item, Operation, Operator, Scope, Variable, VariableKind,
37};
38use gvn::GvnPass;
39use passes::{
40 CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
41 EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
42 EmptyBranchToSelect, InBoundsToUnchecked, InlineAssignments, MergeBlocks, MergeSameExpressions,
43 OptimizerPass, ReduceStrength, RemoveIndexScalar,
44};
45use petgraph::{Direction, prelude::StableDiGraph, visit::EdgeRef};
46
47mod analyses;
48mod block;
49mod control_flow;
50mod debug;
51mod gvn;
52mod instructions;
53mod passes;
54mod phi_frontiers;
55mod transformers;
56mod version;
57
58pub use analyses::uniformity::Uniformity;
59pub use block::*;
60pub use control_flow::*;
61pub use petgraph::graph::{EdgeIndex, NodeIndex};
62pub use transformers::*;
63pub use version::PhiInstruction;
64
65#[derive(Clone, Debug, Default)]
67pub struct AtomicCounter {
68 inner: Rc<AtomicUsize>,
69}
70
71impl AtomicCounter {
72 pub fn new(val: usize) -> Self {
74 Self {
75 inner: Rc::new(AtomicUsize::new(val)),
76 }
77 }
78
79 pub fn inc(&self) -> usize {
81 self.inner.fetch_add(1, Ordering::AcqRel)
82 }
83
84 pub fn get(&self) -> usize {
86 self.inner.load(Ordering::Acquire)
87 }
88}
89
90#[derive(Debug, Clone)]
91pub struct ConstArray {
92 pub id: Id,
93 pub length: u32,
94 pub item: Item,
95 pub values: Vec<core::Variable>,
96}
97
98#[derive(Default, Debug, Clone)]
99struct Program {
100 pub const_arrays: Vec<ConstArray>,
101 pub variables: HashMap<Id, Item>,
102 pub graph: StableDiGraph<BasicBlock, ()>,
103 root: NodeIndex,
104}
105
106impl Deref for Program {
107 type Target = StableDiGraph<BasicBlock, ()>;
108
109 fn deref(&self) -> &Self::Target {
110 &self.graph
111 }
112}
113
114impl DerefMut for Program {
115 fn deref_mut(&mut self) -> &mut Self::Target {
116 &mut self.graph
117 }
118}
119
120type VarId = (Id, u16);
121
122#[derive(Debug, Clone)]
124pub struct Optimizer {
125 program: Program,
127 pub allocator: Allocator,
129 analysis_cache: Rc<AnalysisCache>,
131 current_block: Option<NodeIndex>,
133 loop_break: VecDeque<NodeIndex>,
135 pub ret: NodeIndex,
137 pub root_scope: Scope,
139 pub(crate) cube_dim: CubeDim,
141 pub(crate) mode: ExecutionMode,
143 pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
144}
145
146impl Default for Optimizer {
147 fn default() -> Self {
148 Self {
149 program: Default::default(),
150 allocator: Default::default(),
151 current_block: Default::default(),
152 loop_break: Default::default(),
153 ret: Default::default(),
154 root_scope: Scope::root(false),
155 cube_dim: Default::default(),
156 mode: Default::default(),
157 analysis_cache: Default::default(),
158 transformers: Default::default(),
159 }
160 }
161}
162
163impl Optimizer {
164 pub fn new(
167 expand: Scope,
168 cube_dim: CubeDim,
169 mode: ExecutionMode,
170 transformers: Vec<Rc<dyn IrTransformer>>,
171 ) -> Self {
172 let mut opt = Self {
173 root_scope: expand.clone(),
174 cube_dim,
175 mode,
176 allocator: expand.allocator.clone(),
177 transformers,
178 ..Default::default()
179 };
180 opt.run_opt();
181
182 opt
183 }
184
185 fn run_opt(&mut self) {
187 self.parse_graph(self.root_scope.clone());
188 self.split_critical_edges();
189 self.apply_pre_ssa_passes();
190 self.exempt_index_assign_locals();
191 self.ssa_transform();
192 self.apply_post_ssa_passes();
193
194 let arrays_prop = AtomicCounter::new(0);
198 CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
199 if arrays_prop.get() > 0 {
200 self.invalidate_analysis::<Liveness>();
201 self.ssa_transform();
202 self.apply_post_ssa_passes();
203 }
204
205 let gvn_count = AtomicCounter::new(0);
206 GvnPass.apply_post_ssa(self, gvn_count.clone());
207 ReduceStrength.apply_post_ssa(self, gvn_count.clone());
208 CopyTransform.apply_post_ssa(self, gvn_count.clone());
209
210 if gvn_count.get() > 0 {
211 self.apply_post_ssa_passes();
212 }
213
214 MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
215 }
216
217 pub fn entry(&self) -> NodeIndex {
219 self.program.root
220 }
221
222 fn parse_graph(&mut self, scope: Scope) {
223 let entry = self.program.add_node(BasicBlock::default());
224 self.program.root = entry;
225 self.current_block = Some(entry);
226 self.ret = self.program.add_node(BasicBlock::default());
227 *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
228 self.parse_scope(scope);
229 if let Some(current_block) = self.current_block {
230 self.program.add_edge(current_block, self.ret, ());
231 }
232 self.invalidate_structure();
235 }
236
237 fn apply_pre_ssa_passes(&mut self) {
238 let mut passes = vec![CompositeMerge];
240 loop {
241 let counter = AtomicCounter::default();
242
243 for pass in &mut passes {
244 pass.apply_pre_ssa(self, counter.clone());
245 }
246
247 if counter.get() == 0 {
248 break;
249 }
250 }
251 }
252
253 fn apply_post_ssa_passes(&mut self) {
254 let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
256 Box::new(InlineAssignments),
257 Box::new(EliminateUnusedVariables),
258 Box::new(ConstOperandSimplify),
259 Box::new(MergeSameExpressions),
260 Box::new(ConstEval),
261 Box::new(RemoveIndexScalar),
262 Box::new(EliminateConstBranches),
263 Box::new(EmptyBranchToSelect),
264 Box::new(EliminateDeadBlocks),
265 Box::new(EliminateDeadPhi),
266 ];
267
268 loop {
269 let counter = AtomicCounter::default();
270 for pass in &mut passes {
271 pass.apply_post_ssa(self, counter.clone());
272 }
273
274 if counter.get() == 0 {
275 break;
276 }
277 }
278
279 if matches!(self.mode, ExecutionMode::Checked) {
281 InBoundsToUnchecked.apply_post_ssa(self, AtomicCounter::new(0));
282 }
283 }
284
285 fn exempt_index_assign_locals(&mut self) {
288 for node in self.node_ids() {
289 let ops = self.program[node].ops.clone();
290 for op in ops.borrow().values() {
291 if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation {
292 if let VariableKind::LocalMut { id } = &op.out().kind {
293 self.program.variables.remove(id);
294 }
295 }
296 }
297 }
298 }
299
300 pub fn node_ids(&self) -> Vec<NodeIndex> {
302 self.program.node_indices().collect()
303 }
304
305 fn ssa_transform(&mut self) {
306 self.place_phi_nodes();
307 self.version_program();
308 self.program.variables.clear();
309 self.invalidate_analysis::<Writes>();
310 self.invalidate_analysis::<DomFrontiers>();
311 }
312
313 pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
315 &mut self.program[self.current_block.unwrap()]
316 }
317
318 pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
320 self.program
321 .edges_directed(block, Direction::Incoming)
322 .map(|it| it.source())
323 .collect()
324 }
325
326 pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
328 self.program
329 .edges_directed(block, Direction::Outgoing)
330 .map(|it| it.target())
331 .collect()
332 }
333
334 #[track_caller]
336 pub fn block(&self, block: NodeIndex) -> &BasicBlock {
337 &self.program[block]
338 }
339
340 #[track_caller]
342 pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
343 &mut self.program[block]
344 }
345
346 pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
348 let processed = scope.process();
349
350 for var in processed.variables {
351 if let VariableKind::LocalMut { id } = var.kind {
352 self.program.variables.insert(id, var.item);
353 }
354 }
355
356 for (var, values) in scope.const_arrays.clone() {
357 let VariableKind::ConstantArray { id, length } = var.kind else {
358 unreachable!()
359 };
360 self.program.const_arrays.push(ConstArray {
361 id,
362 length,
363 item: var.item,
364 values,
365 });
366 }
367
368 let is_break = processed.instructions.contains(&Branch::Break.into());
369
370 for mut instruction in processed.instructions {
371 let mut removed = false;
372 for transform in self.transformers.iter() {
373 match transform.maybe_transform(&mut scope, &instruction) {
374 TransformAction::Ignore => {}
375 TransformAction::Replace(replacement) => {
376 self.current_block_mut()
377 .ops
378 .borrow_mut()
379 .extend(replacement);
380 removed = true;
381 break;
382 }
383 TransformAction::Remove => {
384 removed = true;
385 break;
386 }
387 }
388 }
389 if removed {
390 continue;
391 }
392 match &mut instruction.operation {
393 Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
394 _ => {
395 self.current_block_mut().ops.borrow_mut().push(instruction);
396 }
397 }
398 }
399
400 is_break
401 }
402
403 pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
405 match variable.kind {
406 core::VariableKind::LocalMut { id } if !variable.item.elem.is_atomic() => Some(id),
407 _ => None,
408 }
409 }
410
411 pub(crate) fn ret(&mut self) -> NodeIndex {
412 if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
413 let new_ret = self.program.add_node(BasicBlock::default());
414 self.program.add_edge(new_ret, self.ret, ());
415 self.ret = new_ret;
416 self.invalidate_structure();
417 new_ret
418 } else {
419 self.ret
420 }
421 }
422
423 pub fn const_arrays(&self) -> Vec<ConstArray> {
424 self.program.const_arrays.clone()
425 }
426}
427
428pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
430
431#[cfg(test)]
432mod test {
433 use cubecl_core as cubecl;
434 use cubecl_core::cube;
435 use cubecl_core::prelude::*;
436 use cubecl_ir::{Elem, ExpandElement, Item, UIntKind, Variable, VariableKind};
437
438 use crate::Optimizer;
439
440 #[allow(unused)]
441 #[cube(launch)]
442 fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
443 let mut y = 0;
444 let mut z = 0;
445 if cond == 0 {
446 y = x + 4;
447 }
448 z = x + 4;
449 out[0] = y;
450 out[1] = z;
451 }
452
453 #[test]
454 #[ignore = "no good way to assert opt is applied"]
455 fn test_pre() {
456 let mut ctx = Scope::root(false);
457 let x = ExpandElement::Plain(Variable::new(
458 VariableKind::GlobalScalar(0),
459 Item::new(Elem::UInt(UIntKind::U32)),
460 ));
461 let cond = ExpandElement::Plain(Variable::new(
462 VariableKind::GlobalScalar(1),
463 Item::new(Elem::UInt(UIntKind::U32)),
464 ));
465 let arr = ExpandElement::Plain(Variable::new(
466 VariableKind::GlobalOutputArray(0),
467 Item::new(Elem::UInt(UIntKind::U32)),
468 ));
469
470 pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
471 let opt = Optimizer::new(ctx, CubeDim::default(), ExecutionMode::Checked, vec![]);
472 println!("{opt}")
473 }
474}