Trait dfdx::tensor_ops::SelectTo
source · pub trait SelectTo<E, D: Storage<E> + Storage<usize>>: HasErr + HasShape {
// Required method
fn try_select<Dst: Shape, Idx: Shape>(
self,
idx: Tensor<Idx, usize, D>
) -> Result<Self::WithShape<Dst>, Self::Err>
where Self::Shape: RemoveDimTo<Dst, Idx>;
// Provided method
fn select<Dst: Shape, Idx: Shape>(
self,
idx: Tensor<Idx, usize, D>
) -> Self::WithShape<Dst>
where Self::Shape: RemoveDimTo<Dst, Idx> { ... }
}
Expand description
Select a single value from a single dimension, removing that dimension
from the shape. Equivalent to torch.select
from pytorch.
The shape of the index is the shape of the tensor up to the axis you want to select from.
For example, given a tensor of shape (M, N, O), here are the required index shapes to select each axis:
- Axis 0: index shape ()
- Axis 1: index shape (M, )
- Axis 2: index shape (M, N)
Here is an example selecting from a 2d tensor:
let a: Tensor<Rank2<3, 5>, f32, _> = dev.zeros();
// select from the 0th axis
let idx: Tensor<Rank0, usize, _> = dev.tensor(0);
let _: Tensor<Rank1<5>, f32, _> = a.clone().select(idx);
// select from the 1st axis
let idx: Tensor<Rank1<3>, usize, _> = dev.tensor([0, 2, 4]);
let _: Tensor<Rank1<3>, f32, _> = a.select(idx);