Module dfdx::tensor_ops

source ·
Expand description

Operations on tensors like relu(), matmul(), softmax(), and more.

Generic function and struct methods

All functionality is provided in two ways.

  1. The generic standalone function that takes a generic parameter. e.g. relu().
  2. The struct method for tensor structs. e.g. crate::tensor::Tensor::relu().

The functions are all just pass throughs to the tensor methods.

Fallibility

All tensor methods also have a try_* variant, like crate::tensor::Tensor::relu() and crate::tensor::Tensor::try_relu().

These methods return a Result, where the error in most cases indicates an allocation error.

Axes/Dimensions for broadcasting/reductions/selecting

For the following sections, some traits/functions utilizing const isize to determine the axis to apply the transformation to.

Here are the valid axes for each tensor:

  1. 0d tensor: Axis<0>
  2. 1d tensor: Axis<0>
  3. 2d tensor: Axis<0>, Axis<1>
  4. 3d tensor: Axis<0>, Axis<1>, Axis<2>,
  5. 4d tensor: Axis<0>, Axis<1>, Axis<2>, Axis<3>
  6. etc.

To specify multiple axes you can use Axes2, Axes3, and Axes4

Reductions

There are a number of methods that reduce 1 or more axes.Anything that can be reduced can also be broadcasted back to the original shape using BroadcastTo.

Each axis reducing function has two generic parameters:

  1. The target shape
  2. The axes to reduce along You only need to specify one of these! Generally it is better practice to specify the target shape, unless it is ambiguous in which case you should specify the axes.

For example:

let t: Tensor<Rank3<2, 4, 6>, f32, _> = dev.zeros();
// shape version
let _ = t.clone().sum::<Rank1<4>, _>();
// axes version
let _ = t.clone().sum::<_, Axes2<0, 2>>();
// typed version
let _: Tensor<Rank1<4>, _, _> = t.clone().sum();

Complete list of reductions:

Broadcasts

Broadcasting tensors is provided through the BroadcastTo trait. Similar to reductions there are two generic parameters to broadcast:

  1. (Required) The target shape
  2. (usually optional) The axes of the result type to broadcast You’ll only need to specify axes if the shape makes the broadcasts ambiguous.

For example:

let t: Tensor<Rank1<4>, f32, _> = dev.zeros();
// shape version
let _ = t.clone().broadcast::<Rank3<2, 4, 6>, _>();
// typed version
let _: Tensor<Rank3<2, 4, 6>, _, _> = t.clone().broadcast();

Rust can also infer the output type if you use it in another operation:

let big: Tensor<Rank2<2, 5>, f32, _> = dev.zeros();
let small: Tensor<Rank1<5>, f32, _> = dev.zeros();
let _ = big + small.broadcast();

Permutes

Permuting has an identical interface to broadcasts/reductions:

let t: Tensor<Rank3<2, 3, 4>, f32, _> = dev.zeros();
// shape version
let _ = t.clone().permute::<Rank3<3, 4, 2>, _>();
// axes version
let _ = t.clone().permute::<_, Axes3<1, 2, 0>>();

Indexing using select and gather

Two traits provide indexing capability SelectTo and GatherTo. The difference is:

  1. SelectTo::select allows you to select a single value
  2. GatherTo::gather allows you select multiple values from the same axis.

For example you can select from the 0th axis like so:

let t = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let r: Tensor<Rank1<3>, f32, _> = t.select(dev.tensor(1));
assert_eq!(r.array(), [4.0, 5.0, 6.0]);

Or you can gather from the 0th axis to select multiple entries:

let t = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let r: Tensor<Rank2<3, 3>, f32, _> = t.gather(dev.tensor([1, 1, 0]));
assert_eq!(r.array(), [
    [4.0, 5.0, 6.0],
    [4.0, 5.0, 6.0],
    [1.0, 2.0, 3.0],
]);

To select from anything after the 0th axis, you need a multi-dimensional axis. See GatherTo and SelectTo docstrings for examples of this.

But you can use BroadcastTo to make this easy! In this example we select the same index from the 1st axis of a tensor:

let t = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let r = t.select::<Rank1<2>, _>(dev.tensor(1).broadcast());
assert_eq!(r.array(), [2.0, 5.0]);

Structs

Enums

Traits

Functions