Skip to main content

cubecl_opt/
version.rs

1use std::{
2    collections::{HashMap, HashSet},
3    mem::take,
4};
5
6use cubecl_ir::{Id, Instruction, Type, Variable, VariableKind};
7use petgraph::visit::EdgeRef;
8
9use crate::{ControlFlow, EdgeIndex, NodeIndex};
10
11use super::Optimizer;
12
13/// The state required by the SSA transform
14#[derive(Debug)]
15pub struct SsaState<'a> {
16    versions: HashMap<Id, u16>,
17    visited_blocks: &'a mut HashSet<NodeIndex>,
18    visited_edges: &'a mut HashSet<EdgeIndex>,
19    max_versions: &'a mut HashMap<Id, u16>,
20}
21
22/// An entry in the phi instruction. Contains the variable ID that should be used when coming from
23/// `block`.
24#[derive(Debug, Clone, PartialEq)]
25pub struct PhiEntry {
26    pub block: NodeIndex,
27    pub value: Variable,
28}
29
30/// A phi node that picks its value based on the `BasicBlock` that came immediately before.
31/// For more information, see <https://en.wikipedia.org/wiki/Static_single-assignment_form>
32///
33/// # Example
34/// ```ignore
35/// if cond {
36///     result = "heads";
37/// } else {
38///     result = "tails";
39/// }
40/// ```
41/// would translate to the following SSA graph:
42/// ```ignore
43/// bb1: {
44///     branch if cond { bb2 } else { bb3 };
45/// }
46///
47/// bb2: {
48///     let result.v1 = "heads";
49///     branch bb4;
50/// }
51///
52/// bb3: {
53///     let result.v2 = "tails";
54///     branch bb4;
55/// }
56///
57/// bb4: {
58///     let result.v3 = phi [bb2: result.v1] [bb3: result.v2];
59/// }
60/// ```
61#[derive(Debug, Clone, PartialEq)]
62pub struct PhiInstruction {
63    /// The out variable for the phi instruction
64    pub out: Variable,
65    /// The set of `block`-`value` pairs for the phi instruction
66    pub entries: Vec<PhiEntry>,
67}
68
69impl Optimizer {
70    /// Version all variables in the program so they are each assigned to exactly once.
71    pub(crate) fn version_program(&mut self) {
72        let versions: HashMap<_, _> = self.program.variables.keys().map(|key| (*key, 0)).collect();
73        let mut visited_blocks = HashSet::new();
74        let mut visited_edges = HashSet::new();
75        let mut max_versions = versions.clone();
76        let initial_state = SsaState {
77            versions,
78            visited_blocks: &mut visited_blocks,
79            visited_edges: &mut visited_edges,
80            max_versions: &mut max_versions,
81        };
82        self.version_block(self.entry(), initial_state);
83    }
84
85    fn version_block(&mut self, block: NodeIndex, mut state: SsaState<'_>) {
86        self.version_block_ops(block, &mut state);
87
88        let edges: Vec<_> = self
89            .program
90            .edges(block)
91            .map(|it| (it.id(), it.target()))
92            .collect();
93        let state = &mut state;
94        for (edge_id, target) in edges {
95            let edge_visited = state.visited_edges.contains(&edge_id);
96            state.visited_edges.insert(edge_id);
97            let block_visited = state.visited_blocks.contains(&target);
98            state.visited_blocks.insert(block);
99
100            let new_state = SsaState {
101                versions: state.versions.clone(),
102                visited_blocks: state.visited_blocks,
103                visited_edges: state.visited_edges,
104                max_versions: state.max_versions,
105            };
106
107            if !edge_visited {
108                self.version_phi(target, block, &new_state);
109            }
110            if !block_visited {
111                self.version_block(target, new_state);
112            }
113        }
114    }
115
116    /// Version the phi entry for this edge
117    fn version_phi(&mut self, target: NodeIndex, source: NodeIndex, state: &SsaState<'_>) {
118        let phi = self.program[target].phi_nodes.clone();
119        for node in phi.borrow_mut().iter_mut() {
120            let entry = node
121                .entries
122                .iter_mut()
123                .find(|it| it.block == source)
124                .unwrap();
125            if let Some((id, item, _)) = as_versioned(entry.value)
126                && self.program.variables.contains_key(&id)
127            {
128                let version = state.versions[&id];
129                entry.value = Variable::new(VariableKind::Versioned { id, version }, item);
130            }
131        }
132    }
133
134    /// Version the operations for this block
135    fn version_block_ops(&mut self, block: NodeIndex, state: &mut SsaState<'_>) {
136        for phi in self.program[block].phi_nodes.borrow_mut().iter_mut() {
137            if let Some((id, item, _)) = as_versioned(phi.out)
138                && self.program.variables.contains_key(&id)
139            {
140                let version = state.versions.get_mut(&id).unwrap();
141                let max_version = state.max_versions.get_mut(&id).unwrap();
142                *max_version += 1;
143                *version = *max_version;
144                phi.out = Variable::new(
145                    VariableKind::Versioned {
146                        id,
147                        version: *version,
148                    },
149                    item,
150                );
151            }
152        }
153
154        let mut ops = take(&mut *self.program[block].ops.borrow_mut());
155        for operation in ops.values_mut() {
156            self.version_reads(operation, state);
157            self.version_writes(operation, state);
158        }
159        *self.program[block].ops.borrow_mut() = ops;
160        match &mut *self.program[block].control_flow.borrow_mut() {
161            ControlFlow::IfElse { cond, .. } => self.version_read(cond, state),
162            ControlFlow::LoopBreak { break_cond, .. } => self.version_read(break_cond, state),
163            ControlFlow::Switch { value, .. } => self.version_read(value, state),
164            ControlFlow::Loop { .. } => {}
165            ControlFlow::Return | ControlFlow::Unreachable | ControlFlow::None => {}
166        }
167    }
168
169    fn version_reads(&mut self, op: &mut Instruction, state: &mut SsaState<'_>) {
170        self.visit_operation(&mut op.operation, &mut op.out, |opt, var| {
171            opt.version_read(var, state)
172        });
173    }
174
175    fn version_writes(&mut self, op: &mut Instruction, state: &mut SsaState<'_>) {
176        self.visit_out(&mut op.out, |_, var| match var.kind {
177            VariableKind::LocalMut { id } | VariableKind::Versioned { id, .. } => {
178                if let Some(version) = state.versions.get_mut(&id) {
179                    let max_version = state.max_versions.get_mut(&id).unwrap();
180                    *max_version += 1;
181                    *version = *max_version;
182                    *var = Variable::new(
183                        VariableKind::Versioned {
184                            id,
185                            version: *version,
186                        },
187                        var.ty,
188                    )
189                }
190            }
191            _ => {}
192        });
193    }
194
195    fn version_read(&self, var: &mut Variable, state: &mut SsaState<'_>) {
196        match var.kind {
197            VariableKind::LocalMut { id } | VariableKind::Versioned { id, .. } => {
198                if self.program.variables.contains_key(&id)
199                    && let Some(version) = state.versions.get(&id)
200                {
201                    *var = Variable::new(
202                        VariableKind::Versioned {
203                            id,
204                            version: *version,
205                        },
206                        var.ty,
207                    )
208                }
209            }
210            _ => {}
211        }
212    }
213}
214
215fn as_versioned(var: Variable) -> Option<(Id, Type, u16)> {
216    match var.kind {
217        VariableKind::Versioned { id, version } => Some((id, var.ty, version)),
218        _ => None,
219    }
220}