use lamina_mir::{Function, Instruction, Operand, Register, VirtualReg};
use std::collections::{HashMap, HashSet};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum PhysRegHandle {
Named(&'static str),
}
impl PhysRegHandle {
pub fn as_named(self) -> Option<&'static str> {
match self {
PhysRegHandle::Named(name) => Some(name),
}
}
}
pub trait PhysRegConvertible: Copy + Eq {
fn into_handle(self) -> PhysRegHandle;
fn from_handle(handle: PhysRegHandle) -> Option<Self>
where
Self: Sized;
}
impl PhysRegConvertible for &'static str {
fn into_handle(self) -> PhysRegHandle {
PhysRegHandle::Named(self)
}
fn from_handle(handle: PhysRegHandle) -> Option<Self> {
match handle {
PhysRegHandle::Named(name) => Some(name),
}
}
}
pub trait RegisterAllocator {
type PhysReg: PhysRegConvertible;
fn alloc_scratch(&mut self) -> Option<Self::PhysReg>;
fn free_scratch(&mut self, phys: Self::PhysReg);
fn get_mapping(&self, vreg: &VirtualReg) -> Option<Self::PhysReg>;
fn ensure_mapping(&mut self, vreg: VirtualReg) -> Option<Self::PhysReg>;
fn mapped_for_register(&self, reg: &Register) -> Option<Self::PhysReg>;
fn occupy(&mut self, phys: Self::PhysReg);
fn release(&mut self, phys: Self::PhysReg);
fn is_occupied(&self, phys: Self::PhysReg) -> bool;
}
pub trait RegisterAllocatorDyn {
fn alloc_scratch_dyn(&mut self) -> Option<PhysRegHandle>;
fn free_scratch_dyn(&mut self, phys: PhysRegHandle);
fn get_mapping_dyn(&self, vreg: &VirtualReg) -> Option<PhysRegHandle>;
fn ensure_mapping_dyn(&mut self, vreg: VirtualReg) -> Option<PhysRegHandle>;
fn mapped_for_register_dyn(&self, reg: &Register) -> Option<PhysRegHandle>;
fn occupy_dyn(&mut self, phys: PhysRegHandle);
fn release_dyn(&mut self, phys: PhysRegHandle);
fn is_occupied_dyn(&self, phys: PhysRegHandle) -> bool;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Allocation<R: Copy> {
Register(R),
Spill(i32),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LiveInterval {
pub vreg: VirtualReg,
pub start: usize,
pub end: usize,
}
pub struct LinearScanAllocator;
impl LinearScanAllocator {
pub fn compute_intervals(function: &Function) -> Vec<LiveInterval> {
let mut intervals: HashMap<VirtualReg, LiveInterval> = HashMap::new();
let mut pos: usize = 0;
for param in &function.sig.params {
if let Register::Virtual(v) = ¶m.reg {
intervals.entry(*v).or_insert(LiveInterval {
vreg: *v,
start: 0,
end: 0,
});
}
}
for block in &function.blocks {
for instr in &block.instructions {
Self::scan_instruction(instr, pos, &mut intervals);
pos += 1;
}
}
let mut result: Vec<LiveInterval> = intervals.into_values().collect();
result.sort_by_key(|i| i.start);
result
}
fn scan_instruction(
instr: &Instruction,
pos: usize,
intervals: &mut HashMap<VirtualReg, LiveInterval>,
) {
if let Some(def) = Self::def_reg(instr)
&& let Register::Virtual(v) = def
{
let entry = intervals.entry(v).or_insert(LiveInterval {
vreg: v,
start: pos,
end: pos,
});
if pos > entry.end {
entry.end = pos;
}
}
for used in Self::use_regs(instr) {
if let Register::Virtual(v) = used {
let entry = intervals.entry(v).or_insert(LiveInterval {
vreg: v,
start: pos,
end: pos,
});
if pos > entry.end {
entry.end = pos;
}
}
}
}
fn def_reg(instr: &Instruction) -> Option<Register> {
match instr {
Instruction::IntBinary { dst, .. }
| Instruction::FloatBinary { dst, .. }
| Instruction::FloatUnary { dst, .. }
| Instruction::IntCmp { dst, .. }
| Instruction::FloatCmp { dst, .. }
| Instruction::Select { dst, .. }
| Instruction::Load { dst, .. }
| Instruction::Lea { dst, .. }
| Instruction::VectorOp { dst, .. } => Some(dst.clone()),
Instruction::Call { ret: Some(ret), .. } => Some(ret.clone()),
#[cfg(feature = "nightly")]
Instruction::SimdBinary { dst, .. }
| Instruction::SimdUnary { dst, .. }
| Instruction::SimdTernary { dst, .. }
| Instruction::SimdShuffle { dst, .. }
| Instruction::SimdExtract { dst, .. }
| Instruction::SimdInsert { dst, .. }
| Instruction::SimdLoad { dst, .. } => Some(dst.clone()),
#[cfg(feature = "nightly")]
Instruction::AtomicLoad { dst, .. }
| Instruction::AtomicBinary { dst, .. }
| Instruction::AtomicCompareExchange { dst, .. } => Some(dst.clone()),
_ => None,
}
}
fn use_regs(instr: &Instruction) -> Vec<Register> {
let mut uses = Vec::new();
let push_op = |uses: &mut Vec<Register>, op: &Operand| {
if let Operand::Register(r) = op {
uses.push(r.clone());
}
};
let push_addr = |uses: &mut Vec<Register>, addr: &lamina_mir::AddressMode| match addr {
lamina_mir::AddressMode::BaseOffset { base, .. } => uses.push(base.clone()),
lamina_mir::AddressMode::BaseIndexScale { base, index, .. } => {
uses.push(base.clone());
uses.push(index.clone());
}
};
match instr {
Instruction::IntBinary { lhs, rhs, .. }
| Instruction::FloatBinary { lhs, rhs, .. }
| Instruction::IntCmp { lhs, rhs, .. }
| Instruction::FloatCmp { lhs, rhs, .. } => {
push_op(&mut uses, lhs);
push_op(&mut uses, rhs);
}
Instruction::FloatUnary { src, .. } => push_op(&mut uses, src),
Instruction::Select {
cond,
true_val,
false_val,
..
} => {
uses.push(cond.clone());
push_op(&mut uses, true_val);
push_op(&mut uses, false_val);
}
Instruction::Load { addr, .. } => push_addr(&mut uses, addr),
Instruction::Store { src, addr, .. } => {
push_op(&mut uses, src);
push_addr(&mut uses, addr);
}
Instruction::Lea { base, .. } => uses.push(base.clone()),
Instruction::VectorOp { operands, .. } => {
for op in operands {
push_op(&mut uses, op);
}
}
Instruction::Call { args, .. } | Instruction::TailCall { args, .. } => {
for op in args {
push_op(&mut uses, op);
}
}
Instruction::Ret { value: Some(v) } => push_op(&mut uses, v),
Instruction::Br { cond, .. } | Instruction::Switch { value: cond, .. } => {
uses.push(cond.clone());
}
_ => {}
}
uses
}
pub fn allocate<R: Copy + Eq>(
intervals: &[LiveInterval],
available_regs: &[R],
) -> HashMap<VirtualReg, Allocation<R>> {
let mut result: HashMap<VirtualReg, Allocation<R>> = HashMap::new();
let mut active: Vec<(&LiveInterval, R)> = Vec::new();
let mut free: Vec<R> = available_regs.to_vec();
let mut next_spill: i32 = -8;
for interval in intervals {
let current_start = interval.start;
let mut freed: Vec<R> = Vec::new();
active.retain(|(ai, reg)| {
if ai.end < current_start {
freed.push(*reg);
false
} else {
true
}
});
free.extend(freed);
if let Some(reg) = free.pop() {
result.insert(interval.vreg, Allocation::Register(reg));
let pos = active
.binary_search_by_key(&interval.end, |(ai, _)| ai.end)
.unwrap_or_else(|i| i);
active.insert(pos, (interval, reg));
} else {
match active.last().cloned() {
Some((spill_interval, spill_reg)) if spill_interval.end > interval.end => {
result.insert(spill_interval.vreg, Allocation::Spill(next_spill));
next_spill -= 8;
active.pop();
result.insert(interval.vreg, Allocation::Register(spill_reg));
let pos = active
.binary_search_by_key(&interval.end, |(ai, _)| ai.end)
.unwrap_or_else(|i| i);
active.insert(pos, (interval, spill_reg));
}
_ => {
result.insert(interval.vreg, Allocation::Spill(next_spill));
next_spill -= 8;
}
}
}
}
result
}
}
#[inline]
pub fn intervals_interfere(a: &LiveInterval, b: &LiveInterval) -> bool {
a.start <= b.end && b.start <= a.end
}
pub struct GraphColorAllocator;
impl GraphColorAllocator {
pub fn allocate<R: Copy + Eq + std::hash::Hash>(
intervals: &[LiveInterval],
available_regs: &[R],
) -> HashMap<VirtualReg, Allocation<R>> {
if intervals.is_empty() {
return HashMap::new();
}
let mut order: Vec<usize> = (0..intervals.len()).collect();
order.sort_by(|&i, &j| {
let deg_i = intervals
.iter()
.enumerate()
.filter(|(k, other)| *k != i && intervals_interfere(&intervals[i], other))
.count();
let deg_j = intervals
.iter()
.enumerate()
.filter(|(k, other)| *k != j && intervals_interfere(&intervals[j], other))
.count();
deg_j.cmp(°_i).then_with(|| i.cmp(&j))
});
let mut result: HashMap<VirtualReg, Allocation<R>> = HashMap::new();
let mut next_spill: i32 = -8;
for idx in order {
let interval = &intervals[idx];
let mut blocked: HashSet<R> = HashSet::new();
for (j, other) in intervals.iter().enumerate() {
if j == idx || !intervals_interfere(interval, other) {
continue;
}
if let Some(Allocation::Register(r)) = result.get(&other.vreg) {
blocked.insert(*r);
}
}
let mut picked: Option<R> = None;
for reg in available_regs {
if !blocked.contains(reg) {
picked = Some(*reg);
break;
}
}
match picked {
Some(r) => {
result.insert(interval.vreg, Allocation::Register(r));
}
None => {
result.insert(interval.vreg, Allocation::Spill(next_spill));
next_spill -= 8;
}
}
}
result
}
}
impl<T> RegisterAllocatorDyn for T
where
T: RegisterAllocator,
{
fn alloc_scratch_dyn(&mut self) -> Option<PhysRegHandle> {
self.alloc_scratch().map(|reg| reg.into_handle())
}
fn free_scratch_dyn(&mut self, phys: PhysRegHandle) {
if let Some(reg) = <T::PhysReg as PhysRegConvertible>::from_handle(phys) {
self.free_scratch(reg);
} else {
debug_assert!(false, "failed to decode physical register handle");
}
}
fn get_mapping_dyn(&self, vreg: &VirtualReg) -> Option<PhysRegHandle> {
self.get_mapping(vreg).map(|reg| reg.into_handle())
}
fn ensure_mapping_dyn(&mut self, vreg: VirtualReg) -> Option<PhysRegHandle> {
self.ensure_mapping(vreg).map(|reg| reg.into_handle())
}
fn mapped_for_register_dyn(&self, reg: &Register) -> Option<PhysRegHandle> {
self.mapped_for_register(reg).map(|r| r.into_handle())
}
fn occupy_dyn(&mut self, phys: PhysRegHandle) {
if let Some(reg) = <T::PhysReg as PhysRegConvertible>::from_handle(phys) {
self.occupy(reg);
} else {
debug_assert!(false, "failed to decode physical register handle");
}
}
fn release_dyn(&mut self, phys: PhysRegHandle) {
if let Some(reg) = <T::PhysReg as PhysRegConvertible>::from_handle(phys) {
self.release(reg);
} else {
debug_assert!(false, "failed to decode physical register handle");
}
}
fn is_occupied_dyn(&self, phys: PhysRegHandle) -> bool {
if let Some(reg) = <T::PhysReg as PhysRegConvertible>::from_handle(phys) {
self.is_occupied(reg)
} else {
debug_assert!(false, "failed to decode physical register handle");
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use lamina_mir::{
Block, FunctionBuilder, Instruction, IntBinOp, MirType, Operand, Register, ScalarType,
VirtualReg,
};
fn make_add_function() -> Function {
let v0 = Register::Virtual(VirtualReg::gpr(0));
let v1 = Register::Virtual(VirtualReg::gpr(1));
let v2 = Register::Virtual(VirtualReg::gpr(2));
let i64_ty = MirType::Scalar(ScalarType::I64);
FunctionBuilder::new("add")
.param(v0.clone(), i64_ty)
.param(v1.clone(), i64_ty)
.returns(i64_ty)
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: i64_ty,
dst: v2.clone(),
lhs: Operand::Register(v0),
rhs: Operand::Register(v1),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(v2)),
})
.build()
}
#[test]
fn test_compute_intervals_basic() {
let func = make_add_function();
let intervals = LinearScanAllocator::compute_intervals(&func);
assert!(!intervals.is_empty());
let vreg_ids: Vec<u32> = intervals.iter().map(|i| i.vreg.id).collect();
assert!(vreg_ids.contains(&0));
assert!(vreg_ids.contains(&1));
assert!(vreg_ids.contains(&2));
}
#[test]
fn test_compute_intervals_sorted_by_start() {
let func = make_add_function();
let intervals = LinearScanAllocator::compute_intervals(&func);
let starts: Vec<usize> = intervals.iter().map(|i| i.start).collect();
let mut sorted = starts.clone();
sorted.sort_unstable();
assert_eq!(starts, sorted, "intervals should be sorted by start");
}
#[test]
fn test_allocate_fits_in_registers() {
let func = make_add_function();
let intervals = LinearScanAllocator::compute_intervals(&func);
let regs = ["r12", "r13", "r14", "r15"];
let alloc = LinearScanAllocator::allocate(&intervals, ®s);
for interval in &intervals {
let a = alloc
.get(&interval.vreg)
.expect("every vreg should be allocated");
assert!(
matches!(a, Allocation::Register(_)),
"vreg {:?} should be in a register, got {:?}",
interval.vreg,
a
);
}
}
#[test]
fn test_allocate_spills_when_registers_exhausted() {
let i64_ty = MirType::Scalar(ScalarType::I64);
let mut func = FunctionBuilder::new("spill_test").returns(i64_ty).build();
let mut block = Block::new("entry");
for i in 0u32..8 {
let vi = Register::Virtual(VirtualReg::gpr(i));
let vj = Register::Virtual(VirtualReg::gpr(i + 1));
let vd = Register::Virtual(VirtualReg::gpr(i + 2));
block.push(Instruction::IntBinary {
op: IntBinOp::Add,
ty: i64_ty,
dst: vd,
lhs: Operand::Register(vi),
rhs: Operand::Register(vj),
});
}
block.push(Instruction::Ret { value: None });
func.add_block(block);
let intervals = LinearScanAllocator::compute_intervals(&func);
let regs = ["r12", "r13"]; let alloc = LinearScanAllocator::allocate(&intervals, ®s);
let has_spill = alloc.values().any(|a| matches!(a, Allocation::Spill(_)));
assert!(has_spill, "expected spills when registers are exhausted");
for a in alloc.values() {
if let Allocation::Spill(offset) = a {
assert!(*offset < 0, "spill offset should be negative");
assert_eq!(offset % 8, 0, "spill offset should be 8-byte aligned");
}
}
}
#[test]
fn graph_color_fits_three_vregs_without_spill() {
let func = make_add_function();
let intervals = LinearScanAllocator::compute_intervals(&func);
let regs = ["r12", "r13", "r14", "r15"];
let gc = GraphColorAllocator::allocate(&intervals, ®s);
for interval in &intervals {
let a = gc.get(&interval.vreg).expect("allocated");
assert!(
matches!(a, Allocation::Register(_)),
"graph color should keep simple add in registers, got {:?}",
a
);
}
}
#[test]
fn graph_color_spills_when_register_count_exhausted() {
let i64_ty = MirType::Scalar(ScalarType::I64);
let mut func = FunctionBuilder::new("spill_gc").returns(i64_ty).build();
let mut block = Block::new("entry");
for i in 0u32..8 {
let vi = Register::Virtual(VirtualReg::gpr(i));
let vj = Register::Virtual(VirtualReg::gpr(i + 1));
let vd = Register::Virtual(VirtualReg::gpr(i + 2));
block.push(Instruction::IntBinary {
op: IntBinOp::Add,
ty: i64_ty,
dst: vd,
lhs: Operand::Register(vi),
rhs: Operand::Register(vj),
});
}
block.push(Instruction::Ret { value: None });
func.add_block(block);
let intervals = LinearScanAllocator::compute_intervals(&func);
let regs = ["r12", "r13"];
let gc = GraphColorAllocator::allocate(&intervals, ®s);
let has_spill = gc.values().any(|a| matches!(a, Allocation::Spill(_)));
assert!(
has_spill,
"graph color should spill when k=2 and pressure is high"
);
}
}