use crate::utils::{fast_log2sumexp2, fast_log2sumexp2_3};
#[derive(Clone)]
pub struct BitSet {
bits: Vec<u64>,
}
impl BitSet {
#[inline]
pub fn new(capacity: usize) -> Self {
let num_words = (capacity + 63) / 64;
Self {
bits: vec![0; num_words],
}
}
#[inline]
pub fn clear(&mut self) {
for word in &mut self.bits {
*word = 0;
}
}
#[inline]
pub fn insert(&mut self, idx: usize) {
let word = idx / 64;
let bit = idx % 64;
if word < self.bits.len() {
self.bits[word] |= 1 << bit;
}
}
#[inline]
pub fn contains(&self, idx: usize) -> bool {
let word = idx / 64;
let bit = idx % 64;
word < self.bits.len() && (self.bits[word] & (1 << bit)) != 0
}
#[inline]
pub fn set_from_slice(&mut self, slice: &[usize]) {
self.clear();
for &idx in slice {
self.insert(idx);
}
}
}
pub struct ScratchSpace {
pub nedge: usize,
bits_a: BitSet,
bits_b: BitSet,
bits_c: BitSet,
bits_d: BitSet,
}
impl ScratchSpace {
pub fn new(nedge: usize) -> Self {
Self {
nedge,
bits_a: BitSet::new(nedge),
bits_b: BitSet::new(nedge),
bits_c: BitSet::new(nedge),
bits_d: BitSet::new(nedge),
}
}
#[inline]
pub fn compute_intermediate_output(
&mut self,
a: &[usize],
c: &[usize],
b: &[usize],
d: &[usize],
) -> Vec<usize> {
self.bits_a.set_from_slice(a);
self.bits_b.set_from_slice(b);
self.bits_d.set_from_slice(d);
let mut output = Vec::with_capacity(a.len() + c.len());
for &l in a {
if self.bits_b.contains(l) || self.bits_d.contains(l) {
output.push(l);
}
}
for &l in c {
if !self.bits_a.contains(l) && (self.bits_b.contains(l) || self.bits_d.contains(l)) {
output.push(l);
}
}
output
}
#[inline]
pub fn tcscrw(
&mut self,
ix1: &[usize],
ix2: &[usize],
iy: &[usize],
log2_sizes: &[f64],
compute_rw: bool,
) -> (f64, f64, f64) {
self.bits_b.set_from_slice(ix2);
self.bits_c.set_from_slice(iy);
unsafe {
let sc1: f64 = if compute_rw {
ix1.iter().map(|&l| *log2_sizes.get_unchecked(l)).sum()
} else {
0.0
};
let sc2: f64 = if compute_rw {
ix2.iter().map(|&l| *log2_sizes.get_unchecked(l)).sum()
} else {
0.0
};
let sc: f64 = iy.iter().map(|&l| *log2_sizes.get_unchecked(l)).sum();
let mut tc = sc;
for &l in ix1 {
if self.bits_b.contains(l) && !self.bits_c.contains(l) {
tc += *log2_sizes.get_unchecked(l);
}
}
let rw = if compute_rw {
fast_log2sumexp2_3(sc, sc1, sc2)
} else {
0.0
};
(tc, sc, rw)
}
}
pub fn rule_diff(
&mut self,
tree: &ExprTree,
rule: Rule,
log2_sizes: &[f64],
compute_rw: bool,
) -> Option<RuleDiff> {
match tree {
ExprTree::Leaf(_) => None,
ExprTree::Node { left, right, info } => {
let d = &info.out_dims;
match rule {
Rule::Rule1 | Rule::Rule2 => match left.as_ref() {
ExprTree::Node {
left: a,
right: b,
info: ab_info,
} => {
let c = right;
let ab = &ab_info.out_dims;
let (tc_ab, sc_ab, rw_ab) =
self.tcscrw(a.labels(), b.labels(), ab, log2_sizes, compute_rw);
let (tc_d, sc_d, rw_d) =
self.tcscrw(ab, c.labels(), d, log2_sizes, compute_rw);
let tc0 = fast_log2sumexp2(tc_ab, tc_d);
let sc0 = sc_ab.max(sc_d);
let rw0 = if compute_rw {
fast_log2sumexp2(rw_ab, rw_d)
} else {
0.0
};
let new_labels = match rule {
Rule::Rule1 => self.compute_intermediate_output(
a.labels(),
c.labels(),
b.labels(),
d,
),
Rule::Rule2 => self.compute_intermediate_output(
b.labels(),
c.labels(),
a.labels(),
d,
),
_ => unreachable!(),
};
let (tc_new_left, sc_new_left, rw_new_left) = match rule {
Rule::Rule1 => self.tcscrw(
a.labels(),
c.labels(),
&new_labels,
log2_sizes,
compute_rw,
),
Rule::Rule2 => self.tcscrw(
c.labels(),
b.labels(),
&new_labels,
log2_sizes,
compute_rw,
),
_ => unreachable!(),
};
let (tc_new_d, sc_new_d, rw_new_d) = match rule {
Rule::Rule1 => {
self.tcscrw(&new_labels, b.labels(), d, log2_sizes, compute_rw)
}
Rule::Rule2 => {
self.tcscrw(&new_labels, a.labels(), d, log2_sizes, compute_rw)
}
_ => unreachable!(),
};
let tc1 = fast_log2sumexp2(tc_new_left, tc_new_d);
let sc1 = sc_new_left.max(sc_new_d);
let rw1 = if compute_rw {
fast_log2sumexp2(rw_new_left, rw_new_d)
} else {
0.0
};
Some(RuleDiff {
tc0,
tc1,
dsc: sc1 - sc0,
rw0,
rw1,
new_labels,
})
}
_ => None,
},
Rule::Rule3 | Rule::Rule4 => match right.as_ref() {
ExprTree::Node {
left: b,
right: c,
info: bc_info,
} => {
let a = left;
let bc = &bc_info.out_dims;
let (tc_bc, sc_bc, rw_bc) =
self.tcscrw(b.labels(), c.labels(), bc, log2_sizes, compute_rw);
let (tc_d, sc_d, rw_d) =
self.tcscrw(a.labels(), bc, d, log2_sizes, compute_rw);
let tc0 = fast_log2sumexp2(tc_bc, tc_d);
let sc0 = sc_bc.max(sc_d);
let rw0 = if compute_rw {
fast_log2sumexp2(rw_bc, rw_d)
} else {
0.0
};
let new_labels = match rule {
Rule::Rule3 => self.compute_intermediate_output(
c.labels(),
a.labels(),
b.labels(),
d,
),
Rule::Rule4 => self.compute_intermediate_output(
b.labels(),
a.labels(),
c.labels(),
d,
),
_ => unreachable!(),
};
let (tc_new_right, sc_new_right, rw_new_right) = match rule {
Rule::Rule3 => self.tcscrw(
a.labels(),
c.labels(),
&new_labels,
log2_sizes,
compute_rw,
),
Rule::Rule4 => self.tcscrw(
b.labels(),
a.labels(),
&new_labels,
log2_sizes,
compute_rw,
),
_ => unreachable!(),
};
let (tc_new_d, sc_new_d, rw_new_d) = match rule {
Rule::Rule3 => {
self.tcscrw(b.labels(), &new_labels, d, log2_sizes, compute_rw)
}
Rule::Rule4 => {
self.tcscrw(c.labels(), &new_labels, d, log2_sizes, compute_rw)
}
_ => unreachable!(),
};
let tc1 = fast_log2sumexp2(tc_new_right, tc_new_d);
let sc1 = sc_new_right.max(sc_new_d);
let rw1 = if compute_rw {
fast_log2sumexp2(rw_new_right, rw_new_d)
} else {
0.0
};
Some(RuleDiff {
tc0,
tc1,
dsc: sc1 - sc0,
rw0,
rw1,
new_labels,
})
}
_ => None,
},
Rule::Rule5 => {
let (tc, _sc, rw) =
self.tcscrw(left.labels(), right.labels(), d, log2_sizes, compute_rw);
Some(RuleDiff {
tc0: tc,
tc1: tc,
dsc: 0.0,
rw0: rw,
rw1: rw,
new_labels: d.clone(),
})
}
}
}
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CachedComplexity {
pub tc: f64,
pub sc: f64,
pub rw: f64,
}
#[derive(Debug, Clone)]
pub struct ExprInfo {
pub out_dims: Vec<usize>,
pub tensor_id: Option<usize>,
pub cached: Option<CachedComplexity>,
}
impl ExprInfo {
pub fn internal(out_dims: Vec<usize>) -> Self {
Self {
out_dims,
tensor_id: None,
cached: None,
}
}
pub fn leaf(out_dims: Vec<usize>, tensor_id: usize) -> Self {
Self {
out_dims,
tensor_id: Some(tensor_id),
cached: None,
}
}
}
#[derive(Debug, Clone)]
pub enum ExprTree {
Leaf(ExprInfo),
Node {
left: Box<ExprTree>,
right: Box<ExprTree>,
info: ExprInfo,
},
}
impl ExprTree {
pub fn leaf(out_dims: Vec<usize>, tensor_id: usize) -> Self {
Self::Leaf(ExprInfo::leaf(out_dims, tensor_id))
}
pub fn node(left: ExprTree, right: ExprTree, out_dims: Vec<usize>) -> Self {
Self::Node {
left: Box::new(left),
right: Box::new(right),
info: ExprInfo::internal(out_dims),
}
}
#[inline]
pub fn is_leaf(&self) -> bool {
matches!(self, Self::Leaf(_))
}
pub fn labels(&self) -> &[usize] {
match self {
Self::Leaf(info) | Self::Node { info, .. } => &info.out_dims,
}
}
pub fn tensor_id(&self) -> Option<usize> {
match self {
Self::Leaf(info) => info.tensor_id,
Self::Node { .. } => None,
}
}
pub fn info(&self) -> &ExprInfo {
match self {
Self::Leaf(info) | Self::Node { info, .. } => info,
}
}
pub fn info_mut(&mut self) -> &mut ExprInfo {
match self {
Self::Leaf(info) | Self::Node { info, .. } => info,
}
}
pub fn leaf_count(&self) -> usize {
match self {
Self::Leaf(_) => 1,
Self::Node { left, right, .. } => left.leaf_count() + right.leaf_count(),
}
}
pub fn leaf_ids(&self) -> Vec<usize> {
let mut ids = Vec::new();
self.collect_leaf_ids(&mut ids);
ids
}
fn collect_leaf_ids(&self, ids: &mut Vec<usize>) {
match self {
Self::Leaf(info) => {
if let Some(id) = info.tensor_id {
ids.push(id);
}
}
Self::Node { left, right, .. } => {
left.collect_leaf_ids(ids);
right.collect_leaf_ids(ids);
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DecompositionType {
#[default]
Tree,
Path,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Rule {
Rule1,
Rule2,
Rule3,
Rule4,
Rule5,
}
const RULES_NONE: &[Rule] = &[];
const RULES_1_2: &[Rule] = &[Rule::Rule1, Rule::Rule2];
const RULES_3_4: &[Rule] = &[Rule::Rule3, Rule::Rule4];
const RULES_1_2_3_4: &[Rule] = &[Rule::Rule1, Rule::Rule2, Rule::Rule3, Rule::Rule4];
const RULES_5: &[Rule] = &[Rule::Rule5];
const RULES_1: &[Rule] = &[Rule::Rule1];
impl Rule {
#[inline]
pub fn applicable_rules(tree: &ExprTree, decomp: DecompositionType) -> &'static [Rule] {
match tree {
ExprTree::Leaf(_) => RULES_NONE,
ExprTree::Node { left, right, .. } => {
let left_is_leaf = left.is_leaf();
let right_is_leaf = right.is_leaf();
match decomp {
DecompositionType::Tree => {
if left_is_leaf && right_is_leaf {
RULES_NONE
} else if right_is_leaf {
RULES_1_2
} else if left_is_leaf {
RULES_3_4
} else {
RULES_1_2_3_4
}
}
DecompositionType::Path => {
if left_is_leaf {
RULES_5
} else {
RULES_1
}
}
}
}
}
}
}
#[inline(always)]
fn slice_contains(slice: &[usize], elem: usize) -> bool {
slice.contains(&elem)
}
#[inline(always)]
pub fn tcscrw(
ix1: &[usize],
ix2: &[usize],
iy: &[usize],
log2_sizes: &[f64],
compute_rw: bool,
) -> (f64, f64, f64) {
unsafe {
let sc1: f64 = if compute_rw {
ix1.iter().map(|&l| *log2_sizes.get_unchecked(l)).sum()
} else {
0.0
};
let sc2: f64 = if compute_rw {
ix2.iter().map(|&l| *log2_sizes.get_unchecked(l)).sum()
} else {
0.0
};
let sc: f64 = iy.iter().map(|&l| *log2_sizes.get_unchecked(l)).sum();
let mut tc = sc;
for &l in ix1 {
if slice_contains(ix2, l) && !slice_contains(iy, l) {
tc += *log2_sizes.get_unchecked(l);
}
}
let rw = if compute_rw {
fast_log2sumexp2_3(sc, sc1, sc2)
} else {
0.0
};
(tc, sc, rw)
}
}
#[inline]
pub fn contraction_output(ix1: &[usize], ix2: &[usize], final_output: &[usize]) -> Vec<usize> {
let mut output = Vec::with_capacity(ix1.len() + ix2.len());
for &l in ix1 {
if (!slice_contains(ix2, l) || slice_contains(final_output, l))
&& !slice_contains(&output, l)
{
output.push(l);
}
}
for &l in ix2 {
if (!slice_contains(ix1, l) || slice_contains(final_output, l))
&& !slice_contains(&output, l)
{
output.push(l);
}
}
output
}
#[inline(always)]
pub fn compute_intermediate_output(
a: &[usize],
c: &[usize],
b: &[usize],
d: &[usize],
) -> Vec<usize> {
let mut output = Vec::with_capacity(a.len() + c.len());
for &l in a {
if slice_contains(b, l) || slice_contains(d, l) {
output.push(l);
}
}
for &l in c {
if !slice_contains(a, l) && (slice_contains(b, l) || slice_contains(d, l)) {
output.push(l);
}
}
output
}
pub fn tree_complexity(tree: &ExprTree, log2_sizes: &[f64]) -> (f64, f64, f64) {
if let Some(cached) = tree.info().cached {
return (cached.tc, cached.sc, cached.rw);
}
match tree {
ExprTree::Leaf(info) => {
let sc: f64 = info.out_dims.iter().map(|&l| log2_sizes[l]).sum();
(f64::NEG_INFINITY, sc, f64::NEG_INFINITY)
}
ExprTree::Node { left, right, info } => {
let (tcl, scl, rwl) = tree_complexity(left, log2_sizes);
let (tcr, scr, rwr) = tree_complexity(right, log2_sizes);
let (tc, sc, rw) = tcscrw(
left.labels(),
right.labels(),
&info.out_dims,
log2_sizes,
true,
);
(
fast_log2sumexp2_3(tc, tcl, tcr),
sc.max(scl).max(scr),
fast_log2sumexp2_3(rw, rwl, rwr),
)
}
}
}
pub fn tree_complexity_cached(tree: &mut ExprTree, log2_sizes: &[f64]) -> CachedComplexity {
if let Some(cached) = tree.info().cached {
return cached;
}
let (tc, sc, rw) = match tree {
ExprTree::Leaf(info) => {
let sc: f64 = info.out_dims.iter().map(|&l| log2_sizes[l]).sum();
(f64::NEG_INFINITY, sc, f64::NEG_INFINITY)
}
ExprTree::Node { left, right, info } => {
let left_cached = tree_complexity_cached(left, log2_sizes);
let right_cached = tree_complexity_cached(right, log2_sizes);
let (tc, sc, rw) = tcscrw(
left.labels(),
right.labels(),
&info.out_dims,
log2_sizes,
true,
);
(
fast_log2sumexp2_3(tc, left_cached.tc, right_cached.tc),
sc.max(left_cached.sc).max(right_cached.sc),
fast_log2sumexp2_3(rw, left_cached.rw, right_cached.rw),
)
}
};
let cached = CachedComplexity { tc, sc, rw };
tree.info_mut().cached = Some(cached);
cached
}
#[inline]
pub fn tree_sc_only(tree: &ExprTree, log2_sizes: &[f64]) -> f64 {
match tree {
ExprTree::Leaf(info) => info.out_dims.iter().map(|&l| log2_sizes[l]).sum(),
ExprTree::Node { left, right, info } => {
let scl = tree_sc_only(left, log2_sizes);
let scr = tree_sc_only(right, log2_sizes);
let sc: f64 = info.out_dims.iter().map(|&l| log2_sizes[l]).sum();
sc.max(scl).max(scr)
}
}
}
#[derive(Debug, Clone)]
pub struct RuleDiff {
pub tc0: f64,
pub tc1: f64,
pub dsc: f64,
pub rw0: f64,
pub rw1: f64,
pub new_labels: Vec<usize>,
}
pub fn rule_diff(
tree: &ExprTree,
rule: Rule,
log2_sizes: &[f64],
compute_rw: bool,
) -> Option<RuleDiff> {
match tree {
ExprTree::Leaf(_) => None,
ExprTree::Node { left, right, info } => {
let d = &info.out_dims;
match rule {
Rule::Rule1 | Rule::Rule2 => {
match left.as_ref() {
ExprTree::Node {
left: a,
right: b,
info: ab_info,
} => {
let c = right;
let ab = &ab_info.out_dims;
let (tc_ab, sc_ab, rw_ab) =
tcscrw(a.labels(), b.labels(), ab, log2_sizes, compute_rw);
let (tc_d, sc_d, rw_d) =
tcscrw(ab, c.labels(), d, log2_sizes, compute_rw);
let tc0 = fast_log2sumexp2(tc_ab, tc_d);
let sc0 = sc_ab.max(sc_d);
let rw0 = if compute_rw {
fast_log2sumexp2(rw_ab, rw_d)
} else {
0.0
};
let new_labels = match rule {
Rule::Rule1 => {
compute_intermediate_output(
a.labels(),
c.labels(),
b.labels(),
d,
)
}
Rule::Rule2 => {
compute_intermediate_output(
b.labels(),
c.labels(),
a.labels(),
d,
)
}
_ => unreachable!(),
};
let (tc_new_left, sc_new_left, rw_new_left) = match rule {
Rule::Rule1 => tcscrw(
a.labels(),
c.labels(),
&new_labels,
log2_sizes,
compute_rw,
),
Rule::Rule2 => tcscrw(
c.labels(),
b.labels(),
&new_labels,
log2_sizes,
compute_rw,
),
_ => unreachable!(),
};
let (tc_new_d, sc_new_d, rw_new_d) = match rule {
Rule::Rule1 => {
tcscrw(&new_labels, b.labels(), d, log2_sizes, compute_rw)
}
Rule::Rule2 => {
tcscrw(&new_labels, a.labels(), d, log2_sizes, compute_rw)
}
_ => unreachable!(),
};
let tc1 = fast_log2sumexp2(tc_new_left, tc_new_d);
let sc1 = sc_new_left.max(sc_new_d);
let rw1 = if compute_rw {
fast_log2sumexp2(rw_new_left, rw_new_d)
} else {
0.0
};
Some(RuleDiff {
tc0,
tc1,
dsc: sc1 - sc0,
rw0,
rw1,
new_labels,
})
}
_ => None,
}
}
Rule::Rule3 | Rule::Rule4 => {
match right.as_ref() {
ExprTree::Node {
left: b,
right: c,
info: bc_info,
} => {
let a = left;
let bc = &bc_info.out_dims;
let (tc_bc, sc_bc, rw_bc) =
tcscrw(b.labels(), c.labels(), bc, log2_sizes, compute_rw);
let (tc_d, sc_d, rw_d) =
tcscrw(a.labels(), bc, d, log2_sizes, compute_rw);
let tc0 = fast_log2sumexp2(tc_bc, tc_d);
let sc0 = sc_bc.max(sc_d);
let rw0 = if compute_rw {
fast_log2sumexp2(rw_bc, rw_d)
} else {
0.0
};
let new_labels = match rule {
Rule::Rule3 => {
compute_intermediate_output(
c.labels(),
a.labels(),
b.labels(),
d,
)
}
Rule::Rule4 => {
compute_intermediate_output(
b.labels(),
a.labels(),
c.labels(),
d,
)
}
_ => unreachable!(),
};
let (tc_new_right, sc_new_right, rw_new_right) = match rule {
Rule::Rule3 => tcscrw(
a.labels(),
c.labels(),
&new_labels,
log2_sizes,
compute_rw,
),
Rule::Rule4 => tcscrw(
b.labels(),
a.labels(),
&new_labels,
log2_sizes,
compute_rw,
),
_ => unreachable!(),
};
let (tc_new_d, sc_new_d, rw_new_d) = match rule {
Rule::Rule3 => {
tcscrw(b.labels(), &new_labels, d, log2_sizes, compute_rw)
}
Rule::Rule4 => {
tcscrw(c.labels(), &new_labels, d, log2_sizes, compute_rw)
}
_ => unreachable!(),
};
let tc1 = fast_log2sumexp2(tc_new_right, tc_new_d);
let sc1 = sc_new_right.max(sc_new_d);
let rw1 = if compute_rw {
fast_log2sumexp2(rw_new_right, rw_new_d)
} else {
0.0
};
Some(RuleDiff {
tc0,
tc1,
dsc: sc1 - sc0,
rw0,
rw1,
new_labels,
})
}
_ => None,
}
}
Rule::Rule5 => {
Some(RuleDiff {
tc0: 0.0,
tc1: 0.0,
dsc: 0.0,
rw0: 0.0,
rw1: 0.0,
new_labels: info.out_dims.clone(),
})
}
}
}
}
}
pub fn apply_rule_mut(tree: &mut ExprTree, rule: Rule, new_labels: Vec<usize>) {
if let ExprTree::Node { left, right, .. } = tree {
match rule {
Rule::Rule1 => {
if let ExprTree::Node {
right: b,
info: left_info,
..
} = left.as_mut()
{
std::mem::swap(b, right);
left_info.out_dims = new_labels;
}
}
Rule::Rule2 => {
if let ExprTree::Node {
left: a,
info: left_info,
..
} = left.as_mut()
{
std::mem::swap(a, right);
left_info.out_dims = new_labels;
}
}
Rule::Rule3 => {
if let ExprTree::Node {
left: b,
info: right_info,
..
} = right.as_mut()
{
std::mem::swap(left, b);
right_info.out_dims = new_labels;
}
}
Rule::Rule4 => {
if let ExprTree::Node {
right: c,
info: right_info,
..
} = right.as_mut()
{
std::mem::swap(left, c);
right_info.out_dims = new_labels;
}
}
Rule::Rule5 => {
std::mem::swap(left, right);
}
}
}
}
pub fn apply_rule(tree: ExprTree, rule: Rule, new_labels: Vec<usize>) -> ExprTree {
match tree {
ExprTree::Leaf(_) => tree,
ExprTree::Node {
left,
right,
mut info,
} => {
match rule {
Rule::Rule1 => {
match *left {
ExprTree::Node {
left: a, right: b, ..
} => {
let new_left = ExprTree::Node {
left: a,
right,
info: ExprInfo::internal(new_labels),
};
ExprTree::Node {
left: Box::new(new_left),
right: b,
info,
}
}
_ => ExprTree::Node { left, right, info },
}
}
Rule::Rule2 => {
match *left {
ExprTree::Node {
left: a, right: b, ..
} => {
let new_left = ExprTree::Node {
left: right,
right: b,
info: ExprInfo::internal(new_labels),
};
ExprTree::Node {
left: Box::new(new_left),
right: a,
info,
}
}
_ => ExprTree::Node { left, right, info },
}
}
Rule::Rule3 => {
match *right {
ExprTree::Node {
left: b, right: c, ..
} => {
let new_right = ExprTree::Node {
left,
right: c,
info: ExprInfo::internal(new_labels),
};
ExprTree::Node {
left: b,
right: Box::new(new_right),
info,
}
}
_ => ExprTree::Node { left, right, info },
}
}
Rule::Rule4 => {
match *right {
ExprTree::Node {
left: b, right: c, ..
} => {
let new_right = ExprTree::Node {
left: b,
right: left,
info: ExprInfo::internal(new_labels),
};
ExprTree::Node {
left: c,
right: Box::new(new_right),
info,
}
}
_ => ExprTree::Node { left, right, info },
}
}
Rule::Rule5 => {
info.out_dims = new_labels;
ExprTree::Node {
left: right,
right: left,
info,
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_tree() -> ExprTree {
let leaf0 = ExprTree::leaf(vec![0, 1], 0); let leaf1 = ExprTree::leaf(vec![1, 2], 1); let leaf2 = ExprTree::leaf(vec![2, 3], 2);
let inner = ExprTree::node(leaf0, leaf1, vec![0, 2]); ExprTree::node(inner, leaf2, vec![0, 3]) }
#[test]
fn test_expr_tree_leaf() {
let leaf = ExprTree::leaf(vec![0, 1], 0);
assert!(leaf.is_leaf());
assert_eq!(leaf.tensor_id(), Some(0));
assert_eq!(leaf.labels(), &[0, 1]);
}
#[test]
fn test_expr_tree_node() {
let tree = simple_tree();
assert!(!tree.is_leaf());
assert_eq!(tree.leaf_count(), 3);
assert_eq!(tree.leaf_ids(), vec![0, 1, 2]);
}
#[test]
fn test_applicable_rules_tree_decomp() {
let tree = simple_tree();
let rules = Rule::applicable_rules(&tree, DecompositionType::Tree);
assert!(rules.contains(&Rule::Rule1));
assert!(rules.contains(&Rule::Rule2));
assert!(!rules.contains(&Rule::Rule3));
}
#[test]
fn test_tcscrw() {
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0];
let (tc, sc, _rw) = tcscrw(&[0, 1], &[1, 2], &[0, 2], &log2_sizes, true);
assert!((tc - 8.0).abs() < 1e-10);
assert!((sc - 5.0).abs() < 1e-10);
}
#[test]
fn test_contraction_output() {
let output = contraction_output(&[0, 1], &[1, 2], &[0, 3]);
assert!(output.contains(&0)); assert!(output.contains(&2)); assert!(!output.contains(&1)); }
#[test]
fn test_tree_complexity() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let tree = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let log2_sizes = vec![2.0, 3.0, 2.0];
let (tc, sc, _rw) = tree_complexity(&tree, &log2_sizes);
assert!((tc - 7.0).abs() < 1e-10);
assert!((sc - 5.0).abs() < 1e-10);
}
#[test]
fn test_rule_diff_rule1() {
let tree = simple_tree();
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0];
let diff = rule_diff(&tree, Rule::Rule1, &log2_sizes, true);
assert!(diff.is_some());
let diff = diff.unwrap();
assert!(diff.tc0 > 0.0);
assert!(diff.tc1 > 0.0);
}
#[test]
fn test_rule_diff_rule2() {
let tree = simple_tree();
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0];
let diff = rule_diff(&tree, Rule::Rule2, &log2_sizes, false);
assert!(diff.is_some());
let diff = diff.unwrap();
assert!(diff.tc0 > 0.0);
assert!(diff.tc1 > 0.0);
assert_eq!(diff.rw0, 0.0);
assert_eq!(diff.rw1, 0.0);
}
#[test]
fn test_rule_diff_rule3_rule4() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0); let leaf1 = ExprTree::leaf(vec![1, 2], 1); let leaf2 = ExprTree::leaf(vec![2, 3], 2);
let inner = ExprTree::node(leaf1, leaf2, vec![1, 3]); let tree = ExprTree::node(leaf0, inner, vec![0, 3]);
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0];
let diff3 = rule_diff(&tree, Rule::Rule3, &log2_sizes, true);
assert!(diff3.is_some());
let diff4 = rule_diff(&tree, Rule::Rule4, &log2_sizes, true);
assert!(diff4.is_some());
}
#[test]
fn test_rule_diff_rule5() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let leaf2 = ExprTree::leaf(vec![2, 3], 2);
let leaf3 = ExprTree::leaf(vec![3, 4], 3);
let left = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let right = ExprTree::node(leaf2, leaf3, vec![2, 4]);
let tree = ExprTree::node(left, right, vec![0, 4]);
let log2_sizes = vec![2.0, 3.0, 3.0, 3.0, 2.0];
let diff = rule_diff(&tree, Rule::Rule5, &log2_sizes, true);
assert!(diff.is_some());
}
#[test]
fn test_apply_rule1() {
let tree = simple_tree();
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0];
let diff = rule_diff(&tree, Rule::Rule1, &log2_sizes, false).unwrap();
let new_tree = apply_rule(tree, Rule::Rule1, diff.new_labels);
assert_eq!(new_tree.leaf_count(), 3);
assert!(!new_tree.is_leaf());
}
#[test]
fn test_apply_rule3() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let leaf2 = ExprTree::leaf(vec![2, 3], 2);
let inner = ExprTree::node(leaf1, leaf2, vec![1, 3]);
let tree = ExprTree::node(leaf0, inner, vec![0, 3]);
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0];
let diff = rule_diff(&tree, Rule::Rule3, &log2_sizes, false).unwrap();
let new_tree = apply_rule(tree, Rule::Rule3, diff.new_labels);
assert_eq!(new_tree.leaf_count(), 3);
}
#[test]
fn test_apply_rule5() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let leaf2 = ExprTree::leaf(vec![2, 3], 2);
let leaf3 = ExprTree::leaf(vec![3, 4], 3);
let left = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let right = ExprTree::node(leaf2, leaf3, vec![2, 4]);
let tree = ExprTree::node(left, right, vec![0, 4]);
let log2_sizes = vec![2.0, 3.0, 3.0, 3.0, 2.0];
let diff = rule_diff(&tree, Rule::Rule5, &log2_sizes, false).unwrap();
let new_tree = apply_rule(tree, Rule::Rule5, diff.new_labels);
assert_eq!(new_tree.leaf_count(), 4);
}
#[test]
fn test_applicable_rules_path_decomp() {
let tree = simple_tree();
let rules = Rule::applicable_rules(&tree, DecompositionType::Path);
assert!(rules.contains(&Rule::Rule1));
assert!(!rules.contains(&Rule::Rule5));
}
#[test]
fn test_applicable_rules_path_decomp_left_leaf() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let leaf2 = ExprTree::leaf(vec![2, 3], 2);
let inner = ExprTree::node(leaf1, leaf2, vec![1, 3]);
let tree = ExprTree::node(leaf0, inner, vec![0, 3]);
let rules = Rule::applicable_rules(&tree, DecompositionType::Path);
assert!(rules.contains(&Rule::Rule5));
assert!(!rules.contains(&Rule::Rule1));
}
#[test]
fn test_rule_diff_on_leaf() {
let leaf = ExprTree::leaf(vec![0, 1], 0);
let log2_sizes = vec![2.0, 3.0];
let diff = rule_diff(&leaf, Rule::Rule1, &log2_sizes, false);
assert!(diff.is_none());
}
#[test]
fn test_tree_complexity_single_leaf() {
let leaf = ExprTree::leaf(vec![0, 1], 0);
let log2_sizes = vec![2.0, 3.0];
let (tc, sc, rw) = tree_complexity(&leaf, &log2_sizes);
assert!(tc < 1e-10 || tc == f64::NEG_INFINITY);
assert!((sc - 5.0).abs() < 1e-10); assert!(rw < 1e-10 || rw == f64::NEG_INFINITY);
}
#[test]
fn test_applicable_rules_both_children_nodes() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let leaf2 = ExprTree::leaf(vec![2, 3], 2);
let leaf3 = ExprTree::leaf(vec![3, 4], 3);
let left = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let right = ExprTree::node(leaf2, leaf3, vec![2, 4]);
let tree = ExprTree::node(left, right, vec![0, 4]);
let rules = Rule::applicable_rules(&tree, DecompositionType::Tree);
assert!(rules.contains(&Rule::Rule1));
assert!(rules.contains(&Rule::Rule2));
assert!(rules.contains(&Rule::Rule3));
assert!(rules.contains(&Rule::Rule4));
assert!(!rules.contains(&Rule::Rule5));
}
#[test]
fn test_contraction_output_all_contracted() {
let output = contraction_output(&[0, 1], &[0, 1], &[]);
assert!(output.is_empty());
}
#[test]
fn test_tcscrw_no_rw() {
let log2_sizes = vec![2.0, 3.0, 2.0];
let (tc, sc, rw) = tcscrw(&[0, 1], &[1, 2], &[0, 2], &log2_sizes, false);
assert!(tc > 0.0);
assert!(sc > 0.0);
assert_eq!(rw, 0.0);
}
#[test]
fn test_expr_info_internal() {
let info = ExprInfo::internal(vec![0, 1, 2]);
assert_eq!(info.out_dims, vec![0, 1, 2]);
assert!(info.tensor_id.is_none());
assert!(info.cached.is_none());
}
#[test]
fn test_expr_info_leaf() {
let info = ExprInfo::leaf(vec![0, 1], 42);
assert_eq!(info.out_dims, vec![0, 1]);
assert_eq!(info.tensor_id, Some(42));
assert!(info.cached.is_none());
}
#[test]
fn test_expr_tree_info() {
let leaf = ExprTree::leaf(vec![0, 1], 0);
let info = leaf.info();
assert_eq!(info.out_dims, vec![0, 1]);
assert_eq!(info.tensor_id, Some(0));
}
#[test]
fn test_expr_tree_info_mut() {
let mut leaf = ExprTree::leaf(vec![0, 1], 0);
{
let info = leaf.info_mut();
info.out_dims = vec![2, 3];
}
assert_eq!(leaf.labels(), &[2, 3]);
}
#[test]
fn test_tree_complexity_cached() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let mut tree = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let log2_sizes = vec![2.0, 3.0, 2.0];
let cached1 = tree_complexity_cached(&mut tree, &log2_sizes);
assert!(cached1.tc > 0.0);
assert!(cached1.sc > 0.0);
let cached2 = tree_complexity_cached(&mut tree, &log2_sizes);
assert!((cached1.tc - cached2.tc).abs() < 1e-10);
assert!((cached1.sc - cached2.sc).abs() < 1e-10);
}
#[test]
fn test_tree_complexity_cached_node() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let mut tree = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let log2_sizes = vec![2.0, 3.0, 2.0];
let cached = tree_complexity_cached(&mut tree, &log2_sizes);
assert!(cached.tc > 0.0);
assert!(cached.sc > 0.0);
}
#[test]
fn test_tree_sc_only() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let tree = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let log2_sizes = vec![2.0, 3.0, 2.0];
let sc = tree_sc_only(&tree, &log2_sizes);
let (_, full_sc, _) = tree_complexity(&tree, &log2_sizes);
assert!((sc - full_sc).abs() < 1e-10);
}
#[test]
fn test_tree_sc_only_leaf() {
let leaf = ExprTree::leaf(vec![0, 1], 0);
let log2_sizes = vec![2.0, 3.0];
let sc = tree_sc_only(&leaf, &log2_sizes);
assert!((sc - 5.0).abs() < 1e-10); }
#[test]
fn test_apply_rule_on_leaf() {
let leaf = ExprTree::leaf(vec![0, 1], 0);
let result = apply_rule(leaf.clone(), Rule::Rule1, vec![]);
assert!(result.is_leaf());
}
#[test]
fn test_apply_rule2() {
let tree = simple_tree();
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0];
let diff = rule_diff(&tree, Rule::Rule2, &log2_sizes, false).unwrap();
let new_tree = apply_rule(tree, Rule::Rule2, diff.new_labels);
assert_eq!(new_tree.leaf_count(), 3);
}
#[test]
fn test_apply_rule4() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let leaf2 = ExprTree::leaf(vec![2, 3], 2);
let inner = ExprTree::node(leaf1, leaf2, vec![1, 3]);
let tree = ExprTree::node(leaf0, inner, vec![0, 3]);
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0];
let diff = rule_diff(&tree, Rule::Rule4, &log2_sizes, false).unwrap();
let new_tree = apply_rule(tree, Rule::Rule4, diff.new_labels);
assert_eq!(new_tree.leaf_count(), 3);
}
#[test]
fn test_apply_rule_wrong_structure() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let tree = ExprTree::node(leaf0.clone(), leaf1.clone(), vec![0, 2]);
let result = apply_rule(tree, Rule::Rule1, vec![]);
assert_eq!(result.leaf_count(), 2);
}
#[test]
fn test_apply_rule3_wrong_structure() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let tree = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let result = apply_rule(tree, Rule::Rule3, vec![]);
assert_eq!(result.leaf_count(), 2);
}
#[test]
fn test_rule_diff_wrong_structure_rule1() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let tree = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let log2_sizes = vec![2.0, 3.0, 2.0];
let diff = rule_diff(&tree, Rule::Rule1, &log2_sizes, false);
assert!(diff.is_none()); }
#[test]
fn test_rule_diff_wrong_structure_rule3() {
let tree = simple_tree();
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0];
let diff = rule_diff(&tree, Rule::Rule3, &log2_sizes, false);
assert!(diff.is_none()); }
#[test]
fn test_applicable_rules_both_leaves() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let tree = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let rules = Rule::applicable_rules(&tree, DecompositionType::Tree);
assert!(rules.is_empty());
let rules_path = Rule::applicable_rules(&tree, DecompositionType::Path);
assert!(rules_path.contains(&Rule::Rule5));
}
#[test]
fn test_applicable_rules_right_is_node() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let leaf2 = ExprTree::leaf(vec![2, 3], 2);
let inner = ExprTree::node(leaf1, leaf2, vec![1, 3]);
let tree = ExprTree::node(leaf0, inner, vec![0, 3]);
let rules = Rule::applicable_rules(&tree, DecompositionType::Tree);
assert!(rules.contains(&Rule::Rule3));
assert!(rules.contains(&Rule::Rule4));
assert!(!rules.contains(&Rule::Rule1));
assert!(!rules.contains(&Rule::Rule2));
}
#[test]
fn test_tree_complexity_with_cached() {
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let mut tree = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let log2_sizes = vec![2.0, 3.0, 2.0];
tree_complexity_cached(&mut tree, &log2_sizes);
let (tc, sc, rw) = tree_complexity(&tree, &log2_sizes);
assert!(tc > 0.0);
assert!(sc > 0.0);
assert!(rw > 0.0 || rw == f64::NEG_INFINITY);
}
#[test]
fn test_leaf_ids_deep_tree() {
let leaf0 = ExprTree::leaf(vec![0], 0);
let leaf1 = ExprTree::leaf(vec![1], 1);
let leaf2 = ExprTree::leaf(vec![2], 2);
let leaf3 = ExprTree::leaf(vec![3], 3);
let inner1 = ExprTree::node(leaf0, leaf1, vec![0, 1]);
let inner2 = ExprTree::node(leaf2, leaf3, vec![2, 3]);
let tree = ExprTree::node(inner1, inner2, vec![0, 1, 2, 3]);
let ids = tree.leaf_ids();
assert_eq!(ids, vec![0, 1, 2, 3]);
}
#[test]
fn test_contraction_output_duplicates() {
let output = contraction_output(&[0, 1, 2], &[1, 2, 3], &[0, 3]);
let mut counts = std::collections::HashMap::new();
for &x in &output {
*counts.entry(x).or_insert(0) += 1;
}
for (_, count) in counts {
assert_eq!(count, 1);
}
}
#[test]
fn test_compute_intermediate_output() {
let a = vec![0, 1, 2]; let c = vec![3, 4]; let b = vec![1, 3]; let d = vec![0, 4];
let output = compute_intermediate_output(&a, &c, &b, &d);
assert!(output.contains(&0)); assert!(output.contains(&1)); assert!(output.contains(&3)); assert!(output.contains(&4)); assert!(!output.contains(&2)); assert_eq!(output.len(), 4);
}
#[test]
fn test_compute_intermediate_output_vs_contraction_output() {
let a = vec![0, 1, 2]; let c = vec![3, 4]; let b = vec![1, 3]; let d = vec![0, 4];
let contraction_out = contraction_output(&a, &c, &d);
let intermediate_out = compute_intermediate_output(&a, &c, &b, &d);
assert!(contraction_out.contains(&2)); assert!(!intermediate_out.contains(&2)); }
#[test]
fn test_bitset_basic_operations() {
let mut bs = BitSet::new(100);
assert!(!bs.contains(0));
assert!(!bs.contains(50));
assert!(!bs.contains(99));
bs.insert(0);
bs.insert(50);
bs.insert(99);
assert!(bs.contains(0));
assert!(bs.contains(50));
assert!(bs.contains(99));
assert!(!bs.contains(1));
assert!(!bs.contains(51));
bs.clear();
assert!(!bs.contains(0));
assert!(!bs.contains(50));
assert!(!bs.contains(99));
}
#[test]
fn test_bitset_set_from_slice() {
let mut bs = BitSet::new(100);
bs.set_from_slice(&[10, 20, 30, 40]);
assert!(bs.contains(10));
assert!(bs.contains(20));
assert!(bs.contains(30));
assert!(bs.contains(40));
assert!(!bs.contains(0));
assert!(!bs.contains(15));
assert!(!bs.contains(50));
bs.set_from_slice(&[5, 15]);
assert!(bs.contains(5));
assert!(bs.contains(15));
assert!(!bs.contains(10)); }
#[test]
fn test_bitset_boundary_conditions() {
let mut bs = BitSet::new(64);
bs.insert(0);
bs.insert(63);
assert!(bs.contains(0));
assert!(bs.contains(63));
let mut bs2 = BitSet::new(128);
bs2.insert(64); bs2.insert(127);
assert!(bs2.contains(64));
assert!(bs2.contains(127));
assert!(!bs2.contains(63));
}
#[test]
fn test_bitset_out_of_bounds() {
let mut bs = BitSet::new(50);
bs.insert(100);
assert!(!bs.contains(100));
}
#[test]
fn test_scratch_space_compute_intermediate_output() {
let mut scratch = ScratchSpace::new(10);
let a = vec![0, 1, 2]; let c = vec![3, 4]; let b = vec![1, 3]; let d = vec![0, 4];
let output = scratch.compute_intermediate_output(&a, &c, &b, &d);
assert!(output.contains(&0)); assert!(output.contains(&1)); assert!(output.contains(&3)); assert!(output.contains(&4)); assert!(!output.contains(&2)); }
#[test]
fn test_scratch_space_tcscrw() {
let mut scratch = ScratchSpace::new(5);
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0, 2.0];
let (tc, sc, rw) = scratch.tcscrw(&[0, 1], &[1, 2], &[0, 2], &log2_sizes, true);
assert!((tc - 8.0).abs() < 1e-10);
assert!((sc - 5.0).abs() < 1e-10);
assert!(rw > 0.0);
}
#[test]
fn test_scratch_space_tcscrw_no_rw() {
let mut scratch = ScratchSpace::new(5);
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0, 2.0];
let (tc, sc, rw) = scratch.tcscrw(&[0, 1], &[1, 2], &[0, 2], &log2_sizes, false);
assert!((tc - 8.0).abs() < 1e-10);
assert!((sc - 5.0).abs() < 1e-10);
assert_eq!(rw, 0.0); }
#[test]
fn test_scratch_space_rule_diff() {
let mut scratch = ScratchSpace::new(5);
let tree = simple_tree();
let log2_sizes = vec![2.0, 3.0, 3.0, 2.0];
let diff = scratch.rule_diff(&tree, Rule::Rule1, &log2_sizes, true);
assert!(diff.is_some());
let diff = diff.unwrap();
assert!(diff.tc0 > 0.0);
assert!(diff.tc1 > 0.0);
}
#[test]
fn test_scratch_space_rule_diff_leaf() {
let mut scratch = ScratchSpace::new(5);
let leaf = ExprTree::leaf(vec![0, 1], 0);
let log2_sizes = vec![2.0, 3.0];
let diff = scratch.rule_diff(&leaf, Rule::Rule1, &log2_sizes, true);
assert!(diff.is_none());
}
#[test]
fn test_scratch_space_all_rules() {
let mut scratch = ScratchSpace::new(10);
let leaf0 = ExprTree::leaf(vec![0, 1], 0);
let leaf1 = ExprTree::leaf(vec![1, 2], 1);
let leaf2 = ExprTree::leaf(vec![2, 3], 2);
let leaf3 = ExprTree::leaf(vec![3, 4], 3);
let log2_sizes = vec![2.0, 3.0, 3.0, 3.0, 2.0];
let inner1 = ExprTree::node(leaf0.clone(), leaf1.clone(), vec![0, 2]);
let tree12 = ExprTree::node(inner1, leaf2.clone(), vec![0, 3]);
assert!(scratch
.rule_diff(&tree12, Rule::Rule1, &log2_sizes, true)
.is_some());
assert!(scratch
.rule_diff(&tree12, Rule::Rule2, &log2_sizes, true)
.is_some());
let inner2 = ExprTree::node(leaf1.clone(), leaf2.clone(), vec![1, 3]);
let tree34 = ExprTree::node(leaf0.clone(), inner2, vec![0, 3]);
assert!(scratch
.rule_diff(&tree34, Rule::Rule3, &log2_sizes, true)
.is_some());
assert!(scratch
.rule_diff(&tree34, Rule::Rule4, &log2_sizes, true)
.is_some());
let left = ExprTree::node(leaf0, leaf1, vec![0, 2]);
let right = ExprTree::node(leaf2, leaf3, vec![2, 4]);
let tree5 = ExprTree::node(left, right, vec![0, 4]);
assert!(scratch
.rule_diff(&tree5, Rule::Rule5, &log2_sizes, true)
.is_some());
}
}