1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer};
6use cubecl_ir::{
7 Arithmetic, BinaryOperator, Branch, Comparison, ConstantValue, ElemType, If, IfElse,
8 Instruction, Loop, Marker, Operation, RangeLoop, Switch, Type, Variable, VariableKind,
9};
10use petgraph::{Direction, graph::EdgeIndex, visit::EdgeRef};
11use stable_vec::StableVec;
12
13#[derive(Default, Debug, Clone)]
15pub enum ControlFlow {
16 IfElse {
18 cond: Variable,
19 then: NodeIndex,
20 or_else: NodeIndex,
21 merge: Option<NodeIndex>,
22 },
23 Switch {
25 value: Variable,
26 default: NodeIndex,
27 branches: Vec<(u32, NodeIndex)>,
28 merge: Option<NodeIndex>,
29 },
30 Loop {
33 body: NodeIndex,
34 continue_target: NodeIndex,
35 merge: NodeIndex,
36 },
37 LoopBreak {
41 break_cond: Variable,
42 body: NodeIndex,
43 continue_target: NodeIndex,
44 merge: NodeIndex,
45 },
46 Return,
49 Unreachable,
51 #[default]
53 None,
54}
55
56pub(crate) enum ControlFlowAction {
57 None,
58 AbortBlock,
59}
60
61impl Optimizer {
62 pub(crate) fn parse_control_flow(&mut self, branch: Branch) -> ControlFlowAction {
63 match branch {
64 Branch::If(if_) => {
65 self.parse_if(*if_);
66 ControlFlowAction::None
67 }
68 Branch::IfElse(if_else) => {
69 self.parse_if_else(if_else);
70 ControlFlowAction::None
71 }
72 Branch::Switch(switch) => {
73 self.parse_switch(*switch);
74 ControlFlowAction::None
75 }
76 Branch::RangeLoop(range_loop) => {
77 self.parse_for_loop(*range_loop);
78 ControlFlowAction::None
79 }
80 Branch::Loop(loop_) => {
81 self.parse_loop(*loop_);
82 ControlFlowAction::None
83 }
84 Branch::Unreachable => {
85 let current_block = self.current_block.take().unwrap();
86 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Unreachable;
87 ControlFlowAction::AbortBlock
88 }
89 Branch::Return => {
90 let current_block = self.current_block.take().unwrap();
91 let ret = self.ret();
92 self.program.add_edge(current_block, ret, 0);
93 ControlFlowAction::AbortBlock
94 }
95 Branch::Break => {
96 let current_block = self.current_block.take().unwrap();
97 let loop_break = self.loop_break.back().expect("Can't break outside loop");
98 self.program.add_edge(current_block, *loop_break, 0);
99 ControlFlowAction::AbortBlock
100 }
101 }
102 }
103
104 pub(crate) fn parse_if(&mut self, if_: If) {
105 let current_block = self.current_block.unwrap();
106 let then = self.program.add_node(BasicBlock::default());
107 let next = self.program.add_node(BasicBlock::default());
108 let mut merge = next;
109
110 self.program.add_edge(current_block, then, 0);
111 self.program.add_edge(current_block, next, 0);
112
113 self.current_block = Some(then);
114 let is_break = self.parse_scope(if_.scope);
115
116 if let Some(current_block) = self.current_block {
117 self.program.add_edge(current_block, next, 0);
118 } else {
119 merge = self.ret;
121 }
122
123 let merge = if is_break { None } else { Some(merge) };
124
125 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
126 cond: if_.cond,
127 then,
128 or_else: next,
129 merge,
130 };
131 if let Some(merge) = merge {
132 self.program[merge].block_use.push(BlockUse::Merge);
133 }
134 self.current_block = Some(next);
135 }
136
137 pub(crate) fn parse_if_else(&mut self, if_else: Box<IfElse>) {
138 let current_block = self.current_block.unwrap();
139 let then = self.program.add_node(BasicBlock::default());
140 let or_else = self.program.add_node(BasicBlock::default());
141 let next = self.program.add_node(BasicBlock::default());
142 let mut merge = next;
143
144 self.program.add_edge(current_block, then, 0);
145 self.program.add_edge(current_block, or_else, 0);
146
147 self.current_block = Some(then);
148 let is_break = self.parse_scope(if_else.scope_if);
149
150 if let Some(current_block) = self.current_block {
151 self.program.add_edge(current_block, next, 0);
152 } else {
153 merge = self.ret;
155 }
156
157 self.current_block = Some(or_else);
158 let is_break = self.parse_scope(if_else.scope_else) || is_break;
159
160 if let Some(current_block) = self.current_block {
161 self.program.add_edge(current_block, next, 0);
162 } else {
163 merge = self.ret;
165 }
166
167 let merge = if is_break { None } else { Some(merge) };
168 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
169 cond: if_else.cond,
170 then,
171 or_else,
172 merge,
173 };
174 if let Some(merge) = merge {
175 self.program[merge].block_use.push(BlockUse::Merge);
176 }
177 self.current_block = Some(next);
178 }
179
180 pub(crate) fn parse_switch(&mut self, switch: Switch) {
181 let current_block = self.current_block.unwrap();
182 let next = self.program.add_node(BasicBlock::default());
183
184 let branches = switch
185 .cases
186 .into_iter()
187 .map(|(val, case)| {
188 let case_id = self.program.add_node(BasicBlock::default());
189 self.program.add_edge(current_block, case_id, 0);
190 self.current_block = Some(case_id);
191 let is_break = self.parse_scope(case);
192 let is_ret = if let Some(current_block) = self.current_block {
193 self.program.add_edge(current_block, next, 0);
194 false
195 } else {
196 !is_break
197 };
198 let val = match val.as_const().expect("Switch value must be constant") {
199 ConstantValue::Int(val) => unsafe { transmute::<i32, u32>(val as i32) },
200 ConstantValue::UInt(val) => val as u32,
201 _ => unreachable!("Switch cases must be integer"),
202 };
203 (val, case_id, is_break, is_ret)
204 })
205 .collect::<Vec<_>>();
206
207 let is_break_branch = branches.iter().any(|it| it.2);
208 let mut is_ret = branches.iter().any(|it| it.3);
209 let branches = branches
210 .into_iter()
211 .map(|it| (it.0, it.1))
212 .collect::<Vec<_>>();
213
214 let default = self.program.add_node(BasicBlock::default());
215 self.program.add_edge(current_block, default, 0);
216 self.current_block = Some(default);
217 let is_break_def = self.parse_scope(switch.scope_default);
218
219 if let Some(current_block) = self.current_block {
220 self.program.add_edge(current_block, next, 0);
221 } else {
222 is_ret = !is_break_def;
223 }
224
225 let merge = if is_break_def || is_break_branch {
226 None
227 } else if is_ret {
228 Some(self.ret)
229 } else {
230 self.program[next].block_use.push(BlockUse::Merge);
231 Some(next)
232 };
233
234 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Switch {
235 value: switch.value,
236 default,
237 branches,
238 merge,
239 };
240
241 self.current_block = Some(next);
242 }
243
244 fn parse_loop(&mut self, loop_: Loop) {
245 let current_block = self.current_block.unwrap();
246 let header = self.program.add_node(BasicBlock::default());
247 self.program.add_edge(current_block, header, 0);
248
249 let body = self.program.add_node(BasicBlock::default());
250 let next = self.program.add_node(BasicBlock::default());
251
252 self.program.add_edge(header, body, 0);
253
254 self.loop_break.push_back(next);
255
256 self.current_block = Some(body);
257 self.parse_scope(loop_.scope);
258 let continue_target = self.program.add_node(BasicBlock::default());
259 self.program[continue_target]
260 .block_use
261 .push(BlockUse::ContinueTarget);
262
263 self.loop_break.pop_back();
264
265 if let Some(current_block) = self.current_block {
266 self.program.add_edge(current_block, continue_target, 0);
267 }
268
269 self.program.add_edge(continue_target, header, 0);
270
271 *self.program[header].control_flow.borrow_mut() = ControlFlow::Loop {
272 body,
273 continue_target,
274 merge: next,
275 };
276 self.program[next].block_use.push(BlockUse::Merge);
277 self.current_block = Some(next);
278 }
279
280 fn parse_for_loop(&mut self, range_loop: RangeLoop) {
281 let step = range_loop.step.unwrap_or(1.into());
282
283 let i_id = match range_loop.i.kind {
284 VariableKind::LocalMut { id, .. } => id,
285 _ => unreachable!(),
286 };
287 let i = range_loop.i;
288 self.program.variables.insert(i_id, i.ty);
289
290 let assign = Instruction::new(Operation::Copy(range_loop.start), i);
291 self.current_block_mut().ops.borrow_mut().push(assign);
292
293 let current_block = self.current_block.unwrap();
294 let header = self.program.add_node(BasicBlock::default());
295 self.program.add_edge(current_block, header, 0);
296
297 let body = self.program.add_node(BasicBlock::default());
298 let next = self.program.add_node(BasicBlock::default());
299
300 self.program.add_edge(header, body, 0);
301 self.program.add_edge(header, next, 0);
302
303 self.loop_break.push_back(next);
304
305 self.current_block = Some(body);
306 self.parse_scope(range_loop.scope);
307
308 self.loop_break.pop_back();
309
310 let current_block = self.current_block.expect("For loop has no loopback path");
311
312 let continue_target = if self.program[current_block]
313 .block_use
314 .contains(&BlockUse::Merge)
315 {
316 let target = self.program.add_node(BasicBlock::default());
317 self.program.add_edge(current_block, target, 0);
318 target
319 } else {
320 current_block
321 };
322
323 self.program.add_edge(continue_target, header, 0);
324
325 self.program[continue_target]
326 .block_use
327 .push(BlockUse::ContinueTarget);
328 self.program[next].block_use.push(BlockUse::Merge);
329 self.current_block = Some(next);
330
331 self.insert_phi(header, i_id, range_loop.start.ty);
333 {
334 let op = match range_loop.inclusive {
335 true => Comparison::LowerEqual,
336 false => Comparison::Lower,
337 };
338 let tmp = *self.allocator.create_local(Type::scalar(ElemType::Bool));
339 self.program[header].ops.borrow_mut().push(Instruction::new(
340 op(BinaryOperator {
341 lhs: i,
342 rhs: range_loop.end,
343 }),
344 tmp,
345 ));
346
347 *self.program[header].control_flow.borrow_mut() = ControlFlow::LoopBreak {
348 break_cond: tmp,
349 body,
350 continue_target,
351 merge: next,
352 };
353 }
354 self.program[current_block]
355 .ops
356 .borrow_mut()
357 .push(Instruction::new(
358 Arithmetic::Add(BinaryOperator { lhs: i, rhs: step }),
359 i,
360 ));
361 }
362
363 pub(crate) fn split_critical_edges(&mut self) {
364 for block in self.node_ids() {
365 let successors = self.program.edges(block);
366 let successors = successors.map(|edge| (edge.id(), edge.target()));
367 let successors: Vec<_> = successors.collect();
368
369 if successors.len() > 1 {
370 let crit = successors
371 .iter()
372 .filter(|(_, b)| self.predecessors(*b).len() > 1)
373 .collect::<Vec<_>>();
374 for (edge, successor) in crit {
375 self.program.remove_edge(*edge);
376 let new_block = self.program.add_node(BasicBlock::default());
377 self.program.add_edge(block, new_block, 0);
378 self.program.add_edge(new_block, *successor, 0);
379 self.invalidate_structure();
380 update_phi(self, *successor, block, new_block);
381 update_control_flow(self, block, *successor, new_block);
382 }
383 }
384 }
385 }
386
387 pub(crate) fn split_free(&mut self) {
390 let mut splits = 0;
391 while self.split_free_inner() {
392 splits += 1;
393 }
394 if splits > 0 {
395 self.invalidate_structure();
396 }
397 }
398
399 fn split_free_inner(&mut self) -> bool {
400 let is_free =
401 |inst: &Instruction| matches!(inst.operation, Operation::Marker(Marker::Free(_)));
402
403 for block in self.node_ids() {
404 let ops = self.block(block).ops.clone();
405 let len = ops.borrow().num_elements();
406 let idx = ops.borrow().values().position(is_free);
407 if let Some(idx) = idx {
408 if idx > 0 {
410 self.split_block_after(block, idx - 1);
411 return true;
412 }
413 if idx < len - 1 {
414 self.split_block_after(block, idx);
415 return true;
416 }
417 }
418 }
419
420 false
421 }
422
423 fn split_block_after(&mut self, block: NodeIndex, idx: usize) -> NodeIndex {
425 let successors = self.successors(block);
426 let edges: Vec<EdgeIndex> = self
427 .program
428 .edges_directed(block, Direction::Outgoing)
429 .map(|it| it.id())
430 .collect();
431 for edge in edges {
432 self.program.remove_edge(edge);
433 }
434
435 let ops = self.block(block).ops.take();
436 let before: Vec<_> = ops.values().take(idx + 1).cloned().collect();
437 let after: Vec<_> = ops.values().skip(idx + 1).cloned().collect();
438 *self.block(block).ops.borrow_mut() = StableVec::from_iter(before);
439
440 let new_block = BasicBlock::default();
441 new_block.control_flow.swap(&self.block(block).control_flow);
442 new_block.ops.borrow_mut().extend(after);
443 let new_block = self.program.graph.add_node(new_block);
444
445 self.program.add_edge(block, new_block, 0);
446 for successor in successors {
447 self.program.add_edge(new_block, successor, 0);
448 }
449 new_block
450 }
451}
452
453fn update_control_flow(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
454 let update = |id: &mut NodeIndex| {
455 if *id == from {
456 *id = to
457 }
458 };
459
460 match &mut *opt.program[block].control_flow.borrow_mut() {
461 ControlFlow::IfElse { then, or_else, .. } => {
462 update(then);
463 update(or_else);
464 }
465 ControlFlow::Switch {
466 default, branches, ..
467 } => {
468 update(default);
469
470 for branch in branches {
471 update(&mut branch.1);
472 }
473 }
474 ControlFlow::Loop {
475 body,
476 continue_target,
477 merge,
478 } => {
479 update(body);
480 update(continue_target);
481 update(merge);
482 }
483 ControlFlow::LoopBreak {
484 body,
485 continue_target,
486 merge,
487 ..
488 } => {
489 update(body);
490 update(continue_target);
491 update(merge);
492 }
493 _ => {}
494 }
495}
496
497fn update_phi(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
498 for phi in opt.program[block].phi_nodes.borrow_mut().iter_mut() {
499 for entry in phi.entries.iter_mut() {
500 if entry.block == from {
501 entry.block = to;
502 }
503 }
504 }
505}