use std::{
cmp::Reverse,
collections::{BinaryHeap, HashMap, HashSet},
};
use rayon::prelude::*;
use crate::analysis::{
AnalysisResults, DataFlowSolver, LiveVariables, LivenessResult, SsaCfg, SsaFunction, SsaType,
SsaVarId, TypeClass, VariableOrigin,
};
pub struct InterferenceGraph {
edges: HashMap<SsaVarId, HashSet<SsaVarId>>,
var_types: HashMap<SsaVarId, SsaType>,
}
impl InterferenceGraph {
fn new() -> Self {
Self {
edges: HashMap::new(),
var_types: HashMap::new(),
}
}
fn add_edge(&mut self, a: SsaVarId, b: SsaVarId) {
if a != b {
self.edges.entry(a).or_default().insert(b);
self.edges.entry(b).or_default().insert(a);
}
}
fn set_type(&mut self, var: SsaVarId, ty: SsaType) {
self.var_types.insert(var, ty);
}
fn neighbors(&self, var: SsaVarId) -> impl Iterator<Item = SsaVarId> + '_ {
self.edges
.get(&var)
.into_iter()
.flat_map(|set| set.iter().copied())
}
fn degree(&self, var: SsaVarId) -> usize {
self.edges.get(&var).map_or(0, HashSet::len)
}
}
pub struct LocalAllocation {
pub var_to_local: HashMap<SsaVarId, u16>,
pub num_locals: u16,
pub original_to_compacted: HashMap<u16, u16>,
}
#[derive(Debug, Clone)]
struct LiveInterval {
start: usize,
end: usize,
}
impl LiveInterval {
fn new(pos: usize) -> Self {
Self {
start: pos,
end: pos + 1,
}
}
fn extend_start(&mut self, pos: usize) {
self.start = self.start.min(pos);
}
fn extend_end(&mut self, pos: usize) {
self.end = self.end.max(pos);
}
}
pub struct LocalCoalescer {
interference: InterferenceGraph,
coalescable_vars: Vec<SsaVarId>,
precomputed: Option<LocalAllocation>,
}
const LINEAR_SCAN_THRESHOLD: usize = 500;
impl LocalCoalescer {
pub fn build(ssa: &SsaFunction) -> Self {
if ssa.variable_count() > LINEAR_SCAN_THRESHOLD {
return Self::build_linear_scan(ssa);
}
Self::build_graph_coloring(ssa)
}
fn build_graph_coloring(ssa: &SsaFunction) -> Self {
let mut interference = InterferenceGraph::new();
for var in ssa.variables() {
interference.set_type(var.id(), var.var_type().clone());
}
let cfg = SsaCfg::from_ssa(ssa);
let analysis = LiveVariables::new(ssa);
let solver = DataFlowSolver::new(analysis);
let results = solver.solve(ssa, &cfg);
let block_ids: Vec<usize> = (0..ssa.block_count()).collect();
let boundary_edges: Vec<(SsaVarId, SsaVarId)> = block_ids
.par_iter()
.flat_map(|&block_id| {
let mut edges = Vec::new();
if let Some(live_out) = results.out_state(block_id) {
let live_vars: Vec<SsaVarId> = live_out.variables().collect();
for (i, &var1) in live_vars.iter().enumerate() {
for &var2 in &live_vars[i + 1..] {
edges.push((var1, var2));
}
}
}
if let Some(live_in) = results.in_state(block_id) {
let live_vars: Vec<SsaVarId> = live_in.variables().collect();
for (i, &var1) in live_vars.iter().enumerate() {
for &var2 in &live_vars[i + 1..] {
edges.push((var1, var2));
}
}
}
edges
})
.collect();
let intra_block_edges: Vec<(SsaVarId, SsaVarId)> = block_ids
.par_iter()
.flat_map(|&block_id| Self::collect_intra_block_edges(ssa, &results, block_id))
.collect();
let phi_edge_edges: Vec<(SsaVarId, SsaVarId)> = ssa
.blocks()
.par_iter()
.flat_map(|block| {
let mut edges = Vec::new();
let mut operands_by_pred: HashMap<usize, Vec<SsaVarId>> = HashMap::new();
for phi in block.phi_nodes() {
for operand in phi.operands() {
operands_by_pred
.entry(operand.predecessor())
.or_default()
.push(operand.value());
}
}
for (_, operands) in operands_by_pred {
for (i, &var1) in operands.iter().enumerate() {
for &var2 in &operands[i + 1..] {
edges.push((var1, var2));
}
}
}
edges
})
.collect();
for (a, b) in boundary_edges {
interference.add_edge(a, b);
}
for (a, b) in intra_block_edges {
interference.add_edge(a, b);
}
for (a, b) in phi_edge_edges {
interference.add_edge(a, b);
}
let coalescable_vars: Vec<SsaVarId> = ssa
.variables()
.iter()
.filter_map(|v| {
match v.origin() {
VariableOrigin::Stack(_) | VariableOrigin::Phi => Some(v.id()),
VariableOrigin::Argument(_) | VariableOrigin::Local(_) => None,
}
})
.collect();
Self {
interference,
coalescable_vars,
precomputed: None,
}
}
fn build_linear_scan(ssa: &SsaFunction) -> Self {
let mut var_types: HashMap<SsaVarId, SsaType> = HashMap::new();
for var in ssa.variables() {
var_types.insert(var.id(), var.var_type().clone());
}
let intervals = Self::compute_live_intervals(ssa);
let mut var_to_local: HashMap<SsaVarId, u16> = HashMap::new();
let mut next_local: u16 = 0;
let mut reserved_slots: HashSet<u16> = HashSet::new();
let mut original_to_new: HashMap<u16, u16> = HashMap::new();
let mut used_local_vars: Vec<(SsaVarId, u16)> = ssa
.variables()
.iter()
.filter_map(|v| {
if let VariableOrigin::Local(idx) = v.origin() {
if intervals.contains_key(&v.id()) {
return Some((v.id(), idx));
}
}
None
})
.collect();
used_local_vars.sort_by_key(|(_, idx)| *idx);
for (var_id, original_idx) in used_local_vars {
let new_slot = *original_to_new.entry(original_idx).or_insert_with(|| {
let slot = next_local;
next_local += 1;
slot
});
var_to_local.insert(var_id, new_slot);
reserved_slots.insert(new_slot);
}
let var_sort_key = |var_id: SsaVarId| -> (VariableOrigin, u32) {
ssa.variable(var_id)
.map_or((VariableOrigin::Stack(u32::MAX), u32::MAX), |v| {
(v.origin(), v.version())
})
};
let mut sorted_intervals: Vec<_> = intervals.into_iter().collect();
sorted_intervals.sort_by(|(var_id_a, interval_a), (var_id_b, interval_b)| {
interval_a
.start
.cmp(&interval_b.start)
.then_with(|| var_sort_key(*var_id_a).cmp(&var_sort_key(*var_id_b)))
});
let mut active: BinaryHeap<Reverse<(usize, u16, TypeClass)>> = BinaryHeap::new();
let mut free_slots: HashMap<TypeClass, Vec<u16>> = HashMap::new();
for (var_id, interval) in sorted_intervals {
if var_to_local.contains_key(&var_id) {
if let Some(slot) = var_to_local.get(&var_id) {
let type_class = var_types
.get(&var_id)
.map_or(TypeClass::Int32, SsaType::storage_class);
active.push(Reverse((interval.end, *slot, type_class)));
}
continue;
}
while let Some(&Reverse((end, slot, type_class))) = active.peek() {
if end > interval.start {
break;
}
active.pop();
if !reserved_slots.contains(&slot) {
free_slots.entry(type_class).or_default().push(slot);
}
}
let type_class = var_types
.get(&var_id)
.map_or(TypeClass::Int32, SsaType::storage_class);
let slot = free_slots.get_mut(&type_class).and_then(Vec::pop);
let slot = slot.unwrap_or_else(|| {
let s = next_local;
next_local += 1;
s
});
var_to_local.insert(var_id, slot);
active.push(Reverse((interval.end, slot, type_class)));
}
let allocation = LocalAllocation {
var_to_local,
num_locals: next_local,
original_to_compacted: original_to_new,
};
Self {
interference: InterferenceGraph::new(),
coalescable_vars: Vec::new(),
precomputed: Some(allocation),
}
}
fn compute_live_intervals(ssa: &SsaFunction) -> HashMap<SsaVarId, LiveInterval> {
let mut intervals: HashMap<SsaVarId, LiveInterval> = HashMap::new();
let mut instr_idx = 0usize;
for block_id in 0..ssa.block_count() {
let Some(block) = ssa.block(block_id) else {
continue;
};
for phi in block.phi_nodes() {
let def = phi.result();
intervals
.entry(def)
.or_insert_with(|| LiveInterval::new(instr_idx))
.extend_start(instr_idx);
for operand in phi.operands() {
intervals
.entry(operand.value())
.or_insert_with(|| LiveInterval::new(instr_idx))
.extend_end(instr_idx + 1);
}
}
for instr in block.instructions() {
for &use_var in &instr.uses() {
intervals
.entry(use_var)
.or_insert_with(|| LiveInterval::new(instr_idx))
.extend_end(instr_idx + 1);
}
if let Some(def) = instr.def() {
intervals
.entry(def)
.or_insert_with(|| LiveInterval::new(instr_idx))
.extend_start(instr_idx);
}
instr_idx += 1;
}
}
intervals
}
fn collect_intra_block_edges(
ssa: &SsaFunction,
results: &AnalysisResults<LivenessResult>,
block_id: usize,
) -> Vec<(SsaVarId, SsaVarId)> {
let mut edges = Vec::new();
let Some(block) = ssa.block(block_id) else {
return edges;
};
let mut live: HashSet<SsaVarId> = results
.out_state(block_id)
.map(|r| r.variables().collect())
.unwrap_or_default();
for instr in block.instructions().iter().rev() {
if let Some(def) = instr.def() {
for &live_var in &live {
if live_var != def {
edges.push((def, live_var));
}
}
live.remove(&def);
}
for &use_var in &instr.uses() {
live.insert(use_var);
}
}
for phi in block.phi_nodes() {
let def = phi.result();
for &live_var in &live {
if live_var != def {
edges.push((def, live_var));
}
}
}
edges
}
pub fn allocate(&self, ssa: &SsaFunction) -> LocalAllocation {
if let Some(precomputed) = &self.precomputed {
return LocalAllocation {
var_to_local: precomputed.var_to_local.clone(),
num_locals: precomputed.num_locals,
original_to_compacted: precomputed.original_to_compacted.clone(),
};
}
self.allocate_graph_coloring(ssa)
}
fn allocate_graph_coloring(&self, ssa: &SsaFunction) -> LocalAllocation {
let mut var_to_local: HashMap<SsaVarId, u16> = HashMap::new();
let mut next_local: u16 = 0;
let mut reserved_slots: HashSet<u16> = HashSet::new();
let mut used_local_vars: HashSet<SsaVarId> = HashSet::new();
for block in ssa.blocks() {
for phi in block.phi_nodes() {
if let Some(var) = ssa.variable(phi.result()) {
if matches!(var.origin(), VariableOrigin::Local(_)) {
used_local_vars.insert(phi.result());
}
}
for operand in phi.operands() {
if let Some(var) = ssa.variable(operand.value()) {
if matches!(var.origin(), VariableOrigin::Local(_)) {
used_local_vars.insert(operand.value());
}
}
}
}
for instr in block.instructions() {
let op = instr.op();
if let Some(dest) = op.dest() {
if let Some(var) = ssa.variable(dest) {
if matches!(var.origin(), VariableOrigin::Local(_)) {
used_local_vars.insert(dest);
}
}
}
for use_var in op.uses() {
if let Some(var) = ssa.variable(use_var) {
if matches!(var.origin(), VariableOrigin::Local(_)) {
used_local_vars.insert(use_var);
}
}
}
}
}
let mut local_vars: Vec<(SsaVarId, u16)> = ssa
.variables()
.iter()
.filter_map(|v| {
if let VariableOrigin::Local(idx) = v.origin() {
if used_local_vars.contains(&v.id()) {
return Some((v.id(), idx));
}
}
None
})
.collect();
local_vars.sort_by_key(|(_, idx)| *idx);
let mut original_to_new: HashMap<u16, u16> = HashMap::new();
for (var_id, original_idx) in local_vars {
let new_slot = *original_to_new.entry(original_idx).or_insert_with(|| {
let slot = next_local;
next_local += 1;
slot
});
var_to_local.insert(var_id, new_slot);
reserved_slots.insert(new_slot);
}
let var_sort_key = |var_id: SsaVarId| -> (VariableOrigin, u32) {
ssa.variable(var_id)
.map_or((VariableOrigin::Stack(u32::MAX), u32::MAX), |v| {
(v.origin(), v.version())
})
};
let mut sorted_vars = self.coalescable_vars.clone();
sorted_vars.sort_by(|a, b| {
let deg_a = self.interference.degree(*a);
let deg_b = self.interference.degree(*b);
deg_b
.cmp(°_a)
.then_with(|| var_sort_key(*a).cmp(&var_sort_key(*b)))
});
for var in sorted_vars {
let used_slots: HashSet<u16> = self
.interference
.neighbors(var)
.filter_map(|neighbor| var_to_local.get(&neighbor).copied())
.collect();
let var_type = self.interference.var_types.get(&var);
#[allow(clippy::maybe_infinite_iter)]
let slot = (0u16..)
.find(|&s| {
if reserved_slots.contains(&s) {
return false;
}
if used_slots.contains(&s) {
return false;
}
for (&other_var, &other_slot) in &var_to_local {
if other_slot == s {
let other_type = self.interference.var_types.get(&other_var);
if !types_compatible(var_type, other_type) {
return false;
}
}
}
true
})
.expect("Should always find a valid slot");
var_to_local.insert(var, slot);
next_local = next_local.max(slot + 1);
}
LocalAllocation {
var_to_local,
num_locals: next_local,
original_to_compacted: original_to_new,
}
}
}
fn types_compatible(t1: Option<&SsaType>, t2: Option<&SsaType>) -> bool {
match (t1, t2) {
(None, _) | (_, None) => true, (Some(a), Some(b)) => a.is_compatible_for_storage(b),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_vars(n: usize) -> Vec<SsaVarId> {
(0..n).map(|_| SsaVarId::new()).collect()
}
#[test]
fn test_interference_graph_add_edge() {
let mut graph = InterferenceGraph::new();
let vars = make_vars(3);
let (var_a, var_b, var_c) = (vars[0], vars[1], vars[2]);
graph.add_edge(var_a, var_b);
assert_eq!(graph.degree(var_a), 1);
assert_eq!(graph.degree(var_b), 1);
assert_eq!(graph.degree(var_c), 0);
let neighbors_a: Vec<_> = graph.neighbors(var_a).collect();
assert_eq!(neighbors_a, vec![var_b]);
let neighbors_b: Vec<_> = graph.neighbors(var_b).collect();
assert_eq!(neighbors_b, vec![var_a]);
graph.add_edge(var_a, var_b);
assert_eq!(graph.degree(var_a), 1);
graph.add_edge(var_a, var_a);
assert_eq!(graph.degree(var_a), 1);
}
#[test]
fn test_interference_graph_multiple_edges() {
let mut graph = InterferenceGraph::new();
let vars = make_vars(5);
graph.add_edge(vars[0], vars[1]);
graph.add_edge(vars[0], vars[2]);
graph.add_edge(vars[1], vars[2]);
assert_eq!(graph.degree(vars[0]), 2);
assert_eq!(graph.degree(vars[1]), 2);
assert_eq!(graph.degree(vars[2]), 2);
assert_eq!(graph.degree(vars[3]), 0);
assert_eq!(graph.degree(vars[4]), 0);
}
#[test]
fn test_type_compatibility_same_class() {
assert!(types_compatible(Some(&SsaType::I32), Some(&SsaType::I32)));
assert!(types_compatible(Some(&SsaType::I32), Some(&SsaType::U32)));
assert!(types_compatible(Some(&SsaType::I32), Some(&SsaType::Bool)));
assert!(types_compatible(Some(&SsaType::Bool), Some(&SsaType::Char)));
assert!(types_compatible(Some(&SsaType::I64), Some(&SsaType::U64)));
assert!(types_compatible(
Some(&SsaType::Object),
Some(&SsaType::String)
));
}
#[test]
fn test_type_compatibility_different_class() {
assert!(!types_compatible(Some(&SsaType::I32), Some(&SsaType::I64)));
assert!(!types_compatible(
Some(&SsaType::I32),
Some(&SsaType::Object)
));
assert!(!types_compatible(Some(&SsaType::F32), Some(&SsaType::F64)));
assert!(!types_compatible(Some(&SsaType::I32), Some(&SsaType::F32)));
}
#[test]
fn test_type_compatibility_with_none() {
assert!(types_compatible(None, Some(&SsaType::I32)));
assert!(types_compatible(Some(&SsaType::I32), None));
assert!(types_compatible(None, None));
}
#[test]
fn test_greedy_coloring_non_interfering_same_slot() {
let mut graph = InterferenceGraph::new();
let vars = make_vars(2);
let (var_a, var_b) = (vars[0], vars[1]);
graph.set_type(var_a, SsaType::I32);
graph.set_type(var_b, SsaType::I32);
let coalescer = LocalCoalescer {
interference: graph,
coalescable_vars: vec![var_a, var_b],
precomputed: None,
};
let ssa = create_minimal_ssa_function();
let allocation = coalescer.allocate(&ssa);
assert_eq!(allocation.var_to_local.get(&var_a), Some(&0));
assert_eq!(allocation.var_to_local.get(&var_b), Some(&0));
assert_eq!(allocation.num_locals, 1);
}
#[test]
fn test_greedy_coloring_interfering_different_slots() {
let mut graph = InterferenceGraph::new();
let vars = make_vars(2);
let (var_a, var_b) = (vars[0], vars[1]);
graph.set_type(var_a, SsaType::I32);
graph.set_type(var_b, SsaType::I32);
graph.add_edge(var_a, var_b);
let coalescer = LocalCoalescer {
interference: graph,
coalescable_vars: vec![var_a, var_b],
precomputed: None,
};
let ssa = create_minimal_ssa_function();
let allocation = coalescer.allocate(&ssa);
let slot_a = allocation.var_to_local.get(&var_a).unwrap();
let slot_b = allocation.var_to_local.get(&var_b).unwrap();
assert_ne!(slot_a, slot_b);
assert_eq!(allocation.num_locals, 2);
}
#[test]
fn test_greedy_coloring_type_incompatible_different_slots() {
let mut graph = InterferenceGraph::new();
let vars = make_vars(2);
let (var_a, var_b) = (vars[0], vars[1]);
graph.set_type(var_a, SsaType::I32);
graph.set_type(var_b, SsaType::I64);
let coalescer = LocalCoalescer {
interference: graph,
coalescable_vars: vec![var_a, var_b],
precomputed: None,
};
let ssa = create_minimal_ssa_function();
let allocation = coalescer.allocate(&ssa);
let slot_a = allocation.var_to_local.get(&var_a).unwrap();
let slot_b = allocation.var_to_local.get(&var_b).unwrap();
assert_ne!(slot_a, slot_b);
assert_eq!(allocation.num_locals, 2);
}
#[test]
fn test_greedy_coloring_clique_needs_n_colors() {
let mut graph = InterferenceGraph::new();
let vars = make_vars(4);
for i in 0..4 {
graph.set_type(vars[i], SsaType::I32);
for j in (i + 1)..4 {
graph.add_edge(vars[i], vars[j]);
}
}
let coalescer = LocalCoalescer {
interference: graph,
coalescable_vars: vars.clone(),
precomputed: None,
};
let ssa = create_minimal_ssa_function();
let allocation = coalescer.allocate(&ssa);
let slots: HashSet<_> = vars
.iter()
.filter_map(|v| allocation.var_to_local.get(v).copied())
.collect();
assert_eq!(slots.len(), 4);
assert_eq!(allocation.num_locals, 4);
}
#[test]
fn test_greedy_coloring_chain_needs_2_colors() {
let mut graph = InterferenceGraph::new();
let vars = make_vars(4);
for var in &vars {
graph.set_type(*var, SsaType::I32);
}
graph.add_edge(vars[0], vars[1]);
graph.add_edge(vars[1], vars[2]);
graph.add_edge(vars[2], vars[3]);
let coalescer = LocalCoalescer {
interference: graph,
coalescable_vars: vars.clone(),
precomputed: None,
};
let ssa = create_minimal_ssa_function();
let allocation = coalescer.allocate(&ssa);
for i in 0..3 {
let slot_i = allocation.var_to_local.get(&vars[i]).unwrap();
let slot_j = allocation.var_to_local.get(&vars[i + 1]).unwrap();
assert_ne!(
slot_i,
slot_j,
"Adjacent vars {} and {} share slot",
i,
i + 1
);
}
assert!(allocation.num_locals <= 2);
}
#[test]
fn test_mixed_types_coalesce_within_class() {
let mut graph = InterferenceGraph::new();
let vars = make_vars(5);
let (var_i32, var_u32, var_bool, var_i64, var_u64) =
(vars[0], vars[1], vars[2], vars[3], vars[4]);
graph.set_type(var_i32, SsaType::I32);
graph.set_type(var_u32, SsaType::U32);
graph.set_type(var_bool, SsaType::Bool);
graph.set_type(var_i64, SsaType::I64);
graph.set_type(var_u64, SsaType::U64);
let coalescer = LocalCoalescer {
interference: graph,
coalescable_vars: vec![var_i32, var_u32, var_bool, var_i64, var_u64],
precomputed: None,
};
let ssa = create_minimal_ssa_function();
let allocation = coalescer.allocate(&ssa);
assert_eq!(allocation.var_to_local.get(&var_i32), Some(&0));
assert_eq!(allocation.var_to_local.get(&var_u32), Some(&0));
assert_eq!(allocation.var_to_local.get(&var_bool), Some(&0));
assert_eq!(allocation.var_to_local.get(&var_i64), Some(&1));
assert_eq!(allocation.var_to_local.get(&var_u64), Some(&1));
assert_eq!(allocation.num_locals, 2);
}
fn create_minimal_ssa_function() -> SsaFunction {
use crate::analysis::SsaFunctionBuilder;
SsaFunctionBuilder::new(0, 0).build_with(|ctx| {
ctx.block(0, |b| {
b.ret();
});
})
}
}