ndarray_einsum/lib.rs
1// Copyright 2019 Jared Samet
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The `ndarray_einsum` crate implements the `einsum` function, originally
16//! implemented for numpy by Mark Wiebe and subsequently reimplemented for
17//! other tensor libraries such as Tensorflow and PyTorch. `einsum` (short for Einstein summation)
18//! implements general multidimensional tensor contraction. Many linear algebra operations
19//! and generalizations of those operations can be expressed as special cases of tensor
20//! contraction. Examples include matrix multiplication, matrix trace, vector dot product,
21//! tensor Hadamard [element-wise] product, axis permutation, outer product, batch
22//! matrix multiplication, bilinear transformations, and many more.
23//!
24//! Examples (deliberately similar to [numpy's documentation](https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html)):
25//!
26//! ```
27//! # use ndarray_einsum::*;
28//! # use ndarray::prelude::*;
29//! let a: Array2<f64> = Array::range(0., 25., 1.)
30//! .into_shape((5,5,)).unwrap();
31//! let b: Array1<f64> = Array::range(0., 5., 1.);
32//! let c: Array2<f64> = Array::range(0., 6., 1.)
33//! .into_shape((2,3,)).unwrap();
34//! let d: Array2<f64> = Array::range(0., 12., 1.)
35//! .into_shape((3,4,)).unwrap();
36//! ```
37//!
38//! Trace of a matrix
39//! ```
40//! # use ndarray_einsum::*;
41//! # use ndarray::prelude::*;
42//! # let a: Array2<f64> = Array::range(0., 25., 1.)
43//! # .into_shape((5,5,)).unwrap();
44//! # let b: Array1<f64> = Array::range(0., 5., 1.);
45//! # let c: Array2<f64> = Array::range(0., 6., 1.)
46//! # .into_shape((2,3,)).unwrap();
47//! # let d: Array2<f64> = Array::range(0., 12., 1.)
48//! # .into_shape((3,4,)).unwrap();
49//! assert_eq!(
50//! einsum("ii", &[&a]).unwrap(),
51//! arr0(60.).into_dyn()
52//! );
53//! assert_eq!(
54//! einsum("ii", &[&a]).unwrap(),
55//! arr0(a.diag().sum()).into_dyn()
56//! );
57//! ```
58//!
59//! Extract the diagonal
60//! ```
61//! # use ndarray_einsum::*;
62//! # use ndarray::prelude::*;
63//! # let a: Array2<f64> = Array::range(0., 25., 1.)
64//! # .into_shape((5,5,)).unwrap();
65//! # let b: Array1<f64> = Array::range(0., 5., 1.);
66//! # let c: Array2<f64> = Array::range(0., 6., 1.)
67//! # .into_shape((2,3,)).unwrap();
68//! # let d: Array2<f64> = Array::range(0., 12., 1.)
69//! # .into_shape((3,4,)).unwrap();
70//! assert_eq!(
71//! einsum("ii->i", &[&a]).unwrap(),
72//! arr1(&[0., 6., 12., 18., 24.]).into_dyn()
73//! );
74//! assert_eq!(
75//! einsum("ii->i", &[&a]).unwrap(),
76//! a.diag().into_dyn()
77//! );
78//!
79//! ```
80//!
81//! Sum over an axis
82//! ```
83//! # use ndarray_einsum::*;
84//! # use ndarray::prelude::*;
85//! # let a: Array2<f64> = Array::range(0., 25., 1.)
86//! # .into_shape((5,5,)).unwrap();
87//! # let b: Array1<f64> = Array::range(0., 5., 1.);
88//! # let c: Array2<f64> = Array::range(0., 6., 1.)
89//! # .into_shape((2,3,)).unwrap();
90//! # let d: Array2<f64> = Array::range(0., 12., 1.)
91//! # .into_shape((3,4,)).unwrap();
92//! assert_eq!(
93//! einsum("ij->i", &[&a]).unwrap(),
94//! arr1(&[10., 35., 60., 85., 110.]).into_dyn()
95//! );
96//! assert_eq!(
97//! einsum("ij->i", &[&a]).unwrap(),
98//! a.sum_axis(Axis(1)).into_dyn()
99//! );
100//!
101//! ```
102//!
103//! Compute matrix transpose
104//! ```
105//! # use ndarray_einsum::*;
106//! # use ndarray::prelude::*;
107//! # let a: Array2<f64> = Array::range(0., 25., 1.)
108//! # .into_shape((5,5,)).unwrap();
109//! # let b: Array1<f64> = Array::range(0., 5., 1.);
110//! # let c: Array2<f64> = Array::range(0., 6., 1.)
111//! # .into_shape((2,3,)).unwrap();
112//! # let d: Array2<f64> = Array::range(0., 12., 1.)
113//! # .into_shape((3,4,)).unwrap();
114//! assert_eq!(
115//! einsum("ji", &[&c]).unwrap(),
116//! c.t().into_dyn()
117//! );
118//! assert_eq!(
119//! einsum("ji", &[&c]).unwrap(),
120//! arr2(&[[0., 3.], [1., 4.], [2., 5.]]).into_dyn()
121//! );
122//! assert_eq!(
123//! einsum("ji", &[&c]).unwrap(),
124//! einsum("ij->ji", &[&c]).unwrap()
125//! );
126//!
127//! ```
128//!
129//! Multiply two matrices
130//! ```
131//! # use ndarray_einsum::*;
132//! # use ndarray::prelude::*;
133//! # let a: Array2<f64> = Array::range(0., 25., 1.)
134//! # .into_shape((5,5,)).unwrap();
135//! # let b: Array1<f64> = Array::range(0., 5., 1.);
136//! # let c: Array2<f64> = Array::range(0., 6., 1.)
137//! # .into_shape((2,3,)).unwrap();
138//! # let d: Array2<f64> = Array::range(0., 12., 1.)
139//! # .into_shape((3,4,)).unwrap();
140//! assert_eq!(
141//! einsum("ij,jk->ik", &[&c, &d]).unwrap(),
142//! c.dot(&d).into_dyn()
143//! );
144//! ```
145//!
146//! Compute the path separately from the result
147//! ```
148//! # use ndarray_einsum::*;
149//! # use ndarray::prelude::*;
150//! # let a: Array2<f64> = Array::range(0., 25., 1.)
151//! # .into_shape((5,5,)).unwrap();
152//! # let b: Array1<f64> = Array::range(0., 5., 1.);
153//! # let c: Array2<f64> = Array::range(0., 6., 1.)
154//! # .into_shape((2,3,)).unwrap();
155//! # let d: Array2<f64> = Array::range(0., 12., 1.)
156//! # .into_shape((3,4,)).unwrap();
157//! let path = einsum_path(
158//! "ij,jk->ik",
159//! &[&c, &d],
160//! OptimizationMethod::Naive
161//! ).unwrap();
162//! assert_eq!(
163//! path.contract_operands(&[&c, &d]),
164//! c.dot(&d).into_dyn()
165//! );
166//! ```
167use ndarray::prelude::*;
168use ndarray::{Data, IxDyn, LinalgScalar};
169
170mod validation;
171pub use validation::{
172 validate, validate_and_optimize_order, validate_and_size, Contraction, SizedContraction,
173};
174
175mod optimizers;
176pub use optimizers::{generate_optimized_order, ContractionOrder, OptimizationMethod};
177
178mod contractors;
179pub use contractors::{EinsumPath, EinsumPathSteps};
180use contractors::{PairContractor, TensordotGeneral};
181
182#[allow(clippy::wrong_self_convention)]
183/// This trait is implemented for all `ArrayBase` variants and is parameterized by the data type.
184///
185/// It's here so `einsum` and the other functions accepting a list of operands
186/// can take a slice `&[&dyn ArrayLike<A>]` where the elements of the slice can have
187/// different numbers of dimensions and can be a mixture of `Array` and `ArrayView`.
188pub trait ArrayLike<A> {
189 fn into_dyn_view(&self) -> ArrayView<'_, A, IxDyn>;
190}
191
192impl<A, S, D> ArrayLike<A> for ArrayBase<S, D>
193where
194 S: Data<Elem = A>,
195 D: Dimension,
196{
197 fn into_dyn_view(&self) -> ArrayView<'_, A, IxDyn> {
198 self.view().into_dyn()
199 }
200}
201
202/// Wrapper around [SizedContraction::contract_operands](struct.SizedContraction.html#method.contract_operands).
203pub fn einsum_sc<A: LinalgScalar>(
204 sized_contraction: &SizedContraction,
205 operands: &[&dyn ArrayLike<A>],
206) -> ArrayD<A> {
207 sized_contraction.contract_operands(operands)
208}
209
210/// Create a [SizedContraction](struct.SizedContraction.html), optimize the contraction order, and compile the result into an [EinsumPath](struct.EinsumPath.html).
211pub fn einsum_path<A>(
212 input_string: &str,
213 operands: &[&dyn ArrayLike<A>],
214 optimization_strategy: OptimizationMethod,
215) -> Result<EinsumPath<A>, &'static str> {
216 let contraction_order =
217 validate_and_optimize_order(input_string, operands, optimization_strategy)?;
218 Ok(EinsumPath::from_path(&contraction_order))
219}
220
221/// Performs all steps of the process in one function: parse the string, compile the execution plan, and execute the contraction.
222pub fn einsum<A: LinalgScalar>(
223 input_string: &str,
224 operands: &[&dyn ArrayLike<A>],
225) -> Result<ArrayD<A>, &'static str> {
226 let sized_contraction = validate_and_size(input_string, operands)?;
227 Ok(einsum_sc(&sized_contraction, operands))
228}
229
230/// Compute tensor dot product between two tensors.
231///
232/// Similar to [the numpy function of the same name](https://docs.scipy.org/doc/numpy/reference/generated/numpy.tensordot.html).
233/// Easiest to explain by showing the `einsum` equivalents:
234///
235/// ```
236/// # use ndarray::prelude::*;
237/// # use ndarray_einsum::*;
238/// let m1 = Array::range(0., (3*4*5*6) as f64, 1.)
239/// .into_shape((3,4,5,6,))
240/// .unwrap();
241/// let m2 = Array::range(0., (4*5*6*7) as f64, 1.)
242/// .into_shape((4,5,6,7))
243/// .unwrap();
244/// assert_eq!(
245/// einsum(
246/// "ijkl,jklm->im",
247/// &[&m1, &m2]
248/// ).unwrap(),
249/// tensordot(
250/// &m1,
251/// &m2,
252/// &[Axis(1), Axis(2), Axis(3)],
253/// &[Axis(0), Axis(1), Axis(2)]
254/// )
255/// );
256///
257/// assert_eq!(
258/// einsum(
259/// "abic,dief->abcdef",
260/// &[&m1, &m2]
261/// ).unwrap(),
262/// tensordot(
263/// &m1,
264/// &m2,
265/// &[Axis(2)],
266/// &[Axis(1)]
267/// )
268/// );
269/// ```
270pub fn tensordot<A, S, S2, D, E>(
271 lhs: &ArrayBase<S, D>,
272 rhs: &ArrayBase<S2, E>,
273 lhs_axes: &[Axis],
274 rhs_axes: &[Axis],
275) -> ArrayD<A>
276where
277 A: ndarray::LinalgScalar,
278 S: Data<Elem = A>,
279 S2: Data<Elem = A>,
280 D: Dimension,
281 E: Dimension,
282{
283 assert_eq!(lhs_axes.len(), rhs_axes.len());
284 let lhs_axes_copy: Vec<_> = lhs_axes.iter().map(|x| x.index()).collect();
285 let rhs_axes_copy: Vec<_> = rhs_axes.iter().map(|x| x.index()).collect();
286 let output_order: Vec<usize> = (0..(lhs.ndim() + rhs.ndim() - 2 * (lhs_axes.len()))).collect();
287 let tensordotter = TensordotGeneral::from_shapes_and_axis_numbers(
288 lhs.shape(),
289 rhs.shape(),
290 &lhs_axes_copy,
291 &rhs_axes_copy,
292 &output_order,
293 );
294 tensordotter.contract_pair(&lhs.view().into_dyn(), &rhs.view().into_dyn())
295}