Expand description
The ndarray_einsum
crate implements the einsum
function, originally
implemented for numpy by Mark Wiebe and subsequently reimplemented for
other tensor libraries such as Tensorflow and PyTorch. einsum
(short for Einstein summation)
implements general multidimensional tensor contraction. Many linear algebra operations
and generalizations of those operations can be expressed as special cases of tensor
contraction. Examples include matrix multiplication, matrix trace, vector dot product,
tensor Hadamard [element-wise] product, axis permutation, outer product, batch
matrix multiplication, bilinear transformations, and many more.
Examples (deliberately similar to numpy’s documentation):
let a: Array2<f64> = Array::range(0., 25., 1.)
.into_shape((5,5,)).unwrap();
let b: Array1<f64> = Array::range(0., 5., 1.);
let c: Array2<f64> = Array::range(0., 6., 1.)
.into_shape((2,3,)).unwrap();
let d: Array2<f64> = Array::range(0., 12., 1.)
.into_shape((3,4,)).unwrap();
Trace of a matrix
assert_eq!(
einsum("ii", &[&a]).unwrap(),
arr0(60.).into_dyn()
);
assert_eq!(
einsum("ii", &[&a]).unwrap(),
arr0(a.diag().sum()).into_dyn()
);
Extract the diagonal
assert_eq!(
einsum("ii->i", &[&a]).unwrap(),
arr1(&[0., 6., 12., 18., 24.]).into_dyn()
);
assert_eq!(
einsum("ii->i", &[&a]).unwrap(),
a.diag().into_dyn()
);
Sum over an axis
assert_eq!(
einsum("ij->i", &[&a]).unwrap(),
arr1(&[10., 35., 60., 85., 110.]).into_dyn()
);
assert_eq!(
einsum("ij->i", &[&a]).unwrap(),
a.sum_axis(Axis(1)).into_dyn()
);
Compute matrix transpose
assert_eq!(
einsum("ji", &[&c]).unwrap(),
c.t().into_dyn()
);
assert_eq!(
einsum("ji", &[&c]).unwrap(),
arr2(&[[0., 3.], [1., 4.], [2., 5.]]).into_dyn()
);
assert_eq!(
einsum("ji", &[&c]).unwrap(),
einsum("ij->ji", &[&c]).unwrap()
);
Multiply two matrices
assert_eq!(
einsum("ij,jk->ik", &[&c, &d]).unwrap(),
c.dot(&d).into_dyn()
);
Compute the path separately from the result
let path = einsum_path(
"ij,jk->ik",
&[&c, &d],
OptimizationMethod::Naive
).unwrap();
assert_eq!(
path.contract_operands(&[&c, &d]),
c.dot(&d).into_dyn()
);
Modules§
- contractors 🔒
- Implementations of the base-case singleton and pair contractors for different types of contractions.
- optimizers 🔒
- Methods to produce a
ContractionOrder
, specifying what order in which to perform pairwise contractions between tensors in order to perform the full contraction. - validation 🔒
- Contains functions and structs related to parsing an
einsum
-formatted string
Structs§
- Contraction
- A
Contraction
contains the result of parsing aneinsum
-formatted string. - Einsum
Path - An
EinsumPath
, returned byeinsum_path
, represents a fully-prepared plan to perform a tensor contraction. - Sized
Contraction - A
SizedContraction
contains aContraction
as well as aHashMap<char, usize>
specifying the axis lengths for each index in the contraction.
Enums§
- Contraction
Order - The order in which to contract pairs of tensors and the specific contractions to be performed between the pairs.
- Einsum
Path Steps - Either a singleton contraction, in the case of a single input operand, or a list of pair contractions, given two or more input operands
- Optimization
Method - Strategy for optimizing the contraction. The only currently supported options are “Naive” and “Reverse”.
Traits§
- Array
Like - This trait is implemented for all
ArrayBase
variants and is parameterized by the data type.
Functions§
- einsum
- Performs all steps of the process in one function: parse the string, compile the execution plan, and execute the contraction.
- einsum_
path - Create a SizedContraction, optimize the contraction order, and compile the result into an EinsumPath.
- einsum_
sc - Wrapper around SizedContraction::contract_operands.
- generate_
optimized_ order - Given a
SizedContraction
and an optimization strategy, returns an order in which to perform pairwise contractions in order to produce the final result - tensordot
- Compute tensor dot product between two tensors.
- validate
- Wrapper around Contraction::new().
- validate_
and_ optimize_ order - Create a SizedContraction and then optimize the order in which pairs of inputs will be contracted.
- validate_
and_ size - Wrapper around SizedContraction::new().