cubecl_opt/
control_flow.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer};
6use cubecl_ir::{
7    Arithmetic, BinaryOperator, Branch, Comparison, ConstantScalarValue, ElemType, If, IfElse,
8    Instruction, Loop, Operation, RangeLoop, Switch, Type, Variable, VariableKind,
9};
10use petgraph::{Direction, graph::EdgeIndex, visit::EdgeRef};
11use stable_vec::StableVec;
12
13/// Control flow that terminates a block
14#[derive(Default, Debug, Clone)]
15pub enum ControlFlow {
16    /// An if or if-else branch that should be structured if applicable.
17    IfElse {
18        cond: Variable,
19        then: NodeIndex,
20        or_else: NodeIndex,
21        merge: Option<NodeIndex>,
22    },
23    /// A switch branch that paths based on `value`
24    Switch {
25        value: Variable,
26        default: NodeIndex,
27        branches: Vec<(u32, NodeIndex)>,
28        merge: Option<NodeIndex>,
29    },
30    /// A loop with a header (the block that contains this variant), a `body` and a `continue target`.
31    /// `merge` is the block that gets executed as soon as the loop terminates.
32    Loop {
33        body: NodeIndex,
34        continue_target: NodeIndex,
35        merge: NodeIndex,
36    },
37    /// A loop with a header (the block that contains this variant), a `body` and a `continue target`.
38    /// `merge` is the block that gets executed as soon as the loop terminates. The header contains
39    /// the break condition.
40    LoopBreak {
41        break_cond: Variable,
42        body: NodeIndex,
43        continue_target: NodeIndex,
44        merge: NodeIndex,
45    },
46    /// A return statement. This should only occur once in the program and all other returns should
47    /// instead branch to this single return block.
48    Return,
49    /// No special control flow. The block must have exactly one edge that should be followed.
50    #[default]
51    None,
52}
53
54impl Optimizer {
55    pub(crate) fn parse_control_flow(&mut self, branch: Branch) {
56        match branch {
57            Branch::If(if_) => self.parse_if(*if_),
58            Branch::IfElse(if_else) => self.parse_if_else(if_else),
59            Branch::Switch(switch) => self.parse_switch(*switch),
60            Branch::RangeLoop(range_loop) => {
61                self.parse_for_loop(*range_loop);
62            }
63            Branch::Loop(loop_) => self.parse_loop(*loop_),
64            Branch::Return => {
65                let current_block = self.current_block.take().unwrap();
66                let ret = self.ret();
67                self.program.add_edge(current_block, ret, 0);
68            }
69            Branch::Break => {
70                let current_block = self.current_block.take().unwrap();
71                let loop_break = self.loop_break.back().expect("Can't break outside loop");
72                self.program.add_edge(current_block, *loop_break, 0);
73            }
74        }
75    }
76
77    pub(crate) fn parse_if(&mut self, if_: If) {
78        let current_block = self.current_block.unwrap();
79        let then = self.program.add_node(BasicBlock::default());
80        let next = self.program.add_node(BasicBlock::default());
81        let mut merge = next;
82
83        self.program.add_edge(current_block, then, 0);
84        self.program.add_edge(current_block, next, 0);
85
86        self.current_block = Some(then);
87        let is_break = self.parse_scope(if_.scope);
88
89        if let Some(current_block) = self.current_block {
90            self.program.add_edge(current_block, next, 0);
91        } else {
92            // Returned
93            merge = self.ret;
94        }
95
96        let merge = if is_break { None } else { Some(merge) };
97
98        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
99            cond: if_.cond,
100            then,
101            or_else: next,
102            merge,
103        };
104        if let Some(merge) = merge {
105            self.program[merge].block_use.push(BlockUse::Merge);
106        }
107        self.current_block = Some(next);
108    }
109
110    pub(crate) fn parse_if_else(&mut self, if_else: Box<IfElse>) {
111        let current_block = self.current_block.unwrap();
112        let then = self.program.add_node(BasicBlock::default());
113        let or_else = self.program.add_node(BasicBlock::default());
114        let next = self.program.add_node(BasicBlock::default());
115        let mut merge = next;
116
117        self.program.add_edge(current_block, then, 0);
118        self.program.add_edge(current_block, or_else, 0);
119
120        self.current_block = Some(then);
121        let is_break = self.parse_scope(if_else.scope_if);
122
123        if let Some(current_block) = self.current_block {
124            self.program.add_edge(current_block, next, 0);
125        } else {
126            // Returned
127            merge = self.ret;
128        }
129
130        self.current_block = Some(or_else);
131        let is_break = self.parse_scope(if_else.scope_else) || is_break;
132
133        if let Some(current_block) = self.current_block {
134            self.program.add_edge(current_block, next, 0);
135        } else {
136            // Returned
137            merge = self.ret;
138        }
139
140        let merge = if is_break { None } else { Some(merge) };
141        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
142            cond: if_else.cond,
143            then,
144            or_else,
145            merge,
146        };
147        if let Some(merge) = merge {
148            self.program[merge].block_use.push(BlockUse::Merge);
149        }
150        self.current_block = Some(next);
151    }
152
153    pub(crate) fn parse_switch(&mut self, switch: Switch) {
154        let current_block = self.current_block.unwrap();
155        let next = self.program.add_node(BasicBlock::default());
156
157        let branches = switch
158            .cases
159            .into_iter()
160            .map(|(val, case)| {
161                let case_id = self.program.add_node(BasicBlock::default());
162                self.program.add_edge(current_block, case_id, 0);
163                self.current_block = Some(case_id);
164                let is_break = self.parse_scope(case);
165                let is_ret = if let Some(current_block) = self.current_block {
166                    self.program.add_edge(current_block, next, 0);
167                    false
168                } else {
169                    !is_break
170                };
171                let val = match val.as_const().expect("Switch value must be constant") {
172                    ConstantScalarValue::Int(val, _) => unsafe {
173                        transmute::<i32, u32>(val as i32)
174                    },
175                    ConstantScalarValue::UInt(val, _) => val as u32,
176                    _ => unreachable!("Switch cases must be integer"),
177                };
178                (val, case_id, is_break, is_ret)
179            })
180            .collect::<Vec<_>>();
181
182        let is_break_branch = branches.iter().any(|it| it.2);
183        let mut is_ret = branches.iter().any(|it| it.3);
184        let branches = branches
185            .into_iter()
186            .map(|it| (it.0, it.1))
187            .collect::<Vec<_>>();
188
189        let default = self.program.add_node(BasicBlock::default());
190        self.program.add_edge(current_block, default, 0);
191        self.current_block = Some(default);
192        let is_break_def = self.parse_scope(switch.scope_default);
193
194        if let Some(current_block) = self.current_block {
195            self.program.add_edge(current_block, next, 0);
196        } else {
197            is_ret = !is_break_def;
198        }
199
200        let merge = if is_break_def || is_break_branch {
201            None
202        } else if is_ret {
203            Some(self.ret)
204        } else {
205            self.program[next].block_use.push(BlockUse::Merge);
206            Some(next)
207        };
208
209        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Switch {
210            value: switch.value,
211            default,
212            branches,
213            merge,
214        };
215
216        self.current_block = Some(next);
217    }
218
219    fn parse_loop(&mut self, loop_: Loop) {
220        let current_block = self.current_block.unwrap();
221        let header = self.program.add_node(BasicBlock::default());
222        self.program.add_edge(current_block, header, 0);
223
224        let body = self.program.add_node(BasicBlock::default());
225        let next = self.program.add_node(BasicBlock::default());
226
227        self.program.add_edge(header, body, 0);
228
229        self.loop_break.push_back(next);
230
231        self.current_block = Some(body);
232        self.parse_scope(loop_.scope);
233        let continue_target = self.program.add_node(BasicBlock::default());
234        self.program[continue_target]
235            .block_use
236            .push(BlockUse::ContinueTarget);
237
238        self.loop_break.pop_back();
239
240        if let Some(current_block) = self.current_block {
241            self.program.add_edge(current_block, continue_target, 0);
242        }
243
244        self.program.add_edge(continue_target, header, 0);
245
246        *self.program[header].control_flow.borrow_mut() = ControlFlow::Loop {
247            body,
248            continue_target,
249            merge: next,
250        };
251        self.program[next].block_use.push(BlockUse::Merge);
252        self.current_block = Some(next);
253    }
254
255    fn parse_for_loop(&mut self, range_loop: RangeLoop) {
256        let step = range_loop.step.unwrap_or(1.into());
257
258        let i_id = match range_loop.i.kind {
259            VariableKind::LocalMut { id, .. } => id,
260            _ => unreachable!(),
261        };
262        let i = range_loop.i;
263        self.program.variables.insert(i_id, i.ty);
264
265        let assign = Instruction::new(Operation::Copy(range_loop.start), i);
266        self.current_block_mut().ops.borrow_mut().push(assign);
267
268        let current_block = self.current_block.unwrap();
269        let header = self.program.add_node(BasicBlock::default());
270        self.program.add_edge(current_block, header, 0);
271
272        let body = self.program.add_node(BasicBlock::default());
273        let next = self.program.add_node(BasicBlock::default());
274
275        self.program.add_edge(header, body, 0);
276        self.program.add_edge(header, next, 0);
277
278        self.loop_break.push_back(next);
279
280        self.current_block = Some(body);
281        self.parse_scope(range_loop.scope);
282
283        self.loop_break.pop_back();
284
285        let current_block = self.current_block.expect("For loop has no loopback path");
286
287        let continue_target = if self.program[current_block]
288            .block_use
289            .contains(&BlockUse::Merge)
290        {
291            let target = self.program.add_node(BasicBlock::default());
292            self.program.add_edge(current_block, target, 0);
293            target
294        } else {
295            current_block
296        };
297
298        self.program.add_edge(continue_target, header, 0);
299
300        self.program[continue_target]
301            .block_use
302            .push(BlockUse::ContinueTarget);
303        self.program[next].block_use.push(BlockUse::Merge);
304        self.current_block = Some(next);
305
306        // For loop constructs
307        self.insert_phi(header, i_id, range_loop.start.ty);
308        {
309            let op = match range_loop.inclusive {
310                true => Comparison::LowerEqual,
311                false => Comparison::Lower,
312            };
313            let tmp = *self.allocator.create_local(Type::scalar(ElemType::Bool));
314            self.program[header].ops.borrow_mut().push(Instruction::new(
315                op(BinaryOperator {
316                    lhs: i,
317                    rhs: range_loop.end,
318                }),
319                tmp,
320            ));
321
322            *self.program[header].control_flow.borrow_mut() = ControlFlow::LoopBreak {
323                break_cond: tmp,
324                body,
325                continue_target,
326                merge: next,
327            };
328        }
329        self.program[current_block]
330            .ops
331            .borrow_mut()
332            .push(Instruction::new(
333                Arithmetic::Add(BinaryOperator { lhs: i, rhs: step }),
334                i,
335            ));
336    }
337
338    pub(crate) fn split_critical_edges(&mut self) {
339        for block in self.node_ids() {
340            let successors = self.program.edges(block);
341            let successors = successors.map(|edge| (edge.id(), edge.target()));
342            let successors: Vec<_> = successors.collect();
343
344            if successors.len() > 1 {
345                let crit = successors
346                    .iter()
347                    .filter(|(_, b)| self.predecessors(*b).len() > 1)
348                    .collect::<Vec<_>>();
349                for (edge, successor) in crit {
350                    self.program.remove_edge(*edge);
351                    let new_block = self.program.add_node(BasicBlock::default());
352                    self.program.add_edge(block, new_block, 0);
353                    self.program.add_edge(new_block, *successor, 0);
354                    self.invalidate_structure();
355                    update_phi(self, *successor, block, new_block);
356                    update_control_flow(self, block, *successor, new_block);
357                }
358            }
359        }
360    }
361
362    /// Split blocks at a `free` call because we only track liveness at a block level
363    /// It's easier than doing liveness per-instruction, and free calls are rare anyways
364    pub(crate) fn split_free(&mut self) {
365        let mut splits = 0;
366        while self.split_free_inner() {
367            splits += 1;
368        }
369        if splits > 0 {
370            self.invalidate_structure();
371        }
372    }
373
374    fn split_free_inner(&mut self) -> bool {
375        let is_free = |inst: &Instruction| matches!(inst.operation, Operation::Free(_));
376
377        for block in self.node_ids() {
378            let ops = self.block(block).ops.clone();
379            let len = ops.borrow().num_elements();
380            let idx = ops.borrow().values().position(is_free);
381            if let Some(idx) = idx {
382                // Separate free into its own block. They can be merged again later.
383                if idx > 0 {
384                    self.split_block_after(block, idx - 1);
385                    return true;
386                }
387                if idx < len - 1 {
388                    self.split_block_after(block, idx);
389                    return true;
390                }
391            }
392        }
393
394        false
395    }
396
397    /// Split block after `idx` and return the new block
398    fn split_block_after(&mut self, block: NodeIndex, idx: usize) -> NodeIndex {
399        let successors = self.successors(block);
400        let edges: Vec<EdgeIndex> = self
401            .program
402            .edges_directed(block, Direction::Outgoing)
403            .map(|it| it.id())
404            .collect();
405        for edge in edges {
406            self.program.remove_edge(edge);
407        }
408
409        let ops = self.block(block).ops.take();
410        let before: Vec<_> = ops.values().take(idx + 1).cloned().collect();
411        let after: Vec<_> = ops.values().skip(idx + 1).cloned().collect();
412        *self.block(block).ops.borrow_mut() = StableVec::from_iter(before);
413
414        let new_block = BasicBlock::default();
415        new_block.control_flow.swap(&self.block(block).control_flow);
416        new_block.ops.borrow_mut().extend(after);
417        let new_block = self.program.graph.add_node(new_block);
418
419        self.program.add_edge(block, new_block, 0);
420        for successor in successors {
421            self.program.add_edge(new_block, successor, 0);
422        }
423        new_block
424    }
425}
426
427fn update_control_flow(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
428    let update = |id: &mut NodeIndex| {
429        if *id == from {
430            *id = to
431        }
432    };
433
434    match &mut *opt.program[block].control_flow.borrow_mut() {
435        ControlFlow::IfElse { then, or_else, .. } => {
436            update(then);
437            update(or_else);
438        }
439        ControlFlow::Switch {
440            default, branches, ..
441        } => {
442            update(default);
443
444            for branch in branches {
445                update(&mut branch.1);
446            }
447        }
448        _ => {}
449    }
450}
451
452fn update_phi(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
453    for phi in opt.program[block].phi_nodes.borrow_mut().iter_mut() {
454        for entry in phi.entries.iter_mut() {
455            if entry.block == from {
456                entry.block = to;
457            }
458        }
459    }
460}