[][src]Crate ndarray_einsum_beta

The ndarray_einsum crate implements the einsum function, originally implemented for numpy by Mark Wiebe and subsequently reimplemented for other tensor libraries such as Tensorflow and PyTorch. einsum (short for Einstein summation) implements general multidimensional tensor contraction. Many linear algebra operations and generalizations of those operations can be expressed as special cases of tensor contraction. Examples include matrix multiplication, matrix trace, vector dot product, tensor Hadamard [element-wise] product, axis permutation, outer product, batch matrix multiplication, bilinear transformations, and many more.

Examples (deliberately similar to numpy's documentation):

let a: Array2<f64> = Array::range(0., 25., 1.)
    .into_shape((5,5,)).unwrap();
let b: Array1<f64> = Array::range(0., 5., 1.);
let c: Array2<f64> = Array::range(0., 6., 1.)
    .into_shape((2,3,)).unwrap();
let d: Array2<f64> = Array::range(0., 12., 1.)
    .into_shape((3,4,)).unwrap();

Trace of a matrix

assert_eq!(
    einsum("ii", &[&a]).unwrap(),
    arr0(60.).into_dyn()
);
assert_eq!(
    einsum("ii", &[&a]).unwrap(),
    arr0(a.diag().sum()).into_dyn()
);

Extract the diagonal

assert_eq!(
    einsum("ii->i", &[&a]).unwrap(),
    arr1(&[0., 6., 12., 18., 24.]).into_dyn()
);
assert_eq!(
    einsum("ii->i", &[&a]).unwrap(),
    a.diag().into_dyn()
);

Sum over an axis

assert_eq!(
    einsum("ij->i", &[&a]).unwrap(),
    arr1(&[10., 35., 60., 85., 110.]).into_dyn()
);
assert_eq!(
    einsum("ij->i", &[&a]).unwrap(),
    a.sum_axis(Axis(1)).into_dyn()
);

Compute matrix transpose

assert_eq!(
    einsum("ji", &[&c]).unwrap(),
    c.t().into_dyn()
);
assert_eq!(
    einsum("ji", &[&c]).unwrap(),
    arr2(&[[0., 3.], [1., 4.], [2., 5.]]).into_dyn()
);
assert_eq!(
    einsum("ji", &[&c]).unwrap(),
    einsum("ij->ji", &[&c]).unwrap()
);

Multiply two matrices

assert_eq!(
    einsum("ij,jk->ik", &[&c, &d]).unwrap(),
    c.dot(&d).into_dyn()
);

Compute the path separately from the result

let path = einsum_path(
    "ij,jk->ik",
    &[&c, &d],
    OptimizationMethod::Naive
).unwrap();
assert_eq!(
    path.contract_operands(&[&c, &d]),
    c.dot(&d).into_dyn()
);

Modules

contractors

Implementations of the base-case singleton and pair contractors for different types of contractions.

optimizers

Methods to produce a ContractionOrder, specifying what order in which to perform pairwise contractions between tensors in order to perform the full contraction.

slow_versions

Very inefficient and expected to be removed or only used for testing

validation

Contains functions and structs related to parsing an einsum-formatted string

Structs

Contraction

A Contraction contains the result of parsing an einsum-formatted string.

EinsumPath

An EinsumPath, returned by einsum_path, represents a fully-prepared plan to perform a tensor contraction.

SizedContraction

A SizedContraction contains a Contraction as well as a HashMap<char, usize> specifying the axis lengths for each index in the contraction.

Enums

ContractionOrder

The order in which to contract pairs of tensors and the specific contractions to be performed between the pairs.

EinsumPathSteps

Either a singleton contraction, in the case of a single input operand, or a list of pair contractions, given two or more input operands

OptimizationMethod

Strategy for optimizing the contraction. The only currently supported options are "Naive" and "Reverse".

Traits

ArrayLike

This trait is implemented for all ArrayBase variants and is parameterized by the data type.

Functions

einsum

Performs all steps of the process in one function: parse the string, compile the execution plan, and execute the contraction.

einsum_path

Create a SizedContraction, optimize the contraction order, and compile the result into an EinsumPath.

einsum_sc

Wrapper around SizedContraction::contract_operands.

generate_optimized_order

Given a SizedContraction and an optimization strategy, returns an order in which to perform pairwise contractions in order to produce the final result

tensordot

Compute tensor dot product between two tensors.

validate

Wrapper around Contraction::new().

validate_and_optimize_order

Create a SizedContraction and then optimize the order in which pairs of inputs will be contracted.

validate_and_size

Wrapper around SizedContraction::new().