1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer};
6use cubecl_ir::{
7 Arithmetic, BinaryOperator, Branch, Comparison, ConstantScalarValue, Elem, If, IfElse,
8 Instruction, Item, Loop, Operation, RangeLoop, Switch, Variable, VariableKind,
9};
10use petgraph::visit::EdgeRef;
11
12#[derive(Default, Debug, Clone)]
14pub enum ControlFlow {
15 IfElse {
17 cond: Variable,
18 then: NodeIndex,
19 or_else: NodeIndex,
20 merge: Option<NodeIndex>,
21 },
22 Switch {
24 value: Variable,
25 default: NodeIndex,
26 branches: Vec<(u32, NodeIndex)>,
27 merge: Option<NodeIndex>,
28 },
29 Loop {
32 body: NodeIndex,
33 continue_target: NodeIndex,
34 merge: NodeIndex,
35 },
36 LoopBreak {
40 break_cond: Variable,
41 body: NodeIndex,
42 continue_target: NodeIndex,
43 merge: NodeIndex,
44 },
45 Return,
48 #[default]
50 None,
51}
52
53impl Optimizer {
54 pub(crate) fn parse_control_flow(&mut self, branch: Branch) {
55 match branch {
56 Branch::If(if_) => self.parse_if(*if_),
57 Branch::IfElse(if_else) => self.parse_if_else(if_else),
58 Branch::Switch(switch) => self.parse_switch(*switch),
59 Branch::RangeLoop(range_loop) => {
60 self.parse_for_loop(*range_loop);
61 }
62 Branch::Loop(loop_) => self.parse_loop(*loop_),
63 Branch::Return => {
64 let current_block = self.current_block.take().unwrap();
65 let ret = self.ret();
66 self.program.add_edge(current_block, ret, 0);
67 }
68 Branch::Break => {
69 let current_block = self.current_block.take().unwrap();
70 let loop_break = self.loop_break.back().expect("Can't break outside loop");
71 self.program.add_edge(current_block, *loop_break, 0);
72 }
73 }
74 }
75
76 pub(crate) fn parse_if(&mut self, if_: If) {
77 let current_block = self.current_block.unwrap();
78 let then = self.program.add_node(BasicBlock::default());
79 let next = self.program.add_node(BasicBlock::default());
80 let mut merge = next;
81
82 self.program.add_edge(current_block, then, 0);
83 self.program.add_edge(current_block, next, 0);
84
85 self.current_block = Some(then);
86 let is_break = self.parse_scope(if_.scope);
87
88 if let Some(current_block) = self.current_block {
89 self.program.add_edge(current_block, next, 0);
90 } else {
91 merge = self.ret;
93 }
94
95 let merge = if is_break { None } else { Some(merge) };
96
97 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
98 cond: if_.cond,
99 then,
100 or_else: next,
101 merge,
102 };
103 if let Some(merge) = merge {
104 self.program[merge].block_use.push(BlockUse::Merge);
105 }
106 self.current_block = Some(next);
107 }
108
109 pub(crate) fn parse_if_else(&mut self, if_else: Box<IfElse>) {
110 let current_block = self.current_block.unwrap();
111 let then = self.program.add_node(BasicBlock::default());
112 let or_else = self.program.add_node(BasicBlock::default());
113 let next = self.program.add_node(BasicBlock::default());
114 let mut merge = next;
115
116 self.program.add_edge(current_block, then, 0);
117 self.program.add_edge(current_block, or_else, 0);
118
119 self.current_block = Some(then);
120 let is_break = self.parse_scope(if_else.scope_if);
121
122 if let Some(current_block) = self.current_block {
123 self.program.add_edge(current_block, next, 0);
124 } else {
125 merge = self.ret;
127 }
128
129 self.current_block = Some(or_else);
130 let is_break = self.parse_scope(if_else.scope_else) || is_break;
131
132 if let Some(current_block) = self.current_block {
133 self.program.add_edge(current_block, next, 0);
134 } else {
135 merge = self.ret;
137 }
138
139 let merge = if is_break { None } else { Some(merge) };
140 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
141 cond: if_else.cond,
142 then,
143 or_else,
144 merge,
145 };
146 if let Some(merge) = merge {
147 self.program[merge].block_use.push(BlockUse::Merge);
148 }
149 self.current_block = Some(next);
150 }
151
152 pub(crate) fn parse_switch(&mut self, switch: Switch) {
153 let current_block = self.current_block.unwrap();
154 let next = self.program.add_node(BasicBlock::default());
155
156 let branches = switch
157 .cases
158 .into_iter()
159 .map(|(val, case)| {
160 let case_id = self.program.add_node(BasicBlock::default());
161 self.program.add_edge(current_block, case_id, 0);
162 self.current_block = Some(case_id);
163 let is_break = self.parse_scope(case);
164 let is_ret = if let Some(current_block) = self.current_block {
165 self.program.add_edge(current_block, next, 0);
166 false
167 } else {
168 !is_break
169 };
170 let val = match val.as_const().expect("Switch value must be constant") {
171 ConstantScalarValue::Int(val, _) => unsafe {
172 transmute::<i32, u32>(val as i32)
173 },
174 ConstantScalarValue::UInt(val, _) => val as u32,
175 _ => unreachable!("Switch cases must be integer"),
176 };
177 (val, case_id, is_break, is_ret)
178 })
179 .collect::<Vec<_>>();
180
181 let is_break_branch = branches.iter().any(|it| it.2);
182 let mut is_ret = branches.iter().any(|it| it.3);
183 let branches = branches
184 .into_iter()
185 .map(|it| (it.0, it.1))
186 .collect::<Vec<_>>();
187
188 let default = self.program.add_node(BasicBlock::default());
189 self.program.add_edge(current_block, default, 0);
190 self.current_block = Some(default);
191 let is_break_def = self.parse_scope(switch.scope_default);
192
193 if let Some(current_block) = self.current_block {
194 self.program.add_edge(current_block, next, 0);
195 } else {
196 is_ret = !is_break_def;
197 }
198
199 let merge = if is_break_def || is_break_branch {
200 None
201 } else if is_ret {
202 Some(self.ret)
203 } else {
204 self.program[next].block_use.push(BlockUse::Merge);
205 Some(next)
206 };
207
208 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Switch {
209 value: switch.value,
210 default,
211 branches,
212 merge,
213 };
214
215 self.current_block = Some(next);
216 }
217
218 fn parse_loop(&mut self, loop_: Loop) {
219 let current_block = self.current_block.unwrap();
220 let header = self.program.add_node(BasicBlock::default());
221 self.program.add_edge(current_block, header, 0);
222
223 let body = self.program.add_node(BasicBlock::default());
224 let next = self.program.add_node(BasicBlock::default());
225
226 self.program.add_edge(header, body, 0);
227
228 self.loop_break.push_back(next);
229
230 self.current_block = Some(body);
231 self.parse_scope(loop_.scope);
232 let continue_target = self.program.add_node(BasicBlock::default());
233 self.program[continue_target]
234 .block_use
235 .push(BlockUse::ContinueTarget);
236
237 self.loop_break.pop_back();
238
239 if let Some(current_block) = self.current_block {
240 self.program.add_edge(current_block, continue_target, 0);
241 }
242
243 self.program.add_edge(continue_target, header, 0);
244
245 *self.program[header].control_flow.borrow_mut() = ControlFlow::Loop {
246 body,
247 continue_target,
248 merge: next,
249 };
250 self.program[next].block_use.push(BlockUse::Merge);
251 self.current_block = Some(next);
252 }
253
254 fn parse_for_loop(&mut self, range_loop: RangeLoop) {
255 let step = range_loop.step.unwrap_or(1.into());
256
257 let i_id = match range_loop.i.kind {
258 VariableKind::LocalMut { id, .. } => id,
259 _ => unreachable!(),
260 };
261 let i = range_loop.i;
262 self.program.variables.insert(i_id, i.item);
263
264 let assign = Instruction::new(Operation::Copy(range_loop.start), i);
265 self.current_block_mut().ops.borrow_mut().push(assign);
266
267 let current_block = self.current_block.unwrap();
268 let header = self.program.add_node(BasicBlock::default());
269 self.program.add_edge(current_block, header, 0);
270
271 let body = self.program.add_node(BasicBlock::default());
272 let next = self.program.add_node(BasicBlock::default());
273
274 self.program.add_edge(header, body, 0);
275 self.program.add_edge(header, next, 0);
276
277 self.loop_break.push_back(next);
278
279 self.current_block = Some(body);
280 self.parse_scope(range_loop.scope);
281
282 self.loop_break.pop_back();
283
284 let current_block = self.current_block.expect("For loop has no loopback path");
285
286 let continue_target = if self.program[current_block]
287 .block_use
288 .contains(&BlockUse::Merge)
289 {
290 let target = self.program.add_node(BasicBlock::default());
291 self.program.add_edge(current_block, target, 0);
292 target
293 } else {
294 current_block
295 };
296
297 self.program.add_edge(continue_target, header, 0);
298
299 self.program[continue_target]
300 .block_use
301 .push(BlockUse::ContinueTarget);
302 self.program[next].block_use.push(BlockUse::Merge);
303 self.current_block = Some(next);
304
305 self.insert_phi(header, i_id, range_loop.start.item);
307 {
308 let op = match range_loop.inclusive {
309 true => Comparison::LowerEqual,
310 false => Comparison::Lower,
311 };
312 let tmp = *self.allocator.create_local(Item::new(Elem::Bool));
313 self.program[header].ops.borrow_mut().push(Instruction::new(
314 op(BinaryOperator {
315 lhs: i,
316 rhs: range_loop.end,
317 }),
318 tmp,
319 ));
320
321 *self.program[header].control_flow.borrow_mut() = ControlFlow::LoopBreak {
322 break_cond: tmp,
323 body,
324 continue_target,
325 merge: next,
326 };
327 }
328 self.program[current_block]
329 .ops
330 .borrow_mut()
331 .push(Instruction::new(
332 Arithmetic::Add(BinaryOperator { lhs: i, rhs: step }),
333 i,
334 ));
335 }
336
337 pub(crate) fn split_critical_edges(&mut self) {
338 for block in self.node_ids() {
339 let successors = self.program.edges(block);
340 let successors = successors.map(|edge| (edge.id(), edge.target()));
341 let successors: Vec<_> = successors.collect();
342
343 if successors.len() > 1 {
344 let crit = successors
345 .iter()
346 .filter(|(_, b)| self.predecessors(*b).len() > 1)
347 .collect::<Vec<_>>();
348 for (edge, successor) in crit {
349 self.program.remove_edge(*edge);
350 let new_block = self.program.add_node(BasicBlock::default());
351 self.program.add_edge(block, new_block, 0);
352 self.program.add_edge(new_block, *successor, 0);
353 self.invalidate_structure();
354 update_phi(self, *successor, block, new_block);
355 update_control_flow(self, block, *successor, new_block);
356 }
357 }
358 }
359 }
360}
361
362fn update_control_flow(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
363 let update = |id: &mut NodeIndex| {
364 if *id == from {
365 *id = to
366 }
367 };
368
369 match &mut *opt.program[block].control_flow.borrow_mut() {
370 ControlFlow::IfElse { then, or_else, .. } => {
371 update(then);
372 update(or_else);
373 }
374 ControlFlow::Switch {
375 default, branches, ..
376 } => {
377 update(default);
378
379 for branch in branches {
380 update(&mut branch.1);
381 }
382 }
383 _ => {}
384 }
385}
386
387fn update_phi(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
388 for phi in opt.program[block].phi_nodes.borrow_mut().iter_mut() {
389 for entry in phi.entries.iter_mut() {
390 if entry.block == from {
391 entry.block = to;
392 }
393 }
394 }
395}