Skip to main content

easy_ml/tensors/
einsum.rs

1/*!
2 * Einstein summation notation
3 *
4 * A very general purpose sum of products that can represent many
5 * different tensor operations with a single notation.
6 *
7 * See [Einsum].
8 */
9
10use crate::numeric::{Numeric, NumericRef};
11use crate::tensors::indexing::DynamicShapeIterator;
12use crate::tensors::views::{TensorRef, TensorRename, TensorView};
13use crate::tensors::{Dimension, Tensor};
14
15use std::collections::HashSet;
16use std::error::Error;
17use std::fmt;
18
19/**
20 * Einstein summation notation
21 *
22 * A very general purpose sum of products that can represent many
23 * different tensor operations with a single notation. In Easy-ML,
24 * as tensors are already named, their dimension names are used instead
25 * of arbitary characters to refer to dimensions across inputs and
26 * the output.
27 *
28 * Whereas the typical notation used in python libraries
29 * is of the form `ab,bc->ac` or `ab->` these would be
30 * `Einsum::with_2(&i, &j).to(["a", "c"])` or
31 * `Einsum::with_1(&i).to([])` respectively. In scenarios
32 * where the existing dimension names in a tensor aren't what you
33 * need for the summation notation, there are `named` helper methods
34 * to provide an override, so you can perform `ab,bc->ac` with
35 * input tensors of different dimension names if you write
36 * `Einsum::with_2(&i, &j).named(["a", "b"], ["b", "c"]).to(["a", "c"])`.
37 *
38 * As with other tensor APIs, the dimension names in a Tensor must be
39 * unique, so diagonal summation notation like `aa->` is not supported.
40 * Dimensions names can and often will be repeated across input tensors and
41 * or the output tensor shape, and each dimension name must have the same length
42 * among all of these inputs. APIs will return [InconsistentDimensionLengthError]
43 * if a caller passes in inconsistent arguments.
44 *
45 * See also
46 * - [Einsum is All you Need - Einstein Summation in Deep Learning](https://rockt.ai/2018/04/30/einsum)
47 * - [Einsum Is All You Need (Video)](https://www.youtube.com/watch?v=pkVwUVEHmfI)
48 *
49 * # You can do everything with Einsum<sup>1</sup>
50 *
51 * ```
52 * use easy_ml::tensors::Tensor;
53 * use easy_ml::tensors::views::TensorView;
54 * use easy_ml::tensors::einsum::Einsum;
55 *
56 * // Note the length of each dimension name needs to be consistent across
57 * // inputs. We know this is the case here because these are constructed examples
58 * // so we just unwrap each Result from the Einsum APIs.
59 * let w = Tensor::from_fn([("a", 4), ("b", 2)], |[a,b]| ((a * 2) + b) as f32);
60 * let x = Tensor::from_fn([("b", 2), ("c", 3)], |[b,c]| ((b * 3) + c) as f32);
61 * let y = Tensor::from_fn([("b", 2), ("c", 3), ("d", 2)], |[b,c,d]| {
62 *     return ((b * 6) + (c * 2) + d) as f32
63 * });
64 * let z = Tensor::from_fn([("a", 4), ("c", 3)], |[a,c]| ((a * 2) + c) as f32);
65 *
66 * let w_transposed = TensorView::from(w.index_by(["b", "a"]));
67 * assert_eq!(w_transposed, Einsum::with_1(&w).to(["b", "a"]).expect(""));
68 *
69 * let x_sum: f32 = x.iter().sum();
70 * assert_eq!(x_sum, Einsum::with_1(&x).to([]).expect("").into_scalar());
71 *
72 * let multiplied = &w * &x;
73 * assert_eq!(multiplied, Einsum::with_2(&w, &x).to(["a", "c"]).expect(""));
74 *
75 * let multiplied_2 = (&w.transpose_view(["b", "a"]) * &z).rename_owned(["b", "c"]);
76 * // No need to transpose first as Einsum doesn't need the matrices ordered "b"x"a" * "a"x"c".
77 * // There is a slight difference in resulting names though, as `*` drops the "a" dimension
78 * // names so the output is "b"x"c", whereas we have specified the same calculation with
79 * // einsum as a "b"*"c" output.
80 * assert_eq!(multiplied_2, Einsum::with_2(&w, &z).to(["b", "c"]).expect(""));
81 *
82 * let batch_multiply = Einsum::with_3(&x, &w, &z).to(["b", "a"]).expect("");
83 * ```
84 *
85 * - 1 - as long as you only need sums of products
86 */
87#[derive(Clone, Debug, Default)]
88pub struct Einsum {
89    _private: (),
90}
91
92impl Einsum {
93    /**
94     * An operation with a single input tensor, taking an input that can
95     * be converted into a TensorView which includes Tensor, &Tensor,
96     * &mut Tensor as well as references to a TensorView.
97     */
98    pub fn with_1<T, S, I, const D: usize>(input_1: I) -> Einsum1<T, S, D>
99    where
100        S: TensorRef<T, D>,
101        I: Into<TensorView<T, S, D>>,
102    {
103        Einsum1 {
104            tensor_1: input_1.into(),
105        }
106    }
107
108    /**
109     * An operation with two input tensors, taking inputs that can
110     * be converted into a TensorView which include Tensor, &Tensor,
111     * &mut Tensor as well as references to a TensorView.
112     */
113    pub fn with_2<T, S1, S2, I1, I2, const D1: usize, const D2: usize>(
114        input_1: I1,
115        input_2: I2,
116    ) -> Einsum2<T, S1, S2, D1, D2>
117    where
118        S1: TensorRef<T, D1>,
119        S2: TensorRef<T, D2>,
120        I1: Into<TensorView<T, S1, D1>>,
121        I2: Into<TensorView<T, S2, D2>>,
122    {
123        Einsum2 {
124            tensor_1: input_1.into(),
125            tensor_2: input_2.into(),
126        }
127    }
128
129    /**
130     * An operation with three input tensors, taking inputs that can
131     * be converted into a TensorView which include Tensor, &Tensor,
132     * &mut Tensor as well as references to a TensorView.
133     */
134    pub fn with_3<T, S1, S2, S3, I1, I2, I3, const D1: usize, const D2: usize, const D3: usize>(
135        input_1: I1,
136        input_2: I2,
137        input_3: I3,
138    ) -> Einsum3<T, S1, S2, S3, D1, D2, D3>
139    where
140        S1: TensorRef<T, D1>,
141        S2: TensorRef<T, D2>,
142        S3: TensorRef<T, D3>,
143        I1: Into<TensorView<T, S1, D1>>,
144        I2: Into<TensorView<T, S2, D2>>,
145        I3: Into<TensorView<T, S3, D3>>,
146    {
147        Einsum3 {
148            tensor_1: input_1.into(),
149            tensor_2: input_2.into(),
150            tensor_3: input_3.into(),
151        }
152    }
153
154    /**
155     * An operation with four input tensors, taking inputs that can
156     * be converted into a TensorView which include Tensor, &Tensor,
157     * &mut Tensor as well as references to a TensorView.
158     */
159    pub fn with_4<
160        T,
161        S1,
162        S2,
163        S3,
164        S4,
165        I1,
166        I2,
167        I3,
168        I4,
169        const D1: usize,
170        const D2: usize,
171        const D3: usize,
172        const D4: usize,
173    >(
174        input_1: I1,
175        input_2: I2,
176        input_3: I3,
177        input_4: I4,
178    ) -> Einsum4<T, S1, S2, S3, S4, D1, D2, D3, D4>
179    where
180        S1: TensorRef<T, D1>,
181        S2: TensorRef<T, D2>,
182        S3: TensorRef<T, D3>,
183        S4: TensorRef<T, D4>,
184        I1: Into<TensorView<T, S1, D1>>,
185        I2: Into<TensorView<T, S2, D2>>,
186        I3: Into<TensorView<T, S3, D3>>,
187        I4: Into<TensorView<T, S4, D4>>,
188    {
189        Einsum4 {
190            tensor_1: input_1.into(),
191            tensor_2: input_2.into(),
192            tensor_3: input_3.into(),
193            tensor_4: input_4.into(),
194        }
195    }
196
197    /**
198     * An operation with five input tensors, taking inputs that can
199     * be converted into a TensorView which include Tensor, &Tensor,
200     * &mut Tensor as well as references to a TensorView.
201     */
202    pub fn with_5<
203        T,
204        S1,
205        S2,
206        S3,
207        S4,
208        S5,
209        I1,
210        I2,
211        I3,
212        I4,
213        I5,
214        const D1: usize,
215        const D2: usize,
216        const D3: usize,
217        const D4: usize,
218        const D5: usize,
219    >(
220        input_1: I1,
221        input_2: I2,
222        input_3: I3,
223        input_4: I4,
224        input_5: I5,
225    ) -> Einsum5<T, S1, S2, S3, S4, S5, D1, D2, D3, D4, D5>
226    where
227        S1: TensorRef<T, D1>,
228        S2: TensorRef<T, D2>,
229        S3: TensorRef<T, D3>,
230        S4: TensorRef<T, D4>,
231        S5: TensorRef<T, D5>,
232        I1: Into<TensorView<T, S1, D1>>,
233        I2: Into<TensorView<T, S2, D2>>,
234        I3: Into<TensorView<T, S3, D3>>,
235        I4: Into<TensorView<T, S4, D4>>,
236        I5: Into<TensorView<T, S5, D5>>,
237    {
238        Einsum5 {
239            tensor_1: input_1.into(),
240            tensor_2: input_2.into(),
241            tensor_3: input_3.into(),
242            tensor_4: input_4.into(),
243            tensor_5: input_5.into(),
244        }
245    }
246
247    /**
248     * An operation with six input tensors, taking inputs that can
249     * be converted into a TensorView which include Tensor, &Tensor,
250     * &mut Tensor as well as references to a TensorView.
251     *
252     * There are no technical limits on extending support to a greater number
253     * of input tensors, but as it's not feasible to write a generic implementation
254     * for any number of inputs at the moment we have to stop somewhere. For large
255     * numbers of inputs it may often be more efficient to break down the operation
256     * into substeps, such as detailed in
257     * <https://optimized-einsum.readthedocs.io/en/stable/index.html>
258     *
259     * Until support for using an optimiser to choose an efficient contraction order
260     * is added, the caller can still manually split larger operations into substeps
261     * by calling Einsum multiple times with a subset of the total inputs.
262     */
263    pub fn with_6<
264        T,
265        S1,
266        S2,
267        S3,
268        S4,
269        S5,
270        S6,
271        I1,
272        I2,
273        I3,
274        I4,
275        I5,
276        I6,
277        const D1: usize,
278        const D2: usize,
279        const D3: usize,
280        const D4: usize,
281        const D5: usize,
282        const D6: usize,
283    >(
284        input_1: I1,
285        input_2: I2,
286        input_3: I3,
287        input_4: I4,
288        input_5: I5,
289        input_6: I6,
290    ) -> Einsum6<T, S1, S2, S3, S4, S5, S6, D1, D2, D3, D4, D5, D6>
291    where
292        S1: TensorRef<T, D1>,
293        S2: TensorRef<T, D2>,
294        S3: TensorRef<T, D3>,
295        S4: TensorRef<T, D4>,
296        S5: TensorRef<T, D5>,
297        S6: TensorRef<T, D6>,
298        I1: Into<TensorView<T, S1, D1>>,
299        I2: Into<TensorView<T, S2, D2>>,
300        I3: Into<TensorView<T, S3, D3>>,
301        I4: Into<TensorView<T, S4, D4>>,
302        I5: Into<TensorView<T, S5, D5>>,
303        I6: Into<TensorView<T, S6, D6>>,
304    {
305        Einsum6 {
306            tensor_1: input_1.into(),
307            tensor_2: input_2.into(),
308            tensor_3: input_3.into(),
309            tensor_4: input_4.into(),
310            tensor_5: input_5.into(),
311            tensor_6: input_6.into(),
312        }
313    }
314}
315
316/**
317 * An error indicating the lengths of dimensions with the same
318 * name were inconsistent in the `I` input tensors.
319 */
320#[derive(Clone, Debug, Eq, PartialEq)]
321pub struct InconsistentDimensionLengthError<const I: usize> {
322    /**
323     * The lengths of each matching dimension name in each input
324     * in the same order as they were passed to the Einsum APIs.
325     *
326     * Some inputs may not have this dimension, so will be None.
327     */
328    pub lengths: [Option<usize>; I],
329    /**
330     * The dimension name with an inconsistency.
331     */
332    pub dimension: Dimension,
333}
334
335impl<const I: usize> fmt::Display for InconsistentDimensionLengthError<I> {
336    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337        write!(
338            f,
339            "inconsistent dimension lengths for dimension '{}': {:?}, lengths must match when repeated in different shapes as the same dimension name",
340            self.dimension,
341            self.lengths,
342        )
343    }
344}
345
346impl<const I: usize> Error for InconsistentDimensionLengthError<I> {}
347
348#[test]
349fn test_inconsistent_dimension_length_error() {
350    let error = InconsistentDimensionLengthError {
351        lengths: [Some(3), None, Some(2)],
352        dimension: "a",
353    };
354    assert_eq!(
355        error.to_string(),
356        "inconsistent dimension lengths for dimension 'a': [Some(3), None, Some(2)], lengths must match when repeated in different shapes as the same dimension name",
357    )
358}
359
360/**
361 * A single step in the contractions of an optimised Einsum calculation.
362 *
363 * The elements in the contraction correspond to indexes for the tensors
364 * left in the overall calculation. At the first contraction, there are as many
365 * tensor inputs as the number of tensors provided by the caller. For
366 * example, a contraction could select the first and third tensors to perform einsum
367 * on first, so would be [0, 2]. These tensors are removed from the remaining
368 * inputs and we add the results of the einsum operation to the end of the list.
369 * If we started with 3 tensors and selected the first and third, we would therefore
370 * have two tensors remaining, the second input (now at index 0) and the intermediate
371 * tensor we created (now at index 1). Therefore we could have
372 * `vec![Contraction::from(vec![0, 2], Contraction::from(vec![0, 1]))]` as our
373 * contraction order to split up an einsum calculation into two smaller substeps.
374 */
375#[allow(dead_code)]
376#[derive(Clone, Debug, Eq, PartialEq)]
377struct Contraction {
378    tensor_indexes: Vec<usize>,
379}
380
381// Will come back to using this eventually
382#[allow(dead_code)]
383impl Contraction {
384    /**
385     * Creates a Contraction from the input indexes.
386     */
387    fn from(tensor_indexes: Vec<usize>) -> Contraction {
388        Contraction { tensor_indexes }
389    }
390
391    /**
392     * Returns a reference to the indexes in this contraction.
393     */
394    fn indexes(&self) -> &[usize] {
395        &self.tensor_indexes
396    }
397}
398
399#[allow(dead_code)]
400#[derive(Clone, Debug, Eq, PartialEq)]
401struct StepByStepContractionResult {
402    input_shapes_left: Vec<Vec<(Dimension, usize)>>,
403    contraction_output: Vec<(Dimension, usize)>,
404}
405
406/// Given already validated input such that each dimension name repeated over
407/// the input_shapes_left and the output_shape share a common dimension length,
408/// returns the new list of input_shapes_left and the dimension names for
409/// the output of this contraction step.
410///
411/// Many thanks to Daniel G. A. Smith for proving assistance with understanding
412/// how this step by step process is done in https://github.com/dgasmith/opt_einsum
413#[allow(dead_code)]
414fn step_by_step_contraction(
415    input_shapes_left: &[&[(Dimension, usize)]],
416    output_shape: &[(Dimension, usize)],
417    contraction: &Contraction,
418) -> StepByStepContractionResult {
419    // 1. take the (Dimension, usize) shapes out of input_shapes_left matching the Contraction
420    // These are the shapes for the tensors we're contracting.
421    let contracting: Vec<&[(Dimension, usize)]> = contraction
422        .tensor_indexes
423        .iter()
424        .map(|index| input_shapes_left[*index])
425        .collect();
426
427    // 2. make a new list for the other ones not in this contraction (might be empty)
428    // These are the shapes for the tensors we'll contract later.
429    let not_contracting_yet: Vec<&[(Dimension, usize)]> = input_shapes_left
430        .iter()
431        .enumerate()
432        .filter(|(i, _)| !contraction.tensor_indexes.contains(i))
433        .map(|(_, s)| *s)
434        .collect();
435
436    // 3. take the union of the dimension names from 1., preserving
437    // the order they were originally in the inputs
438    // These are the dimensions our contraction will be able to remove via
439    // summation if they aren't needed after this step.
440    let contracting_dimensions: Vec<(Dimension, usize)> = {
441        let mut seen = HashSet::new();
442        let mut set = Vec::new();
443        for shape in &contracting {
444            for d in shape.iter() {
445                let new = seen.insert(*d);
446                if new {
447                    set.push(*d);
448                }
449            }
450        }
451        set
452    };
453
454    // 4. take the union of the dimension names from the output_shape and 2.,
455    // preserving the order they were originally in the inputs
456    // These are the dimensions we will still have after this step, due to
457    // them being in the final output shape or just required in a later step.
458    let retained_dimensions: Vec<(Dimension, usize)> = {
459        let mut seen = HashSet::new();
460        let mut set = Vec::new();
461        for shape in &not_contracting_yet {
462            for d in shape.iter() {
463                let new = seen.insert(*d);
464                if new {
465                    set.push(*d);
466                }
467            }
468        }
469        for d in output_shape.iter() {
470            let new = seen.insert(*d);
471            if new {
472                set.push(*d);
473            }
474        }
475        set
476    };
477
478    // 5. take the dimension names that are in individually in both 4. and 3.
479    // These are the dimensions we retain in the contraction at this
480    // step.
481    let contraction_output: Vec<(Dimension, usize)> = {
482        let mut intersection = retained_dimensions.clone();
483        intersection.retain(|shape| contracting_dimensions.contains(shape));
484        intersection
485    };
486
487    // 6. add 2. and new input shape from 5., return to caller to become new input_shapes_left
488    // These are the shapes of the tensors left to be contracted
489    // in later steps. This will eventually be a single element list
490    // matching the output shape when we complete the final step.
491    let new_input_shapes_left = {
492        let mut vec = Vec::with_capacity(not_contracting_yet.len() + 1);
493        for d in not_contracting_yet.iter() {
494            vec.push(d.to_vec());
495        }
496        vec.push(contraction_output.clone());
497        vec
498    };
499
500    StepByStepContractionResult {
501        contraction_output,
502        input_shapes_left: new_input_shapes_left,
503    }
504}
505
506/// Return length of matching dimension in inputs, and error if the length of
507/// this output dimension is inconsistent in the input.
508fn length_of<const I: usize>(
509    output_dimension: Dimension,
510    input: &[&[(Dimension, usize)]; I],
511) -> Result<usize, InconsistentDimensionLengthError<I>> {
512    let lengths = input.map(|shapes| {
513        shapes
514            .iter()
515            .find(|(dimension, _length)| *dimension == output_dimension)
516            .map(|(_dimension, length)| *length)
517    });
518
519    let first_length = lengths.iter().filter_map(|l| *l).next();
520    if let Some(length) = first_length {
521        // Check other lengths agree
522        if lengths.iter().any(|l| l.is_some() && *l != Some(length)) {
523            // Different length matches
524            Err(InconsistentDimensionLengthError {
525                lengths,
526                dimension: output_dimension,
527            })
528        } else {
529            Ok(length)
530        }
531    } else {
532        // No matching lengths, we needed 1 match
533        Err(InconsistentDimensionLengthError {
534            lengths,
535            dimension: output_dimension,
536        })
537    }
538}
539
540#[track_caller]
541fn tensor_with_name<T, I, S, const D: usize>(
542    dimensions: [Dimension; D],
543    tensor: I,
544) -> TensorView<T, TensorRename<T, S, D>, D>
545where
546    I: Into<TensorView<T, S, D>>,
547    S: TensorRef<T, D>,
548{
549    let source: S = tensor.into().source();
550    let with_names = TensorRename::from(source, dimensions);
551    TensorView::from(with_names)
552}
553
554/// Return required output shape given input shape sizes and output shape
555/// dimension names. This can fail if a dimension in the requested output
556/// shape isn't present in the input, or if the input has contradictory sizes
557/// for it.
558// We could validate some parts of the input earlier than when we have
559// the output dimensions, but validating tensor lengths are consistent for
560// each common input dimension name would happen multiple times in the scenario
561// of a user using the `named` helper method, so it's a lot easier to use the API
562// if we defer validation till the final method call.
563fn output_shape_for<const I: usize, const O: usize>(
564    input: &[&[(Dimension, usize)]; I],
565    output: &[Dimension; O],
566) -> Result<[(Dimension, usize); O], InconsistentDimensionLengthError<I>> {
567    let mut output_shape = std::array::from_fn(|d| (output[d], 0));
568    for x in output_shape.iter_mut() {
569        x.1 = length_of(x.0, input)?;
570    }
571    Ok(output_shape)
572}
573
574/// We sum over every dimension included in the input and not the output
575///
576/// Returns a vec of the summation dimensions along with their validated
577/// lengths, and errors if the lengths of any summation dimensions in
578/// the input are inconsistent.
579fn summation_dimensions<const I: usize, const O: usize>(
580    input: &[&[(Dimension, usize)]; I],
581    output: &[Dimension; O],
582) -> Result<Vec<(Dimension, usize)>, InconsistentDimensionLengthError<I>> {
583    let mut total_dimensions = 0;
584    for shape in input {
585        total_dimensions += shape.len();
586    }
587
588    // Worst case is every dimension in each input tensor
589    // has unique dimensions
590    let mut unique_dimensions = Vec::with_capacity(total_dimensions);
591
592    for shape in input {
593        for (dimension, length) in shape.iter() {
594            if output.contains(dimension) {
595                // If this dimension is requested in the output we will be checking
596                // for consistent lengths in `length_of` and this dimension won't be
597                // a summation dimension so we can ignore it here.
598                continue;
599            }
600            let existing = unique_dimensions.iter().find(|(d, _)| d == dimension);
601            match existing {
602                None => unique_dimensions.push((*dimension, *length)),
603                Some((_, l)) => {
604                    if length != l {
605                        // Inconsistent lengths
606                        return Err(InconsistentDimensionLengthError {
607                            lengths: std::array::from_fn(|i| {
608                                input[i]
609                                    .iter()
610                                    .find(|(d, _)| d == dimension)
611                                    .map(|(_, l)| *l)
612                            }),
613                            dimension,
614                        });
615                    }
616                }
617            }
618        }
619    }
620
621    Ok(unique_dimensions)
622}
623
624/// Filters outer indexes to only the matching dimensions
625/// for the input shape. Panics if any dimensions in the
626/// input shape are missing from the outer slices, but accepts
627/// more indexes and dimensions in the outer slices than
628/// actually needed for the input shape without any errors.
629fn filter_outer_indexes<const D: usize, const O: usize>(
630    outer_indexes: &[usize; O],
631    outer_shape: &[(Dimension, usize); O],
632    input_shape: &[(Dimension, usize); D],
633) -> [usize; D] {
634    let mut input_indexes = [0; D];
635    for d in 0..D {
636        let mut found = false;
637        let dimension = input_shape[d].0;
638        for o in 0..O {
639            let possible_dimension = outer_shape[o].0;
640            if possible_dimension == dimension {
641                input_indexes[d] = outer_indexes[o];
642                found = true;
643                break;
644            }
645        }
646        if !found {
647            panic!(
648                "Expected to find an index for dimension {:?} but was not present in {:?} for {:?} while trying to index tensor of shape {:?}",
649                dimension,
650                outer_indexes,
651                outer_shape,
652                input_shape,
653            );
654        }
655    }
656    input_indexes
657}
658
659/// Filters outer indexes and summation indexes to only the
660/// matching dimensions for the input shape. Panics if any dimensions
661/// in the input shape are missing from the outer and summation slices,
662/// but accepts more indexes and dimensions in the outer and summation
663/// slices than actually needed for the input shape without any errors.
664/// Summation slices must be the same length, we just don't know their
665/// length at compile time so can't enforce it in the type system.
666fn filter_outer_and_summation_indexes<const D: usize, const O: usize>(
667    outer_indexes: &[usize; O],
668    outer_shape: &[(Dimension, usize); O],
669    summation_indexes: &[usize],
670    summation_shape: &[(Dimension, usize)],
671    input_shape: &[(Dimension, usize); D],
672) -> [usize; D] {
673    let mut input_indexes = [0; D];
674    for d in 0..D {
675        let mut found = false;
676        let dimension = input_shape[d].0;
677        for o in 0..O {
678            let possible_dimension = outer_shape[o].0;
679            if possible_dimension == dimension {
680                input_indexes[d] = outer_indexes[o];
681                found = true;
682                break;
683            }
684        }
685        let summation_iter = summation_indexes.iter().zip(summation_shape.iter());
686        for (index, (possible_dimension, _length)) in summation_iter {
687            if *possible_dimension == dimension {
688                input_indexes[d] = *index;
689                found = true;
690                break;
691            }
692        }
693        if !found {
694            panic!(
695                "Expected to find an index for dimension {:?} but was not present in {:?} for {:?} or {:?} for {:?} while trying to index tensor of shape {:?}",
696                dimension,
697                outer_indexes,
698                outer_shape,
699                summation_indexes,
700                summation_shape,
701                input_shape,
702            );
703        }
704    }
705    input_indexes
706}
707
708/**
709 * Einstein summation notation operation with a single input tensor.
710 */
711pub struct Einsum1<T, S1, const D1: usize> {
712    tensor_1: TensorView<T, S1, D1>,
713}
714
715/**
716 * Einstein summation notation operation with two input tensors.
717 */
718pub struct Einsum2<T, S1, S2, const D1: usize, const D2: usize> {
719    tensor_1: TensorView<T, S1, D1>,
720    tensor_2: TensorView<T, S2, D2>,
721}
722
723/**
724 * Einstein summation notation operation with three input tensors
725 */
726pub struct Einsum3<T, S1, S2, S3, const D1: usize, const D2: usize, const D3: usize> {
727    tensor_1: TensorView<T, S1, D1>,
728    tensor_2: TensorView<T, S2, D2>,
729    tensor_3: TensorView<T, S3, D3>,
730}
731
732/**
733 * Einstein summation notation operation with four input tensors
734 */
735pub struct Einsum4<
736    T,
737    S1,
738    S2,
739    S3,
740    S4,
741    const D1: usize,
742    const D2: usize,
743    const D3: usize,
744    const D4: usize,
745> {
746    tensor_1: TensorView<T, S1, D1>,
747    tensor_2: TensorView<T, S2, D2>,
748    tensor_3: TensorView<T, S3, D3>,
749    tensor_4: TensorView<T, S4, D4>,
750}
751
752/**
753 * Einstein summation notation operation with five input tensors
754 */
755pub struct Einsum5<
756    T,
757    S1,
758    S2,
759    S3,
760    S4,
761    S5,
762    const D1: usize,
763    const D2: usize,
764    const D3: usize,
765    const D4: usize,
766    const D5: usize,
767> {
768    tensor_1: TensorView<T, S1, D1>,
769    tensor_2: TensorView<T, S2, D2>,
770    tensor_3: TensorView<T, S3, D3>,
771    tensor_4: TensorView<T, S4, D4>,
772    tensor_5: TensorView<T, S5, D5>,
773}
774
775/**
776 * Einstein summation notation operation with six input tensors
777 *
778 * There are no technical limits on extending support to a greater number
779 * of input tensors, but as it's not feasible to write a generic implementation
780 * for any number of inputs at the moment we have to stop somewhere. For large
781 * numbers of inputs it may often be more efficient to break down the operation
782 * into substeps, such as detailed in
783 * <https://optimized-einsum.readthedocs.io/en/stable/index.html>
784 *
785 * Until support for using an optimiser to choose an efficient contraction order
786 * is added, the caller can still manually split larger operations into substeps
787 * by calling Einsum multiple times with a subset of the total inputs.
788 */
789pub struct Einsum6<
790    T,
791    S1,
792    S2,
793    S3,
794    S4,
795    S5,
796    S6,
797    const D1: usize,
798    const D2: usize,
799    const D3: usize,
800    const D4: usize,
801    const D5: usize,
802    const D6: usize,
803> {
804    tensor_1: TensorView<T, S1, D1>,
805    tensor_2: TensorView<T, S2, D2>,
806    tensor_3: TensorView<T, S3, D3>,
807    tensor_4: TensorView<T, S4, D4>,
808    tensor_5: TensorView<T, S5, D5>,
809    tensor_6: TensorView<T, S6, D6>,
810}
811
812impl<T, S1, const D1: usize> Einsum1<T, S1, D1> {
813    /**
814     * Renames all input tensors to the new names. Their shapes will
815     * still be in the same order with the same lengths of data, as
816     * per [TensorRename]. As per TensorRename, dimension names for
817     * each individual tensor must be unique.
818     */
819    #[track_caller]
820    pub fn named(self, input_1: [Dimension; D1]) -> Einsum1<T, TensorRename<T, S1, D1>, D1>
821    where
822        S1: TensorRef<T, D1>,
823    {
824        Einsum1 {
825            tensor_1: tensor_with_name(input_1, self.tensor_1),
826        }
827    }
828
829    pub fn to<const O: usize>(
830        self,
831        output: [Dimension; O],
832    ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<1>>
833    where
834        T: Numeric,
835        for<'a> &'a T: NumericRef<T>,
836        S1: TensorRef<T, D1>,
837    {
838        let input_1_shape_const = &self.tensor_1.shape();
839        let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
840        let input = &[input_1_shape];
841
842        let output_shape = output_shape_for(input, &output)?;
843        let mut output_tensor = Tensor::empty(output_shape, T::zero());
844
845        let summation_dimensions = summation_dimensions(input, &output)?;
846        let tensor_1_indexing = self.tensor_1.index();
847
848        for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
849            let mut sum = T::zero();
850
851            if summation_dimensions.is_empty() {
852                let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
853                    &indexes,
854                    &output_shape,
855                    input_1_shape_const,
856                ));
857                sum = sum + product_1;
858            } else {
859                let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
860                loop {
861                    let next = summation_iterator.next();
862                    match next {
863                        Some(summation_indexes) => {
864                            let product_1 =
865                                tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
866                                    &indexes,
867                                    &output_shape,
868                                    summation_indexes,
869                                    &summation_dimensions,
870                                    input_1_shape_const,
871                                ));
872                            sum = sum + product_1;
873                        }
874                        None => break,
875                    }
876                }
877            }
878            *element = sum;
879        }
880
881        Ok(output_tensor)
882    }
883}
884
885impl<T, S1, S2, const D1: usize, const D2: usize> Einsum2<T, S1, S2, D1, D2> {
886    /**
887     * Renames all input tensors to the new names. Their shapes will
888     * still be in the same order with the same lengths of data, as
889     * per [TensorRename]. As per TensorRename, dimension names for
890     * each individual tensor must be unique.
891     */
892    #[track_caller]
893    pub fn named(
894        self,
895        input_1: [Dimension; D1],
896        input_2: [Dimension; D2],
897    ) -> Einsum2<T, TensorRename<T, S1, D1>, TensorRename<T, S2, D2>, D1, D2>
898    where
899        S1: TensorRef<T, D1>,
900        S2: TensorRef<T, D2>,
901    {
902        Einsum2 {
903            tensor_1: tensor_with_name(input_1, self.tensor_1),
904            tensor_2: tensor_with_name(input_2, self.tensor_2),
905        }
906    }
907
908    pub fn to<const O: usize>(
909        self,
910        output: [Dimension; O],
911    ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<2>>
912    where
913        T: Numeric,
914        for<'a> &'a T: NumericRef<T>,
915        S1: TensorRef<T, D1>,
916        S2: TensorRef<T, D2>,
917    {
918        let input_1_shape_const = &self.tensor_1.shape();
919        let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
920        let input_2_shape_const = &self.tensor_2.shape();
921        let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
922        let input = &[input_1_shape, input_2_shape];
923
924        let output_shape = output_shape_for(input, &output)?;
925        let mut output_tensor = Tensor::empty(output_shape, T::zero());
926
927        let summation_dimensions = summation_dimensions(input, &output)?;
928        let tensor_1_indexing = self.tensor_1.index();
929        let tensor_2_indexing = self.tensor_2.index();
930
931        for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
932            let mut sum = T::zero();
933
934            if summation_dimensions.is_empty() {
935                let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
936                    &indexes,
937                    &output_shape,
938                    input_1_shape_const,
939                ));
940                let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
941                    &indexes,
942                    &output_shape,
943                    input_2_shape_const,
944                ));
945                sum = sum + (product_1 * product_2);
946            } else {
947                let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
948                loop {
949                    let next = summation_iterator.next();
950                    match next {
951                        Some(summation_indexes) => {
952                            let product_1 =
953                                tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
954                                    &indexes,
955                                    &output_shape,
956                                    summation_indexes,
957                                    &summation_dimensions,
958                                    input_1_shape_const,
959                                ));
960                            let product_2 =
961                                tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
962                                    &indexes,
963                                    &output_shape,
964                                    summation_indexes,
965                                    &summation_dimensions,
966                                    input_2_shape_const,
967                                ));
968                            sum = sum + (product_1 * product_2);
969                        }
970                        None => break,
971                    }
972                }
973            }
974
975            *element = sum;
976        }
977
978        Ok(output_tensor)
979    }
980}
981
982impl<T, S1, S2, S3, const D1: usize, const D2: usize, const D3: usize>
983    Einsum3<T, S1, S2, S3, D1, D2, D3>
984{
985    /**
986     * Renames all input tensors to the new names. Their shapes will
987     * still be in the same order with the same lengths of data, as
988     * per [TensorRename]. As per TensorRename, dimension names for
989     * each individual tensor must be unique.
990     */
991    #[track_caller]
992    #[allow(clippy::type_complexity)]
993    pub fn named(
994        self,
995        input_1: [Dimension; D1],
996        input_2: [Dimension; D2],
997        input_3: [Dimension; D3],
998    ) -> Einsum3<
999        T,
1000        TensorRename<T, S1, D1>,
1001        TensorRename<T, S2, D2>,
1002        TensorRename<T, S3, D3>,
1003        D1,
1004        D2,
1005        D3,
1006    >
1007    where
1008        S1: TensorRef<T, D1>,
1009        S2: TensorRef<T, D2>,
1010        S3: TensorRef<T, D3>,
1011    {
1012        Einsum3 {
1013            tensor_1: tensor_with_name(input_1, self.tensor_1),
1014            tensor_2: tensor_with_name(input_2, self.tensor_2),
1015            tensor_3: tensor_with_name(input_3, self.tensor_3),
1016        }
1017    }
1018
1019    pub fn to<const O: usize>(
1020        self,
1021        output: [Dimension; O],
1022    ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<3>>
1023    where
1024        T: Numeric,
1025        for<'a> &'a T: NumericRef<T>,
1026        S1: TensorRef<T, D1>,
1027        S2: TensorRef<T, D2>,
1028        S3: TensorRef<T, D3>,
1029    {
1030        let input_1_shape_const = &self.tensor_1.shape();
1031        let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
1032        let input_2_shape_const = &self.tensor_2.shape();
1033        let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
1034        let input_3_shape_const = &self.tensor_3.shape();
1035        let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
1036        let input = &[input_1_shape, input_2_shape, input_3_shape];
1037
1038        let output_shape = output_shape_for(input, &output)?;
1039        let mut output_tensor = Tensor::empty(output_shape, T::zero());
1040
1041        let summation_dimensions = summation_dimensions(input, &output)?;
1042        let tensor_1_indexing = self.tensor_1.index();
1043        let tensor_2_indexing = self.tensor_2.index();
1044        let tensor_3_indexing = self.tensor_3.index();
1045
1046        for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
1047            let mut sum = T::zero();
1048
1049            if summation_dimensions.is_empty() {
1050                let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
1051                    &indexes,
1052                    &output_shape,
1053                    input_1_shape_const,
1054                ));
1055                let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
1056                    &indexes,
1057                    &output_shape,
1058                    input_2_shape_const,
1059                ));
1060                let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
1061                    &indexes,
1062                    &output_shape,
1063                    input_3_shape_const,
1064                ));
1065                sum = sum + (product_1 * product_2 * product_3);
1066            } else {
1067                let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
1068                loop {
1069                    let next = summation_iterator.next();
1070                    match next {
1071                        Some(summation_indexes) => {
1072                            let product_1 =
1073                                tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
1074                                    &indexes,
1075                                    &output_shape,
1076                                    summation_indexes,
1077                                    &summation_dimensions,
1078                                    input_1_shape_const,
1079                                ));
1080                            let product_2 =
1081                                tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
1082                                    &indexes,
1083                                    &output_shape,
1084                                    summation_indexes,
1085                                    &summation_dimensions,
1086                                    input_2_shape_const,
1087                                ));
1088                            let product_3 =
1089                                tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
1090                                    &indexes,
1091                                    &output_shape,
1092                                    summation_indexes,
1093                                    &summation_dimensions,
1094                                    input_3_shape_const,
1095                                ));
1096                            sum = sum + (product_1 * product_2 * product_3);
1097                        }
1098                        None => break,
1099                    }
1100                }
1101            }
1102
1103            *element = sum;
1104        }
1105
1106        Ok(output_tensor)
1107    }
1108}
1109
1110impl<T, S1, S2, S3, S4, const D1: usize, const D2: usize, const D3: usize, const D4: usize>
1111    Einsum4<T, S1, S2, S3, S4, D1, D2, D3, D4>
1112{
1113    /**
1114     * Renames all input tensors to the new names. Their shapes will
1115     * still be in the same order with the same lengths of data, as
1116     * per [TensorRename]. As per TensorRename, dimension names for
1117     * each individual tensor must be unique.
1118     */
1119    #[track_caller]
1120    #[allow(clippy::type_complexity)]
1121    pub fn named(
1122        self,
1123        input_1: [Dimension; D1],
1124        input_2: [Dimension; D2],
1125        input_3: [Dimension; D3],
1126        input_4: [Dimension; D4],
1127    ) -> Einsum4<
1128        T,
1129        TensorRename<T, S1, D1>,
1130        TensorRename<T, S2, D2>,
1131        TensorRename<T, S3, D3>,
1132        TensorRename<T, S4, D4>,
1133        D1,
1134        D2,
1135        D3,
1136        D4,
1137    >
1138    where
1139        S1: TensorRef<T, D1>,
1140        S2: TensorRef<T, D2>,
1141        S3: TensorRef<T, D3>,
1142        S4: TensorRef<T, D4>,
1143    {
1144        Einsum4 {
1145            tensor_1: tensor_with_name(input_1, self.tensor_1),
1146            tensor_2: tensor_with_name(input_2, self.tensor_2),
1147            tensor_3: tensor_with_name(input_3, self.tensor_3),
1148            tensor_4: tensor_with_name(input_4, self.tensor_4),
1149        }
1150    }
1151
1152    pub fn to<const O: usize>(
1153        self,
1154        output: [Dimension; O],
1155    ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<4>>
1156    where
1157        T: Numeric,
1158        for<'a> &'a T: NumericRef<T>,
1159        S1: TensorRef<T, D1>,
1160        S2: TensorRef<T, D2>,
1161        S3: TensorRef<T, D3>,
1162        S4: TensorRef<T, D4>,
1163    {
1164        let input_1_shape_const = &self.tensor_1.shape();
1165        let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
1166        let input_2_shape_const = &self.tensor_2.shape();
1167        let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
1168        let input_3_shape_const = &self.tensor_3.shape();
1169        let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
1170        let input_4_shape_const = &self.tensor_4.shape();
1171        let input_4_shape: &[(Dimension, usize)] = input_4_shape_const;
1172        let input = &[input_1_shape, input_2_shape, input_3_shape, input_4_shape];
1173
1174        let output_shape = output_shape_for(input, &output)?;
1175        let mut output_tensor = Tensor::empty(output_shape, T::zero());
1176
1177        let summation_dimensions = summation_dimensions(input, &output)?;
1178        let tensor_1_indexing = self.tensor_1.index();
1179        let tensor_2_indexing = self.tensor_2.index();
1180        let tensor_3_indexing = self.tensor_3.index();
1181        let tensor_4_indexing = self.tensor_4.index();
1182
1183        for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
1184            let mut sum = T::zero();
1185
1186            if summation_dimensions.is_empty() {
1187                let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
1188                    &indexes,
1189                    &output_shape,
1190                    input_1_shape_const,
1191                ));
1192                let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
1193                    &indexes,
1194                    &output_shape,
1195                    input_2_shape_const,
1196                ));
1197                let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
1198                    &indexes,
1199                    &output_shape,
1200                    input_3_shape_const,
1201                ));
1202                let product_4 = tensor_4_indexing.get_ref(filter_outer_indexes(
1203                    &indexes,
1204                    &output_shape,
1205                    input_4_shape_const,
1206                ));
1207                sum = sum + (product_1 * product_2 * product_3 * product_4);
1208            } else {
1209                let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
1210                loop {
1211                    let next = summation_iterator.next();
1212                    match next {
1213                        Some(summation_indexes) => {
1214                            let product_1 =
1215                                tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
1216                                    &indexes,
1217                                    &output_shape,
1218                                    summation_indexes,
1219                                    &summation_dimensions,
1220                                    input_1_shape_const,
1221                                ));
1222                            let product_2 =
1223                                tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
1224                                    &indexes,
1225                                    &output_shape,
1226                                    summation_indexes,
1227                                    &summation_dimensions,
1228                                    input_2_shape_const,
1229                                ));
1230                            let product_3 =
1231                                tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
1232                                    &indexes,
1233                                    &output_shape,
1234                                    summation_indexes,
1235                                    &summation_dimensions,
1236                                    input_3_shape_const,
1237                                ));
1238                            let product_4 =
1239                                tensor_4_indexing.get_ref(filter_outer_and_summation_indexes(
1240                                    &indexes,
1241                                    &output_shape,
1242                                    summation_indexes,
1243                                    &summation_dimensions,
1244                                    input_4_shape_const,
1245                                ));
1246                            sum = sum + (product_1 * product_2 * product_3 * product_4);
1247                        }
1248                        None => break,
1249                    }
1250                }
1251            }
1252
1253            *element = sum;
1254        }
1255
1256        Ok(output_tensor)
1257    }
1258}
1259
1260impl<
1261        T,
1262        S1,
1263        S2,
1264        S3,
1265        S4,
1266        S5,
1267        const D1: usize,
1268        const D2: usize,
1269        const D3: usize,
1270        const D4: usize,
1271        const D5: usize,
1272    > Einsum5<T, S1, S2, S3, S4, S5, D1, D2, D3, D4, D5>
1273{
1274    /**
1275     * Renames all input tensors to the new names. Their shapes will
1276     * still be in the same order with the same lengths of data, as
1277     * per [TensorRename]. As per TensorRename, dimension names for
1278     * each individual tensor must be unique.
1279     */
1280    #[track_caller]
1281    #[allow(clippy::type_complexity)]
1282    pub fn named(
1283        self,
1284        input_1: [Dimension; D1],
1285        input_2: [Dimension; D2],
1286        input_3: [Dimension; D3],
1287        input_4: [Dimension; D4],
1288        input_5: [Dimension; D5],
1289    ) -> Einsum5<
1290        T,
1291        TensorRename<T, S1, D1>,
1292        TensorRename<T, S2, D2>,
1293        TensorRename<T, S3, D3>,
1294        TensorRename<T, S4, D4>,
1295        TensorRename<T, S5, D5>,
1296        D1,
1297        D2,
1298        D3,
1299        D4,
1300        D5,
1301    >
1302    where
1303        S1: TensorRef<T, D1>,
1304        S2: TensorRef<T, D2>,
1305        S3: TensorRef<T, D3>,
1306        S4: TensorRef<T, D4>,
1307        S5: TensorRef<T, D5>,
1308    {
1309        Einsum5 {
1310            tensor_1: tensor_with_name(input_1, self.tensor_1),
1311            tensor_2: tensor_with_name(input_2, self.tensor_2),
1312            tensor_3: tensor_with_name(input_3, self.tensor_3),
1313            tensor_4: tensor_with_name(input_4, self.tensor_4),
1314            tensor_5: tensor_with_name(input_5, self.tensor_5),
1315        }
1316    }
1317
1318    pub fn to<const O: usize>(
1319        self,
1320        output: [Dimension; O],
1321    ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<5>>
1322    where
1323        T: Numeric,
1324        for<'a> &'a T: NumericRef<T>,
1325        S1: TensorRef<T, D1>,
1326        S2: TensorRef<T, D2>,
1327        S3: TensorRef<T, D3>,
1328        S4: TensorRef<T, D4>,
1329        S5: TensorRef<T, D5>,
1330    {
1331        let input_1_shape_const = &self.tensor_1.shape();
1332        let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
1333        let input_2_shape_const = &self.tensor_2.shape();
1334        let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
1335        let input_3_shape_const = &self.tensor_3.shape();
1336        let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
1337        let input_4_shape_const = &self.tensor_4.shape();
1338        let input_4_shape: &[(Dimension, usize)] = input_4_shape_const;
1339        let input_5_shape_const = &self.tensor_5.shape();
1340        let input_5_shape: &[(Dimension, usize)] = input_5_shape_const;
1341        let input = &[
1342            input_1_shape,
1343            input_2_shape,
1344            input_3_shape,
1345            input_4_shape,
1346            input_5_shape,
1347        ];
1348
1349        let output_shape = output_shape_for(input, &output)?;
1350        let mut output_tensor = Tensor::empty(output_shape, T::zero());
1351
1352        let summation_dimensions = summation_dimensions(input, &output)?;
1353        let tensor_1_indexing = self.tensor_1.index();
1354        let tensor_2_indexing = self.tensor_2.index();
1355        let tensor_3_indexing = self.tensor_3.index();
1356        let tensor_4_indexing = self.tensor_4.index();
1357        let tensor_5_indexing = self.tensor_5.index();
1358
1359        for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
1360            let mut sum = T::zero();
1361
1362            if summation_dimensions.is_empty() {
1363                let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
1364                    &indexes,
1365                    &output_shape,
1366                    input_1_shape_const,
1367                ));
1368                let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
1369                    &indexes,
1370                    &output_shape,
1371                    input_2_shape_const,
1372                ));
1373                let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
1374                    &indexes,
1375                    &output_shape,
1376                    input_3_shape_const,
1377                ));
1378                let product_4 = tensor_4_indexing.get_ref(filter_outer_indexes(
1379                    &indexes,
1380                    &output_shape,
1381                    input_4_shape_const,
1382                ));
1383                let product_5 = tensor_5_indexing.get_ref(filter_outer_indexes(
1384                    &indexes,
1385                    &output_shape,
1386                    input_5_shape_const,
1387                ));
1388                sum = sum + (product_1 * product_2 * product_3 * product_4 * product_5);
1389            } else {
1390                let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
1391                loop {
1392                    let next = summation_iterator.next();
1393                    match next {
1394                        Some(summation_indexes) => {
1395                            let product_1 =
1396                                tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
1397                                    &indexes,
1398                                    &output_shape,
1399                                    summation_indexes,
1400                                    &summation_dimensions,
1401                                    input_1_shape_const,
1402                                ));
1403                            let product_2 =
1404                                tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
1405                                    &indexes,
1406                                    &output_shape,
1407                                    summation_indexes,
1408                                    &summation_dimensions,
1409                                    input_2_shape_const,
1410                                ));
1411                            let product_3 =
1412                                tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
1413                                    &indexes,
1414                                    &output_shape,
1415                                    summation_indexes,
1416                                    &summation_dimensions,
1417                                    input_3_shape_const,
1418                                ));
1419                            let product_4 =
1420                                tensor_4_indexing.get_ref(filter_outer_and_summation_indexes(
1421                                    &indexes,
1422                                    &output_shape,
1423                                    summation_indexes,
1424                                    &summation_dimensions,
1425                                    input_4_shape_const,
1426                                ));
1427                            let product_5 =
1428                                tensor_5_indexing.get_ref(filter_outer_and_summation_indexes(
1429                                    &indexes,
1430                                    &output_shape,
1431                                    summation_indexes,
1432                                    &summation_dimensions,
1433                                    input_5_shape_const,
1434                                ));
1435                            sum = sum + (product_1 * product_2 * product_3 * product_4 * product_5);
1436                        }
1437                        None => break,
1438                    }
1439                }
1440            }
1441
1442            *element = sum;
1443        }
1444
1445        Ok(output_tensor)
1446    }
1447}
1448
1449impl<
1450        T,
1451        S1,
1452        S2,
1453        S3,
1454        S4,
1455        S5,
1456        S6,
1457        const D1: usize,
1458        const D2: usize,
1459        const D3: usize,
1460        const D4: usize,
1461        const D5: usize,
1462        const D6: usize,
1463    > Einsum6<T, S1, S2, S3, S4, S5, S6, D1, D2, D3, D4, D5, D6>
1464{
1465    /**
1466     * Renames all input tensors to the new names. Their shapes will
1467     * still be in the same order with the same lengths of data, as
1468     * per [TensorRename]. As per TensorRename, dimension names for
1469     * each individual tensor must be unique.
1470     */
1471    #[track_caller]
1472    #[allow(clippy::type_complexity)]
1473    pub fn named(
1474        self,
1475        input_1: [Dimension; D1],
1476        input_2: [Dimension; D2],
1477        input_3: [Dimension; D3],
1478        input_4: [Dimension; D4],
1479        input_5: [Dimension; D5],
1480        input_6: [Dimension; D6],
1481    ) -> Einsum6<
1482        T,
1483        TensorRename<T, S1, D1>,
1484        TensorRename<T, S2, D2>,
1485        TensorRename<T, S3, D3>,
1486        TensorRename<T, S4, D4>,
1487        TensorRename<T, S5, D5>,
1488        TensorRename<T, S6, D6>,
1489        D1,
1490        D2,
1491        D3,
1492        D4,
1493        D5,
1494        D6,
1495    >
1496    where
1497        S1: TensorRef<T, D1>,
1498        S2: TensorRef<T, D2>,
1499        S3: TensorRef<T, D3>,
1500        S4: TensorRef<T, D4>,
1501        S5: TensorRef<T, D5>,
1502        S6: TensorRef<T, D6>,
1503    {
1504        Einsum6 {
1505            tensor_1: tensor_with_name(input_1, self.tensor_1),
1506            tensor_2: tensor_with_name(input_2, self.tensor_2),
1507            tensor_3: tensor_with_name(input_3, self.tensor_3),
1508            tensor_4: tensor_with_name(input_4, self.tensor_4),
1509            tensor_5: tensor_with_name(input_5, self.tensor_5),
1510            tensor_6: tensor_with_name(input_6, self.tensor_6),
1511        }
1512    }
1513
1514    pub fn to<const O: usize>(
1515        self,
1516        output: [Dimension; O],
1517    ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<6>>
1518    where
1519        T: Numeric,
1520        for<'a> &'a T: NumericRef<T>,
1521        S1: TensorRef<T, D1>,
1522        S2: TensorRef<T, D2>,
1523        S3: TensorRef<T, D3>,
1524        S4: TensorRef<T, D4>,
1525        S5: TensorRef<T, D5>,
1526        S6: TensorRef<T, D6>,
1527    {
1528        let input_1_shape_const = &self.tensor_1.shape();
1529        let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
1530        let input_2_shape_const = &self.tensor_2.shape();
1531        let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
1532        let input_3_shape_const = &self.tensor_3.shape();
1533        let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
1534        let input_4_shape_const = &self.tensor_4.shape();
1535        let input_4_shape: &[(Dimension, usize)] = input_4_shape_const;
1536        let input_5_shape_const = &self.tensor_5.shape();
1537        let input_5_shape: &[(Dimension, usize)] = input_5_shape_const;
1538        let input_6_shape_const = &self.tensor_6.shape();
1539        let input_6_shape: &[(Dimension, usize)] = input_6_shape_const;
1540        let input = &[
1541            input_1_shape,
1542            input_2_shape,
1543            input_3_shape,
1544            input_4_shape,
1545            input_5_shape,
1546            input_6_shape,
1547        ];
1548
1549        let output_shape = output_shape_for(input, &output)?;
1550        let mut output_tensor = Tensor::empty(output_shape, T::zero());
1551
1552        let summation_dimensions = summation_dimensions(input, &output)?;
1553        let tensor_1_indexing = self.tensor_1.index();
1554        let tensor_2_indexing = self.tensor_2.index();
1555        let tensor_3_indexing = self.tensor_3.index();
1556        let tensor_4_indexing = self.tensor_4.index();
1557        let tensor_5_indexing = self.tensor_5.index();
1558        let tensor_6_indexing = self.tensor_6.index();
1559
1560        for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
1561            let mut sum = T::zero();
1562
1563            if summation_dimensions.is_empty() {
1564                let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
1565                    &indexes,
1566                    &output_shape,
1567                    input_1_shape_const,
1568                ));
1569                let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
1570                    &indexes,
1571                    &output_shape,
1572                    input_2_shape_const,
1573                ));
1574                let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
1575                    &indexes,
1576                    &output_shape,
1577                    input_3_shape_const,
1578                ));
1579                let product_4 = tensor_4_indexing.get_ref(filter_outer_indexes(
1580                    &indexes,
1581                    &output_shape,
1582                    input_4_shape_const,
1583                ));
1584                let product_5 = tensor_5_indexing.get_ref(filter_outer_indexes(
1585                    &indexes,
1586                    &output_shape,
1587                    input_5_shape_const,
1588                ));
1589                let product_6 = tensor_6_indexing.get_ref(filter_outer_indexes(
1590                    &indexes,
1591                    &output_shape,
1592                    input_6_shape_const,
1593                ));
1594                sum = sum + (product_1 * product_2 * product_3 * product_4 * product_5 * product_6);
1595            } else {
1596                let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
1597                loop {
1598                    let next = summation_iterator.next();
1599                    match next {
1600                        Some(summation_indexes) => {
1601                            let product_1 =
1602                                tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
1603                                    &indexes,
1604                                    &output_shape,
1605                                    summation_indexes,
1606                                    &summation_dimensions,
1607                                    input_1_shape_const,
1608                                ));
1609                            let product_2 =
1610                                tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
1611                                    &indexes,
1612                                    &output_shape,
1613                                    summation_indexes,
1614                                    &summation_dimensions,
1615                                    input_2_shape_const,
1616                                ));
1617                            let product_3 =
1618                                tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
1619                                    &indexes,
1620                                    &output_shape,
1621                                    summation_indexes,
1622                                    &summation_dimensions,
1623                                    input_3_shape_const,
1624                                ));
1625                            let product_4 =
1626                                tensor_4_indexing.get_ref(filter_outer_and_summation_indexes(
1627                                    &indexes,
1628                                    &output_shape,
1629                                    summation_indexes,
1630                                    &summation_dimensions,
1631                                    input_4_shape_const,
1632                                ));
1633                            let product_5 =
1634                                tensor_5_indexing.get_ref(filter_outer_and_summation_indexes(
1635                                    &indexes,
1636                                    &output_shape,
1637                                    summation_indexes,
1638                                    &summation_dimensions,
1639                                    input_5_shape_const,
1640                                ));
1641                            let product_6 =
1642                                tensor_6_indexing.get_ref(filter_outer_and_summation_indexes(
1643                                    &indexes,
1644                                    &output_shape,
1645                                    summation_indexes,
1646                                    &summation_dimensions,
1647                                    input_6_shape_const,
1648                                ));
1649                            sum = sum
1650                                + (product_1
1651                                    * product_2
1652                                    * product_3
1653                                    * product_4
1654                                    * product_5
1655                                    * product_6);
1656                        }
1657                        None => break,
1658                    }
1659                }
1660            }
1661
1662            *element = sum;
1663        }
1664
1665        Ok(output_tensor)
1666    }
1667}
1668
1669#[test]
1670fn step_by_step_contraction_tests() {
1671    // Simple case where we just contract 2 tensors and consume all the input
1672    assert_eq!(
1673        step_by_step_contraction(
1674            &[&[("x", 2), ("y", 3)], &[("y", 3), ("z", 4)]],
1675            &[("x", 2), ("z", 4)],
1676            &Contraction {
1677                tensor_indexes: vec![0, 1]
1678            },
1679        ),
1680        StepByStepContractionResult {
1681            input_shapes_left: vec![vec![("x", 2), ("z", 4)]],
1682            contraction_output: vec![("x", 2), ("z", 4)],
1683        }
1684    );
1685    // Case where we contract out `b` and `d` and leave just two tensors
1686    // with `a` and `c` terms to contract next.
1687    #[rustfmt::skip]
1688    assert_eq!(
1689        step_by_step_contraction(
1690            &[
1691                &[("a", 2), ("b", 3), ("d", 5)],
1692                &[("a", 2), ("c", 4)],
1693                &[("b", 3), ("d", 5), ("c", 4)],
1694            ],
1695            &[("a", 2), ("c", 4)],
1696            &Contraction {
1697                tensor_indexes: vec![0, 2]
1698            },
1699        ),
1700        StepByStepContractionResult {
1701            input_shapes_left: vec![
1702                vec![("a", 2), ("c", 4)],
1703                vec![("a", 2), ("c", 4)],
1704            ],
1705            contraction_output: vec![("a", 2), ("c", 4)],
1706        }
1707    );
1708    // Less optimised route where we have to leave `b` and `d` terms in
1709    // because the last input still needs them, and we can't contract out
1710    // `a` because we requested it in the output.
1711    assert_eq!(
1712        step_by_step_contraction(
1713            &[
1714                &[("a", 2), ("b", 3), ("d", 5)],
1715                &[("a", 2), ("c", 4)],
1716                &[("b", 3), ("d", 5), ("c", 4)],
1717            ],
1718            &[("a", 2), ("c", 4)],
1719            &Contraction {
1720                tensor_indexes: vec![0, 1]
1721            },
1722        ),
1723        StepByStepContractionResult {
1724            input_shapes_left: vec![
1725                vec![("b", 3), ("d", 5), ("c", 4)],
1726                vec![("b", 3), ("d", 5), ("c", 4), ("a", 2)],
1727            ],
1728            contraction_output: vec![("b", 3), ("d", 5), ("c", 4), ("a", 2)],
1729        }
1730    );
1731    // Slightly different route where we can contract out `a` because
1732    // we didn't request it in the output.
1733    assert_eq!(
1734        step_by_step_contraction(
1735            &[
1736                &[("a", 2), ("b", 3), ("d", 5)],
1737                &[("a", 2), ("c", 4)],
1738                &[("b", 3), ("d", 5), ("c", 4)],
1739            ],
1740            &[("c", 4)],
1741            &Contraction {
1742                tensor_indexes: vec![0, 1]
1743            },
1744        ),
1745        StepByStepContractionResult {
1746            input_shapes_left: vec![
1747                vec![("b", 3), ("d", 5), ("c", 4)],
1748                vec![("b", 3), ("d", 5), ("c", 4)],
1749            ],
1750            contraction_output: vec![("b", 3), ("d", 5), ("c", 4)],
1751        }
1752    );
1753}
1754
1755// TODO: Letting caller pass in the desired contraction order
1756// should largely build on top of naive Einsum implementation, but we
1757// are going to need quite a few more APIs to generalise over different
1758// dimensionalities of tensors first since we have to erase dimension length
1759// and dimension arguments somehow.
1760// fn by_contraction_order<const O: usize>(
1761//     self,
1762//     output: [Dimension; O],
1763//     contraction_order: &[Contraction],
1764// ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<3>>
1765// where
1766//     T: Numeric,
1767//     for<'a> &'a T: NumericRef<T>,
1768//     S1: TensorRef<T, D1>,
1769//     S2: TensorRef<T, D2>,
1770//     S3: TensorRef<T, D3>,
1771// {
1772//     let input_1_shape_const = &self.tensor_1.shape();
1773//     let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
1774//     let input_2_shape_const = &self.tensor_2.shape();
1775//     let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
1776//     let input_3_shape_const = &self.tensor_3.shape();
1777//     let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
1778//     let input = &[input_1_shape, input_2_shape, input_3_shape];
1779//
1780//     let output_shape = output_shape_for(input, &output)?;
1781//     let mut output_tensor = Tensor::empty(output_shape, T::zero());
1782//
1783//     let summation_dimensions = summation_dimensions(input, &output)?;
1784//
1785//     let mut input: Vec<Vec<(Dimension, usize)>> = input.iter().map(|i| i.to_vec()).collect();
1786//
1787//     for contraction in contraction_order.iter() {
1788//         let step = step_by_step_contraction(
1789//             &input.iter().map(|i| i.as_slice()).collect::<Vec<&[(Dimension, usize)]>>(),
1790//             &output_shape,
1791//             &contraction,
1792//         );
1793//         input = step.input_shapes_left;
1794//         let einsum_step = step.contraction_output;
1795//         // Need to store the unprocessed inputs in a list somehow which
1796//         // is going to require first erasing or at least enumerating over
1797//         // their dimensionality.
1798//         match einsum_step.len() {
1799//             0 => unimplemented!(),
1800//             1 => Einsum::with_1(...),
1801//             2 => Einsum::with_2(...),
1802//             3 => Einsum::with_3(...),
1803//             _ => panic!("Unsupported contraction step, output was larger than supported")
1804//         }
1805//     }
1806//
1807//     unimplemented!()
1808// }
1809
1810// TODO: Once Tensor implementation is working, should be able to actually generalise
1811// to work on RecordTensor inputs too, they can be passed in up to the .to() step already.
1812// Final step needs to be aware of some kind of NumericLike type that knows how to lift
1813// and lower from additional context to a Numeric type for addition and multiplication.
1814// In some future work can introduce a generic associated type for TensorRef that
1815// 'knows' what the desired container output is for tensor operations like these to collect
1816// the results back into, so result type becomes Result<S1::Output<T, O>, InconsistentDimensionLength>
1817// and somehow we enforce S1::Output == S2::Output????