cubecl_opt/
control_flow.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer};
6use cubecl_ir::{
7    Arithmetic, BinaryOperator, Branch, Comparison, ConstantScalarValue, Elem, If, IfElse,
8    Instruction, Item, Loop, Operation, RangeLoop, Switch, Variable, VariableKind,
9};
10use petgraph::visit::EdgeRef;
11
12/// Control flow that terminates a block
13#[derive(Default, Debug, Clone)]
14pub enum ControlFlow {
15    /// An if or if-else branch that should be structured if applicable.
16    IfElse {
17        cond: Variable,
18        then: NodeIndex,
19        or_else: NodeIndex,
20        merge: Option<NodeIndex>,
21    },
22    /// A switch branch that paths based on `value`
23    Switch {
24        value: Variable,
25        default: NodeIndex,
26        branches: Vec<(u32, NodeIndex)>,
27        merge: Option<NodeIndex>,
28    },
29    /// A loop with a header (the block that contains this variant), a `body` and a `continue target`.
30    /// `merge` is the block that gets executed as soon as the loop terminates.
31    Loop {
32        body: NodeIndex,
33        continue_target: NodeIndex,
34        merge: NodeIndex,
35    },
36    /// A loop with a header (the block that contains this variant), a `body` and a `continue target`.
37    /// `merge` is the block that gets executed as soon as the loop terminates. The header contains
38    /// the break condition.
39    LoopBreak {
40        break_cond: Variable,
41        body: NodeIndex,
42        continue_target: NodeIndex,
43        merge: NodeIndex,
44    },
45    /// A return statement. This should only occur once in the program and all other returns should
46    /// instead branch to this single return block.
47    Return,
48    /// No special control flow. The block must have exactly one edge that should be followed.
49    #[default]
50    None,
51}
52
53impl Optimizer {
54    pub(crate) fn parse_control_flow(&mut self, branch: Branch) {
55        match branch {
56            Branch::If(if_) => self.parse_if(*if_),
57            Branch::IfElse(if_else) => self.parse_if_else(if_else),
58            Branch::Switch(switch) => self.parse_switch(*switch),
59            Branch::RangeLoop(range_loop) => {
60                self.parse_for_loop(*range_loop);
61            }
62            Branch::Loop(loop_) => self.parse_loop(*loop_),
63            Branch::Return => {
64                let current_block = self.current_block.take().unwrap();
65                let ret = self.ret();
66                self.program.add_edge(current_block, ret, 0);
67            }
68            Branch::Break => {
69                let current_block = self.current_block.take().unwrap();
70                let loop_break = self.loop_break.back().expect("Can't break outside loop");
71                self.program.add_edge(current_block, *loop_break, 0);
72            }
73        }
74    }
75
76    pub(crate) fn parse_if(&mut self, if_: If) {
77        let current_block = self.current_block.unwrap();
78        let then = self.program.add_node(BasicBlock::default());
79        let next = self.program.add_node(BasicBlock::default());
80        let mut merge = next;
81
82        self.program.add_edge(current_block, then, 0);
83        self.program.add_edge(current_block, next, 0);
84
85        self.current_block = Some(then);
86        let is_break = self.parse_scope(if_.scope);
87
88        if let Some(current_block) = self.current_block {
89            self.program.add_edge(current_block, next, 0);
90        } else {
91            // Returned
92            merge = self.ret;
93        }
94
95        let merge = if is_break { None } else { Some(merge) };
96
97        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
98            cond: if_.cond,
99            then,
100            or_else: next,
101            merge,
102        };
103        if let Some(merge) = merge {
104            self.program[merge].block_use.push(BlockUse::Merge);
105        }
106        self.current_block = Some(next);
107    }
108
109    pub(crate) fn parse_if_else(&mut self, if_else: Box<IfElse>) {
110        let current_block = self.current_block.unwrap();
111        let then = self.program.add_node(BasicBlock::default());
112        let or_else = self.program.add_node(BasicBlock::default());
113        let next = self.program.add_node(BasicBlock::default());
114        let mut merge = next;
115
116        self.program.add_edge(current_block, then, 0);
117        self.program.add_edge(current_block, or_else, 0);
118
119        self.current_block = Some(then);
120        let is_break = self.parse_scope(if_else.scope_if);
121
122        if let Some(current_block) = self.current_block {
123            self.program.add_edge(current_block, next, 0);
124        } else {
125            // Returned
126            merge = self.ret;
127        }
128
129        self.current_block = Some(or_else);
130        let is_break = self.parse_scope(if_else.scope_else) || is_break;
131
132        if let Some(current_block) = self.current_block {
133            self.program.add_edge(current_block, next, 0);
134        } else {
135            // Returned
136            merge = self.ret;
137        }
138
139        let merge = if is_break { None } else { Some(merge) };
140        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
141            cond: if_else.cond,
142            then,
143            or_else,
144            merge,
145        };
146        if let Some(merge) = merge {
147            self.program[merge].block_use.push(BlockUse::Merge);
148        }
149        self.current_block = Some(next);
150    }
151
152    pub(crate) fn parse_switch(&mut self, switch: Switch) {
153        let current_block = self.current_block.unwrap();
154        let next = self.program.add_node(BasicBlock::default());
155
156        let branches = switch
157            .cases
158            .into_iter()
159            .map(|(val, case)| {
160                let case_id = self.program.add_node(BasicBlock::default());
161                self.program.add_edge(current_block, case_id, 0);
162                self.current_block = Some(case_id);
163                let is_break = self.parse_scope(case);
164                let is_ret = if let Some(current_block) = self.current_block {
165                    self.program.add_edge(current_block, next, 0);
166                    false
167                } else {
168                    !is_break
169                };
170                let val = match val.as_const().expect("Switch value must be constant") {
171                    ConstantScalarValue::Int(val, _) => unsafe {
172                        transmute::<i32, u32>(val as i32)
173                    },
174                    ConstantScalarValue::UInt(val, _) => val as u32,
175                    _ => unreachable!("Switch cases must be integer"),
176                };
177                (val, case_id, is_break, is_ret)
178            })
179            .collect::<Vec<_>>();
180
181        let is_break_branch = branches.iter().any(|it| it.2);
182        let mut is_ret = branches.iter().any(|it| it.3);
183        let branches = branches
184            .into_iter()
185            .map(|it| (it.0, it.1))
186            .collect::<Vec<_>>();
187
188        let default = self.program.add_node(BasicBlock::default());
189        self.program.add_edge(current_block, default, 0);
190        self.current_block = Some(default);
191        let is_break_def = self.parse_scope(switch.scope_default);
192
193        if let Some(current_block) = self.current_block {
194            self.program.add_edge(current_block, next, 0);
195        } else {
196            is_ret = !is_break_def;
197        }
198
199        let merge = if is_break_def || is_break_branch {
200            None
201        } else if is_ret {
202            Some(self.ret)
203        } else {
204            self.program[next].block_use.push(BlockUse::Merge);
205            Some(next)
206        };
207
208        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Switch {
209            value: switch.value,
210            default,
211            branches,
212            merge,
213        };
214
215        self.current_block = Some(next);
216    }
217
218    fn parse_loop(&mut self, loop_: Loop) {
219        let current_block = self.current_block.unwrap();
220        let header = self.program.add_node(BasicBlock::default());
221        self.program.add_edge(current_block, header, 0);
222
223        let body = self.program.add_node(BasicBlock::default());
224        let next = self.program.add_node(BasicBlock::default());
225
226        self.program.add_edge(header, body, 0);
227
228        self.loop_break.push_back(next);
229
230        self.current_block = Some(body);
231        self.parse_scope(loop_.scope);
232        let continue_target = self.program.add_node(BasicBlock::default());
233        self.program[continue_target]
234            .block_use
235            .push(BlockUse::ContinueTarget);
236
237        self.loop_break.pop_back();
238
239        if let Some(current_block) = self.current_block {
240            self.program.add_edge(current_block, continue_target, 0);
241        }
242
243        self.program.add_edge(continue_target, header, 0);
244
245        *self.program[header].control_flow.borrow_mut() = ControlFlow::Loop {
246            body,
247            continue_target,
248            merge: next,
249        };
250        self.program[next].block_use.push(BlockUse::Merge);
251        self.current_block = Some(next);
252    }
253
254    fn parse_for_loop(&mut self, range_loop: RangeLoop) {
255        let step = range_loop.step.unwrap_or(1.into());
256
257        let i_id = match range_loop.i.kind {
258            VariableKind::LocalMut { id, .. } => id,
259            _ => unreachable!(),
260        };
261        let i = range_loop.i;
262        self.program.variables.insert(i_id, i.item);
263
264        let assign = Instruction::new(Operation::Copy(range_loop.start), i);
265        self.current_block_mut().ops.borrow_mut().push(assign);
266
267        let current_block = self.current_block.unwrap();
268        let header = self.program.add_node(BasicBlock::default());
269        self.program.add_edge(current_block, header, 0);
270
271        let body = self.program.add_node(BasicBlock::default());
272        let next = self.program.add_node(BasicBlock::default());
273
274        self.program.add_edge(header, body, 0);
275        self.program.add_edge(header, next, 0);
276
277        self.loop_break.push_back(next);
278
279        self.current_block = Some(body);
280        self.parse_scope(range_loop.scope);
281
282        self.loop_break.pop_back();
283
284        let current_block = self.current_block.expect("For loop has no loopback path");
285
286        let continue_target = if self.program[current_block]
287            .block_use
288            .contains(&BlockUse::Merge)
289        {
290            let target = self.program.add_node(BasicBlock::default());
291            self.program.add_edge(current_block, target, 0);
292            target
293        } else {
294            current_block
295        };
296
297        self.program.add_edge(continue_target, header, 0);
298
299        self.program[continue_target]
300            .block_use
301            .push(BlockUse::ContinueTarget);
302        self.program[next].block_use.push(BlockUse::Merge);
303        self.current_block = Some(next);
304
305        // For loop constructs
306        self.insert_phi(header, i_id, range_loop.start.item);
307        {
308            let op = match range_loop.inclusive {
309                true => Comparison::LowerEqual,
310                false => Comparison::Lower,
311            };
312            let tmp = *self.allocator.create_local(Item::new(Elem::Bool));
313            self.program[header].ops.borrow_mut().push(Instruction::new(
314                op(BinaryOperator {
315                    lhs: i,
316                    rhs: range_loop.end,
317                }),
318                tmp,
319            ));
320
321            *self.program[header].control_flow.borrow_mut() = ControlFlow::LoopBreak {
322                break_cond: tmp,
323                body,
324                continue_target,
325                merge: next,
326            };
327        }
328        self.program[current_block]
329            .ops
330            .borrow_mut()
331            .push(Instruction::new(
332                Arithmetic::Add(BinaryOperator { lhs: i, rhs: step }),
333                i,
334            ));
335    }
336
337    pub(crate) fn split_critical_edges(&mut self) {
338        for block in self.node_ids() {
339            let successors = self.program.edges(block);
340            let successors = successors.map(|edge| (edge.id(), edge.target()));
341            let successors: Vec<_> = successors.collect();
342
343            if successors.len() > 1 {
344                let crit = successors
345                    .iter()
346                    .filter(|(_, b)| self.predecessors(*b).len() > 1)
347                    .collect::<Vec<_>>();
348                for (edge, successor) in crit {
349                    self.program.remove_edge(*edge);
350                    let new_block = self.program.add_node(BasicBlock::default());
351                    self.program.add_edge(block, new_block, 0);
352                    self.program.add_edge(new_block, *successor, 0);
353                    self.invalidate_structure();
354                    update_phi(self, *successor, block, new_block);
355                    update_control_flow(self, block, *successor, new_block);
356                }
357            }
358        }
359    }
360}
361
362fn update_control_flow(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
363    let update = |id: &mut NodeIndex| {
364        if *id == from {
365            *id = to
366        }
367    };
368
369    match &mut *opt.program[block].control_flow.borrow_mut() {
370        ControlFlow::IfElse { then, or_else, .. } => {
371            update(then);
372            update(or_else);
373        }
374        ControlFlow::Switch {
375            default, branches, ..
376        } => {
377            update(default);
378
379            for branch in branches {
380                update(&mut branch.1);
381            }
382        }
383        _ => {}
384    }
385}
386
387fn update_phi(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
388    for phi in opt.program[block].phi_nodes.borrow_mut().iter_mut() {
389        for entry in phi.entries.iter_mut() {
390            if entry.block == from {
391                entry.block = to;
392            }
393        }
394    }
395}