zenforks-cubecl-opt 0.10.1

Compiler optimizations for CubeCL
Documentation
use std::{
    cell::RefCell,
    collections::{HashMap, HashSet, LinkedList, VecDeque},
    ops::Deref,
};

use crate::{ControlFlow, NodeIndex, analyses::Analysis};
use smallvec::SmallVec;

use crate::{
    Optimizer,
    analyses::dominance::{Dominators, PostDominators},
};

use super::{Expression, Value, ValueTable, convert::value_of_var};

const MAX_SET_PASSES: usize = 10;

#[derive(Default)]
pub struct GlobalValues(pub RefCell<GvnState>);

impl Deref for GlobalValues {
    type Target = RefCell<GvnState>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

#[derive(Debug, Clone, Default)]
pub struct GvnState {
    pub values: ValueTable,
    pub block_sets: HashMap<NodeIndex, BlockSets>,
}

impl Analysis for GlobalValues {
    fn init(opt: &mut Optimizer) -> Self {
        let mut this = GvnState::default();
        this.build_sets(opt);
        GlobalValues(RefCell::new(this))
    }
}

/// The set annotations for a given block
#[derive(Debug, Clone, Default)]
pub struct BlockSets {
    /// Expressions generated in this block
    pub exp_gen: LinkedList<(u32, Expression)>,
    /// Phi nodes that create new values in this block
    pub phi_gen: HashMap<u32, Value>,
    /// Temporaries that are assigned black box values (i.e. atomics, index to mutable array)
    pub tmp_gen: HashSet<Value>,
    /// The set of leaders for each value. This is the first temporary that contains the expression
    /// on any given path.
    pub leaders: HashMap<u32, Value>,

    /// The set of anticipated ("requested") expressions at the output point of the block
    pub antic_out: LinkedList<(u32, Expression)>,
    /// The set of anticipated ("requested") expressions at the input point of the block
    pub antic_in: LinkedList<(u32, Expression)>,
}

impl GvnState {
    /// Build set annotations for each block. Executes two steps:
    /// 1. Forward DFA that generates the available expressions, values and leaders for each block
    /// 2. Backward fixed-point DFA that generates the anticipated expressions/antileaders for each
    ///    block
    pub fn build_sets(&mut self, opt: &mut Optimizer) {
        self.build_sets_forward(opt);
        self.build_sets_backward(opt);

        let global_leaders = self.values.value_numbers.iter();
        let global_leaders = global_leaders
            .filter(|(k, _)| {
                matches!(
                    k,
                    Value::Constant(_, _)
                        | Value::Input(_, _)
                        | Value::Scalar(_, _)
                        | Value::ConstArray(_, _, _, _)
                        | Value::Builtin(..)
                        | Value::Output(_, _)
                )
            })
            .map(|(k, v)| (*v, *k))
            .collect::<HashMap<_, _>>();
        for set in self.block_sets.values_mut() {
            set.leaders.extend(global_leaders.clone());
        }
    }

    fn build_sets_forward(&mut self, opt: &mut Optimizer) {
        let mut worklist = VecDeque::new();
        let dominators = opt.analysis::<Dominators>();

        worklist.push_back((vec![opt.entry()], HashMap::new(), HashSet::new()));

        while let Some((successors, leaders, tmp_gen)) = worklist.pop_front() {
            for block in successors {
                let (leaders, tmp_gen) =
                    self.build_block_sets_forward(opt, block, leaders.clone(), tmp_gen.clone());
                let successors = dominators.immediately_dominated_by(block);
                worklist.push_back((successors.collect(), leaders, tmp_gen));
            }
        }
    }

    fn build_sets_backward(&mut self, opt: &mut Optimizer) {
        let mut build_passes = 0;
        let mut changed = true;
        let mut worklist = VecDeque::new();
        let post_doms = opt.analysis::<PostDominators>();

        worklist.push_back(opt.ret);

        while changed && build_passes < MAX_SET_PASSES {
            changed = false;
            while let Some(current) = worklist.pop_front() {
                changed |= self.build_block_sets_backward(opt, current);
                let predecessors = post_doms.immediately_dominated_by(current);
                worklist.extend(predecessors);
            }
            build_passes += 1;
        }
    }

    /// Iterate through the dominator tree to find available (used) expressions and local leaders
    /// for those expressions in each block. Leaders are inherited in dominated blocks, since the
    /// variables that represent them are also available there.
    fn build_block_sets_forward(
        &mut self,
        opt: &mut Optimizer,
        block: NodeIndex,
        mut leaders: HashMap<u32, Value>,
        tmp_gen: HashSet<Value>,
    ) -> (HashMap<u32, Value>, HashSet<Value>) {
        // Expressions generated (used on the right hand side of an instruction) in this block
        let mut exp_gen = LinkedList::new();
        // Values generated by the output variables of phi nodes in this block.
        let mut phi_gen = HashMap::new();
        // Temporaries/variables that are generated with a volatile expression on the right hand
        // side. Used to kill all expressions that depend on them.
        let mut tmp_gen = tmp_gen;
        // Values already added in this block. Used to deduplicate locally.
        let mut added_exprs = HashSet::new();

        // Number phi outputs and add the out var as a leader for that value
        for phi in opt.program[block].phi_nodes.borrow().iter() {
            let (num, val) = self.values.lookup_or_add_phi(phi);
            leaders.entry(num).or_insert(val);
            phi_gen.entry(num).or_insert(val);
        }

        for op in opt.program[block].ops.borrow().values() {
            // Try inserting operation
            match self
                .values
                .maybe_insert_op(op, &mut exp_gen, &mut added_exprs)
            {
                Ok((num, Some(val), _)) => {
                    // New value, add out var as leader
                    leaders.entry(num).or_insert(val);
                }
                Err(Some(killed)) => {
                    // Volatile expression, kill out var
                    tmp_gen.insert(killed);
                }
                _ => {}
            }
        }

        let sets = BlockSets {
            exp_gen,
            phi_gen,
            tmp_gen: tmp_gen.clone(),
            leaders: leaders.clone(),

            antic_out: Default::default(),
            antic_in: Default::default(),
        };
        self.block_sets.insert(block, sets);
        (leaders, tmp_gen)
    }

    /// Do a fixed point data backward flow analysis to find expected expressions at any given
    /// program point. Iterates through the post-dominator tree because it's the fastest way to
    /// converge.
    fn build_block_sets_backward(&mut self, opt: &mut Optimizer, current: NodeIndex) -> bool {
        let mut changed = false;

        let successors = opt.successors(current);
        // Since we have no critical edges, if successors > 1 then they must have only one entry,
        // So no phi nodes.
        //
        // Loops are a special case because the conservative nature of PRE normally prevents loop
        // invariants from being moved out of the loop. Since only side-effect free values are
        // numbered, we can safely treat loops as being executed at least once. The worst case is
        // some expressions are executed unnecessarily, but for a loop that never runs, performance
        // is likely secondary.
        #[allow(clippy::comparison_chain)]
        if let ControlFlow::Loop { body, .. } | ControlFlow::LoopBreak { body, .. } =
            opt.block(current).control_flow.borrow().clone()
        {
            let antic_in_succ = &self.block_sets[&body].antic_in;
            let phi_gen = &self.block_sets[&body].phi_gen;
            let result =
                phi_translate(opt, phi_gen, antic_in_succ, body, current, &mut self.values);
            if self.block_sets[&current].antic_out != result {
                changed = true;
            }
            self.block_sets.get_mut(&current).unwrap().antic_out = result;
        } else if successors.len() > 1 {
            let potential_out = &self.block_sets[&successors[0]].antic_in;
            let mut result = LinkedList::new();
            let rest = successors[1..]
                .iter()
                .map(|child| &self.block_sets[child].antic_in);
            // Only add expressions expected at all successors to this block's anticipated list
            for (val, expr) in potential_out {
                if rest.clone().all(|child| child.iter().any(|v| v.0 == *val)) {
                    result.push_back((*val, expr.clone()));
                }
            }
            if self.block_sets[&current].antic_out != result {
                changed = true;
            }
            self.block_sets.get_mut(&current).unwrap().antic_out = result;
        } else if successors.len() == 1 {
            let child = successors[0];
            let antic_in_succ = &self.block_sets[&child].antic_in;
            let phi_gen = &self.block_sets[&child].phi_gen;
            let result = phi_translate(
                opt,
                phi_gen,
                antic_in_succ,
                child,
                current,
                &mut self.values,
            );
            if self.block_sets[&current].antic_out != result {
                changed = true;
            }
            self.block_sets.get_mut(&current).unwrap().antic_out = result;
        }

        let mut killed = self.block_sets[&current]
            .tmp_gen
            .iter()
            .map(|tmp| self.values.lookup_or_add_value(*tmp))
            .collect::<HashSet<_>>();
        let cleaned = self.block_sets[&current]
            .exp_gen
            .iter()
            .chain(self.block_sets[&current].antic_out.iter())
            .filter_map(|(val, exp)| {
                // Kill expression if any dependency is volatile
                for dependency in exp.depends_on() {
                    if killed.contains(&dependency) {
                        killed.insert(*val);
                        return None;
                    }
                }
                if let Expression::Volatile(_) = exp {
                    killed.insert(*val);
                    return None;
                }
                Some((*val, exp.clone()))
            });
        let mut added = HashSet::new();
        let mut result = LinkedList::new();
        for v in cleaned {
            if !added.contains(&v.0) {
                added.insert(v.0);
                result.push_back(v);
            }
        }
        if self.block_sets[&current].antic_in != result {
            changed = true;
        }
        self.block_sets.get_mut(&current).unwrap().antic_in = result;

        changed
    }
}

/// Translate the phi output values to their equivalent input value in the predecessor block
pub fn phi_translate(
    opt: &Optimizer,
    phi_gen: &HashMap<u32, Value>,
    antic: &LinkedList<(u32, Expression)>,
    child: NodeIndex,
    parent: NodeIndex,
    values: &mut ValueTable,
) -> LinkedList<(u32, Expression)> {
    let mut result = LinkedList::new();
    let mut translated = HashMap::new();

    // Translate each phi's output variable value to the input variable value
    for phi in opt.block(child).phi_nodes.borrow().iter() {
        let (num, _) = values.lookup_or_add_phi(phi);
        let here = phi.entries.iter().find(|it| it.block == parent).unwrap();
        let num_here = values.lookup_or_add_var(&here.value).unwrap();
        translated.insert(num, num_here);
    }

    for (val, expr) in antic {
        // Translate phi node itself
        if let Some(value) = phi_gen.get(val) {
            let nodes = opt.block(child).phi_nodes.borrow();
            let phi = nodes
                .iter()
                .find(|it| &value_of_var(&it.out).unwrap() == value);

            if let Some(phi) = phi {
                let value_here = phi.entries.iter().find(|it| it.block == parent).unwrap();
                let value_here = value_of_var(&value_here.value).unwrap();
                let num_here = values.lookup_or_add_expr(Expression::Value(value_here), None);
                result.push_back((num_here, Expression::Value(value_here)));
                translated.insert(*val, num_here);
            }
        } else {
            let t = |val: &u32| *translated.get(val).unwrap_or(val);

            // Recursively translate each dependency's value from the child block to the parent
            // block it's (transitively) based on the phi output.
            let updated = match expr {
                Expression::Instruction(inst) => {
                    let args = inst.args.iter().map(t).collect::<SmallVec<[u32; 4]>>();
                    let mut inst = inst.clone();
                    inst.args = args;
                    Expression::Instruction(inst)
                }
                Expression::Copy(val, item) => Expression::Copy(t(val), *item),
                Expression::Phi(_) => continue,
                other => other.clone(),
            };
            let updated_val = values.lookup_or_add_expr(updated.clone(), None);
            result.push_back((updated_val, updated));
            translated.insert(*val, updated_val);
        }
    }
    result
}