use crate::numeric::{Numeric, NumericRef};
use crate::tensors::indexing::DynamicShapeIterator;
use crate::tensors::views::{TensorRef, TensorRename, TensorView};
use crate::tensors::{Dimension, Tensor};
use std::collections::HashSet;
use std::error::Error;
use std::fmt;
#[derive(Clone, Debug, Default)]
pub struct Einsum {
_private: (),
}
impl Einsum {
pub fn with_1<T, S, I, const D: usize>(input_1: I) -> Einsum1<T, S, D>
where
S: TensorRef<T, D>,
I: Into<TensorView<T, S, D>>,
{
Einsum1 {
tensor_1: input_1.into(),
}
}
pub fn with_2<T, S1, S2, I1, I2, const D1: usize, const D2: usize>(
input_1: I1,
input_2: I2,
) -> Einsum2<T, S1, S2, D1, D2>
where
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
I1: Into<TensorView<T, S1, D1>>,
I2: Into<TensorView<T, S2, D2>>,
{
Einsum2 {
tensor_1: input_1.into(),
tensor_2: input_2.into(),
}
}
pub fn with_3<T, S1, S2, S3, I1, I2, I3, const D1: usize, const D2: usize, const D3: usize>(
input_1: I1,
input_2: I2,
input_3: I3,
) -> Einsum3<T, S1, S2, S3, D1, D2, D3>
where
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
I1: Into<TensorView<T, S1, D1>>,
I2: Into<TensorView<T, S2, D2>>,
I3: Into<TensorView<T, S3, D3>>,
{
Einsum3 {
tensor_1: input_1.into(),
tensor_2: input_2.into(),
tensor_3: input_3.into(),
}
}
pub fn with_4<
T,
S1,
S2,
S3,
S4,
I1,
I2,
I3,
I4,
const D1: usize,
const D2: usize,
const D3: usize,
const D4: usize,
>(
input_1: I1,
input_2: I2,
input_3: I3,
input_4: I4,
) -> Einsum4<T, S1, S2, S3, S4, D1, D2, D3, D4>
where
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
S4: TensorRef<T, D4>,
I1: Into<TensorView<T, S1, D1>>,
I2: Into<TensorView<T, S2, D2>>,
I3: Into<TensorView<T, S3, D3>>,
I4: Into<TensorView<T, S4, D4>>,
{
Einsum4 {
tensor_1: input_1.into(),
tensor_2: input_2.into(),
tensor_3: input_3.into(),
tensor_4: input_4.into(),
}
}
pub fn with_5<
T,
S1,
S2,
S3,
S4,
S5,
I1,
I2,
I3,
I4,
I5,
const D1: usize,
const D2: usize,
const D3: usize,
const D4: usize,
const D5: usize,
>(
input_1: I1,
input_2: I2,
input_3: I3,
input_4: I4,
input_5: I5,
) -> Einsum5<T, S1, S2, S3, S4, S5, D1, D2, D3, D4, D5>
where
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
S4: TensorRef<T, D4>,
S5: TensorRef<T, D5>,
I1: Into<TensorView<T, S1, D1>>,
I2: Into<TensorView<T, S2, D2>>,
I3: Into<TensorView<T, S3, D3>>,
I4: Into<TensorView<T, S4, D4>>,
I5: Into<TensorView<T, S5, D5>>,
{
Einsum5 {
tensor_1: input_1.into(),
tensor_2: input_2.into(),
tensor_3: input_3.into(),
tensor_4: input_4.into(),
tensor_5: input_5.into(),
}
}
pub fn with_6<
T,
S1,
S2,
S3,
S4,
S5,
S6,
I1,
I2,
I3,
I4,
I5,
I6,
const D1: usize,
const D2: usize,
const D3: usize,
const D4: usize,
const D5: usize,
const D6: usize,
>(
input_1: I1,
input_2: I2,
input_3: I3,
input_4: I4,
input_5: I5,
input_6: I6,
) -> Einsum6<T, S1, S2, S3, S4, S5, S6, D1, D2, D3, D4, D5, D6>
where
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
S4: TensorRef<T, D4>,
S5: TensorRef<T, D5>,
S6: TensorRef<T, D6>,
I1: Into<TensorView<T, S1, D1>>,
I2: Into<TensorView<T, S2, D2>>,
I3: Into<TensorView<T, S3, D3>>,
I4: Into<TensorView<T, S4, D4>>,
I5: Into<TensorView<T, S5, D5>>,
I6: Into<TensorView<T, S6, D6>>,
{
Einsum6 {
tensor_1: input_1.into(),
tensor_2: input_2.into(),
tensor_3: input_3.into(),
tensor_4: input_4.into(),
tensor_5: input_5.into(),
tensor_6: input_6.into(),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct InconsistentDimensionLengthError<const I: usize> {
pub lengths: [Option<usize>; I],
pub dimension: Dimension,
}
impl<const I: usize> fmt::Display for InconsistentDimensionLengthError<I> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"inconsistent dimension lengths for dimension '{}': {:?}, lengths must match when repeated in different shapes as the same dimension name",
self.dimension,
self.lengths,
)
}
}
impl<const I: usize> Error for InconsistentDimensionLengthError<I> {}
#[test]
fn test_inconsistent_dimension_length_error() {
let error = InconsistentDimensionLengthError {
lengths: [Some(3), None, Some(2)],
dimension: "a",
};
assert_eq!(
error.to_string(),
"inconsistent dimension lengths for dimension 'a': [Some(3), None, Some(2)], lengths must match when repeated in different shapes as the same dimension name",
)
}
#[allow(dead_code)]
#[derive(Clone, Debug, Eq, PartialEq)]
struct Contraction {
tensor_indexes: Vec<usize>,
}
#[allow(dead_code)]
impl Contraction {
fn from(tensor_indexes: Vec<usize>) -> Contraction {
Contraction { tensor_indexes }
}
fn indexes(&self) -> &[usize] {
&self.tensor_indexes
}
}
#[allow(dead_code)]
#[derive(Clone, Debug, Eq, PartialEq)]
struct StepByStepContractionResult {
input_shapes_left: Vec<Vec<(Dimension, usize)>>,
contraction_output: Vec<(Dimension, usize)>,
}
#[allow(dead_code)]
fn step_by_step_contraction(
input_shapes_left: &[&[(Dimension, usize)]],
output_shape: &[(Dimension, usize)],
contraction: &Contraction,
) -> StepByStepContractionResult {
let contracting: Vec<&[(Dimension, usize)]> = contraction
.tensor_indexes
.iter()
.map(|index| input_shapes_left[*index])
.collect();
let not_contracting_yet: Vec<&[(Dimension, usize)]> = input_shapes_left
.iter()
.enumerate()
.filter(|(i, _)| !contraction.tensor_indexes.contains(i))
.map(|(_, s)| *s)
.collect();
let contracting_dimensions: Vec<(Dimension, usize)> = {
let mut seen = HashSet::new();
let mut set = Vec::new();
for shape in &contracting {
for d in shape.iter() {
let new = seen.insert(*d);
if new {
set.push(*d);
}
}
}
set
};
let retained_dimensions: Vec<(Dimension, usize)> = {
let mut seen = HashSet::new();
let mut set = Vec::new();
for shape in ¬_contracting_yet {
for d in shape.iter() {
let new = seen.insert(*d);
if new {
set.push(*d);
}
}
}
for d in output_shape.iter() {
let new = seen.insert(*d);
if new {
set.push(*d);
}
}
set
};
let contraction_output: Vec<(Dimension, usize)> = {
let mut intersection = retained_dimensions.clone();
intersection.retain(|shape| contracting_dimensions.contains(shape));
intersection
};
let new_input_shapes_left = {
let mut vec = Vec::with_capacity(not_contracting_yet.len() + 1);
for d in not_contracting_yet.iter() {
vec.push(d.to_vec());
}
vec.push(contraction_output.clone());
vec
};
StepByStepContractionResult {
contraction_output,
input_shapes_left: new_input_shapes_left,
}
}
fn length_of<const I: usize>(
output_dimension: Dimension,
input: &[&[(Dimension, usize)]; I],
) -> Result<usize, InconsistentDimensionLengthError<I>> {
let lengths = input.map(|shapes| {
shapes
.iter()
.find(|(dimension, _length)| *dimension == output_dimension)
.map(|(_dimension, length)| *length)
});
let first_length = lengths.iter().filter_map(|l| *l).next();
if let Some(length) = first_length {
if lengths.iter().any(|l| l.is_some() && *l != Some(length)) {
Err(InconsistentDimensionLengthError {
lengths,
dimension: output_dimension,
})
} else {
Ok(length)
}
} else {
Err(InconsistentDimensionLengthError {
lengths,
dimension: output_dimension,
})
}
}
#[track_caller]
fn tensor_with_name<T, I, S, const D: usize>(
dimensions: [Dimension; D],
tensor: I,
) -> TensorView<T, TensorRename<T, S, D>, D>
where
I: Into<TensorView<T, S, D>>,
S: TensorRef<T, D>,
{
let source: S = tensor.into().source();
let with_names = TensorRename::from(source, dimensions);
TensorView::from(with_names)
}
fn output_shape_for<const I: usize, const O: usize>(
input: &[&[(Dimension, usize)]; I],
output: &[Dimension; O],
) -> Result<[(Dimension, usize); O], InconsistentDimensionLengthError<I>> {
let mut output_shape = std::array::from_fn(|d| (output[d], 0));
for x in output_shape.iter_mut() {
x.1 = length_of(x.0, input)?;
}
Ok(output_shape)
}
fn summation_dimensions<const I: usize, const O: usize>(
input: &[&[(Dimension, usize)]; I],
output: &[Dimension; O],
) -> Result<Vec<(Dimension, usize)>, InconsistentDimensionLengthError<I>> {
let mut total_dimensions = 0;
for shape in input {
total_dimensions += shape.len();
}
let mut unique_dimensions = Vec::with_capacity(total_dimensions);
for shape in input {
for (dimension, length) in shape.iter() {
if output.contains(dimension) {
continue;
}
let existing = unique_dimensions.iter().find(|(d, _)| d == dimension);
match existing {
None => unique_dimensions.push((*dimension, *length)),
Some((_, l)) => {
if length != l {
return Err(InconsistentDimensionLengthError {
lengths: std::array::from_fn(|i| {
input[i]
.iter()
.find(|(d, _)| d == dimension)
.map(|(_, l)| *l)
}),
dimension,
});
}
}
}
}
}
Ok(unique_dimensions)
}
fn filter_outer_indexes<const D: usize, const O: usize>(
outer_indexes: &[usize; O],
outer_shape: &[(Dimension, usize); O],
input_shape: &[(Dimension, usize); D],
) -> [usize; D] {
let mut input_indexes = [0; D];
for d in 0..D {
let mut found = false;
let dimension = input_shape[d].0;
for o in 0..O {
let possible_dimension = outer_shape[o].0;
if possible_dimension == dimension {
input_indexes[d] = outer_indexes[o];
found = true;
break;
}
}
if !found {
panic!(
"Expected to find an index for dimension {:?} but was not present in {:?} for {:?} while trying to index tensor of shape {:?}",
dimension,
outer_indexes,
outer_shape,
input_shape,
);
}
}
input_indexes
}
fn filter_outer_and_summation_indexes<const D: usize, const O: usize>(
outer_indexes: &[usize; O],
outer_shape: &[(Dimension, usize); O],
summation_indexes: &[usize],
summation_shape: &[(Dimension, usize)],
input_shape: &[(Dimension, usize); D],
) -> [usize; D] {
let mut input_indexes = [0; D];
for d in 0..D {
let mut found = false;
let dimension = input_shape[d].0;
for o in 0..O {
let possible_dimension = outer_shape[o].0;
if possible_dimension == dimension {
input_indexes[d] = outer_indexes[o];
found = true;
break;
}
}
let summation_iter = summation_indexes.iter().zip(summation_shape.iter());
for (index, (possible_dimension, _length)) in summation_iter {
if *possible_dimension == dimension {
input_indexes[d] = *index;
found = true;
break;
}
}
if !found {
panic!(
"Expected to find an index for dimension {:?} but was not present in {:?} for {:?} or {:?} for {:?} while trying to index tensor of shape {:?}",
dimension,
outer_indexes,
outer_shape,
summation_indexes,
summation_shape,
input_shape,
);
}
}
input_indexes
}
pub struct Einsum1<T, S1, const D1: usize> {
tensor_1: TensorView<T, S1, D1>,
}
pub struct Einsum2<T, S1, S2, const D1: usize, const D2: usize> {
tensor_1: TensorView<T, S1, D1>,
tensor_2: TensorView<T, S2, D2>,
}
pub struct Einsum3<T, S1, S2, S3, const D1: usize, const D2: usize, const D3: usize> {
tensor_1: TensorView<T, S1, D1>,
tensor_2: TensorView<T, S2, D2>,
tensor_3: TensorView<T, S3, D3>,
}
pub struct Einsum4<
T,
S1,
S2,
S3,
S4,
const D1: usize,
const D2: usize,
const D3: usize,
const D4: usize,
> {
tensor_1: TensorView<T, S1, D1>,
tensor_2: TensorView<T, S2, D2>,
tensor_3: TensorView<T, S3, D3>,
tensor_4: TensorView<T, S4, D4>,
}
pub struct Einsum5<
T,
S1,
S2,
S3,
S4,
S5,
const D1: usize,
const D2: usize,
const D3: usize,
const D4: usize,
const D5: usize,
> {
tensor_1: TensorView<T, S1, D1>,
tensor_2: TensorView<T, S2, D2>,
tensor_3: TensorView<T, S3, D3>,
tensor_4: TensorView<T, S4, D4>,
tensor_5: TensorView<T, S5, D5>,
}
pub struct Einsum6<
T,
S1,
S2,
S3,
S4,
S5,
S6,
const D1: usize,
const D2: usize,
const D3: usize,
const D4: usize,
const D5: usize,
const D6: usize,
> {
tensor_1: TensorView<T, S1, D1>,
tensor_2: TensorView<T, S2, D2>,
tensor_3: TensorView<T, S3, D3>,
tensor_4: TensorView<T, S4, D4>,
tensor_5: TensorView<T, S5, D5>,
tensor_6: TensorView<T, S6, D6>,
}
impl<T, S1, const D1: usize> Einsum1<T, S1, D1> {
#[track_caller]
pub fn named(self, input_1: [Dimension; D1]) -> Einsum1<T, TensorRename<T, S1, D1>, D1>
where
S1: TensorRef<T, D1>,
{
Einsum1 {
tensor_1: tensor_with_name(input_1, self.tensor_1),
}
}
pub fn to<const O: usize>(
self,
output: [Dimension; O],
) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<1>>
where
T: Numeric,
for<'a> &'a T: NumericRef<T>,
S1: TensorRef<T, D1>,
{
let input_1_shape_const = &self.tensor_1.shape();
let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
let input = &[input_1_shape];
let output_shape = output_shape_for(input, &output)?;
let mut output_tensor = Tensor::empty(output_shape, T::zero());
let summation_dimensions = summation_dimensions(input, &output)?;
let tensor_1_indexing = self.tensor_1.index();
for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
let mut sum = T::zero();
if summation_dimensions.is_empty() {
let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_1_shape_const,
));
sum = sum + product_1;
} else {
let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
loop {
let next = summation_iterator.next();
match next {
Some(summation_indexes) => {
let product_1 =
tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_1_shape_const,
));
sum = sum + product_1;
}
None => break,
}
}
}
*element = sum;
}
Ok(output_tensor)
}
}
impl<T, S1, S2, const D1: usize, const D2: usize> Einsum2<T, S1, S2, D1, D2> {
#[track_caller]
pub fn named(
self,
input_1: [Dimension; D1],
input_2: [Dimension; D2],
) -> Einsum2<T, TensorRename<T, S1, D1>, TensorRename<T, S2, D2>, D1, D2>
where
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
{
Einsum2 {
tensor_1: tensor_with_name(input_1, self.tensor_1),
tensor_2: tensor_with_name(input_2, self.tensor_2),
}
}
pub fn to<const O: usize>(
self,
output: [Dimension; O],
) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<2>>
where
T: Numeric,
for<'a> &'a T: NumericRef<T>,
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
{
let input_1_shape_const = &self.tensor_1.shape();
let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
let input_2_shape_const = &self.tensor_2.shape();
let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
let input = &[input_1_shape, input_2_shape];
let output_shape = output_shape_for(input, &output)?;
let mut output_tensor = Tensor::empty(output_shape, T::zero());
let summation_dimensions = summation_dimensions(input, &output)?;
let tensor_1_indexing = self.tensor_1.index();
let tensor_2_indexing = self.tensor_2.index();
for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
let mut sum = T::zero();
if summation_dimensions.is_empty() {
let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_1_shape_const,
));
let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_2_shape_const,
));
sum = sum + (product_1 * product_2);
} else {
let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
loop {
let next = summation_iterator.next();
match next {
Some(summation_indexes) => {
let product_1 =
tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_1_shape_const,
));
let product_2 =
tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_2_shape_const,
));
sum = sum + (product_1 * product_2);
}
None => break,
}
}
}
*element = sum;
}
Ok(output_tensor)
}
}
impl<T, S1, S2, S3, const D1: usize, const D2: usize, const D3: usize>
Einsum3<T, S1, S2, S3, D1, D2, D3>
{
#[track_caller]
#[allow(clippy::type_complexity)]
pub fn named(
self,
input_1: [Dimension; D1],
input_2: [Dimension; D2],
input_3: [Dimension; D3],
) -> Einsum3<
T,
TensorRename<T, S1, D1>,
TensorRename<T, S2, D2>,
TensorRename<T, S3, D3>,
D1,
D2,
D3,
>
where
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
{
Einsum3 {
tensor_1: tensor_with_name(input_1, self.tensor_1),
tensor_2: tensor_with_name(input_2, self.tensor_2),
tensor_3: tensor_with_name(input_3, self.tensor_3),
}
}
pub fn to<const O: usize>(
self,
output: [Dimension; O],
) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<3>>
where
T: Numeric,
for<'a> &'a T: NumericRef<T>,
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
{
let input_1_shape_const = &self.tensor_1.shape();
let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
let input_2_shape_const = &self.tensor_2.shape();
let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
let input_3_shape_const = &self.tensor_3.shape();
let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
let input = &[input_1_shape, input_2_shape, input_3_shape];
let output_shape = output_shape_for(input, &output)?;
let mut output_tensor = Tensor::empty(output_shape, T::zero());
let summation_dimensions = summation_dimensions(input, &output)?;
let tensor_1_indexing = self.tensor_1.index();
let tensor_2_indexing = self.tensor_2.index();
let tensor_3_indexing = self.tensor_3.index();
for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
let mut sum = T::zero();
if summation_dimensions.is_empty() {
let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_1_shape_const,
));
let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_2_shape_const,
));
let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_3_shape_const,
));
sum = sum + (product_1 * product_2 * product_3);
} else {
let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
loop {
let next = summation_iterator.next();
match next {
Some(summation_indexes) => {
let product_1 =
tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_1_shape_const,
));
let product_2 =
tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_2_shape_const,
));
let product_3 =
tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_3_shape_const,
));
sum = sum + (product_1 * product_2 * product_3);
}
None => break,
}
}
}
*element = sum;
}
Ok(output_tensor)
}
}
impl<T, S1, S2, S3, S4, const D1: usize, const D2: usize, const D3: usize, const D4: usize>
Einsum4<T, S1, S2, S3, S4, D1, D2, D3, D4>
{
#[track_caller]
#[allow(clippy::type_complexity)]
pub fn named(
self,
input_1: [Dimension; D1],
input_2: [Dimension; D2],
input_3: [Dimension; D3],
input_4: [Dimension; D4],
) -> Einsum4<
T,
TensorRename<T, S1, D1>,
TensorRename<T, S2, D2>,
TensorRename<T, S3, D3>,
TensorRename<T, S4, D4>,
D1,
D2,
D3,
D4,
>
where
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
S4: TensorRef<T, D4>,
{
Einsum4 {
tensor_1: tensor_with_name(input_1, self.tensor_1),
tensor_2: tensor_with_name(input_2, self.tensor_2),
tensor_3: tensor_with_name(input_3, self.tensor_3),
tensor_4: tensor_with_name(input_4, self.tensor_4),
}
}
pub fn to<const O: usize>(
self,
output: [Dimension; O],
) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<4>>
where
T: Numeric,
for<'a> &'a T: NumericRef<T>,
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
S4: TensorRef<T, D4>,
{
let input_1_shape_const = &self.tensor_1.shape();
let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
let input_2_shape_const = &self.tensor_2.shape();
let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
let input_3_shape_const = &self.tensor_3.shape();
let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
let input_4_shape_const = &self.tensor_4.shape();
let input_4_shape: &[(Dimension, usize)] = input_4_shape_const;
let input = &[input_1_shape, input_2_shape, input_3_shape, input_4_shape];
let output_shape = output_shape_for(input, &output)?;
let mut output_tensor = Tensor::empty(output_shape, T::zero());
let summation_dimensions = summation_dimensions(input, &output)?;
let tensor_1_indexing = self.tensor_1.index();
let tensor_2_indexing = self.tensor_2.index();
let tensor_3_indexing = self.tensor_3.index();
let tensor_4_indexing = self.tensor_4.index();
for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
let mut sum = T::zero();
if summation_dimensions.is_empty() {
let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_1_shape_const,
));
let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_2_shape_const,
));
let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_3_shape_const,
));
let product_4 = tensor_4_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_4_shape_const,
));
sum = sum + (product_1 * product_2 * product_3 * product_4);
} else {
let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
loop {
let next = summation_iterator.next();
match next {
Some(summation_indexes) => {
let product_1 =
tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_1_shape_const,
));
let product_2 =
tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_2_shape_const,
));
let product_3 =
tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_3_shape_const,
));
let product_4 =
tensor_4_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_4_shape_const,
));
sum = sum + (product_1 * product_2 * product_3 * product_4);
}
None => break,
}
}
}
*element = sum;
}
Ok(output_tensor)
}
}
impl<
T,
S1,
S2,
S3,
S4,
S5,
const D1: usize,
const D2: usize,
const D3: usize,
const D4: usize,
const D5: usize,
> Einsum5<T, S1, S2, S3, S4, S5, D1, D2, D3, D4, D5>
{
#[track_caller]
#[allow(clippy::type_complexity)]
pub fn named(
self,
input_1: [Dimension; D1],
input_2: [Dimension; D2],
input_3: [Dimension; D3],
input_4: [Dimension; D4],
input_5: [Dimension; D5],
) -> Einsum5<
T,
TensorRename<T, S1, D1>,
TensorRename<T, S2, D2>,
TensorRename<T, S3, D3>,
TensorRename<T, S4, D4>,
TensorRename<T, S5, D5>,
D1,
D2,
D3,
D4,
D5,
>
where
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
S4: TensorRef<T, D4>,
S5: TensorRef<T, D5>,
{
Einsum5 {
tensor_1: tensor_with_name(input_1, self.tensor_1),
tensor_2: tensor_with_name(input_2, self.tensor_2),
tensor_3: tensor_with_name(input_3, self.tensor_3),
tensor_4: tensor_with_name(input_4, self.tensor_4),
tensor_5: tensor_with_name(input_5, self.tensor_5),
}
}
pub fn to<const O: usize>(
self,
output: [Dimension; O],
) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<5>>
where
T: Numeric,
for<'a> &'a T: NumericRef<T>,
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
S4: TensorRef<T, D4>,
S5: TensorRef<T, D5>,
{
let input_1_shape_const = &self.tensor_1.shape();
let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
let input_2_shape_const = &self.tensor_2.shape();
let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
let input_3_shape_const = &self.tensor_3.shape();
let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
let input_4_shape_const = &self.tensor_4.shape();
let input_4_shape: &[(Dimension, usize)] = input_4_shape_const;
let input_5_shape_const = &self.tensor_5.shape();
let input_5_shape: &[(Dimension, usize)] = input_5_shape_const;
let input = &[
input_1_shape,
input_2_shape,
input_3_shape,
input_4_shape,
input_5_shape,
];
let output_shape = output_shape_for(input, &output)?;
let mut output_tensor = Tensor::empty(output_shape, T::zero());
let summation_dimensions = summation_dimensions(input, &output)?;
let tensor_1_indexing = self.tensor_1.index();
let tensor_2_indexing = self.tensor_2.index();
let tensor_3_indexing = self.tensor_3.index();
let tensor_4_indexing = self.tensor_4.index();
let tensor_5_indexing = self.tensor_5.index();
for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
let mut sum = T::zero();
if summation_dimensions.is_empty() {
let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_1_shape_const,
));
let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_2_shape_const,
));
let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_3_shape_const,
));
let product_4 = tensor_4_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_4_shape_const,
));
let product_5 = tensor_5_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_5_shape_const,
));
sum = sum + (product_1 * product_2 * product_3 * product_4 * product_5);
} else {
let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
loop {
let next = summation_iterator.next();
match next {
Some(summation_indexes) => {
let product_1 =
tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_1_shape_const,
));
let product_2 =
tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_2_shape_const,
));
let product_3 =
tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_3_shape_const,
));
let product_4 =
tensor_4_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_4_shape_const,
));
let product_5 =
tensor_5_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_5_shape_const,
));
sum = sum + (product_1 * product_2 * product_3 * product_4 * product_5);
}
None => break,
}
}
}
*element = sum;
}
Ok(output_tensor)
}
}
impl<
T,
S1,
S2,
S3,
S4,
S5,
S6,
const D1: usize,
const D2: usize,
const D3: usize,
const D4: usize,
const D5: usize,
const D6: usize,
> Einsum6<T, S1, S2, S3, S4, S5, S6, D1, D2, D3, D4, D5, D6>
{
#[track_caller]
#[allow(clippy::type_complexity)]
pub fn named(
self,
input_1: [Dimension; D1],
input_2: [Dimension; D2],
input_3: [Dimension; D3],
input_4: [Dimension; D4],
input_5: [Dimension; D5],
input_6: [Dimension; D6],
) -> Einsum6<
T,
TensorRename<T, S1, D1>,
TensorRename<T, S2, D2>,
TensorRename<T, S3, D3>,
TensorRename<T, S4, D4>,
TensorRename<T, S5, D5>,
TensorRename<T, S6, D6>,
D1,
D2,
D3,
D4,
D5,
D6,
>
where
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
S4: TensorRef<T, D4>,
S5: TensorRef<T, D5>,
S6: TensorRef<T, D6>,
{
Einsum6 {
tensor_1: tensor_with_name(input_1, self.tensor_1),
tensor_2: tensor_with_name(input_2, self.tensor_2),
tensor_3: tensor_with_name(input_3, self.tensor_3),
tensor_4: tensor_with_name(input_4, self.tensor_4),
tensor_5: tensor_with_name(input_5, self.tensor_5),
tensor_6: tensor_with_name(input_6, self.tensor_6),
}
}
pub fn to<const O: usize>(
self,
output: [Dimension; O],
) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<6>>
where
T: Numeric,
for<'a> &'a T: NumericRef<T>,
S1: TensorRef<T, D1>,
S2: TensorRef<T, D2>,
S3: TensorRef<T, D3>,
S4: TensorRef<T, D4>,
S5: TensorRef<T, D5>,
S6: TensorRef<T, D6>,
{
let input_1_shape_const = &self.tensor_1.shape();
let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
let input_2_shape_const = &self.tensor_2.shape();
let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
let input_3_shape_const = &self.tensor_3.shape();
let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
let input_4_shape_const = &self.tensor_4.shape();
let input_4_shape: &[(Dimension, usize)] = input_4_shape_const;
let input_5_shape_const = &self.tensor_5.shape();
let input_5_shape: &[(Dimension, usize)] = input_5_shape_const;
let input_6_shape_const = &self.tensor_6.shape();
let input_6_shape: &[(Dimension, usize)] = input_6_shape_const;
let input = &[
input_1_shape,
input_2_shape,
input_3_shape,
input_4_shape,
input_5_shape,
input_6_shape,
];
let output_shape = output_shape_for(input, &output)?;
let mut output_tensor = Tensor::empty(output_shape, T::zero());
let summation_dimensions = summation_dimensions(input, &output)?;
let tensor_1_indexing = self.tensor_1.index();
let tensor_2_indexing = self.tensor_2.index();
let tensor_3_indexing = self.tensor_3.index();
let tensor_4_indexing = self.tensor_4.index();
let tensor_5_indexing = self.tensor_5.index();
let tensor_6_indexing = self.tensor_6.index();
for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
let mut sum = T::zero();
if summation_dimensions.is_empty() {
let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_1_shape_const,
));
let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_2_shape_const,
));
let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_3_shape_const,
));
let product_4 = tensor_4_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_4_shape_const,
));
let product_5 = tensor_5_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_5_shape_const,
));
let product_6 = tensor_6_indexing.get_ref(filter_outer_indexes(
&indexes,
&output_shape,
input_6_shape_const,
));
sum = sum + (product_1 * product_2 * product_3 * product_4 * product_5 * product_6);
} else {
let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
loop {
let next = summation_iterator.next();
match next {
Some(summation_indexes) => {
let product_1 =
tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_1_shape_const,
));
let product_2 =
tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_2_shape_const,
));
let product_3 =
tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_3_shape_const,
));
let product_4 =
tensor_4_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_4_shape_const,
));
let product_5 =
tensor_5_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_5_shape_const,
));
let product_6 =
tensor_6_indexing.get_ref(filter_outer_and_summation_indexes(
&indexes,
&output_shape,
summation_indexes,
&summation_dimensions,
input_6_shape_const,
));
sum = sum
+ (product_1
* product_2
* product_3
* product_4
* product_5
* product_6);
}
None => break,
}
}
}
*element = sum;
}
Ok(output_tensor)
}
}
#[test]
fn step_by_step_contraction_tests() {
assert_eq!(
step_by_step_contraction(
&[&[("x", 2), ("y", 3)], &[("y", 3), ("z", 4)]],
&[("x", 2), ("z", 4)],
&Contraction {
tensor_indexes: vec![0, 1]
},
),
StepByStepContractionResult {
input_shapes_left: vec![vec![("x", 2), ("z", 4)]],
contraction_output: vec![("x", 2), ("z", 4)],
}
);
#[rustfmt::skip]
assert_eq!(
step_by_step_contraction(
&[
&[("a", 2), ("b", 3), ("d", 5)],
&[("a", 2), ("c", 4)],
&[("b", 3), ("d", 5), ("c", 4)],
],
&[("a", 2), ("c", 4)],
&Contraction {
tensor_indexes: vec![0, 2]
},
),
StepByStepContractionResult {
input_shapes_left: vec![
vec![("a", 2), ("c", 4)],
vec![("a", 2), ("c", 4)],
],
contraction_output: vec![("a", 2), ("c", 4)],
}
);
assert_eq!(
step_by_step_contraction(
&[
&[("a", 2), ("b", 3), ("d", 5)],
&[("a", 2), ("c", 4)],
&[("b", 3), ("d", 5), ("c", 4)],
],
&[("a", 2), ("c", 4)],
&Contraction {
tensor_indexes: vec![0, 1]
},
),
StepByStepContractionResult {
input_shapes_left: vec![
vec![("b", 3), ("d", 5), ("c", 4)],
vec![("b", 3), ("d", 5), ("c", 4), ("a", 2)],
],
contraction_output: vec![("b", 3), ("d", 5), ("c", 4), ("a", 2)],
}
);
assert_eq!(
step_by_step_contraction(
&[
&[("a", 2), ("b", 3), ("d", 5)],
&[("a", 2), ("c", 4)],
&[("b", 3), ("d", 5), ("c", 4)],
],
&[("c", 4)],
&Contraction {
tensor_indexes: vec![0, 1]
},
),
StepByStepContractionResult {
input_shapes_left: vec![
vec![("b", 3), ("d", 5), ("c", 4)],
vec![("b", 3), ("d", 5), ("c", 4)],
],
contraction_output: vec![("b", 3), ("d", 5), ("c", 4)],
}
);
}