Crate ndarray_einsum_beta

Source
Expand description

The ndarray_einsum crate implements the einsum function, originally implemented for numpy by Mark Wiebe and subsequently reimplemented for other tensor libraries such as Tensorflow and PyTorch. einsum (short for Einstein summation) implements general multidimensional tensor contraction. Many linear algebra operations and generalizations of those operations can be expressed as special cases of tensor contraction. Examples include matrix multiplication, matrix trace, vector dot product, tensor Hadamard [element-wise] product, axis permutation, outer product, batch matrix multiplication, bilinear transformations, and many more.

Examples (deliberately similar to numpy’s documentation):

let a: Array2<f64> = Array::range(0., 25., 1.)
    .into_shape((5,5,)).unwrap();
let b: Array1<f64> = Array::range(0., 5., 1.);
let c: Array2<f64> = Array::range(0., 6., 1.)
    .into_shape((2,3,)).unwrap();
let d: Array2<f64> = Array::range(0., 12., 1.)
    .into_shape((3,4,)).unwrap();

Trace of a matrix

assert_eq!(
    einsum("ii", &[&a]).unwrap(),
    arr0(60.).into_dyn()
);
assert_eq!(
    einsum("ii", &[&a]).unwrap(),
    arr0(a.diag().sum()).into_dyn()
);

Extract the diagonal

assert_eq!(
    einsum("ii->i", &[&a]).unwrap(),
    arr1(&[0., 6., 12., 18., 24.]).into_dyn()
);
assert_eq!(
    einsum("ii->i", &[&a]).unwrap(),
    a.diag().into_dyn()
);

Sum over an axis

assert_eq!(
    einsum("ij->i", &[&a]).unwrap(),
    arr1(&[10., 35., 60., 85., 110.]).into_dyn()
);
assert_eq!(
    einsum("ij->i", &[&a]).unwrap(),
    a.sum_axis(Axis(1)).into_dyn()
);

Compute matrix transpose

assert_eq!(
    einsum("ji", &[&c]).unwrap(),
    c.t().into_dyn()
);
assert_eq!(
    einsum("ji", &[&c]).unwrap(),
    arr2(&[[0., 3.], [1., 4.], [2., 5.]]).into_dyn()
);
assert_eq!(
    einsum("ji", &[&c]).unwrap(),
    einsum("ij->ji", &[&c]).unwrap()
);

Multiply two matrices

assert_eq!(
    einsum("ij,jk->ik", &[&c, &d]).unwrap(),
    c.dot(&d).into_dyn()
);

Compute the path separately from the result

let path = einsum_path(
    "ij,jk->ik",
    &[&c, &d],
    OptimizationMethod::Naive
).unwrap();
assert_eq!(
    path.contract_operands(&[&c, &d]),
    c.dot(&d).into_dyn()
);

Modules§

contractors 🔒
Implementations of the base-case singleton and pair contractors for different types of contractions.
optimizers 🔒
Methods to produce a ContractionOrder, specifying what order in which to perform pairwise contractions between tensors in order to perform the full contraction.
validation 🔒
Contains functions and structs related to parsing an einsum-formatted string

Structs§

Contraction
A Contraction contains the result of parsing an einsum-formatted string.
EinsumPath
An EinsumPath, returned by einsum_path, represents a fully-prepared plan to perform a tensor contraction.
SizedContraction
A SizedContraction contains a Contraction as well as a HashMap<char, usize> specifying the axis lengths for each index in the contraction.

Enums§

ContractionOrder
The order in which to contract pairs of tensors and the specific contractions to be performed between the pairs.
EinsumPathSteps
Either a singleton contraction, in the case of a single input operand, or a list of pair contractions, given two or more input operands
OptimizationMethod
Strategy for optimizing the contraction. The only currently supported options are “Naive” and “Reverse”.

Traits§

ArrayLike
This trait is implemented for all ArrayBase variants and is parameterized by the data type.

Functions§

einsum
Performs all steps of the process in one function: parse the string, compile the execution plan, and execute the contraction.
einsum_path
Create a SizedContraction, optimize the contraction order, and compile the result into an EinsumPath.
einsum_sc
Wrapper around SizedContraction::contract_operands.
generate_optimized_order
Given a SizedContraction and an optimization strategy, returns an order in which to perform pairwise contractions in order to produce the final result
tensordot
Compute tensor dot product between two tensors.
validate
Wrapper around Contraction::new().
validate_and_optimize_order
Create a SizedContraction and then optimize the order in which pairs of inputs will be contracted.
validate_and_size
Wrapper around SizedContraction::new().