1#![allow(unknown_lints, unnecessary_transmutes)]
27
28use std::{
29 collections::{HashMap, VecDeque},
30 ops::{Deref, DerefMut},
31 rc::Rc,
32 sync::atomic::{AtomicUsize, Ordering},
33};
34
35use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
36use cubecl_common::CubeDim;
37use cubecl_ir::{
38 self as core, Allocator, Branch, Id, Operation, Operator, Processor, Scope, Type, Variable,
39 VariableKind,
40};
41use gvn::GvnPass;
42use passes::{
43 CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
44 EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
45 EmptyBranchToSelect, InlineAssignments, MergeBlocks, MergeSameExpressions, OptimizerPass,
46 ReduceStrength, RemoveIndexScalar,
47};
48use petgraph::{
49 Direction,
50 dot::{Config, Dot},
51 prelude::StableDiGraph,
52 visit::EdgeRef,
53};
54
55mod analyses;
56mod block;
57mod control_flow;
58mod debug;
59mod gvn;
60mod instructions;
61mod passes;
62mod phi_frontiers;
63mod transformers;
64mod version;
65
66pub use analyses::uniformity::Uniformity;
67pub use block::*;
68pub use control_flow::*;
69pub use petgraph::graph::{EdgeIndex, NodeIndex};
70pub use transformers::*;
71pub use version::PhiInstruction;
72
73pub use crate::analyses::liveness::shared::SharedLiveness;
74
75#[derive(Clone, Debug, Default)]
77pub struct AtomicCounter {
78 inner: Rc<AtomicUsize>,
79}
80
81impl AtomicCounter {
82 pub fn new(val: usize) -> Self {
84 Self {
85 inner: Rc::new(AtomicUsize::new(val)),
86 }
87 }
88
89 pub fn inc(&self) -> usize {
91 self.inner.fetch_add(1, Ordering::AcqRel)
92 }
93
94 pub fn get(&self) -> usize {
96 self.inner.load(Ordering::Acquire)
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct ConstArray {
102 pub id: Id,
103 pub length: u32,
104 pub item: Type,
105 pub values: Vec<core::Variable>,
106}
107
108#[derive(Default, Debug, Clone)]
109pub struct Program {
110 pub const_arrays: Vec<ConstArray>,
111 pub variables: HashMap<Id, Type>,
112 pub graph: StableDiGraph<BasicBlock, u32>,
113 root: NodeIndex,
114}
115
116impl Deref for Program {
117 type Target = StableDiGraph<BasicBlock, u32>;
118
119 fn deref(&self) -> &Self::Target {
120 &self.graph
121 }
122}
123
124impl DerefMut for Program {
125 fn deref_mut(&mut self) -> &mut Self::Target {
126 &mut self.graph
127 }
128}
129
130type VarId = (Id, u16);
131
132#[derive(Debug, Clone)]
134pub struct Optimizer {
135 pub program: Program,
137 pub allocator: Allocator,
139 analysis_cache: Rc<AnalysisCache>,
141 current_block: Option<NodeIndex>,
143 loop_break: VecDeque<NodeIndex>,
145 pub ret: NodeIndex,
147 pub root_scope: Scope,
149 pub(crate) cube_dim: CubeDim,
151 pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
152 pub(crate) processors: Rc<Vec<Box<dyn Processor>>>,
153}
154
155impl Default for Optimizer {
156 fn default() -> Self {
157 Self {
158 program: Default::default(),
159 allocator: Default::default(),
160 current_block: Default::default(),
161 loop_break: Default::default(),
162 ret: Default::default(),
163 root_scope: Scope::root(false),
164 cube_dim: Default::default(),
165 analysis_cache: Default::default(),
166 transformers: Default::default(),
167 processors: Default::default(),
168 }
169 }
170}
171
172impl Optimizer {
173 pub fn new(
176 expand: Scope,
177 cube_dim: CubeDim,
178 transformers: Vec<Rc<dyn IrTransformer>>,
179 processors: Vec<Box<dyn Processor>>,
180 ) -> Self {
181 let mut opt = Self {
182 root_scope: expand.clone(),
183 cube_dim,
184 allocator: expand.allocator.clone(),
185 transformers,
186 processors: Rc::new(processors),
187 ..Default::default()
188 };
189 opt.run_opt();
190
191 opt
192 }
193
194 pub fn shared_only(expand: Scope, cube_dim: CubeDim) -> Self {
197 let mut opt = Self {
198 root_scope: expand.clone(),
199 cube_dim,
200 allocator: expand.allocator.clone(),
201 transformers: Vec::new(),
202 processors: Rc::new(Vec::new()),
203 ..Default::default()
204 };
205 opt.run_shared_only();
206
207 opt
208 }
209
210 fn run_opt(&mut self) {
212 self.parse_graph(self.root_scope.clone());
213 self.split_critical_edges();
214 self.transform_ssa_and_merge_composites();
215 self.apply_post_ssa_passes();
216
217 let arrays_prop = AtomicCounter::new(0);
221 CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
222 if arrays_prop.get() > 0 {
223 self.invalidate_analysis::<Liveness>();
224 self.ssa_transform();
225 self.apply_post_ssa_passes();
226 }
227
228 let gvn_count = AtomicCounter::new(0);
229 GvnPass.apply_post_ssa(self, gvn_count.clone());
230 ReduceStrength.apply_post_ssa(self, gvn_count.clone());
231 CopyTransform.apply_post_ssa(self, gvn_count.clone());
232
233 if gvn_count.get() > 0 {
234 self.apply_post_ssa_passes();
235 }
236
237 self.split_free();
238 self.analysis::<SharedLiveness>();
239
240 MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
241 }
242
243 fn run_shared_only(&mut self) {
245 self.parse_graph(self.root_scope.clone());
246 self.split_critical_edges();
247 self.transform_ssa_and_merge_composites();
248 self.split_free();
249 self.analysis::<SharedLiveness>();
250 }
251
252 pub fn entry(&self) -> NodeIndex {
254 self.program.root
255 }
256
257 fn parse_graph(&mut self, scope: Scope) {
258 let entry = self.program.add_node(BasicBlock::default());
259 self.program.root = entry;
260 self.current_block = Some(entry);
261 self.ret = self.program.add_node(BasicBlock::default());
262 *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
263 self.parse_scope(scope);
264 if let Some(current_block) = self.current_block {
265 self.program.add_edge(current_block, self.ret, 0);
266 }
267 self.invalidate_structure();
270 }
271
272 fn apply_post_ssa_passes(&mut self) {
273 let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
275 Box::new(InlineAssignments),
276 Box::new(EliminateUnusedVariables),
277 Box::new(ConstOperandSimplify),
278 Box::new(MergeSameExpressions),
279 Box::new(ConstEval),
280 Box::new(RemoveIndexScalar),
281 Box::new(EliminateConstBranches),
282 Box::new(EmptyBranchToSelect),
283 Box::new(EliminateDeadBlocks),
284 Box::new(EliminateDeadPhi),
285 ];
286
287 loop {
288 let counter = AtomicCounter::default();
289 for pass in &mut passes {
290 pass.apply_post_ssa(self, counter.clone());
291 }
292
293 if counter.get() == 0 {
294 break;
295 }
296 }
297 }
298
299 fn exempt_index_assign_locals(&mut self) {
302 for node in self.node_ids() {
303 let ops = self.program[node].ops.clone();
304 for op in ops.borrow().values() {
305 if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation
306 && let VariableKind::LocalMut { id } = &op.out().kind
307 {
308 self.program.variables.remove(id);
309 }
310 }
311 }
312 }
313
314 pub fn node_ids(&self) -> Vec<NodeIndex> {
316 self.program.node_indices().collect()
317 }
318
319 fn transform_ssa_and_merge_composites(&mut self) {
320 self.exempt_index_assign_locals();
321 self.ssa_transform();
322
323 let mut done = false;
324 while !done {
325 let changes = AtomicCounter::new(0);
326 CompositeMerge.apply_post_ssa(self, changes.clone());
327 if changes.get() > 0 {
328 self.exempt_index_assign_locals();
329 self.ssa_transform();
330 } else {
331 done = true;
332 }
333 }
334 }
335
336 fn ssa_transform(&mut self) {
337 self.place_phi_nodes();
338 self.version_program();
339 self.program.variables.clear();
340 self.invalidate_analysis::<Writes>();
341 self.invalidate_analysis::<DomFrontiers>();
342 }
343
344 pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
346 &mut self.program[self.current_block.unwrap()]
347 }
348
349 pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
351 self.program
352 .edges_directed(block, Direction::Incoming)
353 .map(|it| it.source())
354 .collect()
355 }
356
357 pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
359 self.program
360 .edges_directed(block, Direction::Outgoing)
361 .map(|it| it.target())
362 .collect()
363 }
364
365 #[track_caller]
367 pub fn block(&self, block: NodeIndex) -> &BasicBlock {
368 &self.program[block]
369 }
370
371 #[track_caller]
373 pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
374 &mut self.program[block]
375 }
376
377 pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
379 let processed = scope.process(self.processors.iter().map(|it| &**it));
380
381 for var in processed.variables {
382 if let VariableKind::LocalMut { id } = var.kind {
383 self.program.variables.insert(id, var.ty);
384 }
385 }
386
387 for (var, values) in scope.const_arrays.clone() {
388 let VariableKind::ConstantArray {
389 id,
390 length,
391 unroll_factor,
392 } = var.kind
393 else {
394 unreachable!()
395 };
396 self.program.const_arrays.push(ConstArray {
397 id,
398 length: length * unroll_factor,
399 item: var.ty,
400 values,
401 });
402 }
403
404 let is_break = processed.instructions.contains(&Branch::Break.into());
405
406 for mut instruction in processed.instructions {
407 let mut removed = false;
408 for transform in self.transformers.iter() {
409 match transform.maybe_transform(&mut scope, &instruction) {
410 TransformAction::Ignore => {}
411 TransformAction::Replace(replacement) => {
412 self.current_block_mut()
413 .ops
414 .borrow_mut()
415 .extend(replacement);
416 removed = true;
417 break;
418 }
419 TransformAction::Remove => {
420 removed = true;
421 break;
422 }
423 }
424 }
425 if removed {
426 continue;
427 }
428 match &mut instruction.operation {
429 Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
430 _ => {
431 self.current_block_mut().ops.borrow_mut().push(instruction);
432 }
433 }
434 }
435
436 is_break
437 }
438
439 pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
441 match variable.kind {
442 core::VariableKind::LocalMut { id } if !variable.ty.is_atomic() => Some(id),
443 _ => None,
444 }
445 }
446
447 pub(crate) fn ret(&mut self) -> NodeIndex {
448 if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
449 let new_ret = self.program.add_node(BasicBlock::default());
450 self.program.add_edge(new_ret, self.ret, 0);
451 self.ret = new_ret;
452 self.invalidate_structure();
453 new_ret
454 } else {
455 self.ret
456 }
457 }
458
459 pub fn const_arrays(&self) -> Vec<ConstArray> {
460 self.program.const_arrays.clone()
461 }
462
463 pub fn dot_viz(&self) -> Dot<'_, &StableDiGraph<BasicBlock, u32>> {
464 Dot::with_config(&self.program, &[Config::EdgeNoLabel])
465 }
466}
467
468pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
470
471#[cfg(test)]
472mod test {
473 use cubecl_core as cubecl;
474 use cubecl_core::cube;
475 use cubecl_core::prelude::*;
476 use cubecl_ir::{ElemType, ExpandElement, Type, UIntKind, Variable, VariableKind};
477
478 use crate::Optimizer;
479
480 #[allow(unused)]
481 #[cube(launch)]
482 fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
483 let mut y = 0;
484 let mut z = 0;
485 if cond == 0 {
486 y = x + 4;
487 }
488 z = x + 4;
489 out[0] = y;
490 out[1] = z;
491 }
492
493 #[test]
494 #[ignore = "no good way to assert opt is applied"]
495 fn test_pre() {
496 let mut ctx = Scope::root(false);
497 let x = ExpandElement::Plain(Variable::new(
498 VariableKind::GlobalScalar(0),
499 Type::scalar(ElemType::UInt(UIntKind::U32)),
500 ));
501 let cond = ExpandElement::Plain(Variable::new(
502 VariableKind::GlobalScalar(1),
503 Type::scalar(ElemType::UInt(UIntKind::U32)),
504 ));
505 let arr = ExpandElement::Plain(Variable::new(
506 VariableKind::GlobalOutputArray(0),
507 Type::scalar(ElemType::UInt(UIntKind::U32)),
508 ));
509
510 pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
511 let opt = Optimizer::new(ctx, CubeDim::default(), vec![], vec![]);
512 println!("{opt}")
513 }
514}