1use crate::numeric::{Numeric, NumericRef};
11use crate::tensors::indexing::DynamicShapeIterator;
12use crate::tensors::views::{TensorRef, TensorRename, TensorView};
13use crate::tensors::{Dimension, Tensor};
14
15use std::collections::HashSet;
16use std::error::Error;
17use std::fmt;
18
19#[derive(Clone, Debug, Default)]
88pub struct Einsum {
89 _private: (),
90}
91
92impl Einsum {
93 pub fn with_1<T, S, I, const D: usize>(input_1: I) -> Einsum1<T, S, D>
99 where
100 S: TensorRef<T, D>,
101 I: Into<TensorView<T, S, D>>,
102 {
103 Einsum1 {
104 tensor_1: input_1.into(),
105 }
106 }
107
108 pub fn with_2<T, S1, S2, I1, I2, const D1: usize, const D2: usize>(
114 input_1: I1,
115 input_2: I2,
116 ) -> Einsum2<T, S1, S2, D1, D2>
117 where
118 S1: TensorRef<T, D1>,
119 S2: TensorRef<T, D2>,
120 I1: Into<TensorView<T, S1, D1>>,
121 I2: Into<TensorView<T, S2, D2>>,
122 {
123 Einsum2 {
124 tensor_1: input_1.into(),
125 tensor_2: input_2.into(),
126 }
127 }
128
129 pub fn with_3<T, S1, S2, S3, I1, I2, I3, const D1: usize, const D2: usize, const D3: usize>(
135 input_1: I1,
136 input_2: I2,
137 input_3: I3,
138 ) -> Einsum3<T, S1, S2, S3, D1, D2, D3>
139 where
140 S1: TensorRef<T, D1>,
141 S2: TensorRef<T, D2>,
142 S3: TensorRef<T, D3>,
143 I1: Into<TensorView<T, S1, D1>>,
144 I2: Into<TensorView<T, S2, D2>>,
145 I3: Into<TensorView<T, S3, D3>>,
146 {
147 Einsum3 {
148 tensor_1: input_1.into(),
149 tensor_2: input_2.into(),
150 tensor_3: input_3.into(),
151 }
152 }
153
154 pub fn with_4<
160 T,
161 S1,
162 S2,
163 S3,
164 S4,
165 I1,
166 I2,
167 I3,
168 I4,
169 const D1: usize,
170 const D2: usize,
171 const D3: usize,
172 const D4: usize,
173 >(
174 input_1: I1,
175 input_2: I2,
176 input_3: I3,
177 input_4: I4,
178 ) -> Einsum4<T, S1, S2, S3, S4, D1, D2, D3, D4>
179 where
180 S1: TensorRef<T, D1>,
181 S2: TensorRef<T, D2>,
182 S3: TensorRef<T, D3>,
183 S4: TensorRef<T, D4>,
184 I1: Into<TensorView<T, S1, D1>>,
185 I2: Into<TensorView<T, S2, D2>>,
186 I3: Into<TensorView<T, S3, D3>>,
187 I4: Into<TensorView<T, S4, D4>>,
188 {
189 Einsum4 {
190 tensor_1: input_1.into(),
191 tensor_2: input_2.into(),
192 tensor_3: input_3.into(),
193 tensor_4: input_4.into(),
194 }
195 }
196
197 pub fn with_5<
203 T,
204 S1,
205 S2,
206 S3,
207 S4,
208 S5,
209 I1,
210 I2,
211 I3,
212 I4,
213 I5,
214 const D1: usize,
215 const D2: usize,
216 const D3: usize,
217 const D4: usize,
218 const D5: usize,
219 >(
220 input_1: I1,
221 input_2: I2,
222 input_3: I3,
223 input_4: I4,
224 input_5: I5,
225 ) -> Einsum5<T, S1, S2, S3, S4, S5, D1, D2, D3, D4, D5>
226 where
227 S1: TensorRef<T, D1>,
228 S2: TensorRef<T, D2>,
229 S3: TensorRef<T, D3>,
230 S4: TensorRef<T, D4>,
231 S5: TensorRef<T, D5>,
232 I1: Into<TensorView<T, S1, D1>>,
233 I2: Into<TensorView<T, S2, D2>>,
234 I3: Into<TensorView<T, S3, D3>>,
235 I4: Into<TensorView<T, S4, D4>>,
236 I5: Into<TensorView<T, S5, D5>>,
237 {
238 Einsum5 {
239 tensor_1: input_1.into(),
240 tensor_2: input_2.into(),
241 tensor_3: input_3.into(),
242 tensor_4: input_4.into(),
243 tensor_5: input_5.into(),
244 }
245 }
246
247 pub fn with_6<
264 T,
265 S1,
266 S2,
267 S3,
268 S4,
269 S5,
270 S6,
271 I1,
272 I2,
273 I3,
274 I4,
275 I5,
276 I6,
277 const D1: usize,
278 const D2: usize,
279 const D3: usize,
280 const D4: usize,
281 const D5: usize,
282 const D6: usize,
283 >(
284 input_1: I1,
285 input_2: I2,
286 input_3: I3,
287 input_4: I4,
288 input_5: I5,
289 input_6: I6,
290 ) -> Einsum6<T, S1, S2, S3, S4, S5, S6, D1, D2, D3, D4, D5, D6>
291 where
292 S1: TensorRef<T, D1>,
293 S2: TensorRef<T, D2>,
294 S3: TensorRef<T, D3>,
295 S4: TensorRef<T, D4>,
296 S5: TensorRef<T, D5>,
297 S6: TensorRef<T, D6>,
298 I1: Into<TensorView<T, S1, D1>>,
299 I2: Into<TensorView<T, S2, D2>>,
300 I3: Into<TensorView<T, S3, D3>>,
301 I4: Into<TensorView<T, S4, D4>>,
302 I5: Into<TensorView<T, S5, D5>>,
303 I6: Into<TensorView<T, S6, D6>>,
304 {
305 Einsum6 {
306 tensor_1: input_1.into(),
307 tensor_2: input_2.into(),
308 tensor_3: input_3.into(),
309 tensor_4: input_4.into(),
310 tensor_5: input_5.into(),
311 tensor_6: input_6.into(),
312 }
313 }
314}
315
316#[derive(Clone, Debug, Eq, PartialEq)]
321pub struct InconsistentDimensionLengthError<const I: usize> {
322 pub lengths: [Option<usize>; I],
329 pub dimension: Dimension,
333}
334
335impl<const I: usize> fmt::Display for InconsistentDimensionLengthError<I> {
336 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337 write!(
338 f,
339 "inconsistent dimension lengths for dimension '{}': {:?}, lengths must match when repeated in different shapes as the same dimension name",
340 self.dimension,
341 self.lengths,
342 )
343 }
344}
345
346impl<const I: usize> Error for InconsistentDimensionLengthError<I> {}
347
348#[test]
349fn test_inconsistent_dimension_length_error() {
350 let error = InconsistentDimensionLengthError {
351 lengths: [Some(3), None, Some(2)],
352 dimension: "a",
353 };
354 assert_eq!(
355 error.to_string(),
356 "inconsistent dimension lengths for dimension 'a': [Some(3), None, Some(2)], lengths must match when repeated in different shapes as the same dimension name",
357 )
358}
359
360#[allow(dead_code)]
376#[derive(Clone, Debug, Eq, PartialEq)]
377struct Contraction {
378 tensor_indexes: Vec<usize>,
379}
380
381#[allow(dead_code)]
383impl Contraction {
384 fn from(tensor_indexes: Vec<usize>) -> Contraction {
388 Contraction { tensor_indexes }
389 }
390
391 fn indexes(&self) -> &[usize] {
395 &self.tensor_indexes
396 }
397}
398
399#[allow(dead_code)]
400#[derive(Clone, Debug, Eq, PartialEq)]
401struct StepByStepContractionResult {
402 input_shapes_left: Vec<Vec<(Dimension, usize)>>,
403 contraction_output: Vec<(Dimension, usize)>,
404}
405
406#[allow(dead_code)]
414fn step_by_step_contraction(
415 input_shapes_left: &[&[(Dimension, usize)]],
416 output_shape: &[(Dimension, usize)],
417 contraction: &Contraction,
418) -> StepByStepContractionResult {
419 let contracting: Vec<&[(Dimension, usize)]> = contraction
422 .tensor_indexes
423 .iter()
424 .map(|index| input_shapes_left[*index])
425 .collect();
426
427 let not_contracting_yet: Vec<&[(Dimension, usize)]> = input_shapes_left
430 .iter()
431 .enumerate()
432 .filter(|(i, _)| !contraction.tensor_indexes.contains(i))
433 .map(|(_, s)| *s)
434 .collect();
435
436 let contracting_dimensions: Vec<(Dimension, usize)> = {
441 let mut seen = HashSet::new();
442 let mut set = Vec::new();
443 for shape in &contracting {
444 for d in shape.iter() {
445 let new = seen.insert(*d);
446 if new {
447 set.push(*d);
448 }
449 }
450 }
451 set
452 };
453
454 let retained_dimensions: Vec<(Dimension, usize)> = {
459 let mut seen = HashSet::new();
460 let mut set = Vec::new();
461 for shape in ¬_contracting_yet {
462 for d in shape.iter() {
463 let new = seen.insert(*d);
464 if new {
465 set.push(*d);
466 }
467 }
468 }
469 for d in output_shape.iter() {
470 let new = seen.insert(*d);
471 if new {
472 set.push(*d);
473 }
474 }
475 set
476 };
477
478 let contraction_output: Vec<(Dimension, usize)> = {
482 let mut intersection = retained_dimensions.clone();
483 intersection.retain(|shape| contracting_dimensions.contains(shape));
484 intersection
485 };
486
487 let new_input_shapes_left = {
492 let mut vec = Vec::with_capacity(not_contracting_yet.len() + 1);
493 for d in not_contracting_yet.iter() {
494 vec.push(d.to_vec());
495 }
496 vec.push(contraction_output.clone());
497 vec
498 };
499
500 StepByStepContractionResult {
501 contraction_output,
502 input_shapes_left: new_input_shapes_left,
503 }
504}
505
506fn length_of<const I: usize>(
509 output_dimension: Dimension,
510 input: &[&[(Dimension, usize)]; I],
511) -> Result<usize, InconsistentDimensionLengthError<I>> {
512 let lengths = input.map(|shapes| {
513 shapes
514 .iter()
515 .find(|(dimension, _length)| *dimension == output_dimension)
516 .map(|(_dimension, length)| *length)
517 });
518
519 let first_length = lengths.iter().filter_map(|l| *l).next();
520 if let Some(length) = first_length {
521 if lengths.iter().any(|l| l.is_some() && *l != Some(length)) {
523 Err(InconsistentDimensionLengthError {
525 lengths,
526 dimension: output_dimension,
527 })
528 } else {
529 Ok(length)
530 }
531 } else {
532 Err(InconsistentDimensionLengthError {
534 lengths,
535 dimension: output_dimension,
536 })
537 }
538}
539
540#[track_caller]
541fn tensor_with_name<T, I, S, const D: usize>(
542 dimensions: [Dimension; D],
543 tensor: I,
544) -> TensorView<T, TensorRename<T, S, D>, D>
545where
546 I: Into<TensorView<T, S, D>>,
547 S: TensorRef<T, D>,
548{
549 let source: S = tensor.into().source();
550 let with_names = TensorRename::from(source, dimensions);
551 TensorView::from(with_names)
552}
553
554fn output_shape_for<const I: usize, const O: usize>(
564 input: &[&[(Dimension, usize)]; I],
565 output: &[Dimension; O],
566) -> Result<[(Dimension, usize); O], InconsistentDimensionLengthError<I>> {
567 let mut output_shape = std::array::from_fn(|d| (output[d], 0));
568 for x in output_shape.iter_mut() {
569 x.1 = length_of(x.0, input)?;
570 }
571 Ok(output_shape)
572}
573
574fn summation_dimensions<const I: usize, const O: usize>(
580 input: &[&[(Dimension, usize)]; I],
581 output: &[Dimension; O],
582) -> Result<Vec<(Dimension, usize)>, InconsistentDimensionLengthError<I>> {
583 let mut total_dimensions = 0;
584 for shape in input {
585 total_dimensions += shape.len();
586 }
587
588 let mut unique_dimensions = Vec::with_capacity(total_dimensions);
591
592 for shape in input {
593 for (dimension, length) in shape.iter() {
594 if output.contains(dimension) {
595 continue;
599 }
600 let existing = unique_dimensions.iter().find(|(d, _)| d == dimension);
601 match existing {
602 None => unique_dimensions.push((*dimension, *length)),
603 Some((_, l)) => {
604 if length != l {
605 return Err(InconsistentDimensionLengthError {
607 lengths: std::array::from_fn(|i| {
608 input[i]
609 .iter()
610 .find(|(d, _)| d == dimension)
611 .map(|(_, l)| *l)
612 }),
613 dimension,
614 });
615 }
616 }
617 }
618 }
619 }
620
621 Ok(unique_dimensions)
622}
623
624fn filter_outer_indexes<const D: usize, const O: usize>(
630 outer_indexes: &[usize; O],
631 outer_shape: &[(Dimension, usize); O],
632 input_shape: &[(Dimension, usize); D],
633) -> [usize; D] {
634 let mut input_indexes = [0; D];
635 for d in 0..D {
636 let mut found = false;
637 let dimension = input_shape[d].0;
638 for o in 0..O {
639 let possible_dimension = outer_shape[o].0;
640 if possible_dimension == dimension {
641 input_indexes[d] = outer_indexes[o];
642 found = true;
643 break;
644 }
645 }
646 if !found {
647 panic!(
648 "Expected to find an index for dimension {:?} but was not present in {:?} for {:?} while trying to index tensor of shape {:?}",
649 dimension,
650 outer_indexes,
651 outer_shape,
652 input_shape,
653 );
654 }
655 }
656 input_indexes
657}
658
659fn filter_outer_and_summation_indexes<const D: usize, const O: usize>(
667 outer_indexes: &[usize; O],
668 outer_shape: &[(Dimension, usize); O],
669 summation_indexes: &[usize],
670 summation_shape: &[(Dimension, usize)],
671 input_shape: &[(Dimension, usize); D],
672) -> [usize; D] {
673 let mut input_indexes = [0; D];
674 for d in 0..D {
675 let mut found = false;
676 let dimension = input_shape[d].0;
677 for o in 0..O {
678 let possible_dimension = outer_shape[o].0;
679 if possible_dimension == dimension {
680 input_indexes[d] = outer_indexes[o];
681 found = true;
682 break;
683 }
684 }
685 let summation_iter = summation_indexes.iter().zip(summation_shape.iter());
686 for (index, (possible_dimension, _length)) in summation_iter {
687 if *possible_dimension == dimension {
688 input_indexes[d] = *index;
689 found = true;
690 break;
691 }
692 }
693 if !found {
694 panic!(
695 "Expected to find an index for dimension {:?} but was not present in {:?} for {:?} or {:?} for {:?} while trying to index tensor of shape {:?}",
696 dimension,
697 outer_indexes,
698 outer_shape,
699 summation_indexes,
700 summation_shape,
701 input_shape,
702 );
703 }
704 }
705 input_indexes
706}
707
708pub struct Einsum1<T, S1, const D1: usize> {
712 tensor_1: TensorView<T, S1, D1>,
713}
714
715pub struct Einsum2<T, S1, S2, const D1: usize, const D2: usize> {
719 tensor_1: TensorView<T, S1, D1>,
720 tensor_2: TensorView<T, S2, D2>,
721}
722
723pub struct Einsum3<T, S1, S2, S3, const D1: usize, const D2: usize, const D3: usize> {
727 tensor_1: TensorView<T, S1, D1>,
728 tensor_2: TensorView<T, S2, D2>,
729 tensor_3: TensorView<T, S3, D3>,
730}
731
732pub struct Einsum4<
736 T,
737 S1,
738 S2,
739 S3,
740 S4,
741 const D1: usize,
742 const D2: usize,
743 const D3: usize,
744 const D4: usize,
745> {
746 tensor_1: TensorView<T, S1, D1>,
747 tensor_2: TensorView<T, S2, D2>,
748 tensor_3: TensorView<T, S3, D3>,
749 tensor_4: TensorView<T, S4, D4>,
750}
751
752pub struct Einsum5<
756 T,
757 S1,
758 S2,
759 S3,
760 S4,
761 S5,
762 const D1: usize,
763 const D2: usize,
764 const D3: usize,
765 const D4: usize,
766 const D5: usize,
767> {
768 tensor_1: TensorView<T, S1, D1>,
769 tensor_2: TensorView<T, S2, D2>,
770 tensor_3: TensorView<T, S3, D3>,
771 tensor_4: TensorView<T, S4, D4>,
772 tensor_5: TensorView<T, S5, D5>,
773}
774
775pub struct Einsum6<
790 T,
791 S1,
792 S2,
793 S3,
794 S4,
795 S5,
796 S6,
797 const D1: usize,
798 const D2: usize,
799 const D3: usize,
800 const D4: usize,
801 const D5: usize,
802 const D6: usize,
803> {
804 tensor_1: TensorView<T, S1, D1>,
805 tensor_2: TensorView<T, S2, D2>,
806 tensor_3: TensorView<T, S3, D3>,
807 tensor_4: TensorView<T, S4, D4>,
808 tensor_5: TensorView<T, S5, D5>,
809 tensor_6: TensorView<T, S6, D6>,
810}
811
812impl<T, S1, const D1: usize> Einsum1<T, S1, D1> {
813 #[track_caller]
820 pub fn named(self, input_1: [Dimension; D1]) -> Einsum1<T, TensorRename<T, S1, D1>, D1>
821 where
822 S1: TensorRef<T, D1>,
823 {
824 Einsum1 {
825 tensor_1: tensor_with_name(input_1, self.tensor_1),
826 }
827 }
828
829 pub fn to<const O: usize>(
830 self,
831 output: [Dimension; O],
832 ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<1>>
833 where
834 T: Numeric,
835 for<'a> &'a T: NumericRef<T>,
836 S1: TensorRef<T, D1>,
837 {
838 let input_1_shape_const = &self.tensor_1.shape();
839 let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
840 let input = &[input_1_shape];
841
842 let output_shape = output_shape_for(input, &output)?;
843 let mut output_tensor = Tensor::empty(output_shape, T::zero());
844
845 let summation_dimensions = summation_dimensions(input, &output)?;
846 let tensor_1_indexing = self.tensor_1.index();
847
848 for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
849 let mut sum = T::zero();
850
851 if summation_dimensions.is_empty() {
852 let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
853 &indexes,
854 &output_shape,
855 input_1_shape_const,
856 ));
857 sum = sum + product_1;
858 } else {
859 let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
860 loop {
861 let next = summation_iterator.next();
862 match next {
863 Some(summation_indexes) => {
864 let product_1 =
865 tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
866 &indexes,
867 &output_shape,
868 summation_indexes,
869 &summation_dimensions,
870 input_1_shape_const,
871 ));
872 sum = sum + product_1;
873 }
874 None => break,
875 }
876 }
877 }
878 *element = sum;
879 }
880
881 Ok(output_tensor)
882 }
883}
884
885impl<T, S1, S2, const D1: usize, const D2: usize> Einsum2<T, S1, S2, D1, D2> {
886 #[track_caller]
893 pub fn named(
894 self,
895 input_1: [Dimension; D1],
896 input_2: [Dimension; D2],
897 ) -> Einsum2<T, TensorRename<T, S1, D1>, TensorRename<T, S2, D2>, D1, D2>
898 where
899 S1: TensorRef<T, D1>,
900 S2: TensorRef<T, D2>,
901 {
902 Einsum2 {
903 tensor_1: tensor_with_name(input_1, self.tensor_1),
904 tensor_2: tensor_with_name(input_2, self.tensor_2),
905 }
906 }
907
908 pub fn to<const O: usize>(
909 self,
910 output: [Dimension; O],
911 ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<2>>
912 where
913 T: Numeric,
914 for<'a> &'a T: NumericRef<T>,
915 S1: TensorRef<T, D1>,
916 S2: TensorRef<T, D2>,
917 {
918 let input_1_shape_const = &self.tensor_1.shape();
919 let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
920 let input_2_shape_const = &self.tensor_2.shape();
921 let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
922 let input = &[input_1_shape, input_2_shape];
923
924 let output_shape = output_shape_for(input, &output)?;
925 let mut output_tensor = Tensor::empty(output_shape, T::zero());
926
927 let summation_dimensions = summation_dimensions(input, &output)?;
928 let tensor_1_indexing = self.tensor_1.index();
929 let tensor_2_indexing = self.tensor_2.index();
930
931 for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
932 let mut sum = T::zero();
933
934 if summation_dimensions.is_empty() {
935 let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
936 &indexes,
937 &output_shape,
938 input_1_shape_const,
939 ));
940 let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
941 &indexes,
942 &output_shape,
943 input_2_shape_const,
944 ));
945 sum = sum + (product_1 * product_2);
946 } else {
947 let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
948 loop {
949 let next = summation_iterator.next();
950 match next {
951 Some(summation_indexes) => {
952 let product_1 =
953 tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
954 &indexes,
955 &output_shape,
956 summation_indexes,
957 &summation_dimensions,
958 input_1_shape_const,
959 ));
960 let product_2 =
961 tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
962 &indexes,
963 &output_shape,
964 summation_indexes,
965 &summation_dimensions,
966 input_2_shape_const,
967 ));
968 sum = sum + (product_1 * product_2);
969 }
970 None => break,
971 }
972 }
973 }
974
975 *element = sum;
976 }
977
978 Ok(output_tensor)
979 }
980}
981
982impl<T, S1, S2, S3, const D1: usize, const D2: usize, const D3: usize>
983 Einsum3<T, S1, S2, S3, D1, D2, D3>
984{
985 #[track_caller]
992 #[allow(clippy::type_complexity)]
993 pub fn named(
994 self,
995 input_1: [Dimension; D1],
996 input_2: [Dimension; D2],
997 input_3: [Dimension; D3],
998 ) -> Einsum3<
999 T,
1000 TensorRename<T, S1, D1>,
1001 TensorRename<T, S2, D2>,
1002 TensorRename<T, S3, D3>,
1003 D1,
1004 D2,
1005 D3,
1006 >
1007 where
1008 S1: TensorRef<T, D1>,
1009 S2: TensorRef<T, D2>,
1010 S3: TensorRef<T, D3>,
1011 {
1012 Einsum3 {
1013 tensor_1: tensor_with_name(input_1, self.tensor_1),
1014 tensor_2: tensor_with_name(input_2, self.tensor_2),
1015 tensor_3: tensor_with_name(input_3, self.tensor_3),
1016 }
1017 }
1018
1019 pub fn to<const O: usize>(
1020 self,
1021 output: [Dimension; O],
1022 ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<3>>
1023 where
1024 T: Numeric,
1025 for<'a> &'a T: NumericRef<T>,
1026 S1: TensorRef<T, D1>,
1027 S2: TensorRef<T, D2>,
1028 S3: TensorRef<T, D3>,
1029 {
1030 let input_1_shape_const = &self.tensor_1.shape();
1031 let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
1032 let input_2_shape_const = &self.tensor_2.shape();
1033 let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
1034 let input_3_shape_const = &self.tensor_3.shape();
1035 let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
1036 let input = &[input_1_shape, input_2_shape, input_3_shape];
1037
1038 let output_shape = output_shape_for(input, &output)?;
1039 let mut output_tensor = Tensor::empty(output_shape, T::zero());
1040
1041 let summation_dimensions = summation_dimensions(input, &output)?;
1042 let tensor_1_indexing = self.tensor_1.index();
1043 let tensor_2_indexing = self.tensor_2.index();
1044 let tensor_3_indexing = self.tensor_3.index();
1045
1046 for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
1047 let mut sum = T::zero();
1048
1049 if summation_dimensions.is_empty() {
1050 let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
1051 &indexes,
1052 &output_shape,
1053 input_1_shape_const,
1054 ));
1055 let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
1056 &indexes,
1057 &output_shape,
1058 input_2_shape_const,
1059 ));
1060 let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
1061 &indexes,
1062 &output_shape,
1063 input_3_shape_const,
1064 ));
1065 sum = sum + (product_1 * product_2 * product_3);
1066 } else {
1067 let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
1068 loop {
1069 let next = summation_iterator.next();
1070 match next {
1071 Some(summation_indexes) => {
1072 let product_1 =
1073 tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
1074 &indexes,
1075 &output_shape,
1076 summation_indexes,
1077 &summation_dimensions,
1078 input_1_shape_const,
1079 ));
1080 let product_2 =
1081 tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
1082 &indexes,
1083 &output_shape,
1084 summation_indexes,
1085 &summation_dimensions,
1086 input_2_shape_const,
1087 ));
1088 let product_3 =
1089 tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
1090 &indexes,
1091 &output_shape,
1092 summation_indexes,
1093 &summation_dimensions,
1094 input_3_shape_const,
1095 ));
1096 sum = sum + (product_1 * product_2 * product_3);
1097 }
1098 None => break,
1099 }
1100 }
1101 }
1102
1103 *element = sum;
1104 }
1105
1106 Ok(output_tensor)
1107 }
1108}
1109
1110impl<T, S1, S2, S3, S4, const D1: usize, const D2: usize, const D3: usize, const D4: usize>
1111 Einsum4<T, S1, S2, S3, S4, D1, D2, D3, D4>
1112{
1113 #[track_caller]
1120 #[allow(clippy::type_complexity)]
1121 pub fn named(
1122 self,
1123 input_1: [Dimension; D1],
1124 input_2: [Dimension; D2],
1125 input_3: [Dimension; D3],
1126 input_4: [Dimension; D4],
1127 ) -> Einsum4<
1128 T,
1129 TensorRename<T, S1, D1>,
1130 TensorRename<T, S2, D2>,
1131 TensorRename<T, S3, D3>,
1132 TensorRename<T, S4, D4>,
1133 D1,
1134 D2,
1135 D3,
1136 D4,
1137 >
1138 where
1139 S1: TensorRef<T, D1>,
1140 S2: TensorRef<T, D2>,
1141 S3: TensorRef<T, D3>,
1142 S4: TensorRef<T, D4>,
1143 {
1144 Einsum4 {
1145 tensor_1: tensor_with_name(input_1, self.tensor_1),
1146 tensor_2: tensor_with_name(input_2, self.tensor_2),
1147 tensor_3: tensor_with_name(input_3, self.tensor_3),
1148 tensor_4: tensor_with_name(input_4, self.tensor_4),
1149 }
1150 }
1151
1152 pub fn to<const O: usize>(
1153 self,
1154 output: [Dimension; O],
1155 ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<4>>
1156 where
1157 T: Numeric,
1158 for<'a> &'a T: NumericRef<T>,
1159 S1: TensorRef<T, D1>,
1160 S2: TensorRef<T, D2>,
1161 S3: TensorRef<T, D3>,
1162 S4: TensorRef<T, D4>,
1163 {
1164 let input_1_shape_const = &self.tensor_1.shape();
1165 let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
1166 let input_2_shape_const = &self.tensor_2.shape();
1167 let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
1168 let input_3_shape_const = &self.tensor_3.shape();
1169 let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
1170 let input_4_shape_const = &self.tensor_4.shape();
1171 let input_4_shape: &[(Dimension, usize)] = input_4_shape_const;
1172 let input = &[input_1_shape, input_2_shape, input_3_shape, input_4_shape];
1173
1174 let output_shape = output_shape_for(input, &output)?;
1175 let mut output_tensor = Tensor::empty(output_shape, T::zero());
1176
1177 let summation_dimensions = summation_dimensions(input, &output)?;
1178 let tensor_1_indexing = self.tensor_1.index();
1179 let tensor_2_indexing = self.tensor_2.index();
1180 let tensor_3_indexing = self.tensor_3.index();
1181 let tensor_4_indexing = self.tensor_4.index();
1182
1183 for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
1184 let mut sum = T::zero();
1185
1186 if summation_dimensions.is_empty() {
1187 let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
1188 &indexes,
1189 &output_shape,
1190 input_1_shape_const,
1191 ));
1192 let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
1193 &indexes,
1194 &output_shape,
1195 input_2_shape_const,
1196 ));
1197 let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
1198 &indexes,
1199 &output_shape,
1200 input_3_shape_const,
1201 ));
1202 let product_4 = tensor_4_indexing.get_ref(filter_outer_indexes(
1203 &indexes,
1204 &output_shape,
1205 input_4_shape_const,
1206 ));
1207 sum = sum + (product_1 * product_2 * product_3 * product_4);
1208 } else {
1209 let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
1210 loop {
1211 let next = summation_iterator.next();
1212 match next {
1213 Some(summation_indexes) => {
1214 let product_1 =
1215 tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
1216 &indexes,
1217 &output_shape,
1218 summation_indexes,
1219 &summation_dimensions,
1220 input_1_shape_const,
1221 ));
1222 let product_2 =
1223 tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
1224 &indexes,
1225 &output_shape,
1226 summation_indexes,
1227 &summation_dimensions,
1228 input_2_shape_const,
1229 ));
1230 let product_3 =
1231 tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
1232 &indexes,
1233 &output_shape,
1234 summation_indexes,
1235 &summation_dimensions,
1236 input_3_shape_const,
1237 ));
1238 let product_4 =
1239 tensor_4_indexing.get_ref(filter_outer_and_summation_indexes(
1240 &indexes,
1241 &output_shape,
1242 summation_indexes,
1243 &summation_dimensions,
1244 input_4_shape_const,
1245 ));
1246 sum = sum + (product_1 * product_2 * product_3 * product_4);
1247 }
1248 None => break,
1249 }
1250 }
1251 }
1252
1253 *element = sum;
1254 }
1255
1256 Ok(output_tensor)
1257 }
1258}
1259
1260impl<
1261 T,
1262 S1,
1263 S2,
1264 S3,
1265 S4,
1266 S5,
1267 const D1: usize,
1268 const D2: usize,
1269 const D3: usize,
1270 const D4: usize,
1271 const D5: usize,
1272 > Einsum5<T, S1, S2, S3, S4, S5, D1, D2, D3, D4, D5>
1273{
1274 #[track_caller]
1281 #[allow(clippy::type_complexity)]
1282 pub fn named(
1283 self,
1284 input_1: [Dimension; D1],
1285 input_2: [Dimension; D2],
1286 input_3: [Dimension; D3],
1287 input_4: [Dimension; D4],
1288 input_5: [Dimension; D5],
1289 ) -> Einsum5<
1290 T,
1291 TensorRename<T, S1, D1>,
1292 TensorRename<T, S2, D2>,
1293 TensorRename<T, S3, D3>,
1294 TensorRename<T, S4, D4>,
1295 TensorRename<T, S5, D5>,
1296 D1,
1297 D2,
1298 D3,
1299 D4,
1300 D5,
1301 >
1302 where
1303 S1: TensorRef<T, D1>,
1304 S2: TensorRef<T, D2>,
1305 S3: TensorRef<T, D3>,
1306 S4: TensorRef<T, D4>,
1307 S5: TensorRef<T, D5>,
1308 {
1309 Einsum5 {
1310 tensor_1: tensor_with_name(input_1, self.tensor_1),
1311 tensor_2: tensor_with_name(input_2, self.tensor_2),
1312 tensor_3: tensor_with_name(input_3, self.tensor_3),
1313 tensor_4: tensor_with_name(input_4, self.tensor_4),
1314 tensor_5: tensor_with_name(input_5, self.tensor_5),
1315 }
1316 }
1317
1318 pub fn to<const O: usize>(
1319 self,
1320 output: [Dimension; O],
1321 ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<5>>
1322 where
1323 T: Numeric,
1324 for<'a> &'a T: NumericRef<T>,
1325 S1: TensorRef<T, D1>,
1326 S2: TensorRef<T, D2>,
1327 S3: TensorRef<T, D3>,
1328 S4: TensorRef<T, D4>,
1329 S5: TensorRef<T, D5>,
1330 {
1331 let input_1_shape_const = &self.tensor_1.shape();
1332 let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
1333 let input_2_shape_const = &self.tensor_2.shape();
1334 let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
1335 let input_3_shape_const = &self.tensor_3.shape();
1336 let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
1337 let input_4_shape_const = &self.tensor_4.shape();
1338 let input_4_shape: &[(Dimension, usize)] = input_4_shape_const;
1339 let input_5_shape_const = &self.tensor_5.shape();
1340 let input_5_shape: &[(Dimension, usize)] = input_5_shape_const;
1341 let input = &[
1342 input_1_shape,
1343 input_2_shape,
1344 input_3_shape,
1345 input_4_shape,
1346 input_5_shape,
1347 ];
1348
1349 let output_shape = output_shape_for(input, &output)?;
1350 let mut output_tensor = Tensor::empty(output_shape, T::zero());
1351
1352 let summation_dimensions = summation_dimensions(input, &output)?;
1353 let tensor_1_indexing = self.tensor_1.index();
1354 let tensor_2_indexing = self.tensor_2.index();
1355 let tensor_3_indexing = self.tensor_3.index();
1356 let tensor_4_indexing = self.tensor_4.index();
1357 let tensor_5_indexing = self.tensor_5.index();
1358
1359 for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
1360 let mut sum = T::zero();
1361
1362 if summation_dimensions.is_empty() {
1363 let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
1364 &indexes,
1365 &output_shape,
1366 input_1_shape_const,
1367 ));
1368 let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
1369 &indexes,
1370 &output_shape,
1371 input_2_shape_const,
1372 ));
1373 let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
1374 &indexes,
1375 &output_shape,
1376 input_3_shape_const,
1377 ));
1378 let product_4 = tensor_4_indexing.get_ref(filter_outer_indexes(
1379 &indexes,
1380 &output_shape,
1381 input_4_shape_const,
1382 ));
1383 let product_5 = tensor_5_indexing.get_ref(filter_outer_indexes(
1384 &indexes,
1385 &output_shape,
1386 input_5_shape_const,
1387 ));
1388 sum = sum + (product_1 * product_2 * product_3 * product_4 * product_5);
1389 } else {
1390 let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
1391 loop {
1392 let next = summation_iterator.next();
1393 match next {
1394 Some(summation_indexes) => {
1395 let product_1 =
1396 tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
1397 &indexes,
1398 &output_shape,
1399 summation_indexes,
1400 &summation_dimensions,
1401 input_1_shape_const,
1402 ));
1403 let product_2 =
1404 tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
1405 &indexes,
1406 &output_shape,
1407 summation_indexes,
1408 &summation_dimensions,
1409 input_2_shape_const,
1410 ));
1411 let product_3 =
1412 tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
1413 &indexes,
1414 &output_shape,
1415 summation_indexes,
1416 &summation_dimensions,
1417 input_3_shape_const,
1418 ));
1419 let product_4 =
1420 tensor_4_indexing.get_ref(filter_outer_and_summation_indexes(
1421 &indexes,
1422 &output_shape,
1423 summation_indexes,
1424 &summation_dimensions,
1425 input_4_shape_const,
1426 ));
1427 let product_5 =
1428 tensor_5_indexing.get_ref(filter_outer_and_summation_indexes(
1429 &indexes,
1430 &output_shape,
1431 summation_indexes,
1432 &summation_dimensions,
1433 input_5_shape_const,
1434 ));
1435 sum = sum + (product_1 * product_2 * product_3 * product_4 * product_5);
1436 }
1437 None => break,
1438 }
1439 }
1440 }
1441
1442 *element = sum;
1443 }
1444
1445 Ok(output_tensor)
1446 }
1447}
1448
1449impl<
1450 T,
1451 S1,
1452 S2,
1453 S3,
1454 S4,
1455 S5,
1456 S6,
1457 const D1: usize,
1458 const D2: usize,
1459 const D3: usize,
1460 const D4: usize,
1461 const D5: usize,
1462 const D6: usize,
1463 > Einsum6<T, S1, S2, S3, S4, S5, S6, D1, D2, D3, D4, D5, D6>
1464{
1465 #[track_caller]
1472 #[allow(clippy::type_complexity)]
1473 pub fn named(
1474 self,
1475 input_1: [Dimension; D1],
1476 input_2: [Dimension; D2],
1477 input_3: [Dimension; D3],
1478 input_4: [Dimension; D4],
1479 input_5: [Dimension; D5],
1480 input_6: [Dimension; D6],
1481 ) -> Einsum6<
1482 T,
1483 TensorRename<T, S1, D1>,
1484 TensorRename<T, S2, D2>,
1485 TensorRename<T, S3, D3>,
1486 TensorRename<T, S4, D4>,
1487 TensorRename<T, S5, D5>,
1488 TensorRename<T, S6, D6>,
1489 D1,
1490 D2,
1491 D3,
1492 D4,
1493 D5,
1494 D6,
1495 >
1496 where
1497 S1: TensorRef<T, D1>,
1498 S2: TensorRef<T, D2>,
1499 S3: TensorRef<T, D3>,
1500 S4: TensorRef<T, D4>,
1501 S5: TensorRef<T, D5>,
1502 S6: TensorRef<T, D6>,
1503 {
1504 Einsum6 {
1505 tensor_1: tensor_with_name(input_1, self.tensor_1),
1506 tensor_2: tensor_with_name(input_2, self.tensor_2),
1507 tensor_3: tensor_with_name(input_3, self.tensor_3),
1508 tensor_4: tensor_with_name(input_4, self.tensor_4),
1509 tensor_5: tensor_with_name(input_5, self.tensor_5),
1510 tensor_6: tensor_with_name(input_6, self.tensor_6),
1511 }
1512 }
1513
1514 pub fn to<const O: usize>(
1515 self,
1516 output: [Dimension; O],
1517 ) -> Result<Tensor<T, O>, InconsistentDimensionLengthError<6>>
1518 where
1519 T: Numeric,
1520 for<'a> &'a T: NumericRef<T>,
1521 S1: TensorRef<T, D1>,
1522 S2: TensorRef<T, D2>,
1523 S3: TensorRef<T, D3>,
1524 S4: TensorRef<T, D4>,
1525 S5: TensorRef<T, D5>,
1526 S6: TensorRef<T, D6>,
1527 {
1528 let input_1_shape_const = &self.tensor_1.shape();
1529 let input_1_shape: &[(Dimension, usize)] = input_1_shape_const;
1530 let input_2_shape_const = &self.tensor_2.shape();
1531 let input_2_shape: &[(Dimension, usize)] = input_2_shape_const;
1532 let input_3_shape_const = &self.tensor_3.shape();
1533 let input_3_shape: &[(Dimension, usize)] = input_3_shape_const;
1534 let input_4_shape_const = &self.tensor_4.shape();
1535 let input_4_shape: &[(Dimension, usize)] = input_4_shape_const;
1536 let input_5_shape_const = &self.tensor_5.shape();
1537 let input_5_shape: &[(Dimension, usize)] = input_5_shape_const;
1538 let input_6_shape_const = &self.tensor_6.shape();
1539 let input_6_shape: &[(Dimension, usize)] = input_6_shape_const;
1540 let input = &[
1541 input_1_shape,
1542 input_2_shape,
1543 input_3_shape,
1544 input_4_shape,
1545 input_5_shape,
1546 input_6_shape,
1547 ];
1548
1549 let output_shape = output_shape_for(input, &output)?;
1550 let mut output_tensor = Tensor::empty(output_shape, T::zero());
1551
1552 let summation_dimensions = summation_dimensions(input, &output)?;
1553 let tensor_1_indexing = self.tensor_1.index();
1554 let tensor_2_indexing = self.tensor_2.index();
1555 let tensor_3_indexing = self.tensor_3.index();
1556 let tensor_4_indexing = self.tensor_4.index();
1557 let tensor_5_indexing = self.tensor_5.index();
1558 let tensor_6_indexing = self.tensor_6.index();
1559
1560 for (indexes, element) in output_tensor.index_mut().iter_reference_mut().with_index() {
1561 let mut sum = T::zero();
1562
1563 if summation_dimensions.is_empty() {
1564 let product_1 = tensor_1_indexing.get_ref(filter_outer_indexes(
1565 &indexes,
1566 &output_shape,
1567 input_1_shape_const,
1568 ));
1569 let product_2 = tensor_2_indexing.get_ref(filter_outer_indexes(
1570 &indexes,
1571 &output_shape,
1572 input_2_shape_const,
1573 ));
1574 let product_3 = tensor_3_indexing.get_ref(filter_outer_indexes(
1575 &indexes,
1576 &output_shape,
1577 input_3_shape_const,
1578 ));
1579 let product_4 = tensor_4_indexing.get_ref(filter_outer_indexes(
1580 &indexes,
1581 &output_shape,
1582 input_4_shape_const,
1583 ));
1584 let product_5 = tensor_5_indexing.get_ref(filter_outer_indexes(
1585 &indexes,
1586 &output_shape,
1587 input_5_shape_const,
1588 ));
1589 let product_6 = tensor_6_indexing.get_ref(filter_outer_indexes(
1590 &indexes,
1591 &output_shape,
1592 input_6_shape_const,
1593 ));
1594 sum = sum + (product_1 * product_2 * product_3 * product_4 * product_5 * product_6);
1595 } else {
1596 let mut summation_iterator = DynamicShapeIterator::from(&summation_dimensions);
1597 loop {
1598 let next = summation_iterator.next();
1599 match next {
1600 Some(summation_indexes) => {
1601 let product_1 =
1602 tensor_1_indexing.get_ref(filter_outer_and_summation_indexes(
1603 &indexes,
1604 &output_shape,
1605 summation_indexes,
1606 &summation_dimensions,
1607 input_1_shape_const,
1608 ));
1609 let product_2 =
1610 tensor_2_indexing.get_ref(filter_outer_and_summation_indexes(
1611 &indexes,
1612 &output_shape,
1613 summation_indexes,
1614 &summation_dimensions,
1615 input_2_shape_const,
1616 ));
1617 let product_3 =
1618 tensor_3_indexing.get_ref(filter_outer_and_summation_indexes(
1619 &indexes,
1620 &output_shape,
1621 summation_indexes,
1622 &summation_dimensions,
1623 input_3_shape_const,
1624 ));
1625 let product_4 =
1626 tensor_4_indexing.get_ref(filter_outer_and_summation_indexes(
1627 &indexes,
1628 &output_shape,
1629 summation_indexes,
1630 &summation_dimensions,
1631 input_4_shape_const,
1632 ));
1633 let product_5 =
1634 tensor_5_indexing.get_ref(filter_outer_and_summation_indexes(
1635 &indexes,
1636 &output_shape,
1637 summation_indexes,
1638 &summation_dimensions,
1639 input_5_shape_const,
1640 ));
1641 let product_6 =
1642 tensor_6_indexing.get_ref(filter_outer_and_summation_indexes(
1643 &indexes,
1644 &output_shape,
1645 summation_indexes,
1646 &summation_dimensions,
1647 input_6_shape_const,
1648 ));
1649 sum = sum
1650 + (product_1
1651 * product_2
1652 * product_3
1653 * product_4
1654 * product_5
1655 * product_6);
1656 }
1657 None => break,
1658 }
1659 }
1660 }
1661
1662 *element = sum;
1663 }
1664
1665 Ok(output_tensor)
1666 }
1667}
1668
1669#[test]
1670fn step_by_step_contraction_tests() {
1671 assert_eq!(
1673 step_by_step_contraction(
1674 &[&[("x", 2), ("y", 3)], &[("y", 3), ("z", 4)]],
1675 &[("x", 2), ("z", 4)],
1676 &Contraction {
1677 tensor_indexes: vec![0, 1]
1678 },
1679 ),
1680 StepByStepContractionResult {
1681 input_shapes_left: vec![vec![("x", 2), ("z", 4)]],
1682 contraction_output: vec![("x", 2), ("z", 4)],
1683 }
1684 );
1685 #[rustfmt::skip]
1688 assert_eq!(
1689 step_by_step_contraction(
1690 &[
1691 &[("a", 2), ("b", 3), ("d", 5)],
1692 &[("a", 2), ("c", 4)],
1693 &[("b", 3), ("d", 5), ("c", 4)],
1694 ],
1695 &[("a", 2), ("c", 4)],
1696 &Contraction {
1697 tensor_indexes: vec![0, 2]
1698 },
1699 ),
1700 StepByStepContractionResult {
1701 input_shapes_left: vec![
1702 vec![("a", 2), ("c", 4)],
1703 vec![("a", 2), ("c", 4)],
1704 ],
1705 contraction_output: vec![("a", 2), ("c", 4)],
1706 }
1707 );
1708 assert_eq!(
1712 step_by_step_contraction(
1713 &[
1714 &[("a", 2), ("b", 3), ("d", 5)],
1715 &[("a", 2), ("c", 4)],
1716 &[("b", 3), ("d", 5), ("c", 4)],
1717 ],
1718 &[("a", 2), ("c", 4)],
1719 &Contraction {
1720 tensor_indexes: vec![0, 1]
1721 },
1722 ),
1723 StepByStepContractionResult {
1724 input_shapes_left: vec![
1725 vec![("b", 3), ("d", 5), ("c", 4)],
1726 vec![("b", 3), ("d", 5), ("c", 4), ("a", 2)],
1727 ],
1728 contraction_output: vec![("b", 3), ("d", 5), ("c", 4), ("a", 2)],
1729 }
1730 );
1731 assert_eq!(
1734 step_by_step_contraction(
1735 &[
1736 &[("a", 2), ("b", 3), ("d", 5)],
1737 &[("a", 2), ("c", 4)],
1738 &[("b", 3), ("d", 5), ("c", 4)],
1739 ],
1740 &[("c", 4)],
1741 &Contraction {
1742 tensor_indexes: vec![0, 1]
1743 },
1744 ),
1745 StepByStepContractionResult {
1746 input_shapes_left: vec![
1747 vec![("b", 3), ("d", 5), ("c", 4)],
1748 vec![("b", 3), ("d", 5), ("c", 4)],
1749 ],
1750 contraction_output: vec![("b", 3), ("d", 5), ("c", 4)],
1751 }
1752 );
1753}
1754
1755