use crate::ir::{Ir, VReg};
use std::collections::HashMap;
pub type Allocation = HashMap<VReg, u8>;
const FIRST_PHYS: u8 = 1;
const LAST_PHYS: u8 = 254;
#[derive(Debug, Clone)]
struct Interval {
vreg: VReg,
start: usize,
end: usize,
}
fn compute_intervals(instrs: &[Ir]) -> Vec<Interval> {
let mut first_def: HashMap<VReg, usize> = HashMap::new();
let mut last_use: HashMap<VReg, usize> = HashMap::new();
for (i, instr) in instrs.iter().enumerate() {
if let Some(d) = instr.def() {
first_def.entry(d).or_insert(i);
last_use
.entry(d)
.and_modify(|e| *e = (*e).max(i))
.or_insert(i);
}
for u in instr.uses() {
last_use
.entry(u)
.and_modify(|e| *e = (*e).max(i))
.or_insert(i);
first_def.entry(u).or_insert(0);
}
}
let n = instrs.len();
for (i, instr) in instrs.iter().enumerate() {
if matches!(instr, Ir::Call(_)) {
for (vreg, lu) in last_use.iter_mut() {
let def = first_def.get(vreg).copied().unwrap_or(usize::MAX);
if def < i && *lu > i {
*lu = n.saturating_sub(1); }
}
}
}
let mut intervals: Vec<Interval> = first_def
.into_iter()
.map(|(vreg, start)| {
let end = last_use.get(&vreg).copied().unwrap_or(start);
Interval { vreg, start, end }
})
.collect();
intervals.sort_unstable_by_key(|i| i.start);
intervals
}
pub fn allocate(instrs: &[Ir]) -> Allocation {
let mut alloc: Allocation = HashMap::new();
alloc.insert(0, 0);
let intervals = compute_intervals(instrs);
if intervals.is_empty() {
return alloc;
}
let mut free: Vec<u8> = (FIRST_PHYS..=LAST_PHYS).collect();
let mut active: Vec<Interval> = Vec::new();
for interval in intervals {
if interval.vreg == 0 {
continue;
}
let mut expired_count = 0;
for a in &active {
if a.end < interval.start {
let freed_phys = *alloc.get(&a.vreg).unwrap();
let pos = free.partition_point(|&r| r < freed_phys);
free.insert(pos, freed_phys);
expired_count += 1;
} else {
break; }
}
active.drain(..expired_count);
if let Some(phys) = free.first().copied() {
free.remove(0);
alloc.insert(interval.vreg, phys);
let pos = active.partition_point(|a| a.end <= interval.end);
active.insert(pos, interval);
} else {
panic!(
"register spill required for vreg {} - \
more than 254 simultaneously live variables is not supported",
interval.vreg
);
}
}
alloc
}
#[inline(always)]
pub fn resolve(alloc: &Allocation, vreg: VReg) -> u8 {
*alloc
.get(&vreg)
.unwrap_or_else(|| panic!("vreg {} not allocated", vreg))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::Ir;
#[test]
fn basic_allocation() {
let instrs = vec![
Ir::GetCaller(1),
Ir::GetValue(2),
Ir::Add(3, 1, 2),
Ir::Halt,
];
let alloc = allocate(&instrs);
let p1 = alloc[&1];
let p2 = alloc[&2];
let p3 = alloc[&3];
assert!(p1 >= 1 && p1 <= 254);
assert!(p2 >= 1 && p2 <= 254);
assert!(p3 >= 1 && p3 <= 254);
assert_ne!(p1, p2);
assert_ne!(p1, p3);
assert_ne!(p2, p3);
}
#[test]
fn register_reuse_after_last_use() {
let instrs = vec![
Ir::GetCaller(1), Ir::RequireNonZero(1), Ir::GetValue(2), Ir::Halt,
];
let alloc = allocate(&instrs);
assert_eq!(alloc[&1], alloc[&2]);
}
#[test]
fn call_does_not_clobber_caller_regs() {
let instrs = vec![
Ir::GetCaller(1), Ir::Call(42), Ir::RequireNonZero(1), Ir::GetCaller(1000), Ir::RequireNonZero(1000),
Ir::Halt,
];
let alloc = allocate(&instrs);
assert_ne!(
alloc[&1], alloc[&1000],
"caller vreg v1 and callee vreg v1000 must not share a physical register"
);
}
#[test]
fn zero_vreg_always_r0() {
let instrs = vec![Ir::Add(1, 0, 0), Ir::Halt];
let alloc = allocate(&instrs);
assert_eq!(alloc[&0], 0);
}
#[test]
fn many_vregs_no_spill() {
let mut instrs: Vec<Ir> = (1u32..=200).map(|i| Ir::GetCaller(i)).collect();
for i in 1u32..=200 {
instrs.push(Ir::RequireNonZero(i));
}
instrs.push(Ir::Halt);
let alloc = allocate(&instrs);
let mut phys: Vec<u8> = (1u32..=200).map(|i| alloc[&i]).collect();
phys.sort();
phys.dedup();
assert_eq!(
phys.len(),
200,
"all 200 vregs should get distinct physical regs"
);
}
}