cubecl_opt/passes/
liveness.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
use std::collections::{HashMap, HashSet, VecDeque};

use petgraph::graph::NodeIndex;

use crate::{visit_noop, Optimizer};

#[derive(Clone)]
struct BlockSets {
    gen: HashSet<(u16, u8)>,
    kill: HashSet<(u16, u8)>,
}

struct State {
    worklist: VecDeque<NodeIndex>,
    block_sets: HashMap<NodeIndex, BlockSets>,
}

impl Optimizer {
    /// Do a conservative block level liveness analysis
    pub fn analyze_liveness(&mut self) {
        let mut state = State {
            worklist: VecDeque::from(self.post_order()),
            block_sets: HashMap::new(),
        };
        while let Some(block) = state.worklist.pop_front() {
            self.analyze_block(block, &mut state);
        }
    }

    fn analyze_block(&mut self, block: NodeIndex, state: &mut State) {
        let BlockSets { gen, kill } = self.block_sets(block, state);

        let mut live_vars = gen.clone();

        for successor in self.successors(block) {
            let successor = &self.program[successor].live_vars;
            live_vars.extend(successor.difference(kill));
        }

        if live_vars != self.program[block].live_vars {
            state.worklist.extend(self.predecessors(block));
            self.program[block].live_vars = live_vars;
        }
    }

    fn block_sets<'a>(&mut self, block: NodeIndex, state: &'a mut State) -> &'a BlockSets {
        let block_sets = state.block_sets.entry(block);
        block_sets.or_insert_with(|| self.calculate_block_sets(block))
    }

    fn calculate_block_sets(&mut self, block: NodeIndex) -> BlockSets {
        let mut gen = HashSet::new();
        let mut kill = HashSet::new();

        let ops = self.program[block].ops.clone();

        for op in ops.borrow_mut().values_mut().rev() {
            // Reads must be tracked after writes
            self.visit_operation(op, visit_noop, |opt, var| {
                if let Some(id) = opt.local_variable_id(var) {
                    kill.insert(id);
                    gen.remove(&id);
                }
            });
            self.visit_operation(
                op,
                |opt, var| {
                    if let Some(id) = opt.local_variable_id(var) {
                        gen.insert(id);
                    }
                },
                visit_noop,
            );
        }

        BlockSets { gen, kill }
    }
}