cubecl_opt/
control_flow.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer};
6use cubecl_ir::{
7    Arithmetic, BinaryOperator, Branch, Comparison, ConstantValue, ElemType, If, IfElse,
8    Instruction, Loop, Marker, Operation, RangeLoop, Switch, Type, Variable, VariableKind,
9};
10use petgraph::{Direction, graph::EdgeIndex, visit::EdgeRef};
11use stable_vec::StableVec;
12
13/// Control flow that terminates a block
14#[derive(Default, Debug, Clone)]
15pub enum ControlFlow {
16    /// An if or if-else branch that should be structured if applicable.
17    IfElse {
18        cond: Variable,
19        then: NodeIndex,
20        or_else: NodeIndex,
21        merge: Option<NodeIndex>,
22    },
23    /// A switch branch that paths based on `value`
24    Switch {
25        value: Variable,
26        default: NodeIndex,
27        branches: Vec<(u32, NodeIndex)>,
28        merge: Option<NodeIndex>,
29    },
30    /// A loop with a header (the block that contains this variant), a `body` and a `continue target`.
31    /// `merge` is the block that gets executed as soon as the loop terminates.
32    Loop {
33        body: NodeIndex,
34        continue_target: NodeIndex,
35        merge: NodeIndex,
36    },
37    /// A loop with a header (the block that contains this variant), a `body` and a `continue target`.
38    /// `merge` is the block that gets executed as soon as the loop terminates. The header contains
39    /// the break condition.
40    LoopBreak {
41        break_cond: Variable,
42        body: NodeIndex,
43        continue_target: NodeIndex,
44        merge: NodeIndex,
45    },
46    /// A return statement. This should only occur once in the program and all other returns should
47    /// instead branch to this single return block.
48    Return,
49    /// No special control flow. The block must have exactly one edge that should be followed.
50    #[default]
51    None,
52}
53
54impl Optimizer {
55    pub(crate) fn parse_control_flow(&mut self, branch: Branch) {
56        match branch {
57            Branch::If(if_) => self.parse_if(*if_),
58            Branch::IfElse(if_else) => self.parse_if_else(if_else),
59            Branch::Switch(switch) => self.parse_switch(*switch),
60            Branch::RangeLoop(range_loop) => {
61                self.parse_for_loop(*range_loop);
62            }
63            Branch::Loop(loop_) => self.parse_loop(*loop_),
64            Branch::Return => {
65                let current_block = self.current_block.take().unwrap();
66                let ret = self.ret();
67                self.program.add_edge(current_block, ret, 0);
68            }
69            Branch::Break => {
70                let current_block = self.current_block.take().unwrap();
71                let loop_break = self.loop_break.back().expect("Can't break outside loop");
72                self.program.add_edge(current_block, *loop_break, 0);
73            }
74        }
75    }
76
77    pub(crate) fn parse_if(&mut self, if_: If) {
78        let current_block = self.current_block.unwrap();
79        let then = self.program.add_node(BasicBlock::default());
80        let next = self.program.add_node(BasicBlock::default());
81        let mut merge = next;
82
83        self.program.add_edge(current_block, then, 0);
84        self.program.add_edge(current_block, next, 0);
85
86        self.current_block = Some(then);
87        let is_break = self.parse_scope(if_.scope);
88
89        if let Some(current_block) = self.current_block {
90            self.program.add_edge(current_block, next, 0);
91        } else {
92            // Returned
93            merge = self.ret;
94        }
95
96        let merge = if is_break { None } else { Some(merge) };
97
98        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
99            cond: if_.cond,
100            then,
101            or_else: next,
102            merge,
103        };
104        if let Some(merge) = merge {
105            self.program[merge].block_use.push(BlockUse::Merge);
106        }
107        self.current_block = Some(next);
108    }
109
110    pub(crate) fn parse_if_else(&mut self, if_else: Box<IfElse>) {
111        let current_block = self.current_block.unwrap();
112        let then = self.program.add_node(BasicBlock::default());
113        let or_else = self.program.add_node(BasicBlock::default());
114        let next = self.program.add_node(BasicBlock::default());
115        let mut merge = next;
116
117        self.program.add_edge(current_block, then, 0);
118        self.program.add_edge(current_block, or_else, 0);
119
120        self.current_block = Some(then);
121        let is_break = self.parse_scope(if_else.scope_if);
122
123        if let Some(current_block) = self.current_block {
124            self.program.add_edge(current_block, next, 0);
125        } else {
126            // Returned
127            merge = self.ret;
128        }
129
130        self.current_block = Some(or_else);
131        let is_break = self.parse_scope(if_else.scope_else) || is_break;
132
133        if let Some(current_block) = self.current_block {
134            self.program.add_edge(current_block, next, 0);
135        } else {
136            // Returned
137            merge = self.ret;
138        }
139
140        let merge = if is_break { None } else { Some(merge) };
141        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::IfElse {
142            cond: if_else.cond,
143            then,
144            or_else,
145            merge,
146        };
147        if let Some(merge) = merge {
148            self.program[merge].block_use.push(BlockUse::Merge);
149        }
150        self.current_block = Some(next);
151    }
152
153    pub(crate) fn parse_switch(&mut self, switch: Switch) {
154        let current_block = self.current_block.unwrap();
155        let next = self.program.add_node(BasicBlock::default());
156
157        let branches = switch
158            .cases
159            .into_iter()
160            .map(|(val, case)| {
161                let case_id = self.program.add_node(BasicBlock::default());
162                self.program.add_edge(current_block, case_id, 0);
163                self.current_block = Some(case_id);
164                let is_break = self.parse_scope(case);
165                let is_ret = if let Some(current_block) = self.current_block {
166                    self.program.add_edge(current_block, next, 0);
167                    false
168                } else {
169                    !is_break
170                };
171                let val = match val.as_const().expect("Switch value must be constant") {
172                    ConstantValue::Int(val) => unsafe { transmute::<i32, u32>(val as i32) },
173                    ConstantValue::UInt(val) => val as u32,
174                    _ => unreachable!("Switch cases must be integer"),
175                };
176                (val, case_id, is_break, is_ret)
177            })
178            .collect::<Vec<_>>();
179
180        let is_break_branch = branches.iter().any(|it| it.2);
181        let mut is_ret = branches.iter().any(|it| it.3);
182        let branches = branches
183            .into_iter()
184            .map(|it| (it.0, it.1))
185            .collect::<Vec<_>>();
186
187        let default = self.program.add_node(BasicBlock::default());
188        self.program.add_edge(current_block, default, 0);
189        self.current_block = Some(default);
190        let is_break_def = self.parse_scope(switch.scope_default);
191
192        if let Some(current_block) = self.current_block {
193            self.program.add_edge(current_block, next, 0);
194        } else {
195            is_ret = !is_break_def;
196        }
197
198        let merge = if is_break_def || is_break_branch {
199            None
200        } else if is_ret {
201            Some(self.ret)
202        } else {
203            self.program[next].block_use.push(BlockUse::Merge);
204            Some(next)
205        };
206
207        *self.program[current_block].control_flow.borrow_mut() = ControlFlow::Switch {
208            value: switch.value,
209            default,
210            branches,
211            merge,
212        };
213
214        self.current_block = Some(next);
215    }
216
217    fn parse_loop(&mut self, loop_: Loop) {
218        let current_block = self.current_block.unwrap();
219        let header = self.program.add_node(BasicBlock::default());
220        self.program.add_edge(current_block, header, 0);
221
222        let body = self.program.add_node(BasicBlock::default());
223        let next = self.program.add_node(BasicBlock::default());
224
225        self.program.add_edge(header, body, 0);
226
227        self.loop_break.push_back(next);
228
229        self.current_block = Some(body);
230        self.parse_scope(loop_.scope);
231        let continue_target = self.program.add_node(BasicBlock::default());
232        self.program[continue_target]
233            .block_use
234            .push(BlockUse::ContinueTarget);
235
236        self.loop_break.pop_back();
237
238        if let Some(current_block) = self.current_block {
239            self.program.add_edge(current_block, continue_target, 0);
240        }
241
242        self.program.add_edge(continue_target, header, 0);
243
244        *self.program[header].control_flow.borrow_mut() = ControlFlow::Loop {
245            body,
246            continue_target,
247            merge: next,
248        };
249        self.program[next].block_use.push(BlockUse::Merge);
250        self.current_block = Some(next);
251    }
252
253    fn parse_for_loop(&mut self, range_loop: RangeLoop) {
254        let step = range_loop.step.unwrap_or(1.into());
255
256        let i_id = match range_loop.i.kind {
257            VariableKind::LocalMut { id, .. } => id,
258            _ => unreachable!(),
259        };
260        let i = range_loop.i;
261        self.program.variables.insert(i_id, i.ty);
262
263        let assign = Instruction::new(Operation::Copy(range_loop.start), i);
264        self.current_block_mut().ops.borrow_mut().push(assign);
265
266        let current_block = self.current_block.unwrap();
267        let header = self.program.add_node(BasicBlock::default());
268        self.program.add_edge(current_block, header, 0);
269
270        let body = self.program.add_node(BasicBlock::default());
271        let next = self.program.add_node(BasicBlock::default());
272
273        self.program.add_edge(header, body, 0);
274        self.program.add_edge(header, next, 0);
275
276        self.loop_break.push_back(next);
277
278        self.current_block = Some(body);
279        self.parse_scope(range_loop.scope);
280
281        self.loop_break.pop_back();
282
283        let current_block = self.current_block.expect("For loop has no loopback path");
284
285        let continue_target = if self.program[current_block]
286            .block_use
287            .contains(&BlockUse::Merge)
288        {
289            let target = self.program.add_node(BasicBlock::default());
290            self.program.add_edge(current_block, target, 0);
291            target
292        } else {
293            current_block
294        };
295
296        self.program.add_edge(continue_target, header, 0);
297
298        self.program[continue_target]
299            .block_use
300            .push(BlockUse::ContinueTarget);
301        self.program[next].block_use.push(BlockUse::Merge);
302        self.current_block = Some(next);
303
304        // For loop constructs
305        self.insert_phi(header, i_id, range_loop.start.ty);
306        {
307            let op = match range_loop.inclusive {
308                true => Comparison::LowerEqual,
309                false => Comparison::Lower,
310            };
311            let tmp = *self.allocator.create_local(Type::scalar(ElemType::Bool));
312            self.program[header].ops.borrow_mut().push(Instruction::new(
313                op(BinaryOperator {
314                    lhs: i,
315                    rhs: range_loop.end,
316                }),
317                tmp,
318            ));
319
320            *self.program[header].control_flow.borrow_mut() = ControlFlow::LoopBreak {
321                break_cond: tmp,
322                body,
323                continue_target,
324                merge: next,
325            };
326        }
327        self.program[current_block]
328            .ops
329            .borrow_mut()
330            .push(Instruction::new(
331                Arithmetic::Add(BinaryOperator { lhs: i, rhs: step }),
332                i,
333            ));
334    }
335
336    pub(crate) fn split_critical_edges(&mut self) {
337        for block in self.node_ids() {
338            let successors = self.program.edges(block);
339            let successors = successors.map(|edge| (edge.id(), edge.target()));
340            let successors: Vec<_> = successors.collect();
341
342            if successors.len() > 1 {
343                let crit = successors
344                    .iter()
345                    .filter(|(_, b)| self.predecessors(*b).len() > 1)
346                    .collect::<Vec<_>>();
347                for (edge, successor) in crit {
348                    self.program.remove_edge(*edge);
349                    let new_block = self.program.add_node(BasicBlock::default());
350                    self.program.add_edge(block, new_block, 0);
351                    self.program.add_edge(new_block, *successor, 0);
352                    self.invalidate_structure();
353                    update_phi(self, *successor, block, new_block);
354                    update_control_flow(self, block, *successor, new_block);
355                }
356            }
357        }
358    }
359
360    /// Split blocks at a `free` call because we only track liveness at a block level
361    /// It's easier than doing liveness per-instruction, and free calls are rare anyways
362    pub(crate) fn split_free(&mut self) {
363        let mut splits = 0;
364        while self.split_free_inner() {
365            splits += 1;
366        }
367        if splits > 0 {
368            self.invalidate_structure();
369        }
370    }
371
372    fn split_free_inner(&mut self) -> bool {
373        let is_free =
374            |inst: &Instruction| matches!(inst.operation, Operation::Marker(Marker::Free(_)));
375
376        for block in self.node_ids() {
377            let ops = self.block(block).ops.clone();
378            let len = ops.borrow().num_elements();
379            let idx = ops.borrow().values().position(is_free);
380            if let Some(idx) = idx {
381                // Separate free into its own block. They can be merged again later.
382                if idx > 0 {
383                    self.split_block_after(block, idx - 1);
384                    return true;
385                }
386                if idx < len - 1 {
387                    self.split_block_after(block, idx);
388                    return true;
389                }
390            }
391        }
392
393        false
394    }
395
396    /// Split block after `idx` and return the new block
397    fn split_block_after(&mut self, block: NodeIndex, idx: usize) -> NodeIndex {
398        let successors = self.successors(block);
399        let edges: Vec<EdgeIndex> = self
400            .program
401            .edges_directed(block, Direction::Outgoing)
402            .map(|it| it.id())
403            .collect();
404        for edge in edges {
405            self.program.remove_edge(edge);
406        }
407
408        let ops = self.block(block).ops.take();
409        let before: Vec<_> = ops.values().take(idx + 1).cloned().collect();
410        let after: Vec<_> = ops.values().skip(idx + 1).cloned().collect();
411        *self.block(block).ops.borrow_mut() = StableVec::from_iter(before);
412
413        let new_block = BasicBlock::default();
414        new_block.control_flow.swap(&self.block(block).control_flow);
415        new_block.ops.borrow_mut().extend(after);
416        let new_block = self.program.graph.add_node(new_block);
417
418        self.program.add_edge(block, new_block, 0);
419        for successor in successors {
420            self.program.add_edge(new_block, successor, 0);
421        }
422        new_block
423    }
424}
425
426fn update_control_flow(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
427    let update = |id: &mut NodeIndex| {
428        if *id == from {
429            *id = to
430        }
431    };
432
433    match &mut *opt.program[block].control_flow.borrow_mut() {
434        ControlFlow::IfElse { then, or_else, .. } => {
435            update(then);
436            update(or_else);
437        }
438        ControlFlow::Switch {
439            default, branches, ..
440        } => {
441            update(default);
442
443            for branch in branches {
444                update(&mut branch.1);
445            }
446        }
447        _ => {}
448    }
449}
450
451fn update_phi(opt: &mut Optimizer, block: NodeIndex, from: NodeIndex, to: NodeIndex) {
452    for phi in opt.program[block].phi_nodes.borrow_mut().iter_mut() {
453        for entry in phi.entries.iter_mut() {
454            if entry.block == from {
455                entry.block = to;
456            }
457        }
458    }
459}