use crate::Scalar;
use crate::error::{CoreError, Result};
use crate::tensor::Tensor;
use std::collections::{BTreeMap, BTreeSet};
use super::einsum::einsum;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PathStrategy {
Greedy,
Optimal,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ContractionPair {
pub first: usize,
pub second: usize,
}
#[derive(Debug, Clone)]
pub struct PathInfo {
pub path: Vec<ContractionPair>,
pub flops: usize,
pub largest_intermediate: usize,
}
#[derive(Debug, Clone)]
struct OperandDesc {
indices: Vec<char>,
}
type ParsedSubscripts = (Vec<Vec<char>>, Vec<char>, BTreeMap<char, usize>);
pub fn einsum_path<T: Scalar>(
subscripts: &str,
operands: &[&Tensor<T>],
strategy: PathStrategy,
) -> Result<PathInfo> {
let (input_subs, output_sub, index_sizes) = parse_path_subscripts(subscripts, operands)?;
let descs: Vec<OperandDesc> = input_subs
.iter()
.map(|indices| OperandDesc {
indices: indices.clone(),
})
.collect();
match strategy {
PathStrategy::Greedy => greedy_path(&descs, &output_sub, &index_sizes),
PathStrategy::Optimal => optimal_path(&descs, &output_sub, &index_sizes),
}
}
pub fn einsum_optimized<T: Scalar>(
subscripts: &str,
operands: &[&Tensor<T>],
strategy: PathStrategy,
) -> Result<Tensor<T>> {
if operands.len() <= 2 {
return einsum(subscripts, operands);
}
let (input_subs, output_sub, index_sizes) = parse_path_subscripts(subscripts, operands)?;
let descs: Vec<OperandDesc> = input_subs
.iter()
.map(|indices| OperandDesc {
indices: indices.clone(),
})
.collect();
let path_info = match strategy {
PathStrategy::Greedy => greedy_path(&descs, &output_sub, &index_sizes)?,
PathStrategy::Optimal => optimal_path(&descs, &output_sub, &index_sizes)?,
};
execute_path(subscripts, operands, &path_info, &input_subs, &output_sub)
}
fn parse_path_subscripts<T: Scalar>(
subscripts: &str,
operands: &[&Tensor<T>],
) -> Result<ParsedSubscripts> {
let subscripts = subscripts.replace(' ', "");
let (inputs_str, output_sub) = if let Some((inp, out)) = subscripts.split_once("->") {
let output_indices: Vec<char> = out.chars().collect();
(inp.to_string(), output_indices)
} else {
let mut counts: BTreeMap<char, usize> = BTreeMap::new();
for c in subscripts.chars() {
if c == ',' {
continue;
}
*counts.entry(c).or_insert(0) += 1;
}
let output_indices: Vec<char> = counts
.iter()
.filter(|(_, count)| **count == 1)
.map(|(&c, _)| c)
.collect();
(subscripts.clone(), output_indices)
};
let input_parts: Vec<&str> = inputs_str.split(',').collect();
if input_parts.len() != operands.len() {
return Err(CoreError::InvalidArgument {
reason: "number of subscript groups does not match number of operands",
});
}
let mut input_subs = Vec::with_capacity(input_parts.len());
let mut index_sizes: BTreeMap<char, usize> = BTreeMap::new();
for (i, part) in input_parts.iter().enumerate() {
let indices: Vec<char> = part.chars().collect();
if indices.len() != operands[i].ndim() {
return Err(CoreError::InvalidArgument {
reason: "operand rank does not match number of subscript indices",
});
}
let shape = operands[i].shape();
for (d, &c) in indices.iter().enumerate() {
if let Some(&existing) = index_sizes.get(&c) {
if existing != shape[d] {
return Err(CoreError::DimensionMismatch {
expected: vec![existing],
got: vec![shape[d]],
});
}
} else {
index_sizes.insert(c, shape[d]);
}
}
input_subs.push(indices);
}
Ok((input_subs, output_sub, index_sizes))
}
fn contraction_cost(
a: &OperandDesc,
b: &OperandDesc,
output_indices: &[char],
index_sizes: &BTreeMap<char, usize>,
) -> (usize, Vec<char>, Vec<usize>) {
let a_set: BTreeSet<char> = a.indices.iter().copied().collect();
let b_set: BTreeSet<char> = b.indices.iter().copied().collect();
let output_set: BTreeSet<char> = output_indices.iter().copied().collect();
let contracted: BTreeSet<char> = a_set
.intersection(&b_set)
.filter(|c| !output_set.contains(c))
.copied()
.collect();
let mut result_indices: Vec<char> = Vec::new();
for &c in &a.indices {
if !contracted.contains(&c) && !result_indices.contains(&c) {
result_indices.push(c);
}
}
for &c in &b.indices {
if !contracted.contains(&c) && !result_indices.contains(&c) {
result_indices.push(c);
}
}
let result_shape: Vec<usize> = result_indices.iter().map(|c| index_sizes[c]).collect();
let result_size: usize = result_shape.iter().product::<usize>().max(1);
let contract_size: usize = contracted
.iter()
.map(|c| index_sizes[c])
.product::<usize>()
.max(1);
let flops = result_size * contract_size;
(flops, result_indices, result_shape)
}
fn pairwise_result_indices(
a: &OperandDesc,
b: &OperandDesc,
remaining: &[OperandDesc],
final_output: &[char],
index_sizes: &BTreeMap<char, usize>,
) -> (Vec<char>, Vec<usize>) {
let a_set: BTreeSet<char> = a.indices.iter().copied().collect();
let b_set: BTreeSet<char> = b.indices.iter().copied().collect();
let mut needed: BTreeSet<char> = final_output.iter().copied().collect();
for op in remaining {
for &c in &op.indices {
needed.insert(c);
}
}
let contracted: BTreeSet<char> = a_set
.intersection(&b_set)
.filter(|c| !needed.contains(c))
.copied()
.collect();
let mut result_indices: Vec<char> = Vec::new();
for &c in &a.indices {
if !contracted.contains(&c) && !result_indices.contains(&c) {
result_indices.push(c);
}
}
for &c in &b.indices {
if !contracted.contains(&c) && !result_indices.contains(&c) {
result_indices.push(c);
}
}
let result_shape: Vec<usize> = result_indices.iter().map(|c| index_sizes[c]).collect();
(result_indices, result_shape)
}
#[allow(clippy::unnecessary_wraps)]
fn greedy_path(
descs: &[OperandDesc],
output_sub: &[char],
index_sizes: &BTreeMap<char, usize>,
) -> Result<PathInfo> {
let n = descs.len();
if n <= 1 {
return Ok(PathInfo {
path: vec![],
flops: 0,
largest_intermediate: 0,
});
}
let mut current: Vec<OperandDesc> = descs.to_vec();
let mut path = Vec::with_capacity(n - 1);
let mut total_flops = 0usize;
let mut largest_intermediate = 0usize;
while current.len() > 1 {
let mut best_cost = usize::MAX;
let mut best_i = 0;
let mut best_j = 1;
for i in 0..current.len() {
for j in (i + 1)..current.len() {
let (cost, _, _) =
contraction_cost(¤t[i], ¤t[j], output_sub, index_sizes);
if cost < best_cost {
best_cost = cost;
best_i = i;
best_j = j;
}
}
}
let remaining: Vec<OperandDesc> = current
.iter()
.enumerate()
.filter(|&(k, _)| k != best_i && k != best_j)
.map(|(_, d)| d.clone())
.collect();
let (result_indices, result_shape) = pairwise_result_indices(
¤t[best_i],
¤t[best_j],
&remaining,
output_sub,
index_sizes,
);
let result_size: usize = result_shape.iter().product::<usize>().max(1);
largest_intermediate = largest_intermediate.max(result_size);
total_flops += best_cost;
path.push(ContractionPair {
first: best_i,
second: best_j,
});
current.remove(best_j);
current.remove(best_i);
current.push(OperandDesc {
indices: result_indices,
});
}
Ok(PathInfo {
path,
flops: total_flops,
largest_intermediate,
})
}
fn optimal_path(
descs: &[OperandDesc],
output_sub: &[char],
index_sizes: &BTreeMap<char, usize>,
) -> Result<PathInfo> {
let n = descs.len();
if n <= 1 {
return Ok(PathInfo {
path: vec![],
flops: 0,
largest_intermediate: 0,
});
}
if n > 8 {
return greedy_path(descs, output_sub, index_sizes);
}
let mut best_path: Vec<ContractionPair> = Vec::new();
let mut best_flops = usize::MAX;
let mut best_largest = 0usize;
find_optimal(
descs,
output_sub,
index_sizes,
&mut vec![],
0,
0,
&mut best_path,
&mut best_flops,
&mut best_largest,
);
Ok(PathInfo {
path: best_path,
flops: best_flops,
largest_intermediate: best_largest,
})
}
#[allow(clippy::too_many_arguments)]
fn find_optimal(
current: &[OperandDesc],
output_sub: &[char],
index_sizes: &BTreeMap<char, usize>,
current_path: &mut Vec<ContractionPair>,
current_flops: usize,
current_largest: usize,
best_path: &mut Vec<ContractionPair>,
best_flops: &mut usize,
best_largest: &mut usize,
) {
if current.len() <= 1 {
if current_flops < *best_flops {
*best_flops = current_flops;
*best_path = current_path.clone();
*best_largest = current_largest;
}
return;
}
if current_flops >= *best_flops {
return;
}
for i in 0..current.len() {
for j in (i + 1)..current.len() {
let remaining: Vec<OperandDesc> = current
.iter()
.enumerate()
.filter(|&(k, _)| k != i && k != j)
.map(|(_, d)| d.clone())
.collect();
let (cost, _, _) = contraction_cost(¤t[i], ¤t[j], output_sub, index_sizes);
let (result_indices, result_shape) = pairwise_result_indices(
¤t[i],
¤t[j],
&remaining,
output_sub,
index_sizes,
);
let result_size: usize = result_shape.iter().product::<usize>().max(1);
let mut next = remaining;
next.push(OperandDesc {
indices: result_indices,
});
current_path.push(ContractionPair {
first: i,
second: j,
});
find_optimal(
&next,
output_sub,
index_sizes,
current_path,
current_flops + cost,
current_largest.max(result_size),
best_path,
best_flops,
best_largest,
);
current_path.pop();
}
}
}
fn execute_path<T: Scalar>(
_subscripts: &str,
operands: &[&Tensor<T>],
path_info: &PathInfo,
input_subs: &[Vec<char>],
final_output: &[char],
) -> Result<Tensor<T>> {
let mut tensors: Vec<(Vec<char>, Tensor<T>)> = input_subs
.iter()
.zip(operands.iter())
.map(|(indices, &t)| (indices.clone(), t.clone()))
.collect();
for step in &path_info.path {
let j = step.second;
let i = step.first;
let (b_indices, b_tensor) = tensors.remove(j);
let (a_indices, a_tensor) = tensors.remove(i);
let remaining_descs: Vec<OperandDesc> = tensors
.iter()
.map(|(indices, _)| OperandDesc {
indices: indices.clone(),
})
.collect();
let a_desc = OperandDesc {
indices: a_indices.clone(),
};
let b_desc = OperandDesc {
indices: b_indices.clone(),
};
let mut local_sizes: BTreeMap<char, usize> = BTreeMap::new();
for (c, &s) in a_indices.iter().zip(a_tensor.shape().iter()) {
local_sizes.insert(*c, s);
}
for (c, &s) in b_indices.iter().zip(b_tensor.shape().iter()) {
local_sizes.insert(*c, s);
}
let (result_indices, _result_shape) = pairwise_result_indices(
&a_desc,
&b_desc,
&remaining_descs,
final_output,
&local_sizes,
);
let a_sub: String = a_indices.iter().collect();
let b_sub: String = b_indices.iter().collect();
let out_sub: String = result_indices.iter().collect();
let pair_subscripts = format!("{a_sub},{b_sub}->{out_sub}");
let result = einsum(&pair_subscripts, &[&a_tensor, &b_tensor])?;
tensors.push((result_indices, result));
}
if tensors.len() == 1 {
let (current_indices, tensor) = tensors.pop().unwrap();
if current_indices == final_output {
Ok(tensor)
} else {
let cur_sub: String = current_indices.iter().collect();
let out_sub: String = final_output.iter().collect();
let reorder = format!("{cur_sub}->{out_sub}");
einsum(&reorder, &[&tensor])
}
} else {
Err(CoreError::InvalidArgument {
reason: "einsum path execution did not reduce to a single tensor",
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_einsum_path_chain_matmul() {
let a = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
let b = Tensor::from_vec(vec![1.0_f64; 12], vec![3, 4]).unwrap();
let c = Tensor::from_vec(vec![1.0_f64; 20], vec![4, 5]).unwrap();
let info = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
assert_eq!(info.path.len(), 2);
assert!(info.flops > 0);
}
#[test]
fn test_einsum_path_optimal_vs_greedy() {
let a = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
let b = Tensor::from_vec(vec![1.0_f64; 12], vec![3, 4]).unwrap();
let c = Tensor::from_vec(vec![1.0_f64; 20], vec![4, 5]).unwrap();
let greedy = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
let optimal = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Optimal).unwrap();
assert!(optimal.flops <= greedy.flops);
assert_eq!(optimal.path.len(), 2);
}
#[test]
fn test_einsum_optimized_chain_matmul() {
let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let b = Tensor::from_vec(
vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
vec![3, 4],
)
.unwrap();
let c = Tensor::from_vec(
vec![
1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
],
vec![4, 5],
)
.unwrap();
let direct = einsum("ij,jk,kl->il", &[&a, &b, &c]).unwrap();
let optimized =
einsum_optimized("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
assert_eq!(direct.shape(), optimized.shape());
for (a, b) in direct.as_slice().iter().zip(optimized.as_slice().iter()) {
assert!((a - b).abs() < 1e-10, "mismatch: {a} vs {b}");
}
}
#[test]
fn test_einsum_optimized_four_operands() {
let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
let c = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
let d = Tensor::from_vec(vec![2.0, 1.0, 1.0, 2.0], vec![2, 2]).unwrap();
let direct = einsum("ij,jk,kl,lm->im", &[&a, &b, &c, &d]).unwrap();
let optimized =
einsum_optimized("ij,jk,kl,lm->im", &[&a, &b, &c, &d], PathStrategy::Optimal).unwrap();
assert_eq!(direct.shape(), optimized.shape());
for (x, y) in direct.as_slice().iter().zip(optimized.as_slice().iter()) {
assert!((x - y).abs() < 1e-10, "mismatch: {x} vs {y}");
}
}
#[test]
fn test_einsum_path_two_operands() {
let a = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
let b = Tensor::from_vec(vec![1.0_f64; 6], vec![3, 2]).unwrap();
let result = einsum_optimized("ij,jk->ik", &[&a, &b], PathStrategy::Greedy).unwrap();
assert_eq!(result.shape(), &[2, 2]);
}
#[test]
fn test_einsum_path_single_operand() {
let a = Tensor::from_vec(vec![1.0_f64; 4], vec![2, 2]).unwrap();
let info = einsum_path("ij->ji", &[&a], PathStrategy::Greedy).unwrap();
assert!(info.path.is_empty());
}
#[test]
fn test_einsum_path_asymmetric_shapes() {
let a = Tensor::from_vec(vec![1.0_f64; 200], vec![100, 2]).unwrap();
let b = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
let c = Tensor::from_vec(vec![1.0_f64; 300], vec![3, 100]).unwrap();
let info = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
assert_eq!(info.path.len(), 2);
assert!(info.flops > 0);
}
#[test]
fn test_einsum_optimized_with_trace() {
let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
let c = Tensor::from_vec(vec![3.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
let direct = einsum("ij,jk,kk->i", &[&a, &b, &c]).unwrap();
let optimized =
einsum_optimized("ij,jk,kk->i", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
assert_eq!(direct.shape(), optimized.shape());
for (x, y) in direct.as_slice().iter().zip(optimized.as_slice().iter()) {
assert!((x - y).abs() < 1e-10);
}
}
}