1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer};
6use cubecl_ir::{
7 Arithmetic, BinaryOperator, Branch, Comparison, ConstantValue, ElemType, If, IfElse,
8 Instruction, Loop, Marker, Operation, RangeLoop, Switch, Type, Variable, VariableKind,
9};
10use petgraph::{Direction, graph::EdgeIndex, visit::EdgeRef};
11use stable_vec::StableVec;
12
13#[derive(Default, Debug, Clone)]
15pub enum ControlFlow {
16 IfElse {
18 cond: Variable,
19 then: NodeIndex,
20 or_else: NodeIndex,
21 merge: Option<NodeIndex>,
22 },
23 Switch {
25 value: Variable,
26 default: NodeIndex,
27 branches: Vec<(u32, NodeIndex)>,
28 merge: Option<NodeIndex>,
29 },
30 Loop {
33 body: NodeIndex,
34 continue_target: NodeIndex,
35 merge: NodeIndex,
36 },
37 LoopBreak {
41 break_cond: Variable,
42 body: NodeIndex,
43 continue_target: NodeIndex,
44 merge: NodeIndex,
45 },
46 Return,
49 #[default]
51 None,
52}
53
54impl Optimizer {
55 pub(crate) fn parse_control_flow(&mut self, branch: Branch) {
56 match branch {
57 Branch::If(if_) => self.parse_if(*if_),
58 Branch::IfElse(if_else) => self.parse_if_else(if_else),
59 Branch::Switch(switch) => self.parse_switch(*switch),
60 Branch::RangeLoop(range_loop) => {
61 self.parse_for_loop(*range_loop);
62 }
63 Branch::Loop(loop_) => self.parse_loop(*loop_),
64 Branch::Return => {
65 let current_block = self.current_block.take().unwrap();
66 let ret = self.ret();
67 self.program.add_edge(current_block, ret, 0);
68 }
69 Branch::Break => {
70 let current_block = self.current_block.take().unwrap();
71 let loop_break = self.loop_break.back().expect("Can't break outside loop");
72 self.program.add_edge(current_block, *loop_break, 0);
73 }
74 }
75 }
76
77 pub(crate) fn parse_if(&mut self, if_: If) {
78 let current_block = self.current_block.unwrap();
79 let then = self.program.add_node(BasicBlock::default());
80 let next = self.program.add_node(BasicBlock::default());
81 let mut merge = next;
82
83 self.program.add_edge(current_block, then, 0);
84 self.program.add_edge(current_block, next, 0);
85
86 self.current_block = Some(then);
87 let is_break = self.parse_scope(if_.scope);
88
89 if let Some(current_block) = self.current_block {
90 self.program.add_edge(current_block, next, 0);
91 } else {
92 merge = self.ret;
94 }
95
96 let merge = if is_break { None } else { Some(merge) };
97
98 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
99 cond: if_.cond,
100 then,
101 or_else: next,
102 merge,
103 };
104 if let Some(merge) = merge {
105 self.program[merge].block_use.push(BlockUse::Merge);
106 }
107 self.current_block = Some(next);
108 }
109
110 pub(crate) fn parse_if_else(&mut self, if_else: Box<IfElse>) {
111 let current_block = self.current_block.unwrap();
112 let then = self.program.add_node(BasicBlock::default());
113 let or_else = self.program.add_node(BasicBlock::default());
114 let next = self.program.add_node(BasicBlock::default());
115 let mut merge = next;
116
117 self.program.add_edge(current_block, then, 0);
118 self.program.add_edge(current_block, or_else, 0);
119
120 self.current_block = Some(then);
121 let is_break = self.parse_scope(if_else.scope_if);
122
123 if let Some(current_block) = self.current_block {
124 self.program.add_edge(current_block, next, 0);
125 } else {
126 merge = self.ret;
128 }
129
130 self.current_block = Some(or_else);
131 let is_break = self.parse_scope(if_else.scope_else) || is_break;
132
133 if let Some(current_block) = self.current_block {
134 self.program.add_edge(current_block, next, 0);
135 } else {
136 merge = self.ret;
138 }
139
140 let merge = if is_break { None } else { Some(merge) };
141 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
142 cond: if_else.cond,
143 then,
144 or_else,
145 merge,
146 };
147 if let Some(merge) = merge {
148 self.program[merge].block_use.push(BlockUse::Merge);
149 }
150 self.current_block = Some(next);
151 }
152
153 pub(crate) fn parse_switch(&mut self, switch: Switch) {
154 let current_block = self.current_block.unwrap();
155 let next = self.program.add_node(BasicBlock::default());
156
157 let branches = switch
158 .cases
159 .into_iter()
160 .map(|(val, case)| {
161 let case_id = self.program.add_node(BasicBlock::default());
162 self.program.add_edge(current_block, case_id, 0);
163 self.current_block = Some(case_id);
164 let is_break = self.parse_scope(case);
165 let is_ret = if let Some(current_block) = self.current_block {
166 self.program.add_edge(current_block, next, 0);
167 false
168 } else {
169 !is_break
170 };
171 let val = match val.as_const().expect("Switch value must be constant") {
172 ConstantValue::Int(val) => unsafe { transmute::<i32, u32>(val as i32) },
173 ConstantValue::UInt(val) => val as u32,
174 _ => unreachable!("Switch cases must be integer"),
175 };
176 (val, case_id, is_break, is_ret)
177 })
178 .collect::<Vec<_>>();
179
180 let is_break_branch = branches.iter().any(|it| it.2);
181 let mut is_ret = branches.iter().any(|it| it.3);
182 let branches = branches
183 .into_iter()
184 .map(|it| (it.0, it.1))
185 .collect::<Vec<_>>();
186
187 let default = self.program.add_node(BasicBlock::default());
188 self.program.add_edge(current_block, default, 0);
189 self.current_block = Some(default);
190 let is_break_def = self.parse_scope(switch.scope_default);
191
192 if let Some(current_block) = self.current_block {
193 self.program.add_edge(current_block, next, 0);
194 } else {
195 is_ret = !is_break_def;
196 }
197
198 let merge = if is_break_def || is_break_branch {
199 None
200 } else if is_ret {
201 Some(self.ret)
202 } else {
203 self.program[next].block_use.push(BlockUse::Merge);
204 Some(next)
205 };
206
207 *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Switch {
208 value: switch.value,
209 default,
210 branches,
211 merge,
212 };
213
214 self.current_block = Some(next);
215 }
216
217 fn parse_loop(&mut self, loop_: Loop) {
218 let current_block = self.current_block.unwrap();
219 let header = self.program.add_node(BasicBlock::default());
220 self.program.add_edge(current_block, header, 0);
221
222 let body = self.program.add_node(BasicBlock::default());
223 let next = self.program.add_node(BasicBlock::default());
224
225 self.program.add_edge(header, body, 0);
226
227 self.loop_break.push_back(next);
228
229 self.current_block = Some(body);
230 self.parse_scope(loop_.scope);
231 let continue_target = self.program.add_node(BasicBlock::default());
232 self.program[continue_target]
233 .block_use
234 .push(BlockUse::ContinueTarget);
235
236 self.loop_break.pop_back();
237
238 if let Some(current_block) = self.current_block {
239 self.program.add_edge(current_block, continue_target, 0);
240 }
241
242 self.program.add_edge(continue_target, header, 0);
243
244 *self.program[header].control_flow.borrow_mut() = ControlFlow::Loop {
245 body,
246 continue_target,
247 merge: next,
248 };
249 self.program[next].block_use.push(BlockUse::Merge);
250 self.current_block = Some(next);
251 }
252
253 fn parse_for_loop(&mut self, range_loop: RangeLoop) {
254 let step = range_loop.step.unwrap_or(1.into());
255
256 let i_id = match range_loop.i.kind {
257 VariableKind::LocalMut { id, .. } => id,
258 _ => unreachable!(),
259 };
260 let i = range_loop.i;
261 self.program.variables.insert(i_id, i.ty);
262
263 let assign = Instruction::new(Operation::Copy(range_loop.start), i);
264 self.current_block_mut().ops.borrow_mut().push(assign);
265
266 let current_block = self.current_block.unwrap();
267 let header = self.program.add_node(BasicBlock::default());
268 self.program.add_edge(current_block, header, 0);
269
270 let body = self.program.add_node(BasicBlock::default());
271 let next = self.program.add_node(BasicBlock::default());
272
273 self.program.add_edge(header, body, 0);
274 self.program.add_edge(header, next, 0);
275
276 self.loop_break.push_back(next);
277
278 self.current_block = Some(body);
279 self.parse_scope(range_loop.scope);
280
281 self.loop_break.pop_back();
282
283 let current_block = self.current_block.expect("For loop has no loopback path");
284
285 let continue_target = if self.program[current_block]
286 .block_use
287 .contains(&BlockUse::Merge)
288 {
289 let target = self.program.add_node(BasicBlock::default());
290 self.program.add_edge(current_block, target, 0);
291 target
292 } else {
293 current_block
294 };
295
296 self.program.add_edge(continue_target, header, 0);
297
298 self.program[continue_target]
299 .block_use
300 .push(BlockUse::ContinueTarget);
301 self.program[next].block_use.push(BlockUse::Merge);
302 self.current_block = Some(next);
303
304 self.insert_phi(header, i_id, range_loop.start.ty);
306 {
307 let op = match range_loop.inclusive {
308 true => Comparison::LowerEqual,
309 false => Comparison::Lower,
310 };
311 let tmp = *self.allocator.create_local(Type::scalar(ElemType::Bool));
312 self.program[header].ops.borrow_mut().push(Instruction::new(
313 op(BinaryOperator {
314 lhs: i,
315 rhs: range_loop.end,
316 }),
317 tmp,
318 ));
319
320 *self.program[header].control_flow.borrow_mut() = ControlFlow::LoopBreak {
321 break_cond: tmp,
322 body,
323 continue_target,
324 merge: next,
325 };
326 }
327 self.program[current_block]
328 .ops
329 .borrow_mut()
330 .push(Instruction::new(
331 Arithmetic::Add(BinaryOperator { lhs: i, rhs: step }),
332 i,
333 ));
334 }
335
336 pub(crate) fn split_critical_edges(&mut self) {
337 for block in self.node_ids() {
338 let successors = self.program.edges(block);
339 let successors = successors.map(|edge| (edge.id(), edge.target()));
340 let successors: Vec<_> = successors.collect();
341
342 if successors.len() > 1 {
343 let crit = successors
344 .iter()
345 .filter(|(_, b)| self.predecessors(*b).len() > 1)
346 .collect::<Vec<_>>();
347 for (edge, successor) in crit {
348 self.program.remove_edge(*edge);
349 let new_block = self.program.add_node(BasicBlock::default());
350 self.program.add_edge(block, new_block, 0);
351 self.program.add_edge(new_block, *successor, 0);
352 self.invalidate_structure();
353 update_phi(self, *successor, block, new_block);
354 update_control_flow(self, block, *successor, new_block);
355 }
356 }
357 }
358 }
359
360 pub(crate) fn split_free(&mut self) {
363 let mut splits = 0;
364 while self.split_free_inner() {
365 splits += 1;
366 }
367 if splits > 0 {
368 self.invalidate_structure();
369 }
370 }
371
372 fn split_free_inner(&mut self) -> bool {
373 let is_free =
374 |inst: &Instruction| matches!(inst.operation, Operation::Marker(Marker::Free(_)));
375
376 for block in self.node_ids() {
377 let ops = self.block(block).ops.clone();
378 let len = ops.borrow().num_elements();
379 let idx = ops.borrow().values().position(is_free);
380 if let Some(idx) = idx {
381 if idx > 0 {
383 self.split_block_after(block, idx - 1);
384 return true;
385 }
386 if idx < len - 1 {
387 self.split_block_after(block, idx);
388 return true;
389 }
390 }
391 }
392
393 false
394 }
395
396 fn split_block_after(&mut self, block: NodeIndex, idx: usize) -> NodeIndex {
398 let successors = self.successors(block);
399 let edges: Vec<EdgeIndex> = self
400 .program
401 .edges_directed(block, Direction::Outgoing)
402 .map(|it| it.id())
403 .collect();
404 for edge in edges {
405 self.program.remove_edge(edge);
406 }
407
408 let ops = self.block(block).ops.take();
409 let before: Vec<_> = ops.values().take(idx + 1).cloned().collect();
410 let after: Vec<_> = ops.values().skip(idx + 1).cloned().collect();
411 *self.block(block).ops.borrow_mut() = StableVec::from_iter(before);
412
413 let new_block = BasicBlock::default();
414 new_block.control_flow.swap(&self.block(block).control_flow);
415 new_block.ops.borrow_mut().extend(after);
416 let new_block = self.program.graph.add_node(new_block);
417
418 self.program.add_edge(block, new_block, 0);
419 for successor in successors {
420 self.program.add_edge(new_block, successor, 0);
421 }
422 new_block
423 }
424}
425
426fn update_control_flow(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
427 let update = |id: &mut NodeIndex| {
428 if *id == from {
429 *id = to
430 }
431 };
432
433 match &mut *opt.program[block].control_flow.borrow_mut() {
434 ControlFlow::IfElse { then, or_else, .. } => {
435 update(then);
436 update(or_else);
437 }
438 ControlFlow::Switch {
439 default, branches, ..
440 } => {
441 update(default);
442
443 for branch in branches {
444 update(&mut branch.1);
445 }
446 }
447 _ => {}
448 }
449}
450
451fn update_phi(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
452 for phi in opt.program[block].phi_nodes.borrow_mut().iter_mut() {
453 for entry in phi.entries.iter_mut() {
454 if entry.block == from {
455 entry.block = to;
456 }
457 }
458 }
459}