ndarray_einsum/contractors/
strategies.rs

1// Copyright 2019 Jared Samet
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module contains the "strategy choice" logic for which specific contractor
16//! should be used for a given mini-contraction.
17//!
18//! In general, `DiagonalizationAndSummation` should be able to accomodate all singleton
19//! contractions and `StackedTensordotGeneral` should be able to handle all pairs; however,
20//! other trait implementations might be faster.
21//!
22//! The code here has some duplication and is probably not the most idiomatic way to accomplish this.
23
24use crate::SizedContraction;
25use hashbrown::{HashMap, HashSet};
26
27#[derive(Copy, Clone, Debug)]
28pub enum SingletonMethod {
29    Identity,
30    Permutation,
31    Summation,
32    Diagonalization,
33    PermutationAndSummation,
34    DiagonalizationAndSummation,
35}
36
37#[derive(Copy, Clone, Debug)]
38pub struct SingletonSummary {
39    num_summed_axes: usize,
40    num_diagonalized_axes: usize,
41    num_reordered_axes: usize,
42}
43
44impl SingletonSummary {
45    pub fn new(sc: &SizedContraction) -> Self {
46        assert_eq!(sc.contraction.operand_indices.len(), 1);
47        let output_indices = &sc.contraction.output_indices;
48        let input_indices = &sc.contraction.operand_indices[0];
49
50        SingletonSummary::from_indices(input_indices, output_indices)
51    }
52
53    fn from_indices(input_indices: &[char], output_indices: &[char]) -> Self {
54        let mut input_counts = HashMap::new();
55        for &c in input_indices.iter() {
56            *input_counts.entry(c).or_insert(0) += 1;
57        }
58        let num_summed_axes = input_counts.len() - output_indices.len();
59        let num_diagonalized_axes = input_counts.iter().filter(|&(_, &v)| v > 1).count();
60        let num_reordered_axes = output_indices
61            .iter()
62            .zip(input_indices.iter())
63            .filter(|&(&output_char, &input_char)| output_char != input_char)
64            .count();
65
66        SingletonSummary {
67            num_summed_axes,
68            num_diagonalized_axes,
69            num_reordered_axes,
70        }
71    }
72
73    pub fn get_strategy(&self) -> SingletonMethod {
74        match (
75            self.num_summed_axes,
76            self.num_diagonalized_axes,
77            self.num_reordered_axes,
78        ) {
79            (0, 0, 0) => SingletonMethod::Identity,
80            (0, 0, _) => SingletonMethod::Permutation,
81            (_, 0, 0) => SingletonMethod::Summation,
82            (0, _, _) => SingletonMethod::Diagonalization,
83            (_, 0, _) => SingletonMethod::PermutationAndSummation,
84            (_, _, _) => SingletonMethod::DiagonalizationAndSummation,
85        }
86    }
87}
88
89#[allow(dead_code)]
90#[derive(Debug, Copy, Clone)]
91pub enum PairMethod {
92    HadamardProduct,
93    HadamardProductGeneral,
94    TensordotFixedPosition,
95    TensordotGeneral,
96    ScalarMatrixProduct,
97    ScalarMatrixProductGeneral,
98    MatrixScalarProduct,
99    MatrixScalarProductGeneral,
100    BroadcastProductGeneral,
101    StackedTensordotGeneral,
102}
103
104#[derive(Debug, Clone)]
105pub struct PairSummary {
106    num_stacked_axes: usize,
107    num_lhs_outer_axes: usize,
108    num_rhs_outer_axes: usize,
109    num_contracted_axes: usize,
110}
111
112impl PairSummary {
113    pub fn new(sc: &SizedContraction) -> Self {
114        assert_eq!(sc.contraction.operand_indices.len(), 2);
115        let output_indices = &sc.contraction.output_indices;
116        let lhs_indices = &sc.contraction.operand_indices[0];
117        let rhs_indices = &sc.contraction.operand_indices[1];
118
119        PairSummary::from_indices(lhs_indices, rhs_indices, output_indices)
120    }
121
122    fn from_indices(lhs_indices: &[char], rhs_indices: &[char], output_indices: &[char]) -> Self {
123        let lhs_uniques: HashSet<char> = lhs_indices.iter().cloned().collect();
124        let rhs_uniques: HashSet<char> = rhs_indices.iter().cloned().collect();
125        let output_uniques: HashSet<char> = output_indices.iter().cloned().collect();
126        assert_eq!(lhs_indices.len(), lhs_uniques.len());
127        assert_eq!(rhs_indices.len(), rhs_uniques.len());
128        assert_eq!(output_indices.len(), output_uniques.len());
129
130        let lhs_and_rhs: HashSet<char> = lhs_uniques.intersection(&rhs_uniques).cloned().collect();
131        let stacked: HashSet<char> = lhs_and_rhs.intersection(&output_uniques).cloned().collect();
132
133        let num_stacked_axes = stacked.len();
134        let num_contracted_axes = lhs_and_rhs.len() - num_stacked_axes;
135        let num_lhs_outer_axes = lhs_uniques.len() - num_stacked_axes - num_contracted_axes;
136        let num_rhs_outer_axes = rhs_uniques.len() - num_stacked_axes - num_contracted_axes;
137
138        PairSummary {
139            num_stacked_axes,
140            num_lhs_outer_axes,
141            num_rhs_outer_axes,
142            num_contracted_axes,
143        }
144    }
145
146    pub fn get_strategy(&self) -> PairMethod {
147        match (
148            self.num_contracted_axes,
149            self.num_lhs_outer_axes,
150            self.num_rhs_outer_axes,
151            self.num_stacked_axes,
152        ) {
153            (0, 0, 0, _) => PairMethod::HadamardProductGeneral,
154            (0, 0, _, 0) => PairMethod::ScalarMatrixProductGeneral,
155            (0, _, 0, 0) => PairMethod::MatrixScalarProductGeneral,
156            // This contractor works, but appears to be slower
157            // than StackedTensordotGeneral
158            // (0, _, _, _) => PairMethod::BroadcastProductGeneral,
159            (_, _, _, 0) => PairMethod::TensordotGeneral,
160            (_, _, _, _) => PairMethod::StackedTensordotGeneral,
161        }
162    }
163}