cubecl_opt/
block.rs

1use std::{cell::RefCell, rc::Rc};
2
3use cubecl_ir::{Instruction, Variable};
4use stable_vec::StableVec;
5
6use crate::{ControlFlow, Optimizer, version::PhiInstruction};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum BlockUse {
10    ContinueTarget,
11    Merge,
12}
13
14/// A basic block of instructions interrupted by control flow. Phi nodes are assumed to come before
15/// any instructions. See <https://en.wikipedia.org/wiki/Basic_block>
16#[derive(Default, Debug, Clone)]
17pub struct BasicBlock {
18    pub(crate) block_use: Vec<BlockUse>,
19    /// The phi nodes that are required to be generated at the start of this block.
20    pub phi_nodes: Rc<RefCell<Vec<PhiInstruction>>>,
21    /// A stable list of operations performed in this block.
22    pub ops: Rc<RefCell<StableVec<Instruction>>>,
23    /// The control flow that terminates this block.
24    pub control_flow: Rc<RefCell<ControlFlow>>,
25}
26
27impl Optimizer {
28    /// Visit all operations in the program with the specified read and write visitors.
29    pub fn visit_all(
30        &mut self,
31        mut visit_read: impl FnMut(&mut Self, &mut Variable) + Clone,
32        mut visit_write: impl FnMut(&mut Self, &mut Variable) + Clone,
33    ) {
34        for node in self.program.node_indices().collect::<Vec<_>>() {
35            let phi = self.program[node].phi_nodes.clone();
36            let ops = self.program[node].ops.clone();
37            let control_flow = self.program[node].control_flow.clone();
38
39            for phi in phi.borrow_mut().iter_mut() {
40                for elem in &mut phi.entries {
41                    visit_read(self, &mut elem.value);
42                }
43                visit_write(self, &mut phi.out);
44            }
45            for op in ops.borrow_mut().values_mut() {
46                self.visit_instruction(op, visit_read.clone(), visit_write.clone());
47            }
48            match &mut *control_flow.borrow_mut() {
49                ControlFlow::IfElse { cond, .. } => visit_read(self, cond),
50                ControlFlow::LoopBreak { break_cond, .. } => visit_read(self, break_cond),
51                ControlFlow::Switch { value, .. } => visit_read(self, value),
52                _ => {}
53            };
54        }
55    }
56}