Trait dfdx::tensor_ops::GatherTo
source · pub trait GatherTo<E, D: Storage<E> + Storage<usize>>: HasErr + HasShape {
// Required method
fn try_gather<Dst: Shape, Idx: Shape>(
self,
idx: Tensor<Idx, usize, D>
) -> Result<Self::WithShape<Dst>, Self::Err>
where Self::Shape: ReplaceDimTo<Dst, Idx>;
// Provided method
fn gather<Dst: Shape, Idx: Shape>(
self,
idx: Tensor<Idx, usize, D>
) -> Self::WithShape<Dst>
where Self::Shape: ReplaceDimTo<Dst, Idx> { ... }
}
Expand description
Select multiple values from a single axis, replacing that dimension
with a different one. Equivalent to torch.gather
from pytorch.
The shape of the index is the shape of the tensor up to the axis you want to select from, plus the size of the new dimension.
For example, given a tensor of shape (M, N, O), here are the required index shapes to gather each axis:
- Axis 0: index shape (Z, )
- Axis 1: index shape (M, Z)
- Axis 2: index shape (M, N, Z)
where Z
is the new dimension.
Here is an example gathering from a 2d tensor:
let a: Tensor<Rank2<3, 5>, f32, _> = dev.zeros();
// gather from the 0th axis; dimension 0 becomes 4
let idx: Tensor<Rank1<4>, usize, _> = dev.tensor([0, 0, 1, 2]);
let _: Tensor<Rank2<4, 5>, f32, _> = a.clone().gather(idx);
// gather from the 1st axis; dimension 1 becomes 2
let idx: Tensor<Rank2<3, 2>, usize, _> = dev.tensor([[0, 1], [2, 3], [4, 4]]);
let _: Tensor<Rank2<3, 2>, f32, _> = a.gather(idx);