ndarray_einsum/
optimizers.rs

1// Copyright 2019 Jared Samet
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Methods to produce a `ContractionOrder`, specifying what order in which to perform pairwise contractions between tensors
16//! in order to perform the full contraction.
17use crate::SizedContraction;
18use hashbrown::HashSet;
19
20/// Either an input operand or an intermediate result
21#[derive(Debug, Clone)]
22pub enum OperandNumber {
23    Input(usize),
24    IntermediateResult(usize),
25}
26
27/// Which two tensors to contract
28#[derive(Debug, Clone)]
29pub struct OperandNumPair {
30    pub lhs: OperandNumber,
31    pub rhs: OperandNumber,
32}
33
34/// A single pairwise contraction between two input operands, an input operand and an intermediate
35/// result, or two intermediate results.
36#[derive(Debug, Clone)]
37pub struct Pair {
38    /// The contraction to be performed
39    pub sized_contraction: SizedContraction,
40
41    /// Which two tensors to contract
42    pub operand_nums: OperandNumPair,
43}
44
45/// The order in which to contract pairs of tensors and the specific contractions to be performed between the pairs.
46///
47/// Either a singleton contraction, in the case of a single input operand, or a list of pair contractions,
48/// given two or more input operands
49#[derive(Debug, Clone)]
50pub enum ContractionOrder {
51    /// If there's only one input operand, this is simply a clone of the original SizedContraction
52    Singleton(SizedContraction),
53
54    /// If there are two or more input operands, this is a vector of pairwise contractions between
55    /// input operands and/or intermediate results from prior contractions.
56    Pairs(Vec<Pair>),
57}
58
59/// Strategy for optimizing the contraction. The only currently supported options are "Naive" and "Reverse".
60///
61/// TODO: Figure out whether this should be done with traits
62#[derive(Debug)]
63pub enum OptimizationMethod {
64    /// Contracts each pair of tensors in the order given in the input and uses the intermediate
65    /// result as the LHS of the next contraction.
66    Naive,
67
68    /// Contracts each pair of tensors in the reverse of the order given in the input and uses the
69    /// intermediate result as the LHS of the next contraction. Only implemented to help test
70    /// that this is actually functioning properly.
71    Reverse,
72
73    /// (Not yet supported) Something like [this](https://optimized-einsum.readthedocs.io/en/latest/greedy_path.html)
74    Greedy,
75
76    /// (Not yet supported) Something like [this](https://optimized-einsum.readthedocs.io/en/latest/optimal_path.html)
77    Optimal,
78
79    /// (Not yet supported) Something like [this](https://optimized-einsum.readthedocs.io/en/latest/branching_path.html)
80    Branch,
81}
82
83/// Returns a set of all the indices in any of the remaining operands or in the output
84fn get_remaining_indices(operand_indices: &[Vec<char>], output_indices: &[char]) -> HashSet<char> {
85    let mut result: HashSet<char> = HashSet::new();
86    for &c in operand_indices.iter().flat_map(|s| s.iter()) {
87        result.insert(c);
88    }
89    for &c in output_indices.iter() {
90        result.insert(c);
91    }
92    result
93}
94
95/// Returns a set of all the indices in the LHS or the RHS
96fn get_existing_indices(lhs_indices: &[char], rhs_indices: &[char]) -> HashSet<char> {
97    let mut result: HashSet<char> = lhs_indices.iter().cloned().collect();
98    for &c in rhs_indices.iter() {
99        result.insert(c);
100    }
101    result
102}
103
104/// Returns a permuted version of `sized_contraction`, specified by `tensor_order`
105fn generate_permuted_contraction(
106    sized_contraction: &SizedContraction,
107    tensor_order: &[usize],
108) -> SizedContraction {
109    // Reorder the operands of the SizedContraction and clone everything else
110    assert_eq!(
111        sized_contraction.contraction.operand_indices.len(),
112        tensor_order.len()
113    );
114    let mut new_operand_indices = Vec::new();
115    for &i in tensor_order {
116        new_operand_indices.push(sized_contraction.contraction.operand_indices[i].clone());
117    }
118    sized_contraction
119        .subset(
120            &new_operand_indices,
121            &sized_contraction.contraction.output_indices,
122        )
123        .unwrap()
124}
125
126/// Generates a mini-contraction corresponding to `lhs_operand_indices`,`rhs_operand_indices`->`output_indices`
127fn generate_sized_contraction_pair(
128    lhs_operand_indices: &[char],
129    rhs_operand_indices: &[char],
130    output_indices: &[char],
131    orig_contraction: &SizedContraction,
132) -> SizedContraction {
133    orig_contraction
134        .subset(
135            &[lhs_operand_indices.to_vec(), rhs_operand_indices.to_vec()],
136            output_indices,
137        )
138        .unwrap()
139}
140
141/// Generate the actual path consisting of all the mini-contractions. Currently always
142/// contracts two input operands and then repeatedly uses the result as the LHS of the
143/// next pairwise contraction.
144fn generate_path(sized_contraction: &SizedContraction, tensor_order: &[usize]) -> ContractionOrder {
145    // Generate the actual path consisting of all the mini-contractions.
146    //
147    // TODO: Take a &[OperandNumPair] instead of &[usize]
148    // and Keep track of the intermediate results
149
150    // Make a reordered full SizedContraction in the order specified by the called
151    let permuted_contraction = generate_permuted_contraction(sized_contraction, tensor_order);
152
153    match permuted_contraction.contraction.operand_indices.len() {
154        1 => {
155            // If there's only one input tensor, make a single-step path consisting of a
156            // singleton contraction (operand_nums = None).
157            ContractionOrder::Singleton(permuted_contraction)
158        }
159        2 => {
160            // If there's exactly two input tensors, make a single-step path consisting
161            // of a pair contraction (operand_nums = Some(OperandNumPair)).
162            let sc = generate_sized_contraction_pair(
163                &permuted_contraction.contraction.operand_indices[0],
164                &permuted_contraction.contraction.operand_indices[1],
165                &permuted_contraction.contraction.output_indices,
166                &permuted_contraction,
167            );
168            let operand_num_pair = OperandNumPair {
169                lhs: OperandNumber::Input(tensor_order[0]),
170                rhs: OperandNumber::Input(tensor_order[1]),
171            };
172            let only_step = Pair {
173                sized_contraction: sc,
174                operand_nums: operand_num_pair,
175            };
176            ContractionOrder::Pairs(vec![only_step])
177        }
178        _ => {
179            // If there's three or more input tensors, we have some work to do.
180
181            let mut steps = Vec::new();
182            // In the main body of the loop, output_indices will contain the result of the prior pair
183            // contraction. Initialize it to the elements of the first LHS tensor so that we can
184            // clone it on the first go-around as well as all the later ones.
185            let mut output_indices = permuted_contraction.contraction.operand_indices[0].clone();
186
187            for idx_of_lhs in 0..(permuted_contraction.contraction.operand_indices.len() - 1) {
188                // lhs_indices is either the first tensor (on the first iteration of the loop)
189                // or the output from the previous step.
190                let lhs_indices = output_indices.clone();
191
192                // rhs_indices is always the next tensor.
193                let idx_of_rhs = idx_of_lhs + 1;
194                let rhs_indices = &permuted_contraction.contraction.operand_indices[idx_of_rhs];
195
196                // existing_indices and remaining_indices are only needed to figure out
197                // what output_indices will be for this step.
198                //
199                // existing_indices consists of the indices in either the LHS or the RHS tensor
200                // for this step.
201                //
202                // remaining_indices consists of the indices in all the elements after the RHS
203                // tensor or in the outputs.
204                //
205                // The output indices we want is the intersection of the two (unless this is
206                // the RHS is the last operand, in which case it's just the output indices).
207                //
208                // For example, say the string is "ij,jk,kl,lm->im".
209                // First iteration:
210                //      lhs = [i,j]
211                //      rhs = [j,k]
212                //      existing = {i,j,k}
213                //      remaining = {k,l,m,i} (the i is used in the final output so we need to
214                //      keep it around)
215                //      output = {i,k}
216                //      Mini-contraction: ij,jk->ik
217                // Second iteration:
218                //      lhs = [i,k]
219                //      rhs = [k,l]
220                //      existing = {i,k,l}
221                //      remaining = {l,m,i}
222                //      output = {i,l}
223                //      Mini-contraction: ik,kl->il
224                // Third (and final) iteration:
225                //      lhs = [i,l]
226                //      rhs = [l,m]
227                //      (Short-circuit) output = {i,m}
228                //      Mini-contraction: il,lm->im
229                output_indices =
230                    if idx_of_rhs == (permuted_contraction.contraction.operand_indices.len() - 1) {
231                        // Used up all the operands; just return output
232                        permuted_contraction.contraction.output_indices.clone()
233                    } else {
234                        let existing_indices = get_existing_indices(&lhs_indices, rhs_indices);
235                        let remaining_indices = get_remaining_indices(
236                            &permuted_contraction.contraction.operand_indices[(idx_of_rhs + 1)..],
237                            &permuted_contraction.contraction.output_indices,
238                        );
239                        existing_indices
240                            .intersection(&remaining_indices)
241                            .cloned()
242                            .collect()
243                    };
244
245                // Phew, now make the mini-contraction.
246                let sc = generate_sized_contraction_pair(
247                    &lhs_indices,
248                    rhs_indices,
249                    &output_indices,
250                    &permuted_contraction,
251                );
252
253                let operand_nums = if idx_of_lhs == 0 {
254                    OperandNumPair {
255                        lhs: OperandNumber::Input(tensor_order[idx_of_lhs]), // tensor_order[0]
256                        rhs: OperandNumber::Input(tensor_order[idx_of_rhs]), // tensor_order[1]
257                    }
258                } else {
259                    OperandNumPair {
260                        lhs: OperandNumber::IntermediateResult(idx_of_lhs - 1),
261                        rhs: OperandNumber::Input(tensor_order[idx_of_rhs]),
262                    }
263                };
264                steps.push(Pair {
265                    sized_contraction: sc,
266                    operand_nums,
267                });
268            }
269
270            ContractionOrder::Pairs(steps)
271        }
272    }
273}
274
275/// Contracts the first two operands, then contracts the result with the third operand, etc.
276fn naive_order(sized_contraction: &SizedContraction) -> Vec<usize> {
277    (0..sized_contraction.contraction.operand_indices.len()).collect()
278}
279
280/// Contracts the last two operands, then contracts the result with the third-to-last operand, etc.
281fn reverse_order(sized_contraction: &SizedContraction) -> Vec<usize> {
282    (0..sized_contraction.contraction.operand_indices.len())
283        .rev()
284        .collect()
285}
286
287// TODO: Maybe this should take a function pointer from &SizedContraction -> Vec<usize>?
288/// Given a `SizedContraction` and an optimization strategy, returns an order in which to
289/// perform pairwise contractions in order to produce the final result
290pub fn generate_optimized_order(
291    sized_contraction: &SizedContraction,
292    strategy: OptimizationMethod,
293) -> ContractionOrder {
294    let tensor_order = match strategy {
295        OptimizationMethod::Naive => naive_order(sized_contraction),
296        OptimizationMethod::Reverse => reverse_order(sized_contraction),
297        _ => panic!("Unsupported optimization method"),
298    };
299    generate_path(sized_contraction, &tensor_order)
300}