1use std::collections::{HashMap, HashSet, VecDeque};
2
3use cubecl_ir::Id;
4use petgraph::graph::NodeIndex;
5
6use crate::{Optimizer, analyses::post_order::PostOrder};
7
8use super::Analysis;
9
10pub struct Liveness {
11 live_vars: HashMap<NodeIndex, HashSet<Id>>,
12}
13
14#[derive(Clone)]
15struct BlockSets {
16 generated: HashSet<Id>,
17 kill: HashSet<Id>,
18}
19
20struct State {
21 worklist: VecDeque<NodeIndex>,
22 block_sets: HashMap<NodeIndex, BlockSets>,
23}
24
25impl Analysis for Liveness {
26 fn init(opt: &mut Optimizer) -> Self {
27 let mut this = Self::empty(opt);
28 this.analyze_liveness(opt);
29 this
30 }
31}
32
33impl Liveness {
34 pub fn empty(opt: &Optimizer) -> Self {
35 let live_vars = opt
36 .node_ids()
37 .iter()
38 .map(|it| (*it, HashSet::new()))
39 .collect();
40 Self { live_vars }
41 }
42
43 pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
44 &self.live_vars[&block]
45 }
46
47 pub fn is_dead(&self, node: NodeIndex, var: Id) -> bool {
48 !self.at_block(node).contains(&var)
49 }
50
51 pub fn analyze_liveness(&mut self, opt: &mut Optimizer) {
53 let mut state = State {
54 worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
55 block_sets: HashMap::new(),
56 };
57 while let Some(block) = state.worklist.pop_front() {
58 self.analyze_block(opt, block, &mut state);
59 }
60 }
61
62 fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
63 let BlockSets { generated, kill } = block_sets(opt, block, state);
64
65 let mut live_vars = generated.clone();
66
67 for successor in opt.successors(block) {
68 let successor = &self.live_vars[&successor];
69 live_vars.extend(successor.difference(kill));
70 }
71
72 if live_vars != self.live_vars[&block] {
73 state.worklist.extend(opt.predecessors(block));
74 self.live_vars.insert(block, live_vars);
75 }
76 }
77}
78
79fn block_sets<'a>(opt: &mut Optimizer, block: NodeIndex, state: &'a mut State) -> &'a BlockSets {
80 let block_sets = state.block_sets.entry(block);
81 block_sets.or_insert_with(|| calculate_block_sets(opt, block))
82}
83
84fn calculate_block_sets(opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
85 let mut generated = HashSet::new();
86 let mut kill = HashSet::new();
87
88 let ops = opt.program[block].ops.clone();
89
90 for op in ops.borrow_mut().values_mut().rev() {
91 opt.visit_out(&mut op.out, |opt, var| {
93 if let Some(id) = opt.local_variable_id(var) {
94 kill.insert(id);
95 generated.remove(&id);
96 }
97 });
98 opt.visit_operation(&mut op.operation, &mut op.out, |opt, var| {
99 if let Some(id) = opt.local_variable_id(var) {
100 generated.insert(id);
101 }
102 });
103 }
104
105 BlockSets { generated, kill }
106}
107
108pub mod shared {
110 use cubecl_ir::{Marker, Operation, Type, Variable, VariableKind};
111
112 use crate::Uniformity;
113
114 use super::*;
115
116 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
119 pub enum SharedMemory {
120 Array {
121 id: Id,
122 length: u32,
123 ty: Type,
124 align: u32,
125 },
126 Value {
127 id: Id,
128 ty: Type,
129 align: u32,
130 },
131 }
132
133 impl SharedMemory {
134 pub fn id(&self) -> u32 {
136 match self {
137 SharedMemory::Array { id, .. } => *id,
138 SharedMemory::Value { id, .. } => *id,
139 }
140 }
141
142 pub fn size(&self) -> u32 {
144 match self {
145 SharedMemory::Array { length, ty, .. } => length * ty.size() as u32,
146 SharedMemory::Value { ty, .. } => ty.size() as u32,
147 }
148 }
149
150 pub fn align(&self) -> u32 {
151 match self {
152 SharedMemory::Array { align, .. } => *align,
153 SharedMemory::Value { align, .. } => *align,
154 }
155 }
156 }
157
158 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
160 pub struct SmemAllocation {
161 pub smem: SharedMemory,
163 pub offset: u32,
165 }
166
167 #[derive(Default, Clone)]
175 pub struct SharedLiveness {
176 live_vars: HashMap<NodeIndex, HashSet<Id>>,
177 pub shared_memories: HashMap<Id, SharedMemory>,
180 pub allocations: HashMap<Id, SmemAllocation>,
183 }
184
185 impl Analysis for SharedLiveness {
186 fn init(opt: &mut Optimizer) -> Self {
187 let mut this = Self::empty(opt);
188 this.analyze_liveness(opt);
189 this.uniformize_liveness(opt);
190 this.allocate_slices(opt);
191 this
192 }
193 }
194
195 impl SharedLiveness {
196 pub fn empty(opt: &Optimizer) -> Self {
197 let live_vars = opt
198 .node_ids()
199 .iter()
200 .map(|it| (*it, HashSet::new()))
201 .collect();
202 Self {
203 live_vars,
204 shared_memories: Default::default(),
205 allocations: Default::default(),
206 }
207 }
208
209 pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
210 &self.live_vars[&block]
211 }
212
213 fn is_live(&self, node: NodeIndex, var: Id) -> bool {
214 self.at_block(node).contains(&var)
215 }
216
217 fn analyze_liveness(&mut self, opt: &mut Optimizer) {
219 let mut state = State {
220 worklist: VecDeque::from(opt.analysis::<PostOrder>().reverse()),
221 block_sets: HashMap::new(),
222 };
223 while let Some(block) = state.worklist.pop_front() {
224 self.analyze_block(opt, block, &mut state);
225 }
226 }
227
228 fn uniformize_liveness(&mut self, opt: &mut Optimizer) {
231 let mut state = State {
232 worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
233 block_sets: HashMap::new(),
234 };
235 while let Some(block) = state.worklist.pop_front() {
236 self.uniformize_block(opt, block, &mut state);
237 }
238 }
239
240 fn allocate_slices(&mut self, opt: &mut Optimizer) {
243 for block in opt.node_ids() {
244 for live_smem in self.at_block(block).clone() {
245 if !self.allocations.contains_key(&live_smem) {
246 let smem = self.shared_memories[&live_smem];
247 let offset = self.allocate_slice(block, smem.size(), smem.align());
248 self.allocations
249 .insert(smem.id(), SmemAllocation { smem, offset });
250 }
251 }
252 }
253 }
254
255 fn allocate_slice(&mut self, block: NodeIndex, size: u32, align: u32) -> u32 {
264 let live_slices = self.live_slices(block);
265 if live_slices.is_empty() {
266 return 0;
267 }
268
269 for i in 0..live_slices.len() - 1 {
270 let slice_0 = &live_slices[i];
271 let slice_1 = &live_slices[i + 1];
272 let end_0 = (slice_0.offset + slice_0.smem.size()).next_multiple_of(align);
273 let gap = slice_1.offset.saturating_sub(end_0);
274 if gap >= size {
275 return end_0;
276 }
277 }
278 let last_slice = &live_slices[live_slices.len() - 1];
279 (last_slice.offset + last_slice.smem.size()).next_multiple_of(align)
280 }
281
282 fn live_slices(&mut self, block: NodeIndex) -> Vec<SmemAllocation> {
284 let mut live_slices = self
285 .allocations
286 .iter()
287 .filter(|(k, _)| self.is_live(block, **k))
288 .map(|it| *it.1)
289 .collect::<Vec<_>>();
290 live_slices.sort_by_key(|it| it.offset);
291 live_slices
292 }
293
294 fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
295 let BlockSets { generated, kill } = self.block_sets(opt, block, state);
296
297 let mut live_vars = generated.clone();
298
299 for predecessor in opt.predecessors(block) {
300 let predecessor = &self.live_vars[&predecessor];
301 live_vars.extend(predecessor.difference(kill));
302 }
303
304 if live_vars != self.live_vars[&block] {
305 state.worklist.extend(opt.successors(block));
306 self.live_vars.insert(block, live_vars);
307 }
308 }
309
310 fn uniformize_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
311 let mut live_vars = self.live_vars[&block].clone();
312 let uniformity = opt.analysis::<Uniformity>();
313
314 for successor in opt.successors(block) {
315 if !uniformity.is_block_uniform(successor) {
316 let successor = &self.live_vars[&successor];
317 live_vars.extend(successor);
318 }
319 }
320
321 if live_vars != self.live_vars[&block] {
322 state.worklist.extend(opt.predecessors(block));
323 self.live_vars.insert(block, live_vars);
324 }
325 }
326
327 fn block_sets<'a>(
328 &mut self,
329 opt: &mut Optimizer,
330 block: NodeIndex,
331 state: &'a mut State,
332 ) -> &'a BlockSets {
333 let block_sets = state.block_sets.entry(block);
334 block_sets.or_insert_with(|| self.calculate_block_sets(opt, block))
335 }
336
337 fn calculate_block_sets(&mut self, opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
340 let mut generated = HashSet::new();
341 let mut kill = HashSet::new();
342
343 let ops = opt.program[block].ops.clone();
344
345 for op in ops.borrow_mut().values_mut() {
346 opt.visit_out(&mut op.out, |_, var| {
347 if let Some(smem) = shared_memory(var) {
348 generated.insert(smem.id());
349 self.shared_memories.insert(smem.id(), smem);
350 }
351 });
352 opt.visit_operation(&mut op.operation, &mut op.out, |_, var| {
353 if let Some(smem) = shared_memory(var) {
354 generated.insert(smem.id());
355 self.shared_memories.insert(smem.id(), smem);
356 }
357 });
358
359 if let Operation::Marker(Marker::Free(Variable {
360 kind: VariableKind::SharedArray { id, .. } | VariableKind::Shared { id },
361 ..
362 })) = &op.operation
363 {
364 kill.insert(*id);
365 generated.remove(id);
366 }
367 }
368
369 BlockSets { generated, kill }
370 }
371 }
372
373 fn shared_memory(var: &Variable) -> Option<SharedMemory> {
374 match var.kind {
375 VariableKind::SharedArray {
376 id,
377 length,
378 unroll_factor,
379 alignment,
380 } => Some(SharedMemory::Array {
381 id,
382 length: length * unroll_factor,
383 ty: var.ty,
384 align: alignment.unwrap_or_else(|| var.ty.size() as u32),
385 }),
386 VariableKind::Shared { id } => Some(SharedMemory::Value {
387 id,
388 ty: var.ty,
389 align: var.ty.size() as u32,
390 }),
391 _ => None,
392 }
393 }
394}