ndarray_einsum/contractors/
pair_contractors.rs

1// Copyright 2019 Jared Samet
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Contains the specific implementations of `PairContractor` that represent the base-case ways
16//! to contract two simplified tensors.
17//!
18//! A tensor is simplified with respect to another tensor and a set of output indices
19//! if two conditions are met:
20//!
21//! 1. Each index in the tensor is present in either the other tensor or in the output indices.
22//! 2. Each index in the tensor appears only once.
23//!
24//! Examples of `einsum` strings with two simplified tensors:
25//!
26//! 1. `ijk,jkl->il`
27//! 2. `ijk,jkl->ijkl`
28//! 3. `ijk,ijk->`
29//!
30//! Examples of `einsum` strings with two tensors that are NOT simplified:
31//!
32//! 1. `iijk,jkl->il` Not simplified because `i` appears twice in the LHS tensor
33//! 2. `ijk,jkl->i` Not simplified because `l` only appears in the RHS tensor
34//!
35//! If the two input tensors are both simplified, all instances of tensor contraction
36//! can be expressed as one of a small number of cases. Note that there may be more than
37//! one way to express the same contraction; some preliminary benchmarking has been
38//! done to identify the faster choice.
39
40use hashbrown::HashSet;
41use ndarray::prelude::*;
42use ndarray::LinalgScalar;
43
44use super::{PairContractor, Permutation, SingletonContractor, SingletonViewer};
45use crate::SizedContraction;
46
47/// Helper function used throughout this module to find the positions of one set of indices
48/// in a second set of indices. The most common case is generating a permutation to
49/// be supplied to `permuted_axes`.
50fn maybe_find_outputs_in_inputs_unique(
51    output_indices: &[char],
52    input_indices: &[char],
53) -> Vec<Option<usize>> {
54    output_indices
55        .iter()
56        .map(|&output_char| {
57            let input_pos = input_indices
58                .iter()
59                .position(|&input_char| input_char == output_char);
60            if let Some(input) = input_pos {
61                assert!(!input_indices
62                    .iter()
63                    .skip(input + 1)
64                    .any(|&input_char| input_char == output_char));
65            }
66            input_pos
67        })
68        .collect()
69}
70
71fn find_outputs_in_inputs_unique(output_indices: &[char], input_indices: &[char]) -> Vec<usize> {
72    maybe_find_outputs_in_inputs_unique(output_indices, input_indices)
73        .iter()
74        .map(|x| x.unwrap())
75        .collect()
76}
77
78/// Performs tensor dot product for two tensors where no permutation needs to be performed,
79/// e.g. `ijk,jkl->il` or `ijk,klm->ijlm`.
80///
81/// The axes to be contracted must be the last axes of the LHS tensor and the first axes
82/// of the RHS tensor, and the axis order for the output tensor must be all the uncontracted
83/// axes of the LHS tensor followed by all the uncontracted axes of the RHS tensor, in the
84/// orders those originally appear in the LHS and RHS tensors.
85///
86/// The contraction is performed by reshaping the LHS into a matrix (2-D tensor) of shape
87/// [len_uncontracted_lhs, len_contracted_axes], reshaping the RHS into shape
88/// [len_contracted_axes, len_contracted_rhs], matrix-multiplying the two reshaped tensor,
89/// and then reshaping the result into [...self.output_shape].
90#[derive(Clone, Debug)]
91pub struct TensordotFixedPosition {
92    /// The product of the lengths of all the uncontracted axes in the LHS (or 1 if all of the
93    /// LHS axes are contracted)
94    len_uncontracted_lhs: usize,
95
96    /// The product of the lengths of all the uncontracted axes in the RHS (or 1 if all of the
97    /// RHS axes are contracted)
98    len_uncontracted_rhs: usize,
99
100    /// The product of the lengths of all the contracted axes (or 1 if no axes are contracted,
101    /// i.e. the outer product is computed)
102    len_contracted_axes: usize,
103
104    /// The shape that the tensor dot product will be recast to
105    output_shape: Vec<usize>,
106}
107
108impl TensordotFixedPosition {
109    pub fn new(sc: &SizedContraction) -> Self {
110        assert_eq!(sc.contraction.operand_indices.len(), 2);
111        let lhs_indices = &sc.contraction.operand_indices[0];
112        let rhs_indices = &sc.contraction.operand_indices[1];
113        let output_indices = &sc.contraction.output_indices;
114        // Returns an n-dimensional array where n = |D| + |E| - 2 * last_n.
115        let twice_num_contracted_axes =
116            lhs_indices.len() + rhs_indices.len() - output_indices.len();
117        assert_eq!(twice_num_contracted_axes % 2, 0);
118        let num_contracted_axes = twice_num_contracted_axes / 2;
119        // TODO: Add an assert! that they have the same indices
120
121        let lhs_shape: Vec<usize> = lhs_indices.iter().map(|c| sc.output_size[c]).collect();
122        let rhs_shape: Vec<usize> = rhs_indices.iter().map(|c| sc.output_size[c]).collect();
123
124        TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
125            &lhs_shape,
126            &rhs_shape,
127            num_contracted_axes,
128        )
129    }
130
131    /// Compute the uncontracted and contracted axis lengths and the output shape based on the
132    /// input shapes and how many axes should be contracted from each tensor.
133    ///
134    /// TODO: The assert_eq! here could be tightened up by verifying that the
135    /// last `num_contracted_axes` of the LHS match the first `num_contracted_axes` of the
136    /// RHS axis-by-axis (as opposed to only checking the product as is done here.)
137    pub fn from_shapes_and_number_of_contracted_axes(
138        lhs_shape: &[usize],
139        rhs_shape: &[usize],
140        num_contracted_axes: usize,
141    ) -> Self {
142        let mut len_uncontracted_lhs = 1;
143        let mut len_uncontracted_rhs = 1;
144        let mut len_contracted_lhs = 1;
145        let mut len_contracted_rhs = 1;
146        let mut output_shape = Vec::new();
147
148        let num_axes_lhs = lhs_shape.len();
149        for (axis, &axis_length) in lhs_shape.iter().enumerate() {
150            if axis < (num_axes_lhs - num_contracted_axes) {
151                len_uncontracted_lhs *= axis_length;
152                output_shape.push(axis_length);
153            } else {
154                len_contracted_lhs *= axis_length;
155            }
156        }
157
158        for (axis, &axis_length) in rhs_shape.iter().enumerate() {
159            if axis < num_contracted_axes {
160                len_contracted_rhs *= axis_length;
161            } else {
162                len_uncontracted_rhs *= axis_length;
163                output_shape.push(axis_length);
164            }
165        }
166        assert_eq!(len_contracted_rhs, len_contracted_lhs);
167        let len_contracted_axes = len_contracted_lhs;
168
169        TensordotFixedPosition {
170            len_uncontracted_lhs,
171            len_uncontracted_rhs,
172            len_contracted_axes,
173            output_shape,
174        }
175    }
176}
177
178impl<A> PairContractor<A> for TensordotFixedPosition {
179    fn contract_pair<'a, 'b, 'c, 'd>(
180        &self,
181        lhs: &'b ArrayViewD<'a, A>,
182        rhs: &'d ArrayViewD<'c, A>,
183    ) -> ArrayD<A>
184    where
185        'a: 'b,
186        'c: 'd,
187        A: Clone + LinalgScalar,
188    {
189        let lhs_array;
190        let lhs_view = if lhs.is_standard_layout() {
191            lhs.view()
192                .into_shape_with_order((self.len_uncontracted_lhs, self.len_contracted_axes))
193                .unwrap()
194        } else {
195            lhs_array = Array::from_shape_vec(
196                [self.len_uncontracted_lhs, self.len_contracted_axes],
197                lhs.iter().cloned().collect(),
198            )
199            .unwrap();
200            lhs_array.view()
201        };
202
203        let rhs_array;
204        let rhs_view = if rhs.is_standard_layout() {
205            rhs.view()
206                .into_shape_with_order((self.len_contracted_axes, self.len_uncontracted_rhs))
207                .unwrap()
208        } else {
209            rhs_array = Array::from_shape_vec(
210                [self.len_contracted_axes, self.len_uncontracted_rhs],
211                rhs.iter().cloned().collect(),
212            )
213            .unwrap();
214            rhs_array.view()
215        };
216
217        lhs_view
218            .dot(&rhs_view)
219            .into_shape_with_order(IxDyn(&self.output_shape))
220            .unwrap()
221    }
222}
223
224// TODO: Micro-optimization possible: Have a version without the final permutation,
225// which clones the array
226/// Computes the tensor dot product of two tensors, with individual permutations of the
227/// LHS and RHS performed as necessary, as well as a final permutation of the output.
228///
229/// Examples that qualify for TensordotGeneral but not TensordotFixedPosition:
230///
231/// 1. `jik,jkl->il` LHS tensor needs to be permuted `jik->ijk`
232/// 2. `ijk,klm->imlj` Output tensor needs to be permuted `ijlm->imlj`
233#[derive(Clone, Debug)]
234pub struct TensordotGeneral {
235    lhs_permutation: Permutation,
236    rhs_permutation: Permutation,
237    tensordot_fixed_position: TensordotFixedPosition,
238    output_permutation: Permutation,
239}
240
241impl TensordotGeneral {
242    pub fn new(sc: &SizedContraction) -> Self {
243        assert_eq!(sc.contraction.operand_indices.len(), 2);
244        let lhs_indices = &sc.contraction.operand_indices[0];
245        let rhs_indices = &sc.contraction.operand_indices[1];
246        let contracted_indices = &sc.contraction.summation_indices;
247        let output_indices = &sc.contraction.output_indices;
248        let lhs_shape: Vec<usize> = lhs_indices.iter().map(|c| sc.output_size[c]).collect();
249        let rhs_shape: Vec<usize> = rhs_indices.iter().map(|c| sc.output_size[c]).collect();
250
251        TensordotGeneral::from_shapes_and_indices(
252            &lhs_shape,
253            &rhs_shape,
254            lhs_indices,
255            rhs_indices,
256            contracted_indices,
257            output_indices,
258        )
259    }
260
261    fn from_shapes_and_indices(
262        lhs_shape: &[usize],
263        rhs_shape: &[usize],
264        lhs_indices: &[char],
265        rhs_indices: &[char],
266        contracted_indices: &[char],
267        output_indices: &[char],
268    ) -> Self {
269        let lhs_contracted_axes = find_outputs_in_inputs_unique(contracted_indices, lhs_indices);
270        let rhs_contracted_axes = find_outputs_in_inputs_unique(contracted_indices, rhs_indices);
271        let mut uncontracted_chars: Vec<char> = lhs_indices
272            .iter()
273            .filter(|&&input_char| output_indices.contains(&input_char))
274            .cloned()
275            .collect();
276        let mut rhs_uncontracted_chars: Vec<char> = rhs_indices
277            .iter()
278            .filter(|&&input_char| output_indices.contains(&input_char))
279            .cloned()
280            .collect();
281        uncontracted_chars.append(&mut rhs_uncontracted_chars);
282        let output_order = find_outputs_in_inputs_unique(output_indices, &uncontracted_chars);
283
284        TensordotGeneral::from_shapes_and_axis_numbers(
285            lhs_shape,
286            rhs_shape,
287            &lhs_contracted_axes,
288            &rhs_contracted_axes,
289            &output_order,
290        )
291    }
292
293    /// Produces a `TensordotGeneral` from the shapes and list of axes to be contracted.
294    ///
295    /// Wrapped by the public `tensordot` function and used by `TensordotGeneral::new()`.
296    /// lhs_axes lists the axes from the lhs tensor to contract and rhs_axes lists the
297    /// axes from the rhs tensor to contract.
298    pub fn from_shapes_and_axis_numbers(
299        lhs_shape: &[usize],
300        rhs_shape: &[usize],
301        lhs_axes: &[usize],
302        rhs_axes: &[usize],
303        output_order: &[usize],
304    ) -> Self {
305        let num_contracted_axes = lhs_axes.len();
306        assert!(num_contracted_axes == rhs_axes.len());
307        let lhs_uniques: HashSet<_> = lhs_axes.iter().cloned().collect();
308        let rhs_uniques: HashSet<_> = rhs_axes.iter().cloned().collect();
309        assert!(num_contracted_axes == lhs_uniques.len());
310        assert!(num_contracted_axes == rhs_uniques.len());
311        let mut adjusted_lhs_shape = Vec::new();
312        let mut adjusted_rhs_shape = Vec::new();
313
314        // Rolls the axes specified in lhs and rhs to the back and front respectively,
315        // then calls tensordot_fixed_order(rolled_lhs, rolled_rhs, lhs_axes.len())
316        let mut permutation_lhs = Vec::new();
317        for (i, &axis_length) in lhs_shape.iter().enumerate() {
318            if !(lhs_uniques.contains(&i)) {
319                permutation_lhs.push(i);
320                adjusted_lhs_shape.push(axis_length);
321            }
322        }
323        for &axis in lhs_axes.iter() {
324            permutation_lhs.push(axis);
325            adjusted_lhs_shape.push(lhs_shape[axis]);
326        }
327
328        // Note: These two for loops are (intentionally!) in the reverse order
329        // as they are for LHS.
330        let mut permutation_rhs = Vec::new();
331        for &axis in rhs_axes.iter() {
332            permutation_rhs.push(axis);
333            adjusted_rhs_shape.push(rhs_shape[axis]);
334        }
335        for (i, &axis_length) in rhs_shape.iter().enumerate() {
336            if !(rhs_uniques.contains(&i)) {
337                permutation_rhs.push(i);
338                adjusted_rhs_shape.push(axis_length);
339            }
340        }
341
342        let lhs_permutation = Permutation::from_indices(&permutation_lhs);
343        let rhs_permutation = Permutation::from_indices(&permutation_rhs);
344        let tensordot_fixed_position =
345            TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
346                &adjusted_lhs_shape,
347                &adjusted_rhs_shape,
348                num_contracted_axes,
349            );
350
351        let output_permutation = Permutation::from_indices(output_order);
352
353        TensordotGeneral {
354            lhs_permutation,
355            rhs_permutation,
356            tensordot_fixed_position,
357            output_permutation,
358        }
359    }
360}
361
362impl<A> PairContractor<A> for TensordotGeneral {
363    fn contract_pair<'a, 'b, 'c, 'd>(
364        &self,
365        lhs: &'b ArrayViewD<'a, A>,
366        rhs: &'d ArrayViewD<'c, A>,
367    ) -> ArrayD<A>
368    where
369        'a: 'b,
370        'c: 'd,
371        A: Clone + LinalgScalar,
372    {
373        let permuted_lhs = self.lhs_permutation.view_singleton(lhs);
374        let permuted_rhs = self.rhs_permutation.view_singleton(rhs);
375        let tensordotted = self
376            .tensordot_fixed_position
377            .contract_pair(&permuted_lhs, &permuted_rhs);
378        self.output_permutation
379            .contract_singleton(&tensordotted.view())
380    }
381}
382
383/// Computes the Hadamard (element-wise) product of two tensors.
384///
385/// All instances of `SizedContraction` making use of this contractor must have the form
386/// `ij,ij->ij`.
387///
388/// Contractions of the form `ij,ji->ij` need to use `HadamardProductGeneral` instead.
389#[derive(Clone, Debug)]
390pub struct HadamardProduct {}
391
392impl HadamardProduct {
393    pub fn new(sc: &SizedContraction) -> Self {
394        assert_eq!(sc.contraction.operand_indices.len(), 2);
395        let lhs_indices = &sc.contraction.operand_indices[0];
396        let rhs_indices = &sc.contraction.operand_indices[1];
397        let output_indices = &sc.contraction.output_indices;
398        assert_eq!(lhs_indices, rhs_indices);
399        assert_eq!(lhs_indices, output_indices);
400
401        HadamardProduct {}
402    }
403
404    fn from_nothing() -> Self {
405        HadamardProduct {}
406    }
407}
408
409impl<A> PairContractor<A> for HadamardProduct {
410    fn contract_pair<'a, 'b, 'c, 'd>(
411        &self,
412        lhs: &'b ArrayViewD<'a, A>,
413        rhs: &'d ArrayViewD<'c, A>,
414    ) -> ArrayD<A>
415    where
416        'a: 'b,
417        'c: 'd,
418        A: Clone + LinalgScalar,
419    {
420        lhs * rhs
421    }
422}
423
424/// Permutes the axes of the LHS and RHS tensors to the order in which those axes appear in the
425/// output before computing the Hadamard (element-wise) product.
426///
427/// Examples of contractions that fit this category:
428///
429/// 1. `ij,ij->ij` (Can also can use `HadamardProduct`)
430/// 2. `ij,ji->ij` (Can only use `HadamardProductGeneral`)
431#[derive(Clone, Debug)]
432pub struct HadamardProductGeneral {
433    lhs_permutation: Permutation,
434    rhs_permutation: Permutation,
435    hadamard_product: HadamardProduct,
436}
437
438impl HadamardProductGeneral {
439    pub fn new(sc: &SizedContraction) -> Self {
440        assert_eq!(sc.contraction.operand_indices.len(), 2);
441        let lhs_indices = &sc.contraction.operand_indices[0];
442        let rhs_indices = &sc.contraction.operand_indices[1];
443        let output_indices = &sc.contraction.output_indices;
444        assert_eq!(lhs_indices.len(), rhs_indices.len());
445        assert_eq!(lhs_indices.len(), output_indices.len());
446
447        let lhs_permutation =
448            Permutation::from_indices(&find_outputs_in_inputs_unique(output_indices, lhs_indices));
449        let rhs_permutation =
450            Permutation::from_indices(&find_outputs_in_inputs_unique(output_indices, rhs_indices));
451        let hadamard_product = HadamardProduct::from_nothing();
452
453        HadamardProductGeneral {
454            lhs_permutation,
455            rhs_permutation,
456            hadamard_product,
457        }
458    }
459}
460
461impl<A> PairContractor<A> for HadamardProductGeneral {
462    fn contract_pair<'a, 'b, 'c, 'd>(
463        &self,
464        lhs: &'b ArrayViewD<'a, A>,
465        rhs: &'d ArrayViewD<'c, A>,
466    ) -> ArrayD<A>
467    where
468        'a: 'b,
469        'c: 'd,
470        A: Clone + LinalgScalar,
471    {
472        self.hadamard_product.contract_pair(
473            &self.lhs_permutation.view_singleton(lhs),
474            &self.rhs_permutation.view_singleton(rhs),
475        )
476    }
477}
478
479/// Multiplies every element of the RHS tensor by the single scalar in the 0-d LHS tensor.
480///
481/// This contraction can arise when the simplification of the LHS tensor results in all the
482/// axes being summed before the two tensors are contracted. For example, in the contraction
483/// `i,jk->jk`, every element of the RHS tensor is simply multiplied by the sum of the elements
484/// of the LHS tensor.
485#[derive(Clone, Debug)]
486pub struct ScalarMatrixProduct {}
487
488impl ScalarMatrixProduct {
489    pub fn new(sc: &SizedContraction) -> Self {
490        assert_eq!(sc.contraction.operand_indices.len(), 2);
491        let lhs_indices = &sc.contraction.operand_indices[0];
492        let rhs_indices = &sc.contraction.operand_indices[1];
493        let output_indices = &sc.contraction.output_indices;
494        assert_eq!(lhs_indices.len(), 0);
495        assert_eq!(output_indices, rhs_indices);
496
497        ScalarMatrixProduct {}
498    }
499
500    pub fn from_nothing() -> Self {
501        ScalarMatrixProduct {}
502    }
503}
504
505impl<A> PairContractor<A> for ScalarMatrixProduct {
506    fn contract_pair<'a, 'b, 'c, 'd>(
507        &self,
508        lhs: &'b ArrayViewD<'a, A>,
509        rhs: &'d ArrayViewD<'c, A>,
510    ) -> ArrayD<A>
511    where
512        'a: 'b,
513        'c: 'd,
514        A: Clone + LinalgScalar,
515    {
516        let lhs_0d: A = *lhs.first().unwrap();
517        rhs.mapv(|x| x * lhs_0d)
518    }
519}
520
521/// Permutes the axes of the RHS tensor to the output order and multiply all elements by the single
522/// scalar in the 0-d LHS tensor.
523///
524/// This contraction can arise when the simplification of the LHS tensor results in all the
525/// axes being summed before the two tensors are contracted. For example, in the contraction
526/// `i,jk->kj`, the output matrix is equal to the RHS matrix, transposed and then scalar-multiplied
527/// by the sum of the elements of the LHS tensor.
528#[derive(Clone, Debug)]
529pub struct ScalarMatrixProductGeneral {
530    rhs_permutation: Permutation,
531    scalar_matrix_product: ScalarMatrixProduct,
532}
533
534impl ScalarMatrixProductGeneral {
535    pub fn new(sc: &SizedContraction) -> Self {
536        assert_eq!(sc.contraction.operand_indices.len(), 2);
537        let lhs_indices = &sc.contraction.operand_indices[0];
538        let rhs_indices = &sc.contraction.operand_indices[1];
539        let output_indices = &sc.contraction.output_indices;
540        assert_eq!(lhs_indices.len(), 0);
541        assert_eq!(rhs_indices.len(), output_indices.len());
542
543        ScalarMatrixProductGeneral::from_indices(rhs_indices, output_indices)
544    }
545
546    pub fn from_indices(input_indices: &[char], output_indices: &[char]) -> Self {
547        let rhs_permutation = Permutation::from_indices(&find_outputs_in_inputs_unique(
548            output_indices,
549            input_indices,
550        ));
551        let scalar_matrix_product = ScalarMatrixProduct::from_nothing();
552
553        ScalarMatrixProductGeneral {
554            rhs_permutation,
555            scalar_matrix_product,
556        }
557    }
558}
559
560impl<A> PairContractor<A> for ScalarMatrixProductGeneral {
561    fn contract_pair<'a, 'b, 'c, 'd>(
562        &self,
563        lhs: &'b ArrayViewD<'a, A>,
564        rhs: &'d ArrayViewD<'c, A>,
565    ) -> ArrayD<A>
566    where
567        'a: 'b,
568        'c: 'd,
569        A: Clone + LinalgScalar,
570    {
571        self.scalar_matrix_product
572            .contract_pair(lhs, &self.rhs_permutation.view_singleton(rhs))
573    }
574}
575
576/// Multiplies every element of the LHS tensor by the single scalar in the 0-d RHS tensor.
577///
578/// This contraction can arise when the simplification of the LHS tensor results in all the
579/// axes being summed before the two tensors are contracted. For example, in the contraction
580/// `ij,k->ij`, every element of the LHS tensor is simply multiplied by the sum of the elements
581/// of the RHS tensor.
582#[derive(Clone, Debug)]
583pub struct MatrixScalarProduct {}
584
585impl MatrixScalarProduct {
586    pub fn new(sc: &SizedContraction) -> Self {
587        assert_eq!(sc.contraction.operand_indices.len(), 2);
588        let lhs_indices = &sc.contraction.operand_indices[0];
589        let rhs_indices = &sc.contraction.operand_indices[1];
590        let output_indices = &sc.contraction.output_indices;
591        assert_eq!(rhs_indices.len(), 0);
592        assert_eq!(output_indices, lhs_indices);
593
594        MatrixScalarProduct {}
595    }
596
597    pub fn from_nothing() -> Self {
598        MatrixScalarProduct {}
599    }
600}
601
602impl<A> PairContractor<A> for MatrixScalarProduct {
603    fn contract_pair<'a, 'b, 'c, 'd>(
604        &self,
605        lhs: &'b ArrayViewD<'a, A>,
606        rhs: &'d ArrayViewD<'c, A>,
607    ) -> ArrayD<A>
608    where
609        'a: 'b,
610        'c: 'd,
611        A: Clone + LinalgScalar,
612    {
613        let rhs_0d: A = *rhs.first().unwrap();
614        lhs.mapv(|x| x * rhs_0d)
615    }
616}
617
618/// Permutes the axes of the LHS tensor to the output order and multiply all elements by the single
619/// scalar in the 0-d RHS tensor.
620///
621/// This contraction can arise when the simplification of the RHS tensor results in all the
622/// axes being summed before the two tensors are contracted. For example, in the contraction
623/// `ij,k->ji`, the output matrix is equal to the LHS matrix, transposed and then scalar-multiplied
624/// by the sum of the elements of the RHS tensor.
625#[derive(Clone, Debug)]
626pub struct MatrixScalarProductGeneral {
627    lhs_permutation: Permutation,
628    matrix_scalar_product: MatrixScalarProduct,
629}
630
631impl MatrixScalarProductGeneral {
632    pub fn new(sc: &SizedContraction) -> Self {
633        assert_eq!(sc.contraction.operand_indices.len(), 2);
634        let lhs_indices = &sc.contraction.operand_indices[0];
635        let rhs_indices = &sc.contraction.operand_indices[1];
636        let output_indices = &sc.contraction.output_indices;
637        assert_eq!(rhs_indices.len(), 0);
638        assert_eq!(lhs_indices.len(), output_indices.len());
639
640        MatrixScalarProductGeneral::from_indices(lhs_indices, output_indices)
641    }
642
643    pub fn from_indices(input_indices: &[char], output_indices: &[char]) -> Self {
644        let lhs_permutation = Permutation::from_indices(&find_outputs_in_inputs_unique(
645            output_indices,
646            input_indices,
647        ));
648        let matrix_scalar_product = MatrixScalarProduct::from_nothing();
649
650        MatrixScalarProductGeneral {
651            lhs_permutation,
652            matrix_scalar_product,
653        }
654    }
655}
656
657impl<A> PairContractor<A> for MatrixScalarProductGeneral {
658    fn contract_pair<'a, 'b, 'c, 'd>(
659        &self,
660        lhs: &'b ArrayViewD<'a, A>,
661        rhs: &'d ArrayViewD<'c, A>,
662    ) -> ArrayD<A>
663    where
664        'a: 'b,
665        'c: 'd,
666        A: Clone + LinalgScalar,
667    {
668        self.matrix_scalar_product
669            .contract_pair(&self.lhs_permutation.view_singleton(lhs), rhs)
670    }
671}
672
673/// Permutes the axes of the LHS and RHS tensor, broadcasts into the output shape,
674/// and then computes the element-wise product of the two broadcast tensors.
675///
676/// Currently unused due to (limited) unfavorable benchmarking results compared to
677/// `StackedTensordotGeneral`. An example of a contraction that could theoretically
678/// be performed by this contraction is `ij,jk->ijk`: the LHS and RHS are both
679/// broadcast into output shape (|i|, |j|, |k|) and then multiplied elementwise.
680///
681/// However, the limited benchmarking performed so far favored iterating along the
682/// `j` axis and computing the outer products `i,k->ik` for each subview of the tensors.
683#[derive(Clone, Debug)]
684pub struct BroadcastProductGeneral {
685    lhs_permutation: Permutation,
686    rhs_permutation: Permutation,
687    lhs_insertions: Vec<usize>,
688    rhs_insertions: Vec<usize>,
689    output_sizes: Vec<usize>,
690    hadamard_product: HadamardProduct,
691}
692
693impl BroadcastProductGeneral {
694    pub fn new(sc: &SizedContraction) -> Self {
695        assert_eq!(sc.contraction.operand_indices.len(), 2);
696        let lhs_indices = &sc.contraction.operand_indices[0];
697        let rhs_indices = &sc.contraction.operand_indices[1];
698        let output_indices = &sc.contraction.output_indices;
699
700        let maybe_lhs_indices = maybe_find_outputs_in_inputs_unique(output_indices, lhs_indices);
701        let maybe_rhs_indices = maybe_find_outputs_in_inputs_unique(output_indices, rhs_indices);
702        let lhs_indices: Vec<usize> = maybe_lhs_indices.iter().copied().flatten().collect();
703        let rhs_indices: Vec<usize> = maybe_rhs_indices.iter().copied().flatten().collect();
704        let lhs_insertions: Vec<usize> = maybe_lhs_indices
705            .into_iter()
706            .enumerate()
707            .filter(|(_, x)| x.is_none())
708            .map(|(i, _)| i)
709            .collect();
710        let rhs_insertions: Vec<usize> = maybe_rhs_indices
711            .into_iter()
712            .enumerate()
713            .filter(|(_, x)| x.is_none())
714            .map(|(i, _)| i)
715            .collect();
716        let lhs_permutation = Permutation::from_indices(&lhs_indices);
717        let rhs_permutation = Permutation::from_indices(&rhs_indices);
718        let output_sizes: Vec<usize> = output_indices.iter().map(|c| sc.output_size[c]).collect();
719        let hadamard_product = HadamardProduct::from_nothing();
720
721        BroadcastProductGeneral {
722            lhs_permutation,
723            rhs_permutation,
724            lhs_insertions,
725            rhs_insertions,
726            output_sizes,
727            hadamard_product,
728        }
729    }
730}
731
732impl<A> PairContractor<A> for BroadcastProductGeneral {
733    fn contract_pair<'a, 'b, 'c, 'd>(
734        &self,
735        lhs: &'b ArrayViewD<'a, A>,
736        rhs: &'d ArrayViewD<'c, A>,
737    ) -> ArrayD<A>
738    where
739        'a: 'b,
740        'c: 'd,
741        A: Clone + LinalgScalar,
742    {
743        let mut adjusted_lhs = self.lhs_permutation.view_singleton(lhs);
744        let mut adjusted_rhs = self.rhs_permutation.view_singleton(rhs);
745        for &i in self.lhs_insertions.iter() {
746            adjusted_lhs = adjusted_lhs.insert_axis(Axis(i));
747        }
748        for &i in self.rhs_insertions.iter() {
749            adjusted_rhs = adjusted_rhs.insert_axis(Axis(i));
750        }
751        let output_shape = IxDyn(&self.output_sizes);
752        let broadcast_lhs = adjusted_lhs.broadcast(output_shape.clone()).unwrap();
753        let broadcast_rhs = adjusted_rhs.broadcast(output_shape).unwrap();
754        self.hadamard_product
755            .contract_pair(&broadcast_lhs, &broadcast_rhs)
756    }
757}
758
759// TODO: Micro-optimization: Have a version without the output permutation,
760// which clones the array
761//
762// TODO: Fix whatever bug prevents this from being used in all cases
763//
764// TODO: convert this to directly reshape into a 3-D matrix instead of delegating
765// that to TensordotGeneral
766
767/// Repeatedly computes the tensor dot of subviews of the two tensors, iterating over
768/// indices which appear in the LHS, RHS, and output.
769///
770/// The indices appearing in all three places are referred to here as the "stack" indices.
771/// For example, in the contraction `ijk,ikl->ijl`, `i` would be the (only) "stack" index.
772/// This contraction is an instance of batch matrix multiplication. The LHS and RHS are both
773/// 3-D tensors and the `i`th (2-D) subview of the output is the matrix product of the `i`th
774/// subview of the LHS matrix-multiplied by the `i`th subview of the RHS.
775///
776/// This is the most general contraction and in theory could handle all pairwise contractions,
777/// but is less performant than special-casing when there are no "stack" indices. It is also
778/// currently the only case that requires `.outer_iter_mut()` (which might make parallelizing
779/// operations more difficult).
780#[derive(Clone, Debug)]
781pub struct StackedTensordotGeneral {
782    lhs_permutation: Permutation,
783    rhs_permutation: Permutation,
784    lhs_output_shape: Vec<usize>,
785    rhs_output_shape: Vec<usize>,
786    intermediate_shape: Vec<usize>,
787    tensordot_fixed_position: TensordotFixedPosition,
788    output_shape: Vec<usize>,
789    output_permutation: Permutation,
790}
791
792impl StackedTensordotGeneral {
793    pub fn new(sc: &SizedContraction) -> Self {
794        let mut lhs_order = Vec::new();
795        let mut rhs_order = Vec::new();
796        let mut lhs_output_shape = Vec::new();
797        let mut rhs_output_shape = Vec::new();
798        let mut intermediate_shape = Vec::new();
799
800        assert_eq!(sc.contraction.operand_indices.len(), 2);
801        let lhs_indices = &sc.contraction.operand_indices[0];
802        let rhs_indices = &sc.contraction.operand_indices[1];
803        let output_indices = &sc.contraction.output_indices;
804
805        let maybe_lhs_axes = maybe_find_outputs_in_inputs_unique(output_indices, lhs_indices);
806        let maybe_rhs_axes = maybe_find_outputs_in_inputs_unique(output_indices, rhs_indices);
807        let mut lhs_stack_axes = Vec::new();
808        let mut rhs_stack_axes = Vec::new();
809        let mut stack_indices = Vec::new();
810        let mut lhs_outer_axes = Vec::new();
811        let mut lhs_outer_indices = Vec::new();
812        let mut rhs_outer_axes = Vec::new();
813        let mut rhs_outer_indices = Vec::new();
814        let mut lhs_contracted_axes = Vec::new();
815        let mut rhs_contracted_axes = Vec::new();
816        let mut intermediate_indices = Vec::new();
817        let mut lhs_uncontracted_shape = Vec::new();
818        let mut rhs_uncontracted_shape = Vec::new();
819        let mut contracted_shape = Vec::new();
820
821        lhs_output_shape.push(1);
822        rhs_output_shape.push(1);
823
824        for ((&maybe_lhs_pos, &maybe_rhs_pos), &output_char) in maybe_lhs_axes
825            .iter()
826            .zip(maybe_rhs_axes.iter())
827            .zip(output_indices.iter())
828        {
829            match (maybe_lhs_pos, maybe_rhs_pos) {
830                (Some(lhs_pos), Some(rhs_pos)) => {
831                    lhs_stack_axes.push(lhs_pos);
832                    rhs_stack_axes.push(rhs_pos);
833                    stack_indices.push(output_char);
834                    lhs_output_shape[0] *= sc.output_size[&output_char];
835                    rhs_output_shape[0] *= sc.output_size[&output_char];
836                }
837                (Some(lhs_pos), None) => {
838                    lhs_outer_axes.push(lhs_pos);
839                    lhs_outer_indices.push(output_char);
840                    lhs_uncontracted_shape.push(sc.output_size[&output_char]);
841                }
842                (None, Some(rhs_pos)) => {
843                    rhs_outer_axes.push(rhs_pos);
844                    rhs_outer_indices.push(output_char);
845                    rhs_uncontracted_shape.push(sc.output_size[&output_char]);
846                }
847                (None, None) => {
848                    panic!() // Output char must be either in lhs or rhs
849                }
850            }
851        }
852
853        for (lhs_pos, &lhs_char) in lhs_indices.iter().enumerate() {
854            if !output_indices.contains(&lhs_char) {
855                // Contracted index
856                lhs_contracted_axes.push(lhs_pos);
857                // Must be in RHS if it's not in output
858                rhs_contracted_axes.push(
859                    rhs_indices
860                        .iter()
861                        .position(|&rhs_char| rhs_char == lhs_char)
862                        .unwrap(),
863                );
864                contracted_shape.push(sc.output_size[&lhs_char]);
865            }
866        }
867        // What order do we want the axes in?
868        //
869        // LHS: Stack axes, outer axes, contracted axes
870        // RHS: Stack axes, contracted axes, outer axes
871
872        lhs_order.append(&mut lhs_stack_axes.clone());
873        lhs_order.append(&mut lhs_outer_axes.clone());
874        lhs_output_shape.append(&mut lhs_uncontracted_shape);
875        lhs_order.append(&mut lhs_contracted_axes.clone());
876        lhs_output_shape.append(&mut contracted_shape.clone());
877
878        rhs_order.append(&mut rhs_stack_axes.clone());
879        rhs_order.append(&mut rhs_contracted_axes.clone());
880        rhs_output_shape.append(&mut contracted_shape);
881        rhs_order.append(&mut rhs_outer_axes.clone());
882        rhs_output_shape.append(&mut rhs_uncontracted_shape);
883
884        // What order will the intermediate output indices be in?
885        // Stack indices, lhs outer indices, rhs outer indices
886        intermediate_indices.append(&mut stack_indices.clone());
887        intermediate_indices.append(&mut lhs_outer_indices.clone());
888        intermediate_indices.append(&mut rhs_outer_indices.clone());
889
890        assert_eq!(lhs_output_shape[0], rhs_output_shape[0]);
891        intermediate_shape.push(lhs_output_shape[0]);
892        for lhs_char in lhs_outer_indices.iter() {
893            intermediate_shape.push(sc.output_size[lhs_char]);
894        }
895        for rhs_char in rhs_outer_indices.iter() {
896            intermediate_shape.push(sc.output_size[rhs_char]);
897        }
898
899        let output_order = find_outputs_in_inputs_unique(output_indices, &intermediate_indices);
900        let output_shape = intermediate_indices
901            .iter()
902            .map(|c| sc.output_size[c])
903            .collect();
904
905        let tensordot_fixed_position =
906            TensordotFixedPosition::from_shapes_and_number_of_contracted_axes(
907                &lhs_output_shape[1..],
908                &rhs_output_shape[1..],
909                lhs_contracted_axes.len(),
910            );
911        let lhs_permutation = Permutation::from_indices(&lhs_order);
912        let rhs_permutation = Permutation::from_indices(&rhs_order);
913        let output_permutation = Permutation::from_indices(&output_order);
914        StackedTensordotGeneral {
915            lhs_permutation,
916            rhs_permutation,
917            lhs_output_shape,
918            rhs_output_shape,
919            intermediate_shape,
920            tensordot_fixed_position,
921            output_shape,
922            output_permutation,
923        }
924    }
925}
926
927impl<A> PairContractor<A> for StackedTensordotGeneral {
928    fn contract_pair<'a, 'b, 'c, 'd>(
929        &self,
930        lhs: &'b ArrayViewD<'a, A>,
931        rhs: &'d ArrayViewD<'c, A>,
932    ) -> ArrayD<A>
933    where
934        'a: 'b,
935        'c: 'd,
936        A: Clone + LinalgScalar,
937    {
938        let lhs_permuted = self.lhs_permutation.view_singleton(lhs);
939        let lhs_reshaped = Array::from_shape_vec(
940            IxDyn(&self.lhs_output_shape),
941            lhs_permuted.iter().cloned().collect(),
942        )
943        .unwrap();
944        let rhs_permuted = self.rhs_permutation.view_singleton(rhs);
945        let rhs_reshaped = Array::from_shape_vec(
946            IxDyn(&self.rhs_output_shape),
947            rhs_permuted.iter().cloned().collect(),
948        )
949        .unwrap();
950        let mut intermediate_result: ArrayD<A> = Array::zeros(IxDyn(&self.intermediate_shape));
951        let mut lhs_iter = lhs_reshaped.outer_iter();
952        let mut rhs_iter = rhs_reshaped.outer_iter();
953        for mut output_subview in intermediate_result.outer_iter_mut() {
954            let lhs_subview = lhs_iter.next().unwrap();
955            let rhs_subview = rhs_iter.next().unwrap();
956            self.tensordot_fixed_position.contract_and_assign_pair(
957                &lhs_subview,
958                &rhs_subview,
959                &mut output_subview,
960            );
961        }
962        let intermediate_reshaped = intermediate_result
963            .into_shape_with_order(IxDyn(&self.output_shape))
964            .unwrap();
965        self.output_permutation
966            .contract_singleton(&intermediate_reshaped.view())
967    }
968}