use std::fmt;
use crate::error::{StatorError, StatorResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Register(pub i32);
impl Register {
pub const ACCUMULATOR: Self = Self(i32::MIN);
pub fn parameter(index: u32) -> Self {
Self(-(index as i32) - 1)
}
pub fn local(index: u32) -> Self {
Self(index as i32)
}
pub fn is_accumulator(self) -> bool {
self == Self::ACCUMULATOR
}
pub fn is_parameter(self) -> bool {
self.0 < 0 && self != Self::ACCUMULATOR
}
pub fn is_local(self) -> bool {
self.0 >= 0
}
pub fn parameter_index(self) -> Option<u32> {
if self.is_parameter() {
Some((-(self.0 + 1)) as u32)
} else {
None
}
}
pub fn local_index(self) -> Option<u32> {
if self.is_local() {
Some(self.0 as u32)
} else {
None
}
}
}
impl fmt::Display for Register {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_accumulator() {
write!(f, "acc")
} else if let Some(idx) = self.parameter_index() {
write!(f, "a{idx}")
} else {
write!(f, "r{}", self.local_index().unwrap())
}
}
}
#[derive(Debug)]
pub struct RegisterAllocator {
parameter_count: u32,
local_count: u32,
temporary_count: u32,
frame_size: u32,
}
impl RegisterAllocator {
pub fn new(parameter_count: u32) -> Self {
Self {
parameter_count,
local_count: 0,
temporary_count: 0,
frame_size: 0,
}
}
pub fn new_parameter(&self, index: u32) -> StatorResult<Register> {
if index < self.parameter_count {
Ok(Register::parameter(index))
} else {
Err(StatorError::Internal(format!(
"parameter index {index} out of range (count = {})",
self.parameter_count
)))
}
}
pub fn new_local(&mut self) -> Register {
let reg = Register::local(self.local_count);
self.local_count += 1;
self.update_frame_size();
reg
}
pub fn allocate_temporary(&mut self) -> Register {
let reg = Register::local(self.local_count + self.temporary_count);
self.temporary_count += 1;
self.update_frame_size();
reg
}
pub fn release_temporary(&mut self, reg: Register) -> StatorResult<()> {
if self.temporary_count == 0 {
return Err(StatorError::Internal(
"release_temporary called with no live temporaries".into(),
));
}
let expected_index = self.local_count + self.temporary_count - 1;
match reg.local_index() {
Some(idx) if idx == expected_index => {
self.temporary_count -= 1;
Ok(())
}
_ => Err(StatorError::Internal(format!(
"release_temporary: expected r{expected_index}, got {reg}"
))),
}
}
pub fn parameter_count(&self) -> u32 {
self.parameter_count
}
pub fn local_count(&self) -> u32 {
self.local_count
}
pub fn temporary_count(&self) -> u32 {
self.temporary_count
}
pub fn frame_size(&self) -> u32 {
self.frame_size
}
fn update_frame_size(&mut self) {
let current = self.local_count + self.temporary_count;
if current > self.frame_size {
self.frame_size = current;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_accumulator() {
let acc = Register::ACCUMULATOR;
assert!(acc.is_accumulator());
assert!(!acc.is_parameter());
assert!(!acc.is_local());
assert_eq!(acc.parameter_index(), None);
assert_eq!(acc.local_index(), None);
assert_eq!(acc.to_string(), "acc");
}
#[test]
fn test_register_parameter() {
let p0 = Register::parameter(0);
assert_eq!(p0, Register(-1));
assert!(p0.is_parameter());
assert!(!p0.is_accumulator());
assert!(!p0.is_local());
assert_eq!(p0.parameter_index(), Some(0));
assert_eq!(p0.local_index(), None);
assert_eq!(p0.to_string(), "a0");
let p3 = Register::parameter(3);
assert_eq!(p3, Register(-4));
assert_eq!(p3.parameter_index(), Some(3));
assert_eq!(p3.to_string(), "a3");
}
#[test]
fn test_register_local() {
let r0 = Register::local(0);
assert_eq!(r0, Register(0));
assert!(r0.is_local());
assert!(!r0.is_accumulator());
assert!(!r0.is_parameter());
assert_eq!(r0.local_index(), Some(0));
assert_eq!(r0.parameter_index(), None);
assert_eq!(r0.to_string(), "r0");
let r5 = Register::local(5);
assert_eq!(r5.local_index(), Some(5));
assert_eq!(r5.to_string(), "r5");
}
#[test]
fn test_allocator_parameter_range() {
let alloc = RegisterAllocator::new(3);
assert_eq!(alloc.parameter_count(), 3);
assert_eq!(alloc.new_parameter(0).unwrap(), Register::parameter(0));
assert_eq!(alloc.new_parameter(2).unwrap(), Register::parameter(2));
assert!(alloc.new_parameter(3).is_err());
}
#[test]
fn test_allocator_locals_sequential() {
let mut alloc = RegisterAllocator::new(0);
let r0 = alloc.new_local();
let r1 = alloc.new_local();
let r2 = alloc.new_local();
assert_eq!(r0, Register::local(0));
assert_eq!(r1, Register::local(1));
assert_eq!(r2, Register::local(2));
assert_eq!(alloc.local_count(), 3);
}
#[test]
fn test_allocate_release_temporary() {
let mut alloc = RegisterAllocator::new(0);
let _x = alloc.new_local();
let t0 = alloc.allocate_temporary();
assert_eq!(t0, Register::local(1)); assert_eq!(alloc.temporary_count(), 1);
let t1 = alloc.allocate_temporary();
assert_eq!(t1, Register::local(2));
assert_eq!(alloc.temporary_count(), 2);
alloc.release_temporary(t1).unwrap();
assert_eq!(alloc.temporary_count(), 1);
alloc.release_temporary(t0).unwrap();
assert_eq!(alloc.temporary_count(), 0);
}
#[test]
fn test_release_out_of_order_is_error() {
let mut alloc = RegisterAllocator::new(0);
let t0 = alloc.allocate_temporary();
let _t1 = alloc.allocate_temporary();
assert!(alloc.release_temporary(t0).is_err());
}
#[test]
fn test_release_when_none_live_is_error() {
let mut alloc = RegisterAllocator::new(0);
let fake = Register::local(0);
assert!(alloc.release_temporary(fake).is_err());
}
#[test]
fn test_frame_size_tracks_high_water_mark() {
let mut alloc = RegisterAllocator::new(1);
let _a = alloc.new_local();
let _b = alloc.new_local();
assert_eq!(alloc.frame_size(), 2);
let t0 = alloc.allocate_temporary();
assert_eq!(alloc.frame_size(), 3);
let t1 = alloc.allocate_temporary();
assert_eq!(alloc.frame_size(), 4);
alloc.release_temporary(t1).unwrap();
alloc.release_temporary(t0).unwrap();
assert_eq!(alloc.frame_size(), 4);
let t2 = alloc.allocate_temporary();
assert_eq!(alloc.frame_size(), 4);
alloc.release_temporary(t2).unwrap();
assert_eq!(alloc.frame_size(), 4);
}
#[test]
fn test_frame_size_zero_with_no_locals() {
let alloc = RegisterAllocator::new(5);
assert_eq!(alloc.frame_size(), 0);
}
#[test]
fn test_accumulator_distinct_from_all_registers() {
let acc = Register::ACCUMULATOR;
for i in 0_u32..100 {
assert_ne!(acc, Register::parameter(i));
assert_ne!(acc, Register::local(i));
}
}
}