use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn einsum_optimized(equation: &str, operands: &[&Tensor]) -> TorshResult<Tensor> {
use std::collections::HashMap;
if operands.is_empty() {
return Err(TorshError::invalid_argument_with_context(
"einsum requires at least one operand",
"einsum_optimized",
));
}
let (inputs, output) = parse_einsum_equation(equation)?;
if inputs.len() != operands.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"einsum equation expects {} operands, got {}",
inputs.len(),
operands.len()
),
"einsum_optimized",
));
}
let mut index_sizes: HashMap<char, usize> = HashMap::new();
for (input_idx_str, &operand) in inputs.iter().zip(operands.iter()) {
let shape = operand.shape();
let dims = shape.dims();
let chars: Vec<char> = input_idx_str.chars().collect();
if chars.len() != dims.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"operand {} index string '{}' has {} chars but tensor has {} dimensions",
input_idx_str,
input_idx_str,
chars.len(),
dims.len()
),
"einsum_optimized",
));
}
for (&ch, &sz) in chars.iter().zip(dims.iter()) {
if let Some(&existing) = index_sizes.get(&ch) {
if existing != sz {
return Err(TorshError::invalid_argument_with_context(
&format!(
"index '{}' has inconsistent sizes: {} vs {}",
ch, existing, sz
),
"einsum_optimized",
));
}
} else {
index_sizes.insert(ch, sz);
}
}
}
let optimal_path = optimize_contraction_path(&inputs, &output, &index_sizes)?;
execute_contraction_path(operands, &optimal_path, &output)
}
fn parse_einsum_equation(equation: &str) -> TorshResult<(Vec<String>, String)> {
let parts: Vec<&str> = equation.split("->").collect();
if parts.len() > 2 {
return Err(TorshError::invalid_argument_with_context(
"einsum equation can have at most one '->' separator",
"parse_einsum_equation",
));
}
let input_str = parts[0];
let inputs: Vec<String> = input_str.split(',').map(|s| s.trim().to_string()).collect();
let output = if parts.len() == 2 {
parts[1].trim().to_string()
} else {
infer_output_indices(&inputs)
};
Ok((inputs, output))
}
fn infer_output_indices(inputs: &[String]) -> String {
use std::collections::HashMap;
let mut index_counts = HashMap::new();
for input in inputs {
for ch in input.chars() {
if ch.is_alphabetic() {
*index_counts.entry(ch).or_insert(0) += 1;
}
}
}
let mut output_chars: Vec<char> = index_counts
.iter()
.filter(|(_, &count)| count == 1)
.map(|(&ch, _)| ch)
.collect();
output_chars.sort_unstable();
output_chars.into_iter().collect()
}
#[derive(Debug, Clone)]
struct ContractionStep {
operand1: usize,
operand2: usize,
result_indices: String,
}
fn optimize_contraction_path(
inputs: &[String],
output: &str,
index_sizes: &std::collections::HashMap<char, usize>,
) -> TorshResult<Vec<ContractionStep>> {
let n = inputs.len();
if n <= 1 {
return Ok(vec![]);
}
const DP_THRESHOLD: usize = 14;
if n > DP_THRESHOLD {
return greedy_contraction_path(inputs, output, index_sizes);
}
let all_chars: Vec<char> = {
let mut chars: Vec<char> = inputs
.iter()
.flat_map(|s| s.chars())
.chain(output.chars())
.collect();
chars.sort_unstable();
chars.dedup();
chars
};
let char_to_bit: std::collections::HashMap<char, u64> = all_chars
.iter()
.enumerate()
.map(|(i, &c)| (c, 1u64 << i))
.collect();
let char_sizes: Vec<u64> = all_chars
.iter()
.map(|c| *index_sizes.get(c).unwrap_or(&1) as u64)
.collect();
let tensor_char_mask: Vec<u64> = inputs
.iter()
.map(|s| {
s.chars()
.filter_map(|c| char_to_bit.get(&c))
.fold(0u64, |acc, &b| acc | b)
})
.collect();
let num_subsets = 1usize << n;
let mut subset_char_mask = vec![0u64; num_subsets];
for i in 0..n {
let bit = 1usize << i;
for s in 0..num_subsets {
if s & bit != 0 {
subset_char_mask[s] |= tensor_char_mask[i];
}
}
}
let full_set = num_subsets - 1;
let contraction_flops = |l: usize, r: usize| -> u64 {
let union_chars = subset_char_mask[l] | subset_char_mask[r];
all_chars
.iter()
.enumerate()
.filter(|(i, _)| union_chars & (1u64 << i) != 0)
.map(|(i, _)| char_sizes[i])
.product::<u64>()
};
let inf = u64::MAX / 2;
let mut dp_cost = vec![inf; num_subsets];
let mut dp_split: Vec<(usize, usize)> = vec![(0, 0); num_subsets];
for i in 0..n {
dp_cost[1 << i] = 0;
}
for size in 2..=n {
let mut s = (1usize << size) - 1; while s < num_subsets {
let mut l = (s - 1) & s; while l > 0 {
let r = s ^ l;
if l < r {
let cost_l = dp_cost[l];
let cost_r = dp_cost[r];
if cost_l < inf && cost_r < inf {
let flops = contraction_flops(l, r);
let total = cost_l.saturating_add(cost_r).saturating_add(flops);
if total < dp_cost[s] {
dp_cost[s] = total;
dp_split[s] = (l, r);
}
}
}
l = (l - 1) & s;
}
let c = s & s.wrapping_neg();
let r = s + c;
s = (((r ^ s) >> 2) / c) | r;
}
}
let mut steps: Vec<ContractionStep> = Vec::with_capacity(n - 1);
let mut live_pos: Vec<usize> = (0..n).collect();
let mut live_indices: Vec<String> = inputs.to_vec();
let mut stack: Vec<usize> = vec![full_set];
let mut raw_pairs: Vec<(usize, usize)> = Vec::with_capacity(n - 1);
while let Some(s) = stack.pop() {
if s.count_ones() <= 1 {
continue;
}
let (l, r) = dp_split[s];
raw_pairs.push((l, r));
if r.count_ones() > 1 {
stack.push(r);
}
if l.count_ones() > 1 {
stack.push(l);
}
}
raw_pairs.reverse();
let mut subset_live: std::collections::HashMap<usize, usize> = std::collections::HashMap::new();
for i in 0..n {
subset_live.insert(1 << i, i);
}
for (l, r) in raw_pairs {
let live_l = *subset_live.get(&l).ok_or_else(|| {
TorshError::invalid_argument_with_context(
"DP backtrack: missing live mapping for left subset",
"optimize_contraction_path",
)
})?;
let live_r = *subset_live.get(&r).ok_or_else(|| {
TorshError::invalid_argument_with_context(
"DP backtrack: missing live mapping for right subset",
"optimize_contraction_path",
)
})?;
let idx_l = live_indices[live_pos[live_l]].clone();
let idx_r = live_indices[live_pos[live_r]].clone();
let result_indices =
compute_pairwise_result(&idx_l, &idx_r, output, &live_indices, live_l, live_r);
let pos_l = live_pos[live_l];
let pos_r = live_pos[live_r];
steps.push(ContractionStep {
operand1: pos_l,
operand2: pos_r,
result_indices: result_indices.clone(),
});
live_indices[pos_l] = result_indices;
live_pos[live_r] = pos_l;
subset_live.insert(l | r, live_l);
}
Ok(steps)
}
fn compute_pairwise_result(
idx_l: &str,
idx_r: &str,
final_output: &str,
all_live_indices: &[String],
skip_l: usize,
skip_r: usize,
) -> String {
use std::collections::HashSet;
let chars_l: HashSet<char> = idx_l.chars().collect();
let chars_r: HashSet<char> = idx_r.chars().collect();
let union: HashSet<char> = chars_l.union(&chars_r).copied().collect();
let outside: HashSet<char> = all_live_indices
.iter()
.enumerate()
.filter(|(i, _)| *i != skip_l && *i != skip_r)
.flat_map(|(_, s)| s.chars())
.collect();
let output_chars: HashSet<char> = final_output.chars().collect();
let mut result_chars: Vec<char> = union
.into_iter()
.filter(|c| output_chars.contains(c) || outside.contains(c))
.collect();
result_chars.sort_unstable();
result_chars.into_iter().collect()
}
fn greedy_contraction_path(
inputs: &[String],
output: &str,
index_sizes: &std::collections::HashMap<char, usize>,
) -> TorshResult<Vec<ContractionStep>> {
let mut steps = Vec::new();
let mut remaining = inputs.to_vec();
while remaining.len() > 1 {
let n = remaining.len();
let mut best_cost = u64::MAX;
let mut best_i = 0usize;
let mut best_j = 1usize;
for i in 0..n {
for j in (i + 1)..n {
let chars_i: std::collections::HashSet<char> = remaining[i].chars().collect();
let chars_j: std::collections::HashSet<char> = remaining[j].chars().collect();
let union: std::collections::HashSet<char> =
chars_i.union(&chars_j).copied().collect();
let cost: u64 = union
.iter()
.map(|c| *index_sizes.get(c).unwrap_or(&1) as u64)
.product();
if cost < best_cost {
best_cost = cost;
best_i = i;
best_j = j;
}
}
}
let result_indices = compute_pairwise_result(
&remaining[best_i].clone(),
&remaining[best_j].clone(),
output,
&remaining,
best_i,
best_j,
);
steps.push(ContractionStep {
operand1: best_i,
operand2: best_j,
result_indices: result_indices.clone(),
});
remaining.remove(best_j.max(best_i));
remaining.remove(best_j.min(best_i));
remaining.push(result_indices);
}
Ok(steps)
}
fn execute_contraction_path(
operands: &[&Tensor],
path: &[ContractionStep],
output: &str,
) -> TorshResult<Tensor> {
let mut pool: Vec<Option<Tensor>> = operands.iter().map(|&t| Some(t.clone())).collect();
if path.is_empty() {
return pool.into_iter().find_map(|slot| slot).ok_or_else(|| {
TorshError::InvalidOperation("execute_contraction_path: empty operand pool".to_string())
});
}
for step in path {
let i = step.operand1;
let j = step.operand2;
if i >= pool.len() || j >= pool.len() || i == j {
return Err(TorshError::InvalidOperation(format!(
"execute_contraction_path: invalid step indices ({}, {}) for pool size {} \
(result_indices='{}')",
i,
j,
pool.len(),
step.result_indices
)));
}
let a = pool[i].take().ok_or_else(|| {
TorshError::InvalidOperation(format!(
"execute_contraction_path: slot {} already consumed (result_indices='{}')",
i, step.result_indices
))
})?;
let b = pool[j].take().ok_or_else(|| {
TorshError::InvalidOperation(format!(
"execute_contraction_path: slot {} already consumed (result_indices='{}')",
j, step.result_indices
))
})?;
let operand_vec: Vec<Tensor> = vec![a, b];
let result = crate::math::einsum("ij,jk->ik", &operand_vec)?;
pool[i] = Some(result);
}
let _ = output;
pool.into_iter().find_map(|slot| slot).ok_or_else(|| {
TorshError::InvalidOperation(
"execute_contraction_path: no result tensor after all steps".to_string(),
)
})
}
pub fn tensor_contract(
a: &Tensor,
b: &Tensor,
axes_a: &[usize],
axes_b: &[usize],
) -> TorshResult<Tensor> {
if axes_a.len() != axes_b.len() {
return Err(TorshError::invalid_argument_with_context(
"number of contraction axes must match",
"tensor_contract",
));
}
let a_shape_obj = a.shape();
let shape_a = a_shape_obj.dims();
let b_shape_obj = b.shape();
let shape_b = b_shape_obj.dims();
for &axis in axes_a {
if axis >= shape_a.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"axis {} out of range for tensor with {} dimensions",
axis,
shape_a.len()
),
"tensor_contract",
));
}
}
for &axis in axes_b {
if axis >= shape_b.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"axis {} out of range for tensor with {} dimensions",
axis,
shape_b.len()
),
"tensor_contract",
));
}
}
for (&axis_a, &axis_b) in axes_a.iter().zip(axes_b.iter()) {
if shape_a[axis_a] != shape_b[axis_b] {
return Err(TorshError::invalid_argument_with_context(
&format!(
"contracted dimensions must match: {} != {}",
shape_a[axis_a], shape_b[axis_b]
),
"tensor_contract",
));
}
}
crate::manipulation::tensordot(
a,
b,
crate::manipulation::TensorDotAxes::Arrays(axes_a.to_vec(), axes_b.to_vec()),
)
}
pub fn tensor_map<F>(input: &Tensor<f32>, f: F) -> TorshResult<Tensor<f32>>
where
F: Fn(f32) -> f32 + Send + Sync,
{
let data = input.data()?;
let shape = input.shape().dims().to_vec();
let device = input.device();
let result_data: Vec<f32> = if data.len() > 10000 {
use scirs2_core::parallel_ops::*;
data.iter()
.copied()
.collect::<Vec<_>>()
.into_par_iter()
.map(f)
.collect()
} else {
data.iter().map(|&x| f(x)).collect()
};
Tensor::from_data(result_data, shape, device)
}
pub fn tensor_reduce<F>(
input: &Tensor<f32>,
axis: Option<usize>,
f: F,
init: f32,
) -> TorshResult<Tensor<f32>>
where
F: Fn(f32, f32) -> f32 + Send + Sync,
{
let input_shape = input.shape();
let shape = input_shape.dims();
if let Some(ax) = axis {
if ax >= shape.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"axis {} out of range for tensor with {} dimensions",
ax,
shape.len()
),
"tensor_reduce",
));
}
let data = input.data()?;
let mut output_shape = shape.to_vec();
output_shape.remove(ax);
if output_shape.is_empty() {
let result = data.iter().fold(init, |acc, &x| f(acc, x));
return Tensor::from_data(vec![result], vec![1], input.device());
}
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let output_size: usize = output_shape.iter().product();
let axis_size = shape[ax];
let mut result_data = vec![init; output_size];
for (out_idx, result_val) in result_data.iter_mut().enumerate() {
for axis_idx in 0..axis_size {
let mut in_idx = 0;
let mut remaining = out_idx;
let mut out_dim_idx = 0;
for dim_idx in 0..shape.len() {
if dim_idx == ax {
in_idx += axis_idx * strides[dim_idx];
} else {
let size = output_shape[out_dim_idx];
let coord = remaining % size;
remaining /= size;
in_idx += coord * strides[dim_idx];
out_dim_idx += 1;
}
}
if in_idx < data.len() {
*result_val = f(*result_val, data[in_idx]);
}
}
}
Tensor::from_data(result_data, output_shape, input.device())
} else {
let data = input.data()?;
let result = data.iter().fold(init, |acc, &x| f(acc, x));
Tensor::from_data(vec![result], vec![1], input.device())
}
}
pub fn tensor_scan<F>(input: &Tensor<f32>, axis: usize, f: F, init: f32) -> TorshResult<Tensor<f32>>
where
F: Fn(f32, f32) -> f32,
{
let input_shape = input.shape();
let shape = input_shape.dims();
if axis >= shape.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"axis {} out of range for tensor with {} dimensions",
axis,
shape.len()
),
"tensor_scan",
));
}
let data = input.data()?;
let mut result_data = data.to_vec();
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let axis_size = shape[axis];
let axis_stride = strides[axis];
let other_size: usize = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != axis)
.map(|(_, &s)| s)
.product();
for other_idx in 0..other_size {
let mut base_idx = 0;
let mut remaining = other_idx;
for (dim_idx, &size) in shape.iter().enumerate() {
if dim_idx != axis {
let coord = remaining % size;
remaining /= size;
base_idx += coord * strides[dim_idx];
}
}
let mut acc = init;
for axis_idx in 0..axis_size {
let idx = base_idx + axis_idx * axis_stride;
if idx < result_data.len() {
acc = f(acc, result_data[idx]);
result_data[idx] = acc;
}
}
}
Tensor::from_data(result_data, shape.to_vec(), input.device())
}
pub fn tensor_fold<F>(input: &Tensor<f32>, f: F, init: f32) -> TorshResult<f32>
where
F: Fn(f32, f32) -> f32,
{
let data = input.data()?;
Ok(data.iter().fold(init, |acc, &x| f(acc, x)))
}
pub fn tensor_outer(a: &Tensor<f32>, b: &Tensor<f32>) -> TorshResult<Tensor<f32>> {
let a_shape_obj = a.shape();
let shape_a = a_shape_obj.dims();
let b_shape_obj = b.shape();
let shape_b = b_shape_obj.dims();
let mut new_shape_a = shape_a.to_vec();
new_shape_a.extend(vec![1; shape_b.len()]);
let mut new_shape_b = vec![1; shape_a.len()];
new_shape_b.extend(shape_b);
let a_reshaped = a.view(&new_shape_a.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
let b_reshaped = b.view(&new_shape_b.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
a_reshaped.mul(&b_reshaped)
}
pub fn tensor_zip<F>(a: &Tensor<f32>, b: &Tensor<f32>, f: F) -> TorshResult<Tensor<f32>>
where
F: Fn(f32, f32) -> f32 + Send + Sync,
{
if a.shape().dims() != b.shape().dims() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"tensor shapes must match for zip: {:?} vs {:?}",
a.shape().dims(),
b.shape().dims()
),
"tensor_zip",
));
}
let data_a = a.data()?;
let data_b = b.data()?;
let shape = a.shape().dims().to_vec();
let device = a.device();
let result_data: Vec<f32> = if data_a.len() > 10000 {
use scirs2_core::parallel_ops::*;
let pairs: Vec<(f32, f32)> = data_a.iter().copied().zip(data_b.iter().copied()).collect();
pairs.into_par_iter().map(|(x, y)| f(x, y)).collect()
} else {
data_a
.iter()
.zip(data_b.iter())
.map(|(&x, &y)| f(x, y))
.collect()
};
Tensor::from_data(result_data, shape, device)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_tensor_map() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let output = tensor_map(&input, |x| x * 2.0).expect("map failed");
let output_data = output.data().expect("failed to get data");
assert_relative_eq!(output_data[0], 2.0, epsilon = 1e-6);
assert_relative_eq!(output_data[1], 4.0, epsilon = 1e-6);
assert_relative_eq!(output_data[2], 6.0, epsilon = 1e-6);
assert_relative_eq!(output_data[3], 8.0, epsilon = 1e-6);
}
#[test]
fn test_tensor_reduce() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let output = tensor_reduce(&input, None, |a, b| a + b, 0.0).expect("reduce failed");
let output_data = output.data().expect("failed to get data");
assert_relative_eq!(output_data[0], 10.0, epsilon = 1e-6);
}
#[test]
fn test_tensor_fold() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let result = tensor_fold(&input, |acc, x| acc + x, 0.0).expect("fold failed");
assert_relative_eq!(result, 10.0, epsilon = 1e-6);
}
#[test]
fn test_tensor_scan() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let output = tensor_scan(&input, 0, |a, b| a + b, 0.0).expect("scan failed");
let output_data = output.data().expect("failed to get data");
assert_relative_eq!(output_data[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(output_data[1], 3.0, epsilon = 1e-6);
assert_relative_eq!(output_data[2], 6.0, epsilon = 1e-6);
assert_relative_eq!(output_data[3], 10.0, epsilon = 1e-6);
}
#[test]
fn test_tensor_outer() {
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0],
vec![3],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let b = Tensor::from_data(vec![4.0, 5.0], vec![2], torsh_core::device::DeviceType::Cpu)
.expect("failed to create tensor");
let c = tensor_outer(&a, &b).expect("outer product failed");
assert_eq!(c.shape().dims(), &[3, 2]);
let c_data = c.data().expect("failed to get data");
assert_relative_eq!(c_data[0], 4.0, epsilon = 1e-6); assert_relative_eq!(c_data[1], 5.0, epsilon = 1e-6); assert_relative_eq!(c_data[2], 8.0, epsilon = 1e-6); assert_relative_eq!(c_data[3], 10.0, epsilon = 1e-6); }
#[test]
fn test_tensor_zip() {
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let b = Tensor::from_data(
vec![5.0, 6.0, 7.0, 8.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let c = tensor_zip(&a, &b, |x, y| x + y).expect("zip failed");
let c_data = c.data().expect("failed to get data");
assert_relative_eq!(c_data[0], 6.0, epsilon = 1e-6);
assert_relative_eq!(c_data[1], 8.0, epsilon = 1e-6);
assert_relative_eq!(c_data[2], 10.0, epsilon = 1e-6);
assert_relative_eq!(c_data[3], 12.0, epsilon = 1e-6);
}
#[test]
fn test_parse_einsum_equation() {
let (inputs, output) = parse_einsum_equation("ij,jk->ik").expect("parse failed");
assert_eq!(inputs, vec!["ij", "jk"]);
assert_eq!(output, "ik");
let (inputs, output) = parse_einsum_equation("ii->").expect("parse failed");
assert_eq!(inputs, vec!["ii"]);
assert_eq!(output, "");
}
#[test]
fn test_tensor_reduce_axis() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let output = tensor_reduce(&input, Some(0), |a, b| a + b, 0.0).expect("reduce failed");
assert_eq!(output.shape().dims(), &[3]);
let output_data = output.data().expect("failed to get data");
assert_relative_eq!(output_data[0], 5.0, epsilon = 1e-6); assert_relative_eq!(output_data[1], 7.0, epsilon = 1e-6); assert_relative_eq!(output_data[2], 9.0, epsilon = 1e-6); }
#[test]
fn test_dp_path_single_tensor() {
let index_sizes = std::collections::HashMap::from([('i', 3usize), ('j', 4usize)]);
let inputs = vec!["ij".to_string()];
let path = optimize_contraction_path(&inputs, "ij", &index_sizes)
.expect("optimize should succeed");
assert!(path.is_empty(), "single-tensor path must be empty");
}
#[test]
fn test_dp_path_two_tensors() {
let index_sizes =
std::collections::HashMap::from([('i', 10usize), ('j', 20usize), ('k', 30usize)]);
let inputs = vec!["ij".to_string(), "jk".to_string()];
let path = optimize_contraction_path(&inputs, "ik", &index_sizes)
.expect("optimize should succeed");
assert_eq!(path.len(), 1, "two-tensor path must have exactly one step");
let step = &path[0];
assert!(
(step.operand1 == 0 && step.operand2 == 1)
|| (step.operand1 == 1 && step.operand2 == 0),
"step must reference operands 0 and 1, got ({}, {})",
step.operand1,
step.operand2
);
}
#[test]
fn test_dp_path_three_tensors() {
let index_sizes = std::collections::HashMap::from([
('i', 5usize),
('j', 100usize), ('k', 4usize),
('l', 6usize),
]);
let inputs = vec!["ij".to_string(), "jk".to_string(), "kl".to_string()];
let path = optimize_contraction_path(&inputs, "il", &index_sizes)
.expect("optimize should succeed");
assert_eq!(
path.len(),
2,
"three-tensor path must have exactly two steps"
);
}
#[test]
fn test_dp_path_optimal_vs_greedy_cost() {
use std::collections::HashMap;
let index_sizes: HashMap<char, usize> = HashMap::from([
('i', 2usize),
('j', 500usize), ('k', 3usize),
('l', 4usize),
('m', 2usize),
]);
let inputs = vec![
"ij".to_string(),
"jk".to_string(),
"kl".to_string(),
"lm".to_string(),
];
let path = optimize_contraction_path(&inputs, "im", &index_sizes)
.expect("optimize should succeed");
assert_eq!(
path.len(),
3,
"four-tensor path must have exactly three steps"
);
}
#[test]
fn test_dp_path_greedy_fallback_agreement() {
use std::collections::HashMap;
let index_sizes: HashMap<char, usize> =
HashMap::from([('a', 4usize), ('b', 5usize), ('c', 6usize)]);
let inputs = vec!["ab".to_string(), "bc".to_string()];
let dp_path = optimize_contraction_path(&inputs, "ac", &index_sizes)
.expect("dp optimize should succeed");
let greedy = greedy_contraction_path(&inputs, "ac", &index_sizes)
.expect("greedy optimize should succeed");
assert_eq!(dp_path.len(), greedy.len(), "path lengths must match");
}
#[test]
fn test_infer_output_indices() {
let inputs = vec!["ij".to_string(), "jk".to_string()];
let output = infer_output_indices(&inputs);
assert!(output.contains('i'), "output must contain 'i'");
assert!(output.contains('k'), "output must contain 'k'");
assert!(!output.contains('j'), "output must not contain 'j'");
}
#[test]
fn test_compute_pairwise_result_keeps_outside_chars() {
let all_live = vec!["ij".to_string(), "jk".to_string(), "km".to_string()];
let result = compute_pairwise_result("ij", "jk", "im", &all_live, 0, 1);
assert!(
result.contains('k'),
"k must survive because tensor 2 uses it"
);
assert!(!result.contains('j'), "j must be contracted away");
}
}