Trait dfdx::tensor_ops::GatherTo

source ·
pub trait GatherTo<E, D: Storage<E> + Storage<usize>>: HasErr + HasShape {
    // Required method
    fn try_gather<Dst: Shape, Idx: Shape>(
        self,
        idx: Tensor<Idx, usize, D>
    ) -> Result<Self::WithShape<Dst>, Self::Err>
       where Self::Shape: ReplaceDimTo<Dst, Idx>;

    // Provided method
    fn gather<Dst: Shape, Idx: Shape>(
        self,
        idx: Tensor<Idx, usize, D>
    ) -> Self::WithShape<Dst>
       where Self::Shape: ReplaceDimTo<Dst, Idx> { ... }
}
Expand description

Select multiple values from a single axis, replacing that dimension with a different one. Equivalent to torch.gather from pytorch.

The shape of the index is the shape of the tensor up to the axis you want to select from, plus the size of the new dimension.

For example, given a tensor of shape (M, N, O), here are the required index shapes to gather each axis:

  • Axis 0: index shape (Z, )
  • Axis 1: index shape (M, Z)
  • Axis 2: index shape (M, N, Z)

where Z is the new dimension.

Here is an example gathering from a 2d tensor:

let a: Tensor<Rank2<3, 5>, f32, _> = dev.zeros();

// gather from the 0th axis; dimension 0 becomes 4
let idx: Tensor<Rank1<4>, usize, _> = dev.tensor([0, 0, 1, 2]);
let _: Tensor<Rank2<4, 5>, f32, _> = a.clone().gather(idx);

// gather from the 1st axis; dimension 1 becomes 2
let idx: Tensor<Rank2<3, 2>, usize, _> = dev.tensor([[0, 1], [2, 3], [4, 4]]);
let _: Tensor<Rank2<3, 2>, f32, _> = a.gather(idx);

Required Methods§

source

fn try_gather<Dst: Shape, Idx: Shape>( self, idx: Tensor<Idx, usize, D> ) -> Result<Self::WithShape<Dst>, Self::Err>where Self::Shape: ReplaceDimTo<Dst, Idx>,

Provided Methods§

source

fn gather<Dst: Shape, Idx: Shape>( self, idx: Tensor<Idx, usize, D> ) -> Self::WithShape<Dst>where Self::Shape: ReplaceDimTo<Dst, Idx>,

Gather values given indices.

Implementors§

source§

impl<Src: Shape, E: Dtype, D: ReplaceDimKernel<E>, T: Tape<E, D>> GatherTo<E, D> for Tensor<Src, E, D, T>