ndarray_einsum/contractors/
mod.rs

1// Copyright 2019 Jared Samet
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementations of the base-case singleton and pair contractors for different types of contractions.
16//!
17//! This module defines the `SingletonViewer`, `SingletonContractor`, and `PairContractor` traits as well
18//! as the generic "container" objects `EinsumPath`, `SingletonContraction`, and `PairContraction` that
19//! hold `Box`ed trait objects of the specific cases determined at runtime to be most appropriate
20//! for the requested contraction.
21//!
22//! The specific singleton and pair contractors defined in the `singleton_contractors` and
23//! `pair_contractors` submodules implement the relevant traits defined here. The six specific singleton
24//! contractors defined in `singleton_contractors` perform some combination of permutation of the input
25//! axes (e.g. `ijk->jki`), diagonalization across repeated but un-summed axes (e.g. `ii->i`), and
26//! summation across axes not present in the output index list (e.g. `ijk->j`). Not all of the nine
27//! pair contractors defined in `pair_contractors` are currently used as some appear to be slower than others.
28//!
29//! Each struct implementing one of the `*Contractor` traits performs all the "setup work"
30//! required to perform the actual contraction. For example, `HadamardProductGeneral` permutes
31//! the input and output tensors and then computes the element-wise product of the two tensors.
32//! Given a `SizedContraction` (but no actual tensors), `HadamardProductGeneral::new()` figures out
33//! the permutation orders that will be needed so that `contract_pair` can simply execute the two
34//! permutations and then produce the element-wise product. This can be thought of as a way of
35//! compiling the `einsum` string into a set of instructions and the `EinsumPath` object
36//! can be thought of as an AST that is ready to compute a contraction when supplied with an
37//! actual set of operands to contract.
38
39use crate::optimizers::{
40    generate_optimized_order, ContractionOrder, OperandNumber, OptimizationMethod,
41};
42use crate::{ArrayLike, SizedContraction};
43use hashbrown::HashSet;
44use ndarray::prelude::*;
45use ndarray::LinalgScalar;
46use std::fmt::Debug;
47
48mod singleton_contractors;
49use singleton_contractors::{
50    Diagonalization, DiagonalizationAndSummation, Identity, Permutation, PermutationAndSummation,
51    Summation,
52};
53
54mod pair_contractors;
55pub use pair_contractors::TensordotGeneral;
56use pair_contractors::{
57    BroadcastProductGeneral, HadamardProduct, HadamardProductGeneral, MatrixScalarProduct,
58    MatrixScalarProductGeneral, ScalarMatrixProduct, ScalarMatrixProductGeneral,
59    StackedTensordotGeneral, TensordotFixedPosition,
60};
61
62mod strategies;
63use strategies::{PairMethod, PairSummary, SingletonMethod, SingletonSummary};
64
65/// `let new_view = obj.view_singleton(tensor_view);`
66///
67/// This trait represents contractions that can be performed by returning a view of the original
68/// tensor. The structs that currently implement this view are the ones that don't perform
69/// any summation over indices and hence return only a subset of the elements of the original tensor:
70/// `Identity`, `Permutation`, and `Diagonalization`. Note that whether `Diagonalization`
71/// can actually return a view is dependent on the memory layout of the input tensor; if the input
72/// tensor is not contiguous, `diag.view_singleton()` will `panic`.
73pub trait SingletonViewer<A>: Debug {
74    fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
75    where
76        'a: 'b,
77        A: Clone + LinalgScalar;
78}
79
80/// `let new_array = obj.contract_singleton(tensor_view);`
81///
82/// All singleton contractions should implement this trait. It returns a new owned `ArrayD`.
83pub trait SingletonContractor<A>: Debug {
84    fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
85    where
86        'a: 'b,
87        A: Clone + LinalgScalar;
88}
89
90/// `let new_array = obj.contract_pair(lhs_view, rhs_view);`
91///
92/// All pair contractions should implement this trait. It returns a new owned `ArrayD`. The trait
93/// also has a method with a default implementation, `obj.contract_and_assign_pair(lhs_view: &ArrayViewD,
94/// rhs_view: &ArrayViewD, out: &mut ArrayViewD) -> ()`.
95pub trait PairContractor<A>: Debug {
96    fn contract_pair<'a, 'b, 'c, 'd>(
97        &self,
98        lhs: &'b ArrayViewD<'a, A>,
99        rhs: &'d ArrayViewD<'c, A>,
100    ) -> ArrayD<A>
101    where
102        'a: 'b,
103        'c: 'd,
104        A: Clone + LinalgScalar;
105
106    fn contract_and_assign_pair<'a, 'b, 'c, 'd, 'e, 'f>(
107        &self,
108        lhs: &'b ArrayViewD<'a, A>,
109        rhs: &'d ArrayViewD<'c, A>,
110        out: &'f mut ArrayViewMutD<'e, A>,
111    ) where
112        'a: 'b,
113        'c: 'd,
114        'e: 'f,
115        A: Clone + LinalgScalar,
116    {
117        let result = self.contract_pair(lhs, rhs);
118        out.assign(&result);
119    }
120}
121
122/// Holds a `Box`ed `SingletonContractor` trait object.
123///
124/// Constructed at runtime based on the number of diagonalized, summed, and permuted axes
125/// in the input. Reimplements the `SingletonContractor` trait by delegating to the inner
126/// object.
127///
128/// For example, the contraction `iij->i` will be performed by assigning a `Box`ed
129/// `DiagonalizationAndSummation` to `op`. The contraction `ijk->kij` will be performed
130/// by assigning a `Box`ed `Permutation` to `op`.
131pub struct SingletonContraction<A> {
132    method: SingletonMethod,
133    op: Box<dyn SingletonContractor<A>>,
134}
135
136impl<A> SingletonContraction<A> {
137    pub fn new(sc: &SizedContraction) -> Self {
138        let singleton_summary = SingletonSummary::new(sc);
139        let method = singleton_summary.get_strategy();
140
141        SingletonContraction {
142            method,
143            op: match method {
144                SingletonMethod::Identity => Box::new(Identity::new(sc)),
145                SingletonMethod::Permutation => Box::new(Permutation::new(sc)),
146                SingletonMethod::Summation => Box::new(Summation::new(sc)),
147                SingletonMethod::Diagonalization => Box::new(Diagonalization::new(sc)),
148                SingletonMethod::PermutationAndSummation => {
149                    Box::new(PermutationAndSummation::new(sc))
150                }
151                SingletonMethod::DiagonalizationAndSummation => {
152                    Box::new(DiagonalizationAndSummation::new(sc))
153                }
154            },
155        }
156    }
157}
158
159impl<A> SingletonContractor<A> for SingletonContraction<A> {
160    fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
161    where
162        'a: 'b,
163        A: Clone + LinalgScalar,
164    {
165        self.op.contract_singleton(tensor)
166    }
167}
168
169impl<A> Debug for SingletonContraction<A> {
170    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
171        write!(
172            f,
173            "SingletonContraction {{ method: {:?}, op: {:?} }}",
174            self.method, self.op
175        )
176    }
177}
178
179/// Holds an `Box<dyn SingletonContractor<A>>` and the resulting simplified indices.
180struct SimplificationMethodAndOutput<A> {
181    method: SingletonMethod,
182    op: Box<dyn SingletonContractor<A>>,
183    new_indices: Vec<char>,
184    einsum_string: String,
185}
186
187impl<A> SimplificationMethodAndOutput<A> {
188    /// Based on the number of diagonalized, permuted, and summed axes, chooses a struct implementing
189    /// `SingletonContractor` to simplify the tensor (or `None` if the tensor doesn't need simplification)
190    /// and computes the indices of the simplified tensor.
191    fn from_indices_and_sizes(
192        this_input_indices: &[char],
193        other_input_indices: &[char],
194        output_indices: &[char],
195        orig_contraction: &SizedContraction,
196    ) -> Option<Self> {
197        let this_input_uniques: HashSet<char> = this_input_indices.iter().cloned().collect();
198        let other_input_uniques: HashSet<char> = other_input_indices.iter().cloned().collect();
199        let output_uniques: HashSet<char> = output_indices.iter().cloned().collect();
200
201        let other_and_output: HashSet<char> = other_input_uniques
202            .union(&output_uniques)
203            .cloned()
204            .collect();
205        let desired_uniques: HashSet<char> = this_input_uniques
206            .intersection(&other_and_output)
207            .cloned()
208            .collect();
209        let new_indices: Vec<char> = desired_uniques.iter().cloned().collect();
210
211        let simplification_sc = orig_contraction
212            .subset(&[this_input_indices.to_vec()], &new_indices)
213            .unwrap();
214
215        let SingletonContraction { method, op } = SingletonContraction::new(&simplification_sc);
216
217        match method {
218            SingletonMethod::Identity | SingletonMethod::Permutation => None,
219            _ => Some(SimplificationMethodAndOutput {
220                method,
221                op,
222                new_indices,
223                einsum_string: simplification_sc.as_einsum_string(),
224            }),
225        }
226    }
227}
228
229impl<A> Debug for SimplificationMethodAndOutput<A> {
230    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
231        write!(
232            f,
233            "SingletonContraction {{ method: {:?}, op: {:?}, new_indices: {:?}, einsum_string: {:?} }}",
234            self.method, self.op, self.new_indices, self.einsum_string
235        )
236    }
237}
238
239/// Holds a `Box`ed `PairContractor` trait object and two `Option<Box>`ed simplifications for the LHS and RHS tensors.
240///
241/// For example, the contraction `ijk,kj->jk` will currently be performed as follows:
242///
243/// 1. Simplify the LHS with the contraction `ijk->jk`
244/// 2. Don't simplify the RHS
245/// 3. Use HadamardProductGeneral to compute `jk,kj->jk`
246///
247/// A second example is the contraction `iij,jkk->ik`:
248///
249/// 1. Simplify the LHS with the contraction `iij->ij`
250/// 2. Simplify the RHS with the contraction `jkk->jk`
251/// 3. Use TensordotGeneral to compute `ij,jk->ik`
252///
253/// Since the axis lengths aren't known until runtime, and the actual einsum string may not
254/// be either, it is generally not possible to know at compile time which specific PairContractor
255/// will be used to perform a given contraction, or even which contractions will be performed;
256/// the optimizer could choose a different order.
257pub struct PairContraction<A> {
258    lhs_simplification: Option<SimplificationMethodAndOutput<A>>,
259    rhs_simplification: Option<SimplificationMethodAndOutput<A>>,
260    method: PairMethod,
261    op: Box<dyn PairContractor<A>>,
262    simplified_einsum_string: String,
263}
264
265impl<A> PairContraction<A> {
266    pub fn new(sc: &SizedContraction) -> Self {
267        assert_eq!(sc.contraction.operand_indices.len(), 2);
268        let lhs_indices = &sc.contraction.operand_indices[0];
269        let rhs_indices = &sc.contraction.operand_indices[1];
270        let output_indices = &sc.contraction.output_indices;
271
272        let lhs_simplification = SimplificationMethodAndOutput::from_indices_and_sizes(
273            lhs_indices,
274            rhs_indices,
275            output_indices,
276            sc,
277        );
278        let rhs_simplification = SimplificationMethodAndOutput::from_indices_and_sizes(
279            rhs_indices,
280            lhs_indices,
281            output_indices,
282            sc,
283        );
284        let new_lhs_indices = match &lhs_simplification {
285            Some(s) => s.new_indices.clone(),
286            None => lhs_indices.clone(),
287        };
288        let new_rhs_indices = match &rhs_simplification {
289            Some(s) => s.new_indices.clone(),
290            None => rhs_indices.clone(),
291        };
292
293        let reduced_sc = sc
294            .subset(&[new_lhs_indices, new_rhs_indices], output_indices)
295            .unwrap();
296
297        let pair_summary = PairSummary::new(&reduced_sc);
298        let method = pair_summary.get_strategy();
299
300        let op: Box<dyn PairContractor<A>> = match method {
301            PairMethod::HadamardProduct => {
302                // Never gets returned in current implementation
303                Box::new(HadamardProduct::new(&reduced_sc))
304            }
305            PairMethod::HadamardProductGeneral => {
306                Box::new(HadamardProductGeneral::new(&reduced_sc))
307            }
308            PairMethod::ScalarMatrixProduct => {
309                // Never gets returned in current implementation
310                Box::new(ScalarMatrixProduct::new(&reduced_sc))
311            }
312            PairMethod::ScalarMatrixProductGeneral => {
313                Box::new(ScalarMatrixProductGeneral::new(&reduced_sc))
314            }
315            PairMethod::MatrixScalarProduct => {
316                // Never gets returned in current implementation
317                Box::new(MatrixScalarProduct::new(&reduced_sc))
318            }
319            PairMethod::MatrixScalarProductGeneral => {
320                Box::new(MatrixScalarProductGeneral::new(&reduced_sc))
321            }
322            PairMethod::TensordotFixedPosition => {
323                // Never gets returned in current implementation
324                Box::new(TensordotFixedPosition::new(&reduced_sc))
325            }
326            PairMethod::TensordotGeneral => Box::new(TensordotGeneral::new(&reduced_sc)),
327            PairMethod::StackedTensordotGeneral => {
328                Box::new(StackedTensordotGeneral::new(&reduced_sc))
329            }
330            PairMethod::BroadcastProductGeneral => {
331                // Never gets returned in current implementation
332                Box::new(BroadcastProductGeneral::new(&reduced_sc))
333            }
334        };
335        PairContraction {
336            lhs_simplification,
337            rhs_simplification,
338            method,
339            op,
340            simplified_einsum_string: reduced_sc.as_einsum_string(),
341        }
342    }
343}
344
345impl<A> PairContractor<A> for PairContraction<A> {
346    fn contract_pair<'a, 'b, 'c, 'd>(
347        &self,
348        lhs: &'b ArrayViewD<'a, A>,
349        rhs: &'d ArrayViewD<'c, A>,
350    ) -> ArrayD<A>
351    where
352        'a: 'b,
353        'c: 'd,
354        A: Clone + LinalgScalar,
355    {
356        match (&self.lhs_simplification, &self.rhs_simplification) {
357            (None, None) => self.op.contract_pair(lhs, rhs),
358            (Some(lhs_contraction), None) => self
359                .op
360                .contract_pair(&lhs_contraction.op.contract_singleton(lhs).view(), rhs),
361            (None, Some(rhs_contraction)) => self
362                .op
363                .contract_pair(lhs, &rhs_contraction.op.contract_singleton(rhs).view()),
364            (Some(lhs_contraction), Some(rhs_contraction)) => self.op.contract_pair(
365                &lhs_contraction.op.contract_singleton(lhs).view(),
366                &rhs_contraction.op.contract_singleton(rhs).view(),
367            ),
368        }
369    }
370}
371
372impl<A> Debug for PairContraction<A> {
373    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
374        write!(
375            f,
376            "PairContraction {{ \
377             lhs_simplification: {:?}, \
378             rhs_simplification: {:?}, \
379             method: {:?}, \
380             op: {:?}, \
381             simplified_einsum_string: {:?}",
382            self.lhs_simplification,
383            self.rhs_simplification,
384            self.method,
385            self.op,
386            self.simplified_einsum_string
387        )
388    }
389}
390
391/// Either a singleton contraction, in the case of a single input operand, or a list of pair contractions,
392/// given two or more input operands
393#[derive(Debug)]
394pub enum EinsumPathSteps<A> {
395    /// A `SingletonContraction` consists of some combination of permutation of the input axes,
396    /// diagonalization of repeated indices, and summation across axes not present in the output
397    SingletonContraction(SingletonContraction<A>),
398
399    /// Each `PairContraction` consists of a possible simplification of each of the two input tensors followed
400    /// by a contraction of the two simplified tensors. The two simplified tensors can be combined in a
401    /// number of fashions.
402    PairContractions(Vec<PairContraction<A>>),
403}
404
405/// An `EinsumPath`, returned by [`einsum_path`](fn.einsum_path.html), represents a fully-prepared plan to perform a tensor contraction.
406///
407/// It contains the order in which the input tensors should be contracted with one another or with one of the previous intermediate results,
408/// and for each step in the path, how to perform the pairwise contraction. For example, two tensors might be contracted
409/// with one another by computing the Hadamard (element-wise) product of the tensors, while a different pair might be contracted
410/// by performing a matrix multiplication. The contractions that will be performed are fully specified within the `EinsumPath`.
411pub struct EinsumPath<A> {
412    /// The order in which tensors should be paired off and contracted with one another
413    pub contraction_order: ContractionOrder,
414
415    /// The details of the contractions to be performed
416    pub steps: EinsumPathSteps<A>,
417}
418
419impl<A> EinsumPath<A> {
420    pub fn new(sc: &SizedContraction) -> Self {
421        let contraction_order = generate_optimized_order(sc, OptimizationMethod::Naive);
422
423        EinsumPath::from_path(&contraction_order)
424    }
425
426    pub fn from_path(contraction_order: &ContractionOrder) -> Self {
427        match contraction_order {
428            ContractionOrder::Singleton(sized_contraction) => EinsumPath {
429                contraction_order: contraction_order.clone(),
430                steps: EinsumPathSteps::SingletonContraction(SingletonContraction::new(
431                    sized_contraction,
432                )),
433            },
434            ContractionOrder::Pairs(order_steps) => {
435                let mut steps = Vec::new();
436
437                for step in order_steps.iter() {
438                    steps.push(PairContraction::new(&step.sized_contraction));
439                }
440
441                EinsumPath {
442                    contraction_order: contraction_order.clone(),
443                    steps: EinsumPathSteps::PairContractions(steps),
444                }
445            }
446        }
447    }
448}
449
450impl<A> EinsumPath<A> {
451    pub fn contract_operands(&self, operands: &[&dyn ArrayLike<A>]) -> ArrayD<A>
452    where
453        A: Clone + LinalgScalar,
454    {
455        // Uncomment for help debugging
456        // println!("{:?}", self);
457        match (&self.steps, &self.contraction_order) {
458            (EinsumPathSteps::SingletonContraction(c), ContractionOrder::Singleton(_)) => {
459                c.contract_singleton(&operands[0].into_dyn_view())
460            }
461            (EinsumPathSteps::PairContractions(steps), ContractionOrder::Pairs(order_steps)) => {
462                let mut intermediate_results: Vec<ArrayD<A>> = Vec::new();
463                for (step, order_step) in steps.iter().zip(order_steps.iter()) {
464                    let lhs = match order_step.operand_nums.lhs {
465                        OperandNumber::Input(pos) => operands[pos].into_dyn_view(),
466                        OperandNumber::IntermediateResult(pos) => intermediate_results[pos].view(),
467                    };
468                    let rhs = match order_step.operand_nums.rhs {
469                        OperandNumber::Input(pos) => operands[pos].into_dyn_view(),
470                        OperandNumber::IntermediateResult(pos) => intermediate_results[pos].view(),
471                    };
472                    let intermediate_result = step.contract_pair(&lhs, &rhs);
473                    // let lhs = match order_step.
474                    intermediate_results.push(intermediate_result);
475                }
476                intermediate_results.pop().unwrap()
477            }
478            _ => panic!(), // steps and contraction_order don't match
479        }
480    }
481}
482
483impl<A> Debug for EinsumPath<A> {
484    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
485        match &self.steps {
486            EinsumPathSteps::SingletonContraction(step) => write!(f, "only_step: {:?}", step),
487            EinsumPathSteps::PairContractions(steps) => write!(f, "steps: {:?}", steps),
488        }
489    }
490}