use dbg::DisplayList;
use dominator_tree::DominatorTreePreorder;
use entity::EntityRef;
use entity::{EntityList, ListPool};
use entity::{Keys, PrimaryMap, SecondaryMap};
use ir::{Function, Value};
use packed_option::PackedOption;
use ref_slice::ref_slice;
use std::cmp::Ordering;
use std::fmt;
use std::vec::Vec;
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct VirtReg(u32);
entity_impl!(VirtReg, "vreg");
type ValueList = EntityList<Value>;
pub struct VirtRegs {
pool: ListPool<Value>,
vregs: PrimaryMap<VirtReg, ValueList>,
unused_vregs: Vec<VirtReg>,
value_vregs: SecondaryMap<Value, PackedOption<VirtReg>>,
union_find: SecondaryMap<Value, i32>,
pending_values: Vec<Value>,
}
impl VirtRegs {
pub fn new() -> Self {
Self {
pool: ListPool::new(),
vregs: PrimaryMap::new(),
unused_vregs: Vec::new(),
value_vregs: SecondaryMap::new(),
union_find: SecondaryMap::new(),
pending_values: Vec::new(),
}
}
pub fn clear(&mut self) {
self.vregs.clear();
self.unused_vregs.clear();
self.value_vregs.clear();
self.pool.clear();
self.union_find.clear();
self.pending_values.clear();
}
pub fn get(&self, value: Value) -> Option<VirtReg> {
self.value_vregs[value].into()
}
pub fn values(&self, vreg: VirtReg) -> &[Value] {
self.vregs[vreg].as_slice(&self.pool)
}
pub fn all_virtregs(&self) -> Keys<VirtReg> {
self.vregs.keys()
}
#[cfg_attr(feature = "cargo-clippy", allow(trivially_copy_pass_by_ref))]
pub fn congruence_class<'a, 'b>(&'a self, value: &'b Value) -> &'b [Value]
where
'a: 'b,
{
self.get(*value)
.map_or_else(|| ref_slice(value), |vr| self.values(vr))
}
pub fn same_class(&self, a: Value, b: Value) -> bool {
match (self.get(a), self.get(b)) {
(Some(va), Some(vb)) => va == vb,
_ => a == b,
}
}
pub fn sort_values(
&mut self,
vreg: VirtReg,
func: &Function,
preorder: &DominatorTreePreorder,
) -> &[Value] {
let s = self.vregs[vreg].as_mut_slice(&mut self.pool);
s.sort_unstable_by(|&a, &b| preorder.pre_cmp_def(a, b, func));
s
}
pub fn insert_single(
&mut self,
big: Value,
single: Value,
func: &Function,
preorder: &DominatorTreePreorder,
) -> VirtReg {
debug_assert_eq!(self.get(single), None, "Expected singleton {}", single);
let vreg = self.get(big).unwrap_or_else(|| {
let vr = self.alloc();
self.vregs[vr].push(big, &mut self.pool);
self.value_vregs[big] = vr.into();
vr
});
let index = match self
.values(vreg)
.binary_search_by(|&v| preorder.pre_cmp_def(v, single, func))
{
Ok(_) => panic!("{} already in {}", single, vreg),
Err(i) => i,
};
self.vregs[vreg].insert(index, single, &mut self.pool);
self.value_vregs[single] = vreg.into();
vreg
}
pub fn remove(&mut self, vreg: VirtReg) {
for &v in self.vregs[vreg].as_slice(&self.pool) {
let old = self.value_vregs[v].take();
debug_assert_eq!(old, Some(vreg));
}
self.vregs[vreg].clear(&mut self.pool);
self.unused_vregs.push(vreg);
}
fn alloc(&mut self) -> VirtReg {
self.unused_vregs
.pop()
.unwrap_or_else(|| self.vregs.push(Default::default()))
}
pub fn unify(&mut self, values: &[Value]) -> VirtReg {
let mut singletons = 0;
let mut cleared = 0;
for &val in values {
match self.get(val) {
None => singletons += 1,
Some(vreg) => {
if !self.vregs[vreg].is_empty() {
cleared += self.vregs[vreg].len(&self.pool);
self.vregs[vreg].clear(&mut self.pool);
self.unused_vregs.push(vreg);
}
}
}
}
debug_assert_eq!(
values.len(),
singletons + cleared,
"Can't unify partial virtual registers"
);
let vreg = self.alloc();
self.vregs[vreg].extend(values.iter().cloned(), &mut self.pool);
for &v in values {
self.value_vregs[v] = vreg.into();
}
vreg
}
}
impl fmt::Display for VirtRegs {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for vreg in self.all_virtregs() {
write!(f, "\n{} = {}", vreg, DisplayList(self.values(vreg)))?;
}
Ok(())
}
}
enum UFEntry {
Rank(u32),
Link(Value),
}
impl UFEntry {
fn decode(x: i32) -> Self {
if x < 0 {
UFEntry::Link(Value::new((!x) as usize))
} else {
UFEntry::Rank(x as u32)
}
}
fn encode_link(v: Value) -> i32 {
!(v.index() as i32)
}
}
impl VirtRegs {
fn find(&mut self, mut val: Value) -> (Value, u32) {
let mut val_stack = vec![];
let found = loop {
match UFEntry::decode(self.union_find[val]) {
UFEntry::Rank(rank) => break (val, rank),
UFEntry::Link(parent) => {
val_stack.push(val);
val = parent;
}
}
};
while let Some(val) = val_stack.pop() {
self.union_find[val] = UFEntry::encode_link(found.0);
}
found
}
pub fn union(&mut self, a: Value, b: Value) {
let (leader_a, rank_a) = self.find(a);
let (leader_b, rank_b) = self.find(b);
if leader_a == leader_b {
return;
}
if rank_a == 0 {
debug_assert_eq!(a, leader_a);
self.pending_values.push(a);
}
if rank_b == 0 {
debug_assert_eq!(b, leader_b);
self.pending_values.push(b);
}
match rank_a.cmp(&rank_b) {
Ordering::Less => {
self.union_find[leader_a] = UFEntry::encode_link(leader_b);
}
Ordering::Greater => {
self.union_find[leader_b] = UFEntry::encode_link(leader_a);
}
Ordering::Equal => {
self.union_find[leader_a] += 1;
self.union_find[leader_b] = UFEntry::encode_link(leader_a);
}
}
}
pub fn finish_union_find(&mut self, mut new_vregs: Option<&mut Vec<VirtReg>>) {
debug_assert_eq!(
self.pending_values.iter().find(|&&v| self.get(v).is_some()),
None,
"Values participating in union-find must not belong to existing virtual registers"
);
while let Some(val) = self.pending_values.pop() {
let (leader, _) = self.find(val);
let vreg = self.get(leader).unwrap_or_else(|| {
let vr = self.alloc();
if let Some(ref mut vec) = new_vregs {
vec.push(vr);
}
self.value_vregs[leader] = vr.into();
vr
});
self.vregs[vreg].push(val, &mut self.pool);
self.value_vregs[val] = vreg.into();
self.union_find[val] = 0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use entity::EntityRef;
use ir::Value;
#[test]
fn empty_union_find() {
let mut vregs = VirtRegs::new();
vregs.finish_union_find(None);
assert_eq!(vregs.all_virtregs().count(), 0);
}
#[test]
fn union_self() {
let mut vregs = VirtRegs::new();
let v1 = Value::new(1);
vregs.union(v1, v1);
vregs.finish_union_find(None);
assert_eq!(vregs.get(v1), None);
assert_eq!(vregs.all_virtregs().count(), 0);
}
#[test]
fn union_pair() {
let mut vregs = VirtRegs::new();
let v1 = Value::new(1);
let v2 = Value::new(2);
vregs.union(v1, v2);
vregs.finish_union_find(None);
assert_eq!(vregs.congruence_class(&v1), &[v2, v1]);
assert_eq!(vregs.congruence_class(&v2), &[v2, v1]);
assert_eq!(vregs.all_virtregs().count(), 1);
}
#[test]
fn union_pair_backwards() {
let mut vregs = VirtRegs::new();
let v1 = Value::new(1);
let v2 = Value::new(2);
vregs.union(v2, v1);
vregs.finish_union_find(None);
assert_eq!(vregs.congruence_class(&v1), &[v1, v2]);
assert_eq!(vregs.congruence_class(&v2), &[v1, v2]);
assert_eq!(vregs.all_virtregs().count(), 1);
}
#[test]
fn union_tree() {
let mut vregs = VirtRegs::new();
let v1 = Value::new(1);
let v2 = Value::new(2);
let v3 = Value::new(3);
let v4 = Value::new(4);
vregs.union(v2, v4);
vregs.union(v3, v1);
vregs.union(v4, v1);
vregs.finish_union_find(None);
assert_eq!(vregs.congruence_class(&v1), &[v1, v3, v4, v2]);
assert_eq!(vregs.congruence_class(&v2), &[v1, v3, v4, v2]);
assert_eq!(vregs.congruence_class(&v3), &[v1, v3, v4, v2]);
assert_eq!(vregs.congruence_class(&v4), &[v1, v3, v4, v2]);
assert_eq!(vregs.all_virtregs().count(), 1);
}
#[test]
fn union_two() {
let mut vregs = VirtRegs::new();
let v1 = Value::new(1);
let v2 = Value::new(2);
let v3 = Value::new(3);
let v4 = Value::new(4);
vregs.union(v2, v4);
vregs.union(v3, v1);
vregs.finish_union_find(None);
assert_eq!(vregs.congruence_class(&v1), &[v1, v3]);
assert_eq!(vregs.congruence_class(&v2), &[v4, v2]);
assert_eq!(vregs.congruence_class(&v3), &[v1, v3]);
assert_eq!(vregs.congruence_class(&v4), &[v4, v2]);
assert_eq!(vregs.all_virtregs().count(), 2);
}
#[test]
fn union_uneven() {
let mut vregs = VirtRegs::new();
let v1 = Value::new(1);
let v2 = Value::new(2);
let v3 = Value::new(3);
let v4 = Value::new(4);
vregs.union(v2, v4); vregs.union(v3, v2); vregs.union(v2, v1); vregs.finish_union_find(None);
assert_eq!(vregs.congruence_class(&v1), &[v1, v3, v4, v2]);
assert_eq!(vregs.congruence_class(&v2), &[v1, v3, v4, v2]);
assert_eq!(vregs.congruence_class(&v3), &[v1, v3, v4, v2]);
assert_eq!(vregs.congruence_class(&v4), &[v1, v3, v4, v2]);
assert_eq!(vregs.all_virtregs().count(), 1);
}
}