Trait dfdx::tensor_ops::SelectTo

source ·
pub trait SelectTo<E, D: Storage<E> + Storage<usize>>: HasErr + HasShape {
    // Required method
    fn try_select<Dst: Shape, Idx: Shape>(
        self,
        idx: Tensor<Idx, usize, D>
    ) -> Result<Self::WithShape<Dst>, Self::Err>
       where Self::Shape: RemoveDimTo<Dst, Idx>;

    // Provided method
    fn select<Dst: Shape, Idx: Shape>(
        self,
        idx: Tensor<Idx, usize, D>
    ) -> Self::WithShape<Dst>
       where Self::Shape: RemoveDimTo<Dst, Idx> { ... }
}
Expand description

Select a single value from a single dimension, removing that dimension from the shape. Equivalent to torch.select from pytorch.

The shape of the index is the shape of the tensor up to the axis you want to select from.

For example, given a tensor of shape (M, N, O), here are the required index shapes to select each axis:

  • Axis 0: index shape ()
  • Axis 1: index shape (M, )
  • Axis 2: index shape (M, N)

Here is an example selecting from a 2d tensor:

let a: Tensor<Rank2<3, 5>, f32, _> = dev.zeros();

// select from the 0th axis
let idx: Tensor<Rank0, usize, _> = dev.tensor(0);
let _: Tensor<Rank1<5>, f32, _> = a.clone().select(idx);

// select from the 1st axis
let idx: Tensor<Rank1<3>, usize, _> = dev.tensor([0, 2, 4]);
let _: Tensor<Rank1<3>, f32, _> = a.select(idx);

Required Methods§

source

fn try_select<Dst: Shape, Idx: Shape>( self, idx: Tensor<Idx, usize, D> ) -> Result<Self::WithShape<Dst>, Self::Err>where Self::Shape: RemoveDimTo<Dst, Idx>,

Fallible select

Provided Methods§

source

fn select<Dst: Shape, Idx: Shape>( self, idx: Tensor<Idx, usize, D> ) -> Self::WithShape<Dst>where Self::Shape: RemoveDimTo<Dst, Idx>,

Select values given indices.

Implementors§

source§

impl<Src: Shape, E: Dtype, D: RemoveDimKernel<E>, T: Tape<E, D>> SelectTo<E, D> for Tensor<Src, E, D, T>