1use std::collections::{HashMap, HashSet, VecDeque};
2
3use cubecl_ir::Id;
4use petgraph::graph::NodeIndex;
5
6use crate::{Optimizer, analyses::post_order::PostOrder};
7
8use super::Analysis;
9
10pub struct Liveness {
11 live_vars: HashMap<NodeIndex, HashSet<Id>>,
12}
13
14#[derive(Clone)]
15struct BlockSets {
16 generated: HashSet<Id>,
17 kill: HashSet<Id>,
18}
19
20struct State {
21 worklist: VecDeque<NodeIndex>,
22 block_sets: HashMap<NodeIndex, BlockSets>,
23}
24
25impl Analysis for Liveness {
26 fn init(opt: &mut Optimizer) -> Self {
27 let mut this = Self::empty(opt);
28 this.analyze_liveness(opt);
29 this
30 }
31}
32
33impl Liveness {
34 pub fn empty(opt: &Optimizer) -> Self {
35 let live_vars = opt
36 .node_ids()
37 .iter()
38 .map(|it| (*it, HashSet::new()))
39 .collect();
40 Self { live_vars }
41 }
42
43 pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
44 &self.live_vars[&block]
45 }
46
47 pub fn is_dead(&self, node: NodeIndex, var: Id) -> bool {
48 !self.at_block(node).contains(&var)
49 }
50
51 pub fn analyze_liveness(&mut self, opt: &mut Optimizer) {
53 let mut state = State {
54 worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
55 block_sets: HashMap::new(),
56 };
57 while let Some(block) = state.worklist.pop_front() {
58 self.analyze_block(opt, block, &mut state);
59 }
60 }
61
62 fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
63 let BlockSets { generated, kill } = block_sets(opt, block, state);
64
65 let mut live_vars = generated.clone();
66
67 for successor in opt.successors(block) {
68 let successor = &self.live_vars[&successor];
69 live_vars.extend(successor.difference(kill));
70 }
71
72 if live_vars != self.live_vars[&block] {
73 state.worklist.extend(opt.predecessors(block));
74 self.live_vars.insert(block, live_vars);
75 }
76 }
77}
78
79fn block_sets<'a>(opt: &mut Optimizer, block: NodeIndex, state: &'a mut State) -> &'a BlockSets {
80 let block_sets = state.block_sets.entry(block);
81 block_sets.or_insert_with(|| calculate_block_sets(opt, block))
82}
83
84fn calculate_block_sets(opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
85 let mut generated = HashSet::new();
86 let mut kill = HashSet::new();
87
88 let ops = opt.program[block].ops.clone();
89
90 for op in ops.borrow_mut().values_mut().rev() {
91 opt.visit_out(&mut op.out, |opt, var| {
93 if let Some(id) = opt.local_variable_id(var) {
94 kill.insert(id);
95 generated.remove(&id);
96 }
97 });
98 opt.visit_operation(&mut op.operation, &mut op.out, |opt, var| {
99 if let Some(id) = opt.local_variable_id(var) {
100 generated.insert(id);
101 }
102 });
103 }
104
105 BlockSets { generated, kill }
106}
107
108pub mod shared {
110 use cubecl_ir::{Operation, Type, Variable, VariableKind};
111
112 use crate::Uniformity;
113
114 use super::*;
115
116 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
119 pub struct SharedMemory {
120 pub id: Id,
121 pub length: u32,
122 pub ty: Type,
123 pub align: u32,
124 }
125
126 impl SharedMemory {
127 pub fn size(&self) -> u32 {
129 self.length * self.ty.size() as u32
130 }
131 }
132
133 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
135 pub struct SmemAllocation {
136 pub smem: SharedMemory,
138 pub offset: u32,
140 }
141
142 #[derive(Default, Clone)]
150 pub struct SharedLiveness {
151 live_vars: HashMap<NodeIndex, HashSet<Id>>,
152 pub shared_memories: HashMap<Id, SharedMemory>,
155 pub allocations: HashMap<Id, SmemAllocation>,
158 }
159
160 impl Analysis for SharedLiveness {
161 fn init(opt: &mut Optimizer) -> Self {
162 let mut this = Self::empty(opt);
163 this.analyze_liveness(opt);
164 this.uniformize_liveness(opt);
165 this.allocate_slices(opt);
166 this
167 }
168 }
169
170 impl SharedLiveness {
171 pub fn empty(opt: &Optimizer) -> Self {
172 let live_vars = opt
173 .node_ids()
174 .iter()
175 .map(|it| (*it, HashSet::new()))
176 .collect();
177 Self {
178 live_vars,
179 shared_memories: Default::default(),
180 allocations: Default::default(),
181 }
182 }
183
184 pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
185 &self.live_vars[&block]
186 }
187
188 fn is_live(&self, node: NodeIndex, var: Id) -> bool {
189 self.at_block(node).contains(&var)
190 }
191
192 fn analyze_liveness(&mut self, opt: &mut Optimizer) {
194 let mut state = State {
195 worklist: VecDeque::from(opt.analysis::<PostOrder>().reverse()),
196 block_sets: HashMap::new(),
197 };
198 while let Some(block) = state.worklist.pop_front() {
199 self.analyze_block(opt, block, &mut state);
200 }
201 }
202
203 fn uniformize_liveness(&mut self, opt: &mut Optimizer) {
206 let mut state = State {
207 worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
208 block_sets: HashMap::new(),
209 };
210 while let Some(block) = state.worklist.pop_front() {
211 self.uniformize_block(opt, block, &mut state);
212 }
213 }
214
215 fn allocate_slices(&mut self, opt: &mut Optimizer) {
218 for block in opt.node_ids() {
219 for live_smem in self.at_block(block).clone() {
220 if !self.allocations.contains_key(&live_smem) {
221 let smem = self.shared_memories[&live_smem];
222 let offset = self.allocate_slice(block, smem.size(), smem.align);
223 self.allocations
224 .insert(smem.id, SmemAllocation { smem, offset });
225 }
226 }
227 }
228 }
229
230 fn allocate_slice(&mut self, block: NodeIndex, size: u32, align: u32) -> u32 {
239 let live_slices = self.live_slices(block);
240 if live_slices.is_empty() {
241 return 0;
242 }
243
244 for i in 0..live_slices.len() - 1 {
245 let slice_0 = &live_slices[i];
246 let slice_1 = &live_slices[i + 1];
247 let end_0 = (slice_0.offset + slice_0.smem.size()).next_multiple_of(align);
248 let gap = slice_1.offset - end_0;
249 if gap >= size {
250 return end_0;
251 }
252 }
253 let last_slice = &live_slices[live_slices.len() - 1];
254 (last_slice.offset + last_slice.smem.size()).next_multiple_of(align)
255 }
256
257 fn live_slices(&mut self, block: NodeIndex) -> Vec<SmemAllocation> {
259 let mut live_slices = self
260 .allocations
261 .iter()
262 .filter(|(k, _)| self.is_live(block, **k))
263 .map(|it| *it.1)
264 .collect::<Vec<_>>();
265 live_slices.sort_by_key(|it| it.offset);
266 live_slices
267 }
268
269 fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
270 let BlockSets { generated, kill } = self.block_sets(opt, block, state);
271
272 let mut live_vars = generated.clone();
273
274 for predecessor in opt.predecessors(block) {
275 let predecessor = &self.live_vars[&predecessor];
276 live_vars.extend(predecessor.difference(kill));
277 }
278
279 if live_vars != self.live_vars[&block] {
280 state.worklist.extend(opt.successors(block));
281 self.live_vars.insert(block, live_vars);
282 }
283 }
284
285 fn uniformize_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
286 let mut live_vars = self.live_vars[&block].clone();
287 let uniformity = opt.analysis::<Uniformity>();
288
289 for successor in opt.successors(block) {
290 if !uniformity.is_block_uniform(successor) {
291 let successor = &self.live_vars[&successor];
292 live_vars.extend(successor);
293 }
294 }
295
296 if live_vars != self.live_vars[&block] {
297 state.worklist.extend(opt.predecessors(block));
298 self.live_vars.insert(block, live_vars);
299 }
300 }
301
302 fn block_sets<'a>(
303 &mut self,
304 opt: &mut Optimizer,
305 block: NodeIndex,
306 state: &'a mut State,
307 ) -> &'a BlockSets {
308 let block_sets = state.block_sets.entry(block);
309 block_sets.or_insert_with(|| self.calculate_block_sets(opt, block))
310 }
311
312 fn calculate_block_sets(&mut self, opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
315 let mut generated = HashSet::new();
316 let mut kill = HashSet::new();
317
318 let ops = opt.program[block].ops.clone();
319
320 for op in ops.borrow_mut().values_mut() {
321 opt.visit_out(&mut op.out, |_, var| {
322 if let Some(smem) = shared_memory(var) {
323 generated.insert(smem.id);
324 self.shared_memories.insert(smem.id, smem);
325 }
326 });
327 opt.visit_operation(&mut op.operation, &mut op.out, |_, var| {
328 if let Some(smem) = shared_memory(var) {
329 generated.insert(smem.id);
330 self.shared_memories.insert(smem.id, smem);
331 }
332 });
333
334 if let Operation::Free(Variable {
335 kind: VariableKind::SharedMemory { id, .. },
336 ..
337 }) = &op.operation
338 {
339 kill.insert(*id);
340 generated.remove(id);
341 }
342 }
343
344 BlockSets { generated, kill }
345 }
346 }
347
348 fn shared_memory(var: &Variable) -> Option<SharedMemory> {
349 if let Variable {
350 kind:
351 VariableKind::SharedMemory {
352 id,
353 length,
354 unroll_factor,
355 alignment,
356 },
357 ..
358 } = *var
359 {
360 Some(SharedMemory {
361 id,
362 length: length * unroll_factor,
363 ty: var.ty,
364 align: alignment.unwrap_or_else(|| var.ty.size() as u32),
365 })
366 } else {
367 None
368 }
369 }
370}