use ndarray::{ArrayD, IxDyn};
use rand::Rng;
use std::collections::{HashMap, HashSet};
#[derive(Default, Clone)]
pub struct NaiveContractor {
tensors: HashMap<usize, ArrayD<f64>>,
}
impl NaiveContractor {
pub fn new() -> Self {
Self::default()
}
pub fn add_tensor(&mut self, idx: usize, shape: Vec<usize>) {
let mut rng = rand::rng();
let size: usize = shape.iter().product();
let data: Vec<f64> = (0..size).map(|_| rng.random()).collect();
let tensor = ArrayD::from_shape_vec(IxDyn(&shape), data).unwrap();
self.tensors.insert(idx, tensor);
}
pub fn contract(
&mut self,
left_idx: usize,
right_idx: usize,
left_labels: &[usize],
right_labels: &[usize],
output_labels: &[usize],
) -> usize {
let left = self.tensors[&left_idx].clone();
let right = self.tensors[&right_idx].clone();
let result = self.einsum_contract(&left, &right, left_labels, right_labels, output_labels);
let result_idx = left_idx.min(right_idx);
self.tensors.insert(result_idx, result);
self.tensors.remove(&left_idx.max(right_idx));
result_idx
}
pub fn get_tensor(&self, idx: usize) -> Option<&ArrayD<f64>> {
self.tensors.get(&idx)
}
pub fn get_shape(&self, idx: usize) -> Option<Vec<usize>> {
self.tensors.get(&idx).map(|t| t.shape().to_vec())
}
fn einsum_contract(
&self,
left: &ArrayD<f64>,
right: &ArrayD<f64>,
left_labels: &[usize],
right_labels: &[usize],
output_labels: &[usize],
) -> ArrayD<f64> {
let mut label_sizes: HashMap<usize, usize> = HashMap::new();
for (i, &label) in left_labels.iter().enumerate() {
let size = left.shape()[i];
if let Some(&existing) = label_sizes.get(&label) {
assert_eq!(existing, size, "Label {} has inconsistent sizes", label);
} else {
label_sizes.insert(label, size);
}
}
for (i, &label) in right_labels.iter().enumerate() {
let size = right.shape()[i];
if let Some(&existing) = label_sizes.get(&label) {
assert_eq!(existing, size, "Label {} has inconsistent sizes", label);
} else {
label_sizes.insert(label, size);
}
}
let output_shape: Vec<usize> = output_labels
.iter()
.map(|&label| *label_sizes.get(&label).unwrap_or(&1))
.collect();
let output_size: usize = if output_shape.is_empty() {
1
} else {
output_shape.iter().product()
};
let mut result_data = vec![0.0; output_size];
let mut all_labels: HashSet<usize> = HashSet::new();
all_labels.extend(left_labels.iter().copied());
all_labels.extend(right_labels.iter().copied());
let all_labels: Vec<usize> = all_labels.into_iter().collect();
let total_iterations: usize = all_labels
.iter()
.map(|&label| *label_sizes.get(&label).unwrap_or(&1))
.product();
for iter_idx in 0..total_iterations {
let mut label_values: HashMap<usize, usize> = HashMap::new();
let mut remaining = iter_idx;
for &label in all_labels.iter().rev() {
let size = *label_sizes.get(&label).unwrap_or(&1);
label_values.insert(label, remaining % size);
remaining /= size;
}
let left_indices: Vec<usize> = left_labels
.iter()
.map(|&label| *label_values.get(&label).unwrap_or(&0))
.collect();
let right_indices: Vec<usize> = right_labels
.iter()
.map(|&label| *label_values.get(&label).unwrap_or(&0))
.collect();
let output_indices: Vec<usize> = output_labels
.iter()
.map(|&label| *label_values.get(&label).unwrap_or(&0))
.collect();
let left_val = if left.shape().is_empty() {
1.0
} else {
left[&*left_indices]
};
let right_val = if right.shape().is_empty() {
1.0
} else {
right[&*right_indices]
};
let mut out_idx = 0;
let mut out_stride = 1;
for i in (0..output_indices.len()).rev() {
out_idx += output_indices[i] * out_stride;
out_stride *= output_shape[i];
}
result_data[out_idx] += left_val * right_val;
}
if output_shape.is_empty() {
ArrayD::from_shape_vec(IxDyn(&[]), vec![result_data[0]]).unwrap()
} else {
ArrayD::from_shape_vec(IxDyn(&output_shape), result_data).unwrap()
}
}
}
pub fn generate_random_eincode(
num_tensors: usize,
num_indices: usize,
allow_duplicates: bool,
allow_output_only_indices: bool,
) -> (Vec<Vec<usize>>, Vec<usize>) {
let mut rng = rand::rng();
let mut ixs = Vec::new();
let mut all_indices = HashSet::new();
for _ in 0..num_tensors {
let tensor_rank = rng.random_range(1..=4);
let mut tensor_indices = Vec::new();
for _ in 0..tensor_rank {
let idx = rng.random_range(1..=num_indices);
tensor_indices.push(idx);
all_indices.insert(idx);
}
if allow_duplicates && rng.random_bool(0.3) && !tensor_indices.is_empty() {
let dup_idx = tensor_indices[rng.random_range(0..tensor_indices.len())];
tensor_indices.push(dup_idx);
}
ixs.push(tensor_indices);
}
let mut output = Vec::new();
let mut used_output_indices = HashSet::new();
let num_output = if all_indices.is_empty() {
0
} else {
rng.random_range(0..=3)
};
for _ in 0..num_output {
let idx = if allow_output_only_indices && rng.random_bool(0.2) {
num_indices + 1 + output.len()
} else if !all_indices.is_empty() {
let available: Vec<_> = all_indices
.difference(&used_output_indices)
.copied()
.collect();
if available.is_empty() {
continue;
}
available[rng.random_range(0..available.len())]
} else {
continue;
};
if !used_output_indices.contains(&idx) {
output.push(idx);
used_output_indices.insert(idx);
}
}
(ixs, output)
}
pub fn generate_fullerene_edges() -> Vec<(usize, usize)> {
let mut edges = Vec::new();
for i in 0..5 {
edges.push((i, (i + 1) % 5));
}
for i in 0..5 {
edges.push((i, 5 + 2 * i));
edges.push((i, 5 + 2 * i + 1));
}
for i in 0..10 {
edges.push((5 + i, 5 + (i + 1) % 10));
}
for i in 0..10 {
edges.push((5 + i, 15 + 2 * i));
}
for i in 0..20 {
edges.push((15 + i, 15 + (i + 1) % 20));
}
for i in 0..10 {
edges.push((15 + 2 * i, 35 + i));
}
for i in 0..10 {
edges.push((35 + i, 35 + (i + 1) % 10));
}
for i in 0..5 {
edges.push((35 + 2 * i, 45 + i));
edges.push((35 + 2 * i + 1, 45 + i));
}
for i in 0..5 {
edges.push((45 + i, 45 + (i + 1) % 5));
}
for i in 0..10 {
edges.push((35 + i, 50 + i));
}
for i in 0..10 {
edges.push((50 + i, 50 + (i + 1) % 10));
}
edges.into_iter().map(|(a, b)| (a + 1, b + 1)).collect()
}
pub fn generate_tutte_edges() -> Vec<(usize, usize)> {
let mut edges = Vec::new();
for i in 0..15 {
edges.push((i, (i + 1) % 15));
}
for i in 0..15 {
edges.push((15 + i, 15 + (i + 1) % 15));
}
for i in 0..15 {
edges.push((i, 15 + i));
}
for i in 0..15 {
let inner = 30 + i;
let next_inner = 30 + (i + 1) % 15;
edges.push((15 + i, inner));
if i % 3 == 0 {
edges.push((inner, next_inner));
edges.push((next_inner, 30 + (i + 2) % 15));
}
}
for i in 0..15 {
if i % 3 == 1 {
edges.push((30 + i, 45));
}
}
edges.into_iter().map(|(a, b)| (a + 1, b + 1)).collect()
}
pub fn generate_ring_edges(n: usize) -> Vec<(usize, usize)> {
let mut edges = Vec::new();
for i in 0..n {
edges.push((i + 1, ((i + 1) % n) + 1));
}
edges
}
pub fn generate_chain_edges(n: usize) -> Vec<(usize, usize)> {
(1..n).map(|i| (i, i + 1)).collect()
}
pub fn generate_petersen_edges() -> Vec<(usize, usize)> {
let mut edges = Vec::new();
for i in 1..=5 {
edges.push((i, (i % 5) + 1));
}
for i in 0..5 {
let from = 6 + i;
let to = 6 + (i + 2) % 5;
edges.push((from, to));
}
for i in 1..=5 {
edges.push((i, i + 5));
}
edges
}
pub fn execute_nested<L: crate::Label>(
nested: &crate::NestedEinsum<L>,
contractor: &mut NaiveContractor,
label_map: &HashMap<L, usize>,
) -> usize {
execute_nested_impl(nested, contractor, label_map).0
}
fn execute_nested_impl<L: crate::Label>(
nested: &crate::NestedEinsum<L>,
contractor: &mut NaiveContractor,
label_map: &HashMap<L, usize>,
) -> (usize, Vec<usize>) {
use crate::NestedEinsum;
match nested {
NestedEinsum::Leaf { tensor_index } => {
(*tensor_index, vec![])
}
NestedEinsum::Node { args, eins } => {
let child_results: Vec<(usize, Vec<usize>)> = args
.iter()
.map(|child| execute_nested_impl(child, contractor, label_map))
.collect();
if child_results.len() == 2 {
let (left_idx, _) = child_results[0];
let (right_idx, _) = child_results[1];
let left_labels: Vec<usize> = eins.ixs[0]
.iter()
.map(|l| *label_map.get(l).expect("Label should be in map"))
.collect();
let right_labels: Vec<usize> = eins.ixs[1]
.iter()
.map(|l| *label_map.get(l).expect("Label should be in map"))
.collect();
let output_labels: Vec<usize> = eins
.iy
.iter()
.map(|l| *label_map.get(l).expect("Label should be in map"))
.collect();
let result_idx = contractor.contract(
left_idx,
right_idx,
&left_labels,
&right_labels,
&output_labels,
);
(result_idx, output_labels)
} else if child_results.len() > 2 {
let mut current_idx = child_results[0].0;
let mut current_labels: Vec<usize> = eins.ixs[0]
.iter()
.map(|l| *label_map.get(l).expect("Label should be in map"))
.collect();
let final_output_labels: Vec<usize> = eins
.iy
.iter()
.map(|l| *label_map.get(l).expect("Label should be in map"))
.collect();
let all_remaining_labels: Vec<HashSet<usize>> = (1..child_results.len())
.map(|j| {
eins.ixs[j]
.iter()
.map(|l| *label_map.get(l).expect("Label should be in map"))
.collect()
})
.collect();
for i in 1..child_results.len() {
let (right_idx, _) = child_results[i];
let right_labels: Vec<usize> = eins.ixs[i]
.iter()
.map(|l| *label_map.get(l).expect("Label should be in map"))
.collect();
let output_labels: Vec<usize> = if i == child_results.len() - 1 {
final_output_labels.clone()
} else {
let remaining: HashSet<usize> = all_remaining_labels[i..]
.iter()
.flat_map(|s| s.iter().copied())
.chain(final_output_labels.iter().copied())
.collect();
let mut out = Vec::new();
for &l in ¤t_labels {
if remaining.contains(&l) && !out.contains(&l) {
out.push(l);
}
}
for &l in &right_labels {
if remaining.contains(&l) && !out.contains(&l) {
out.push(l);
}
}
out
};
current_idx = contractor.contract(
current_idx,
right_idx,
¤t_labels,
&right_labels,
&output_labels,
);
current_labels = output_labels;
}
(current_idx, current_labels)
} else if child_results.len() == 1 {
let (idx, _) = child_results[0];
let output_labels: Vec<usize> = eins
.iy
.iter()
.map(|l| *label_map.get(l).expect("Label should be in map"))
.collect();
(idx, output_labels)
} else {
panic!("execute_nested: Node with no children");
}
}
}
}
pub fn tensors_approx_equal(a: &ArrayD<f64>, b: &ArrayD<f64>, rtol: f64, atol: f64) -> bool {
if a.shape() != b.shape() {
return false;
}
a.iter()
.zip(b.iter())
.all(|(x, y)| (x - y).abs() <= atol + rtol * y.abs())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_naive_contractor_basic() {
let mut contractor = NaiveContractor::new();
contractor.add_tensor(0, vec![2, 3]); contractor.add_tensor(1, vec![3, 2]);
let result_idx = contractor.contract(
0,
1,
&[1, 2], &[2, 3], &[1, 3], );
let result_shape = contractor.get_shape(result_idx).unwrap();
assert_eq!(result_shape, vec![2, 2], "Result should be 2x2");
}
#[test]
fn test_naive_contractor_scalar() {
let mut contractor = NaiveContractor::new();
contractor.add_tensor(0, vec![2, 2]);
contractor.add_tensor(1, vec![2, 2]);
let result_idx = contractor.contract(0, 1, &[1, 2], &[1, 2], &[]);
let result_tensor = contractor.get_tensor(result_idx).unwrap();
assert_eq!(result_tensor.ndim(), 0, "Result should be scalar");
}
#[test]
fn test_naive_contractor_outer_product() {
let mut contractor = NaiveContractor::new();
contractor.add_tensor(0, vec![2]);
contractor.add_tensor(1, vec![3]);
let result_idx = contractor.contract(0, 1, &[1], &[2], &[1, 2]);
let result_shape = contractor.get_shape(result_idx).unwrap();
assert_eq!(result_shape, vec![2, 3], "Result should be 2x3");
}
#[test]
fn test_generate_random_eincode_basic() {
let (ixs, output) = generate_random_eincode(3, 5, false, false);
assert_eq!(ixs.len(), 3, "Should generate 3 tensors");
for tensor_indices in &ixs {
for &idx in tensor_indices {
assert!((1..=5).contains(&idx), "Index should be in range 1-5");
}
}
let mut seen = HashSet::new();
for &idx in &output {
assert!(!seen.contains(&idx), "Output should not have duplicates");
seen.insert(idx);
}
}
#[test]
fn test_generate_random_eincode_with_duplicates() {
let (ixs, _output) = generate_random_eincode(5, 8, true, false);
assert_eq!(ixs.len(), 5, "Should generate 5 tensors");
}
#[test]
fn test_generate_random_eincode_with_broadcast() {
let (_ixs, output) = generate_random_eincode(3, 8, false, true);
assert!(output.len() <= 3, "Output should have at most 3 indices");
}
#[test]
fn test_generate_ring_edges() {
let edges = generate_ring_edges(5);
assert_eq!(edges.len(), 5, "Ring with 5 vertices should have 5 edges");
assert_eq!(edges[0], (1, 2));
assert_eq!(edges[1], (2, 3));
assert_eq!(edges[2], (3, 4));
assert_eq!(edges[3], (4, 5));
assert_eq!(edges[4], (5, 1)); }
#[test]
fn test_generate_fullerene_edges() {
let edges = generate_fullerene_edges();
assert!(!edges.is_empty(), "Fullerene graph should have edges");
for &(a, b) in &edges {
assert!(a >= 1, "Vertices should be 1-indexed");
assert!(b >= 1, "Vertices should be 1-indexed");
assert_ne!(a, b, "No self-loops");
}
}
#[test]
fn test_generate_tutte_edges() {
let edges = generate_tutte_edges();
assert!(!edges.is_empty(), "Tutte graph should have edges");
for &(a, b) in &edges {
assert!(a >= 1, "Vertices should be 1-indexed");
assert!(b >= 1, "Vertices should be 1-indexed");
assert_ne!(a, b, "No self-loops");
}
}
#[test]
fn test_naive_contractor_default() {
let contractor = NaiveContractor::default();
assert_eq!(contractor.tensors.len(), 0, "Default should be empty");
}
#[test]
fn test_generate_chain_edges() {
let edges = generate_chain_edges(5);
assert_eq!(edges.len(), 4, "Chain with 5 vertices should have 4 edges");
assert_eq!(edges[0], (1, 2));
assert_eq!(edges[1], (2, 3));
assert_eq!(edges[2], (3, 4));
assert_eq!(edges[3], (4, 5));
}
#[test]
fn test_generate_chain_edges_small() {
let edges = generate_chain_edges(2);
assert_eq!(edges.len(), 1, "Chain with 2 vertices should have 1 edge");
assert_eq!(edges[0], (1, 2));
}
#[test]
fn test_generate_petersen_edges() {
let edges = generate_petersen_edges();
assert_eq!(edges.len(), 15, "Petersen graph should have 15 edges");
for &(a, b) in &edges {
assert!((1..=10).contains(&a), "Vertices should be 1-10");
assert!((1..=10).contains(&b), "Vertices should be 1-10");
assert_ne!(a, b, "No self-loops");
}
let mut degree = [0; 11];
for &(a, b) in &edges {
degree[a] += 1;
degree[b] += 1;
}
for (v, &d) in degree.iter().enumerate().take(11).skip(1) {
assert_eq!(
d, 3,
"Petersen graph is 3-regular, vertex {} has degree {}",
v, d
);
}
}
#[test]
fn test_tensors_approx_equal() {
let a = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let b = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
assert!(tensors_approx_equal(&a, &b, 1e-5, 1e-8));
let c = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0001]).unwrap();
assert!(tensors_approx_equal(&a, &c, 1e-3, 1e-8));
assert!(!tensors_approx_equal(&a, &c, 1e-6, 1e-10));
let d = ArrayD::from_shape_vec(IxDyn(&[4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
assert!(!tensors_approx_equal(&a, &d, 1e-5, 1e-8));
}
#[test]
fn test_execute_nested_single_contraction() {
use crate::greedy::optimize_greedy;
use crate::{EinCode, GreedyMethod};
let code = EinCode::new(vec![vec![1usize, 2], vec![2, 3]], vec![1, 3]);
let sizes: HashMap<usize, usize> = [(1, 2), (2, 3), (3, 2)].into();
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let mut contractor = NaiveContractor::new();
contractor.add_tensor(0, vec![2, 3]); contractor.add_tensor(1, vec![3, 2]);
let label_map: HashMap<usize, usize> = [(1, 1), (2, 2), (3, 3)].into();
let result_idx = execute_nested(&tree, &mut contractor, &label_map);
let result_shape = contractor.get_shape(result_idx).unwrap();
assert_eq!(result_shape, vec![2, 2]);
}
#[test]
fn test_execute_nested_chain() {
use crate::greedy::optimize_greedy;
use crate::{EinCode, GreedyMethod};
let code = EinCode::new(vec![vec![1usize, 2], vec![2, 3], vec![3, 4]], vec![1, 4]);
let sizes: HashMap<usize, usize> = [(1, 2), (2, 3), (3, 4), (4, 2)].into();
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let mut contractor = NaiveContractor::new();
contractor.add_tensor(0, vec![2, 3]); contractor.add_tensor(1, vec![3, 4]); contractor.add_tensor(2, vec![4, 2]);
let label_map: HashMap<usize, usize> = [(1, 1), (2, 2), (3, 3), (4, 4)].into();
let result_idx = execute_nested(&tree, &mut contractor, &label_map);
let result_shape = contractor.get_shape(result_idx).unwrap();
assert_eq!(result_shape, vec![2, 2]); }
#[test]
fn test_execute_nested_scalar_result() {
use crate::greedy::optimize_greedy;
use crate::{EinCode, GreedyMethod};
let code = EinCode::new(vec![vec![1usize, 2], vec![2, 1]], vec![]);
let sizes: HashMap<usize, usize> = [(1, 2), (2, 3)].into();
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let mut contractor = NaiveContractor::new();
contractor.add_tensor(0, vec![2, 3]);
contractor.add_tensor(1, vec![3, 2]);
let label_map: HashMap<usize, usize> = [(1, 1), (2, 2)].into();
let result_idx = execute_nested(&tree, &mut contractor, &label_map);
let result_tensor = contractor.get_tensor(result_idx).unwrap();
assert_eq!(result_tensor.ndim(), 0); }
#[test]
fn test_naive_contractor_get_tensor() {
let mut contractor = NaiveContractor::new();
contractor.add_tensor(0, vec![2, 3]);
let tensor = contractor.get_tensor(0);
assert!(tensor.is_some());
assert_eq!(tensor.unwrap().shape(), &[2, 3]);
let missing = contractor.get_tensor(99);
assert!(missing.is_none());
}
#[test]
fn test_naive_contractor_get_shape() {
let mut contractor = NaiveContractor::new();
contractor.add_tensor(0, vec![2, 3, 4]);
let shape = contractor.get_shape(0);
assert_eq!(shape, Some(vec![2, 3, 4]));
let missing = contractor.get_shape(99);
assert!(missing.is_none());
}
#[test]
fn test_generate_random_eincode_determinism() {
for _ in 0..5 {
let (ixs, output) = generate_random_eincode(4, 6, false, false);
assert_eq!(ixs.len(), 4);
for ix in &ixs {
assert!(!ix.is_empty());
assert!(ix.len() <= 5); }
let output_set: HashSet<_> = output.iter().collect();
assert_eq!(output_set.len(), output.len());
}
}
}