pub trait SelectTo<T, Axes> {
    type Indices: Clone;

    fn select(self, indices: &Self::Indices) -> T;
}
Expand description

Select values along Axes resulting in T. Equivalent to torch.select and torch.gather from pytorch.

There are two ways to select:

  1. Select a single value from an axis, which removes that axis and returns a smaller tensor
  2. Select multiple values from an axis, which keeps the number of dimensions the same. You can select the same element multiple number of times.

You can also select batches of data with this trait.

Required Associated Types

Required Methods

Select sub elements using Self::Indices. The same element can be selected multiple times depending on Self::Indices.

Selecting single value from 2d tensors:

// select a single element from the 0th axis
let _: Tensor1D<5> = Tensor2D::<3, 5>::zeros().select(&0);

// select a single element from the 1st axis - number of elements is equal
// to the size of the 0th axis, and the usize values can be 0..5
let _: Tensor1D<3> = Tensor2D::<3, 5>::zeros().select(&[0, 2, 4]);

Selecting multiple values from 2d tensors:

// select a multiple elements from the 0th axis.
// the number of indices is the new size of the 0th axis.
let _: Tensor2D<6, 5> = Tensor2D::<3, 5>::zeros().select(&[0, 1, 2, 0, 1, 2]);

// select a multiple elements from the 1st axis.
// must have same number of elements as the 0th axis, and the number of indices
// is the new size of the 1st axis.
let _: Tensor2D<3, 2> = Tensor2D::<3, 5>::zeros().select(&[[0, 4], [1, 3], [2, 2]]);

Selecting batch of values from a 1d tensor:

let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select(&[[0], [1]]);

Selecting batch of values from a 2d tensor:

let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select(&[[0], [1]]);

Implementors