ndarray_einsum/optimizers.rs
1// Copyright 2019 Jared Samet
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Methods to produce a `ContractionOrder`, specifying what order in which to perform pairwise contractions between tensors
16//! in order to perform the full contraction.
17use crate::SizedContraction;
18use hashbrown::HashSet;
19
20/// Either an input operand or an intermediate result
21#[derive(Debug, Clone)]
22pub enum OperandNumber {
23 Input(usize),
24 IntermediateResult(usize),
25}
26
27/// Which two tensors to contract
28#[derive(Debug, Clone)]
29pub struct OperandNumPair {
30 pub lhs: OperandNumber,
31 pub rhs: OperandNumber,
32}
33
34/// A single pairwise contraction between two input operands, an input operand and an intermediate
35/// result, or two intermediate results.
36#[derive(Debug, Clone)]
37pub struct Pair {
38 /// The contraction to be performed
39 pub sized_contraction: SizedContraction,
40
41 /// Which two tensors to contract
42 pub operand_nums: OperandNumPair,
43}
44
45/// The order in which to contract pairs of tensors and the specific contractions to be performed between the pairs.
46///
47/// Either a singleton contraction, in the case of a single input operand, or a list of pair contractions,
48/// given two or more input operands
49#[derive(Debug, Clone)]
50pub enum ContractionOrder {
51 /// If there's only one input operand, this is simply a clone of the original SizedContraction
52 Singleton(SizedContraction),
53
54 /// If there are two or more input operands, this is a vector of pairwise contractions between
55 /// input operands and/or intermediate results from prior contractions.
56 Pairs(Vec<Pair>),
57}
58
59/// Strategy for optimizing the contraction. The only currently supported options are "Naive" and "Reverse".
60///
61/// TODO: Figure out whether this should be done with traits
62#[derive(Debug)]
63pub enum OptimizationMethod {
64 /// Contracts each pair of tensors in the order given in the input and uses the intermediate
65 /// result as the LHS of the next contraction.
66 Naive,
67
68 /// Contracts each pair of tensors in the reverse of the order given in the input and uses the
69 /// intermediate result as the LHS of the next contraction. Only implemented to help test
70 /// that this is actually functioning properly.
71 Reverse,
72
73 /// (Not yet supported) Something like [this](https://optimized-einsum.readthedocs.io/en/latest/greedy_path.html)
74 Greedy,
75
76 /// (Not yet supported) Something like [this](https://optimized-einsum.readthedocs.io/en/latest/optimal_path.html)
77 Optimal,
78
79 /// (Not yet supported) Something like [this](https://optimized-einsum.readthedocs.io/en/latest/branching_path.html)
80 Branch,
81}
82
83/// Returns a set of all the indices in any of the remaining operands or in the output
84fn get_remaining_indices(operand_indices: &[Vec<char>], output_indices: &[char]) -> HashSet<char> {
85 let mut result: HashSet<char> = HashSet::new();
86 for &c in operand_indices.iter().flat_map(|s| s.iter()) {
87 result.insert(c);
88 }
89 for &c in output_indices.iter() {
90 result.insert(c);
91 }
92 result
93}
94
95/// Returns a set of all the indices in the LHS or the RHS
96fn get_existing_indices(lhs_indices: &[char], rhs_indices: &[char]) -> HashSet<char> {
97 let mut result: HashSet<char> = lhs_indices.iter().cloned().collect();
98 for &c in rhs_indices.iter() {
99 result.insert(c);
100 }
101 result
102}
103
104/// Returns a permuted version of `sized_contraction`, specified by `tensor_order`
105fn generate_permuted_contraction(
106 sized_contraction: &SizedContraction,
107 tensor_order: &[usize],
108) -> SizedContraction {
109 // Reorder the operands of the SizedContraction and clone everything else
110 assert_eq!(
111 sized_contraction.contraction.operand_indices.len(),
112 tensor_order.len()
113 );
114 let mut new_operand_indices = Vec::new();
115 for &i in tensor_order {
116 new_operand_indices.push(sized_contraction.contraction.operand_indices[i].clone());
117 }
118 sized_contraction
119 .subset(
120 &new_operand_indices,
121 &sized_contraction.contraction.output_indices,
122 )
123 .unwrap()
124}
125
126/// Generates a mini-contraction corresponding to `lhs_operand_indices`,`rhs_operand_indices`->`output_indices`
127fn generate_sized_contraction_pair(
128 lhs_operand_indices: &[char],
129 rhs_operand_indices: &[char],
130 output_indices: &[char],
131 orig_contraction: &SizedContraction,
132) -> SizedContraction {
133 orig_contraction
134 .subset(
135 &[lhs_operand_indices.to_vec(), rhs_operand_indices.to_vec()],
136 output_indices,
137 )
138 .unwrap()
139}
140
141/// Generate the actual path consisting of all the mini-contractions. Currently always
142/// contracts two input operands and then repeatedly uses the result as the LHS of the
143/// next pairwise contraction.
144fn generate_path(sized_contraction: &SizedContraction, tensor_order: &[usize]) -> ContractionOrder {
145 // Generate the actual path consisting of all the mini-contractions.
146 //
147 // TODO: Take a &[OperandNumPair] instead of &[usize]
148 // and Keep track of the intermediate results
149
150 // Make a reordered full SizedContraction in the order specified by the called
151 let permuted_contraction = generate_permuted_contraction(sized_contraction, tensor_order);
152
153 match permuted_contraction.contraction.operand_indices.len() {
154 1 => {
155 // If there's only one input tensor, make a single-step path consisting of a
156 // singleton contraction (operand_nums = None).
157 ContractionOrder::Singleton(permuted_contraction)
158 }
159 2 => {
160 // If there's exactly two input tensors, make a single-step path consisting
161 // of a pair contraction (operand_nums = Some(OperandNumPair)).
162 let sc = generate_sized_contraction_pair(
163 &permuted_contraction.contraction.operand_indices[0],
164 &permuted_contraction.contraction.operand_indices[1],
165 &permuted_contraction.contraction.output_indices,
166 &permuted_contraction,
167 );
168 let operand_num_pair = OperandNumPair {
169 lhs: OperandNumber::Input(tensor_order[0]),
170 rhs: OperandNumber::Input(tensor_order[1]),
171 };
172 let only_step = Pair {
173 sized_contraction: sc,
174 operand_nums: operand_num_pair,
175 };
176 ContractionOrder::Pairs(vec![only_step])
177 }
178 _ => {
179 // If there's three or more input tensors, we have some work to do.
180
181 let mut steps = Vec::new();
182 // In the main body of the loop, output_indices will contain the result of the prior pair
183 // contraction. Initialize it to the elements of the first LHS tensor so that we can
184 // clone it on the first go-around as well as all the later ones.
185 let mut output_indices = permuted_contraction.contraction.operand_indices[0].clone();
186
187 for idx_of_lhs in 0..(permuted_contraction.contraction.operand_indices.len() - 1) {
188 // lhs_indices is either the first tensor (on the first iteration of the loop)
189 // or the output from the previous step.
190 let lhs_indices = output_indices.clone();
191
192 // rhs_indices is always the next tensor.
193 let idx_of_rhs = idx_of_lhs + 1;
194 let rhs_indices = &permuted_contraction.contraction.operand_indices[idx_of_rhs];
195
196 // existing_indices and remaining_indices are only needed to figure out
197 // what output_indices will be for this step.
198 //
199 // existing_indices consists of the indices in either the LHS or the RHS tensor
200 // for this step.
201 //
202 // remaining_indices consists of the indices in all the elements after the RHS
203 // tensor or in the outputs.
204 //
205 // The output indices we want is the intersection of the two (unless this is
206 // the RHS is the last operand, in which case it's just the output indices).
207 //
208 // For example, say the string is "ij,jk,kl,lm->im".
209 // First iteration:
210 // lhs = [i,j]
211 // rhs = [j,k]
212 // existing = {i,j,k}
213 // remaining = {k,l,m,i} (the i is used in the final output so we need to
214 // keep it around)
215 // output = {i,k}
216 // Mini-contraction: ij,jk->ik
217 // Second iteration:
218 // lhs = [i,k]
219 // rhs = [k,l]
220 // existing = {i,k,l}
221 // remaining = {l,m,i}
222 // output = {i,l}
223 // Mini-contraction: ik,kl->il
224 // Third (and final) iteration:
225 // lhs = [i,l]
226 // rhs = [l,m]
227 // (Short-circuit) output = {i,m}
228 // Mini-contraction: il,lm->im
229 output_indices =
230 if idx_of_rhs == (permuted_contraction.contraction.operand_indices.len() - 1) {
231 // Used up all the operands; just return output
232 permuted_contraction.contraction.output_indices.clone()
233 } else {
234 let existing_indices = get_existing_indices(&lhs_indices, rhs_indices);
235 let remaining_indices = get_remaining_indices(
236 &permuted_contraction.contraction.operand_indices[(idx_of_rhs + 1)..],
237 &permuted_contraction.contraction.output_indices,
238 );
239 existing_indices
240 .intersection(&remaining_indices)
241 .cloned()
242 .collect()
243 };
244
245 // Phew, now make the mini-contraction.
246 let sc = generate_sized_contraction_pair(
247 &lhs_indices,
248 rhs_indices,
249 &output_indices,
250 &permuted_contraction,
251 );
252
253 let operand_nums = if idx_of_lhs == 0 {
254 OperandNumPair {
255 lhs: OperandNumber::Input(tensor_order[idx_of_lhs]), // tensor_order[0]
256 rhs: OperandNumber::Input(tensor_order[idx_of_rhs]), // tensor_order[1]
257 }
258 } else {
259 OperandNumPair {
260 lhs: OperandNumber::IntermediateResult(idx_of_lhs - 1),
261 rhs: OperandNumber::Input(tensor_order[idx_of_rhs]),
262 }
263 };
264 steps.push(Pair {
265 sized_contraction: sc,
266 operand_nums,
267 });
268 }
269
270 ContractionOrder::Pairs(steps)
271 }
272 }
273}
274
275/// Contracts the first two operands, then contracts the result with the third operand, etc.
276fn naive_order(sized_contraction: &SizedContraction) -> Vec<usize> {
277 (0..sized_contraction.contraction.operand_indices.len()).collect()
278}
279
280/// Contracts the last two operands, then contracts the result with the third-to-last operand, etc.
281fn reverse_order(sized_contraction: &SizedContraction) -> Vec<usize> {
282 (0..sized_contraction.contraction.operand_indices.len())
283 .rev()
284 .collect()
285}
286
287// TODO: Maybe this should take a function pointer from &SizedContraction -> Vec<usize>?
288/// Given a `SizedContraction` and an optimization strategy, returns an order in which to
289/// perform pairwise contractions in order to produce the final result
290pub fn generate_optimized_order(
291 sized_contraction: &SizedContraction,
292 strategy: OptimizationMethod,
293) -> ContractionOrder {
294 let tensor_order = match strategy {
295 OptimizationMethod::Naive => naive_order(sized_contraction),
296 OptimizationMethod::Reverse => reverse_order(sized_contraction),
297 _ => panic!("Unsupported optimization method"),
298 };
299 generate_path(sized_contraction, &tensor_order)
300}