use std::collections::{HashMap};
use std::fmt::{self, Debug, Formatter};
use super::{NUM_REGISTERS, all_registers, Resources, Dataflow, Node, Exit, Frontier};
use super::cost::{BUDGET, SPILL_COST, SLOT_COST};
use super::code::{Register, Variable};
use crate::util::{ArrayMap, map_filter_max, Usage};
mod pool;
use pool::{RegisterPool};
mod placer;
use placer::{Time, LEAST as EARLY, Placer};
#[derive(Copy, Clone, PartialEq)]
pub enum Instruction {
Spill(Node, Node),
Node(Node),
}
use Instruction::*;
impl Debug for Instruction {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
match *self {
Spill(out_x, out_y) => write!(f, "({:?}, {:?})", out_x, out_y),
Node(node) => node.fmt(f),
}
}
}
#[derive(Debug, Copy, Clone)]
pub struct Input {
is_value: bool,
is_cold: bool,
}
#[derive(Debug)]
struct Allocator<'a> {
dataflow: &'a Dataflow,
usage: Usage<Node, Input>,
placer: Placer<Instruction>,
allocation: HashMap<Node, Register>,
access_times: HashMap<Node, Time>,
node_times: HashMap<Node, Time>,
regs: ArrayMap<Register, Option<Node>>,
pool: RegisterPool,
}
impl<'a> Allocator<'a> {
pub fn new(
variables: &HashMap<Node, Variable>,
dataflow: &'a Dataflow,
usage: Usage<Node, Input>,
) -> Self {
let mut dirty = ArrayMap::new(NUM_REGISTERS);
let mut allocation: HashMap<Node, Register> = HashMap::new();
let mut regs: ArrayMap<Register, Option<Node>> = ArrayMap::new(NUM_REGISTERS);
for (&node, &value) in variables.iter() {
if usage.topmost(&node).is_some() {
if let Variable::Register(reg) = value {
dirty[reg] = true;
regs[reg] = Some(node);
allocation.insert(node, reg);
}
}
}
let placer = Placer::new();
let access_times: HashMap<Node, Time> = HashMap::new();
let node_times: HashMap<Node, Time> = HashMap::new();
let pool = RegisterPool::new(dirty);
Allocator {dataflow, usage, placer, allocation, access_times, node_times, regs, pool}
}
fn current_reg(&self, node: Node) -> Option<Register> {
self.allocation.get(&node).copied().filter(
|®| self.regs[reg] == Some(node)
)
}
fn pop_use(&mut self) -> (Node, Input) {
let (node, input) = self.usage.pop().expect("Incorrect usage information");
if self.usage.topmost(&node).is_none() {
if let Some(reg) = self.current_reg(node) {
self.pool.free(reg);
}
}
(node, input)
}
fn access(&mut self, node: Node, time: Time) {
self.access_times.entry(node).or_insert(EARLY).max_with(time);
}
fn free_a_register(&mut self) -> Register {
let i = map_filter_max(all_registers(), |reg| {
self.regs[reg]
.filter(|_| !self.pool.is_clean(reg))
.map(|node| std::cmp::Reverse(
self.usage.topmost(&node).expect("Dirty register is unused")
))
}).expect("No register is dirty");
let reg = Register::new(i as u8).unwrap();
self.pool.free(reg);
reg
}
fn node_time(&self, node: Node, add_latency: bool) -> Time {
if let Some(&time) = self.node_times.get(&node) {
if add_latency {
time + (self.dataflow.cost(node).latency as usize)
} else {
time
}
} else {
EARLY
}
}
fn spill_until(&mut self, num_required: usize) {
while self.pool.num_clean() < num_required {
let reg_x = self.free_a_register();
let reg_y = self.free_a_register();
let node_x = self.regs[reg_x].unwrap();
let node_y = self.regs[reg_y].unwrap();
let mut time = self.node_time(node_x, true);
time.max_with(self.node_time(node_y, true));
self.placer.add_item(Spill(node_x, node_y), SPILL_COST, &mut time);
self.access(node_x, time);
self.access(node_y, time);
}
}
pub fn add_node(&mut self, node: Node, num_inputs: usize) {
let df: &'a Dataflow = self.dataflow;
let mut time = EARLY; let mut inputs = Vec::<(Node, Input)>::new();
let mut has_spilled_input = false;
for _ in 0..num_inputs {
let (in_, input) = self.pop_use();
inputs.push((in_, input));
if !input.is_cold {
has_spilled_input |= input.is_value & self.current_reg(in_).is_none();
time.max_with(self.node_time(in_, input.is_value));
}
}
if df.has_out(node) {
self.spill_until(1);
let reg = self.pool.allocate();
self.allocation.insert(node, reg);
if let Some(prev) = self.regs[reg].replace(node) {
if let Some(&read_time) = self.access_times.get(&prev) {
time.max_with(read_time);
}
}
if self.usage.topmost(&node).is_none() {
self.pool.free(reg);
}
}
let mut resources = df.cost(node).resources;
if has_spilled_input {
resources += SLOT_COST;
}
self.placer.add_item(Node(node), resources, &mut time);
self.node_times.insert(node, time);
for &(node, input) in &inputs {
if input.is_value {
self.access(node, time);
}
}
if df.has_out(node) {
self.access(node, time);
}
}
fn finish(mut self, num_outputs: usize) -> (Vec<Instruction>, HashMap<Node, Register>) {
for _ in 0..num_outputs { let _ = self.pop_use(); }
let _ = self.pop_use();
assert_eq!(self.usage.len(), 0);
assert!(all_registers().all(|reg| self.pool.is_clean(reg)));
(self.placer.iter().cloned().collect(), self.allocation)
}
}
#[derive(Debug, Default)]
struct Address {
mems: Vec<Node>,
sends: Vec<Node>,
}
#[derive(Debug, Default)]
struct Queue {
counts: HashMap<Node, usize>,
queue: Vec<Node>,
}
impl Queue {
pub fn new(nodes: &[Node]) -> Self {
Self {
counts: nodes.iter().map(|&node| (node, 0)).collect(),
queue: Vec::new(),
}
}
pub fn increment(&mut self, node: Node) {
if let Some(count) = self.counts.get_mut(&node) { *count += 1; }
}
pub fn decrement(&mut self, node: Node) {
if let Some(count) = self.counts.get_mut(&node) {
*count -= 1;
if *count == 0 {
self.queue.push(node);
}
}
}
pub fn pop(&mut self) -> Option<Node> {
self.queue.pop()
}
}
pub fn allocate<'a>(
variables: &HashMap<Node, Variable>,
dataflow: &Dataflow,
nodes: &[Node],
get_frontier: impl Fn(Node) -> Option<&'a Frontier>,
exit: &Exit,
) -> (
Vec<Instruction>,
HashMap<Node, Register>
) {
let mut queue = Queue::new(nodes);
let mut addresses = HashMap::<Node, Address>::new();
for &node in nodes {
dataflow.each_input(node, |in_, dep| {
if !dep.is_cold() {
queue.increment(in_);
}
if dep.is_address() {
addresses.entry(in_).or_default().mems.push(node);
}
if dep.is_send() {
addresses.entry(in_).or_default().sends.push(node);
}
});
}
for address in addresses.values() {
for &send in &address.sends {
for &mem in &address.mems {
if mem != send {
queue.increment(mem);
}
}
}
}
queue.increment(exit.sequence);
for &in_ in &*exit.outputs {
queue.increment(in_);
}
let mut usage = Usage::default();
let mut nodes_rev = Vec::new();
queue.decrement(exit.sequence);
usage.push(exit.sequence, Input {is_value: false, is_cold: false});
for &in_ in &*exit.outputs {
queue.decrement(in_);
usage.push(in_, Input {is_value: true, is_cold: false});
}
while let Some(node) = queue.pop() {
let start = usage.len();
dataflow.each_input(node, |in_, dep| {
if !dep.is_cold() {
queue.decrement(in_);
usage.push(in_, Input {is_value: dep.is_value(), is_cold: false});
}
if dep.is_send() {
for &mem in &addresses[&in_].mems {
if mem != node {
queue.decrement(mem);
usage.push(mem, Input {is_value: false, is_cold: false});
}
}
}
});
if let Some(f) = get_frontier(node) {
for (&in_, &v) in &f.0 {
usage.push(in_, Input {is_value: v.is_value(), is_cold: true});
}
}
let end = usage.len();
nodes_rev.push((node, end - start));
}
assert_eq!(nodes_rev.len(), nodes.len());
let mut a = Allocator::new(variables, dataflow, usage);
while let Some((node, num_inputs)) = nodes_rev.pop() {
a.add_node(node, num_inputs);
}
a.finish(exit.outputs.len())
}