ndarray_einsum/contractors/
strategies.rs1use crate::SizedContraction;
25use hashbrown::{HashMap, HashSet};
26
27#[derive(Copy, Clone, Debug)]
28pub enum SingletonMethod {
29 Identity,
30 Permutation,
31 Summation,
32 Diagonalization,
33 PermutationAndSummation,
34 DiagonalizationAndSummation,
35}
36
37#[derive(Copy, Clone, Debug)]
38pub struct SingletonSummary {
39 num_summed_axes: usize,
40 num_diagonalized_axes: usize,
41 num_reordered_axes: usize,
42}
43
44impl SingletonSummary {
45 pub fn new(sc: &SizedContraction) -> Self {
46 assert_eq!(sc.contraction.operand_indices.len(), 1);
47 let output_indices = &sc.contraction.output_indices;
48 let input_indices = &sc.contraction.operand_indices[0];
49
50 SingletonSummary::from_indices(input_indices, output_indices)
51 }
52
53 fn from_indices(input_indices: &[char], output_indices: &[char]) -> Self {
54 let mut input_counts = HashMap::new();
55 for &c in input_indices.iter() {
56 *input_counts.entry(c).or_insert(0) += 1;
57 }
58 let num_summed_axes = input_counts.len() - output_indices.len();
59 let num_diagonalized_axes = input_counts.iter().filter(|&(_, &v)| v > 1).count();
60 let num_reordered_axes = output_indices
61 .iter()
62 .zip(input_indices.iter())
63 .filter(|&(&output_char, &input_char)| output_char != input_char)
64 .count();
65
66 SingletonSummary {
67 num_summed_axes,
68 num_diagonalized_axes,
69 num_reordered_axes,
70 }
71 }
72
73 pub fn get_strategy(&self) -> SingletonMethod {
74 match (
75 self.num_summed_axes,
76 self.num_diagonalized_axes,
77 self.num_reordered_axes,
78 ) {
79 (0, 0, 0) => SingletonMethod::Identity,
80 (0, 0, _) => SingletonMethod::Permutation,
81 (_, 0, 0) => SingletonMethod::Summation,
82 (0, _, _) => SingletonMethod::Diagonalization,
83 (_, 0, _) => SingletonMethod::PermutationAndSummation,
84 (_, _, _) => SingletonMethod::DiagonalizationAndSummation,
85 }
86 }
87}
88
89#[allow(dead_code)]
90#[derive(Debug, Copy, Clone)]
91pub enum PairMethod {
92 HadamardProduct,
93 HadamardProductGeneral,
94 TensordotFixedPosition,
95 TensordotGeneral,
96 ScalarMatrixProduct,
97 ScalarMatrixProductGeneral,
98 MatrixScalarProduct,
99 MatrixScalarProductGeneral,
100 BroadcastProductGeneral,
101 StackedTensordotGeneral,
102}
103
104#[derive(Debug, Clone)]
105pub struct PairSummary {
106 num_stacked_axes: usize,
107 num_lhs_outer_axes: usize,
108 num_rhs_outer_axes: usize,
109 num_contracted_axes: usize,
110}
111
112impl PairSummary {
113 pub fn new(sc: &SizedContraction) -> Self {
114 assert_eq!(sc.contraction.operand_indices.len(), 2);
115 let output_indices = &sc.contraction.output_indices;
116 let lhs_indices = &sc.contraction.operand_indices[0];
117 let rhs_indices = &sc.contraction.operand_indices[1];
118
119 PairSummary::from_indices(lhs_indices, rhs_indices, output_indices)
120 }
121
122 fn from_indices(lhs_indices: &[char], rhs_indices: &[char], output_indices: &[char]) -> Self {
123 let lhs_uniques: HashSet<char> = lhs_indices.iter().cloned().collect();
124 let rhs_uniques: HashSet<char> = rhs_indices.iter().cloned().collect();
125 let output_uniques: HashSet<char> = output_indices.iter().cloned().collect();
126 assert_eq!(lhs_indices.len(), lhs_uniques.len());
127 assert_eq!(rhs_indices.len(), rhs_uniques.len());
128 assert_eq!(output_indices.len(), output_uniques.len());
129
130 let lhs_and_rhs: HashSet<char> = lhs_uniques.intersection(&rhs_uniques).cloned().collect();
131 let stacked: HashSet<char> = lhs_and_rhs.intersection(&output_uniques).cloned().collect();
132
133 let num_stacked_axes = stacked.len();
134 let num_contracted_axes = lhs_and_rhs.len() - num_stacked_axes;
135 let num_lhs_outer_axes = lhs_uniques.len() - num_stacked_axes - num_contracted_axes;
136 let num_rhs_outer_axes = rhs_uniques.len() - num_stacked_axes - num_contracted_axes;
137
138 PairSummary {
139 num_stacked_axes,
140 num_lhs_outer_axes,
141 num_rhs_outer_axes,
142 num_contracted_axes,
143 }
144 }
145
146 pub fn get_strategy(&self) -> PairMethod {
147 match (
148 self.num_contracted_axes,
149 self.num_lhs_outer_axes,
150 self.num_rhs_outer_axes,
151 self.num_stacked_axes,
152 ) {
153 (0, 0, 0, _) => PairMethod::HadamardProductGeneral,
154 (0, 0, _, 0) => PairMethod::ScalarMatrixProductGeneral,
155 (0, _, 0, 0) => PairMethod::MatrixScalarProductGeneral,
156 (_, _, _, 0) => PairMethod::TensordotGeneral,
160 (_, _, _, _) => PairMethod::StackedTensordotGeneral,
161 }
162 }
163}