1use hashbrown::HashSet;
41use ndarray::prelude::*;
42use ndarray::LinalgScalar;
43
44use super::{PairContractor, Permutation, SingletonContractor, SingletonViewer};
45use crate::SizedContraction;
46
47fn maybe_find_outputs_in_inputs_unique(
51 output_indices: &[char],
52 input_indices: &[char],
53) -> Vec<Option<usize>> {
54 output_indices
55 .iter()
56 .map(|&output_char| {
57 let input_pos = input_indices
58 .iter()
59 .position(|&input_char| input_char == output_char);
60 if let Some(input) = input_pos {
61 assert!(!input_indices
62 .iter()
63 .skip(input + 1)
64 .any(|&input_char| input_char == output_char));
65 }
66 input_pos
67 })
68 .collect()
69}
70
71fn find_outputs_in_inputs_unique(output_indices: &[char], input_indices: &[char]) -> Vec<usize> {
72 maybe_find_outputs_in_inputs_unique(output_indices, input_indices)
73 .iter()
74 .map(|x| x.unwrap())
75 .collect()
76}
77
78#[derive(Clone, Debug)]
91pub struct TensordotFixedPosition {
92 len_uncontracted_lhs: usize,
95
96 len_uncontracted_rhs: usize,
99
100 len_contracted_axes: usize,
103
104 output_shape: Vec<usize>,
106}
107
108impl TensordotFixedPosition {
109 pub fn new(sc: &SizedContraction) -> Self {
110 assert_eq!(sc.contraction.operand_indices.len(), 2);
111 let lhs_indices = &sc.contraction.operand_indices[0];
112 let rhs_indices = &sc.contraction.operand_indices[1];
113 let output_indices = &sc.contraction.output_indices;
114 let twice_num_contracted_axes =
116 lhs_indices.len() + rhs_indices.len() - output_indices.len();
117 assert_eq!(twice_num_contracted_axes % 2, 0);
118 let num_contracted_axes = twice_num_contracted_axes / 2;
119 let lhs_shape: Vec<usize> = lhs_indices.iter().map(|c| sc.output_size[c]).collect();
122 let rhs_shape: Vec<usize> = rhs_indices.iter().map(|c| sc.output_size[c]).collect();
123
124 TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
125 &lhs_shape,
126 &rhs_shape,
127 num_contracted_axes,
128 )
129 }
130
131 pub fn from_shapes_and_number_of_contracted_axes(
138 lhs_shape: &[usize],
139 rhs_shape: &[usize],
140 num_contracted_axes: usize,
141 ) -> Self {
142 let mut len_uncontracted_lhs = 1;
143 let mut len_uncontracted_rhs = 1;
144 let mut len_contracted_lhs = 1;
145 let mut len_contracted_rhs = 1;
146 let mut output_shape = Vec::new();
147
148 let num_axes_lhs = lhs_shape.len();
149 for (axis, &axis_length) in lhs_shape.iter().enumerate() {
150 if axis < (num_axes_lhs - num_contracted_axes) {
151 len_uncontracted_lhs *= axis_length;
152 output_shape.push(axis_length);
153 } else {
154 len_contracted_lhs *= axis_length;
155 }
156 }
157
158 for (axis, &axis_length) in rhs_shape.iter().enumerate() {
159 if axis < num_contracted_axes {
160 len_contracted_rhs *= axis_length;
161 } else {
162 len_uncontracted_rhs *= axis_length;
163 output_shape.push(axis_length);
164 }
165 }
166 assert_eq!(len_contracted_rhs, len_contracted_lhs);
167 let len_contracted_axes = len_contracted_lhs;
168
169 TensordotFixedPosition {
170 len_uncontracted_lhs,
171 len_uncontracted_rhs,
172 len_contracted_axes,
173 output_shape,
174 }
175 }
176}
177
178impl<A> PairContractor<A> for TensordotFixedPosition {
179 fn contract_pair<'a, 'b, 'c, 'd>(
180 &self,
181 lhs: &'b ArrayViewD<'a, A>,
182 rhs: &'d ArrayViewD<'c, A>,
183 ) -> ArrayD<A>
184 where
185 'a: 'b,
186 'c: 'd,
187 A: Clone + LinalgScalar,
188 {
189 let lhs_array;
190 let lhs_view = if lhs.is_standard_layout() {
191 lhs.view()
192 .into_shape_with_order((self.len_uncontracted_lhs, self.len_contracted_axes))
193 .unwrap()
194 } else {
195 lhs_array = Array::from_shape_vec(
196 [self.len_uncontracted_lhs, self.len_contracted_axes],
197 lhs.iter().cloned().collect(),
198 )
199 .unwrap();
200 lhs_array.view()
201 };
202
203 let rhs_array;
204 let rhs_view = if rhs.is_standard_layout() {
205 rhs.view()
206 .into_shape_with_order((self.len_contracted_axes, self.len_uncontracted_rhs))
207 .unwrap()
208 } else {
209 rhs_array = Array::from_shape_vec(
210 [self.len_contracted_axes, self.len_uncontracted_rhs],
211 rhs.iter().cloned().collect(),
212 )
213 .unwrap();
214 rhs_array.view()
215 };
216
217 lhs_view
218 .dot(&rhs_view)
219 .into_shape_with_order(IxDyn(&self.output_shape))
220 .unwrap()
221 }
222}
223
224#[derive(Clone, Debug)]
234pub struct TensordotGeneral {
235 lhs_permutation: Permutation,
236 rhs_permutation: Permutation,
237 tensordot_fixed_position: TensordotFixedPosition,
238 output_permutation: Permutation,
239}
240
241impl TensordotGeneral {
242 pub fn new(sc: &SizedContraction) -> Self {
243 assert_eq!(sc.contraction.operand_indices.len(), 2);
244 let lhs_indices = &sc.contraction.operand_indices[0];
245 let rhs_indices = &sc.contraction.operand_indices[1];
246 let contracted_indices = &sc.contraction.summation_indices;
247 let output_indices = &sc.contraction.output_indices;
248 let lhs_shape: Vec<usize> = lhs_indices.iter().map(|c| sc.output_size[c]).collect();
249 let rhs_shape: Vec<usize> = rhs_indices.iter().map(|c| sc.output_size[c]).collect();
250
251 TensordotGeneral::from_shapes_and_indices(
252 &lhs_shape,
253 &rhs_shape,
254 lhs_indices,
255 rhs_indices,
256 contracted_indices,
257 output_indices,
258 )
259 }
260
261 fn from_shapes_and_indices(
262 lhs_shape: &[usize],
263 rhs_shape: &[usize],
264 lhs_indices: &[char],
265 rhs_indices: &[char],
266 contracted_indices: &[char],
267 output_indices: &[char],
268 ) -> Self {
269 let lhs_contracted_axes = find_outputs_in_inputs_unique(contracted_indices, lhs_indices);
270 let rhs_contracted_axes = find_outputs_in_inputs_unique(contracted_indices, rhs_indices);
271 let mut uncontracted_chars: Vec<char> = lhs_indices
272 .iter()
273 .filter(|&&input_char| output_indices.contains(&input_char))
274 .cloned()
275 .collect();
276 let mut rhs_uncontracted_chars: Vec<char> = rhs_indices
277 .iter()
278 .filter(|&&input_char| output_indices.contains(&input_char))
279 .cloned()
280 .collect();
281 uncontracted_chars.append(&mut rhs_uncontracted_chars);
282 let output_order = find_outputs_in_inputs_unique(output_indices, &uncontracted_chars);
283
284 TensordotGeneral::from_shapes_and_axis_numbers(
285 lhs_shape,
286 rhs_shape,
287 &lhs_contracted_axes,
288 &rhs_contracted_axes,
289 &output_order,
290 )
291 }
292
293 pub fn from_shapes_and_axis_numbers(
299 lhs_shape: &[usize],
300 rhs_shape: &[usize],
301 lhs_axes: &[usize],
302 rhs_axes: &[usize],
303 output_order: &[usize],
304 ) -> Self {
305 let num_contracted_axes = lhs_axes.len();
306 assert!(num_contracted_axes == rhs_axes.len());
307 let lhs_uniques: HashSet<_> = lhs_axes.iter().cloned().collect();
308 let rhs_uniques: HashSet<_> = rhs_axes.iter().cloned().collect();
309 assert!(num_contracted_axes == lhs_uniques.len());
310 assert!(num_contracted_axes == rhs_uniques.len());
311 let mut adjusted_lhs_shape = Vec::new();
312 let mut adjusted_rhs_shape = Vec::new();
313
314 let mut permutation_lhs = Vec::new();
317 for (i, &axis_length) in lhs_shape.iter().enumerate() {
318 if !(lhs_uniques.contains(&i)) {
319 permutation_lhs.push(i);
320 adjusted_lhs_shape.push(axis_length);
321 }
322 }
323 for &axis in lhs_axes.iter() {
324 permutation_lhs.push(axis);
325 adjusted_lhs_shape.push(lhs_shape[axis]);
326 }
327
328 let mut permutation_rhs = Vec::new();
331 for &axis in rhs_axes.iter() {
332 permutation_rhs.push(axis);
333 adjusted_rhs_shape.push(rhs_shape[axis]);
334 }
335 for (i, &axis_length) in rhs_shape.iter().enumerate() {
336 if !(rhs_uniques.contains(&i)) {
337 permutation_rhs.push(i);
338 adjusted_rhs_shape.push(axis_length);
339 }
340 }
341
342 let lhs_permutation = Permutation::from_indices(&permutation_lhs);
343 let rhs_permutation = Permutation::from_indices(&permutation_rhs);
344 let tensordot_fixed_position =
345 TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
346 &adjusted_lhs_shape,
347 &adjusted_rhs_shape,
348 num_contracted_axes,
349 );
350
351 let output_permutation = Permutation::from_indices(output_order);
352
353 TensordotGeneral {
354 lhs_permutation,
355 rhs_permutation,
356 tensordot_fixed_position,
357 output_permutation,
358 }
359 }
360}
361
362impl<A> PairContractor<A> for TensordotGeneral {
363 fn contract_pair<'a, 'b, 'c, 'd>(
364 &self,
365 lhs: &'b ArrayViewD<'a, A>,
366 rhs: &'d ArrayViewD<'c, A>,
367 ) -> ArrayD<A>
368 where
369 'a: 'b,
370 'c: 'd,
371 A: Clone + LinalgScalar,
372 {
373 let permuted_lhs = self.lhs_permutation.view_singleton(lhs);
374 let permuted_rhs = self.rhs_permutation.view_singleton(rhs);
375 let tensordotted = self
376 .tensordot_fixed_position
377 .contract_pair(&permuted_lhs, &permuted_rhs);
378 self.output_permutation
379 .contract_singleton(&tensordotted.view())
380 }
381}
382
383#[derive(Clone, Debug)]
390pub struct HadamardProduct {}
391
392impl HadamardProduct {
393 pub fn new(sc: &SizedContraction) -> Self {
394 assert_eq!(sc.contraction.operand_indices.len(), 2);
395 let lhs_indices = &sc.contraction.operand_indices[0];
396 let rhs_indices = &sc.contraction.operand_indices[1];
397 let output_indices = &sc.contraction.output_indices;
398 assert_eq!(lhs_indices, rhs_indices);
399 assert_eq!(lhs_indices, output_indices);
400
401 HadamardProduct {}
402 }
403
404 fn from_nothing() -> Self {
405 HadamardProduct {}
406 }
407}
408
409impl<A> PairContractor<A> for HadamardProduct {
410 fn contract_pair<'a, 'b, 'c, 'd>(
411 &self,
412 lhs: &'b ArrayViewD<'a, A>,
413 rhs: &'d ArrayViewD<'c, A>,
414 ) -> ArrayD<A>
415 where
416 'a: 'b,
417 'c: 'd,
418 A: Clone + LinalgScalar,
419 {
420 lhs * rhs
421 }
422}
423
424#[derive(Clone, Debug)]
432pub struct HadamardProductGeneral {
433 lhs_permutation: Permutation,
434 rhs_permutation: Permutation,
435 hadamard_product: HadamardProduct,
436}
437
438impl HadamardProductGeneral {
439 pub fn new(sc: &SizedContraction) -> Self {
440 assert_eq!(sc.contraction.operand_indices.len(), 2);
441 let lhs_indices = &sc.contraction.operand_indices[0];
442 let rhs_indices = &sc.contraction.operand_indices[1];
443 let output_indices = &sc.contraction.output_indices;
444 assert_eq!(lhs_indices.len(), rhs_indices.len());
445 assert_eq!(lhs_indices.len(), output_indices.len());
446
447 let lhs_permutation =
448 Permutation::from_indices(&find_outputs_in_inputs_unique(output_indices, lhs_indices));
449 let rhs_permutation =
450 Permutation::from_indices(&find_outputs_in_inputs_unique(output_indices, rhs_indices));
451 let hadamard_product = HadamardProduct::from_nothing();
452
453 HadamardProductGeneral {
454 lhs_permutation,
455 rhs_permutation,
456 hadamard_product,
457 }
458 }
459}
460
461impl<A> PairContractor<A> for HadamardProductGeneral {
462 fn contract_pair<'a, 'b, 'c, 'd>(
463 &self,
464 lhs: &'b ArrayViewD<'a, A>,
465 rhs: &'d ArrayViewD<'c, A>,
466 ) -> ArrayD<A>
467 where
468 'a: 'b,
469 'c: 'd,
470 A: Clone + LinalgScalar,
471 {
472 self.hadamard_product.contract_pair(
473 &self.lhs_permutation.view_singleton(lhs),
474 &self.rhs_permutation.view_singleton(rhs),
475 )
476 }
477}
478
479#[derive(Clone, Debug)]
486pub struct ScalarMatrixProduct {}
487
488impl ScalarMatrixProduct {
489 pub fn new(sc: &SizedContraction) -> Self {
490 assert_eq!(sc.contraction.operand_indices.len(), 2);
491 let lhs_indices = &sc.contraction.operand_indices[0];
492 let rhs_indices = &sc.contraction.operand_indices[1];
493 let output_indices = &sc.contraction.output_indices;
494 assert_eq!(lhs_indices.len(), 0);
495 assert_eq!(output_indices, rhs_indices);
496
497 ScalarMatrixProduct {}
498 }
499
500 pub fn from_nothing() -> Self {
501 ScalarMatrixProduct {}
502 }
503}
504
505impl<A> PairContractor<A> for ScalarMatrixProduct {
506 fn contract_pair<'a, 'b, 'c, 'd>(
507 &self,
508 lhs: &'b ArrayViewD<'a, A>,
509 rhs: &'d ArrayViewD<'c, A>,
510 ) -> ArrayD<A>
511 where
512 'a: 'b,
513 'c: 'd,
514 A: Clone + LinalgScalar,
515 {
516 let lhs_0d: A = *lhs.first().unwrap();
517 rhs.mapv(|x| x * lhs_0d)
518 }
519}
520
521#[derive(Clone, Debug)]
529pub struct ScalarMatrixProductGeneral {
530 rhs_permutation: Permutation,
531 scalar_matrix_product: ScalarMatrixProduct,
532}
533
534impl ScalarMatrixProductGeneral {
535 pub fn new(sc: &SizedContraction) -> Self {
536 assert_eq!(sc.contraction.operand_indices.len(), 2);
537 let lhs_indices = &sc.contraction.operand_indices[0];
538 let rhs_indices = &sc.contraction.operand_indices[1];
539 let output_indices = &sc.contraction.output_indices;
540 assert_eq!(lhs_indices.len(), 0);
541 assert_eq!(rhs_indices.len(), output_indices.len());
542
543 ScalarMatrixProductGeneral::from_indices(rhs_indices, output_indices)
544 }
545
546 pub fn from_indices(input_indices: &[char], output_indices: &[char]) -> Self {
547 let rhs_permutation = Permutation::from_indices(&find_outputs_in_inputs_unique(
548 output_indices,
549 input_indices,
550 ));
551 let scalar_matrix_product = ScalarMatrixProduct::from_nothing();
552
553 ScalarMatrixProductGeneral {
554 rhs_permutation,
555 scalar_matrix_product,
556 }
557 }
558}
559
560impl<A> PairContractor<A> for ScalarMatrixProductGeneral {
561 fn contract_pair<'a, 'b, 'c, 'd>(
562 &self,
563 lhs: &'b ArrayViewD<'a, A>,
564 rhs: &'d ArrayViewD<'c, A>,
565 ) -> ArrayD<A>
566 where
567 'a: 'b,
568 'c: 'd,
569 A: Clone + LinalgScalar,
570 {
571 self.scalar_matrix_product
572 .contract_pair(lhs, &self.rhs_permutation.view_singleton(rhs))
573 }
574}
575
576#[derive(Clone, Debug)]
583pub struct MatrixScalarProduct {}
584
585impl MatrixScalarProduct {
586 pub fn new(sc: &SizedContraction) -> Self {
587 assert_eq!(sc.contraction.operand_indices.len(), 2);
588 let lhs_indices = &sc.contraction.operand_indices[0];
589 let rhs_indices = &sc.contraction.operand_indices[1];
590 let output_indices = &sc.contraction.output_indices;
591 assert_eq!(rhs_indices.len(), 0);
592 assert_eq!(output_indices, lhs_indices);
593
594 MatrixScalarProduct {}
595 }
596
597 pub fn from_nothing() -> Self {
598 MatrixScalarProduct {}
599 }
600}
601
602impl<A> PairContractor<A> for MatrixScalarProduct {
603 fn contract_pair<'a, 'b, 'c, 'd>(
604 &self,
605 lhs: &'b ArrayViewD<'a, A>,
606 rhs: &'d ArrayViewD<'c, A>,
607 ) -> ArrayD<A>
608 where
609 'a: 'b,
610 'c: 'd,
611 A: Clone + LinalgScalar,
612 {
613 let rhs_0d: A = *rhs.first().unwrap();
614 lhs.mapv(|x| x * rhs_0d)
615 }
616}
617
618#[derive(Clone, Debug)]
626pub struct MatrixScalarProductGeneral {
627 lhs_permutation: Permutation,
628 matrix_scalar_product: MatrixScalarProduct,
629}
630
631impl MatrixScalarProductGeneral {
632 pub fn new(sc: &SizedContraction) -> Self {
633 assert_eq!(sc.contraction.operand_indices.len(), 2);
634 let lhs_indices = &sc.contraction.operand_indices[0];
635 let rhs_indices = &sc.contraction.operand_indices[1];
636 let output_indices = &sc.contraction.output_indices;
637 assert_eq!(rhs_indices.len(), 0);
638 assert_eq!(lhs_indices.len(), output_indices.len());
639
640 MatrixScalarProductGeneral::from_indices(lhs_indices, output_indices)
641 }
642
643 pub fn from_indices(input_indices: &[char], output_indices: &[char]) -> Self {
644 let lhs_permutation = Permutation::from_indices(&find_outputs_in_inputs_unique(
645 output_indices,
646 input_indices,
647 ));
648 let matrix_scalar_product = MatrixScalarProduct::from_nothing();
649
650 MatrixScalarProductGeneral {
651 lhs_permutation,
652 matrix_scalar_product,
653 }
654 }
655}
656
657impl<A> PairContractor<A> for MatrixScalarProductGeneral {
658 fn contract_pair<'a, 'b, 'c, 'd>(
659 &self,
660 lhs: &'b ArrayViewD<'a, A>,
661 rhs: &'d ArrayViewD<'c, A>,
662 ) -> ArrayD<A>
663 where
664 'a: 'b,
665 'c: 'd,
666 A: Clone + LinalgScalar,
667 {
668 self.matrix_scalar_product
669 .contract_pair(&self.lhs_permutation.view_singleton(lhs), rhs)
670 }
671}
672
673#[derive(Clone, Debug)]
684pub struct BroadcastProductGeneral {
685 lhs_permutation: Permutation,
686 rhs_permutation: Permutation,
687 lhs_insertions: Vec<usize>,
688 rhs_insertions: Vec<usize>,
689 output_sizes: Vec<usize>,
690 hadamard_product: HadamardProduct,
691}
692
693impl BroadcastProductGeneral {
694 pub fn new(sc: &SizedContraction) -> Self {
695 assert_eq!(sc.contraction.operand_indices.len(), 2);
696 let lhs_indices = &sc.contraction.operand_indices[0];
697 let rhs_indices = &sc.contraction.operand_indices[1];
698 let output_indices = &sc.contraction.output_indices;
699
700 let maybe_lhs_indices = maybe_find_outputs_in_inputs_unique(output_indices, lhs_indices);
701 let maybe_rhs_indices = maybe_find_outputs_in_inputs_unique(output_indices, rhs_indices);
702 let lhs_indices: Vec<usize> = maybe_lhs_indices.iter().copied().flatten().collect();
703 let rhs_indices: Vec<usize> = maybe_rhs_indices.iter().copied().flatten().collect();
704 let lhs_insertions: Vec<usize> = maybe_lhs_indices
705 .into_iter()
706 .enumerate()
707 .filter(|(_, x)| x.is_none())
708 .map(|(i, _)| i)
709 .collect();
710 let rhs_insertions: Vec<usize> = maybe_rhs_indices
711 .into_iter()
712 .enumerate()
713 .filter(|(_, x)| x.is_none())
714 .map(|(i, _)| i)
715 .collect();
716 let lhs_permutation = Permutation::from_indices(&lhs_indices);
717 let rhs_permutation = Permutation::from_indices(&rhs_indices);
718 let output_sizes: Vec<usize> = output_indices.iter().map(|c| sc.output_size[c]).collect();
719 let hadamard_product = HadamardProduct::from_nothing();
720
721 BroadcastProductGeneral {
722 lhs_permutation,
723 rhs_permutation,
724 lhs_insertions,
725 rhs_insertions,
726 output_sizes,
727 hadamard_product,
728 }
729 }
730}
731
732impl<A> PairContractor<A> for BroadcastProductGeneral {
733 fn contract_pair<'a, 'b, 'c, 'd>(
734 &self,
735 lhs: &'b ArrayViewD<'a, A>,
736 rhs: &'d ArrayViewD<'c, A>,
737 ) -> ArrayD<A>
738 where
739 'a: 'b,
740 'c: 'd,
741 A: Clone + LinalgScalar,
742 {
743 let mut adjusted_lhs = self.lhs_permutation.view_singleton(lhs);
744 let mut adjusted_rhs = self.rhs_permutation.view_singleton(rhs);
745 for &i in self.lhs_insertions.iter() {
746 adjusted_lhs = adjusted_lhs.insert_axis(Axis(i));
747 }
748 for &i in self.rhs_insertions.iter() {
749 adjusted_rhs = adjusted_rhs.insert_axis(Axis(i));
750 }
751 let output_shape = IxDyn(&self.output_sizes);
752 let broadcast_lhs = adjusted_lhs.broadcast(output_shape.clone()).unwrap();
753 let broadcast_rhs = adjusted_rhs.broadcast(output_shape).unwrap();
754 self.hadamard_product
755 .contract_pair(&broadcast_lhs, &broadcast_rhs)
756 }
757}
758
759#[derive(Clone, Debug)]
781pub struct StackedTensordotGeneral {
782 lhs_permutation: Permutation,
783 rhs_permutation: Permutation,
784 lhs_output_shape: Vec<usize>,
785 rhs_output_shape: Vec<usize>,
786 intermediate_shape: Vec<usize>,
787 tensordot_fixed_position: TensordotFixedPosition,
788 output_shape: Vec<usize>,
789 output_permutation: Permutation,
790}
791
792impl StackedTensordotGeneral {
793 pub fn new(sc: &SizedContraction) -> Self {
794 let mut lhs_order = Vec::new();
795 let mut rhs_order = Vec::new();
796 let mut lhs_output_shape = Vec::new();
797 let mut rhs_output_shape = Vec::new();
798 let mut intermediate_shape = Vec::new();
799
800 assert_eq!(sc.contraction.operand_indices.len(), 2);
801 let lhs_indices = &sc.contraction.operand_indices[0];
802 let rhs_indices = &sc.contraction.operand_indices[1];
803 let output_indices = &sc.contraction.output_indices;
804
805 let maybe_lhs_axes = maybe_find_outputs_in_inputs_unique(output_indices, lhs_indices);
806 let maybe_rhs_axes = maybe_find_outputs_in_inputs_unique(output_indices, rhs_indices);
807 let mut lhs_stack_axes = Vec::new();
808 let mut rhs_stack_axes = Vec::new();
809 let mut stack_indices = Vec::new();
810 let mut lhs_outer_axes = Vec::new();
811 let mut lhs_outer_indices = Vec::new();
812 let mut rhs_outer_axes = Vec::new();
813 let mut rhs_outer_indices = Vec::new();
814 let mut lhs_contracted_axes = Vec::new();
815 let mut rhs_contracted_axes = Vec::new();
816 let mut intermediate_indices = Vec::new();
817 let mut lhs_uncontracted_shape = Vec::new();
818 let mut rhs_uncontracted_shape = Vec::new();
819 let mut contracted_shape = Vec::new();
820
821 lhs_output_shape.push(1);
822 rhs_output_shape.push(1);
823
824 for ((&maybe_lhs_pos, &maybe_rhs_pos), &output_char) in maybe_lhs_axes
825 .iter()
826 .zip(maybe_rhs_axes.iter())
827 .zip(output_indices.iter())
828 {
829 match (maybe_lhs_pos, maybe_rhs_pos) {
830 (Some(lhs_pos), Some(rhs_pos)) => {
831 lhs_stack_axes.push(lhs_pos);
832 rhs_stack_axes.push(rhs_pos);
833 stack_indices.push(output_char);
834 lhs_output_shape[0] *= sc.output_size[&output_char];
835 rhs_output_shape[0] *= sc.output_size[&output_char];
836 }
837 (Some(lhs_pos), None) => {
838 lhs_outer_axes.push(lhs_pos);
839 lhs_outer_indices.push(output_char);
840 lhs_uncontracted_shape.push(sc.output_size[&output_char]);
841 }
842 (None, Some(rhs_pos)) => {
843 rhs_outer_axes.push(rhs_pos);
844 rhs_outer_indices.push(output_char);
845 rhs_uncontracted_shape.push(sc.output_size[&output_char]);
846 }
847 (None, None) => {
848 panic!() }
850 }
851 }
852
853 for (lhs_pos, &lhs_char) in lhs_indices.iter().enumerate() {
854 if !output_indices.contains(&lhs_char) {
855 lhs_contracted_axes.push(lhs_pos);
857 rhs_contracted_axes.push(
859 rhs_indices
860 .iter()
861 .position(|&rhs_char| rhs_char == lhs_char)
862 .unwrap(),
863 );
864 contracted_shape.push(sc.output_size[&lhs_char]);
865 }
866 }
867 lhs_order.append(&mut lhs_stack_axes.clone());
873 lhs_order.append(&mut lhs_outer_axes.clone());
874 lhs_output_shape.append(&mut lhs_uncontracted_shape);
875 lhs_order.append(&mut lhs_contracted_axes.clone());
876 lhs_output_shape.append(&mut contracted_shape.clone());
877
878 rhs_order.append(&mut rhs_stack_axes.clone());
879 rhs_order.append(&mut rhs_contracted_axes.clone());
880 rhs_output_shape.append(&mut contracted_shape);
881 rhs_order.append(&mut rhs_outer_axes.clone());
882 rhs_output_shape.append(&mut rhs_uncontracted_shape);
883
884 intermediate_indices.append(&mut stack_indices.clone());
887 intermediate_indices.append(&mut lhs_outer_indices.clone());
888 intermediate_indices.append(&mut rhs_outer_indices.clone());
889
890 assert_eq!(lhs_output_shape[0], rhs_output_shape[0]);
891 intermediate_shape.push(lhs_output_shape[0]);
892 for lhs_char in lhs_outer_indices.iter() {
893 intermediate_shape.push(sc.output_size[lhs_char]);
894 }
895 for rhs_char in rhs_outer_indices.iter() {
896 intermediate_shape.push(sc.output_size[rhs_char]);
897 }
898
899 let output_order = find_outputs_in_inputs_unique(output_indices, &intermediate_indices);
900 let output_shape = intermediate_indices
901 .iter()
902 .map(|c| sc.output_size[c])
903 .collect();
904
905 let tensordot_fixed_position =
906 TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
907 &lhs_output_shape[1..],
908 &rhs_output_shape[1..],
909 lhs_contracted_axes.len(),
910 );
911 let lhs_permutation = Permutation::from_indices(&lhs_order);
912 let rhs_permutation = Permutation::from_indices(&rhs_order);
913 let output_permutation = Permutation::from_indices(&output_order);
914 StackedTensordotGeneral {
915 lhs_permutation,
916 rhs_permutation,
917 lhs_output_shape,
918 rhs_output_shape,
919 intermediate_shape,
920 tensordot_fixed_position,
921 output_shape,
922 output_permutation,
923 }
924 }
925}
926
927impl<A> PairContractor<A> for StackedTensordotGeneral {
928 fn contract_pair<'a, 'b, 'c, 'd>(
929 &self,
930 lhs: &'b ArrayViewD<'a, A>,
931 rhs: &'d ArrayViewD<'c, A>,
932 ) -> ArrayD<A>
933 where
934 'a: 'b,
935 'c: 'd,
936 A: Clone + LinalgScalar,
937 {
938 let lhs_permuted = self.lhs_permutation.view_singleton(lhs);
939 let lhs_reshaped = Array::from_shape_vec(
940 IxDyn(&self.lhs_output_shape),
941 lhs_permuted.iter().cloned().collect(),
942 )
943 .unwrap();
944 let rhs_permuted = self.rhs_permutation.view_singleton(rhs);
945 let rhs_reshaped = Array::from_shape_vec(
946 IxDyn(&self.rhs_output_shape),
947 rhs_permuted.iter().cloned().collect(),
948 )
949 .unwrap();
950 let mut intermediate_result: ArrayD<A> = Array::zeros(IxDyn(&self.intermediate_shape));
951 let mut lhs_iter = lhs_reshaped.outer_iter();
952 let mut rhs_iter = rhs_reshaped.outer_iter();
953 for mut output_subview in intermediate_result.outer_iter_mut() {
954 let lhs_subview = lhs_iter.next().unwrap();
955 let rhs_subview = rhs_iter.next().unwrap();
956 self.tensordot_fixed_position.contract_and_assign_pair(
957 &lhs_subview,
958 &rhs_subview,
959 &mut output_subview,
960 );
961 }
962 let intermediate_reshaped = intermediate_result
963 .into_shape_with_order(IxDyn(&self.output_shape))
964 .unwrap();
965 self.output_permutation
966 .contract_singleton(&intermediate_reshaped.view())
967 }
968}