Skip to main content

cubecl_opt/
control_flow.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer};
6use cubecl_ir::{
7    Arithmetic, BinaryOperator, Branch, Comparison, ConstantValue, ElemType, If, IfElse,
8    Instruction, Loop, Marker, Operation, RangeLoop, Switch, Type, Variable, VariableKind,
9};
10use petgraph::{Direction, graph::EdgeIndex, visit::EdgeRef};
11use stable_vec::StableVec;
12
13/// Control flow that terminates a block
14#[derive(Default, Debug, Clone)]
15pub enum ControlFlow {
16    /// An if or if-else branch that should be structured if applicable.
17    IfElse {
18        cond: Variable,
19        then: NodeIndex,
20        or_else: NodeIndex,
21        merge: Option<NodeIndex>,
22    },
23    /// A switch branch that paths based on `value`
24    Switch {
25        value: Variable,
26        default: NodeIndex,
27        branches: Vec<(u32, NodeIndex)>,
28        merge: Option<NodeIndex>,
29    },
30    /// A loop with a header (the block that contains this variant), a `body` and a `continue target`.
31    /// `merge` is the block that gets executed as soon as the loop terminates.
32    Loop {
33        body: NodeIndex,
34        continue_target: NodeIndex,
35        merge: NodeIndex,
36    },
37    /// A loop with a header (the block that contains this variant), a `body` and a `continue target`.
38    /// `merge` is the block that gets executed as soon as the loop terminates. The header contains
39    /// the break condition.
40    LoopBreak {
41        break_cond: Variable,
42        body: NodeIndex,
43        continue_target: NodeIndex,
44        merge: NodeIndex,
45    },
46    /// A return statement. This should only occur once in the program and all other returns should
47    /// instead branch to this single return block.
48    Return,
49    /// Unreachable control flow
50    Unreachable,
51    /// No special control flow. The block must have exactly one edge that should be followed.
52    #[default]
53    None,
54}
55
56pub(crate) enum ControlFlowAction {
57    None,
58    AbortBlock,
59}
60
61impl Optimizer {
62    pub(crate) fn parse_control_flow(&mut self, branch: Branch) -> ControlFlowAction {
63        match branch {
64            Branch::If(if_) => {
65                self.parse_if(*if_);
66                ControlFlowAction::None
67            }
68            Branch::IfElse(if_else) => {
69                self.parse_if_else(if_else);
70                ControlFlowAction::None
71            }
72            Branch::Switch(switch) => {
73                self.parse_switch(*switch);
74                ControlFlowAction::None
75            }
76            Branch::RangeLoop(range_loop) => {
77                self.parse_for_loop(*range_loop);
78                ControlFlowAction::None
79            }
80            Branch::Loop(loop_) => {
81                self.parse_loop(*loop_);
82                ControlFlowAction::None
83            }
84            Branch::Unreachable => {
85                let current_block = self.current_block.take().unwrap();
86                *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Unreachable;
87                ControlFlowAction::AbortBlock
88            }
89            Branch::Return => {
90                let current_block = self.current_block.take().unwrap();
91                let ret = self.ret();
92                self.program.add_edge(current_block, ret, 0);
93                ControlFlowAction::AbortBlock
94            }
95            Branch::Break => {
96                let current_block = self.current_block.take().unwrap();
97                let loop_break = self.loop_break.back().expect("Can't break outside loop");
98                self.program.add_edge(current_block, *loop_break, 0);
99                ControlFlowAction::AbortBlock
100            }
101        }
102    }
103
104    pub(crate) fn parse_if(&mut self, if_: If) {
105        let current_block = self.current_block.unwrap();
106        let then = self.program.add_node(BasicBlock::default());
107        let next = self.program.add_node(BasicBlock::default());
108        let mut merge = next;
109
110        self.program.add_edge(current_block, then, 0);
111        self.program.add_edge(current_block, next, 0);
112
113        self.current_block = Some(then);
114        let is_break = self.parse_scope(if_.scope);
115
116        if let Some(current_block) = self.current_block {
117            self.program.add_edge(current_block, next, 0);
118        } else {
119            // Returned
120            merge = self.ret;
121        }
122
123        let merge = if is_break { None } else { Some(merge) };
124
125        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
126            cond: if_.cond,
127            then,
128            or_else: next,
129            merge,
130        };
131        if let Some(merge) = merge {
132            self.program[merge].block_use.push(BlockUse::Merge);
133        }
134        self.current_block = Some(next);
135    }
136
137    pub(crate) fn parse_if_else(&mut self, if_else: Box<IfElse>) {
138        let current_block = self.current_block.unwrap();
139        let then = self.program.add_node(BasicBlock::default());
140        let or_else = self.program.add_node(BasicBlock::default());
141        let next = self.program.add_node(BasicBlock::default());
142        let mut merge = next;
143
144        self.program.add_edge(current_block, then, 0);
145        self.program.add_edge(current_block, or_else, 0);
146
147        self.current_block = Some(then);
148        let is_break = self.parse_scope(if_else.scope_if);
149
150        if let Some(current_block) = self.current_block {
151            self.program.add_edge(current_block, next, 0);
152        } else {
153            // Returned
154            merge = self.ret;
155        }
156
157        self.current_block = Some(or_else);
158        let is_break = self.parse_scope(if_else.scope_else) || is_break;
159
160        if let Some(current_block) = self.current_block {
161            self.program.add_edge(current_block, next, 0);
162        } else {
163            // Returned
164            merge = self.ret;
165        }
166
167        let merge = if is_break { None } else { Some(merge) };
168        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
169            cond: if_else.cond,
170            then,
171            or_else,
172            merge,
173        };
174        if let Some(merge) = merge {
175            self.program[merge].block_use.push(BlockUse::Merge);
176        }
177        self.current_block = Some(next);
178    }
179
180    pub(crate) fn parse_switch(&mut self, switch: Switch) {
181        let current_block = self.current_block.unwrap();
182        let next = self.program.add_node(BasicBlock::default());
183
184        let branches = switch
185            .cases
186            .into_iter()
187            .map(|(val, case)| {
188                let case_id = self.program.add_node(BasicBlock::default());
189                self.program.add_edge(current_block, case_id, 0);
190                self.current_block = Some(case_id);
191                let is_break = self.parse_scope(case);
192                let is_ret = if let Some(current_block) = self.current_block {
193                    self.program.add_edge(current_block, next, 0);
194                    false
195                } else {
196                    !is_break
197                };
198                let val = match val.as_const().expect("Switch value must be constant") {
199                    ConstantValue::Int(val) => unsafe { transmute::<i32, u32>(val as i32) },
200                    ConstantValue::UInt(val) => val as u32,
201                    _ => unreachable!("Switch cases must be integer"),
202                };
203                (val, case_id, is_break, is_ret)
204            })
205            .collect::<Vec<_>>();
206
207        let is_break_branch = branches.iter().any(|it| it.2);
208        let mut is_ret = branches.iter().any(|it| it.3);
209        let branches = branches
210            .into_iter()
211            .map(|it| (it.0, it.1))
212            .collect::<Vec<_>>();
213
214        let default = self.program.add_node(BasicBlock::default());
215        self.program.add_edge(current_block, default, 0);
216        self.current_block = Some(default);
217        let is_break_def = self.parse_scope(switch.scope_default);
218
219        if let Some(current_block) = self.current_block {
220            self.program.add_edge(current_block, next, 0);
221        } else {
222            is_ret = !is_break_def;
223        }
224
225        let merge = if is_break_def || is_break_branch {
226            None
227        } else if is_ret {
228            Some(self.ret)
229        } else {
230            self.program[next].block_use.push(BlockUse::Merge);
231            Some(next)
232        };
233
234        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Switch {
235            value: switch.value,
236            default,
237            branches,
238            merge,
239        };
240
241        self.current_block = Some(next);
242    }
243
244    fn parse_loop(&mut self, loop_: Loop) {
245        let current_block = self.current_block.unwrap();
246        let header = self.program.add_node(BasicBlock::default());
247        self.program.add_edge(current_block, header, 0);
248
249        let body = self.program.add_node(BasicBlock::default());
250        let next = self.program.add_node(BasicBlock::default());
251
252        self.program.add_edge(header, body, 0);
253
254        self.loop_break.push_back(next);
255
256        self.current_block = Some(body);
257        self.parse_scope(loop_.scope);
258        let continue_target = self.program.add_node(BasicBlock::default());
259        self.program[continue_target]
260            .block_use
261            .push(BlockUse::ContinueTarget);
262
263        self.loop_break.pop_back();
264
265        if let Some(current_block) = self.current_block {
266            self.program.add_edge(current_block, continue_target, 0);
267        }
268
269        self.program.add_edge(continue_target, header, 0);
270
271        *self.program[header].control_flow.borrow_mut() = ControlFlow::Loop {
272            body,
273            continue_target,
274            merge: next,
275        };
276        self.program[next].block_use.push(BlockUse::Merge);
277        self.current_block = Some(next);
278    }
279
280    fn parse_for_loop(&mut self, range_loop: RangeLoop) {
281        let step = range_loop.step.unwrap_or(1.into());
282
283        let i_id = match range_loop.i.kind {
284            VariableKind::LocalMut { id, .. } => id,
285            _ => unreachable!(),
286        };
287        let i = range_loop.i;
288        self.program.variables.insert(i_id, i.ty);
289
290        let assign = Instruction::new(Operation::Copy(range_loop.start), i);
291        self.current_block_mut().ops.borrow_mut().push(assign);
292
293        let current_block = self.current_block.unwrap();
294        let header = self.program.add_node(BasicBlock::default());
295        self.program.add_edge(current_block, header, 0);
296
297        let body = self.program.add_node(BasicBlock::default());
298        let next = self.program.add_node(BasicBlock::default());
299
300        self.program.add_edge(header, body, 0);
301        self.program.add_edge(header, next, 0);
302
303        self.loop_break.push_back(next);
304
305        self.current_block = Some(body);
306        self.parse_scope(range_loop.scope);
307
308        self.loop_break.pop_back();
309
310        let current_block = self.current_block.expect("For loop has no loopback path");
311
312        let continue_target = if self.program[current_block]
313            .block_use
314            .contains(&BlockUse::Merge)
315        {
316            let target = self.program.add_node(BasicBlock::default());
317            self.program.add_edge(current_block, target, 0);
318            target
319        } else {
320            current_block
321        };
322
323        self.program.add_edge(continue_target, header, 0);
324
325        self.program[continue_target]
326            .block_use
327            .push(BlockUse::ContinueTarget);
328        self.program[next].block_use.push(BlockUse::Merge);
329        self.current_block = Some(next);
330
331        // For loop constructs
332        self.insert_phi(header, i_id, range_loop.start.ty);
333        {
334            let op = match range_loop.inclusive {
335                true => Comparison::LowerEqual,
336                false => Comparison::Lower,
337            };
338            let tmp = *self.allocator.create_local(Type::scalar(ElemType::Bool));
339            self.program[header].ops.borrow_mut().push(Instruction::new(
340                op(BinaryOperator {
341                    lhs: i,
342                    rhs: range_loop.end,
343                }),
344                tmp,
345            ));
346
347            *self.program[header].control_flow.borrow_mut() = ControlFlow::LoopBreak {
348                break_cond: tmp,
349                body,
350                continue_target,
351                merge: next,
352            };
353        }
354        self.program[current_block]
355            .ops
356            .borrow_mut()
357            .push(Instruction::new(
358                Arithmetic::Add(BinaryOperator { lhs: i, rhs: step }),
359                i,
360            ));
361    }
362
363    pub(crate) fn split_critical_edges(&mut self) {
364        for block in self.node_ids() {
365            let successors = self.program.edges(block);
366            let successors = successors.map(|edge| (edge.id(), edge.target()));
367            let successors: Vec<_> = successors.collect();
368
369            if successors.len() > 1 {
370                let crit = successors
371                    .iter()
372                    .filter(|(_, b)| self.predecessors(*b).len() > 1)
373                    .collect::<Vec<_>>();
374                for (edge, successor) in crit {
375                    self.program.remove_edge(*edge);
376                    let new_block = self.program.add_node(BasicBlock::default());
377                    self.program.add_edge(block, new_block, 0);
378                    self.program.add_edge(new_block, *successor, 0);
379                    self.invalidate_structure();
380                    update_phi(self, *successor, block, new_block);
381                    update_control_flow(self, block, *successor, new_block);
382                }
383            }
384        }
385    }
386
387    /// Split blocks at a `free` call because we only track liveness at a block level
388    /// It's easier than doing liveness per-instruction, and free calls are rare anyways
389    pub(crate) fn split_free(&mut self) {
390        let mut splits = 0;
391        while self.split_free_inner() {
392            splits += 1;
393        }
394        if splits > 0 {
395            self.invalidate_structure();
396        }
397    }
398
399    fn split_free_inner(&mut self) -> bool {
400        let is_free =
401            |inst: &Instruction| matches!(inst.operation, Operation::Marker(Marker::Free(_)));
402
403        for block in self.node_ids() {
404            let ops = self.block(block).ops.clone();
405            let len = ops.borrow().num_elements();
406            let idx = ops.borrow().values().position(is_free);
407            if let Some(idx) = idx {
408                // Separate free into its own block. They can be merged again later.
409                if idx > 0 {
410                    self.split_block_after(block, idx - 1);
411                    return true;
412                }
413                if idx < len - 1 {
414                    self.split_block_after(block, idx);
415                    return true;
416                }
417            }
418        }
419
420        false
421    }
422
423    /// Split block after `idx` and return the new block
424    fn split_block_after(&mut self, block: NodeIndex, idx: usize) -> NodeIndex {
425        let successors = self.successors(block);
426        let edges: Vec<EdgeIndex> = self
427            .program
428            .edges_directed(block, Direction::Outgoing)
429            .map(|it| it.id())
430            .collect();
431        for edge in edges {
432            self.program.remove_edge(edge);
433        }
434
435        let ops = self.block(block).ops.take();
436        let before: Vec<_> = ops.values().take(idx + 1).cloned().collect();
437        let after: Vec<_> = ops.values().skip(idx + 1).cloned().collect();
438        *self.block(block).ops.borrow_mut() = StableVec::from_iter(before);
439
440        let new_block = BasicBlock::default();
441        new_block.control_flow.swap(&self.block(block).control_flow);
442        new_block.ops.borrow_mut().extend(after);
443        let new_block = self.program.graph.add_node(new_block);
444
445        self.program.add_edge(block, new_block, 0);
446        for successor in successors {
447            self.program.add_edge(new_block, successor, 0);
448        }
449        new_block
450    }
451}
452
453fn update_control_flow(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
454    let update = |id: &mut NodeIndex| {
455        if *id == from {
456            *id = to
457        }
458    };
459
460    match &mut *opt.program[block].control_flow.borrow_mut() {
461        ControlFlow::IfElse { then, or_else, .. } => {
462            update(then);
463            update(or_else);
464        }
465        ControlFlow::Switch {
466            default, branches, ..
467        } => {
468            update(default);
469
470            for branch in branches {
471                update(&mut branch.1);
472            }
473        }
474        ControlFlow::Loop {
475            body,
476            continue_target,
477            merge,
478        } => {
479            update(body);
480            update(continue_target);
481            update(merge);
482        }
483        ControlFlow::LoopBreak {
484            body,
485            continue_target,
486            merge,
487            ..
488        } => {
489            update(body);
490            update(continue_target);
491            update(merge);
492        }
493        _ => {}
494    }
495}
496
497fn update_phi(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
498    for phi in opt.program[block].phi_nodes.borrow_mut().iter_mut() {
499        for entry in phi.entries.iter_mut() {
500            if entry.block == from {
501                entry.block = to;
502            }
503        }
504    }
505}