1#![allow(unknown_lints, unnecessary_transmutes)]
27
28use std::{
29 collections::{HashMap, VecDeque},
30 ops::{Deref, DerefMut},
31 rc::Rc,
32 sync::atomic::{AtomicUsize, Ordering},
33};
34
35use analyses::{AnalysisCache, dominance::DomFrontiers, liveness::Liveness, writes::Writes};
36use cubecl_core::CubeDim;
37use cubecl_ir::{
38 self as core, Allocator, Branch, Id, Operation, Operator, Processor, Scope, Type, Variable,
39 VariableKind,
40};
41use gvn::GvnPass;
42use passes::{
43 CompositeMerge, ConstEval, ConstOperandSimplify, CopyTransform, DisaggregateArray,
44 EliminateConstBranches, EliminateDeadBlocks, EliminateDeadPhi, EliminateUnusedVariables,
45 EmptyBranchToSelect, InlineAssignments, MergeBlocks, MergeSameExpressions, OptimizerPass,
46 ReduceStrength, RemoveIndexScalar,
47};
48use petgraph::{
49 Direction,
50 dot::{Config, Dot},
51 prelude::StableDiGraph,
52 visit::EdgeRef,
53};
54
55mod analyses;
56mod block;
57mod control_flow;
58mod debug;
59mod gvn;
60mod instructions;
61mod passes;
62mod phi_frontiers;
63mod transformers;
64mod version;
65
66pub use analyses::uniformity::Uniformity;
67pub use block::*;
68pub use control_flow::*;
69pub use petgraph::graph::{EdgeIndex, NodeIndex};
70pub use transformers::*;
71pub use version::PhiInstruction;
72
73pub use crate::analyses::liveness::shared::{SharedLiveness, SharedMemory};
74
75#[derive(Clone, Debug, Default)]
77pub struct AtomicCounter {
78 inner: Rc<AtomicUsize>,
79}
80
81impl AtomicCounter {
82 pub fn new(val: usize) -> Self {
84 Self {
85 inner: Rc::new(AtomicUsize::new(val)),
86 }
87 }
88
89 pub fn inc(&self) -> usize {
91 self.inner.fetch_add(1, Ordering::AcqRel)
92 }
93
94 pub fn get(&self) -> usize {
96 self.inner.load(Ordering::Acquire)
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct ConstArray {
102 pub id: Id,
103 pub length: usize,
104 pub item: Type,
105 pub values: Vec<core::Variable>,
106}
107
108#[derive(Default, Debug, Clone)]
109pub struct Program {
110 pub const_arrays: Vec<ConstArray>,
111 pub variables: HashMap<Id, Type>,
112 pub graph: StableDiGraph<BasicBlock, u32>,
113 root: NodeIndex,
114}
115
116impl Deref for Program {
117 type Target = StableDiGraph<BasicBlock, u32>;
118
119 fn deref(&self) -> &Self::Target {
120 &self.graph
121 }
122}
123
124impl DerefMut for Program {
125 fn deref_mut(&mut self) -> &mut Self::Target {
126 &mut self.graph
127 }
128}
129
130type VarId = (Id, u16);
131
132#[derive(Debug, Clone)]
134pub struct Optimizer {
135 pub program: Program,
137 pub allocator: Allocator,
139 analysis_cache: Rc<AnalysisCache>,
141 current_block: Option<NodeIndex>,
143 loop_break: VecDeque<NodeIndex>,
145 pub ret: NodeIndex,
147 pub root_scope: Scope,
149 pub(crate) cube_dim: CubeDim,
151 pub(crate) transformers: Vec<Rc<dyn IrTransformer>>,
152 pub(crate) processors: Rc<Vec<Box<dyn Processor>>>,
153}
154
155unsafe impl Send for Optimizer {}
157unsafe impl Sync for Optimizer {}
158
159impl Default for Optimizer {
160 fn default() -> Self {
161 Self {
162 program: Default::default(),
163 allocator: Default::default(),
164 current_block: Default::default(),
165 loop_break: Default::default(),
166 ret: Default::default(),
167 root_scope: Scope::root(false),
168 cube_dim: CubeDim::new_1d(1),
169 analysis_cache: Default::default(),
170 transformers: Default::default(),
171 processors: Default::default(),
172 }
173 }
174}
175
176impl Optimizer {
177 pub fn new(
180 expand: Scope,
181 cube_dim: CubeDim,
182 transformers: Vec<Rc<dyn IrTransformer>>,
183 processors: Vec<Box<dyn Processor>>,
184 ) -> Self {
185 let mut opt = Self {
186 root_scope: expand.clone(),
187 cube_dim,
188 allocator: expand.allocator.clone(),
189 transformers,
190 processors: Rc::new(processors),
191 ..Default::default()
192 };
193 opt.run_opt();
194
195 opt
196 }
197
198 pub fn shared_only(expand: Scope, cube_dim: CubeDim) -> Self {
201 let mut opt = Self {
202 root_scope: expand.clone(),
203 cube_dim,
204 allocator: expand.allocator.clone(),
205 transformers: Vec::new(),
206 processors: Rc::new(Vec::new()),
207 ..Default::default()
208 };
209 opt.run_shared_only();
210
211 opt
212 }
213
214 fn run_opt(&mut self) {
216 self.parse_graph(self.root_scope.clone());
217 self.split_critical_edges();
218 self.transform_ssa_and_merge_composites();
219 self.apply_post_ssa_passes();
220
221 let arrays_prop = AtomicCounter::new(0);
225 log::debug!("Applying {}", DisaggregateArray.name());
226 DisaggregateArray.apply_post_ssa(self, arrays_prop.clone());
227 if arrays_prop.get() > 0 {
228 self.invalidate_analysis::<Liveness>();
229 self.ssa_transform();
230 self.apply_post_ssa_passes();
231 }
232
233 let gvn_count = AtomicCounter::new(0);
234 log::debug!("Applying {}", GvnPass.name());
235 GvnPass.apply_post_ssa(self, gvn_count.clone());
236 log::debug!("Applying {}", ReduceStrength.name());
237 ReduceStrength.apply_post_ssa(self, gvn_count.clone());
238 log::debug!("Applying {}", CopyTransform.name());
239 CopyTransform.apply_post_ssa(self, gvn_count.clone());
240
241 if gvn_count.get() > 0 {
242 self.apply_post_ssa_passes();
243 }
244
245 self.split_free();
246 self.analysis::<SharedLiveness>();
247
248 log::debug!("Applying {}", MergeBlocks.name());
249 MergeBlocks.apply_post_ssa(self, AtomicCounter::new(0));
250 }
251
252 fn run_shared_only(&mut self) {
254 self.parse_graph(self.root_scope.clone());
255 self.split_critical_edges();
256 self.transform_ssa_and_merge_composites();
257 self.split_free();
258 self.analysis::<SharedLiveness>();
259 }
260
261 pub fn entry(&self) -> NodeIndex {
263 self.program.root
264 }
265
266 fn parse_graph(&mut self, scope: Scope) {
267 let entry = self.program.add_node(BasicBlock::default());
268 self.program.root = entry;
269 self.current_block = Some(entry);
270 self.ret = self.program.add_node(BasicBlock::default());
271 *self.program[self.ret].control_flow.borrow_mut() = ControlFlow::Return;
272 self.parse_scope(scope);
273 if let Some(current_block) = self.current_block {
274 self.program.add_edge(current_block, self.ret, 0);
275 }
276 self.invalidate_structure();
279 }
280
281 fn apply_post_ssa_passes(&mut self) {
282 let mut passes: Vec<Box<dyn OptimizerPass>> = vec![
284 Box::new(InlineAssignments),
285 Box::new(EliminateUnusedVariables),
286 Box::new(ConstOperandSimplify),
287 Box::new(MergeSameExpressions),
288 Box::new(ConstEval),
289 Box::new(RemoveIndexScalar),
290 Box::new(EliminateConstBranches),
291 Box::new(EmptyBranchToSelect),
292 Box::new(EliminateDeadBlocks),
293 Box::new(EliminateDeadPhi),
294 ];
295
296 log::debug!("Applying post-SSA passes");
297 loop {
298 let counter = AtomicCounter::default();
299 for pass in &mut passes {
300 log::debug!("Applying {}", pass.name());
301 pass.apply_post_ssa(self, counter.clone());
302 }
303
304 if counter.get() == 0 {
305 break;
306 }
307 }
308 }
309
310 fn exempt_index_assign_locals(&mut self) {
313 for node in self.node_ids() {
314 let ops = self.program[node].ops.clone();
315 for op in ops.borrow().values() {
316 if let Operation::Operator(Operator::IndexAssign(_)) = &op.operation
317 && let VariableKind::LocalMut { id } = &op.out().kind
318 {
319 self.program.variables.remove(id);
320 }
321 }
322 }
323 }
324
325 pub fn node_ids(&self) -> Vec<NodeIndex> {
327 self.program.node_indices().collect()
328 }
329
330 fn transform_ssa_and_merge_composites(&mut self) {
331 self.exempt_index_assign_locals();
332 self.ssa_transform();
333
334 let mut done = false;
335 while !done {
336 let changes = AtomicCounter::new(0);
337 CompositeMerge.apply_post_ssa(self, changes.clone());
338 if changes.get() > 0 {
339 self.exempt_index_assign_locals();
340 self.ssa_transform();
341 } else {
342 done = true;
343 }
344 }
345 }
346
347 fn ssa_transform(&mut self) {
348 self.place_phi_nodes();
349 self.version_program();
350 self.program.variables.clear();
351 self.invalidate_analysis::<Writes>();
352 self.invalidate_analysis::<DomFrontiers>();
353 }
354
355 pub(crate) fn current_block_mut(&mut self) -> &mut BasicBlock {
357 &mut self.program[self.current_block.unwrap()]
358 }
359
360 pub fn predecessors(&self, block: NodeIndex) -> Vec<NodeIndex> {
362 self.program
363 .edges_directed(block, Direction::Incoming)
364 .map(|it| it.source())
365 .collect()
366 }
367
368 pub fn successors(&self, block: NodeIndex) -> Vec<NodeIndex> {
370 self.program
371 .edges_directed(block, Direction::Outgoing)
372 .map(|it| it.target())
373 .collect()
374 }
375
376 #[track_caller]
378 pub fn block(&self, block: NodeIndex) -> &BasicBlock {
379 &self.program[block]
380 }
381
382 #[track_caller]
384 pub fn block_mut(&mut self, block: NodeIndex) -> &mut BasicBlock {
385 &mut self.program[block]
386 }
387
388 pub fn parse_scope(&mut self, mut scope: Scope) -> bool {
390 let processed = scope.process(self.processors.iter().map(|it| &**it));
391
392 for var in processed.variables {
393 if let VariableKind::LocalMut { id } = var.kind {
394 self.program.variables.insert(id, var.ty);
395 }
396 }
397
398 for (var, values) in scope.const_arrays.clone() {
399 let VariableKind::ConstantArray {
400 id,
401 length,
402 unroll_factor,
403 } = var.kind
404 else {
405 unreachable!()
406 };
407 self.program.const_arrays.push(ConstArray {
408 id,
409 length: length * unroll_factor,
410 item: var.ty,
411 values,
412 });
413 }
414
415 let is_break = processed.instructions.contains(&Branch::Break.into());
416
417 for mut instruction in processed.instructions {
418 let mut removed = false;
419 for transform in self.transformers.iter() {
420 match transform.maybe_transform(&mut scope, &instruction) {
421 TransformAction::Ignore => {}
422 TransformAction::Replace(replacement) => {
423 self.current_block_mut()
424 .ops
425 .borrow_mut()
426 .extend(replacement);
427 removed = true;
428 break;
429 }
430 TransformAction::Remove => {
431 removed = true;
432 break;
433 }
434 }
435 }
436 if removed {
437 continue;
438 }
439 match &mut instruction.operation {
440 Operation::Branch(branch) => self.parse_control_flow(branch.clone()),
441 _ => {
442 self.current_block_mut().ops.borrow_mut().push(instruction);
443 }
444 }
445 }
446
447 is_break
448 }
449
450 pub fn local_variable_id(&mut self, variable: &core::Variable) -> Option<Id> {
452 match variable.kind {
453 core::VariableKind::LocalMut { id } if !variable.ty.is_atomic() => Some(id),
454 _ => None,
455 }
456 }
457
458 pub(crate) fn ret(&mut self) -> NodeIndex {
459 if self.program[self.ret].block_use.contains(&BlockUse::Merge) {
460 let new_ret = self.program.add_node(BasicBlock::default());
461 self.program.add_edge(new_ret, self.ret, 0);
462 self.ret = new_ret;
463 self.invalidate_structure();
464 new_ret
465 } else {
466 self.ret
467 }
468 }
469
470 pub fn const_arrays(&self) -> Vec<ConstArray> {
471 self.program.const_arrays.clone()
472 }
473
474 pub fn dot_viz(&self) -> Dot<'_, &StableDiGraph<BasicBlock, u32>> {
475 Dot::with_config(&self.program, &[Config::EdgeNoLabel])
476 }
477}
478
479pub fn visit_noop(_opt: &mut Optimizer, _var: &mut Variable) {}
481
482#[cfg(test)]
483mod test {
484 use cubecl_core as cubecl;
485 use cubecl_core::cube;
486 use cubecl_core::prelude::*;
487 use cubecl_ir::{ElemType, ExpandElement, Type, UIntKind, Variable, VariableKind};
488
489 use crate::Optimizer;
490
491 #[allow(unused)]
492 #[cube(launch)]
493 fn pre_kernel(x: u32, cond: u32, out: &mut Array<u32>) {
494 let mut y = 0;
495 let mut z = 0;
496 if cond == 0 {
497 y = x + 4;
498 }
499 z = x + 4;
500 out[0] = y;
501 out[1] = z;
502 }
503
504 #[test_log::test]
505 #[ignore = "no good way to assert opt is applied"]
506 fn test_pre() {
507 let mut ctx = Scope::root(false);
508 let x = ExpandElement::Plain(Variable::new(
509 VariableKind::GlobalScalar(0),
510 Type::scalar(ElemType::UInt(UIntKind::U32)),
511 ));
512 let cond = ExpandElement::Plain(Variable::new(
513 VariableKind::GlobalScalar(1),
514 Type::scalar(ElemType::UInt(UIntKind::U32)),
515 ));
516 let arr = ExpandElement::Plain(Variable::new(
517 VariableKind::GlobalOutputArray(0),
518 Type::scalar(ElemType::UInt(UIntKind::U32)),
519 ));
520
521 pre_kernel::expand(&mut ctx, x.into(), cond.into(), arr.into());
522 let opt = Optimizer::new(ctx, CubeDim::new_1d(1), vec![], vec![]);
523 println!("{opt}")
524 }
525}