1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer};
6use cubecl_ir::{
7 Arithmetic, BinaryOperator, Branch, Comparison, ConstantScalarValue, ElemType, If, IfElse,
8 Instruction, Loop, Marker, Operation, RangeLoop, Switch, Type, Variable, VariableKind,
9};
10use petgraph::{Direction, graph::EdgeIndex, visit::EdgeRef};
11use stable_vec::StableVec;
12
13#[derive(Default, Debug, Clone)]
15pub enum ControlFlow {
16 IfElse {
18 cond: Variable,
19 then: NodeIndex,
20 or_else: NodeIndex,
21 merge: Option<NodeIndex>,
22 },
23 Switch {
25 value: Variable,
26 default: NodeIndex,
27 branches: Vec<(u32, NodeIndex)>,
28 merge: Option<NodeIndex>,
29 },
30 Loop {
33 body: NodeIndex,
34 continue_target: NodeIndex,
35 merge: NodeIndex,
36 },
37 LoopBreak {
41 break_cond: Variable,
42 body: NodeIndex,
43 continue_target: NodeIndex,
44 merge: NodeIndex,
45 },
46 Return,
49 #[default]
51 None,
52}
53
54impl Optimizer {
55 pub(crate) fn parse_control_flow(&mut self, branch: Branch) {
56 match branch {
57 Branch::If(if_) => self.parse_if(*if_),
58 Branch::IfElse(if_else) => self.parse_if_else(if_else),
59 Branch::Switch(switch) => self.parse_switch(*switch),
60 Branch::RangeLoop(range_loop) => {
61 self.parse_for_loop(*range_loop);
62 }
63 Branch::Loop(loop_) => self.parse_loop(*loop_),
64 Branch::Return => {
65 let current_block = self.current_block.take().unwrap();
66 let ret = self.ret();
67 self.program.add_edge(current_block, ret, 0);
68 }
69 Branch::Break => {
70 let current_block = self.current_block.take().unwrap();
71 let loop_break = self.loop_break.back().expect("Can't break outside loop");
72 self.program.add_edge(current_block, *loop_break, 0);
73 }
74 }
75 }
76
77 pub(crate) fn parse_if(&mut self, if_: If) {
78 let current_block = self.current_block.unwrap();
79 let then = self.program.add_node(BasicBlock::default());
80 let next = self.program.add_node(BasicBlock::default());
81 let mut merge = next;
82
83 self.program.add_edge(current_block, then, 0);
84 self.program.add_edge(current_block, next, 0);
85
86 self.current_block = Some(then);
87 let is_break = self.parse_scope(if_.scope);
88
89 if let Some(current_block) = self.current_block {
90 self.program.add_edge(current_block, next, 0);
91 } else {
92 merge = self.ret;
94 }
95
96 let merge = if is_break { None } else { Some(merge) };
97
98 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
99 cond: if_.cond,
100 then,
101 or_else: next,
102 merge,
103 };
104 if let Some(merge) = merge {
105 self.program[merge].block_use.push(BlockUse::Merge);
106 }
107 self.current_block = Some(next);
108 }
109
110 pub(crate) fn parse_if_else(&mut self, if_else: Box<IfElse>) {
111 let current_block = self.current_block.unwrap();
112 let then = self.program.add_node(BasicBlock::default());
113 let or_else = self.program.add_node(BasicBlock::default());
114 let next = self.program.add_node(BasicBlock::default());
115 let mut merge = next;
116
117 self.program.add_edge(current_block, then, 0);
118 self.program.add_edge(current_block, or_else, 0);
119
120 self.current_block = Some(then);
121 let is_break = self.parse_scope(if_else.scope_if);
122
123 if let Some(current_block) = self.current_block {
124 self.program.add_edge(current_block, next, 0);
125 } else {
126 merge = self.ret;
128 }
129
130 self.current_block = Some(or_else);
131 let is_break = self.parse_scope(if_else.scope_else) || is_break;
132
133 if let Some(current_block) = self.current_block {
134 self.program.add_edge(current_block, next, 0);
135 } else {
136 merge = self.ret;
138 }
139
140 let merge = if is_break { None } else { Some(merge) };
141 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
142 cond: if_else.cond,
143 then,
144 or_else,
145 merge,
146 };
147 if let Some(merge) = merge {
148 self.program[merge].block_use.push(BlockUse::Merge);
149 }
150 self.current_block = Some(next);
151 }
152
153 pub(crate) fn parse_switch(&mut self, switch: Switch) {
154 let current_block = self.current_block.unwrap();
155 let next = self.program.add_node(BasicBlock::default());
156
157 let branches = switch
158 .cases
159 .into_iter()
160 .map(|(val, case)| {
161 let case_id = self.program.add_node(BasicBlock::default());
162 self.program.add_edge(current_block, case_id, 0);
163 self.current_block = Some(case_id);
164 let is_break = self.parse_scope(case);
165 let is_ret = if let Some(current_block) = self.current_block {
166 self.program.add_edge(current_block, next, 0);
167 false
168 } else {
169 !is_break
170 };
171 let val = match val.as_const().expect("Switch value must be constant") {
172 ConstantScalarValue::Int(val, _) => unsafe {
173 transmute::<i32, u32>(val as i32)
174 },
175 ConstantScalarValue::UInt(val, _) => val as u32,
176 _ => unreachable!("Switch cases must be integer"),
177 };
178 (val, case_id, is_break, is_ret)
179 })
180 .collect::<Vec<_>>();
181
182 let is_break_branch = branches.iter().any(|it| it.2);
183 let mut is_ret = branches.iter().any(|it| it.3);
184 let branches = branches
185 .into_iter()
186 .map(|it| (it.0, it.1))
187 .collect::<Vec<_>>();
188
189 let default = self.program.add_node(BasicBlock::default());
190 self.program.add_edge(current_block, default, 0);
191 self.current_block = Some(default);
192 let is_break_def = self.parse_scope(switch.scope_default);
193
194 if let Some(current_block) = self.current_block {
195 self.program.add_edge(current_block, next, 0);
196 } else {
197 is_ret = !is_break_def;
198 }
199
200 let merge = if is_break_def || is_break_branch {
201 None
202 } else if is_ret {
203 Some(self.ret)
204 } else {
205 self.program[next].block_use.push(BlockUse::Merge);
206 Some(next)
207 };
208
209 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Switch {
210 value: switch.value,
211 default,
212 branches,
213 merge,
214 };
215
216 self.current_block = Some(next);
217 }
218
219 fn parse_loop(&mut self, loop_: Loop) {
220 let current_block = self.current_block.unwrap();
221 let header = self.program.add_node(BasicBlock::default());
222 self.program.add_edge(current_block, header, 0);
223
224 let body = self.program.add_node(BasicBlock::default());
225 let next = self.program.add_node(BasicBlock::default());
226
227 self.program.add_edge(header, body, 0);
228
229 self.loop_break.push_back(next);
230
231 self.current_block = Some(body);
232 self.parse_scope(loop_.scope);
233 let continue_target = self.program.add_node(BasicBlock::default());
234 self.program[continue_target]
235 .block_use
236 .push(BlockUse::ContinueTarget);
237
238 self.loop_break.pop_back();
239
240 if let Some(current_block) = self.current_block {
241 self.program.add_edge(current_block, continue_target, 0);
242 }
243
244 self.program.add_edge(continue_target, header, 0);
245
246 *self.program[header].control_flow.borrow_mut() = ControlFlow::Loop {
247 body,
248 continue_target,
249 merge: next,
250 };
251 self.program[next].block_use.push(BlockUse::Merge);
252 self.current_block = Some(next);
253 }
254
255 fn parse_for_loop(&mut self, range_loop: RangeLoop) {
256 let step = range_loop.step.unwrap_or(1.into());
257
258 let i_id = match range_loop.i.kind {
259 VariableKind::LocalMut { id, .. } => id,
260 _ => unreachable!(),
261 };
262 let i = range_loop.i;
263 self.program.variables.insert(i_id, i.ty);
264
265 let assign = Instruction::new(Operation::Copy(range_loop.start), i);
266 self.current_block_mut().ops.borrow_mut().push(assign);
267
268 let current_block = self.current_block.unwrap();
269 let header = self.program.add_node(BasicBlock::default());
270 self.program.add_edge(current_block, header, 0);
271
272 let body = self.program.add_node(BasicBlock::default());
273 let next = self.program.add_node(BasicBlock::default());
274
275 self.program.add_edge(header, body, 0);
276 self.program.add_edge(header, next, 0);
277
278 self.loop_break.push_back(next);
279
280 self.current_block = Some(body);
281 self.parse_scope(range_loop.scope);
282
283 self.loop_break.pop_back();
284
285 let current_block = self.current_block.expect("For loop has no loopback path");
286
287 let continue_target = if self.program[current_block]
288 .block_use
289 .contains(&BlockUse::Merge)
290 {
291 let target = self.program.add_node(BasicBlock::default());
292 self.program.add_edge(current_block, target, 0);
293 target
294 } else {
295 current_block
296 };
297
298 self.program.add_edge(continue_target, header, 0);
299
300 self.program[continue_target]
301 .block_use
302 .push(BlockUse::ContinueTarget);
303 self.program[next].block_use.push(BlockUse::Merge);
304 self.current_block = Some(next);
305
306 self.insert_phi(header, i_id, range_loop.start.ty);
308 {
309 let op = match range_loop.inclusive {
310 true => Comparison::LowerEqual,
311 false => Comparison::Lower,
312 };
313 let tmp = *self.allocator.create_local(Type::scalar(ElemType::Bool));
314 self.program[header].ops.borrow_mut().push(Instruction::new(
315 op(BinaryOperator {
316 lhs: i,
317 rhs: range_loop.end,
318 }),
319 tmp,
320 ));
321
322 *self.program[header].control_flow.borrow_mut() = ControlFlow::LoopBreak {
323 break_cond: tmp,
324 body,
325 continue_target,
326 merge: next,
327 };
328 }
329 self.program[current_block]
330 .ops
331 .borrow_mut()
332 .push(Instruction::new(
333 Arithmetic::Add(BinaryOperator { lhs: i, rhs: step }),
334 i,
335 ));
336 }
337
338 pub(crate) fn split_critical_edges(&mut self) {
339 for block in self.node_ids() {
340 let successors = self.program.edges(block);
341 let successors = successors.map(|edge| (edge.id(), edge.target()));
342 let successors: Vec<_> = successors.collect();
343
344 if successors.len() > 1 {
345 let crit = successors
346 .iter()
347 .filter(|(_, b)| self.predecessors(*b).len() > 1)
348 .collect::<Vec<_>>();
349 for (edge, successor) in crit {
350 self.program.remove_edge(*edge);
351 let new_block = self.program.add_node(BasicBlock::default());
352 self.program.add_edge(block, new_block, 0);
353 self.program.add_edge(new_block, *successor, 0);
354 self.invalidate_structure();
355 update_phi(self, *successor, block, new_block);
356 update_control_flow(self, block, *successor, new_block);
357 }
358 }
359 }
360 }
361
362 pub(crate) fn split_free(&mut self) {
365 let mut splits = 0;
366 while self.split_free_inner() {
367 splits += 1;
368 }
369 if splits > 0 {
370 self.invalidate_structure();
371 }
372 }
373
374 fn split_free_inner(&mut self) -> bool {
375 let is_free =
376 |inst: &Instruction| matches!(inst.operation, Operation::Marker(Marker::Free(_)));
377
378 for block in self.node_ids() {
379 let ops = self.block(block).ops.clone();
380 let len = ops.borrow().num_elements();
381 let idx = ops.borrow().values().position(is_free);
382 if let Some(idx) = idx {
383 if idx > 0 {
385 self.split_block_after(block, idx - 1);
386 return true;
387 }
388 if idx < len - 1 {
389 self.split_block_after(block, idx);
390 return true;
391 }
392 }
393 }
394
395 false
396 }
397
398 fn split_block_after(&mut self, block: NodeIndex, idx: usize) -> NodeIndex {
400 let successors = self.successors(block);
401 let edges: Vec<EdgeIndex> = self
402 .program
403 .edges_directed(block, Direction::Outgoing)
404 .map(|it| it.id())
405 .collect();
406 for edge in edges {
407 self.program.remove_edge(edge);
408 }
409
410 let ops = self.block(block).ops.take();
411 let before: Vec<_> = ops.values().take(idx + 1).cloned().collect();
412 let after: Vec<_> = ops.values().skip(idx + 1).cloned().collect();
413 *self.block(block).ops.borrow_mut() = StableVec::from_iter(before);
414
415 let new_block = BasicBlock::default();
416 new_block.control_flow.swap(&self.block(block).control_flow);
417 new_block.ops.borrow_mut().extend(after);
418 let new_block = self.program.graph.add_node(new_block);
419
420 self.program.add_edge(block, new_block, 0);
421 for successor in successors {
422 self.program.add_edge(new_block, successor, 0);
423 }
424 new_block
425 }
426}
427
428fn update_control_flow(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
429 let update = |id: &mut NodeIndex| {
430 if *id == from {
431 *id = to
432 }
433 };
434
435 match &mut *opt.program[block].control_flow.borrow_mut() {
436 ControlFlow::IfElse { then, or_else, .. } => {
437 update(then);
438 update(or_else);
439 }
440 ControlFlow::Switch {
441 default, branches, ..
442 } => {
443 update(default);
444
445 for branch in branches {
446 update(&mut branch.1);
447 }
448 }
449 _ => {}
450 }
451}
452
453fn update_phi(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
454 for phi in opt.program[block].phi_nodes.borrow_mut().iter_mut() {
455 for entry in phi.entries.iter_mut() {
456 if entry.block == from {
457 entry.block = to;
458 }
459 }
460 }
461}