1use std::{
2 collections::{HashMap, HashSet},
3 mem::take,
4};
5
6use cubecl_ir::{Id, Instruction, Type, Variable, VariableKind};
7use petgraph::visit::EdgeRef;
8
9use crate::{ControlFlow, EdgeIndex, NodeIndex};
10
11use super::Optimizer;
12
13#[derive(Debug)]
15pub struct SsaState<'a> {
16 versions: HashMap<Id, u16>,
17 visited_blocks: &'a mut HashSet<NodeIndex>,
18 visited_edges: &'a mut HashSet<EdgeIndex>,
19 max_versions: &'a mut HashMap<Id, u16>,
20}
21
22#[derive(Debug, Clone, PartialEq)]
25pub struct PhiEntry {
26 pub block: NodeIndex,
27 pub value: Variable,
28}
29
30#[derive(Debug, Clone, PartialEq)]
62pub struct PhiInstruction {
63 pub out: Variable,
65 pub entries: Vec<PhiEntry>,
67}
68
69impl Optimizer {
70 pub(crate) fn version_program(&mut self) {
72 let versions: HashMap<_, _> = self.program.variables.keys().map(|key| (*key, 0)).collect();
73 let mut visited_blocks = HashSet::new();
74 let mut visited_edges = HashSet::new();
75 let mut max_versions = versions.clone();
76 let initial_state = SsaState {
77 versions,
78 visited_blocks: &mut visited_blocks,
79 visited_edges: &mut visited_edges,
80 max_versions: &mut max_versions,
81 };
82 self.version_block(self.entry(), initial_state);
83 }
84
85 fn version_block(&mut self, block: NodeIndex, mut state: SsaState<'_>) {
86 self.version_block_ops(block, &mut state);
87
88 let edges: Vec<_> = self
89 .program
90 .edges(block)
91 .map(|it| (it.id(), it.target()))
92 .collect();
93 let state = &mut state;
94 for (edge_id, target) in edges {
95 let edge_visited = state.visited_edges.contains(&edge_id);
96 state.visited_edges.insert(edge_id);
97 let block_visited = state.visited_blocks.contains(&target);
98 state.visited_blocks.insert(block);
99
100 let new_state = SsaState {
101 versions: state.versions.clone(),
102 visited_blocks: state.visited_blocks,
103 visited_edges: state.visited_edges,
104 max_versions: state.max_versions,
105 };
106
107 if !edge_visited {
108 self.version_phi(target, block, &new_state);
109 }
110 if !block_visited {
111 self.version_block(target, new_state);
112 }
113 }
114 }
115
116 fn version_phi(&mut self, target: NodeIndex, source: NodeIndex, state: &SsaState<'_>) {
118 let phi = self.program[target].phi_nodes.clone();
119 for node in phi.borrow_mut().iter_mut() {
120 let entry = node
121 .entries
122 .iter_mut()
123 .find(|it| it.block == source)
124 .unwrap();
125 if let Some((id, item, _)) = as_versioned(entry.value)
126 && self.program.variables.contains_key(&id)
127 {
128 let version = state.versions[&id];
129 entry.value = Variable::new(VariableKind::Versioned { id, version }, item);
130 }
131 }
132 }
133
134 fn version_block_ops(&mut self, block: NodeIndex, state: &mut SsaState<'_>) {
136 for phi in self.program[block].phi_nodes.borrow_mut().iter_mut() {
137 if let Some((id, item, _)) = as_versioned(phi.out)
138 && self.program.variables.contains_key(&id)
139 {
140 let version = state.versions.get_mut(&id).unwrap();
141 let max_version = state.max_versions.get_mut(&id).unwrap();
142 *max_version += 1;
143 *version = *max_version;
144 phi.out = Variable::new(
145 VariableKind::Versioned {
146 id,
147 version: *version,
148 },
149 item,
150 );
151 }
152 }
153
154 let mut ops = take(&mut *self.program[block].ops.borrow_mut());
155 for operation in ops.values_mut() {
156 self.version_reads(operation, state);
157 self.version_writes(operation, state);
158 }
159 *self.program[block].ops.borrow_mut() = ops;
160 match &mut *self.program[block].control_flow.borrow_mut() {
161 super::ControlFlow::IfElse { cond, .. } => self.version_read(cond, state),
162 super::ControlFlow::LoopBreak { break_cond, .. } => {
163 self.version_read(break_cond, state)
164 }
165 ControlFlow::Switch { value, .. } => self.version_read(value, state),
166 _ => {}
167 }
168 }
169
170 fn version_reads(&mut self, op: &mut Instruction, state: &mut SsaState<'_>) {
171 self.visit_operation(&mut op.operation, &mut op.out, |opt, var| {
172 opt.version_read(var, state)
173 });
174 }
175
176 fn version_writes(&mut self, op: &mut Instruction, state: &mut SsaState<'_>) {
177 self.visit_out(&mut op.out, |_, var| match var.kind {
178 VariableKind::LocalMut { id } | VariableKind::Versioned { id, .. } => {
179 if let Some(version) = state.versions.get_mut(&id) {
180 let max_version = state.max_versions.get_mut(&id).unwrap();
181 *max_version += 1;
182 *version = *max_version;
183 *var = Variable::new(
184 VariableKind::Versioned {
185 id,
186 version: *version,
187 },
188 var.ty,
189 )
190 }
191 }
192 _ => {}
193 });
194 }
195
196 fn version_read(&self, var: &mut Variable, state: &mut SsaState<'_>) {
197 match var.kind {
198 VariableKind::LocalMut { id } | VariableKind::Versioned { id, .. } => {
199 if self.program.variables.contains_key(&id)
200 && let Some(version) = state.versions.get(&id)
201 {
202 *var = Variable::new(
203 VariableKind::Versioned {
204 id,
205 version: *version,
206 },
207 var.ty,
208 )
209 }
210 }
211 _ => {}
212 }
213 }
214}
215
216fn as_versioned(var: Variable) -> Option<(Id, Type, u16)> {
217 match var.kind {
218 VariableKind::Versioned { id, version } => Some((id, var.ty, version)),
219 _ => None,
220 }
221}