1use std::collections::{HashMap, HashSet, VecDeque};
2
3use cubecl_ir::Id;
4use petgraph::graph::NodeIndex;
5
6use crate::{Optimizer, analyses::post_order::PostOrder};
7
8use super::Analysis;
9
10pub struct Liveness {
11 live_vars: HashMap<NodeIndex, HashSet<Id>>,
12}
13
14#[derive(Clone)]
15struct BlockSets {
16 generated: HashSet<Id>,
17 kill: HashSet<Id>,
18}
19
20struct State {
21 worklist: VecDeque<NodeIndex>,
22 block_sets: HashMap<NodeIndex, BlockSets>,
23}
24
25impl Analysis for Liveness {
26 fn init(opt: &mut Optimizer) -> Self {
27 let mut this = Self::empty(opt);
28 this.analyze_liveness(opt);
29 this
30 }
31}
32
33impl Liveness {
34 pub fn empty(opt: &Optimizer) -> Self {
35 let live_vars = opt
36 .node_ids()
37 .iter()
38 .map(|it| (*it, HashSet::new()))
39 .collect();
40 Self { live_vars }
41 }
42
43 pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
44 &self.live_vars[&block]
45 }
46
47 pub fn is_dead(&self, node: NodeIndex, var: Id) -> bool {
48 !self.at_block(node).contains(&var)
49 }
50
51 pub fn analyze_liveness(&mut self, opt: &mut Optimizer) {
53 let mut state = State {
54 worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
55 block_sets: HashMap::new(),
56 };
57 while let Some(block) = state.worklist.pop_front() {
58 self.analyze_block(opt, block, &mut state);
59 }
60 }
61
62 fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
63 let BlockSets { generated, kill } = block_sets(opt, block, state);
64
65 let mut live_vars = generated.clone();
66
67 for successor in opt.successors(block) {
68 let successor = &self.live_vars[&successor];
69 live_vars.extend(successor.difference(kill));
70 }
71
72 if live_vars != self.live_vars[&block] {
73 state.worklist.extend(opt.predecessors(block));
74 self.live_vars.insert(block, live_vars);
75 }
76 }
77}
78
79fn block_sets<'a>(opt: &mut Optimizer, block: NodeIndex, state: &'a mut State) -> &'a BlockSets {
80 let block_sets = state.block_sets.entry(block);
81 block_sets.or_insert_with(|| calculate_block_sets(opt, block))
82}
83
84fn calculate_block_sets(opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
85 let mut generated = HashSet::new();
86 let mut kill = HashSet::new();
87
88 let ops = opt.program[block].ops.clone();
89
90 let control_flow = opt.program[block].control_flow.clone();
91 opt.visit_control_flow(&mut control_flow.borrow_mut(), |opt, var| {
92 if let Some(id) = opt.local_variable_id(var) {
93 generated.insert(id);
94 }
95 });
96 for op in ops.borrow_mut().values_mut().rev() {
97 opt.visit_out(&mut op.out, |opt, var| {
99 if let Some(id) = opt.local_variable_id(var) {
100 kill.insert(id);
101 generated.remove(&id);
102 }
103 });
104 opt.visit_operation(&mut op.operation, &mut op.out, |opt, var| {
105 if let Some(id) = opt.local_variable_id(var) {
106 generated.insert(id);
107 }
108 });
109 }
110
111 BlockSets { generated, kill }
112}
113
114pub mod shared {
116 use cubecl_ir::{Marker, Operation, Type, Variable, VariableKind};
117
118 use crate::Uniformity;
119
120 use super::*;
121
122 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
125 pub enum SharedMemory {
126 Array {
127 id: Id,
128 length: usize,
129 ty: Type,
130 align: usize,
131 },
132 Value {
133 id: Id,
134 ty: Type,
135 align: usize,
136 },
137 }
138
139 impl SharedMemory {
140 pub fn id(&self) -> u32 {
142 match self {
143 SharedMemory::Array { id, .. } => *id,
144 SharedMemory::Value { id, .. } => *id,
145 }
146 }
147
148 pub fn size(&self) -> usize {
150 match self {
151 SharedMemory::Array { length, ty, .. } => length * ty.size(),
152 SharedMemory::Value { ty, .. } => ty.size(),
153 }
154 }
155
156 pub fn align(&self) -> usize {
157 match self {
158 SharedMemory::Array { align, .. } => *align,
159 SharedMemory::Value { align, .. } => *align,
160 }
161 }
162 }
163
164 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
166 pub struct SmemAllocation {
167 pub smem: SharedMemory,
169 pub offset: usize,
171 }
172
173 #[derive(Default, Clone)]
181 pub struct SharedLiveness {
182 live_vars: HashMap<NodeIndex, HashSet<Id>>,
183 pub shared_memories: HashMap<Id, SharedMemory>,
186 pub allocations: HashMap<Id, SmemAllocation>,
189 }
190
191 impl Analysis for SharedLiveness {
192 fn init(opt: &mut Optimizer) -> Self {
193 let mut this = Self::empty(opt);
194 this.analyze_liveness(opt);
195 this.uniformize_liveness(opt);
196 this.allocate_slices(opt);
197 this
198 }
199 }
200
201 impl SharedLiveness {
202 pub fn empty(opt: &Optimizer) -> Self {
203 let live_vars = opt
204 .node_ids()
205 .iter()
206 .map(|it| (*it, HashSet::new()))
207 .collect();
208 Self {
209 live_vars,
210 shared_memories: Default::default(),
211 allocations: Default::default(),
212 }
213 }
214
215 pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
216 &self.live_vars[&block]
217 }
218
219 fn is_live(&self, node: NodeIndex, var: Id) -> bool {
220 self.at_block(node).contains(&var)
221 }
222
223 fn analyze_liveness(&mut self, opt: &mut Optimizer) {
225 let mut state = State {
226 worklist: VecDeque::from(opt.analysis::<PostOrder>().reverse()),
227 block_sets: HashMap::new(),
228 };
229 while let Some(block) = state.worklist.pop_front() {
230 self.analyze_block(opt, block, &mut state);
231 }
232 }
233
234 fn uniformize_liveness(&mut self, opt: &mut Optimizer) {
237 let mut state = State {
238 worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
239 block_sets: HashMap::new(),
240 };
241 while let Some(block) = state.worklist.pop_front() {
242 self.uniformize_block(opt, block, &mut state);
243 }
244 }
245
246 fn allocate_slices(&mut self, opt: &mut Optimizer) {
249 for block in opt.node_ids() {
250 for live_smem in self.at_block(block).clone() {
251 if !self.allocations.contains_key(&live_smem) {
252 let smem = self.shared_memories[&live_smem];
253 let offset = self.allocate_slice(block, smem.size(), smem.align());
254 self.allocations
255 .insert(smem.id(), SmemAllocation { smem, offset });
256 }
257 }
258 }
259 }
260
261 fn allocate_slice(&mut self, block: NodeIndex, size: usize, align: usize) -> usize {
270 let live_slices = self.live_slices(block);
271 if live_slices.is_empty() {
272 return 0;
273 }
274
275 for i in 0..live_slices.len() - 1 {
276 let slice_0 = &live_slices[i];
277 let slice_1 = &live_slices[i + 1];
278 let end_0 = (slice_0.offset + slice_0.smem.size()).next_multiple_of(align);
279 let gap = slice_1.offset.saturating_sub(end_0);
280 if gap >= size {
281 return end_0;
282 }
283 }
284 let last_slice = &live_slices[live_slices.len() - 1];
285 (last_slice.offset + last_slice.smem.size()).next_multiple_of(align)
286 }
287
288 fn live_slices(&mut self, block: NodeIndex) -> Vec<SmemAllocation> {
290 let mut live_slices = self
291 .allocations
292 .iter()
293 .filter(|(k, _)| self.is_live(block, **k))
294 .map(|it| *it.1)
295 .collect::<Vec<_>>();
296 live_slices.sort_by_key(|it| it.offset);
297 live_slices
298 }
299
300 fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
301 let BlockSets { generated, kill } = self.block_sets(opt, block, state);
302
303 let mut live_vars = generated.clone();
304
305 for predecessor in opt.predecessors(block) {
306 let predecessor = &self.live_vars[&predecessor];
307 live_vars.extend(predecessor.difference(kill));
308 }
309
310 if live_vars != self.live_vars[&block] {
311 state.worklist.extend(opt.successors(block));
312 self.live_vars.insert(block, live_vars);
313 }
314 }
315
316 fn uniformize_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
317 let mut live_vars = self.live_vars[&block].clone();
318 let uniformity = opt.analysis::<Uniformity>();
319
320 for successor in opt.successors(block) {
321 if !uniformity.is_block_uniform(successor) {
322 let successor = &self.live_vars[&successor];
323 live_vars.extend(successor);
324 }
325 }
326
327 if live_vars != self.live_vars[&block] {
328 state.worklist.extend(opt.predecessors(block));
329 self.live_vars.insert(block, live_vars);
330 }
331 }
332
333 fn block_sets<'a>(
334 &mut self,
335 opt: &mut Optimizer,
336 block: NodeIndex,
337 state: &'a mut State,
338 ) -> &'a BlockSets {
339 let block_sets = state.block_sets.entry(block);
340 block_sets.or_insert_with(|| self.calculate_block_sets(opt, block))
341 }
342
343 fn calculate_block_sets(&mut self, opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
346 let mut generated = HashSet::new();
347 let mut kill = HashSet::new();
348
349 let ops = opt.program[block].ops.clone();
350
351 for op in ops.borrow_mut().values_mut() {
352 opt.visit_out(&mut op.out, |_, var| {
353 if let Some(smem) = shared_memory(var) {
354 generated.insert(smem.id());
355 self.shared_memories.insert(smem.id(), smem);
356 }
357 });
358 opt.visit_operation(&mut op.operation, &mut op.out, |_, var| {
359 if let Some(smem) = shared_memory(var) {
360 generated.insert(smem.id());
361 self.shared_memories.insert(smem.id(), smem);
362 }
363 });
364
365 if let Operation::Marker(Marker::Free(Variable {
366 kind: VariableKind::SharedArray { id, .. } | VariableKind::Shared { id },
367 ..
368 })) = &op.operation
369 {
370 kill.insert(*id);
371 generated.remove(id);
372 }
373 }
374
375 BlockSets { generated, kill }
376 }
377 }
378
379 fn shared_memory(var: &Variable) -> Option<SharedMemory> {
380 match var.kind {
381 VariableKind::SharedArray {
382 id,
383 length,
384 unroll_factor,
385 alignment,
386 } => Some(SharedMemory::Array {
387 id,
388 length: length * unroll_factor,
389 ty: var.ty,
390 align: alignment.unwrap_or_else(|| var.ty.size()),
391 }),
392 VariableKind::Shared { id } => Some(SharedMemory::Value {
393 id,
394 ty: var.ty,
395 align: var.ty.size(),
396 }),
397 _ => None,
398 }
399 }
400}