cubecl_opt/
version.rs

1use std::{
2    collections::{HashMap, HashSet},
3    mem::take,
4};
5
6use cubecl_ir::{Id, Instruction, Type, Variable, VariableKind};
7use petgraph::visit::EdgeRef;
8
9use crate::{ControlFlow, EdgeIndex, NodeIndex};
10
11use super::Optimizer;
12
13/// The state required by the SSA transform
14#[derive(Debug)]
15pub struct SsaState<'a> {
16    versions: HashMap<Id, u16>,
17    visited_blocks: &'a mut HashSet<NodeIndex>,
18    visited_edges: &'a mut HashSet<EdgeIndex>,
19    max_versions: &'a mut HashMap<Id, u16>,
20}
21
22/// An entry in the phi instruction. Contains the variable ID that should be used when coming from
23/// `block`.
24#[derive(Debug, Clone, PartialEq)]
25pub struct PhiEntry {
26    pub block: NodeIndex,
27    pub value: Variable,
28}
29
30/// A phi node that picks its value based on the `BasicBlock` that came immediately before.
31/// For more information, see <https://en.wikipedia.org/wiki/Static_single-assignment_form>
32///
33/// # Example
34/// ```ignore
35/// if cond {
36///     result = "heads";
37/// } else {
38///     result = "tails";
39/// }
40/// ```
41/// would translate to the following SSA graph:
42/// ```ignore
43/// bb1: {
44///     branch if cond { bb2 } else { bb3 };
45/// }
46///
47/// bb2: {
48///     let result.v1 = "heads";
49///     branch bb4;
50/// }
51///
52/// bb3: {
53///     let result.v2 = "tails";
54///     branch bb4;
55/// }
56///
57/// bb4: {
58///     let result.v3 = phi [bb2: result.v1] [bb3: result.v2];
59/// }
60/// ```
61#[derive(Debug, Clone, PartialEq)]
62pub struct PhiInstruction {
63    /// The out variable for the phi instruction
64    pub out: Variable,
65    /// The set of `block`-`value` pairs for the phi instruction
66    pub entries: Vec<PhiEntry>,
67}
68
69impl Optimizer {
70    /// Version all variables in the program so they are each assigned to exactly once.
71    pub(crate) fn version_program(&mut self) {
72        let versions: HashMap<_, _> = self.program.variables.keys().map(|key| (*key, 0)).collect();
73        let mut visited_blocks = HashSet::new();
74        let mut visited_edges = HashSet::new();
75        let mut max_versions = versions.clone();
76        let initial_state = SsaState {
77            versions,
78            visited_blocks: &mut visited_blocks,
79            visited_edges: &mut visited_edges,
80            max_versions: &mut max_versions,
81        };
82        self.version_block(self.entry(), initial_state);
83    }
84
85    fn version_block(&mut self, block: NodeIndex, mut state: SsaState<'_>) {
86        self.version_block_ops(block, &mut state);
87
88        let edges: Vec<_> = self
89            .program
90            .edges(block)
91            .map(|it| (it.id(), it.target()))
92            .collect();
93        let state = &mut state;
94        for (edge_id, target) in edges {
95            let edge_visited = state.visited_edges.contains(&edge_id);
96            state.visited_edges.insert(edge_id);
97            let block_visited = state.visited_blocks.contains(&target);
98            state.visited_blocks.insert(block);
99
100            let new_state = SsaState {
101                versions: state.versions.clone(),
102                visited_blocks: state.visited_blocks,
103                visited_edges: state.visited_edges,
104                max_versions: state.max_versions,
105            };
106
107            if !edge_visited {
108                self.version_phi(target, block, &new_state);
109            }
110            if !block_visited {
111                self.version_block(target, new_state);
112            }
113        }
114    }
115
116    /// Version the phi entry for this edge
117    fn version_phi(&mut self, target: NodeIndex, source: NodeIndex, state: &SsaState<'_>) {
118        let phi = self.program[target].phi_nodes.clone();
119        for node in phi.borrow_mut().iter_mut() {
120            let entry = node
121                .entries
122                .iter_mut()
123                .find(|it| it.block == source)
124                .unwrap();
125            if let Some((id, item, _)) = as_versioned(entry.value)
126                && self.program.variables.contains_key(&id)
127            {
128                let version = state.versions[&id];
129                entry.value = Variable::new(VariableKind::Versioned { id, version }, item);
130            }
131        }
132    }
133
134    /// Version the operations for this block
135    fn version_block_ops(&mut self, block: NodeIndex, state: &mut SsaState<'_>) {
136        for phi in self.program[block].phi_nodes.borrow_mut().iter_mut() {
137            if let Some((id, item, _)) = as_versioned(phi.out)
138                && self.program.variables.contains_key(&id)
139            {
140                let version = state.versions.get_mut(&id).unwrap();
141                let max_version = state.max_versions.get_mut(&id).unwrap();
142                *max_version += 1;
143                *version = *max_version;
144                phi.out = Variable::new(
145                    VariableKind::Versioned {
146                        id,
147                        version: *version,
148                    },
149                    item,
150                );
151            }
152        }
153
154        let mut ops = take(&mut *self.program[block].ops.borrow_mut());
155        for operation in ops.values_mut() {
156            self.version_reads(operation, state);
157            self.version_writes(operation, state);
158        }
159        *self.program[block].ops.borrow_mut() = ops;
160        match &mut *self.program[block].control_flow.borrow_mut() {
161            super::ControlFlow::IfElse { cond, .. } => self.version_read(cond, state),
162            super::ControlFlow::LoopBreak { break_cond, .. } => {
163                self.version_read(break_cond, state)
164            }
165            ControlFlow::Switch { value, .. } => self.version_read(value, state),
166            _ => {}
167        }
168    }
169
170    fn version_reads(&mut self, op: &mut Instruction, state: &mut SsaState<'_>) {
171        self.visit_operation(&mut op.operation, &mut op.out, |opt, var| {
172            opt.version_read(var, state)
173        });
174    }
175
176    fn version_writes(&mut self, op: &mut Instruction, state: &mut SsaState<'_>) {
177        self.visit_out(&mut op.out, |_, var| match var.kind {
178            VariableKind::LocalMut { id } | VariableKind::Versioned { id, .. } => {
179                if let Some(version) = state.versions.get_mut(&id) {
180                    let max_version = state.max_versions.get_mut(&id).unwrap();
181                    *max_version += 1;
182                    *version = *max_version;
183                    *var = Variable::new(
184                        VariableKind::Versioned {
185                            id,
186                            version: *version,
187                        },
188                        var.ty,
189                    )
190                }
191            }
192            _ => {}
193        });
194    }
195
196    fn version_read(&self, var: &mut Variable, state: &mut SsaState<'_>) {
197        match var.kind {
198            VariableKind::LocalMut { id } | VariableKind::Versioned { id, .. } => {
199                if self.program.variables.contains_key(&id)
200                    && let Some(version) = state.versions.get(&id)
201                {
202                    *var = Variable::new(
203                        VariableKind::Versioned {
204                            id,
205                            version: *version,
206                        },
207                        var.ty,
208                    )
209                }
210            }
211            _ => {}
212        }
213    }
214}
215
216fn as_versioned(var: Variable) -> Option<(Id, Type, u16)> {
217    match var.kind {
218        VariableKind::Versioned { id, version } => Some((id, var.ty, version)),
219        _ => None,
220    }
221}