cubecl_opt/analyses/
liveness.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3use cubecl_ir::Id;
4use petgraph::graph::NodeIndex;
5
6use crate::{Optimizer, analyses::post_order::PostOrder};
7
8use super::Analysis;
9
10pub struct Liveness {
11    live_vars: HashMap<NodeIndex, HashSet<Id>>,
12}
13
14#[derive(Clone)]
15struct BlockSets {
16    generated: HashSet<Id>,
17    kill: HashSet<Id>,
18}
19
20struct State {
21    worklist: VecDeque<NodeIndex>,
22    block_sets: HashMap<NodeIndex, BlockSets>,
23}
24
25impl Analysis for Liveness {
26    fn init(opt: &mut Optimizer) -> Self {
27        let mut this = Self::empty(opt);
28        this.analyze_liveness(opt);
29        this
30    }
31}
32
33impl Liveness {
34    pub fn empty(opt: &Optimizer) -> Self {
35        let live_vars = opt
36            .node_ids()
37            .iter()
38            .map(|it| (*it, HashSet::new()))
39            .collect();
40        Self { live_vars }
41    }
42
43    pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
44        &self.live_vars[&block]
45    }
46
47    pub fn is_dead(&self, node: NodeIndex, var: Id) -> bool {
48        !self.at_block(node).contains(&var)
49    }
50
51    /// Do a conservative block level liveness analysis
52    pub fn analyze_liveness(&mut self, opt: &mut Optimizer) {
53        let mut state = State {
54            worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
55            block_sets: HashMap::new(),
56        };
57        while let Some(block) = state.worklist.pop_front() {
58            self.analyze_block(opt, block, &mut state);
59        }
60    }
61
62    fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
63        let BlockSets { generated, kill } = block_sets(opt, block, state);
64
65        let mut live_vars = generated.clone();
66
67        for successor in opt.successors(block) {
68            let successor = &self.live_vars[&successor];
69            live_vars.extend(successor.difference(kill));
70        }
71
72        if live_vars != self.live_vars[&block] {
73            state.worklist.extend(opt.predecessors(block));
74            self.live_vars.insert(block, live_vars);
75        }
76    }
77}
78
79fn block_sets<'a>(opt: &mut Optimizer, block: NodeIndex, state: &'a mut State) -> &'a BlockSets {
80    let block_sets = state.block_sets.entry(block);
81    block_sets.or_insert_with(|| calculate_block_sets(opt, block))
82}
83
84fn calculate_block_sets(opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
85    let mut generated = HashSet::new();
86    let mut kill = HashSet::new();
87
88    let ops = opt.program[block].ops.clone();
89
90    for op in ops.borrow_mut().values_mut().rev() {
91        // Reads must be tracked after writes
92        opt.visit_out(&mut op.out, |opt, var| {
93            if let Some(id) = opt.local_variable_id(var) {
94                kill.insert(id);
95                generated.remove(&id);
96            }
97        });
98        opt.visit_operation(&mut op.operation, &mut op.out, |opt, var| {
99            if let Some(id) = opt.local_variable_id(var) {
100                generated.insert(id);
101            }
102        });
103    }
104
105    BlockSets { generated, kill }
106}
107
108/// Shared memory liveness analysis and allocation
109pub mod shared {
110    use cubecl_ir::{Marker, Operation, Type, Variable, VariableKind};
111
112    use crate::Uniformity;
113
114    use super::*;
115
116    /// A shared memory instance, all the information contained in the `VariableKind`, but with
117    /// a non-optional `align`.
118    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
119    pub enum SharedMemory {
120        Array {
121            id: Id,
122            length: u32,
123            ty: Type,
124            align: u32,
125        },
126        Value {
127            id: Id,
128            ty: Type,
129            align: u32,
130        },
131    }
132
133    impl SharedMemory {
134        /// The byte size of this shared memory
135        pub fn id(&self) -> u32 {
136            match self {
137                SharedMemory::Array { id, .. } => *id,
138                SharedMemory::Value { id, .. } => *id,
139            }
140        }
141
142        /// The byte size of this shared memory
143        pub fn size(&self) -> u32 {
144            match self {
145                SharedMemory::Array { length, ty, .. } => length * ty.size() as u32,
146                SharedMemory::Value { ty, .. } => ty.size() as u32,
147            }
148        }
149
150        pub fn align(&self) -> u32 {
151            match self {
152                SharedMemory::Array { align, .. } => *align,
153                SharedMemory::Value { align, .. } => *align,
154            }
155        }
156    }
157
158    /// A specific allocation of shared memory at some `offset`
159    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
160    pub struct SmemAllocation {
161        /// The shared memory being allocated
162        pub smem: SharedMemory,
163        /// The offset in the shared memory buffer
164        pub offset: u32,
165    }
166
167    /// Shared liveness works the other way around from normal liveness, since shared memory lives
168    /// forever by default. So any use (read or write) inserts it as live, while only `free` changes
169    /// the state to dead.
170    ///
171    /// It also handles allocation of slices to each shared memory object, using the analyzed
172    /// liveness. `allocations` contains a specific slice allocation for each shared memory, while
173    /// ensuring no shared memories that exist at the same time can overlap.
174    #[derive(Default, Clone)]
175    pub struct SharedLiveness {
176        live_vars: HashMap<NodeIndex, HashSet<Id>>,
177        /// Map of all shared memories by their ID. Populated during the first pass with all
178        /// accessed shared memories.
179        pub shared_memories: HashMap<Id, SharedMemory>,
180        /// Map of allocations for each shared memory by its ID. Populated after the analysis, and
181        /// should contain all memories from `shared_memories`.
182        pub allocations: HashMap<Id, SmemAllocation>,
183    }
184
185    impl Analysis for SharedLiveness {
186        fn init(opt: &mut Optimizer) -> Self {
187            let mut this = Self::empty(opt);
188            this.analyze_liveness(opt);
189            this.uniformize_liveness(opt);
190            this.allocate_slices(opt);
191            this
192        }
193    }
194
195    impl SharedLiveness {
196        pub fn empty(opt: &Optimizer) -> Self {
197            let live_vars = opt
198                .node_ids()
199                .iter()
200                .map(|it| (*it, HashSet::new()))
201                .collect();
202            Self {
203                live_vars,
204                shared_memories: Default::default(),
205                allocations: Default::default(),
206            }
207        }
208
209        pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
210            &self.live_vars[&block]
211        }
212
213        fn is_live(&self, node: NodeIndex, var: Id) -> bool {
214            self.at_block(node).contains(&var)
215        }
216
217        /// Do a conservative block level liveness analysis
218        fn analyze_liveness(&mut self, opt: &mut Optimizer) {
219            let mut state = State {
220                worklist: VecDeque::from(opt.analysis::<PostOrder>().reverse()),
221                block_sets: HashMap::new(),
222            };
223            while let Some(block) = state.worklist.pop_front() {
224                self.analyze_block(opt, block, &mut state);
225            }
226        }
227
228        /// Extend divergent liveness to the preceding uniform block. Shared memory is always
229        /// uniformly declared, so it must be allocated before the branch.
230        fn uniformize_liveness(&mut self, opt: &mut Optimizer) {
231            let mut state = State {
232                worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
233                block_sets: HashMap::new(),
234            };
235            while let Some(block) = state.worklist.pop_front() {
236                self.uniformize_block(opt, block, &mut state);
237            }
238        }
239
240        /// Allocate slices while ensuring no concurrent shared memory slices overlap.
241        /// See also [`allocate_slice`]
242        fn allocate_slices(&mut self, opt: &mut Optimizer) {
243            for block in opt.node_ids() {
244                for live_smem in self.at_block(block).clone() {
245                    if !self.allocations.contains_key(&live_smem) {
246                        let smem = self.shared_memories[&live_smem];
247                        let offset = self.allocate_slice(block, smem.size(), smem.align());
248                        self.allocations
249                            .insert(smem.id(), SmemAllocation { smem, offset });
250                    }
251                }
252            }
253        }
254
255        /// Finds a valid offset for a specific slice, taking into account ranges that are already
256        /// in use.
257        ///
258        /// Essentially the same as the global memory pool, looking for a free slice first, then
259        /// extending the pool if there isn't one. Note that this linear algorithm isn't optimal
260        /// for offline allocations where we know all allocations beforehand, but should be good
261        /// enough for our current purposes. It may produce larger-than-required allocations in
262        /// some cases. Optimal allocation would require a far more complex algorithm.
263        fn allocate_slice(&mut self, block: NodeIndex, size: u32, align: u32) -> u32 {
264            let live_slices = self.live_slices(block);
265            if live_slices.is_empty() {
266                return 0;
267            }
268
269            for i in 0..live_slices.len() - 1 {
270                let slice_0 = &live_slices[i];
271                let slice_1 = &live_slices[i + 1];
272                let end_0 = (slice_0.offset + slice_0.smem.size()).next_multiple_of(align);
273                let gap = slice_1.offset.saturating_sub(end_0);
274                if gap >= size {
275                    return end_0;
276                }
277            }
278            let last_slice = &live_slices[live_slices.len() - 1];
279            (last_slice.offset + last_slice.smem.size()).next_multiple_of(align)
280        }
281
282        /// List of allocations that are currently live
283        fn live_slices(&mut self, block: NodeIndex) -> Vec<SmemAllocation> {
284            let mut live_slices = self
285                .allocations
286                .iter()
287                .filter(|(k, _)| self.is_live(block, **k))
288                .map(|it| *it.1)
289                .collect::<Vec<_>>();
290            live_slices.sort_by_key(|it| it.offset);
291            live_slices
292        }
293
294        fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
295            let BlockSets { generated, kill } = self.block_sets(opt, block, state);
296
297            let mut live_vars = generated.clone();
298
299            for predecessor in opt.predecessors(block) {
300                let predecessor = &self.live_vars[&predecessor];
301                live_vars.extend(predecessor.difference(kill));
302            }
303
304            if live_vars != self.live_vars[&block] {
305                state.worklist.extend(opt.successors(block));
306                self.live_vars.insert(block, live_vars);
307            }
308        }
309
310        fn uniformize_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
311            let mut live_vars = self.live_vars[&block].clone();
312            let uniformity = opt.analysis::<Uniformity>();
313
314            for successor in opt.successors(block) {
315                if !uniformity.is_block_uniform(successor) {
316                    let successor = &self.live_vars[&successor];
317                    live_vars.extend(successor);
318                }
319            }
320
321            if live_vars != self.live_vars[&block] {
322                state.worklist.extend(opt.predecessors(block));
323                self.live_vars.insert(block, live_vars);
324            }
325        }
326
327        fn block_sets<'a>(
328            &mut self,
329            opt: &mut Optimizer,
330            block: NodeIndex,
331            state: &'a mut State,
332        ) -> &'a BlockSets {
333            let block_sets = state.block_sets.entry(block);
334            block_sets.or_insert_with(|| self.calculate_block_sets(opt, block))
335        }
336
337        /// Any use makes a shared memory live (`generated`), while `free` kills it (`kill`).
338        /// Also collects all shared memories into a map.
339        fn calculate_block_sets(&mut self, opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
340            let mut generated = HashSet::new();
341            let mut kill = HashSet::new();
342
343            let ops = opt.program[block].ops.clone();
344
345            for op in ops.borrow_mut().values_mut() {
346                opt.visit_out(&mut op.out, |_, var| {
347                    if let Some(smem) = shared_memory(var) {
348                        generated.insert(smem.id());
349                        self.shared_memories.insert(smem.id(), smem);
350                    }
351                });
352                opt.visit_operation(&mut op.operation, &mut op.out, |_, var| {
353                    if let Some(smem) = shared_memory(var) {
354                        generated.insert(smem.id());
355                        self.shared_memories.insert(smem.id(), smem);
356                    }
357                });
358
359                if let Operation::Marker(Marker::Free(Variable {
360                    kind: VariableKind::SharedArray { id, .. } | VariableKind::Shared { id },
361                    ..
362                })) = &op.operation
363                {
364                    kill.insert(*id);
365                    generated.remove(id);
366                }
367            }
368
369            BlockSets { generated, kill }
370        }
371    }
372
373    fn shared_memory(var: &Variable) -> Option<SharedMemory> {
374        match var.kind {
375            VariableKind::SharedArray {
376                id,
377                length,
378                unroll_factor,
379                alignment,
380            } => Some(SharedMemory::Array {
381                id,
382                length: length * unroll_factor,
383                ty: var.ty,
384                align: alignment.unwrap_or_else(|| var.ty.size() as u32),
385            }),
386            VariableKind::Shared { id } => Some(SharedMemory::Value {
387                id,
388                ty: var.ty,
389                align: var.ty.size() as u32,
390            }),
391            _ => None,
392        }
393    }
394}