1use std::mem::transmute;
2
3use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer};
4use cubecl_core::ir::{
5 BinaryOperator, Branch, ConstantScalarValue, Elem, If, IfElse, Instruction, Item, Loop,
6 Operation, Operator, RangeLoop, Switch, Variable, VariableKind,
7};
8use petgraph::visit::EdgeRef;
9
10#[derive(Default, Debug, Clone)]
12pub enum ControlFlow {
13 IfElse {
15 cond: Variable,
16 then: NodeIndex,
17 or_else: NodeIndex,
18 merge: Option<NodeIndex>,
19 },
20 Switch {
22 value: Variable,
23 default: NodeIndex,
24 branches: Vec<(u32, NodeIndex)>,
25 merge: Option<NodeIndex>,
26 },
27 Loop {
30 body: NodeIndex,
31 continue_target: NodeIndex,
32 merge: NodeIndex,
33 },
34 LoopBreak {
38 break_cond: Variable,
39 body: NodeIndex,
40 continue_target: NodeIndex,
41 merge: NodeIndex,
42 },
43 Return,
46 #[default]
48 None,
49}
50
51impl Optimizer {
52 pub(crate) fn parse_control_flow(&mut self, branch: Branch) {
53 match branch {
54 Branch::If(if_) => self.parse_if(*if_),
55 Branch::IfElse(if_else) => self.parse_if_else(if_else),
56 Branch::Switch(switch) => self.parse_switch(*switch),
57 Branch::RangeLoop(range_loop) => {
58 self.parse_for_loop(*range_loop);
59 }
60 Branch::Loop(loop_) => self.parse_loop(*loop_),
61 Branch::Return => {
62 let current_block = self.current_block.take().unwrap();
63 let ret = self.ret();
64 self.program.add_edge(current_block, ret, ());
65 }
66 Branch::Break => {
67 let current_block = self.current_block.take().unwrap();
68 let loop_break = self.loop_break.back().expect("Can't break outside loop");
69 self.program.add_edge(current_block, *loop_break, ());
70 }
71 }
72 }
73
74 pub(crate) fn parse_if(&mut self, if_: If) {
75 let current_block = self.current_block.unwrap();
76 let then = self.program.add_node(BasicBlock::default());
77 let next = self.program.add_node(BasicBlock::default());
78 let mut merge = next;
79
80 self.program.add_edge(current_block, then, ());
81 self.program.add_edge(current_block, next, ());
82
83 self.current_block = Some(then);
84 let is_break = self.parse_scope(if_.scope);
85
86 if let Some(current_block) = self.current_block {
87 self.program.add_edge(current_block, next, ());
88 } else {
89 merge = self.ret;
91 }
92
93 let merge = if is_break { None } else { Some(merge) };
94
95 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
96 cond: if_.cond,
97 then,
98 or_else: next,
99 merge,
100 };
101 if let Some(merge) = merge {
102 self.program[merge].block_use.push(BlockUse::Merge);
103 }
104 self.current_block = Some(next);
105 }
106
107 pub(crate) fn parse_if_else(&mut self, if_else: Box<IfElse>) {
108 let current_block = self.current_block.unwrap();
109 let then = self.program.add_node(BasicBlock::default());
110 let or_else = self.program.add_node(BasicBlock::default());
111 let next = self.program.add_node(BasicBlock::default());
112 let mut merge = next;
113
114 self.program.add_edge(current_block, then, ());
115 self.program.add_edge(current_block, or_else, ());
116
117 self.current_block = Some(then);
118 let is_break = self.parse_scope(if_else.scope_if);
119
120 if let Some(current_block) = self.current_block {
121 self.program.add_edge(current_block, next, ());
122 } else {
123 merge = self.ret;
125 }
126
127 self.current_block = Some(or_else);
128 let is_break = self.parse_scope(if_else.scope_else) || is_break;
129
130 if let Some(current_block) = self.current_block {
131 self.program.add_edge(current_block, next, ());
132 } else {
133 merge = self.ret;
135 }
136
137 let merge = if is_break { None } else { Some(merge) };
138 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
139 cond: if_else.cond,
140 then,
141 or_else,
142 merge,
143 };
144 if let Some(merge) = merge {
145 self.program[merge].block_use.push(BlockUse::Merge);
146 }
147 self.current_block = Some(next);
148 }
149
150 pub(crate) fn parse_switch(&mut self, switch: Switch) {
151 let current_block = self.current_block.unwrap();
152 let next = self.program.add_node(BasicBlock::default());
153
154 let branches = switch
155 .cases
156 .into_iter()
157 .map(|(val, case)| {
158 let case_id = self.program.add_node(BasicBlock::default());
159 self.program.add_edge(current_block, case_id, ());
160 self.current_block = Some(case_id);
161 let is_break = self.parse_scope(case);
162 let is_ret = if let Some(current_block) = self.current_block {
163 self.program.add_edge(current_block, next, ());
164 false
165 } else {
166 !is_break
167 };
168 let val = match val.as_const().expect("Switch value must be constant") {
169 ConstantScalarValue::Int(val, _) => unsafe {
170 transmute::<i32, u32>(val as i32)
171 },
172 ConstantScalarValue::UInt(val, _) => val as u32,
173 _ => unreachable!("Switch cases must be integer"),
174 };
175 (val, case_id, is_break, is_ret)
176 })
177 .collect::<Vec<_>>();
178
179 let is_break_branch = branches.iter().any(|it| it.2);
180 let mut is_ret = branches.iter().any(|it| it.3);
181 let branches = branches
182 .into_iter()
183 .map(|it| (it.0, it.1))
184 .collect::<Vec<_>>();
185
186 let default = self.program.add_node(BasicBlock::default());
187 self.program.add_edge(current_block, default, ());
188 self.current_block = Some(default);
189 let is_break_def = self.parse_scope(switch.scope_default);
190
191 if let Some(current_block) = self.current_block {
192 self.program.add_edge(current_block, next, ());
193 } else {
194 is_ret = !is_break_def;
195 }
196
197 let merge = if is_break_def || is_break_branch {
198 None
199 } else if is_ret {
200 Some(self.ret)
201 } else {
202 self.program[next].block_use.push(BlockUse::Merge);
203 Some(next)
204 };
205
206 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Switch {
207 value: switch.value,
208 default,
209 branches,
210 merge,
211 };
212
213 self.current_block = Some(next);
214 }
215
216 fn parse_loop(&mut self, loop_: Loop) {
217 let current_block = self.current_block.unwrap();
218 let header = self.program.add_node(BasicBlock::default());
219 self.program.add_edge(current_block, header, ());
220
221 let body = self.program.add_node(BasicBlock::default());
222 let next = self.program.add_node(BasicBlock::default());
223
224 self.program.add_edge(header, body, ());
225
226 self.loop_break.push_back(next);
227
228 self.current_block = Some(body);
229 self.parse_scope(loop_.scope);
230 let continue_target = self.program.add_node(BasicBlock::default());
231 self.program[continue_target]
232 .block_use
233 .push(BlockUse::ContinueTarget);
234
235 self.loop_break.pop_back();
236
237 if let Some(current_block) = self.current_block {
238 self.program.add_edge(current_block, continue_target, ());
239 }
240
241 self.program.add_edge(continue_target, header, ());
242
243 *self.program[header].control_flow.borrow_mut() = ControlFlow::Loop {
244 body,
245 continue_target,
246 merge: next,
247 };
248 self.program[next].block_use.push(BlockUse::Merge);
249 self.current_block = Some(next);
250 }
251
252 fn parse_for_loop(&mut self, range_loop: RangeLoop) {
253 let step = range_loop.step.unwrap_or(1.into());
254
255 let i_id = match range_loop.i.kind {
256 VariableKind::LocalMut { id, .. } => id,
257 _ => unreachable!(),
258 };
259 let i = range_loop.i;
260 self.program.variables.insert(i_id, i.item);
261
262 let assign = Instruction::new(Operation::Copy(range_loop.start), i);
263 self.current_block_mut().ops.borrow_mut().push(assign);
264
265 let current_block = self.current_block.unwrap();
266 let header = self.program.add_node(BasicBlock::default());
267 self.program.add_edge(current_block, header, ());
268
269 let body = self.program.add_node(BasicBlock::default());
270 let next = self.program.add_node(BasicBlock::default());
271
272 self.program.add_edge(header, body, ());
273 self.program.add_edge(header, next, ());
274
275 self.loop_break.push_back(next);
276
277 self.current_block = Some(body);
278 self.parse_scope(range_loop.scope);
279
280 self.loop_break.pop_back();
281
282 let current_block = self.current_block.expect("For loop has no loopback path");
283
284 let continue_target = if self.program[current_block]
285 .block_use
286 .contains(&BlockUse::Merge)
287 {
288 let target = self.program.add_node(BasicBlock::default());
289 self.program.add_edge(current_block, target, ());
290 target
291 } else {
292 current_block
293 };
294
295 self.program.add_edge(continue_target, header, ());
296
297 self.program[continue_target]
298 .block_use
299 .push(BlockUse::ContinueTarget);
300 self.program[next].block_use.push(BlockUse::Merge);
301 self.current_block = Some(next);
302
303 self.insert_phi(header, i_id, range_loop.start.item);
305 {
306 let op = match range_loop.inclusive {
307 true => Operator::LowerEqual,
308 false => Operator::Lower,
309 };
310 let tmp = *self
311 .allocator
312 .create_local_restricted(Item::new(Elem::Bool));
313 self.program[header].ops.borrow_mut().push(Instruction::new(
314 op(BinaryOperator {
315 lhs: i,
316 rhs: range_loop.end,
317 }),
318 tmp,
319 ));
320
321 *self.program[header].control_flow.borrow_mut() = ControlFlow::LoopBreak {
322 break_cond: tmp,
323 body,
324 continue_target,
325 merge: next,
326 };
327 }
328 self.program[current_block]
329 .ops
330 .borrow_mut()
331 .push(Instruction::new(
332 Operator::Add(BinaryOperator { lhs: i, rhs: step }),
333 i,
334 ));
335 }
336
337 pub(crate) fn split_critical_edges(&mut self) {
338 for block in self.node_ids() {
339 let successors = self.program.edges(block);
340 let successors = successors.map(|edge| (edge.id(), edge.target()));
341 let successors: Vec<_> = successors.collect();
342
343 if successors.len() > 1 {
344 let crit = successors
345 .iter()
346 .filter(|(_, b)| self.predecessors(*b).len() > 1)
347 .collect::<Vec<_>>();
348 for (edge, successor) in crit {
349 self.program.remove_edge(*edge);
350 let new_block = self.program.add_node(BasicBlock::default());
351 self.program.add_edge(block, new_block, ());
352 self.program.add_edge(new_block, *successor, ());
353 self.invalidate_structure();
354 update_phi(self, *successor, block, new_block);
355 update_control_flow(self, block, *successor, new_block);
356 }
357 }
358 }
359 }
360}
361
362fn update_control_flow(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
363 let update = |id: &mut NodeIndex| {
364 if *id == from {
365 *id = to
366 }
367 };
368
369 match &mut *opt.program[block].control_flow.borrow_mut() {
370 ControlFlow::IfElse { then, or_else, .. } => {
371 update(then);
372 update(or_else);
373 }
374 ControlFlow::Switch {
375 default, branches, ..
376 } => {
377 update(default);
378
379 for branch in branches {
380 update(&mut branch.1);
381 }
382 }
383 _ => {}
384 }
385}
386
387fn update_phi(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
388 for phi in opt.program[block].phi_nodes.borrow_mut().iter_mut() {
389 for entry in phi.entries.iter_mut() {
390 if entry.block == from {
391 entry.block = to;
392 }
393 }
394 }
395}