cubecl_opt/analyses/
liveness.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3use cubecl_ir::Id;
4use petgraph::graph::NodeIndex;
5
6use crate::{Optimizer, analyses::post_order::PostOrder};
7
8use super::Analysis;
9
10pub struct Liveness {
11    live_vars: HashMap<NodeIndex, HashSet<Id>>,
12}
13
14#[derive(Clone)]
15struct BlockSets {
16    generated: HashSet<Id>,
17    kill: HashSet<Id>,
18}
19
20struct State {
21    worklist: VecDeque<NodeIndex>,
22    block_sets: HashMap<NodeIndex, BlockSets>,
23}
24
25impl Analysis for Liveness {
26    fn init(opt: &mut Optimizer) -> Self {
27        let mut this = Self::empty(opt);
28        this.analyze_liveness(opt);
29        this
30    }
31}
32
33impl Liveness {
34    pub fn empty(opt: &Optimizer) -> Self {
35        let live_vars = opt
36            .node_ids()
37            .iter()
38            .map(|it| (*it, HashSet::new()))
39            .collect();
40        Self { live_vars }
41    }
42
43    pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
44        &self.live_vars[&block]
45    }
46
47    pub fn is_dead(&self, node: NodeIndex, var: Id) -> bool {
48        !self.at_block(node).contains(&var)
49    }
50
51    /// Do a conservative block level liveness analysis
52    pub fn analyze_liveness(&mut self, opt: &mut Optimizer) {
53        let mut state = State {
54            worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
55            block_sets: HashMap::new(),
56        };
57        while let Some(block) = state.worklist.pop_front() {
58            self.analyze_block(opt, block, &mut state);
59        }
60    }
61
62    fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
63        let BlockSets { generated, kill } = block_sets(opt, block, state);
64
65        let mut live_vars = generated.clone();
66
67        for successor in opt.successors(block) {
68            let successor = &self.live_vars[&successor];
69            live_vars.extend(successor.difference(kill));
70        }
71
72        if live_vars != self.live_vars[&block] {
73            state.worklist.extend(opt.predecessors(block));
74            self.live_vars.insert(block, live_vars);
75        }
76    }
77}
78
79fn block_sets<'a>(opt: &mut Optimizer, block: NodeIndex, state: &'a mut State) -> &'a BlockSets {
80    let block_sets = state.block_sets.entry(block);
81    block_sets.or_insert_with(|| calculate_block_sets(opt, block))
82}
83
84fn calculate_block_sets(opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
85    let mut generated = HashSet::new();
86    let mut kill = HashSet::new();
87
88    let ops = opt.program[block].ops.clone();
89
90    for op in ops.borrow_mut().values_mut().rev() {
91        // Reads must be tracked after writes
92        opt.visit_out(&mut op.out, |opt, var| {
93            if let Some(id) = opt.local_variable_id(var) {
94                kill.insert(id);
95                generated.remove(&id);
96            }
97        });
98        opt.visit_operation(&mut op.operation, &mut op.out, |opt, var| {
99            if let Some(id) = opt.local_variable_id(var) {
100                generated.insert(id);
101            }
102        });
103    }
104
105    BlockSets { generated, kill }
106}
107
108/// Shared memory liveness analysis and allocation
109pub mod shared {
110    use cubecl_ir::{Operation, Type, Variable, VariableKind};
111
112    use crate::Uniformity;
113
114    use super::*;
115
116    /// A shared memory instance, all the information contained in the `VariableKind`, but with
117    /// a non-optional `align`.
118    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
119    pub struct SharedMemory {
120        pub id: Id,
121        pub length: u32,
122        pub ty: Type,
123        pub align: u32,
124    }
125
126    impl SharedMemory {
127        /// The byte size of this shared memory
128        pub fn size(&self) -> u32 {
129            self.length * self.ty.size() as u32
130        }
131    }
132
133    /// A specific allocation of shared memory at some `offset`
134    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
135    pub struct SmemAllocation {
136        /// The shared memory being allocated
137        pub smem: SharedMemory,
138        /// The offset in the shared memory buffer
139        pub offset: u32,
140    }
141
142    /// Shared liveness works the other way around from normal liveness, since shared memory lives
143    /// forever by default. So any use (read or write) inserts it as live, while only `free` changes
144    /// the state to dead.
145    ///
146    /// It also handles allocation of slices to each shared memory object, using the analyzed
147    /// liveness. `allocations` contains a specific slice allocation for each shared memory, while
148    /// ensuring no shared memories that exist at the same time can overlap.
149    #[derive(Default, Clone)]
150    pub struct SharedLiveness {
151        live_vars: HashMap<NodeIndex, HashSet<Id>>,
152        /// Map of all shared memories by their ID. Populated during the first pass with all
153        /// accessed shared memories.
154        pub shared_memories: HashMap<Id, SharedMemory>,
155        /// Map of allocations for each shared memory by its ID. Populated after the analysis, and
156        /// should contain all memories from `shared_memories`.
157        pub allocations: HashMap<Id, SmemAllocation>,
158    }
159
160    impl Analysis for SharedLiveness {
161        fn init(opt: &mut Optimizer) -> Self {
162            let mut this = Self::empty(opt);
163            this.analyze_liveness(opt);
164            this.uniformize_liveness(opt);
165            this.allocate_slices(opt);
166            this
167        }
168    }
169
170    impl SharedLiveness {
171        pub fn empty(opt: &Optimizer) -> Self {
172            let live_vars = opt
173                .node_ids()
174                .iter()
175                .map(|it| (*it, HashSet::new()))
176                .collect();
177            Self {
178                live_vars,
179                shared_memories: Default::default(),
180                allocations: Default::default(),
181            }
182        }
183
184        pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
185            &self.live_vars[&block]
186        }
187
188        fn is_live(&self, node: NodeIndex, var: Id) -> bool {
189            self.at_block(node).contains(&var)
190        }
191
192        /// Do a conservative block level liveness analysis
193        fn analyze_liveness(&mut self, opt: &mut Optimizer) {
194            let mut state = State {
195                worklist: VecDeque::from(opt.analysis::<PostOrder>().reverse()),
196                block_sets: HashMap::new(),
197            };
198            while let Some(block) = state.worklist.pop_front() {
199                self.analyze_block(opt, block, &mut state);
200            }
201        }
202
203        /// Extend divergent liveness to the preceding uniform block. Shared memory is always
204        /// uniformly declared, so it must be allocated before the branch.
205        fn uniformize_liveness(&mut self, opt: &mut Optimizer) {
206            let mut state = State {
207                worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
208                block_sets: HashMap::new(),
209            };
210            while let Some(block) = state.worklist.pop_front() {
211                self.uniformize_block(opt, block, &mut state);
212            }
213        }
214
215        /// Allocate slices while ensuring no concurrent shared memory slices overlap.
216        /// See also [`allocate_slice`]
217        fn allocate_slices(&mut self, opt: &mut Optimizer) {
218            for block in opt.node_ids() {
219                for live_smem in self.at_block(block).clone() {
220                    if !self.allocations.contains_key(&live_smem) {
221                        let smem = self.shared_memories[&live_smem];
222                        let offset = self.allocate_slice(block, smem.size(), smem.align);
223                        self.allocations
224                            .insert(smem.id, SmemAllocation { smem, offset });
225                    }
226                }
227            }
228        }
229
230        /// Finds a valid offset for a specific slice, taking into account ranges that are already
231        /// in use.
232        ///
233        /// Essentially the same as the global memory pool, looking for a free slice first, then
234        /// extending the pool if there isn't one. Note that this linear algorithm isn't optimal
235        /// for offline allocations where we know all allocations beforehand, but should be good
236        /// enough for our current purposes. It may produce larger-than-required allocations in
237        /// some cases. Optimal allocation would require a far more complex algorithm.
238        fn allocate_slice(&mut self, block: NodeIndex, size: u32, align: u32) -> u32 {
239            let live_slices = self.live_slices(block);
240            if live_slices.is_empty() {
241                return 0;
242            }
243
244            for i in 0..live_slices.len() - 1 {
245                let slice_0 = &live_slices[i];
246                let slice_1 = &live_slices[i + 1];
247                let end_0 = (slice_0.offset + slice_0.smem.size()).next_multiple_of(align);
248                let gap = slice_1.offset - end_0;
249                if gap >= size {
250                    return end_0;
251                }
252            }
253            let last_slice = &live_slices[live_slices.len() - 1];
254            (last_slice.offset + last_slice.smem.size()).next_multiple_of(align)
255        }
256
257        /// List of allocations that are currently live
258        fn live_slices(&mut self, block: NodeIndex) -> Vec<SmemAllocation> {
259            let mut live_slices = self
260                .allocations
261                .iter()
262                .filter(|(k, _)| self.is_live(block, **k))
263                .map(|it| *it.1)
264                .collect::<Vec<_>>();
265            live_slices.sort_by_key(|it| it.offset);
266            live_slices
267        }
268
269        fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
270            let BlockSets { generated, kill } = self.block_sets(opt, block, state);
271
272            let mut live_vars = generated.clone();
273
274            for predecessor in opt.predecessors(block) {
275                let predecessor = &self.live_vars[&predecessor];
276                live_vars.extend(predecessor.difference(kill));
277            }
278
279            if live_vars != self.live_vars[&block] {
280                state.worklist.extend(opt.successors(block));
281                self.live_vars.insert(block, live_vars);
282            }
283        }
284
285        fn uniformize_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
286            let mut live_vars = self.live_vars[&block].clone();
287            let uniformity = opt.analysis::<Uniformity>();
288
289            for successor in opt.successors(block) {
290                if !uniformity.is_block_uniform(successor) {
291                    let successor = &self.live_vars[&successor];
292                    live_vars.extend(successor);
293                }
294            }
295
296            if live_vars != self.live_vars[&block] {
297                state.worklist.extend(opt.predecessors(block));
298                self.live_vars.insert(block, live_vars);
299            }
300        }
301
302        fn block_sets<'a>(
303            &mut self,
304            opt: &mut Optimizer,
305            block: NodeIndex,
306            state: &'a mut State,
307        ) -> &'a BlockSets {
308            let block_sets = state.block_sets.entry(block);
309            block_sets.or_insert_with(|| self.calculate_block_sets(opt, block))
310        }
311
312        /// Any use makes a shared memory live (`generated`), while `free` kills it (`kill`).
313        /// Also collects all shared memories into a map.
314        fn calculate_block_sets(&mut self, opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
315            let mut generated = HashSet::new();
316            let mut kill = HashSet::new();
317
318            let ops = opt.program[block].ops.clone();
319
320            for op in ops.borrow_mut().values_mut() {
321                opt.visit_out(&mut op.out, |_, var| {
322                    if let Some(smem) = shared_memory(var) {
323                        generated.insert(smem.id);
324                        self.shared_memories.insert(smem.id, smem);
325                    }
326                });
327                opt.visit_operation(&mut op.operation, &mut op.out, |_, var| {
328                    if let Some(smem) = shared_memory(var) {
329                        generated.insert(smem.id);
330                        self.shared_memories.insert(smem.id, smem);
331                    }
332                });
333
334                if let Operation::Free(Variable {
335                    kind: VariableKind::SharedMemory { id, .. },
336                    ..
337                }) = &op.operation
338                {
339                    kill.insert(*id);
340                    generated.remove(id);
341                }
342            }
343
344            BlockSets { generated, kill }
345        }
346    }
347
348    fn shared_memory(var: &Variable) -> Option<SharedMemory> {
349        if let Variable {
350            kind:
351                VariableKind::SharedMemory {
352                    id,
353                    length,
354                    unroll_factor,
355                    alignment,
356                },
357            ..
358        } = *var
359        {
360            Some(SharedMemory {
361                id,
362                length: length * unroll_factor,
363                ty: var.ty,
364                align: alignment.unwrap_or_else(|| var.ty.size() as u32),
365            })
366        } else {
367            None
368        }
369    }
370}