Skip to main content

cubecl_opt/analyses/
liveness.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3use cubecl_ir::Id;
4use petgraph::graph::NodeIndex;
5
6use crate::{Optimizer, analyses::post_order::PostOrder};
7
8use super::Analysis;
9
10pub struct Liveness {
11    live_vars: HashMap<NodeIndex, HashSet<Id>>,
12}
13
14#[derive(Clone)]
15struct BlockSets {
16    generated: HashSet<Id>,
17    kill: HashSet<Id>,
18}
19
20struct State {
21    worklist: VecDeque<NodeIndex>,
22    block_sets: HashMap<NodeIndex, BlockSets>,
23}
24
25impl Analysis for Liveness {
26    fn init(opt: &mut Optimizer) -> Self {
27        let mut this = Self::empty(opt);
28        this.analyze_liveness(opt);
29        this
30    }
31}
32
33impl Liveness {
34    pub fn empty(opt: &Optimizer) -> Self {
35        let live_vars = opt
36            .node_ids()
37            .iter()
38            .map(|it| (*it, HashSet::new()))
39            .collect();
40        Self { live_vars }
41    }
42
43    pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
44        &self.live_vars[&block]
45    }
46
47    pub fn is_dead(&self, node: NodeIndex, var: Id) -> bool {
48        !self.at_block(node).contains(&var)
49    }
50
51    /// Do a conservative block level liveness analysis
52    pub fn analyze_liveness(&mut self, opt: &mut Optimizer) {
53        let mut state = State {
54            worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
55            block_sets: HashMap::new(),
56        };
57        while let Some(block) = state.worklist.pop_front() {
58            self.analyze_block(opt, block, &mut state);
59        }
60    }
61
62    fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
63        let BlockSets { generated, kill } = block_sets(opt, block, state);
64
65        let mut live_vars = generated.clone();
66
67        for successor in opt.successors(block) {
68            let successor = &self.live_vars[&successor];
69            live_vars.extend(successor.difference(kill));
70        }
71
72        if live_vars != self.live_vars[&block] {
73            state.worklist.extend(opt.predecessors(block));
74            self.live_vars.insert(block, live_vars);
75        }
76    }
77}
78
79fn block_sets<'a>(opt: &mut Optimizer, block: NodeIndex, state: &'a mut State) -> &'a BlockSets {
80    let block_sets = state.block_sets.entry(block);
81    block_sets.or_insert_with(|| calculate_block_sets(opt, block))
82}
83
84fn calculate_block_sets(opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
85    let mut generated = HashSet::new();
86    let mut kill = HashSet::new();
87
88    let ops = opt.program[block].ops.clone();
89
90    let control_flow = opt.program[block].control_flow.clone();
91    opt.visit_control_flow(&mut control_flow.borrow_mut(), |opt, var| {
92        if let Some(id) = opt.local_variable_id(var) {
93            generated.insert(id);
94        }
95    });
96    for op in ops.borrow_mut().values_mut().rev() {
97        // Reads must be tracked after writes
98        opt.visit_out(&mut op.out, |opt, var| {
99            if let Some(id) = opt.local_variable_id(var) {
100                kill.insert(id);
101                generated.remove(&id);
102            }
103        });
104        opt.visit_operation(&mut op.operation, &mut op.out, |opt, var| {
105            if let Some(id) = opt.local_variable_id(var) {
106                generated.insert(id);
107            }
108        });
109    }
110
111    BlockSets { generated, kill }
112}
113
114/// Shared memory liveness analysis and allocation
115pub mod shared {
116    use cubecl_ir::{Marker, Operation, Type, Variable, VariableKind};
117
118    use crate::Uniformity;
119
120    use super::*;
121
122    /// A shared memory instance, all the information contained in the `VariableKind`, but with
123    /// a non-optional `align`.
124    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
125    pub enum SharedMemory {
126        Array {
127            id: Id,
128            length: usize,
129            ty: Type,
130            align: usize,
131        },
132        Value {
133            id: Id,
134            ty: Type,
135            align: usize,
136        },
137    }
138
139    impl SharedMemory {
140        /// The byte size of this shared memory
141        pub fn id(&self) -> u32 {
142            match self {
143                SharedMemory::Array { id, .. } => *id,
144                SharedMemory::Value { id, .. } => *id,
145            }
146        }
147
148        /// The byte size of this shared memory
149        pub fn size(&self) -> usize {
150            match self {
151                SharedMemory::Array { length, ty, .. } => length * ty.size(),
152                SharedMemory::Value { ty, .. } => ty.size(),
153            }
154        }
155
156        pub fn align(&self) -> usize {
157            match self {
158                SharedMemory::Array { align, .. } => *align,
159                SharedMemory::Value { align, .. } => *align,
160            }
161        }
162    }
163
164    /// A specific allocation of shared memory at some `offset`
165    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
166    pub struct SmemAllocation {
167        /// The shared memory being allocated
168        pub smem: SharedMemory,
169        /// The offset in the shared memory buffer
170        pub offset: usize,
171    }
172
173    /// Shared liveness works the other way around from normal liveness, since shared memory lives
174    /// forever by default. So any use (read or write) inserts it as live, while only `free` changes
175    /// the state to dead.
176    ///
177    /// It also handles allocation of slices to each shared memory object, using the analyzed
178    /// liveness. `allocations` contains a specific slice allocation for each shared memory, while
179    /// ensuring no shared memories that exist at the same time can overlap.
180    #[derive(Default, Clone)]
181    pub struct SharedLiveness {
182        live_vars: HashMap<NodeIndex, HashSet<Id>>,
183        /// Map of all shared memories by their ID. Populated during the first pass with all
184        /// accessed shared memories.
185        pub shared_memories: HashMap<Id, SharedMemory>,
186        /// Map of allocations for each shared memory by its ID. Populated after the analysis, and
187        /// should contain all memories from `shared_memories`.
188        pub allocations: HashMap<Id, SmemAllocation>,
189    }
190
191    impl Analysis for SharedLiveness {
192        fn init(opt: &mut Optimizer) -> Self {
193            let mut this = Self::empty(opt);
194            this.analyze_liveness(opt);
195            this.uniformize_liveness(opt);
196            this.allocate_slices(opt);
197            this
198        }
199    }
200
201    impl SharedLiveness {
202        pub fn empty(opt: &Optimizer) -> Self {
203            let live_vars = opt
204                .node_ids()
205                .iter()
206                .map(|it| (*it, HashSet::new()))
207                .collect();
208            Self {
209                live_vars,
210                shared_memories: Default::default(),
211                allocations: Default::default(),
212            }
213        }
214
215        pub fn at_block(&self, block: NodeIndex) -> &HashSet<Id> {
216            &self.live_vars[&block]
217        }
218
219        fn is_live(&self, node: NodeIndex, var: Id) -> bool {
220            self.at_block(node).contains(&var)
221        }
222
223        /// Do a conservative block level liveness analysis
224        fn analyze_liveness(&mut self, opt: &mut Optimizer) {
225            let mut state = State {
226                worklist: VecDeque::from(opt.analysis::<PostOrder>().reverse()),
227                block_sets: HashMap::new(),
228            };
229            while let Some(block) = state.worklist.pop_front() {
230                self.analyze_block(opt, block, &mut state);
231            }
232        }
233
234        /// Extend divergent liveness to the preceding uniform block. Shared memory is always
235        /// uniformly declared, so it must be allocated before the branch.
236        fn uniformize_liveness(&mut self, opt: &mut Optimizer) {
237            let mut state = State {
238                worklist: VecDeque::from(opt.analysis::<PostOrder>().forward()),
239                block_sets: HashMap::new(),
240            };
241            while let Some(block) = state.worklist.pop_front() {
242                self.uniformize_block(opt, block, &mut state);
243            }
244        }
245
246        /// Allocate slices while ensuring no concurrent shared memory slices overlap.
247        /// See also [`allocate_slice`]
248        fn allocate_slices(&mut self, opt: &mut Optimizer) {
249            for block in opt.node_ids() {
250                for live_smem in self.at_block(block).clone() {
251                    if !self.allocations.contains_key(&live_smem) {
252                        let smem = self.shared_memories[&live_smem];
253                        let offset = self.allocate_slice(block, smem.size(), smem.align());
254                        self.allocations
255                            .insert(smem.id(), SmemAllocation { smem, offset });
256                    }
257                }
258            }
259        }
260
261        /// Finds a valid offset for a specific slice, taking into account ranges that are already
262        /// in use.
263        ///
264        /// Essentially the same as the global memory pool, looking for a free slice first, then
265        /// extending the pool if there isn't one. Note that this linear algorithm isn't optimal
266        /// for offline allocations where we know all allocations beforehand, but should be good
267        /// enough for our current purposes. It may produce larger-than-required allocations in
268        /// some cases. Optimal allocation would require a far more complex algorithm.
269        fn allocate_slice(&mut self, block: NodeIndex, size: usize, align: usize) -> usize {
270            let live_slices = self.live_slices(block);
271            if live_slices.is_empty() {
272                return 0;
273            }
274
275            for i in 0..live_slices.len() - 1 {
276                let slice_0 = &live_slices[i];
277                let slice_1 = &live_slices[i + 1];
278                let end_0 = (slice_0.offset + slice_0.smem.size()).next_multiple_of(align);
279                let gap = slice_1.offset.saturating_sub(end_0);
280                if gap >= size {
281                    return end_0;
282                }
283            }
284            let last_slice = &live_slices[live_slices.len() - 1];
285            (last_slice.offset + last_slice.smem.size()).next_multiple_of(align)
286        }
287
288        /// List of allocations that are currently live
289        fn live_slices(&mut self, block: NodeIndex) -> Vec<SmemAllocation> {
290            let mut live_slices = self
291                .allocations
292                .iter()
293                .filter(|(k, _)| self.is_live(block, **k))
294                .map(|it| *it.1)
295                .collect::<Vec<_>>();
296            live_slices.sort_by_key(|it| it.offset);
297            live_slices
298        }
299
300        fn analyze_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
301            let BlockSets { generated, kill } = self.block_sets(opt, block, state);
302
303            let mut live_vars = generated.clone();
304
305            for predecessor in opt.predecessors(block) {
306                let predecessor = &self.live_vars[&predecessor];
307                live_vars.extend(predecessor.difference(kill));
308            }
309
310            if live_vars != self.live_vars[&block] {
311                state.worklist.extend(opt.successors(block));
312                self.live_vars.insert(block, live_vars);
313            }
314        }
315
316        fn uniformize_block(&mut self, opt: &mut Optimizer, block: NodeIndex, state: &mut State) {
317            let mut live_vars = self.live_vars[&block].clone();
318            let uniformity = opt.analysis::<Uniformity>();
319
320            for successor in opt.successors(block) {
321                if !uniformity.is_block_uniform(successor) {
322                    let successor = &self.live_vars[&successor];
323                    live_vars.extend(successor);
324                }
325            }
326
327            if live_vars != self.live_vars[&block] {
328                state.worklist.extend(opt.predecessors(block));
329                self.live_vars.insert(block, live_vars);
330            }
331        }
332
333        fn block_sets<'a>(
334            &mut self,
335            opt: &mut Optimizer,
336            block: NodeIndex,
337            state: &'a mut State,
338        ) -> &'a BlockSets {
339            let block_sets = state.block_sets.entry(block);
340            block_sets.or_insert_with(|| self.calculate_block_sets(opt, block))
341        }
342
343        /// Any use makes a shared memory live (`generated`), while `free` kills it (`kill`).
344        /// Also collects all shared memories into a map.
345        fn calculate_block_sets(&mut self, opt: &mut Optimizer, block: NodeIndex) -> BlockSets {
346            let mut generated = HashSet::new();
347            let mut kill = HashSet::new();
348
349            let ops = opt.program[block].ops.clone();
350
351            for op in ops.borrow_mut().values_mut() {
352                opt.visit_out(&mut op.out, |_, var| {
353                    if let Some(smem) = shared_memory(var) {
354                        generated.insert(smem.id());
355                        self.shared_memories.insert(smem.id(), smem);
356                    }
357                });
358                opt.visit_operation(&mut op.operation, &mut op.out, |_, var| {
359                    if let Some(smem) = shared_memory(var) {
360                        generated.insert(smem.id());
361                        self.shared_memories.insert(smem.id(), smem);
362                    }
363                });
364
365                if let Operation::Marker(Marker::Free(Variable {
366                    kind: VariableKind::SharedArray { id, .. } | VariableKind::Shared { id },
367                    ..
368                })) = &op.operation
369                {
370                    kill.insert(*id);
371                    generated.remove(id);
372                }
373            }
374
375            BlockSets { generated, kill }
376        }
377    }
378
379    fn shared_memory(var: &Variable) -> Option<SharedMemory> {
380        match var.kind {
381            VariableKind::SharedArray {
382                id,
383                length,
384                unroll_factor,
385                alignment,
386            } => Some(SharedMemory::Array {
387                id,
388                length: length * unroll_factor,
389                ty: var.ty,
390                align: alignment.unwrap_or_else(|| var.ty.size()),
391            }),
392            VariableKind::Shared { id } => Some(SharedMemory::Value {
393                id,
394                ty: var.ty,
395                align: var.ty.size(),
396            }),
397            _ => None,
398        }
399    }
400}