use crate::eincode::{log2_size_dict, EinCode, NestedEinsum};
use crate::incidence_list::{ContractionDims, IncidenceList};
use crate::Label;
use priority_queue::PriorityQueue;
use rand::prelude::*;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub enum ContractionTree {
Leaf(usize),
Node {
left: Box<ContractionTree>,
right: Box<ContractionTree>,
},
}
impl ContractionTree {
pub fn leaf(idx: usize) -> Self {
Self::Leaf(idx)
}
pub fn node(left: ContractionTree, right: ContractionTree) -> Self {
Self::Node {
left: Box::new(left),
right: Box::new(right),
}
}
fn fmt_with_indent(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
let prefix = " ".repeat(indent);
match self {
ContractionTree::Leaf(idx) => writeln!(f, "{}Leaf({})", prefix, idx),
ContractionTree::Node { left, right } => {
writeln!(f, "{}Node {{", prefix)?;
write!(f, "{} left: ", prefix)?;
left.fmt_with_indent(f, indent + 1)?;
write!(f, "{} right: ", prefix)?;
right.fmt_with_indent(f, indent + 1)?;
writeln!(f, "{}}}", prefix)
}
}
}
}
impl std::fmt::Display for ContractionTree {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.fmt_with_indent(f, 0)
}
}
#[derive(Debug, Clone)]
pub struct GreedyMethod {
pub alpha: f64,
pub temperature: f64,
}
impl Default for GreedyMethod {
fn default() -> Self {
Self {
alpha: 0.0,
temperature: 0.0,
}
}
}
impl GreedyMethod {
pub fn new(alpha: f64, temperature: f64) -> Self {
Self { alpha, temperature }
}
pub fn stochastic(temperature: f64) -> Self {
Self {
alpha: 0.0,
temperature,
}
}
}
#[derive(Debug, Clone, Copy)]
struct Cost(f64);
impl PartialEq for Cost {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for Cost {}
impl PartialOrd for Cost {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Cost {
fn cmp(&self, other: &Self) -> Ordering {
other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal)
}
}
fn greedy_loss(dims: &ContractionDims<impl Clone + Eq + std::hash::Hash>, alpha: f64) -> f64 {
let output_size = f64::exp2(dims.d01 + dims.d02 + dims.d012);
let input1_size = f64::exp2(dims.d01 + dims.d12 + dims.d012);
let input2_size = f64::exp2(dims.d02 + dims.d12 + dims.d012);
output_size - alpha * (input1_size + input2_size)
}
#[derive(Debug, Clone)]
pub struct GreedyResult<E>
where
E: Clone + Eq + std::hash::Hash,
{
pub tree: ContractionTree,
pub log2_tcs: Vec<f64>,
pub log2_scs: Vec<f64>,
pub output_edges: Vec<E>,
incidence_list: IncidenceList<usize, E>,
}
impl<E> GreedyResult<E>
where
E: Clone + Eq + std::hash::Hash,
{
pub fn incidence_list(&self) -> &IncidenceList<usize, E> {
&self.incidence_list
}
}
pub fn tree_greedy<E: Label>(
il: &IncidenceList<usize, E>,
log2_sizes: &HashMap<E, f64>,
alpha: f64,
temperature: f64,
) -> Option<GreedyResult<E>> {
let original_il = il.clone();
let mut il = il.clone();
let n = il.nv();
if n == 0 {
return None;
}
if n == 1 {
let v = *il.vertices().next()?;
return Some(GreedyResult {
tree: ContractionTree::leaf(v),
log2_tcs: Vec::new(),
log2_scs: Vec::new(),
output_edges: il.edges(&v).cloned().unwrap_or_default(),
incidence_list: original_il,
});
}
let mut rng = rand::rng();
let mut log2_tcs = Vec::new();
let mut log2_scs = Vec::new();
let mut trees: HashMap<usize, ContractionTree> = il
.vertices()
.map(|&v| (v, ContractionTree::leaf(v)))
.collect();
let mut pq = PriorityQueue::new();
let mut cost_graph: HashSet<(usize, usize)> = HashSet::new();
let vertices: Vec<usize> = il.vertices().cloned().collect();
for &vi in &vertices {
for vj in il.neighbors(&vi) {
if vj > vi {
let pair = (vi, vj);
let dims = ContractionDims::compute(&il, log2_sizes, &vi, &vj);
let loss = greedy_loss(&dims, alpha);
pq.push(pair, Cost(loss));
cost_graph.insert(pair);
}
}
}
let mut next_vertex = vertices.iter().max().copied().unwrap_or(0) + 1;
while il.nv() > 1 {
let (vi, vj) = if pq.is_empty() {
let vpool: Vec<usize> = il.vertices().cloned().collect();
if vpool.len() < 2 {
break;
}
(vpool[0].min(vpool[1]), vpool[0].max(vpool[1]))
} else {
let (pair, _) = select_pair(&mut pq, temperature, &mut rng, &mut cost_graph)?;
pair
};
if il.edges(&vi).is_none() || il.edges(&vj).is_none() {
continue;
}
let dims = ContractionDims::compute(&il, log2_sizes, &vi, &vj);
log2_tcs.push(dims.time_complexity());
log2_scs.push(dims.space_complexity());
let tree_i = trees.remove(&vi)?;
let tree_j = trees.remove(&vj)?;
let new_tree = ContractionTree::node(tree_i, tree_j);
let new_v = next_vertex;
next_vertex += 1;
il.set_edges(new_v, dims.edges_out.clone());
il.remove_edges(&dims.edges_remove);
il.delete_vertex(&vi);
il.delete_vertex(&vj);
trees.insert(new_v, new_tree);
for other_v in il.neighbors(&new_v) {
let pair_key = (new_v.min(other_v), new_v.max(other_v));
let new_dims = ContractionDims::compute(&il, log2_sizes, &new_v, &other_v);
let loss = greedy_loss(&new_dims, alpha);
if cost_graph.contains(&pair_key) {
pq.change_priority(&pair_key, Cost(loss));
} else {
pq.push(pair_key, Cost(loss));
cost_graph.insert(pair_key);
}
}
let pairs_to_remove: Vec<_> = cost_graph
.iter()
.filter(|(a, b)| *a == vj || *b == vj || *a == vi || *b == vi)
.cloned()
.collect();
for pair in pairs_to_remove {
pq.remove(&pair);
cost_graph.remove(&pair);
}
}
let final_tree = trees.into_values().next()?;
let output_edges = il
.vertices()
.next()
.and_then(|v| il.edges(v).cloned())
.unwrap_or_default();
Some(GreedyResult {
tree: final_tree,
log2_tcs,
log2_scs,
output_edges,
incidence_list: original_il,
})
}
fn select_pair<R: Rng>(
pq: &mut PriorityQueue<(usize, usize), Cost>,
temperature: f64,
rng: &mut R,
cost_graph: &mut HashSet<(usize, usize)>,
) -> Option<((usize, usize), Cost)> {
if pq.is_empty() {
return None;
}
let (pair1, cost1) = pq.pop()?;
cost_graph.remove(&pair1);
if temperature <= 0.0 || pq.is_empty() {
return Some((pair1, cost1));
}
let (pair2, cost2) = pq.pop()?;
cost_graph.remove(&pair2);
let delta = cost2.0 - cost1.0;
let prob = (-delta / temperature).exp();
if rng.random::<f64>() < prob {
pq.push(pair1, cost1);
cost_graph.insert(pair1);
Some((pair2, cost2))
} else {
pq.push(pair2, cost2);
cost_graph.insert(pair2);
Some((pair1, cost1))
}
}
pub fn tree_to_nested_einsum<L: Label>(
tree: &ContractionTree,
incidence_list: &IncidenceList<usize, L>,
openedges: &[L],
) -> NestedEinsum<L> {
let mut leaf_labels: HashMap<usize, Vec<L>> = HashMap::new();
collect_leaf_labels(tree, incidence_list, &mut leaf_labels);
build_nested_with_level(tree, &leaf_labels, incidence_list, openedges, 0)
}
fn collect_leaf_labels<L: Label>(
tree: &ContractionTree,
incidence_list: &IncidenceList<usize, L>,
labels: &mut HashMap<usize, Vec<L>>,
) {
match tree {
ContractionTree::Leaf(idx) => {
if let Some(edges) = incidence_list.edges(idx) {
labels.insert(*idx, edges.clone());
}
}
ContractionTree::Node { left, right } => {
collect_leaf_labels(left, incidence_list, labels);
collect_leaf_labels(right, incidence_list, labels);
}
}
}
fn build_nested_with_level<L: Label>(
tree: &ContractionTree,
leaf_labels: &HashMap<usize, Vec<L>>,
incidence_list: &IncidenceList<usize, L>,
openedges: &[L],
level: usize,
) -> NestedEinsum<L> {
match tree {
ContractionTree::Leaf(idx) => NestedEinsum::leaf(*idx),
ContractionTree::Node { left, right } => {
let left_labels = get_subtree_labels(left, leaf_labels, incidence_list);
let right_labels = get_subtree_labels(right, leaf_labels, incidence_list);
let output_labels = if level == 0 {
openedges.to_vec()
} else {
let left_vertices = get_subtree_vertices(left);
let right_vertices = get_subtree_vertices(right);
compute_contraction_output_with_hypergraph(
&left_labels,
&right_labels,
incidence_list,
&left_vertices,
&right_vertices,
)
};
let left_nested =
build_nested_with_level(left, leaf_labels, incidence_list, openedges, level + 1);
let right_nested =
build_nested_with_level(right, leaf_labels, incidence_list, openedges, level + 1);
let eins = EinCode::new(vec![left_labels, right_labels], output_labels);
NestedEinsum::node(vec![left_nested, right_nested], eins)
}
}
}
fn get_subtree_labels<L: Label>(
tree: &ContractionTree,
leaf_labels: &HashMap<usize, Vec<L>>,
incidence_list: &IncidenceList<usize, L>,
) -> Vec<L> {
match tree {
ContractionTree::Leaf(idx) => leaf_labels.get(idx).cloned().unwrap_or_default(),
ContractionTree::Node { left, right } => {
let left_labels = get_subtree_labels(left, leaf_labels, incidence_list);
let right_labels = get_subtree_labels(right, leaf_labels, incidence_list);
let left_vertices = get_subtree_vertices(left);
let right_vertices = get_subtree_vertices(right);
compute_contraction_output_with_hypergraph(
&left_labels,
&right_labels,
incidence_list,
&left_vertices,
&right_vertices,
)
}
}
}
fn get_subtree_vertices(tree: &ContractionTree) -> Vec<usize> {
match tree {
ContractionTree::Leaf(idx) => vec![*idx],
ContractionTree::Node { left, right } => {
let mut vertices = get_subtree_vertices(left);
vertices.extend(get_subtree_vertices(right));
vertices
}
}
}
fn compute_contraction_output_with_hypergraph<L: Label>(
left: &[L],
right: &[L],
incidence_list: &IncidenceList<usize, L>,
left_vertices: &[usize],
right_vertices: &[usize],
) -> Vec<L> {
use std::collections::HashSet;
let right_set: HashSet<_> = right.iter().cloned().collect();
let left_set: HashSet<_> = left.iter().cloned().collect();
let vertex_set: HashSet<_> = left_vertices
.iter()
.chain(right_vertices.iter())
.cloned()
.collect();
let mut output = Vec::new();
let mut output_set = HashSet::new();
for l in left {
let should_keep = if right_set.contains(l) {
is_index_external(l, incidence_list, &vertex_set)
} else {
true };
if should_keep && output_set.insert(l.clone()) {
output.push(l.clone());
}
}
for l in right {
if !left_set.contains(l) && output_set.insert(l.clone()) {
output.push(l.clone());
}
}
output
}
fn is_index_external<L: Label>(
index: &L,
incidence_list: &IncidenceList<usize, L>,
vertices: &std::collections::HashSet<usize>,
) -> bool {
if incidence_list.is_open(index) {
return true;
}
if let Some(connected_vertices) = incidence_list.vertices_of_edge(index) {
connected_vertices.iter().any(|v| !vertices.contains(v))
} else {
false
}
}
pub fn optimize_greedy<L: Label>(
code: &EinCode<L>,
size_dict: &HashMap<L, usize>,
config: &GreedyMethod,
) -> Option<NestedEinsum<L>> {
let il: IncidenceList<usize, L> = IncidenceList::<usize, L>::from_eincode(&code.ixs, &code.iy);
let log2_sizes = log2_size_dict(size_dict);
let result = tree_greedy(&il, &log2_sizes, config.alpha, config.temperature)?;
Some(tree_to_nested_einsum(
&result.tree,
result.incidence_list(),
&code.iy,
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_greedy_method_default() {
let method = GreedyMethod::default();
assert_eq!(method.alpha, 0.0);
assert_eq!(method.temperature, 0.0);
}
#[test]
fn test_greedy_method_new() {
let method = GreedyMethod::new(0.5, 1.0);
assert_eq!(method.alpha, 0.5);
assert_eq!(method.temperature, 1.0);
}
#[test]
fn test_greedy_method_stochastic() {
let method = GreedyMethod::stochastic(2.5);
assert_eq!(method.alpha, 0.0);
assert_eq!(method.temperature, 2.5);
}
#[test]
fn test_contraction_tree_leaf() {
let leaf = ContractionTree::leaf(42);
assert!(matches!(leaf, ContractionTree::Leaf(42)));
}
#[test]
fn test_contraction_tree_node() {
let left = ContractionTree::leaf(0);
let right = ContractionTree::leaf(1);
let node = ContractionTree::node(left, right);
assert!(matches!(node, ContractionTree::Node { .. }));
}
#[test]
fn test_contraction_tree_display_leaf() {
let leaf = ContractionTree::leaf(42);
let output = format!("{}", leaf);
assert_eq!(output.trim(), "Leaf(42)");
}
#[test]
fn test_contraction_tree_display_simple_node() {
let left = ContractionTree::leaf(0);
let right = ContractionTree::leaf(1);
let node = ContractionTree::node(left, right);
let output = format!("{}", node);
assert!(output.contains("Node {"));
assert!(output.contains(" left: Leaf(0)"));
assert!(output.contains(" right: Leaf(1)"));
assert!(output.contains("}"));
}
#[test]
fn test_contraction_tree_display_nested() {
let inner_left = ContractionTree::leaf(1);
let inner_right = ContractionTree::leaf(2);
let inner_node = ContractionTree::node(inner_left, inner_right);
let outer_left = ContractionTree::leaf(0);
let outer_node = ContractionTree::node(outer_left, inner_node);
let output = format!("{}", outer_node);
assert!(output.contains("Node {"));
assert!(output.contains(" left: Leaf(0)"));
assert!(output.contains(" right: Node {"));
assert!(output.contains(" left: Leaf(1)"));
assert!(output.contains(" right: Leaf(2)"));
let open_braces = output.matches('{').count();
let close_braces = output.matches('}').count();
assert_eq!(open_braces, close_braces);
assert_eq!(open_braces, 2); }
#[test]
fn test_contraction_tree_display_deep_nesting() {
let left_tree = ContractionTree::node(ContractionTree::leaf(0), ContractionTree::leaf(1));
let right_tree = ContractionTree::node(ContractionTree::leaf(2), ContractionTree::leaf(3));
let root = ContractionTree::node(left_tree, right_tree);
let output = format!("{}", root);
assert!(output.contains("Node {")); assert!(output.contains(" left: Node {")); assert!(output.contains(" left: Leaf(0)")); assert!(output.contains(" right: Leaf(1)"));
assert!(output.contains(" right: Node {")); assert!(output.contains(" left: Leaf(2)")); assert!(output.contains(" right: Leaf(3)"));
let open_braces = output.matches('{').count();
let close_braces = output.matches('}').count();
assert_eq!(open_braces, close_braces);
assert_eq!(open_braces, 3); }
#[test]
fn test_greedy_empty() {
let il: IncidenceList<usize, char> = IncidenceList::new(HashMap::new(), vec![]);
let log2_sizes: HashMap<char, f64> = HashMap::new();
let result = tree_greedy(&il, &log2_sizes, 0.0, 0.0);
assert!(result.is_none());
}
#[test]
fn test_greedy_single_tensor() {
let ixs = vec![vec!['i', 'j']];
let iy = vec!['i', 'j'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let mut log2_sizes = HashMap::new();
log2_sizes.insert('i', 2.0);
log2_sizes.insert('j', 2.0);
let result = tree_greedy(&il, &log2_sizes, 0.0, 0.0);
assert!(result.is_some());
let result = result.unwrap();
assert!(matches!(result.tree, ContractionTree::Leaf(0)));
}
#[test]
fn test_greedy_two_tensors() {
let ixs = vec![vec!['i', 'j'], vec!['j', 'k']];
let iy = vec!['i', 'k'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let mut log2_sizes = HashMap::new();
log2_sizes.insert('i', 2.0);
log2_sizes.insert('j', 3.0);
log2_sizes.insert('k', 2.0);
let result = tree_greedy(&il, &log2_sizes, 0.0, 0.0);
assert!(result.is_some());
let result = result.unwrap();
assert!(matches!(result.tree, ContractionTree::Node { .. }));
assert_eq!(result.log2_tcs.len(), 1);
assert_eq!(result.log2_scs.len(), 1);
}
#[test]
fn test_greedy_chain() {
let ixs = vec![vec!['i', 'j'], vec!['j', 'k'], vec!['k', 'l']];
let iy = vec!['i', 'l'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let mut log2_sizes = HashMap::new();
log2_sizes.insert('i', 2.0);
log2_sizes.insert('j', 3.0);
log2_sizes.insert('k', 3.0);
log2_sizes.insert('l', 2.0);
let result = tree_greedy(&il, &log2_sizes, 0.0, 0.0);
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.log2_tcs.len(), 2);
}
#[test]
fn test_greedy_with_alpha() {
let ixs = vec![vec!['i', 'j'], vec!['j', 'k'], vec!['k', 'l']];
let iy = vec!['i', 'l'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let mut log2_sizes = HashMap::new();
log2_sizes.insert('i', 2.0);
log2_sizes.insert('j', 3.0);
log2_sizes.insert('k', 3.0);
log2_sizes.insert('l', 2.0);
let result = tree_greedy(&il, &log2_sizes, 0.5, 0.0);
assert!(result.is_some());
}
#[test]
fn test_greedy_with_temperature() {
let ixs = vec![vec!['i', 'j'], vec!['j', 'k'], vec!['k', 'l']];
let iy = vec!['i', 'l'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let mut log2_sizes = HashMap::new();
log2_sizes.insert('i', 2.0);
log2_sizes.insert('j', 3.0);
log2_sizes.insert('k', 3.0);
log2_sizes.insert('l', 2.0);
let result = tree_greedy(&il, &log2_sizes, 0.0, 1.0);
assert!(result.is_some());
}
#[test]
fn test_optimize_greedy() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'k']], vec!['i', 'k']);
let mut size_dict = HashMap::new();
size_dict.insert('i', 4);
size_dict.insert('j', 8);
size_dict.insert('k', 4);
let config = GreedyMethod::default();
let result = optimize_greedy(&code, &size_dict, &config);
assert!(result.is_some());
let nested = result.unwrap();
assert!(nested.is_binary());
assert_eq!(nested.leaf_count(), 2);
}
#[test]
fn test_optimize_greedy_stochastic() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'k']], vec!['i', 'k']);
let mut size_dict = HashMap::new();
size_dict.insert('i', 4);
size_dict.insert('j', 8);
size_dict.insert('k', 4);
let config = GreedyMethod::stochastic(1.0);
let result = optimize_greedy(&code, &size_dict, &config);
assert!(result.is_some());
}
#[test]
fn test_tree_to_nested_einsum() {
let tree = ContractionTree::node(ContractionTree::leaf(0), ContractionTree::leaf(1));
let ixs = vec![vec!['i', 'j'], vec!['j', 'k']];
let iy = vec!['i', 'k'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let nested = tree_to_nested_einsum(&tree, &il, &iy);
assert!(nested.is_binary());
assert_eq!(nested.leaf_count(), 2);
}
#[test]
fn test_tree_to_nested_einsum_chain() {
let inner = ContractionTree::node(ContractionTree::leaf(0), ContractionTree::leaf(1));
let tree = ContractionTree::node(inner, ContractionTree::leaf(2));
let ixs = vec![vec!['i', 'j'], vec!['j', 'k'], vec!['k', 'l']];
let iy = vec!['i', 'l'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let nested = tree_to_nested_einsum(&tree, &il, &iy);
assert!(nested.is_binary());
assert_eq!(nested.leaf_count(), 3);
}
#[test]
fn test_cost_ordering() {
let cost1 = Cost(1.0);
let cost2 = Cost(2.0);
assert!(cost1 > cost2);
assert!(cost2 < cost1);
assert!(cost1 == Cost(1.0));
}
#[test]
fn test_greedy_disconnected_tensors() {
let ixs = vec![vec!['i', 'j'], vec!['k', 'l']];
let iy = vec!['i', 'j', 'k', 'l'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let mut log2_sizes = HashMap::new();
log2_sizes.insert('i', 2.0);
log2_sizes.insert('j', 2.0);
log2_sizes.insert('k', 2.0);
log2_sizes.insert('l', 2.0);
let result = tree_greedy(&il, &log2_sizes, 0.0, 0.0);
assert!(result.is_some());
}
#[test]
fn test_outer_product_returns_node_not_leaf() {
let ixs = vec![vec![0usize], vec![1usize]]; let iy = vec![0usize, 1]; let code = EinCode::new(ixs, iy);
let size_dict: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
let optimizer = GreedyMethod::new(0.0, 0.0);
let result = optimize_greedy(&code, &size_dict, &optimizer);
assert!(
result.is_some(),
"Should return Some for multi-tensor einsum"
);
let nested = result.unwrap();
assert!(
!nested.is_leaf(),
"Multi-tensor outer product should return Node, not Leaf. Got: {:?}",
nested
);
assert_eq!(
nested.leaf_count(),
2,
"Should have 2 leaves for 2 input tensors"
);
assert!(nested.is_binary(), "Should be a binary tree");
}
#[test]
fn test_outer_product_three_tensors() {
let ixs = vec![vec!['a'], vec!['b'], vec!['c']];
let iy = vec!['a', 'b', 'c'];
let code = EinCode::new(ixs, iy);
let mut size_dict = HashMap::new();
size_dict.insert('a', 2);
size_dict.insert('b', 3);
size_dict.insert('c', 4);
let result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
assert!(result.is_some());
let nested = result.unwrap();
assert!(
!nested.is_leaf(),
"3-tensor operation should not return Leaf"
);
assert_eq!(nested.leaf_count(), 3);
assert!(nested.is_binary());
}
#[test]
fn test_disconnected_contraction_tree() {
let ixs = vec![
vec!['a', 'b'],
vec!['a', 'c', 'd'],
vec!['b', 'c', 'e'],
vec!['e'],
vec!['f'], ];
let iy = vec!['a', 'f'];
let code = EinCode::new(ixs, iy);
let mut size_dict = HashMap::new();
for (i, c) in ['a', 'b', 'c', 'd', 'e', 'f'].iter().enumerate() {
size_dict.insert(*c, 1 << (i + 1)); }
let result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
assert!(
result.is_some(),
"Disconnected contraction tree should be optimizable"
);
let nested = result.unwrap();
assert_eq!(
nested.leaf_count(),
5,
"Should have 5 leaves for 5 input tensors"
);
assert!(nested.is_binary(), "Should produce a binary tree");
}
#[test]
fn test_mixed_connected_and_disconnected_tensors() {
let ixs = vec![vec!['i', 'j'], vec!['j', 'k'], vec!['m']];
let iy = vec!['i', 'k', 'm'];
let code = EinCode::new(ixs, iy);
let mut size_dict = HashMap::new();
size_dict.insert('i', 2);
size_dict.insert('j', 3);
size_dict.insert('k', 4);
size_dict.insert('m', 5);
let result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
assert!(result.is_some());
let nested = result.unwrap();
assert_eq!(nested.leaf_count(), 3);
assert!(nested.is_binary());
}
#[test]
fn test_single_element_tensors_outer_product() {
let ixs = vec![vec!['a'], vec!['b']];
let iy = vec!['a', 'b'];
let code = EinCode::new(ixs, iy);
let mut size_dict = HashMap::new();
size_dict.insert('a', 3);
size_dict.insert('b', 4);
let result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
assert!(result.is_some());
let nested = result.unwrap();
assert!(!nested.is_leaf());
assert_eq!(nested.leaf_count(), 2);
if let NestedEinsum::Node { eins, .. } = &nested {
assert!(eins.iy.contains(&'a'), "Output should contain 'a'");
assert!(eins.iy.contains(&'b'), "Output should contain 'b'");
}
}
#[test]
fn test_greedy_trace() {
let ixs = vec![vec!['i', 'j'], vec!['j', 'i']];
let iy = vec![];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let mut log2_sizes = HashMap::new();
log2_sizes.insert('i', 2.0);
log2_sizes.insert('j', 2.0);
let result = tree_greedy(&il, &log2_sizes, 0.0, 0.0);
assert!(result.is_some());
}
#[test]
fn test_hyperedge_index_preservation() {
let ixs = vec![vec![1usize, 2], vec![2usize], vec![2usize, 3]];
let out = vec![1usize, 3];
let code = EinCode::new(ixs.clone(), out.clone());
let mut sizes = HashMap::new();
sizes.insert(1usize, 2);
sizes.insert(2usize, 3);
sizes.insert(3usize, 2);
let config = GreedyMethod::default();
let nested = optimize_greedy(&code, &sizes, &config);
assert!(nested.is_some());
let nested = nested.unwrap();
assert!(nested.is_binary());
assert_eq!(nested.leaf_count(), 3);
}
#[test]
fn test_compute_hypergraph_aware_output() {
let ixs = vec![vec!['i', 'j'], vec!['j', 'k'], vec!['k', 'l']];
let iy = vec!['i', 'l'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let left = vec!['i', 'j'];
let right = vec!['j', 'k'];
let left_vertices = vec![0];
let right_vertices = vec![1];
let output = compute_contraction_output_with_hypergraph(
&left,
&right,
&il,
&left_vertices,
&right_vertices,
);
assert!(output.contains(&'i'));
assert!(!output.contains(&'j'));
assert!(output.contains(&'k'));
}
#[test]
fn test_compute_hypergraph_aware_output_simple_contraction() {
let ixs = vec![vec!['i', 'j'], vec!['j', 'k']];
let iy = vec!['i', 'k'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let left = vec!['i', 'j'];
let right = vec!['j', 'k'];
let left_vertices = vec![0];
let right_vertices = vec![1];
let output = compute_contraction_output_with_hypergraph(
&left,
&right,
&il,
&left_vertices,
&right_vertices,
);
assert_eq!(output.len(), 2);
assert!(output.contains(&'i'));
assert!(output.contains(&'k'));
assert!(!output.contains(&'j'));
}
#[test]
fn test_compute_hypergraph_aware_output_hyperedge() {
let ixs = vec![vec!['i', 'j'], vec!['i', 'k'], vec!['i', 'l']];
let iy = vec![];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let left = vec!['i', 'j'];
let right = vec!['i', 'k'];
let left_vertices = vec![0];
let right_vertices = vec![1];
let output = compute_contraction_output_with_hypergraph(
&left,
&right,
&il,
&left_vertices,
&right_vertices,
);
assert_eq!(output.len(), 3);
assert!(output.contains(&'i'), "Hyperedge 'i' should be preserved");
assert!(output.contains(&'j'));
assert!(output.contains(&'k'));
}
#[test]
fn test_compute_hypergraph_aware_output_trace() {
let ixs = vec![vec!['i', 'i'], vec!['i', 'j']];
let iy = vec!['j'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let left = vec!['i', 'i']; let right = vec!['i', 'j'];
let left_vertices = vec![0];
let right_vertices = vec![1];
let output = compute_contraction_output_with_hypergraph(
&left,
&right,
&il,
&left_vertices,
&right_vertices,
);
assert!(output.contains(&'j'));
assert!(!output.contains(&'i') || output.iter().filter(|&&x| x == 'i').count() == 0);
}
#[test]
fn test_compute_hypergraph_aware_output_open_edge() {
let ixs = vec![vec!['i', 'j'], vec!['j', 'k']];
let iy = vec!['i', 'k']; let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let left = vec!['i', 'j'];
let right = vec!['j', 'k'];
let left_vertices = vec![0];
let right_vertices = vec![1];
let output = compute_contraction_output_with_hypergraph(
&left,
&right,
&il,
&left_vertices,
&right_vertices,
);
assert_eq!(output.len(), 2);
assert!(output.contains(&'i'));
assert!(output.contains(&'k'));
assert!(!output.contains(&'j'));
}
#[test]
fn test_compute_hypergraph_aware_output_no_common_indices() {
let ixs = vec![vec!['i', 'j'], vec!['k', 'l']];
let iy = vec!['i', 'j', 'k', 'l'];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let left = vec!['i', 'j'];
let right = vec!['k', 'l'];
let left_vertices = vec![0];
let right_vertices = vec![1];
let output = compute_contraction_output_with_hypergraph(
&left,
&right,
&il,
&left_vertices,
&right_vertices,
);
assert_eq!(output.len(), 4);
assert!(output.contains(&'i'));
assert!(output.contains(&'j'));
assert!(output.contains(&'k'));
assert!(output.contains(&'l'));
}
#[test]
fn test_compute_hypergraph_aware_output_complex_hyperedge() {
let ixs = vec![
vec!['i', 'j', 'k'],
vec!['i', 'k', 'l'],
vec!['k', 'm'],
vec!['k', 'n'],
];
let iy = vec![];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let left = vec!['i', 'j', 'k'];
let right = vec!['i', 'k', 'l'];
let left_vertices = vec![0];
let right_vertices = vec![1];
let output = compute_contraction_output_with_hypergraph(
&left,
&right,
&il,
&left_vertices,
&right_vertices,
);
assert!(output.contains(&'k'), "Hyperedge 'k' should be preserved");
assert!(output.contains(&'j'));
assert!(output.contains(&'l'));
}
#[test]
fn test_compute_hypergraph_aware_output_all_contract() {
let ixs = vec![vec!['i', 'j'], vec!['i', 'j']];
let iy = vec![];
let il = IncidenceList::<usize, char>::from_eincode(&ixs, &iy);
let left = vec!['i', 'j'];
let right = vec!['i', 'j'];
let left_vertices = vec![0];
let right_vertices = vec![1];
let output = compute_contraction_output_with_hypergraph(
&left,
&right,
&il,
&left_vertices,
&right_vertices,
);
assert_eq!(
output.len(),
0,
"All indices should contract to produce scalar"
);
}
}
#[cfg(test)]
mod extensive_tests {
use super::*;
use crate::test_utils::{generate_random_eincode, NaiveContractor};
fn execute_nested(nested: &NestedEinsum<usize>, contractor: &mut NaiveContractor) -> usize {
match nested {
NestedEinsum::Leaf { tensor_index } => *tensor_index,
NestedEinsum::Node { args, eins } => {
let left_idx = execute_nested(&args[0], contractor);
let right_idx = execute_nested(&args[1], contractor);
contractor.contract(left_idx, right_idx, &eins.ixs[0], &eins.ixs[1], &eins.iy)
}
}
}
#[test]
fn test_issue_6_regression() {
let ixs = vec![vec![1usize, 2], vec![2usize], vec![2usize, 3]];
let out = vec![1usize, 3];
let code = EinCode::new(ixs.clone(), out.clone());
let mut sizes = HashMap::new();
sizes.insert(1usize, 2); sizes.insert(2usize, 3); sizes.insert(3usize, 2);
let config = GreedyMethod::default();
let nested = optimize_greedy(&code, &sizes, &config).unwrap();
let mut contractor = NaiveContractor::new();
contractor.add_tensor(0, vec![2, 3]); contractor.add_tensor(1, vec![3]); contractor.add_tensor(2, vec![3, 2]);
let result_idx = execute_nested(&nested, &mut contractor);
let result_shape = contractor.get_shape(result_idx).unwrap();
assert_eq!(
*result_shape,
vec![2, 2],
"Result should be 2x2 for indices i,k"
);
}
#[test]
fn test_large_graph_stress() {
let mut ixs = Vec::new();
let n = 10;
for i in 1..=n {
for j in 1..=n {
let idx = (i - 1) * n + j;
if j < n {
ixs.push(vec![idx, idx + 1]);
}
if i < n {
ixs.push(vec![idx, idx + n]);
}
}
}
let code = EinCode::new(ixs.clone(), vec![]);
let size_dict: HashMap<usize, usize> = (1..=n * n).map(|i| (i, 2)).collect();
let config = GreedyMethod::default();
let nested = optimize_greedy(&code, &size_dict, &config).unwrap();
let mut contractor = NaiveContractor::new();
for i in 0..ixs.len() {
contractor.add_tensor(i, vec![2, 2]);
}
let result_idx = execute_nested(&nested, &mut contractor);
let result_tensor = contractor.get_tensor(result_idx).unwrap();
assert_eq!(
result_tensor.ndim(),
0,
"Grid contraction should produce scalar"
);
}
#[test]
fn test_ring_topology() {
let n = 10;
let ixs: Vec<Vec<usize>> = (0..n).map(|i| vec![i + 1, ((i + 1) % n) + 1]).collect();
let code = EinCode::new(ixs.clone(), vec![]);
let size_dict: HashMap<usize, usize> = (1..=n).map(|i| (i, 2)).collect();
let nested = optimize_greedy(&code, &size_dict, &GreedyMethod::default()).unwrap();
assert!(
nested.is_binary(),
"Ring optimization should produce binary tree"
);
}
#[test]
fn test_chain_topology() {
let ixs = vec![vec![1, 2], vec![2, 3], vec![3, 4], vec![4, 5]];
let output = vec![1, 5]; let code = EinCode::new(ixs.clone(), output.clone());
let size_dict: HashMap<usize, usize> = (1..=5).map(|i| (i, 2)).collect();
let nested = optimize_greedy(&code, &size_dict, &GreedyMethod::default()).unwrap();
let mut contractor = NaiveContractor::new();
for i in 0..4 {
contractor.add_tensor(i, vec![2, 2]);
}
let result_idx = execute_nested(&nested, &mut contractor);
let result_tensor = contractor.get_tensor(result_idx).unwrap();
assert_eq!(
result_tensor.shape(),
&[2, 2],
"Chain contraction should produce 2x2 matrix for output [1,5]"
);
}
#[test]
fn test_random_instances_basic() {
for iteration in 0..10 {
let (ixs, output) = generate_random_eincode(
3 + iteration % 3, 8, false, false, );
if ixs.is_empty() {
continue;
}
let code = EinCode::new(ixs.clone(), output.clone());
let size_dict: HashMap<usize, usize> = (1..=20).map(|i| (i, 2)).collect();
let nested_result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
assert!(
nested_result.is_some(),
"Greedy optimization should succeed for valid random instance"
);
if let Some(nested) = nested_result {
let mut contractor = NaiveContractor::new();
for (i, tensor_indices) in ixs.iter().enumerate() {
let shape: Vec<usize> = tensor_indices
.iter()
.map(|&idx| *size_dict.get(&idx).unwrap_or(&2))
.collect();
contractor.add_tensor(i, shape);
}
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
execute_nested(&nested, &mut contractor)
}));
}
}
}
#[test]
fn test_random_instances_with_duplicates() {
for iteration in 0..10 {
let (ixs, output) = generate_random_eincode(
2 + iteration % 3, 8, true, false,
);
if ixs.is_empty() {
continue;
}
let code = EinCode::new(ixs.clone(), output.clone());
let size_dict: HashMap<usize, usize> = (1..=20).map(|i| (i, 2)).collect();
let nested_result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
if let Some(nested) = nested_result {
let mut contractor = NaiveContractor::new();
for (i, tensor_indices) in ixs.iter().enumerate() {
let shape: Vec<usize> = tensor_indices
.iter()
.map(|&idx| *size_dict.get(&idx).unwrap_or(&2))
.collect();
contractor.add_tensor(i, shape);
}
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
execute_nested(&nested, &mut contractor);
}));
}
}
}
#[test]
fn test_random_instances_with_output_only_indices() {
for iteration in 0..10 {
let (ixs, output) = generate_random_eincode(
2 + iteration % 3, 8,
false,
true, );
if ixs.is_empty() || output.is_empty() {
continue;
}
let code = EinCode::new(ixs.clone(), output.clone());
let size_dict: HashMap<usize, usize> = (1..=25).map(|i| (i, 2)).collect();
let nested_result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
if let Some(nested) = nested_result {
assert!(
nested.is_binary() || nested.leaf_count() == 1,
"Result should be valid tree"
);
}
}
}
#[test]
fn test_random_instances_all_edge_cases() {
for iteration in 0..20 {
let (ixs, output) = generate_random_eincode(
2 + iteration % 5, 12,
true, true, );
if ixs.is_empty() {
continue;
}
let code = EinCode::new(ixs.clone(), output.clone());
let size_dict: HashMap<usize, usize> = (1..=25).map(|i| (i, 2)).collect();
let nested_result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
if let Some(nested) = nested_result {
let mut contractor = NaiveContractor::new();
for (i, tensor_indices) in ixs.iter().enumerate() {
let shape: Vec<usize> = tensor_indices
.iter()
.map(|&idx| *size_dict.get(&idx).unwrap_or(&2))
.collect();
contractor.add_tensor(i, shape);
}
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
execute_nested(&nested, &mut contractor);
}));
}
}
}
#[test]
fn test_edge_case_with_trace_and_broadcast() {
use crate::treesa::TreeSA;
use crate::CodeOptimizer;
let ixs = vec![
vec!['i', 'i'], vec!['i', 'k'], vec!['i', 'k', 'l'], vec!['k', 'k'], ];
let output = vec!['k', 'i', 'i', 'm'];
let code = EinCode::new(ixs.clone(), output.clone());
let mut sizes = HashMap::new();
sizes.insert('i', 2);
sizes.insert('k', 2);
sizes.insert('l', 2);
sizes.insert('m', 2);
let greedy_config = GreedyMethod::default();
let greedy_result = greedy_config.optimize(&code, &sizes);
assert!(
greedy_result.is_some(),
"Greedy should handle trace + broadcast case"
);
if let Some(greedy_nested) = greedy_result {
assert!(
greedy_nested.is_binary() || greedy_nested.leaf_count() == 1,
"Greedy result should be valid"
);
if let NestedEinsum::Node { eins, .. } = &greedy_nested {
assert!(
eins.iy.contains(&'k'),
"Greedy result should contain index 'k'"
);
assert!(
eins.iy.contains(&'i'),
"Greedy result should contain index 'i'"
);
}
}
let treesa_config = TreeSA::fast();
let treesa_result = treesa_config.optimize(&code, &sizes);
assert!(
treesa_result.is_some(),
"TreeSA should handle trace + broadcast case"
);
if let Some(treesa_nested) = treesa_result {
assert!(
treesa_nested.is_binary() || treesa_nested.leaf_count() == 1,
"TreeSA result should be valid"
);
if let NestedEinsum::Node { eins, .. } = &treesa_nested {
assert!(
eins.iy.contains(&'k') || eins.iy.contains(&'i'),
"TreeSA result should contain at least one index from inputs"
);
}
}
}
#[test]
fn test_cross_optimizer_simple_chain() {
use crate::test_utils::{execute_nested, tensors_approx_equal, NaiveContractor};
use crate::treesa::TreeSA;
use crate::CodeOptimizer;
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'k']], vec!['i', 'k']);
let mut sizes = HashMap::new();
sizes.insert('i', 3);
sizes.insert('j', 4);
sizes.insert('k', 3);
let label_map: HashMap<char, usize> =
vec![('i', 1), ('j', 2), ('k', 3)].into_iter().collect();
let mut contractor1 = NaiveContractor::new();
contractor1.add_tensor(0, vec![3, 4]); contractor1.add_tensor(1, vec![4, 3]);
let mut contractor2 = contractor1.clone();
let greedy_result = GreedyMethod::default()
.optimize(&code, &sizes)
.expect("Greedy should succeed");
let treesa_result = TreeSA::fast()
.optimize(&code, &sizes)
.expect("TreeSA should succeed");
let greedy_idx = execute_nested(&greedy_result, &mut contractor1, &label_map);
let treesa_idx = execute_nested(&treesa_result, &mut contractor2, &label_map);
let greedy_tensor = contractor1
.get_tensor(greedy_idx)
.expect("Result should exist");
let treesa_tensor = contractor2
.get_tensor(treesa_idx)
.expect("Result should exist");
assert!(
tensors_approx_equal(greedy_tensor, treesa_tensor, 1e-5, 1e-8),
"Greedy and TreeSA should produce same numerical result"
);
}
#[test]
fn test_cross_optimizer_3_regular_graph_small() {
use crate::test_utils::{
execute_nested, generate_ring_edges, tensors_approx_equal, NaiveContractor,
};
use crate::treesa::TreeSA;
use crate::CodeOptimizer;
let n = 10;
let edges = generate_ring_edges(n);
let mut ixs: Vec<Vec<usize>> = edges.iter().map(|&(i, j)| vec![i, j]).collect();
for i in 1..=n {
ixs.push(vec![i]);
}
let code = EinCode::new(ixs.clone(), vec![]);
let sizes: HashMap<usize, usize> = (1..=n).map(|i| (i, 2)).collect();
let label_map: HashMap<usize, usize> = (1..=n).map(|i| (i, i)).collect();
let mut contractor1 = NaiveContractor::new();
for (idx, ix) in ixs.iter().enumerate() {
let shape: Vec<usize> = ix.iter().map(|&label| sizes[&label]).collect();
contractor1.add_tensor(idx, shape);
}
let mut contractor2 = contractor1.clone();
let greedy_result = GreedyMethod::default()
.optimize(&code, &sizes)
.expect("Greedy should succeed");
let treesa_result = TreeSA::fast()
.optimize(&code, &sizes)
.expect("TreeSA should succeed");
let greedy_idx = execute_nested(&greedy_result, &mut contractor1, &label_map);
let treesa_idx = execute_nested(&treesa_result, &mut contractor2, &label_map);
let greedy_tensor = contractor1
.get_tensor(greedy_idx)
.expect("Greedy result should exist");
let treesa_tensor = contractor2
.get_tensor(treesa_idx)
.expect("TreeSA result should exist");
eprintln!("Greedy tensor shape: {:?}", greedy_tensor.shape());
eprintln!("TreeSA tensor shape: {:?}", treesa_tensor.shape());
eprintln!("Greedy tensor sum: {}", greedy_tensor.iter().sum::<f64>());
eprintln!("TreeSA tensor sum: {}", treesa_tensor.iter().sum::<f64>());
assert!(
tensors_approx_equal(greedy_tensor, treesa_tensor, 1e-5, 1e-8),
"Greedy and TreeSA should produce same numerical result for 3-regular graph.\nGreedy shape: {:?}, TreeSA shape: {:?}",
greedy_tensor.shape(), treesa_tensor.shape()
);
}
#[test]
fn test_cross_optimizer_with_trace() {
use crate::test_utils::{execute_nested, tensors_approx_equal, NaiveContractor};
use crate::treesa::TreeSA;
use crate::CodeOptimizer;
let code = EinCode::new(vec![vec!['i', 'i'], vec!['i', 'j']], vec!['j']);
let mut sizes = HashMap::new();
sizes.insert('i', 3);
sizes.insert('j', 4);
let label_map: HashMap<char, usize> = vec![('i', 1), ('j', 2)].into_iter().collect();
let mut contractor1 = NaiveContractor::new();
contractor1.add_tensor(0, vec![3, 3]); contractor1.add_tensor(1, vec![3, 4]);
let mut contractor2 = contractor1.clone();
let greedy_result = GreedyMethod::default()
.optimize(&code, &sizes)
.expect("Greedy should succeed");
let treesa_result = TreeSA::fast()
.optimize(&code, &sizes)
.expect("TreeSA should succeed");
let greedy_idx = execute_nested(&greedy_result, &mut contractor1, &label_map);
let treesa_idx = execute_nested(&treesa_result, &mut contractor2, &label_map);
let greedy_tensor = contractor1
.get_tensor(greedy_idx)
.expect("Result should exist");
let treesa_tensor = contractor2
.get_tensor(treesa_idx)
.expect("Result should exist");
assert!(
tensors_approx_equal(greedy_tensor, treesa_tensor, 1e-5, 1e-8),
"Greedy and TreeSA should produce same result with trace"
);
}
#[test]
fn test_optimize_petersen_graph() {
use crate::complexity::nested_complexity;
use crate::test_utils::generate_petersen_edges;
let edges = generate_petersen_edges();
assert_eq!(edges.len(), 15, "Petersen graph should have 15 edges");
let ixs: Vec<Vec<usize>> = edges.iter().map(|&(a, b)| vec![a, b]).collect();
let code = EinCode::new(ixs, vec![]);
let size_dict: HashMap<usize, usize> = (1..=10).map(|i| (i, 2)).collect();
let result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
assert!(result.is_some(), "Should optimize Petersen graph");
let nested = result.unwrap();
assert!(nested.is_binary(), "Result should be binary tree");
assert_eq!(
nested.leaf_count(),
15,
"Should have 15 leaves (one per edge)"
);
let cc = nested_complexity(&nested, &size_dict, &code.ixs);
assert!(
cc.sc <= 6.0,
"Space complexity should be reasonable, got {}",
cc.sc
);
}
#[test]
fn test_optimize_fullerene_c60() {
use crate::complexity::nested_complexity;
use crate::test_utils::generate_fullerene_edges;
let edges = generate_fullerene_edges();
assert!(!edges.is_empty(), "Fullerene should have edges");
let ixs: Vec<Vec<usize>> = edges.iter().map(|&(a, b)| vec![a, b]).collect();
let code = EinCode::new(ixs.clone(), vec![]);
let max_vertex = edges
.iter()
.flat_map(|&(a, b)| vec![a, b])
.max()
.unwrap_or(60);
let size_dict: HashMap<usize, usize> = (1..=max_vertex).map(|i| (i, 2)).collect();
let result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
assert!(result.is_some(), "Should optimize fullerene graph");
let nested = result.unwrap();
assert!(nested.is_binary(), "Result should be binary tree");
assert_eq!(
nested.leaf_count(),
ixs.len(),
"Should have correct number of leaves"
);
let cc = nested_complexity(&nested, &size_dict, &code.ixs);
assert!(cc.sc.is_finite(), "Space complexity should be finite");
assert!(cc.tc.is_finite(), "Time complexity should be finite");
}
#[test]
fn test_optimize_chain_with_complexity() {
use crate::complexity::nested_complexity;
use crate::test_utils::generate_chain_edges;
for n in [5, 10, 15] {
let edges = generate_chain_edges(n);
let ixs: Vec<Vec<usize>> = edges.iter().map(|&(a, b)| vec![a, b]).collect();
let code = EinCode::new(ixs.clone(), vec![1, n]);
let size_dict: HashMap<usize, usize> = (1..=n).map(|i| (i, 4)).collect();
let result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
assert!(result.is_some(), "Should optimize chain of length {}", n);
let nested = result.unwrap();
let cc = nested_complexity(&nested, &size_dict, &code.ixs);
assert!(
cc.sc <= 5.0,
"Chain sc should be ~4, got {} for n={}",
cc.sc,
n
);
}
}
#[test]
fn test_optimize_ring_with_complexity() {
use crate::complexity::nested_complexity;
use crate::test_utils::generate_ring_edges;
for n in [5, 10, 15] {
let edges = generate_ring_edges(n);
let ixs: Vec<Vec<usize>> = edges.iter().map(|&(a, b)| vec![a, b]).collect();
let code = EinCode::new(ixs.clone(), vec![]);
let size_dict: HashMap<usize, usize> = (1..=n).map(|i| (i, 2)).collect();
let result = optimize_greedy(&code, &size_dict, &GreedyMethod::default());
assert!(result.is_some(), "Should optimize ring of size {}", n);
let nested = result.unwrap();
let cc = nested_complexity(&nested, &size_dict, &code.ixs);
assert!(
cc.sc <= 4.0,
"Ring sc should be low, got {} for n={}",
cc.sc,
n
);
}
}
}