1use hashbrown::HashSet;
41use ndarray::prelude::*;
42use ndarray::LinalgScalar;
43
44use super::{PairContractor, Permutation, SingletonContractor, SingletonViewer};
45use crate::SizedContraction;
46
47fn maybe_find_outputs_in_inputs_unique(
51 output_indices: &[char],
52 input_indices: &[char],
53) -> Vec<Option<usize>> {
54 output_indices
55 .iter()
56 .map(|&output_char| {
57 let input_pos = input_indices
58 .iter()
59 .position(|&input_char| input_char == output_char);
60 if input_pos.is_some() {
61 assert!(!input_indices
62 .iter()
63 .skip(input_pos.unwrap() + 1)
64 .any(|&input_char| input_char == output_char));
65 }
66 input_pos
67 })
68 .collect()
69}
70
71fn find_outputs_in_inputs_unique(output_indices: &[char], input_indices: &[char]) -> Vec<usize> {
72 maybe_find_outputs_in_inputs_unique(output_indices, input_indices)
73 .iter()
74 .map(|x| x.unwrap())
75 .collect()
76}
77
78#[derive(Clone, Debug)]
91pub struct TensordotFixedPosition {
92 len_uncontracted_lhs: usize,
95
96 len_uncontracted_rhs: usize,
99
100 len_contracted_axes: usize,
103
104 output_shape: Vec<usize>,
106}
107
108impl TensordotFixedPosition {
109 pub fn new(sc: &SizedContraction) -> Self {
110 assert_eq!(sc.contraction.operand_indices.len(), 2);
111 let lhs_indices = &sc.contraction.operand_indices[0];
112 let rhs_indices = &sc.contraction.operand_indices[1];
113 let output_indices = &sc.contraction.output_indices;
114 let twice_num_contracted_axes =
116 lhs_indices.len() + rhs_indices.len() - output_indices.len();
117 assert_eq!(twice_num_contracted_axes % 2, 0);
118 let num_contracted_axes = twice_num_contracted_axes / 2;
119 let lhs_shape: Vec<usize> = lhs_indices.iter().map(|c| sc.output_size[c]).collect();
122 let rhs_shape: Vec<usize> = rhs_indices.iter().map(|c| sc.output_size[c]).collect();
123
124 TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
125 &lhs_shape,
126 &rhs_shape,
127 num_contracted_axes,
128 )
129 }
130
131 pub fn from_shapes_and_number_of_contracted_axes(
138 lhs_shape: &[usize],
139 rhs_shape: &[usize],
140 num_contracted_axes: usize,
141 ) -> Self {
142 let mut len_uncontracted_lhs = 1;
143 let mut len_uncontracted_rhs = 1;
144 let mut len_contracted_lhs = 1;
145 let mut len_contracted_rhs = 1;
146 let mut output_shape = Vec::new();
147
148 let num_axes_lhs = lhs_shape.len();
149 for (axis, &axis_length) in lhs_shape.iter().enumerate() {
150 if axis < (num_axes_lhs - num_contracted_axes) {
151 len_uncontracted_lhs *= axis_length;
152 output_shape.push(axis_length);
153 } else {
154 len_contracted_lhs *= axis_length;
155 }
156 }
157
158 for (axis, &axis_length) in rhs_shape.iter().enumerate() {
159 if axis < num_contracted_axes {
160 len_contracted_rhs *= axis_length;
161 } else {
162 len_uncontracted_rhs *= axis_length;
163 output_shape.push(axis_length);
164 }
165 }
166 assert_eq!(len_contracted_rhs, len_contracted_lhs);
167 let len_contracted_axes = len_contracted_lhs;
168
169 TensordotFixedPosition {
170 len_uncontracted_lhs,
171 len_uncontracted_rhs,
172 len_contracted_axes,
173 output_shape,
174 }
175 }
176}
177
178impl<A> PairContractor<A> for TensordotFixedPosition {
179 fn contract_pair<'a, 'b, 'c, 'd>(
180 &self,
181 lhs: &'b ArrayViewD<'a, A>,
182 rhs: &'d ArrayViewD<'c, A>,
183 ) -> ArrayD<A>
184 where
185 'a: 'b,
186 'c: 'd,
187 A: Clone + LinalgScalar,
188 {
189 let lhs_array;
190 let lhs_view = if lhs.is_standard_layout() {
191 lhs.view()
192 .into_shape_with_order((self.len_uncontracted_lhs, self.len_contracted_axes))
193 .unwrap()
194 } else {
195 lhs_array = Array::from_shape_vec(
196 [self.len_uncontracted_lhs, self.len_contracted_axes],
197 lhs.iter().cloned().collect(),
198 )
199 .unwrap();
200 lhs_array.view()
201 };
202
203 let rhs_array;
204 let rhs_view = if rhs.is_standard_layout() {
205 rhs.view()
206 .into_shape_with_order((self.len_contracted_axes, self.len_uncontracted_rhs))
207 .unwrap()
208 } else {
209 rhs_array = Array::from_shape_vec(
210 [self.len_contracted_axes, self.len_uncontracted_rhs],
211 rhs.iter().cloned().collect(),
212 )
213 .unwrap();
214 rhs_array.view()
215 };
216
217 lhs_view
218 .dot(&rhs_view)
219 .into_shape_with_order(IxDyn(&self.output_shape))
220 .unwrap()
221 }
222}
223
224#[derive(Clone, Debug)]
234pub struct TensordotGeneral {
235 lhs_permutation: Permutation,
236 rhs_permutation: Permutation,
237 tensordot_fixed_position: TensordotFixedPosition,
238 output_permutation: Permutation,
239}
240
241impl TensordotGeneral {
242 pub fn new(sc: &SizedContraction) -> Self {
243 assert_eq!(sc.contraction.operand_indices.len(), 2);
244 let lhs_indices = &sc.contraction.operand_indices[0];
245 let rhs_indices = &sc.contraction.operand_indices[1];
246 let contracted_indices = &sc.contraction.summation_indices;
247 let output_indices = &sc.contraction.output_indices;
248 let lhs_shape: Vec<usize> = lhs_indices.iter().map(|c| sc.output_size[c]).collect();
249 let rhs_shape: Vec<usize> = rhs_indices.iter().map(|c| sc.output_size[c]).collect();
250
251 TensordotGeneral::from_shapes_and_indices(
252 &lhs_shape,
253 &rhs_shape,
254 lhs_indices,
255 rhs_indices,
256 contracted_indices,
257 output_indices,
258 )
259 }
260
261 fn from_shapes_and_indices(
262 lhs_shape: &[usize],
263 rhs_shape: &[usize],
264 lhs_indices: &[char],
265 rhs_indices: &[char],
266 contracted_indices: &[char],
267 output_indices: &[char],
268 ) -> Self {
269 let lhs_contracted_axes = find_outputs_in_inputs_unique(contracted_indices, lhs_indices);
270 let rhs_contracted_axes = find_outputs_in_inputs_unique(contracted_indices, rhs_indices);
271 let mut uncontracted_chars: Vec<char> = lhs_indices
272 .iter()
273 .filter(|&&input_char| {
274 output_indices
275 .iter()
276 .any(|&output_char| input_char == output_char)
277 })
278 .cloned()
279 .collect();
280 let mut rhs_uncontracted_chars: Vec<char> = rhs_indices
281 .iter()
282 .filter(|&&input_char| {
283 output_indices
284 .iter()
285 .any(|&output_char| input_char == output_char)
286 })
287 .cloned()
288 .collect();
289 uncontracted_chars.append(&mut rhs_uncontracted_chars);
290 let output_order = find_outputs_in_inputs_unique(output_indices, &uncontracted_chars);
291
292 TensordotGeneral::from_shapes_and_axis_numbers(
293 lhs_shape,
294 rhs_shape,
295 &lhs_contracted_axes,
296 &rhs_contracted_axes,
297 &output_order,
298 )
299 }
300
301 pub fn from_shapes_and_axis_numbers(
307 lhs_shape: &[usize],
308 rhs_shape: &[usize],
309 lhs_axes: &[usize],
310 rhs_axes: &[usize],
311 output_order: &[usize],
312 ) -> Self {
313 let num_contracted_axes = lhs_axes.len();
314 assert!(num_contracted_axes == rhs_axes.len());
315 let lhs_uniques: HashSet<_> = lhs_axes.iter().cloned().collect();
316 let rhs_uniques: HashSet<_> = rhs_axes.iter().cloned().collect();
317 assert!(num_contracted_axes == lhs_uniques.len());
318 assert!(num_contracted_axes == rhs_uniques.len());
319 let mut adjusted_lhs_shape = Vec::new();
320 let mut adjusted_rhs_shape = Vec::new();
321
322 let mut permutation_lhs = Vec::new();
325 for (i, &axis_length) in lhs_shape.iter().enumerate() {
326 if !(lhs_uniques.contains(&i)) {
327 permutation_lhs.push(i);
328 adjusted_lhs_shape.push(axis_length);
329 }
330 }
331 for &axis in lhs_axes.iter() {
332 permutation_lhs.push(axis);
333 adjusted_lhs_shape.push(lhs_shape[axis]);
334 }
335
336 let mut permutation_rhs = Vec::new();
339 for &axis in rhs_axes.iter() {
340 permutation_rhs.push(axis);
341 adjusted_rhs_shape.push(rhs_shape[axis]);
342 }
343 for (i, &axis_length) in rhs_shape.iter().enumerate() {
344 if !(rhs_uniques.contains(&i)) {
345 permutation_rhs.push(i);
346 adjusted_rhs_shape.push(axis_length);
347 }
348 }
349
350 let lhs_permutation = Permutation::from_indices(&permutation_lhs);
351 let rhs_permutation = Permutation::from_indices(&permutation_rhs);
352 let tensordot_fixed_position =
353 TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
354 &adjusted_lhs_shape,
355 &adjusted_rhs_shape,
356 num_contracted_axes,
357 );
358
359 let output_permutation = Permutation::from_indices(output_order);
360
361 TensordotGeneral {
362 lhs_permutation,
363 rhs_permutation,
364 tensordot_fixed_position,
365 output_permutation,
366 }
367 }
368}
369
370impl<A> PairContractor<A> for TensordotGeneral {
371 fn contract_pair<'a, 'b, 'c, 'd>(
372 &self,
373 lhs: &'b ArrayViewD<'a, A>,
374 rhs: &'d ArrayViewD<'c, A>,
375 ) -> ArrayD<A>
376 where
377 'a: 'b,
378 'c: 'd,
379 A: Clone + LinalgScalar,
380 {
381 let permuted_lhs = self.lhs_permutation.view_singleton(lhs);
382 let permuted_rhs = self.rhs_permutation.view_singleton(rhs);
383 let tensordotted = self
384 .tensordot_fixed_position
385 .contract_pair(&permuted_lhs, &permuted_rhs);
386 self.output_permutation
387 .contract_singleton(&tensordotted.view())
388 }
389}
390
391#[derive(Clone, Debug)]
398pub struct HadamardProduct {}
399
400impl HadamardProduct {
401 pub fn new(sc: &SizedContraction) -> Self {
402 assert_eq!(sc.contraction.operand_indices.len(), 2);
403 let lhs_indices = &sc.contraction.operand_indices[0];
404 let rhs_indices = &sc.contraction.operand_indices[1];
405 let output_indices = &sc.contraction.output_indices;
406 assert_eq!(lhs_indices, rhs_indices);
407 assert_eq!(lhs_indices, output_indices);
408
409 HadamardProduct {}
410 }
411
412 fn from_nothing() -> Self {
413 HadamardProduct {}
414 }
415}
416
417impl<A> PairContractor<A> for HadamardProduct {
418 fn contract_pair<'a, 'b, 'c, 'd>(
419 &self,
420 lhs: &'b ArrayViewD<'a, A>,
421 rhs: &'d ArrayViewD<'c, A>,
422 ) -> ArrayD<A>
423 where
424 'a: 'b,
425 'c: 'd,
426 A: Clone + LinalgScalar,
427 {
428 lhs * rhs
429 }
430}
431
432#[derive(Clone, Debug)]
440pub struct HadamardProductGeneral {
441 lhs_permutation: Permutation,
442 rhs_permutation: Permutation,
443 hadamard_product: HadamardProduct,
444}
445
446impl HadamardProductGeneral {
447 pub fn new(sc: &SizedContraction) -> Self {
448 assert_eq!(sc.contraction.operand_indices.len(), 2);
449 let lhs_indices = &sc.contraction.operand_indices[0];
450 let rhs_indices = &sc.contraction.operand_indices[1];
451 let output_indices = &sc.contraction.output_indices;
452 assert_eq!(lhs_indices.len(), rhs_indices.len());
453 assert_eq!(lhs_indices.len(), output_indices.len());
454
455 let lhs_permutation =
456 Permutation::from_indices(&find_outputs_in_inputs_unique(output_indices, lhs_indices));
457 let rhs_permutation =
458 Permutation::from_indices(&find_outputs_in_inputs_unique(output_indices, rhs_indices));
459 let hadamard_product = HadamardProduct::from_nothing();
460
461 HadamardProductGeneral {
462 lhs_permutation,
463 rhs_permutation,
464 hadamard_product,
465 }
466 }
467}
468
469impl<A> PairContractor<A> for HadamardProductGeneral {
470 fn contract_pair<'a, 'b, 'c, 'd>(
471 &self,
472 lhs: &'b ArrayViewD<'a, A>,
473 rhs: &'d ArrayViewD<'c, A>,
474 ) -> ArrayD<A>
475 where
476 'a: 'b,
477 'c: 'd,
478 A: Clone + LinalgScalar,
479 {
480 self.hadamard_product.contract_pair(
481 &self.lhs_permutation.view_singleton(lhs),
482 &self.rhs_permutation.view_singleton(rhs),
483 )
484 }
485}
486
487#[derive(Clone, Debug)]
494pub struct ScalarMatrixProduct {}
495
496impl ScalarMatrixProduct {
497 pub fn new(sc: &SizedContraction) -> Self {
498 assert_eq!(sc.contraction.operand_indices.len(), 2);
499 let lhs_indices = &sc.contraction.operand_indices[0];
500 let rhs_indices = &sc.contraction.operand_indices[1];
501 let output_indices = &sc.contraction.output_indices;
502 assert_eq!(lhs_indices.len(), 0);
503 assert_eq!(output_indices, rhs_indices);
504
505 ScalarMatrixProduct {}
506 }
507
508 pub fn from_nothing() -> Self {
509 ScalarMatrixProduct {}
510 }
511}
512
513impl<A> PairContractor<A> for ScalarMatrixProduct {
514 fn contract_pair<'a, 'b, 'c, 'd>(
515 &self,
516 lhs: &'b ArrayViewD<'a, A>,
517 rhs: &'d ArrayViewD<'c, A>,
518 ) -> ArrayD<A>
519 where
520 'a: 'b,
521 'c: 'd,
522 A: Clone + LinalgScalar,
523 {
524 let lhs_0d: A = *lhs.first().unwrap();
525 rhs.mapv(|x| x * lhs_0d)
526 }
527}
528
529#[derive(Clone, Debug)]
537pub struct ScalarMatrixProductGeneral {
538 rhs_permutation: Permutation,
539 scalar_matrix_product: ScalarMatrixProduct,
540}
541
542impl ScalarMatrixProductGeneral {
543 pub fn new(sc: &SizedContraction) -> Self {
544 assert_eq!(sc.contraction.operand_indices.len(), 2);
545 let lhs_indices = &sc.contraction.operand_indices[0];
546 let rhs_indices = &sc.contraction.operand_indices[1];
547 let output_indices = &sc.contraction.output_indices;
548 assert_eq!(lhs_indices.len(), 0);
549 assert_eq!(rhs_indices.len(), output_indices.len());
550
551 ScalarMatrixProductGeneral::from_indices(rhs_indices, output_indices)
552 }
553
554 pub fn from_indices(input_indices: &[char], output_indices: &[char]) -> Self {
555 let rhs_permutation = Permutation::from_indices(&find_outputs_in_inputs_unique(
556 output_indices,
557 input_indices,
558 ));
559 let scalar_matrix_product = ScalarMatrixProduct::from_nothing();
560
561 ScalarMatrixProductGeneral {
562 rhs_permutation,
563 scalar_matrix_product,
564 }
565 }
566}
567
568impl<A> PairContractor<A> for ScalarMatrixProductGeneral {
569 fn contract_pair<'a, 'b, 'c, 'd>(
570 &self,
571 lhs: &'b ArrayViewD<'a, A>,
572 rhs: &'d ArrayViewD<'c, A>,
573 ) -> ArrayD<A>
574 where
575 'a: 'b,
576 'c: 'd,
577 A: Clone + LinalgScalar,
578 {
579 self.scalar_matrix_product
580 .contract_pair(lhs, &self.rhs_permutation.view_singleton(rhs))
581 }
582}
583
584#[derive(Clone, Debug)]
591pub struct MatrixScalarProduct {}
592
593impl MatrixScalarProduct {
594 pub fn new(sc: &SizedContraction) -> Self {
595 assert_eq!(sc.contraction.operand_indices.len(), 2);
596 let lhs_indices = &sc.contraction.operand_indices[0];
597 let rhs_indices = &sc.contraction.operand_indices[1];
598 let output_indices = &sc.contraction.output_indices;
599 assert_eq!(rhs_indices.len(), 0);
600 assert_eq!(output_indices, lhs_indices);
601
602 MatrixScalarProduct {}
603 }
604
605 pub fn from_nothing() -> Self {
606 MatrixScalarProduct {}
607 }
608}
609
610impl<A> PairContractor<A> for MatrixScalarProduct {
611 fn contract_pair<'a, 'b, 'c, 'd>(
612 &self,
613 lhs: &'b ArrayViewD<'a, A>,
614 rhs: &'d ArrayViewD<'c, A>,
615 ) -> ArrayD<A>
616 where
617 'a: 'b,
618 'c: 'd,
619 A: Clone + LinalgScalar,
620 {
621 let rhs_0d: A = *rhs.first().unwrap();
622 lhs.mapv(|x| x * rhs_0d)
623 }
624}
625
626#[derive(Clone, Debug)]
634pub struct MatrixScalarProductGeneral {
635 lhs_permutation: Permutation,
636 matrix_scalar_product: MatrixScalarProduct,
637}
638
639impl MatrixScalarProductGeneral {
640 pub fn new(sc: &SizedContraction) -> Self {
641 assert_eq!(sc.contraction.operand_indices.len(), 2);
642 let lhs_indices = &sc.contraction.operand_indices[0];
643 let rhs_indices = &sc.contraction.operand_indices[1];
644 let output_indices = &sc.contraction.output_indices;
645 assert_eq!(rhs_indices.len(), 0);
646 assert_eq!(lhs_indices.len(), output_indices.len());
647
648 MatrixScalarProductGeneral::from_indices(lhs_indices, output_indices)
649 }
650
651 pub fn from_indices(input_indices: &[char], output_indices: &[char]) -> Self {
652 let lhs_permutation = Permutation::from_indices(&find_outputs_in_inputs_unique(
653 output_indices,
654 input_indices,
655 ));
656 let matrix_scalar_product = MatrixScalarProduct::from_nothing();
657
658 MatrixScalarProductGeneral {
659 lhs_permutation,
660 matrix_scalar_product,
661 }
662 }
663}
664
665impl<A> PairContractor<A> for MatrixScalarProductGeneral {
666 fn contract_pair<'a, 'b, 'c, 'd>(
667 &self,
668 lhs: &'b ArrayViewD<'a, A>,
669 rhs: &'d ArrayViewD<'c, A>,
670 ) -> ArrayD<A>
671 where
672 'a: 'b,
673 'c: 'd,
674 A: Clone + LinalgScalar,
675 {
676 self.matrix_scalar_product
677 .contract_pair(&self.lhs_permutation.view_singleton(lhs), rhs)
678 }
679}
680
681#[derive(Clone, Debug)]
692pub struct BroadcastProductGeneral {
693 lhs_permutation: Permutation,
694 rhs_permutation: Permutation,
695 lhs_insertions: Vec<usize>,
696 rhs_insertions: Vec<usize>,
697 output_sizes: Vec<usize>,
698 hadamard_product: HadamardProduct,
699}
700
701impl BroadcastProductGeneral {
702 pub fn new(sc: &SizedContraction) -> Self {
703 assert_eq!(sc.contraction.operand_indices.len(), 2);
704 let lhs_indices = &sc.contraction.operand_indices[0];
705 let rhs_indices = &sc.contraction.operand_indices[1];
706 let output_indices = &sc.contraction.output_indices;
707
708 let maybe_lhs_indices = maybe_find_outputs_in_inputs_unique(output_indices, lhs_indices);
709 let maybe_rhs_indices = maybe_find_outputs_in_inputs_unique(output_indices, rhs_indices);
710 let lhs_indices: Vec<usize> = maybe_lhs_indices.iter().copied().flatten().collect();
711 let rhs_indices: Vec<usize> = maybe_rhs_indices.iter().copied().flatten().collect();
712 let lhs_insertions: Vec<usize> = maybe_lhs_indices
713 .into_iter()
714 .enumerate()
715 .filter(|(_, x)| x.is_none())
716 .map(|(i, _)| i)
717 .collect();
718 let rhs_insertions: Vec<usize> = maybe_rhs_indices
719 .into_iter()
720 .enumerate()
721 .filter(|(_, x)| x.is_none())
722 .map(|(i, _)| i)
723 .collect();
724 let lhs_permutation = Permutation::from_indices(&lhs_indices);
725 let rhs_permutation = Permutation::from_indices(&rhs_indices);
726 let output_sizes: Vec<usize> = output_indices.iter().map(|c| sc.output_size[c]).collect();
727 let hadamard_product = HadamardProduct::from_nothing();
728
729 BroadcastProductGeneral {
730 lhs_permutation,
731 rhs_permutation,
732 lhs_insertions,
733 rhs_insertions,
734 output_sizes,
735 hadamard_product,
736 }
737 }
738}
739
740impl<A> PairContractor<A> for BroadcastProductGeneral {
741 fn contract_pair<'a, 'b, 'c, 'd>(
742 &self,
743 lhs: &'b ArrayViewD<'a, A>,
744 rhs: &'d ArrayViewD<'c, A>,
745 ) -> ArrayD<A>
746 where
747 'a: 'b,
748 'c: 'd,
749 A: Clone + LinalgScalar,
750 {
751 let mut adjusted_lhs = self.lhs_permutation.view_singleton(lhs);
752 let mut adjusted_rhs = self.rhs_permutation.view_singleton(rhs);
753 for &i in self.lhs_insertions.iter() {
754 adjusted_lhs = adjusted_lhs.insert_axis(Axis(i));
755 }
756 for &i in self.rhs_insertions.iter() {
757 adjusted_rhs = adjusted_rhs.insert_axis(Axis(i));
758 }
759 let output_shape = IxDyn(&self.output_sizes);
760 let broadcast_lhs = adjusted_lhs.broadcast(output_shape.clone()).unwrap();
761 let broadcast_rhs = adjusted_rhs.broadcast(output_shape).unwrap();
762 self.hadamard_product
763 .contract_pair(&broadcast_lhs, &broadcast_rhs)
764 }
765}
766
767#[derive(Clone, Debug)]
789pub struct StackedTensordotGeneral {
790 lhs_permutation: Permutation,
791 rhs_permutation: Permutation,
792 lhs_output_shape: Vec<usize>,
793 rhs_output_shape: Vec<usize>,
794 intermediate_shape: Vec<usize>,
795 tensordot_fixed_position: TensordotFixedPosition,
796 output_shape: Vec<usize>,
797 output_permutation: Permutation,
798}
799
800impl StackedTensordotGeneral {
801 pub fn new(sc: &SizedContraction) -> Self {
802 let mut lhs_order = Vec::new();
803 let mut rhs_order = Vec::new();
804 let mut lhs_output_shape = Vec::new();
805 let mut rhs_output_shape = Vec::new();
806 let mut intermediate_shape = Vec::new();
807
808 assert_eq!(sc.contraction.operand_indices.len(), 2);
809 let lhs_indices = &sc.contraction.operand_indices[0];
810 let rhs_indices = &sc.contraction.operand_indices[1];
811 let output_indices = &sc.contraction.output_indices;
812
813 let maybe_lhs_axes = maybe_find_outputs_in_inputs_unique(output_indices, lhs_indices);
814 let maybe_rhs_axes = maybe_find_outputs_in_inputs_unique(output_indices, rhs_indices);
815 let mut lhs_stack_axes = Vec::new();
816 let mut rhs_stack_axes = Vec::new();
817 let mut stack_indices = Vec::new();
818 let mut lhs_outer_axes = Vec::new();
819 let mut lhs_outer_indices = Vec::new();
820 let mut rhs_outer_axes = Vec::new();
821 let mut rhs_outer_indices = Vec::new();
822 let mut lhs_contracted_axes = Vec::new();
823 let mut rhs_contracted_axes = Vec::new();
824 let mut intermediate_indices = Vec::new();
825 let mut lhs_uncontracted_shape = Vec::new();
826 let mut rhs_uncontracted_shape = Vec::new();
827 let mut contracted_shape = Vec::new();
828
829 lhs_output_shape.push(1);
830 rhs_output_shape.push(1);
831
832 for ((&maybe_lhs_pos, &maybe_rhs_pos), &output_char) in maybe_lhs_axes
833 .iter()
834 .zip(maybe_rhs_axes.iter())
835 .zip(output_indices.iter())
836 {
837 match (maybe_lhs_pos, maybe_rhs_pos) {
838 (Some(lhs_pos), Some(rhs_pos)) => {
839 lhs_stack_axes.push(lhs_pos);
840 rhs_stack_axes.push(rhs_pos);
841 stack_indices.push(output_char);
842 lhs_output_shape[0] *= sc.output_size[&output_char];
843 rhs_output_shape[0] *= sc.output_size[&output_char];
844 }
845 (Some(lhs_pos), None) => {
846 lhs_outer_axes.push(lhs_pos);
847 lhs_outer_indices.push(output_char);
848 lhs_uncontracted_shape.push(sc.output_size[&output_char]);
849 }
850 (None, Some(rhs_pos)) => {
851 rhs_outer_axes.push(rhs_pos);
852 rhs_outer_indices.push(output_char);
853 rhs_uncontracted_shape.push(sc.output_size[&output_char]);
854 }
855 (None, None) => {
856 panic!() }
858 }
859 }
860
861 for (lhs_pos, &lhs_char) in lhs_indices.iter().enumerate() {
862 if !output_indices
863 .iter()
864 .any(|&output_char| output_char == lhs_char)
865 {
866 lhs_contracted_axes.push(lhs_pos);
868 rhs_contracted_axes.push(
870 rhs_indices
871 .iter()
872 .position(|&rhs_char| rhs_char == lhs_char)
873 .unwrap(),
874 );
875 contracted_shape.push(sc.output_size[&lhs_char]);
876 }
877 }
878 lhs_order.append(&mut lhs_stack_axes.clone());
884 lhs_order.append(&mut lhs_outer_axes.clone());
885 lhs_output_shape.append(&mut lhs_uncontracted_shape);
886 lhs_order.append(&mut lhs_contracted_axes.clone());
887 lhs_output_shape.append(&mut contracted_shape.clone());
888
889 rhs_order.append(&mut rhs_stack_axes.clone());
890 rhs_order.append(&mut rhs_contracted_axes.clone());
891 rhs_output_shape.append(&mut contracted_shape);
892 rhs_order.append(&mut rhs_outer_axes.clone());
893 rhs_output_shape.append(&mut rhs_uncontracted_shape);
894
895 intermediate_indices.append(&mut stack_indices.clone());
898 intermediate_indices.append(&mut lhs_outer_indices.clone());
899 intermediate_indices.append(&mut rhs_outer_indices.clone());
900
901 assert_eq!(lhs_output_shape[0], rhs_output_shape[0]);
902 intermediate_shape.push(lhs_output_shape[0]);
903 for lhs_char in lhs_outer_indices.iter() {
904 intermediate_shape.push(sc.output_size[lhs_char]);
905 }
906 for rhs_char in rhs_outer_indices.iter() {
907 intermediate_shape.push(sc.output_size[rhs_char]);
908 }
909
910 let output_order = find_outputs_in_inputs_unique(output_indices, &intermediate_indices);
911 let output_shape = intermediate_indices
912 .iter()
913 .map(|c| sc.output_size[c])
914 .collect();
915
916 let tensordot_fixed_position =
917 TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
918 &lhs_output_shape[1..],
919 &rhs_output_shape[1..],
920 lhs_contracted_axes.len(),
921 );
922 let lhs_permutation = Permutation::from_indices(&lhs_order);
923 let rhs_permutation = Permutation::from_indices(&rhs_order);
924 let output_permutation = Permutation::from_indices(&output_order);
925 StackedTensordotGeneral {
926 lhs_permutation,
927 rhs_permutation,
928 lhs_output_shape,
929 rhs_output_shape,
930 intermediate_shape,
931 tensordot_fixed_position,
932 output_shape,
933 output_permutation,
934 }
935 }
936}
937
938impl<A> PairContractor<A> for StackedTensordotGeneral {
939 fn contract_pair<'a, 'b, 'c, 'd>(
940 &self,
941 lhs: &'b ArrayViewD<'a, A>,
942 rhs: &'d ArrayViewD<'c, A>,
943 ) -> ArrayD<A>
944 where
945 'a: 'b,
946 'c: 'd,
947 A: Clone + LinalgScalar,
948 {
949 let lhs_permuted = self.lhs_permutation.view_singleton(lhs);
950 let lhs_reshaped = Array::from_shape_vec(
951 IxDyn(&self.lhs_output_shape),
952 lhs_permuted.iter().cloned().collect(),
953 )
954 .unwrap();
955 let rhs_permuted = self.rhs_permutation.view_singleton(rhs);
956 let rhs_reshaped = Array::from_shape_vec(
957 IxDyn(&self.rhs_output_shape),
958 rhs_permuted.iter().cloned().collect(),
959 )
960 .unwrap();
961 let mut intermediate_result: ArrayD<A> = Array::zeros(IxDyn(&self.intermediate_shape));
962 let mut lhs_iter = lhs_reshaped.outer_iter();
963 let mut rhs_iter = rhs_reshaped.outer_iter();
964 for mut output_subview in intermediate_result.outer_iter_mut() {
965 let lhs_subview = lhs_iter.next().unwrap();
966 let rhs_subview = rhs_iter.next().unwrap();
967 self.tensordot_fixed_position.contract_and_assign_pair(
968 &lhs_subview,
969 &rhs_subview,
970 &mut output_subview,
971 );
972 }
973 let intermediate_reshaped = intermediate_result
974 .into_shape_with_order(IxDyn(&self.output_shape))
975 .unwrap();
976 self.output_permutation
977 .contract_singleton(&intermediate_reshaped.view())
978 }
979}