cubecl_opt/
control_flow.rs

1use std::mem::transmute;
2
3use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer};
4use cubecl_core::ir::{
5    BinaryOperator, Branch, ConstantScalarValue, Elem, If, IfElse, Instruction, Item, Loop,
6    Operation, Operator, RangeLoop, Switch, Variable, VariableKind,
7};
8use petgraph::visit::EdgeRef;
9
10/// Control flow that terminates a block
11#[derive(Default, Debug, Clone)]
12pub enum ControlFlow {
13    /// An if or if-else branch that should be structured if applicable.
14    IfElse {
15        cond: Variable,
16        then: NodeIndex,
17        or_else: NodeIndex,
18        merge: Option<NodeIndex>,
19    },
20    /// A switch branch that paths based on `value`
21    Switch {
22        value: Variable,
23        default: NodeIndex,
24        branches: Vec<(u32, NodeIndex)>,
25        merge: Option<NodeIndex>,
26    },
27    /// A loop with a header (the block that contains this variant), a `body` and a `continue target`.
28    /// `merge` is the block that gets executed as soon as the loop terminates.
29    Loop {
30        body: NodeIndex,
31        continue_target: NodeIndex,
32        merge: NodeIndex,
33    },
34    /// A loop with a header (the block that contains this variant), a `body` and a `continue target`.
35    /// `merge` is the block that gets executed as soon as the loop terminates. The header contains
36    /// the break condition.
37    LoopBreak {
38        break_cond: Variable,
39        body: NodeIndex,
40        continue_target: NodeIndex,
41        merge: NodeIndex,
42    },
43    /// A return statement. This should only occur once in the program and all other returns should
44    /// instead branch to this single return block.
45    Return,
46    /// No special control flow. The block must have exactly one edge that should be followed.
47    #[default]
48    None,
49}
50
51impl Optimizer {
52    pub(crate) fn parse_control_flow(&mut self, branch: Branch) {
53        match branch {
54            Branch::If(if_) => self.parse_if(*if_),
55            Branch::IfElse(if_else) => self.parse_if_else(if_else),
56            Branch::Switch(switch) => self.parse_switch(*switch),
57            Branch::RangeLoop(range_loop) => {
58                self.parse_for_loop(*range_loop);
59            }
60            Branch::Loop(loop_) => self.parse_loop(*loop_),
61            Branch::Return => {
62                let current_block = self.current_block.take().unwrap();
63                let ret = self.ret();
64                self.program.add_edge(current_block, ret, ());
65            }
66            Branch::Break => {
67                let current_block = self.current_block.take().unwrap();
68                let loop_break = self.loop_break.back().expect("Can't break outside loop");
69                self.program.add_edge(current_block, *loop_break, ());
70            }
71        }
72    }
73
74    pub(crate) fn parse_if(&mut self, if_: If) {
75        let current_block = self.current_block.unwrap();
76        let then = self.program.add_node(BasicBlock::default());
77        let next = self.program.add_node(BasicBlock::default());
78        let mut merge = next;
79
80        self.program.add_edge(current_block, then, ());
81        self.program.add_edge(current_block, next, ());
82
83        self.current_block = Some(then);
84        let is_break = self.parse_scope(if_.scope);
85
86        if let Some(current_block) = self.current_block {
87            self.program.add_edge(current_block, next, ());
88        } else {
89            // Returned
90            merge = self.ret;
91        }
92
93        let merge = if is_break { None } else { Some(merge) };
94
95        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
96            cond: if_.cond,
97            then,
98            or_else: next,
99            merge,
100        };
101        if let Some(merge) = merge {
102            self.program[merge].block_use.push(BlockUse::Merge);
103        }
104        self.current_block = Some(next);
105    }
106
107    pub(crate) fn parse_if_else(&mut self, if_else: Box<IfElse>) {
108        let current_block = self.current_block.unwrap();
109        let then = self.program.add_node(BasicBlock::default());
110        let or_else = self.program.add_node(BasicBlock::default());
111        let next = self.program.add_node(BasicBlock::default());
112        let mut merge = next;
113
114        self.program.add_edge(current_block, then, ());
115        self.program.add_edge(current_block, or_else, ());
116
117        self.current_block = Some(then);
118        let is_break = self.parse_scope(if_else.scope_if);
119
120        if let Some(current_block) = self.current_block {
121            self.program.add_edge(current_block, next, ());
122        } else {
123            // Returned
124            merge = self.ret;
125        }
126
127        self.current_block = Some(or_else);
128        let is_break = self.parse_scope(if_else.scope_else) || is_break;
129
130        if let Some(current_block) = self.current_block {
131            self.program.add_edge(current_block, next, ());
132        } else {
133            // Returned
134            merge = self.ret;
135        }
136
137        let merge = if is_break { None } else { Some(merge) };
138        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
139            cond: if_else.cond,
140            then,
141            or_else,
142            merge,
143        };
144        if let Some(merge) = merge {
145            self.program[merge].block_use.push(BlockUse::Merge);
146        }
147        self.current_block = Some(next);
148    }
149
150    pub(crate) fn parse_switch(&mut self, switch: Switch) {
151        let current_block = self.current_block.unwrap();
152        let next = self.program.add_node(BasicBlock::default());
153
154        let branches = switch
155            .cases
156            .into_iter()
157            .map(|(val, case)| {
158                let case_id = self.program.add_node(BasicBlock::default());
159                self.program.add_edge(current_block, case_id, ());
160                self.current_block = Some(case_id);
161                let is_break = self.parse_scope(case);
162                let is_ret = if let Some(current_block) = self.current_block {
163                    self.program.add_edge(current_block, next, ());
164                    false
165                } else {
166                    !is_break
167                };
168                let val = match val.as_const().expect("Switch value must be constant") {
169                    ConstantScalarValue::Int(val, _) => unsafe {
170                        transmute::<i32, u32>(val as i32)
171                    },
172                    ConstantScalarValue::UInt(val, _) => val as u32,
173                    _ => unreachable!("Switch cases must be integer"),
174                };
175                (val, case_id, is_break, is_ret)
176            })
177            .collect::<Vec<_>>();
178
179        let is_break_branch = branches.iter().any(|it| it.2);
180        let mut is_ret = branches.iter().any(|it| it.3);
181        let branches = branches
182            .into_iter()
183            .map(|it| (it.0, it.1))
184            .collect::<Vec<_>>();
185
186        let default = self.program.add_node(BasicBlock::default());
187        self.program.add_edge(current_block, default, ());
188        self.current_block = Some(default);
189        let is_break_def = self.parse_scope(switch.scope_default);
190
191        if let Some(current_block) = self.current_block {
192            self.program.add_edge(current_block, next, ());
193        } else {
194            is_ret = !is_break_def;
195        }
196
197        let merge = if is_break_def || is_break_branch {
198            None
199        } else if is_ret {
200            Some(self.ret)
201        } else {
202            self.program[next].block_use.push(BlockUse::Merge);
203            Some(next)
204        };
205
206        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Switch {
207            value: switch.value,
208            default,
209            branches,
210            merge,
211        };
212
213        self.current_block = Some(next);
214    }
215
216    fn parse_loop(&mut self, loop_: Loop) {
217        let current_block = self.current_block.unwrap();
218        let header = self.program.add_node(BasicBlock::default());
219        self.program.add_edge(current_block, header, ());
220
221        let body = self.program.add_node(BasicBlock::default());
222        let next = self.program.add_node(BasicBlock::default());
223
224        self.program.add_edge(header, body, ());
225
226        self.loop_break.push_back(next);
227
228        self.current_block = Some(body);
229        self.parse_scope(loop_.scope);
230        let continue_target = self.program.add_node(BasicBlock::default());
231        self.program[continue_target]
232            .block_use
233            .push(BlockUse::ContinueTarget);
234
235        self.loop_break.pop_back();
236
237        if let Some(current_block) = self.current_block {
238            self.program.add_edge(current_block, continue_target, ());
239        }
240
241        self.program.add_edge(continue_target, header, ());
242
243        *self.program[header].control_flow.borrow_mut() = ControlFlow::Loop {
244            body,
245            continue_target,
246            merge: next,
247        };
248        self.program[next].block_use.push(BlockUse::Merge);
249        self.current_block = Some(next);
250    }
251
252    fn parse_for_loop(&mut self, range_loop: RangeLoop) {
253        let step = range_loop.step.unwrap_or(1.into());
254
255        let i_id = match range_loop.i.kind {
256            VariableKind::LocalMut { id, .. } => id,
257            _ => unreachable!(),
258        };
259        let i = range_loop.i;
260        self.program.variables.insert(i_id, i.item);
261
262        let assign = Instruction::new(Operation::Copy(range_loop.start), i);
263        self.current_block_mut().ops.borrow_mut().push(assign);
264
265        let current_block = self.current_block.unwrap();
266        let header = self.program.add_node(BasicBlock::default());
267        self.program.add_edge(current_block, header, ());
268
269        let body = self.program.add_node(BasicBlock::default());
270        let next = self.program.add_node(BasicBlock::default());
271
272        self.program.add_edge(header, body, ());
273        self.program.add_edge(header, next, ());
274
275        self.loop_break.push_back(next);
276
277        self.current_block = Some(body);
278        self.parse_scope(range_loop.scope);
279
280        self.loop_break.pop_back();
281
282        let current_block = self.current_block.expect("For loop has no loopback path");
283
284        let continue_target = if self.program[current_block]
285            .block_use
286            .contains(&BlockUse::Merge)
287        {
288            let target = self.program.add_node(BasicBlock::default());
289            self.program.add_edge(current_block, target, ());
290            target
291        } else {
292            current_block
293        };
294
295        self.program.add_edge(continue_target, header, ());
296
297        self.program[continue_target]
298            .block_use
299            .push(BlockUse::ContinueTarget);
300        self.program[next].block_use.push(BlockUse::Merge);
301        self.current_block = Some(next);
302
303        // For loop constructs
304        self.insert_phi(header, i_id, range_loop.start.item);
305        {
306            let op = match range_loop.inclusive {
307                true => Operator::LowerEqual,
308                false => Operator::Lower,
309            };
310            let tmp = *self
311                .allocator
312                .create_local_restricted(Item::new(Elem::Bool));
313            self.program[header].ops.borrow_mut().push(Instruction::new(
314                op(BinaryOperator {
315                    lhs: i,
316                    rhs: range_loop.end,
317                }),
318                tmp,
319            ));
320
321            *self.program[header].control_flow.borrow_mut() = ControlFlow::LoopBreak {
322                break_cond: tmp,
323                body,
324                continue_target,
325                merge: next,
326            };
327        }
328        self.program[current_block]
329            .ops
330            .borrow_mut()
331            .push(Instruction::new(
332                Operator::Add(BinaryOperator { lhs: i, rhs: step }),
333                i,
334            ));
335    }
336
337    pub(crate) fn split_critical_edges(&mut self) {
338        for block in self.node_ids() {
339            let successors = self.program.edges(block);
340            let successors = successors.map(|edge| (edge.id(), edge.target()));
341            let successors: Vec<_> = successors.collect();
342
343            if successors.len() > 1 {
344                let crit = successors
345                    .iter()
346                    .filter(|(_, b)| self.predecessors(*b).len() > 1)
347                    .collect::<Vec<_>>();
348                for (edge, successor) in crit {
349                    self.program.remove_edge(*edge);
350                    let new_block = self.program.add_node(BasicBlock::default());
351                    self.program.add_edge(block, new_block, ());
352                    self.program.add_edge(new_block, *successor, ());
353                    self.invalidate_structure();
354                    update_phi(self, *successor, block, new_block);
355                    update_control_flow(self, block, *successor, new_block);
356                }
357            }
358        }
359    }
360}
361
362fn update_control_flow(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
363    let update = |id: &mut NodeIndex| {
364        if *id == from {
365            *id = to
366        }
367    };
368
369    match &mut *opt.program[block].control_flow.borrow_mut() {
370        ControlFlow::IfElse { then, or_else, .. } => {
371            update(then);
372            update(or_else);
373        }
374        ControlFlow::Switch {
375            default, branches, ..
376        } => {
377            update(default);
378
379            for branch in branches {
380                update(&mut branch.1);
381            }
382        }
383        _ => {}
384    }
385}
386
387fn update_phi(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
388    for phi in opt.program[block].phi_nodes.borrow_mut().iter_mut() {
389        for entry in phi.entries.iter_mut() {
390            if entry.block == from {
391                entry.block = to;
392            }
393        }
394    }
395}