1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
// Copyright 2019 Jared Samet
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Methods to produce a `ContractionOrder`, specifying what order in which to perform pairwise contractions between tensors
//! in order to perform the full contraction.
use crate::SizedContraction;
use std::collections::HashSet;

/// Either an input operand or an intermediate result
#[derive(Debug, Clone)]
pub enum OperandNumber {
    Input(usize),
    IntermediateResult(usize),
}

/// Which two tensors to contract
#[derive(Debug, Clone)]
pub struct OperandNumPair {
    pub lhs: OperandNumber,
    pub rhs: OperandNumber,
}

/// A single pairwise contraction between two input operands, an input operand and an intermediate
/// result, or two intermediate results.
#[derive(Debug, Clone)]
pub struct Pair {
    /// The contraction to be performed
    pub sized_contraction: SizedContraction,

    /// Which two tensors to contract
    pub operand_nums: OperandNumPair,
}

/// The order in which to contract pairs of tensors and the specific contractions to be performed between the pairs.
///
/// Either a singleton contraction, in the case of a single input operand, or a list of pair contractions,
/// given two or more input operands
#[derive(Debug, Clone)]
pub enum ContractionOrder {
    /// If there's only one input operand, this is simply a clone of the original SizedContraction
    Singleton(SizedContraction),

    /// If there are two or more input operands, this is a vector of pairwise contractions between
    /// input operands and/or intermediate results from prior contractions.
    Pairs(Vec<Pair>),
}

/// Strategy for optimizing the contraction. The only currently supported options are "Naive" and "Reverse".
///
/// TODO: Figure out whether this should be done with traits
#[derive(Debug)]
pub enum OptimizationMethod {
    /// Contracts each pair of tensors in the order given in the input and uses the intermediate
    /// result as the LHS of the next contraction.
    Naive,

    /// Contracts each pair of tensors in the reverse of the order given in the input and uses the
    /// intermediate result as the LHS of the next contraction. Only implemented to help test
    /// that this is actually functioning properly.
    Reverse,

    /// (Not yet supported) Something like [this](https://optimized-einsum.readthedocs.io/en/latest/greedy_path.html)
    Greedy,

    /// (Not yet supported) Something like [this](https://optimized-einsum.readthedocs.io/en/latest/optimal_path.html)
    Optimal,

    /// (Not yet supported) Something like [this](https://optimized-einsum.readthedocs.io/en/latest/branching_path.html)
    Branch,
}

/// Returns a set of all the indices in any of the remaining operands or in the output
fn get_remaining_indices(operand_indices: &[Vec<char>], output_indices: &[char]) -> HashSet<char> {
    let mut result: HashSet<char> = HashSet::new();
    for &c in operand_indices.iter().flat_map(|s| s.iter()) {
        result.insert(c);
    }
    for &c in output_indices.iter() {
        result.insert(c);
    }
    result
}

/// Returns a set of all the indices in the LHS or the RHS
fn get_existing_indices(lhs_indices: &[char], rhs_indices: &[char]) -> HashSet<char> {
    let mut result: HashSet<char> = lhs_indices.iter().cloned().collect();
    for &c in rhs_indices.iter() {
        result.insert(c);
    }
    result
}

/// Returns a permuted version of `sized_contraction`, specified by `tensor_order`
fn generate_permuted_contraction(
    sized_contraction: &SizedContraction,
    tensor_order: &[usize],
) -> SizedContraction {
    // Reorder the operands of the SizedContraction and clone everything else
    assert_eq!(
        sized_contraction.contraction.operand_indices.len(),
        tensor_order.len()
    );
    let mut new_operand_indices = Vec::new();
    for &i in tensor_order {
        new_operand_indices.push(sized_contraction.contraction.operand_indices[i].clone());
    }
    sized_contraction
        .subset(
            &new_operand_indices,
            &sized_contraction.contraction.output_indices,
        )
        .unwrap()
}

/// Generates a mini-contraction corresponding to `lhs_operand_indices`,`rhs_operand_indices`->`output_indices`
fn generate_sized_contraction_pair(
    lhs_operand_indices: &[char],
    rhs_operand_indices: &[char],
    output_indices: &[char],
    orig_contraction: &SizedContraction,
) -> SizedContraction {
    orig_contraction
        .subset(
            &[lhs_operand_indices.to_vec(), rhs_operand_indices.to_vec()],
            output_indices,
        )
        .unwrap()
}

/// Generate the actual path consisting of all the mini-contractions. Currently always
/// contracts two input operands and then repeatedly uses the result as the LHS of the
/// next pairwise contraction.
fn generate_path(sized_contraction: &SizedContraction, tensor_order: &[usize]) -> ContractionOrder {
    // Generate the actual path consisting of all the mini-contractions.
    //
    // TODO: Take a &[OperandNumPair] instead of &[usize]
    // and Keep track of the intermediate results

    // Make a reordered full SizedContraction in the order specified by the called
    let permuted_contraction = generate_permuted_contraction(sized_contraction, tensor_order);

    match permuted_contraction.contraction.operand_indices.len() {
        1 => {
            // If there's only one input tensor, make a single-step path consisting of a
            // singleton contraction (operand_nums = None).
            ContractionOrder::Singleton(permuted_contraction.clone())
        }
        2 => {
            // If there's exactly two input tensors, make a single-step path consisting
            // of a pair contraction (operand_nums = Some(OperandNumPair)).
            let sc = generate_sized_contraction_pair(
                &permuted_contraction.contraction.operand_indices[0],
                &permuted_contraction.contraction.operand_indices[1],
                &permuted_contraction.contraction.output_indices,
                &permuted_contraction,
            );
            let operand_num_pair = OperandNumPair {
                lhs: OperandNumber::Input(tensor_order[0]),
                rhs: OperandNumber::Input(tensor_order[1]),
            };
            let only_step = Pair {
                sized_contraction: sc,
                operand_nums: operand_num_pair,
            };
            ContractionOrder::Pairs(vec![only_step])
        }
        _ => {
            // If there's three or more input tensors, we have some work to do.

            let mut steps = Vec::new();
            // In the main body of the loop, output_indices will contain the result of the prior pair
            // contraction. Initialize it to the elements of the first LHS tensor so that we can
            // clone it on the first go-around as well as all the later ones.
            let mut output_indices = permuted_contraction.contraction.operand_indices[0].clone();

            for idx_of_lhs in 0..(permuted_contraction.contraction.operand_indices.len() - 1) {
                // lhs_indices is either the first tensor (on the first iteration of the loop)
                // or the output from the previous step.
                let lhs_indices = output_indices.clone();

                // rhs_indices is always the next tensor.
                let idx_of_rhs = idx_of_lhs + 1;
                let rhs_indices = &permuted_contraction.contraction.operand_indices[idx_of_rhs];

                // existing_indices and remaining_indices are only needed to figure out
                // what output_indices will be for this step.
                //
                // existing_indices consists of the indices in either the LHS or the RHS tensor
                // for this step.
                //
                // remaining_indices consists of the indices in all the elements after the RHS
                // tensor or in the outputs.
                //
                // The output indices we want is the intersection of the two (unless this is
                // the RHS is the last operand, in which case it's just the output indices).
                //
                // For example, say the string is "ij,jk,kl,lm->im".
                // First iteration:
                //      lhs = [i,j]
                //      rhs = [j,k]
                //      existing = {i,j,k}
                //      remaining = {k,l,m,i} (the i is used in the final output so we need to
                //      keep it around)
                //      output = {i,k}
                //      Mini-contraction: ij,jk->ik
                // Second iteration:
                //      lhs = [i,k]
                //      rhs = [k,l]
                //      existing = {i,k,l}
                //      remaining = {l,m,i}
                //      output = {i,l}
                //      Mini-contraction: ik,kl->il
                // Third (and final) iteration:
                //      lhs = [i,l]
                //      rhs = [l,m]
                //      (Short-circuit) output = {i,m}
                //      Mini-contraction: il,lm->im
                output_indices =
                    if idx_of_rhs == (permuted_contraction.contraction.operand_indices.len() - 1) {
                        // Used up all the operands; just return output
                        permuted_contraction.contraction.output_indices.clone()
                    } else {
                        let existing_indices = get_existing_indices(&lhs_indices, rhs_indices);
                        let remaining_indices = get_remaining_indices(
                            &permuted_contraction.contraction.operand_indices[(idx_of_rhs + 1)..],
                            &permuted_contraction.contraction.output_indices,
                        );
                        existing_indices
                            .intersection(&remaining_indices)
                            .cloned()
                            .collect()
                    };

                // Phew, now make the mini-contraction.
                let sc = generate_sized_contraction_pair(
                    &lhs_indices,
                    &rhs_indices,
                    &output_indices,
                    &permuted_contraction,
                );

                let operand_nums = if idx_of_lhs == 0 {
                    OperandNumPair {
                        lhs: OperandNumber::Input(tensor_order[idx_of_lhs]), // tensor_order[0]
                        rhs: OperandNumber::Input(tensor_order[idx_of_rhs]), // tensor_order[1]
                    }
                } else {
                    OperandNumPair {
                        lhs: OperandNumber::IntermediateResult(idx_of_lhs - 1),
                        rhs: OperandNumber::Input(tensor_order[idx_of_rhs]),
                    }
                };
                steps.push(Pair {
                    sized_contraction: sc,
                    operand_nums,
                });
            }

            ContractionOrder::Pairs(steps)
        }
    }
}

/// Contracts the first two operands, then contracts the result with the third operand, etc.
fn naive_order(sized_contraction: &SizedContraction) -> Vec<usize> {
    (0..sized_contraction.contraction.operand_indices.len()).collect()
}

/// Contracts the last two operands, then contracts the result with the third-to-last operand, etc.
fn reverse_order(sized_contraction: &SizedContraction) -> Vec<usize> {
    (0..sized_contraction.contraction.operand_indices.len())
        .rev()
        .collect()
}

// TODO: Maybe this should take a function pointer from &SizedContraction -> Vec<usize>?
/// Given a `SizedContraction` and an optimization strategy, returns an order in which to
/// perform pairwise contractions in order to produce the final result
pub fn generate_optimized_order(
    sized_contraction: &SizedContraction,
    strategy: OptimizationMethod,
) -> ContractionOrder {
    let tensor_order = match strategy {
        OptimizationMethod::Naive => naive_order(sized_contraction),
        OptimizationMethod::Reverse => reverse_order(sized_contraction),
        _ => panic!("Unsupported optimization method"),
    };
    generate_path(sized_contraction, &tensor_order)
}