use std::{
collections::{HashMap, HashSet},
fmt,
hash::Hash,
};
use log::trace;
use petgraph::Direction::Outgoing;
use super::ir::instruction::{Block, BlockId, ControlFlowGraph, Instruction, Value};
use crate::bytecode::{MIN_REQUIRED_REGISTER, Register};
#[derive(Debug, Clone)]
pub struct LiveRange {
start: usize,
end: usize,
}
#[derive(Debug, Clone)]
pub struct LiveInterval {
var: Value,
start: usize,
end: usize,
ranges: Vec<LiveRange>,
reg: Option<Register>,
stack: Option<usize>,
}
impl LiveInterval {
pub fn new(var: Value) -> Self {
LiveInterval {
var,
start: usize::MAX,
end: 0,
ranges: Vec::new(),
reg: None,
stack: None,
}
}
pub fn active(&mut self, index: usize, at_block_start: bool) {
match self.ranges.last_mut() {
Some(last) => {
if at_block_start {
self.ranges.push(LiveRange {
start: index,
end: index,
});
} else if last.end == index - 1 {
last.end = last.end.max(index);
} else {
self.ranges.push(LiveRange {
start: index,
end: index,
});
}
}
None => {
self.ranges.push(LiveRange {
start: index,
end: index,
});
}
}
self.start = self.start.min(index);
self.end = self.end.max(index);
}
pub fn update_end(&mut self, index: usize) {
self.end = self.end.max(index);
}
}
#[derive(Debug, Clone)]
pub(crate) struct Liveness {
intervals: HashMap<Value, LiveInterval>,
}
impl Liveness {
fn new() -> Self {
Liveness {
intervals: HashMap::new(),
}
}
fn intervals(&self) -> Vec<LiveInterval> {
self.intervals.values().cloned().collect()
}
fn set_register(&mut self, var: Value, reg: Register) {
self.intervals.get_mut(&var).unwrap().reg.replace(reg);
}
fn set_stack(&mut self, var: Value, stack: usize) {
self.intervals.get_mut(&var).unwrap().stack.replace(stack);
}
fn stack_size(&self) -> usize {
let stacks = self
.intervals
.values()
.filter(|interval| interval.stack.is_some());
if stacks.clone().count() == 0 {
return 0;
}
stacks
.map(|interval| interval.stack.unwrap())
.max()
.unwrap_or(0)
+ 1
}
}
pub struct LiveIntervalAnalyzer {}
impl LiveIntervalAnalyzer {
pub fn scan(cfg: &ControlFlowGraph) -> Liveness {
let mut liveness = Self::build_basic_intervals(cfg);
let (_live_in_sets, live_out_sets) = Self::compute_liveness_sets(cfg);
Self::update_intervals(cfg, &live_out_sets, &mut liveness);
liveness
}
fn build_basic_intervals(cfg: &ControlFlowGraph) -> Liveness {
let mut liveness = Liveness::new();
let mut index = 0;
for block in cfg.blocks.iter() {
let block_start = index;
for inst in block.instructions.iter() {
Self::process_instruction(&mut liveness, index, inst, index == block_start);
index += 1;
}
}
liveness
}
fn process_instruction(
liveness: &mut Liveness,
index: usize,
inst: &Instruction,
at_block_start: bool,
) {
let (defined, used) = inst.defined_and_used_vars();
for var in used {
if matches!(var, Value::Variable(_)) {
liveness
.intervals
.entry(var)
.or_insert(LiveInterval::new(var))
.active(index, at_block_start);
}
}
for var in defined {
if matches!(var, Value::Variable(_)) {
liveness
.intervals
.entry(var)
.or_insert(LiveInterval::new(var))
.active(index, at_block_start);
}
}
}
fn compute_liveness_sets(
cfg: &ControlFlowGraph,
) -> (
HashMap<BlockId, HashSet<Value>>,
HashMap<BlockId, HashSet<Value>>,
) {
let mut live_in_sets = HashMap::new();
let mut live_out_sets = HashMap::new();
let mut changed = true;
for block in cfg.blocks.iter() {
live_in_sets.insert(block.id, HashSet::new());
live_out_sets.insert(block.id, HashSet::new());
}
while changed {
changed = false;
for block in cfg.blocks.iter().rev() {
let mut new_live_out = Self::compute_block_liveness(cfg, block, &live_in_sets);
for inst in block.instructions.iter().rev() {
let (defined, used) = inst.defined_and_used_vars();
for var in defined {
if matches!(var, Value::Variable(_)) {
new_live_out.remove(&var);
}
}
for var in used {
if matches!(var, Value::Variable(_)) {
new_live_out.insert(var);
}
}
}
let new_live_in: HashSet<Value> = new_live_out.clone();
let old_live_in = live_in_sets.get(&block.id).unwrap();
let old_live_out = live_out_sets.get(&block.id).unwrap();
if &new_live_in != old_live_in || &new_live_out != old_live_out {
changed = true;
live_in_sets.insert(block.id, new_live_in);
live_out_sets.insert(block.id, new_live_out);
}
}
}
(live_in_sets, live_out_sets)
}
fn compute_block_liveness(
cfg: &ControlFlowGraph,
block: &Block,
live_in_sets: &HashMap<BlockId, HashSet<Value>>,
) -> HashSet<Value> {
let mut live_out = HashSet::new();
if let Some(&node_index) = cfg.block_node_map.get(&block.id) {
let successors = cfg
.graph
.neighbors_directed(node_index, petgraph::Direction::Outgoing);
for succ_node in successors {
let succ_block_id = cfg.graph[succ_node];
if let Some(succ_live_in) = live_in_sets.get(&succ_block_id) {
live_out.extend(succ_live_in.iter().cloned());
}
}
}
live_out
}
fn update_intervals(
cfg: &ControlFlowGraph,
live_out_sets: &HashMap<BlockId, HashSet<Value>>,
liveness: &mut Liveness,
) {
let mut block_starts = Vec::with_capacity(cfg.blocks.len());
let mut current_index = 0;
for block in cfg.blocks.iter() {
block_starts.push(current_index);
current_index += block.instructions.len();
}
for (block_id, block) in cfg.blocks.iter().enumerate() {
let block_start = block_starts[block_id];
let block_end = block_start + block.instructions.len();
if let Some(live_out) = live_out_sets.get(&block.id) {
for &var in live_out {
if let Some(interval) = liveness.intervals.get_mut(&var) {
if let Some(&node_index) = cfg.block_node_map.get(&block.id) {
let successors = cfg.graph.neighbors_directed(node_index, Outgoing);
let is_loop = successors.clone().any(|succ_node| {
let succ_block_id = cfg.graph[succ_node];
if let Some(succ_block) =
cfg.blocks.iter().position(|b| b.id == succ_block_id)
{
succ_block <= block_id
} else {
false
}
});
if is_loop {
interval.update_end(block_end);
for succ_node in successors {
let succ_block_id = cfg.graph[succ_node];
if let Some(succ_block) =
cfg.blocks.iter().position(|b| b.id == succ_block_id)
{
if succ_block <= block_id {
interval.update_end(block_starts[succ_block]);
}
}
}
} else {
interval.update_end(block_end);
}
}
}
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct RegAlloc {
liveness: Liveness,
pub(super) reg_set: RegisterSet,
}
impl RegAlloc {
pub fn new(registers: &[Register]) -> Self {
if registers.len() <= MIN_REQUIRED_REGISTER {
panic!("Not enough registers");
}
Self {
liveness: Liveness::new(),
reg_set: RegisterSet::new(registers),
}
}
pub fn load_arg(&mut self, arg: usize) -> isize {
0 - ((arg as isize) + 1)
}
pub fn in_use_registers(&self) -> Vec<Register> {
self.reg_set
.registers
.iter()
.filter(|reg| reg.variable.is_some())
.map(|reg| reg.register)
.collect()
}
pub fn stack_size(&self) -> usize {
self.liveness.stack_size()
}
pub fn alloc(&mut self, value: Value, index: usize) -> (Register, Option<Action>) {
trace!("allocating {value}");
let interval = self.liveness.intervals.get(&value).unwrap();
match self.reg_set.find(value) {
Some(register) => (register, None),
None => match interval.reg {
Some(register) => {
self.reg_set.use_register(register, value, true);
(register, None)
}
None => {
let reg = self.reg_set.must_alloc(value);
if interval
.ranges
.iter()
.any(|range| range.start == index && interval.start != index)
{
let spill = Action::Restore {
stack: interval.stack.unwrap(),
register: reg,
};
return (reg, Some(spill));
}
(reg, None)
}
},
}
}
pub fn release(&mut self, value: Value, index: usize) -> Option<Action> {
trace!("releasing {value}");
if !matches!(value, Value::Variable(_)) {
return None;
}
let interval = self.liveness.intervals.get(&value).unwrap();
if interval.ranges.iter().any(|range| range.end == index) {
if let Some(stack) = interval.stack {
if let Some(register) = self.reg_set.release(value) {
let spill = Action::Spill { register, stack };
return Some(spill);
}
}
}
None
}
pub fn arrange(&mut self, cfg: &ControlFlowGraph) {
self.liveness = LiveIntervalAnalyzer::scan(cfg);
let registers: Vec<Register> = self
.reg_set
.registers
.iter()
.map(|reg| reg.register)
.collect();
let mut intervals = self.liveness.intervals();
intervals.sort_by(|a, b| {
let a_len = a.end - a.start;
let b_len = b.end - b.start;
a.start.cmp(&b.start).then(b_len.cmp(&a_len))
});
let mut groups: Vec<Vec<LiveInterval>> = Vec::new();
for interval in intervals {
let mut placed = false;
for group in groups.iter_mut() {
if Self::can_join_group(&interval, group) {
group.push(interval.clone());
placed = true;
break;
}
}
if !placed {
groups.push(vec![interval.clone()]);
}
}
if groups.len() <= registers.len() {
for (group, reg) in groups.into_iter().zip(registers.iter()) {
for interval in group {
self.liveness.set_register(interval.var, *reg);
}
}
return;
}
let (_temp_regs, fixed_regs) = registers.split_at(3);
groups.sort_by(|a, b| {
let a_len: usize = a.iter().map(|interval| interval.ranges.len()).sum();
let b_len: usize = b.iter().map(|interval| interval.ranges.len()).sum();
b_len.cmp(&a_len)
});
for (i, group) in groups.iter().enumerate() {
trace!("Group[{i}]: {group:?}");
}
let (fixed_group, temp_group) = groups.split_at(fixed_regs.len());
for (group, reg) in fixed_group.iter().zip(fixed_regs) {
for interval in group {
self.liveness.set_register(interval.var, *reg);
}
}
for (i, group) in temp_group.iter().enumerate() {
for interval in group {
self.liveness.set_stack(interval.var, i);
}
}
}
fn can_join_group(interval: &LiveInterval, group: &[LiveInterval]) -> bool {
group
.iter()
.all(|existing| !Self::has_overlap(interval, existing))
}
fn has_overlap(interval_a: &LiveInterval, interval_b: &LiveInterval) -> bool {
interval_a.start <= interval_b.end && interval_b.start <= interval_a.end
}
}
#[derive(Debug, Clone)]
pub(super) struct RegisterSet {
registers: Vec<RegisterHold>,
}
impl RegisterSet {
fn new(registers: &[Register]) -> Self {
let registers = registers
.iter()
.map(|addr| RegisterHold::new(*addr))
.collect();
Self { registers }
}
fn must_alloc(&mut self, value: Value) -> Register {
let reg = self
.registers
.iter_mut()
.find(|reg| reg.variable.is_none())
.unwrap();
reg.variable = Some(value);
reg.register
}
fn release(&mut self, value: Value) -> Option<Register> {
match self
.registers
.iter_mut()
.find(|reg| reg.variable == Some(value))
{
Some(reg) => {
reg.variable = None;
Some(reg.register)
}
None => None,
}
}
fn use_register(&mut self, register: Register, variable: Value, is_fixed: bool) {
for reg in self.registers.iter_mut() {
if reg.register == register {
reg.variable = Some(variable);
reg.is_fixed = is_fixed;
return;
}
}
}
fn find(&self, variable: Value) -> Option<Register> {
self.registers
.iter()
.find(|reg| reg.variable == Some(variable))
.map(|reg| reg.register)
}
}
impl fmt::Display for RegisterSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for register in self.registers.iter() {
match register.variable {
Some(var) => write!(f, "{var}"),
None => write!(f, "-"),
}?;
write!(f, "\t|")?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct RegisterHold {
register: Register,
variable: Option<Value>,
is_fixed: bool,
}
impl RegisterHold {
fn new(register: Register) -> Self {
Self {
register,
variable: None,
is_fixed: false,
}
}
}
pub enum Action {
Restore { stack: usize, register: Register },
Spill { register: Register, stack: usize },
}