use crate::eincode::{EinCode, NestedEinsum};
use crate::Label;
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct MergeOperation<L: Label> {
pub vector_index: usize,
pub target_index: usize,
pub shared_index: L,
}
#[derive(Debug, Clone)]
pub struct NetworkSimplifier<L: Label> {
merges: Vec<MergeOperation<L>>,
simplified_code: EinCode<L>,
original_code: EinCode<L>,
index_map: Vec<usize>,
}
impl<L: Label> NetworkSimplifier<L> {
pub fn simplified_code(&self) -> &EinCode<L> {
&self.simplified_code
}
pub fn original_code(&self) -> &EinCode<L> {
&self.original_code
}
pub fn is_simplified(&self) -> bool {
!self.merges.is_empty()
}
pub fn num_merges(&self) -> usize {
self.merges.len()
}
pub fn merges(&self) -> &[MergeOperation<L>] {
&self.merges
}
pub fn embed(&self, simplified_tree: NestedEinsum<L>) -> NestedEinsum<L> {
if !self.is_simplified() {
return simplified_tree;
}
self.embed_recursive(simplified_tree)
}
fn embed_recursive(&self, tree: NestedEinsum<L>) -> NestedEinsum<L> {
match tree {
NestedEinsum::Leaf { tensor_index } => {
let original_idx = self.index_map[tensor_index];
let merged_vectors: Vec<&MergeOperation<L>> = self
.merges
.iter()
.filter(|m| m.target_index == original_idx)
.collect();
if merged_vectors.is_empty() {
NestedEinsum::leaf(original_idx)
} else {
let target_indices = &self.original_code.ixs[original_idx];
let mut current = NestedEinsum::leaf(original_idx);
let current_indices = target_indices.clone();
for merge in merged_vectors {
let vector_indices = &self.original_code.ixs[merge.vector_index];
let vector_leaf = NestedEinsum::leaf(merge.vector_index);
let eins = EinCode::new(
vec![current_indices.clone(), vector_indices.clone()],
current_indices.clone(),
);
current = NestedEinsum::node(vec![current, vector_leaf], eins);
}
current
}
}
NestedEinsum::Node { args, eins } => {
let embedded_args: Vec<NestedEinsum<L>> =
args.into_iter().map(|a| self.embed_recursive(a)).collect();
NestedEinsum::node(embedded_args, eins)
}
}
}
}
pub fn merge_vectors<L: Label>(code: &EinCode<L>) -> NetworkSimplifier<L> {
let num_tensors = code.ixs.len();
let mut removed: HashSet<usize> = HashSet::new();
let mut merges: Vec<MergeOperation<L>> = Vec::new();
for vec_idx in 0..num_tensors {
if removed.contains(&vec_idx) {
continue;
}
let tensor_indices = &code.ixs[vec_idx];
if tensor_indices.len() != 1 {
continue; }
let vec_label = &tensor_indices[0];
let target = code
.ixs
.iter()
.enumerate()
.find(|(i, ix)| *i != vec_idx && !removed.contains(i) && ix.contains(vec_label));
if let Some((target_idx, _)) = target {
merges.push(MergeOperation {
vector_index: vec_idx,
target_index: target_idx,
shared_index: vec_label.clone(),
});
removed.insert(vec_idx);
}
}
let kept_indices: Vec<usize> = (0..num_tensors).filter(|i| !removed.contains(i)).collect();
let simplified_ixs: Vec<Vec<L>> = kept_indices.iter().map(|&i| code.ixs[i].clone()).collect();
let simplified_code = EinCode::new(simplified_ixs, code.iy.clone());
let index_map = kept_indices;
NetworkSimplifier {
merges,
simplified_code,
original_code: code.clone(),
index_map,
}
}
pub fn optimize_simplified<L: Label, O: crate::CodeOptimizer>(
code: &EinCode<L>,
size_dict: &std::collections::HashMap<L, usize>,
optimizer: &O,
) -> Option<NestedEinsum<L>> {
let simplifier = merge_vectors(code);
if !simplifier.is_simplified() {
return optimizer.optimize(code, size_dict);
}
let simplified_tree = optimizer.optimize(simplifier.simplified_code(), size_dict)?;
Some(simplifier.embed(simplified_tree))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eincode::uniform_size_dict;
use crate::greedy::{optimize_greedy, GreedyMethod};
#[test]
fn test_merge_vectors_simple() {
let code = EinCode::new(
vec![vec!['i', 'j'], vec!['j'], vec!['j', 'k']],
vec!['i', 'k'],
);
let simplifier = merge_vectors(&code);
assert!(simplifier.is_simplified());
assert_eq!(simplifier.num_merges(), 1);
assert_eq!(simplifier.simplified_code().num_tensors(), 2);
let merge = &simplifier.merges()[0];
assert_eq!(merge.vector_index, 1);
assert!(merge.target_index == 0 || merge.target_index == 2);
assert_eq!(merge.shared_index, 'j');
}
#[test]
fn test_merge_vectors_multiple_vectors() {
let code = EinCode::new(
vec![vec!['i', 'j'], vec!['j'], vec!['k'], vec!['j', 'k']],
vec!['i'],
);
let simplifier = merge_vectors(&code);
assert_eq!(simplifier.num_merges(), 2);
assert_eq!(simplifier.simplified_code().num_tensors(), 2);
}
#[test]
fn test_merge_vectors_no_vectors() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'k']], vec!['i', 'k']);
let simplifier = merge_vectors(&code);
assert!(!simplifier.is_simplified());
assert_eq!(simplifier.num_merges(), 0);
assert_eq!(simplifier.simplified_code().num_tensors(), 2);
}
#[test]
fn test_merge_vectors_disconnected_vector() {
let code = EinCode::new(
vec![vec!['i', 'j'], vec!['k'], vec!['j', 'l']],
vec!['i', 'l', 'k'],
);
let simplifier = merge_vectors(&code);
assert!(!simplifier.is_simplified());
}
#[test]
fn test_embed_simplifier_simple() {
let code = EinCode::new(
vec![vec!['i', 'j'], vec!['j'], vec!['j', 'k']],
vec!['i', 'k'],
);
let sizes = uniform_size_dict(&code, 4);
let simplifier = merge_vectors(&code);
let simplified_tree = optimize_greedy(
simplifier.simplified_code(),
&sizes,
&GreedyMethod::default(),
)
.unwrap();
let embedded = simplifier.embed(simplified_tree);
assert_eq!(embedded.leaf_count(), 3);
assert!(embedded.is_binary());
}
#[test]
fn test_embed_preserves_correctness() {
let code = EinCode::new(
vec![vec!['a', 'b'], vec!['b'], vec!['b', 'c'], vec!['c']],
vec!['a'],
);
let sizes = uniform_size_dict(&code, 2);
let simplifier = merge_vectors(&code);
assert_eq!(simplifier.num_merges(), 2);
let simplified_tree = optimize_greedy(
simplifier.simplified_code(),
&sizes,
&GreedyMethod::default(),
)
.unwrap();
let embedded = simplifier.embed(simplified_tree);
assert_eq!(embedded.leaf_count(), 4); }
#[test]
fn test_optimize_simplified() {
let code = EinCode::new(
vec![vec!['i', 'j'], vec!['j'], vec!['j', 'k'], vec!['k']],
vec!['i'],
);
let sizes = uniform_size_dict(&code, 4);
let result = optimize_simplified(&code, &sizes, &GreedyMethod::default());
assert!(result.is_some());
let tree = result.unwrap();
assert_eq!(tree.leaf_count(), 4);
}
#[test]
fn test_optimize_simplified_no_vectors() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'k']], vec!['i', 'k']);
let sizes = uniform_size_dict(&code, 4);
let result = optimize_simplified(&code, &sizes, &GreedyMethod::default());
assert!(result.is_some());
let tree = result.unwrap();
assert_eq!(tree.leaf_count(), 2);
}
#[test]
fn test_simplification_with_treesa() {
use crate::treesa::TreeSA;
let code = EinCode::new(
vec![vec!['i', 'j'], vec!['j'], vec!['j', 'k']],
vec!['i', 'k'],
);
let sizes = uniform_size_dict(&code, 4);
let result = optimize_simplified(&code, &sizes, &TreeSA::fast());
assert!(result.is_some());
let tree = result.unwrap();
assert_eq!(tree.leaf_count(), 3);
}
#[test]
fn test_network_simplifier_accessors() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j']], vec!['i']);
let simplifier = merge_vectors(&code);
assert_eq!(simplifier.original_code(), &code);
assert_eq!(simplifier.simplified_code().num_tensors(), 1);
assert!(simplifier.is_simplified());
assert_eq!(simplifier.num_merges(), 1);
}
#[test]
fn test_merge_operation_fields() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j']], vec!['i']);
let simplifier = merge_vectors(&code);
let merge = &simplifier.merges()[0];
assert_eq!(merge.vector_index, 1);
assert_eq!(merge.target_index, 0);
assert_eq!(merge.shared_index, 'j');
}
#[test]
fn test_embed_no_simplification() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'k']], vec!['i', 'k']);
let sizes = uniform_size_dict(&code, 4);
let simplifier = merge_vectors(&code);
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let embedded = simplifier.embed(tree.clone());
assert_eq!(embedded.leaf_count(), tree.leaf_count());
}
#[test]
fn test_simplification_complex_network() {
let code = EinCode::new(
vec![
vec!['a', 'b'], vec!['b'], vec!['b', 'c', 'd'], vec!['c'], vec!['d', 'e'], vec!['e'], ],
vec!['a'],
);
let sizes = uniform_size_dict(&code, 2);
let simplifier = merge_vectors(&code);
assert_eq!(simplifier.num_merges(), 3);
let result = optimize_simplified(&code, &sizes, &GreedyMethod::default());
assert!(result.is_some());
let tree = result.unwrap();
assert_eq!(tree.leaf_count(), 6); }
}