ndarray_einsum/contractors/
singleton_contractors.rs

1// Copyright 2019 Jared Samet
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Contains the specific implementations of `SingletonContractor` and `SingletonViewer` that
16//! represent the base-case ways to contract or simplify a single tensor.
17//!
18//! All the structs here perform perform some combination of
19//! permutation of the input axes (e.g. `ijk->jki`), diagonalization across repeated but
20//! un-summed axes (e.g. `ii->i`),
21//! and summation across axes not present in the output index list (e.g. `ijk->j`).
22
23use ndarray::prelude::*;
24use ndarray::LinalgScalar;
25
26use super::{SingletonContractor, SingletonViewer};
27use crate::{Contraction, SizedContraction};
28
29/// Returns a view or clone of the input tensor.
30///
31/// Example: `ij->ij`
32#[derive(Clone, Debug)]
33pub struct Identity {}
34
35impl Identity {
36    pub fn new(_sc: &SizedContraction) -> Self {
37        Identity {}
38    }
39}
40
41impl<A> SingletonViewer<A> for Identity {
42    fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
43    where
44        'a: 'b,
45        A: Clone + LinalgScalar,
46    {
47        tensor.view()
48    }
49}
50
51impl<A> SingletonContractor<A> for Identity {
52    fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
53    where
54        'a: 'b,
55        A: Clone + LinalgScalar,
56    {
57        tensor.to_owned()
58    }
59}
60
61/// Permutes the axes of the input tensor and returns a view or clones the elements.
62///
63/// Example: `ij->ji`
64#[derive(Clone, Debug)]
65pub struct Permutation {
66    permutation: Vec<usize>,
67}
68
69impl Permutation {
70    pub fn new(sc: &SizedContraction) -> Self {
71        let SizedContraction {
72            contraction:
73                Contraction {
74                    operand_indices,
75                    output_indices,
76                    ..
77                },
78            ..
79        } = sc;
80
81        assert_eq!(operand_indices.len(), 1);
82        assert_eq!(operand_indices[0].len(), output_indices.len());
83
84        let mut permutation = Vec::new();
85        for &c in output_indices.iter() {
86            permutation.push(operand_indices[0].iter().position(|&x| x == c).unwrap());
87        }
88
89        Permutation { permutation }
90    }
91
92    pub fn from_indices(permutation: &[usize]) -> Self {
93        Permutation {
94            permutation: permutation.to_vec(),
95        }
96    }
97}
98
99impl<A> SingletonViewer<A> for Permutation {
100    fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
101    where
102        'a: 'b,
103        A: Clone + LinalgScalar,
104    {
105        tensor.view().permuted_axes(IxDyn(&self.permutation))
106    }
107}
108
109impl<A> SingletonContractor<A> for Permutation {
110    fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
111    where
112        'a: 'b,
113        A: Clone + LinalgScalar,
114    {
115        tensor
116            .view()
117            .permuted_axes(IxDyn(&self.permutation))
118            .to_owned()
119    }
120}
121
122/// Sums across the elements of the input tensor that don't appear in the output tensor.
123///
124/// Example: `ij->i`
125#[derive(Clone, Debug)]
126pub struct Summation {
127    adjusted_axis_list: Vec<usize>,
128}
129
130impl Summation {
131    pub fn new(sc: &SizedContraction) -> Self {
132        let output_indices = &sc.contraction.output_indices;
133        let input_indices = &sc.contraction.operand_indices[0];
134
135        Summation::from_sizes(
136            output_indices.len(),
137            input_indices.len() - output_indices.len(),
138        )
139    }
140
141    fn from_sizes(start_index: usize, num_summed_axes: usize) -> Self {
142        assert!(num_summed_axes >= 1);
143        let adjusted_axis_list = (0..num_summed_axes).map(|_| start_index).collect();
144
145        Summation { adjusted_axis_list }
146    }
147}
148
149impl<A> SingletonContractor<A> for Summation {
150    fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
151    where
152        'a: 'b,
153        A: Clone + LinalgScalar,
154    {
155        let mut result = tensor.sum_axis(Axis(self.adjusted_axis_list[0]));
156        for &axis in self.adjusted_axis_list[1..].iter() {
157            result = result.sum_axis(Axis(axis));
158        }
159        result
160    }
161}
162
163/// Returns the elements of the input tensor where all instances of the repeated indices are equal to one another.
164/// Optionally permutes the axes of the tensor as well.
165///
166/// Examples:
167///
168/// 1. `ii->i`
169/// 2. `iij->ji`
170#[derive(Clone, Debug)]
171pub struct Diagonalization {
172    input_to_output_mapping: Vec<usize>,
173    output_shape: Vec<usize>,
174}
175
176impl Diagonalization {
177    pub fn new(sc: &SizedContraction) -> Self {
178        let SizedContraction {
179            contraction:
180                Contraction {
181                    operand_indices,
182                    output_indices,
183                    ..
184                },
185            output_size,
186        } = sc;
187        assert_eq!(operand_indices.len(), 1);
188
189        let mut adjusted_output_indices = output_indices.clone();
190        let mut input_to_output_mapping = Vec::new();
191        for &c in operand_indices[0].iter() {
192            let current_length = adjusted_output_indices.len();
193            match adjusted_output_indices.iter().position(|&x| x == c) {
194                Some(pos) => {
195                    input_to_output_mapping.push(pos);
196                }
197                None => {
198                    adjusted_output_indices.push(c);
199                    input_to_output_mapping.push(current_length);
200                }
201            }
202        }
203        let output_shape = adjusted_output_indices
204            .iter()
205            .map(|c| output_size[c])
206            .collect();
207
208        Diagonalization {
209            input_to_output_mapping,
210            output_shape,
211        }
212    }
213}
214
215impl<A> SingletonViewer<A> for Diagonalization {
216    fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
217    where
218        'a: 'b,
219        A: Clone + LinalgScalar,
220    {
221        // Construct the stride array on the fly by enumerating (idx, stride) from strides() and
222        // adding stride to self.which_index_is_this
223        let mut strides = vec![0; self.output_shape.len()];
224        for (idx, &stride) in tensor.strides().iter().enumerate() {
225            assert!(stride > 0);
226            strides[self.input_to_output_mapping[idx]] += stride as usize;
227        }
228
229        // Output shape we want is already stored in self.output_shape
230        // let t = ArrayView::from_shape(IxDyn(&[3]).strides(IxDyn(&[4])), &sl).unwrap();
231        let data_slice = tensor.as_slice_memory_order().unwrap();
232        ArrayView::from_shape(
233            IxDyn(&self.output_shape).strides(IxDyn(&strides)),
234            data_slice,
235        )
236        .unwrap()
237    }
238}
239
240impl<A> SingletonContractor<A> for Diagonalization {
241    fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
242    where
243        'a: 'b,
244        A: Clone + LinalgScalar,
245    {
246        // We're only using this method if the tensor is not contiguous
247        // Clones twice as a result
248        let cloned_tensor: ArrayD<A> =
249            Array::from_shape_vec(tensor.raw_dim(), tensor.iter().cloned().collect()).unwrap();
250        self.view_singleton(&cloned_tensor.view()).into_owned()
251    }
252}
253
254/// Permutes the elements of the input tensor and sums across elements that don't appear in the output.
255///
256/// Example: `ijk->kj`
257#[derive(Clone, Debug)]
258pub struct PermutationAndSummation {
259    permutation: Permutation,
260    summation: Summation,
261}
262
263impl PermutationAndSummation {
264    pub fn new(sc: &SizedContraction) -> Self {
265        let mut output_order: Vec<usize> = Vec::new();
266
267        for &output_char in sc.contraction.output_indices.iter() {
268            let input_pos = sc.contraction.operand_indices[0]
269                .iter()
270                .position(|&input_char| input_char == output_char)
271                .unwrap();
272            output_order.push(input_pos);
273        }
274        for (i, &input_char) in sc.contraction.operand_indices[0].iter().enumerate() {
275            if !sc.contraction.output_indices.contains(&input_char) {
276                output_order.push(i);
277            }
278        }
279
280        let permutation = Permutation::from_indices(&output_order);
281        let summation = Summation::new(sc);
282
283        PermutationAndSummation {
284            permutation,
285            summation,
286        }
287    }
288}
289
290impl<A> SingletonContractor<A> for PermutationAndSummation {
291    fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
292    where
293        'a: 'b,
294        A: Clone + LinalgScalar,
295    {
296        let permuted_singleton = self.permutation.view_singleton(tensor);
297        self.summation.contract_singleton(&permuted_singleton)
298    }
299}
300
301/// Returns the elements of the input tensor where all instances of the repeated indices are equal
302/// to one another, optionally permuting the axes, and sums across indices that don't appear in the output.
303///
304/// Examples:
305///
306/// 1. `iijk->ik` (Diagonalizes the `i` axes and sums over `j`)
307/// 2. `jijik->ki` (Diagonalizes `i` and `j` and sums over `j` after diagonalization)
308#[derive(Clone, Debug)]
309pub struct DiagonalizationAndSummation {
310    diagonalization: Diagonalization,
311    summation: Summation,
312}
313
314impl DiagonalizationAndSummation {
315    pub fn new(sc: &SizedContraction) -> Self {
316        let diagonalization = Diagonalization::new(sc);
317        let summation = Summation::from_sizes(
318            sc.contraction.output_indices.len(),
319            diagonalization.output_shape.len() - sc.contraction.output_indices.len(),
320        );
321
322        DiagonalizationAndSummation {
323            diagonalization,
324            summation,
325        }
326    }
327}
328
329impl<A> SingletonContractor<A> for DiagonalizationAndSummation {
330    fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
331    where
332        'a: 'b,
333        A: Clone + LinalgScalar,
334    {
335        // We can only do Diagonalization directly as a view if all the strides are
336        // positive and if the tensor is contiguous. We can't know this just from
337        // looking at the SizedContraction; we need the actual tensor that will
338        // be operated on. So this needs to get checked at "runtime".
339        //
340        // If either condition fails, we use the contract_singleton version to
341        // create a new tensor and view() that intermediate result.
342        let contracted_singleton;
343        let viewed_singleton = if tensor.as_slice_memory_order().is_some()
344            && tensor.strides().iter().all(|&stride| stride > 0)
345        {
346            self.diagonalization.view_singleton(tensor)
347        } else {
348            contracted_singleton = self.diagonalization.contract_singleton(tensor);
349            contracted_singleton.view()
350        };
351
352        self.summation.contract_singleton(&viewed_singleton)
353    }
354}