ndarray_einsum/contractors/
pair_contractors.rs

1// Copyright 2019 Jared Samet
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Contains the specific implementations of `PairContractor` that represent the base-case ways
16//! to contract two simplified tensors.
17//!
18//! A tensor is simplified with respect to another tensor and a set of output indices
19//! if two conditions are met:
20//!
21//! 1. Each index in the tensor is present in either the other tensor or in the output indices.
22//! 2. Each index in the tensor appears only once.
23//!
24//! Examples of `einsum` strings with two simplified tensors:
25//!
26//! 1. `ijk,jkl->il`
27//! 2. `ijk,jkl->ijkl`
28//! 3. `ijk,ijk->`
29//!
30//! Examples of `einsum` strings with two tensors that are NOT simplified:
31//!
32//! 1. `iijk,jkl->il` Not simplified because `i` appears twice in the LHS tensor
33//! 2. `ijk,jkl->i` Not simplified because `l` only appears in the RHS tensor
34//!
35//! If the two input tensors are both simplified, all instances of tensor contraction
36//! can be expressed as one of a small number of cases. Note that there may be more than
37//! one way to express the same contraction; some preliminary benchmarking has been
38//! done to identify the faster choice.
39
40use hashbrown::HashSet;
41use ndarray::prelude::*;
42use ndarray::LinalgScalar;
43
44use super::{PairContractor, Permutation, SingletonContractor, SingletonViewer};
45use crate::SizedContraction;
46
47/// Helper function used throughout this module to find the positions of one set of indices
48/// in a second set of indices. The most common case is generating a permutation to
49/// be supplied to `permuted_axes`.
50fn maybe_find_outputs_in_inputs_unique(
51    output_indices: &[char],
52    input_indices: &[char],
53) -> Vec<Option<usize>> {
54    output_indices
55        .iter()
56        .map(|&output_char| {
57            let input_pos = input_indices
58                .iter()
59                .position(|&input_char| input_char == output_char);
60            if input_pos.is_some() {
61                assert!(!input_indices
62                    .iter()
63                    .skip(input_pos.unwrap() + 1)
64                    .any(|&input_char| input_char == output_char));
65            }
66            input_pos
67        })
68        .collect()
69}
70
71fn find_outputs_in_inputs_unique(output_indices: &[char], input_indices: &[char]) -> Vec<usize> {
72    maybe_find_outputs_in_inputs_unique(output_indices, input_indices)
73        .iter()
74        .map(|x| x.unwrap())
75        .collect()
76}
77
78/// Performs tensor dot product for two tensors where no permutation needs to be performed,
79/// e.g. `ijk,jkl->il` or `ijk,klm->ijlm`.
80///
81/// The axes to be contracted must be the last axes of the LHS tensor and the first axes
82/// of the RHS tensor, and the axis order for the output tensor must be all the uncontracted
83/// axes of the LHS tensor followed by all the uncontracted axes of the RHS tensor, in the
84/// orders those originally appear in the LHS and RHS tensors.
85///
86/// The contraction is performed by reshaping the LHS into a matrix (2-D tensor) of shape
87/// [len_uncontracted_lhs, len_contracted_axes], reshaping the RHS into shape
88/// [len_contracted_axes, len_contracted_rhs], matrix-multiplying the two reshaped tensor,
89/// and then reshaping the result into [...self.output_shape].
90#[derive(Clone, Debug)]
91pub struct TensordotFixedPosition {
92    /// The product of the lengths of all the uncontracted axes in the LHS (or 1 if all of the
93    /// LHS axes are contracted)
94    len_uncontracted_lhs: usize,
95
96    /// The product of the lengths of all the uncontracted axes in the RHS (or 1 if all of the
97    /// RHS axes are contracted)
98    len_uncontracted_rhs: usize,
99
100    /// The product of the lengths of all the contracted axes (or 1 if no axes are contracted,
101    /// i.e. the outer product is computed)
102    len_contracted_axes: usize,
103
104    /// The shape that the tensor dot product will be recast to
105    output_shape: Vec<usize>,
106}
107
108impl TensordotFixedPosition {
109    pub fn new(sc: &SizedContraction) -> Self {
110        assert_eq!(sc.contraction.operand_indices.len(), 2);
111        let lhs_indices = &sc.contraction.operand_indices[0];
112        let rhs_indices = &sc.contraction.operand_indices[1];
113        let output_indices = &sc.contraction.output_indices;
114        // Returns an n-dimensional array where n = |D| + |E| - 2 * last_n.
115        let twice_num_contracted_axes =
116            lhs_indices.len() + rhs_indices.len() - output_indices.len();
117        assert_eq!(twice_num_contracted_axes % 2, 0);
118        let num_contracted_axes = twice_num_contracted_axes / 2;
119        // TODO: Add an assert! that they have the same indices
120
121        let lhs_shape: Vec<usize> = lhs_indices.iter().map(|c| sc.output_size[c]).collect();
122        let rhs_shape: Vec<usize> = rhs_indices.iter().map(|c| sc.output_size[c]).collect();
123
124        TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
125            &lhs_shape,
126            &rhs_shape,
127            num_contracted_axes,
128        )
129    }
130
131    /// Compute the uncontracted and contracted axis lengths and the output shape based on the
132    /// input shapes and how many axes should be contracted from each tensor.
133    ///
134    /// TODO: The assert_eq! here could be tightened up by verifying that the
135    /// last `num_contracted_axes` of the LHS match the first `num_contracted_axes` of the
136    /// RHS axis-by-axis (as opposed to only checking the product as is done here.)
137    pub fn from_shapes_and_number_of_contracted_axes(
138        lhs_shape: &[usize],
139        rhs_shape: &[usize],
140        num_contracted_axes: usize,
141    ) -> Self {
142        let mut len_uncontracted_lhs = 1;
143        let mut len_uncontracted_rhs = 1;
144        let mut len_contracted_lhs = 1;
145        let mut len_contracted_rhs = 1;
146        let mut output_shape = Vec::new();
147
148        let num_axes_lhs = lhs_shape.len();
149        for (axis, &axis_length) in lhs_shape.iter().enumerate() {
150            if axis < (num_axes_lhs - num_contracted_axes) {
151                len_uncontracted_lhs *= axis_length;
152                output_shape.push(axis_length);
153            } else {
154                len_contracted_lhs *= axis_length;
155            }
156        }
157
158        for (axis, &axis_length) in rhs_shape.iter().enumerate() {
159            if axis < num_contracted_axes {
160                len_contracted_rhs *= axis_length;
161            } else {
162                len_uncontracted_rhs *= axis_length;
163                output_shape.push(axis_length);
164            }
165        }
166        assert_eq!(len_contracted_rhs, len_contracted_lhs);
167        let len_contracted_axes = len_contracted_lhs;
168
169        TensordotFixedPosition {
170            len_uncontracted_lhs,
171            len_uncontracted_rhs,
172            len_contracted_axes,
173            output_shape,
174        }
175    }
176}
177
178impl<A> PairContractor<A> for TensordotFixedPosition {
179    fn contract_pair<'a, 'b, 'c, 'd>(
180        &self,
181        lhs: &'b ArrayViewD<'a, A>,
182        rhs: &'d ArrayViewD<'c, A>,
183    ) -> ArrayD<A>
184    where
185        'a: 'b,
186        'c: 'd,
187        A: Clone + LinalgScalar,
188    {
189        let lhs_array;
190        let lhs_view = if lhs.is_standard_layout() {
191            lhs.view()
192                .into_shape_with_order((self.len_uncontracted_lhs, self.len_contracted_axes))
193                .unwrap()
194        } else {
195            lhs_array = Array::from_shape_vec(
196                [self.len_uncontracted_lhs, self.len_contracted_axes],
197                lhs.iter().cloned().collect(),
198            )
199            .unwrap();
200            lhs_array.view()
201        };
202
203        let rhs_array;
204        let rhs_view = if rhs.is_standard_layout() {
205            rhs.view()
206                .into_shape_with_order((self.len_contracted_axes, self.len_uncontracted_rhs))
207                .unwrap()
208        } else {
209            rhs_array = Array::from_shape_vec(
210                [self.len_contracted_axes, self.len_uncontracted_rhs],
211                rhs.iter().cloned().collect(),
212            )
213            .unwrap();
214            rhs_array.view()
215        };
216
217        lhs_view
218            .dot(&rhs_view)
219            .into_shape_with_order(IxDyn(&self.output_shape))
220            .unwrap()
221    }
222}
223
224// TODO: Micro-optimization possible: Have a version without the final permutation,
225// which clones the array
226/// Computes the tensor dot product of two tensors, with individual permutations of the
227/// LHS and RHS performed as necessary, as well as a final permutation of the output.
228///
229/// Examples that qualify for TensordotGeneral but not TensordotFixedPosition:
230///
231/// 1. `jik,jkl->il` LHS tensor needs to be permuted `jik->ijk`
232/// 2. `ijk,klm->imlj` Output tensor needs to be permuted `ijlm->imlj`
233#[derive(Clone, Debug)]
234pub struct TensordotGeneral {
235    lhs_permutation: Permutation,
236    rhs_permutation: Permutation,
237    tensordot_fixed_position: TensordotFixedPosition,
238    output_permutation: Permutation,
239}
240
241impl TensordotGeneral {
242    pub fn new(sc: &SizedContraction) -> Self {
243        assert_eq!(sc.contraction.operand_indices.len(), 2);
244        let lhs_indices = &sc.contraction.operand_indices[0];
245        let rhs_indices = &sc.contraction.operand_indices[1];
246        let contracted_indices = &sc.contraction.summation_indices;
247        let output_indices = &sc.contraction.output_indices;
248        let lhs_shape: Vec<usize> = lhs_indices.iter().map(|c| sc.output_size[c]).collect();
249        let rhs_shape: Vec<usize> = rhs_indices.iter().map(|c| sc.output_size[c]).collect();
250
251        TensordotGeneral::from_shapes_and_indices(
252            &lhs_shape,
253            &rhs_shape,
254            lhs_indices,
255            rhs_indices,
256            contracted_indices,
257            output_indices,
258        )
259    }
260
261    fn from_shapes_and_indices(
262        lhs_shape: &[usize],
263        rhs_shape: &[usize],
264        lhs_indices: &[char],
265        rhs_indices: &[char],
266        contracted_indices: &[char],
267        output_indices: &[char],
268    ) -> Self {
269        let lhs_contracted_axes = find_outputs_in_inputs_unique(contracted_indices, lhs_indices);
270        let rhs_contracted_axes = find_outputs_in_inputs_unique(contracted_indices, rhs_indices);
271        let mut uncontracted_chars: Vec<char> = lhs_indices
272            .iter()
273            .filter(|&&input_char| {
274                output_indices
275                    .iter()
276                    .any(|&output_char| input_char == output_char)
277            })
278            .cloned()
279            .collect();
280        let mut rhs_uncontracted_chars: Vec<char> = rhs_indices
281            .iter()
282            .filter(|&&input_char| {
283                output_indices
284                    .iter()
285                    .any(|&output_char| input_char == output_char)
286            })
287            .cloned()
288            .collect();
289        uncontracted_chars.append(&mut rhs_uncontracted_chars);
290        let output_order = find_outputs_in_inputs_unique(output_indices, &uncontracted_chars);
291
292        TensordotGeneral::from_shapes_and_axis_numbers(
293            lhs_shape,
294            rhs_shape,
295            &lhs_contracted_axes,
296            &rhs_contracted_axes,
297            &output_order,
298        )
299    }
300
301    /// Produces a `TensordotGeneral` from the shapes and list of axes to be contracted.
302    ///
303    /// Wrapped by the public `tensordot` function and used by `TensordotGeneral::new()`.
304    /// lhs_axes lists the axes from the lhs tensor to contract and rhs_axes lists the
305    /// axes from the rhs tensor to contract.
306    pub fn from_shapes_and_axis_numbers(
307        lhs_shape: &[usize],
308        rhs_shape: &[usize],
309        lhs_axes: &[usize],
310        rhs_axes: &[usize],
311        output_order: &[usize],
312    ) -> Self {
313        let num_contracted_axes = lhs_axes.len();
314        assert!(num_contracted_axes == rhs_axes.len());
315        let lhs_uniques: HashSet<_> = lhs_axes.iter().cloned().collect();
316        let rhs_uniques: HashSet<_> = rhs_axes.iter().cloned().collect();
317        assert!(num_contracted_axes == lhs_uniques.len());
318        assert!(num_contracted_axes == rhs_uniques.len());
319        let mut adjusted_lhs_shape = Vec::new();
320        let mut adjusted_rhs_shape = Vec::new();
321
322        // Rolls the axes specified in lhs and rhs to the back and front respectively,
323        // then calls tensordot_fixed_order(rolled_lhs, rolled_rhs, lhs_axes.len())
324        let mut permutation_lhs = Vec::new();
325        for (i, &axis_length) in lhs_shape.iter().enumerate() {
326            if !(lhs_uniques.contains(&i)) {
327                permutation_lhs.push(i);
328                adjusted_lhs_shape.push(axis_length);
329            }
330        }
331        for &axis in lhs_axes.iter() {
332            permutation_lhs.push(axis);
333            adjusted_lhs_shape.push(lhs_shape[axis]);
334        }
335
336        // Note: These two for loops are (intentionally!) in the reverse order
337        // as they are for LHS.
338        let mut permutation_rhs = Vec::new();
339        for &axis in rhs_axes.iter() {
340            permutation_rhs.push(axis);
341            adjusted_rhs_shape.push(rhs_shape[axis]);
342        }
343        for (i, &axis_length) in rhs_shape.iter().enumerate() {
344            if !(rhs_uniques.contains(&i)) {
345                permutation_rhs.push(i);
346                adjusted_rhs_shape.push(axis_length);
347            }
348        }
349
350        let lhs_permutation = Permutation::from_indices(&permutation_lhs);
351        let rhs_permutation = Permutation::from_indices(&permutation_rhs);
352        let tensordot_fixed_position =
353            TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
354                &adjusted_lhs_shape,
355                &adjusted_rhs_shape,
356                num_contracted_axes,
357            );
358
359        let output_permutation = Permutation::from_indices(output_order);
360
361        TensordotGeneral {
362            lhs_permutation,
363            rhs_permutation,
364            tensordot_fixed_position,
365            output_permutation,
366        }
367    }
368}
369
370impl<A> PairContractor<A> for TensordotGeneral {
371    fn contract_pair<'a, 'b, 'c, 'd>(
372        &self,
373        lhs: &'b ArrayViewD<'a, A>,
374        rhs: &'d ArrayViewD<'c, A>,
375    ) -> ArrayD<A>
376    where
377        'a: 'b,
378        'c: 'd,
379        A: Clone + LinalgScalar,
380    {
381        let permuted_lhs = self.lhs_permutation.view_singleton(lhs);
382        let permuted_rhs = self.rhs_permutation.view_singleton(rhs);
383        let tensordotted = self
384            .tensordot_fixed_position
385            .contract_pair(&permuted_lhs, &permuted_rhs);
386        self.output_permutation
387            .contract_singleton(&tensordotted.view())
388    }
389}
390
391/// Computes the Hadamard (element-wise) product of two tensors.
392///
393/// All instances of `SizedContraction` making use of this contractor must have the form
394/// `ij,ij->ij`.
395///
396/// Contractions of the form `ij,ji->ij` need to use `HadamardProductGeneral` instead.
397#[derive(Clone, Debug)]
398pub struct HadamardProduct {}
399
400impl HadamardProduct {
401    pub fn new(sc: &SizedContraction) -> Self {
402        assert_eq!(sc.contraction.operand_indices.len(), 2);
403        let lhs_indices = &sc.contraction.operand_indices[0];
404        let rhs_indices = &sc.contraction.operand_indices[1];
405        let output_indices = &sc.contraction.output_indices;
406        assert_eq!(lhs_indices, rhs_indices);
407        assert_eq!(lhs_indices, output_indices);
408
409        HadamardProduct {}
410    }
411
412    fn from_nothing() -> Self {
413        HadamardProduct {}
414    }
415}
416
417impl<A> PairContractor<A> for HadamardProduct {
418    fn contract_pair<'a, 'b, 'c, 'd>(
419        &self,
420        lhs: &'b ArrayViewD<'a, A>,
421        rhs: &'d ArrayViewD<'c, A>,
422    ) -> ArrayD<A>
423    where
424        'a: 'b,
425        'c: 'd,
426        A: Clone + LinalgScalar,
427    {
428        lhs * rhs
429    }
430}
431
432/// Permutes the axes of the LHS and RHS tensors to the order in which those axes appear in the
433/// output before computing the Hadamard (element-wise) product.
434///
435/// Examples of contractions that fit this category:
436///
437/// 1. `ij,ij->ij` (Can also can use `HadamardProduct`)
438/// 2. `ij,ji->ij` (Can only use `HadamardProductGeneral`)
439#[derive(Clone, Debug)]
440pub struct HadamardProductGeneral {
441    lhs_permutation: Permutation,
442    rhs_permutation: Permutation,
443    hadamard_product: HadamardProduct,
444}
445
446impl HadamardProductGeneral {
447    pub fn new(sc: &SizedContraction) -> Self {
448        assert_eq!(sc.contraction.operand_indices.len(), 2);
449        let lhs_indices = &sc.contraction.operand_indices[0];
450        let rhs_indices = &sc.contraction.operand_indices[1];
451        let output_indices = &sc.contraction.output_indices;
452        assert_eq!(lhs_indices.len(), rhs_indices.len());
453        assert_eq!(lhs_indices.len(), output_indices.len());
454
455        let lhs_permutation =
456            Permutation::from_indices(&find_outputs_in_inputs_unique(output_indices, lhs_indices));
457        let rhs_permutation =
458            Permutation::from_indices(&find_outputs_in_inputs_unique(output_indices, rhs_indices));
459        let hadamard_product = HadamardProduct::from_nothing();
460
461        HadamardProductGeneral {
462            lhs_permutation,
463            rhs_permutation,
464            hadamard_product,
465        }
466    }
467}
468
469impl<A> PairContractor<A> for HadamardProductGeneral {
470    fn contract_pair<'a, 'b, 'c, 'd>(
471        &self,
472        lhs: &'b ArrayViewD<'a, A>,
473        rhs: &'d ArrayViewD<'c, A>,
474    ) -> ArrayD<A>
475    where
476        'a: 'b,
477        'c: 'd,
478        A: Clone + LinalgScalar,
479    {
480        self.hadamard_product.contract_pair(
481            &self.lhs_permutation.view_singleton(lhs),
482            &self.rhs_permutation.view_singleton(rhs),
483        )
484    }
485}
486
487/// Multiplies every element of the RHS tensor by the single scalar in the 0-d LHS tensor.
488///
489/// This contraction can arise when the simplification of the LHS tensor results in all the
490/// axes being summed before the two tensors are contracted. For example, in the contraction
491/// `i,jk->jk`, every element of the RHS tensor is simply multiplied by the sum of the elements
492/// of the LHS tensor.
493#[derive(Clone, Debug)]
494pub struct ScalarMatrixProduct {}
495
496impl ScalarMatrixProduct {
497    pub fn new(sc: &SizedContraction) -> Self {
498        assert_eq!(sc.contraction.operand_indices.len(), 2);
499        let lhs_indices = &sc.contraction.operand_indices[0];
500        let rhs_indices = &sc.contraction.operand_indices[1];
501        let output_indices = &sc.contraction.output_indices;
502        assert_eq!(lhs_indices.len(), 0);
503        assert_eq!(output_indices, rhs_indices);
504
505        ScalarMatrixProduct {}
506    }
507
508    pub fn from_nothing() -> Self {
509        ScalarMatrixProduct {}
510    }
511}
512
513impl<A> PairContractor<A> for ScalarMatrixProduct {
514    fn contract_pair<'a, 'b, 'c, 'd>(
515        &self,
516        lhs: &'b ArrayViewD<'a, A>,
517        rhs: &'d ArrayViewD<'c, A>,
518    ) -> ArrayD<A>
519    where
520        'a: 'b,
521        'c: 'd,
522        A: Clone + LinalgScalar,
523    {
524        let lhs_0d: A = *lhs.first().unwrap();
525        rhs.mapv(|x| x * lhs_0d)
526    }
527}
528
529/// Permutes the axes of the RHS tensor to the output order and multiply all elements by the single
530/// scalar in the 0-d LHS tensor.
531///
532/// This contraction can arise when the simplification of the LHS tensor results in all the
533/// axes being summed before the two tensors are contracted. For example, in the contraction
534/// `i,jk->kj`, the output matrix is equal to the RHS matrix, transposed and then scalar-multiplied
535/// by the sum of the elements of the LHS tensor.
536#[derive(Clone, Debug)]
537pub struct ScalarMatrixProductGeneral {
538    rhs_permutation: Permutation,
539    scalar_matrix_product: ScalarMatrixProduct,
540}
541
542impl ScalarMatrixProductGeneral {
543    pub fn new(sc: &SizedContraction) -> Self {
544        assert_eq!(sc.contraction.operand_indices.len(), 2);
545        let lhs_indices = &sc.contraction.operand_indices[0];
546        let rhs_indices = &sc.contraction.operand_indices[1];
547        let output_indices = &sc.contraction.output_indices;
548        assert_eq!(lhs_indices.len(), 0);
549        assert_eq!(rhs_indices.len(), output_indices.len());
550
551        ScalarMatrixProductGeneral::from_indices(rhs_indices, output_indices)
552    }
553
554    pub fn from_indices(input_indices: &[char], output_indices: &[char]) -> Self {
555        let rhs_permutation = Permutation::from_indices(&find_outputs_in_inputs_unique(
556            output_indices,
557            input_indices,
558        ));
559        let scalar_matrix_product = ScalarMatrixProduct::from_nothing();
560
561        ScalarMatrixProductGeneral {
562            rhs_permutation,
563            scalar_matrix_product,
564        }
565    }
566}
567
568impl<A> PairContractor<A> for ScalarMatrixProductGeneral {
569    fn contract_pair<'a, 'b, 'c, 'd>(
570        &self,
571        lhs: &'b ArrayViewD<'a, A>,
572        rhs: &'d ArrayViewD<'c, A>,
573    ) -> ArrayD<A>
574    where
575        'a: 'b,
576        'c: 'd,
577        A: Clone + LinalgScalar,
578    {
579        self.scalar_matrix_product
580            .contract_pair(lhs, &self.rhs_permutation.view_singleton(rhs))
581    }
582}
583
584/// Multiplies every element of the LHS tensor by the single scalar in the 0-d RHS tensor.
585///
586/// This contraction can arise when the simplification of the LHS tensor results in all the
587/// axes being summed before the two tensors are contracted. For example, in the contraction
588/// `ij,k->ij`, every element of the LHS tensor is simply multiplied by the sum of the elements
589/// of the RHS tensor.
590#[derive(Clone, Debug)]
591pub struct MatrixScalarProduct {}
592
593impl MatrixScalarProduct {
594    pub fn new(sc: &SizedContraction) -> Self {
595        assert_eq!(sc.contraction.operand_indices.len(), 2);
596        let lhs_indices = &sc.contraction.operand_indices[0];
597        let rhs_indices = &sc.contraction.operand_indices[1];
598        let output_indices = &sc.contraction.output_indices;
599        assert_eq!(rhs_indices.len(), 0);
600        assert_eq!(output_indices, lhs_indices);
601
602        MatrixScalarProduct {}
603    }
604
605    pub fn from_nothing() -> Self {
606        MatrixScalarProduct {}
607    }
608}
609
610impl<A> PairContractor<A> for MatrixScalarProduct {
611    fn contract_pair<'a, 'b, 'c, 'd>(
612        &self,
613        lhs: &'b ArrayViewD<'a, A>,
614        rhs: &'d ArrayViewD<'c, A>,
615    ) -> ArrayD<A>
616    where
617        'a: 'b,
618        'c: 'd,
619        A: Clone + LinalgScalar,
620    {
621        let rhs_0d: A = *rhs.first().unwrap();
622        lhs.mapv(|x| x * rhs_0d)
623    }
624}
625
626/// Permutes the axes of the LHS tensor to the output order and multiply all elements by the single
627/// scalar in the 0-d RHS tensor.
628///
629/// This contraction can arise when the simplification of the RHS tensor results in all the
630/// axes being summed before the two tensors are contracted. For example, in the contraction
631/// `ij,k->ji`, the output matrix is equal to the LHS matrix, transposed and then scalar-multiplied
632/// by the sum of the elements of the RHS tensor.
633#[derive(Clone, Debug)]
634pub struct MatrixScalarProductGeneral {
635    lhs_permutation: Permutation,
636    matrix_scalar_product: MatrixScalarProduct,
637}
638
639impl MatrixScalarProductGeneral {
640    pub fn new(sc: &SizedContraction) -> Self {
641        assert_eq!(sc.contraction.operand_indices.len(), 2);
642        let lhs_indices = &sc.contraction.operand_indices[0];
643        let rhs_indices = &sc.contraction.operand_indices[1];
644        let output_indices = &sc.contraction.output_indices;
645        assert_eq!(rhs_indices.len(), 0);
646        assert_eq!(lhs_indices.len(), output_indices.len());
647
648        MatrixScalarProductGeneral::from_indices(lhs_indices, output_indices)
649    }
650
651    pub fn from_indices(input_indices: &[char], output_indices: &[char]) -> Self {
652        let lhs_permutation = Permutation::from_indices(&find_outputs_in_inputs_unique(
653            output_indices,
654            input_indices,
655        ));
656        let matrix_scalar_product = MatrixScalarProduct::from_nothing();
657
658        MatrixScalarProductGeneral {
659            lhs_permutation,
660            matrix_scalar_product,
661        }
662    }
663}
664
665impl<A> PairContractor<A> for MatrixScalarProductGeneral {
666    fn contract_pair<'a, 'b, 'c, 'd>(
667        &self,
668        lhs: &'b ArrayViewD<'a, A>,
669        rhs: &'d ArrayViewD<'c, A>,
670    ) -> ArrayD<A>
671    where
672        'a: 'b,
673        'c: 'd,
674        A: Clone + LinalgScalar,
675    {
676        self.matrix_scalar_product
677            .contract_pair(&self.lhs_permutation.view_singleton(lhs), rhs)
678    }
679}
680
681/// Permutes the axes of the LHS and RHS tensor, broadcasts into the output shape,
682/// and then computes the element-wise product of the two broadcast tensors.
683///
684/// Currently unused due to (limited) unfavorable benchmarking results compared to
685/// `StackedTensordotGeneral`. An example of a contraction that could theoretically
686/// be performed by this contraction is `ij,jk->ijk`: the LHS and RHS are both
687/// broadcast into output shape (|i|, |j|, |k|) and then multiplied elementwise.
688///
689/// However, the limited benchmarking performed so far favored iterating along the
690/// `j` axis and computing the outer products `i,k->ik` for each subview of the tensors.
691#[derive(Clone, Debug)]
692pub struct BroadcastProductGeneral {
693    lhs_permutation: Permutation,
694    rhs_permutation: Permutation,
695    lhs_insertions: Vec<usize>,
696    rhs_insertions: Vec<usize>,
697    output_sizes: Vec<usize>,
698    hadamard_product: HadamardProduct,
699}
700
701impl BroadcastProductGeneral {
702    pub fn new(sc: &SizedContraction) -> Self {
703        assert_eq!(sc.contraction.operand_indices.len(), 2);
704        let lhs_indices = &sc.contraction.operand_indices[0];
705        let rhs_indices = &sc.contraction.operand_indices[1];
706        let output_indices = &sc.contraction.output_indices;
707
708        let maybe_lhs_indices = maybe_find_outputs_in_inputs_unique(output_indices, lhs_indices);
709        let maybe_rhs_indices = maybe_find_outputs_in_inputs_unique(output_indices, rhs_indices);
710        let lhs_indices: Vec<usize> = maybe_lhs_indices.iter().copied().flatten().collect();
711        let rhs_indices: Vec<usize> = maybe_rhs_indices.iter().copied().flatten().collect();
712        let lhs_insertions: Vec<usize> = maybe_lhs_indices
713            .into_iter()
714            .enumerate()
715            .filter(|(_, x)| x.is_none())
716            .map(|(i, _)| i)
717            .collect();
718        let rhs_insertions: Vec<usize> = maybe_rhs_indices
719            .into_iter()
720            .enumerate()
721            .filter(|(_, x)| x.is_none())
722            .map(|(i, _)| i)
723            .collect();
724        let lhs_permutation = Permutation::from_indices(&lhs_indices);
725        let rhs_permutation = Permutation::from_indices(&rhs_indices);
726        let output_sizes: Vec<usize> = output_indices.iter().map(|c| sc.output_size[c]).collect();
727        let hadamard_product = HadamardProduct::from_nothing();
728
729        BroadcastProductGeneral {
730            lhs_permutation,
731            rhs_permutation,
732            lhs_insertions,
733            rhs_insertions,
734            output_sizes,
735            hadamard_product,
736        }
737    }
738}
739
740impl<A> PairContractor<A> for BroadcastProductGeneral {
741    fn contract_pair<'a, 'b, 'c, 'd>(
742        &self,
743        lhs: &'b ArrayViewD<'a, A>,
744        rhs: &'d ArrayViewD<'c, A>,
745    ) -> ArrayD<A>
746    where
747        'a: 'b,
748        'c: 'd,
749        A: Clone + LinalgScalar,
750    {
751        let mut adjusted_lhs = self.lhs_permutation.view_singleton(lhs);
752        let mut adjusted_rhs = self.rhs_permutation.view_singleton(rhs);
753        for &i in self.lhs_insertions.iter() {
754            adjusted_lhs = adjusted_lhs.insert_axis(Axis(i));
755        }
756        for &i in self.rhs_insertions.iter() {
757            adjusted_rhs = adjusted_rhs.insert_axis(Axis(i));
758        }
759        let output_shape = IxDyn(&self.output_sizes);
760        let broadcast_lhs = adjusted_lhs.broadcast(output_shape.clone()).unwrap();
761        let broadcast_rhs = adjusted_rhs.broadcast(output_shape).unwrap();
762        self.hadamard_product
763            .contract_pair(&broadcast_lhs, &broadcast_rhs)
764    }
765}
766
767// TODO: Micro-optimization: Have a version without the output permutation,
768// which clones the array
769//
770// TODO: Fix whatever bug prevents this from being used in all cases
771//
772// TODO: convert this to directly reshape into a 3-D matrix instead of delegating
773// that to TensordotGeneral
774
775/// Repeatedly computes the tensor dot of subviews of the two tensors, iterating over
776/// indices which appear in the LHS, RHS, and output.
777///
778/// The indices appearing in all three places are referred to here as the "stack" indices.
779/// For example, in the contraction `ijk,ikl->ijl`, `i` would be the (only) "stack" index.
780/// This contraction is an instance of batch matrix multiplication. The LHS and RHS are both
781/// 3-D tensors and the `i`th (2-D) subview of the output is the matrix product of the `i`th
782/// subview of the LHS matrix-multiplied by the `i`th subview of the RHS.
783///
784/// This is the most general contraction and in theory could handle all pairwise contractions,
785/// but is less performant than special-casing when there are no "stack" indices. It is also
786/// currently the only case that requires `.outer_iter_mut()` (which might make parallelizing
787/// operations more difficult).
788#[derive(Clone, Debug)]
789pub struct StackedTensordotGeneral {
790    lhs_permutation: Permutation,
791    rhs_permutation: Permutation,
792    lhs_output_shape: Vec<usize>,
793    rhs_output_shape: Vec<usize>,
794    intermediate_shape: Vec<usize>,
795    tensordot_fixed_position: TensordotFixedPosition,
796    output_shape: Vec<usize>,
797    output_permutation: Permutation,
798}
799
800impl StackedTensordotGeneral {
801    pub fn new(sc: &SizedContraction) -> Self {
802        let mut lhs_order = Vec::new();
803        let mut rhs_order = Vec::new();
804        let mut lhs_output_shape = Vec::new();
805        let mut rhs_output_shape = Vec::new();
806        let mut intermediate_shape = Vec::new();
807
808        assert_eq!(sc.contraction.operand_indices.len(), 2);
809        let lhs_indices = &sc.contraction.operand_indices[0];
810        let rhs_indices = &sc.contraction.operand_indices[1];
811        let output_indices = &sc.contraction.output_indices;
812
813        let maybe_lhs_axes = maybe_find_outputs_in_inputs_unique(output_indices, lhs_indices);
814        let maybe_rhs_axes = maybe_find_outputs_in_inputs_unique(output_indices, rhs_indices);
815        let mut lhs_stack_axes = Vec::new();
816        let mut rhs_stack_axes = Vec::new();
817        let mut stack_indices = Vec::new();
818        let mut lhs_outer_axes = Vec::new();
819        let mut lhs_outer_indices = Vec::new();
820        let mut rhs_outer_axes = Vec::new();
821        let mut rhs_outer_indices = Vec::new();
822        let mut lhs_contracted_axes = Vec::new();
823        let mut rhs_contracted_axes = Vec::new();
824        let mut intermediate_indices = Vec::new();
825        let mut lhs_uncontracted_shape = Vec::new();
826        let mut rhs_uncontracted_shape = Vec::new();
827        let mut contracted_shape = Vec::new();
828
829        lhs_output_shape.push(1);
830        rhs_output_shape.push(1);
831
832        for ((&maybe_lhs_pos, &maybe_rhs_pos), &output_char) in maybe_lhs_axes
833            .iter()
834            .zip(maybe_rhs_axes.iter())
835            .zip(output_indices.iter())
836        {
837            match (maybe_lhs_pos, maybe_rhs_pos) {
838                (Some(lhs_pos), Some(rhs_pos)) => {
839                    lhs_stack_axes.push(lhs_pos);
840                    rhs_stack_axes.push(rhs_pos);
841                    stack_indices.push(output_char);
842                    lhs_output_shape[0] *= sc.output_size[&output_char];
843                    rhs_output_shape[0] *= sc.output_size[&output_char];
844                }
845                (Some(lhs_pos), None) => {
846                    lhs_outer_axes.push(lhs_pos);
847                    lhs_outer_indices.push(output_char);
848                    lhs_uncontracted_shape.push(sc.output_size[&output_char]);
849                }
850                (None, Some(rhs_pos)) => {
851                    rhs_outer_axes.push(rhs_pos);
852                    rhs_outer_indices.push(output_char);
853                    rhs_uncontracted_shape.push(sc.output_size[&output_char]);
854                }
855                (None, None) => {
856                    panic!() // Output char must be either in lhs or rhs
857                }
858            }
859        }
860
861        for (lhs_pos, &lhs_char) in lhs_indices.iter().enumerate() {
862            if !output_indices
863                .iter()
864                .any(|&output_char| output_char == lhs_char)
865            {
866                // Contracted index
867                lhs_contracted_axes.push(lhs_pos);
868                // Must be in RHS if it's not in output
869                rhs_contracted_axes.push(
870                    rhs_indices
871                        .iter()
872                        .position(|&rhs_char| rhs_char == lhs_char)
873                        .unwrap(),
874                );
875                contracted_shape.push(sc.output_size[&lhs_char]);
876            }
877        }
878        // What order do we want the axes in?
879        //
880        // LHS: Stack axes, outer axes, contracted axes
881        // RHS: Stack axes, contracted axes, outer axes
882
883        lhs_order.append(&mut lhs_stack_axes.clone());
884        lhs_order.append(&mut lhs_outer_axes.clone());
885        lhs_output_shape.append(&mut lhs_uncontracted_shape);
886        lhs_order.append(&mut lhs_contracted_axes.clone());
887        lhs_output_shape.append(&mut contracted_shape.clone());
888
889        rhs_order.append(&mut rhs_stack_axes.clone());
890        rhs_order.append(&mut rhs_contracted_axes.clone());
891        rhs_output_shape.append(&mut contracted_shape);
892        rhs_order.append(&mut rhs_outer_axes.clone());
893        rhs_output_shape.append(&mut rhs_uncontracted_shape);
894
895        // What order will the intermediate output indices be in?
896        // Stack indices, lhs outer indices, rhs outer indices
897        intermediate_indices.append(&mut stack_indices.clone());
898        intermediate_indices.append(&mut lhs_outer_indices.clone());
899        intermediate_indices.append(&mut rhs_outer_indices.clone());
900
901        assert_eq!(lhs_output_shape[0], rhs_output_shape[0]);
902        intermediate_shape.push(lhs_output_shape[0]);
903        for lhs_char in lhs_outer_indices.iter() {
904            intermediate_shape.push(sc.output_size[lhs_char]);
905        }
906        for rhs_char in rhs_outer_indices.iter() {
907            intermediate_shape.push(sc.output_size[rhs_char]);
908        }
909
910        let output_order = find_outputs_in_inputs_unique(output_indices, &intermediate_indices);
911        let output_shape = intermediate_indices
912            .iter()
913            .map(|c| sc.output_size[c])
914            .collect();
915
916        let tensordot_fixed_position =
917            TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
918                &lhs_output_shape[1..],
919                &rhs_output_shape[1..],
920                lhs_contracted_axes.len(),
921            );
922        let lhs_permutation = Permutation::from_indices(&lhs_order);
923        let rhs_permutation = Permutation::from_indices(&rhs_order);
924        let output_permutation = Permutation::from_indices(&output_order);
925        StackedTensordotGeneral {
926            lhs_permutation,
927            rhs_permutation,
928            lhs_output_shape,
929            rhs_output_shape,
930            intermediate_shape,
931            tensordot_fixed_position,
932            output_shape,
933            output_permutation,
934        }
935    }
936}
937
938impl<A> PairContractor<A> for StackedTensordotGeneral {
939    fn contract_pair<'a, 'b, 'c, 'd>(
940        &self,
941        lhs: &'b ArrayViewD<'a, A>,
942        rhs: &'d ArrayViewD<'c, A>,
943    ) -> ArrayD<A>
944    where
945        'a: 'b,
946        'c: 'd,
947        A: Clone + LinalgScalar,
948    {
949        let lhs_permuted = self.lhs_permutation.view_singleton(lhs);
950        let lhs_reshaped = Array::from_shape_vec(
951            IxDyn(&self.lhs_output_shape),
952            lhs_permuted.iter().cloned().collect(),
953        )
954        .unwrap();
955        let rhs_permuted = self.rhs_permutation.view_singleton(rhs);
956        let rhs_reshaped = Array::from_shape_vec(
957            IxDyn(&self.rhs_output_shape),
958            rhs_permuted.iter().cloned().collect(),
959        )
960        .unwrap();
961        let mut intermediate_result: ArrayD<A> = Array::zeros(IxDyn(&self.intermediate_shape));
962        let mut lhs_iter = lhs_reshaped.outer_iter();
963        let mut rhs_iter = rhs_reshaped.outer_iter();
964        for mut output_subview in intermediate_result.outer_iter_mut() {
965            let lhs_subview = lhs_iter.next().unwrap();
966            let rhs_subview = rhs_iter.next().unwrap();
967            self.tensordot_fixed_position.contract_and_assign_pair(
968                &lhs_subview,
969                &rhs_subview,
970                &mut output_subview,
971            );
972        }
973        let intermediate_reshaped = intermediate_result
974            .into_shape_with_order(IxDyn(&self.output_shape))
975            .unwrap();
976        self.output_permutation
977            .contract_singleton(&intermediate_reshaped.view())
978    }
979}