1#![allow(unknown_lints, unnecessary_transmutes)]
27
28use std::{
29 collections::{HashMap, VecDeque},
30 ops::{Deref, DerefMut},
31 rc::Rc,
32 sync::atomic::{AtomicUsize, Ordering},
33};
34
35use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
36use cubecl_common::CubeDim;
37use cubecl_ir::{
38 self as core, Allocator, Branch, Id, Operation, Operator, Processor, Scope, Type, Variable,
39 VariableKind,
40};
41use gvn::GvnPass;
42use passes::{
43 CompositeMerge, ConstEval, ConstOperandSimplify, CopyPropagateArray, CopyTransform,
44 EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
45 EmptyBranchToSelect, InlineAssignments, MergeBlocks, MergeSameExpressions, OptimizerPass,
46 ReduceStrength, RemoveIndexScalar,
47};
48use petgraph::{
49 Direction,
50 dot::{Config, Dot},
51 prelude::StableDiGraph,
52 visit::EdgeRef,
53};
54
55mod analyses;
56mod block;
57mod control_flow;
58mod debug;
59mod gvn;
60mod instructions;
61mod passes;
62mod phi_frontiers;
63mod transformers;
64mod version;
65
66pub use analyses::uniformity::Uniformity;
67pub use block::*;
68pub use control_flow::*;
69pub use petgraph::graph::{EdgeIndex, NodeIndex};
70pub use transformers::*;
71pub use version::PhiInstruction;
72
73pub use crate::analyses::liveness::shared::SharedLiveness;
74
75#[derive(Clone, Debug, Default)]
77pub struct AtomicCounter {
78 inner: Rc<AtomicUsize>,
79}
80
81impl AtomicCounter {
82 pub fn new(val: usize) -> Self {
84 Self {
85 inner: Rc::new(AtomicUsize::new(val)),
86 }
87 }
88
89 pub fn inc(&self) -> usize {
91 self.inner.fetch_add(1, Ordering::AcqRel)
92 }
93
94 pub fn get(&self) -> usize {
96 self.inner.load(Ordering::Acquire)
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct ConstArray {
102 pub id: Id,
103 pub length: u32,
104 pub item: Type,
105 pub values: Vec<core::Variable>,
106}
107
108#[derive(Default, Debug, Clone)]
109pub struct Program {
110 pub const_arrays: Vec<ConstArray>,
111 pub variables: HashMap<Id, Type>,
112 pub graph: StableDiGraph<BasicBlock, u32>,
113 root: NodeIndex,
114}
115
116impl Deref for Program {
117 type Target = StableDiGraph<BasicBlock, u32>;
118
119 fn deref(&self) -> &Self::Target {
120 &self.graph
121 }
122}
123
124impl DerefMut for Program {
125 fn deref_mut(&mut self) -> &mut Self::Target {
126 &mut self.graph
127 }
128}
129
130type VarId = (Id, u16);
131
132#[derive(Debug, Clone)]
134pub struct Optimizer {
135 pub program: Program,
137 pub allocator: Allocator,
139 analysis_cache: Rc<AnalysisCache>,
141 current_block: Option<NodeIndex>,
143 loop_break: VecDeque<NodeIndex>,
145 pub ret: NodeIndex,
147 pub root_scope: Scope,
149 pub(crate) cube_dim: CubeDim,
151 pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
152 pub(crate) processors: Rc<Vec<Box<dyn Processor>>>,
153}
154
155impl Default for Optimizer {
156 fn default() -> Self {
157 Self {
158 program: Default::default(),
159 allocator: Default::default(),
160 current_block: Default::default(),
161 loop_break: Default::default(),
162 ret: Default::default(),
163 root_scope: Scope::root(false),
164 cube_dim: Default::default(),
165 analysis_cache: Default::default(),
166 transformers: Default::default(),
167 processors: Default::default(),
168 }
169 }
170}
171
172impl Optimizer {
173 pub fn new(
176 expand: Scope,
177 cube_dim: CubeDim,
178 transformers: Vec<Rc<dyn IrTransformer>>,
179 processors: Vec<Box<dyn Processor>>,
180 ) -> Self {
181 let mut opt = Self {
182 root_scope: expand.clone(),
183 cube_dim,
184 allocator: expand.allocator.clone(),
185 transformers,
186 processors: Rc::new(processors),
187 ..Default::default()
188 };
189 opt.run_opt();
190
191 opt
192 }
193
194 pub fn shared_only(expand: Scope, cube_dim: CubeDim) -> Self {
197 let mut opt = Self {
198 root_scope: expand.clone(),
199 cube_dim,
200 allocator: expand.allocator.clone(),
201 transformers: Vec::new(),
202 processors: Rc::new(Vec::new()),
203 ..Default::default()
204 };
205 opt.run_shared_only();
206
207 opt
208 }
209
210 fn run_opt(&mut self) {
212 self.parse_graph(self.root_scope.clone());
213 self.split_critical_edges();
214 self.apply_pre_ssa_passes();
215 self.exempt_index_assign_locals();
216 self.ssa_transform();
217 self.apply_post_ssa_passes();
218
219 let arrays_prop = AtomicCounter::new(0);
223 CopyPropagateArray.apply_post_ssa(self, arrays_prop.clone());
224 if arrays_prop.get() > 0 {
225 self.invalidate_analysis::<Liveness>();
226 self.ssa_transform();
227 self.apply_post_ssa_passes();
228 }
229
230 let gvn_count = AtomicCounter::new(0);
231 GvnPass.apply_post_ssa(self, gvn_count.clone());
232 ReduceStrength.apply_post_ssa(self, gvn_count.clone());
233 CopyTransform.apply_post_ssa(self, gvn_count.clone());
234
235 if gvn_count.get() > 0 {
236 self.apply_post_ssa_passes();
237 }
238
239 self.split_free();
240 self.analysis::<SharedLiveness>();
241
242 MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
243 }
244
245 fn run_shared_only(&mut self) {
247 self.parse_graph(self.root_scope.clone());
248 self.split_critical_edges();
249 self.exempt_index_assign_locals();
250 self.ssa_transform();
251 self.split_free();
252 self.analysis::<SharedLiveness>();
253 }
254
255 pub fn entry(&self) -> NodeIndex {
257 self.program.root
258 }
259
260 fn parse_graph(&mut self, scope: Scope) {
261 let entry = self.program.add_node(BasicBlock::default());
262 self.program.root = entry;
263 self.current_block = Some(entry);
264 self.ret = self.program.add_node(BasicBlock::default());
265 *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
266 self.parse_scope(scope);
267 if let Some(current_block) = self.current_block {
268 self.program.add_edge(current_block, self.ret, 0);
269 }
270 self.invalidate_structure();
273 }
274
275 fn apply_pre_ssa_passes(&mut self) {
276 let mut passes = vec![CompositeMerge];
278 loop {
279 let counter = AtomicCounter::default();
280
281 for pass in &mut passes {
282 pass.apply_pre_ssa(self, counter.clone());
283 }
284
285 if counter.get() == 0 {
286 break;
287 }
288 }
289 }
290
291 fn apply_post_ssa_passes(&mut self) {
292 let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
294 Box::new(InlineAssignments),
295 Box::new(EliminateUnusedVariables),
296 Box::new(ConstOperandSimplify),
297 Box::new(MergeSameExpressions),
298 Box::new(ConstEval),
299 Box::new(RemoveIndexScalar),
300 Box::new(EliminateConstBranches),
301 Box::new(EmptyBranchToSelect),
302 Box::new(EliminateDeadBlocks),
303 Box::new(EliminateDeadPhi),
304 ];
305
306 loop {
307 let counter = AtomicCounter::default();
308 for pass in &mut passes {
309 pass.apply_post_ssa(self, counter.clone());
310 }
311
312 if counter.get() == 0 {
313 break;
314 }
315 }
316 }
317
318 fn exempt_index_assign_locals(&mut self) {
321 for node in self.node_ids() {
322 let ops = self.program[node].ops.clone();
323 for op in ops.borrow().values() {
324 if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation
325 && let VariableKind::LocalMut { id } = &op.out().kind
326 {
327 self.program.variables.remove(id);
328 }
329 }
330 }
331 }
332
333 pub fn node_ids(&self) -> Vec<NodeIndex> {
335 self.program.node_indices().collect()
336 }
337
338 fn ssa_transform(&mut self) {
339 self.place_phi_nodes();
340 self.version_program();
341 self.program.variables.clear();
342 self.invalidate_analysis::<Writes>();
343 self.invalidate_analysis::<DomFrontiers>();
344 }
345
346 pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
348 &mut self.program[self.current_block.unwrap()]
349 }
350
351 pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
353 self.program
354 .edges_directed(block, Direction::Incoming)
355 .map(|it| it.source())
356 .collect()
357 }
358
359 pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
361 self.program
362 .edges_directed(block, Direction::Outgoing)
363 .map(|it| it.target())
364 .collect()
365 }
366
367 #[track_caller]
369 pub fn block(&self, block: NodeIndex) -> &BasicBlock {
370 &self.program[block]
371 }
372
373 #[track_caller]
375 pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
376 &mut self.program[block]
377 }
378
379 pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
381 let processed = scope.process(self.processors.iter().map(|it| &**it));
382
383 for var in processed.variables {
384 if let VariableKind::LocalMut { id } = var.kind {
385 self.program.variables.insert(id, var.ty);
386 }
387 }
388
389 for (var, values) in scope.const_arrays.clone() {
390 let VariableKind::ConstantArray {
391 id,
392 length,
393 unroll_factor,
394 } = var.kind
395 else {
396 unreachable!()
397 };
398 self.program.const_arrays.push(ConstArray {
399 id,
400 length: length * unroll_factor,
401 item: var.ty,
402 values,
403 });
404 }
405
406 let is_break = processed.instructions.contains(&Branch::Break.into());
407
408 for mut instruction in processed.instructions {
409 let mut removed = false;
410 for transform in self.transformers.iter() {
411 match transform.maybe_transform(&mut scope, &instruction) {
412 TransformAction::Ignore => {}
413 TransformAction::Replace(replacement) => {
414 self.current_block_mut()
415 .ops
416 .borrow_mut()
417 .extend(replacement);
418 removed = true;
419 break;
420 }
421 TransformAction::Remove => {
422 removed = true;
423 break;
424 }
425 }
426 }
427 if removed {
428 continue;
429 }
430 match &mut instruction.operation {
431 Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
432 _ => {
433 self.current_block_mut().ops.borrow_mut().push(instruction);
434 }
435 }
436 }
437
438 is_break
439 }
440
441 pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
443 match variable.kind {
444 core::VariableKind::LocalMut { id } if !variable.ty.is_atomic() => Some(id),
445 _ => None,
446 }
447 }
448
449 pub(crate) fn ret(&mut self) -> NodeIndex {
450 if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
451 let new_ret = self.program.add_node(BasicBlock::default());
452 self.program.add_edge(new_ret, self.ret, 0);
453 self.ret = new_ret;
454 self.invalidate_structure();
455 new_ret
456 } else {
457 self.ret
458 }
459 }
460
461 pub fn const_arrays(&self) -> Vec<ConstArray> {
462 self.program.const_arrays.clone()
463 }
464
465 pub fn dot_viz(&self) -> Dot<'_, &StableDiGraph<BasicBlock, u32>> {
466 Dot::with_config(&self.program, &[Config::EdgeNoLabel])
467 }
468}
469
470pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
472
473#[cfg(test)]
474mod test {
475 use cubecl_core as cubecl;
476 use cubecl_core::cube;
477 use cubecl_core::prelude::*;
478 use cubecl_ir::{ElemType, ExpandElement, Type, UIntKind, Variable, VariableKind};
479
480 use crate::Optimizer;
481
482 #[allow(unused)]
483 #[cube(launch)]
484 fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
485 let mut y = 0;
486 let mut z = 0;
487 if cond == 0 {
488 y = x + 4;
489 }
490 z = x + 4;
491 out[0] = y;
492 out[1] = z;
493 }
494
495 #[test]
496 #[ignore = "no good way to assert opt is applied"]
497 fn test_pre() {
498 let mut ctx = Scope::root(false);
499 let x = ExpandElement::Plain(Variable::new(
500 VariableKind::GlobalScalar(0),
501 Type::scalar(ElemType::UInt(UIntKind::U32)),
502 ));
503 let cond = ExpandElement::Plain(Variable::new(
504 VariableKind::GlobalScalar(1),
505 Type::scalar(ElemType::UInt(UIntKind::U32)),
506 ));
507 let arr = ExpandElement::Plain(Variable::new(
508 VariableKind::GlobalOutputArray(0),
509 Type::scalar(ElemType::UInt(UIntKind::U32)),
510 ));
511
512 pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
513 let opt = Optimizer::new(ctx, CubeDim::default(), vec![], vec![]);
514 println!("{opt}")
515 }
516}