Trait dfdx::tensor_ops::PermuteTo
source · pub trait PermuteTo: HasErr + HasShape {
// Required method
fn try_permute<Dst: Shape, Ax: Axes>(
self
) -> Result<Self::WithShape<Dst>, Self::Err>
where Self::Shape: PermuteShapeTo<Dst, Ax>;
// Provided method
fn permute<Dst: Shape, Ax: Axes>(self) -> Self::WithShape<Dst>
where Self::Shape: PermuteShapeTo<Dst, Ax> { ... }
}
Expand description
Changes order of dimensions/axes in a tensor.
pytorch equivalent: torch.permute
.
Option 1: Specifying shape generic:
let a: Tensor<Rank2<2, 3>, f32, _> = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let b: Tensor<Rank2<3, 2>, f32, _> = a.permute::<Rank2<3, 2>, _>();
assert_eq!(b.array(), [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]);
Option 2: Specifying axes generic:
let a: Tensor<Rank2<2, 3>, f32, _> = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let b: Tensor<Rank2<3, 2>, f32, _> = a.permute::<_, Axes2<1, 0>>();
assert_eq!(b.array(), [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]);
Required Methods§
sourcefn try_permute<Dst: Shape, Ax: Axes>(
self
) -> Result<Self::WithShape<Dst>, Self::Err>where
Self::Shape: PermuteShapeTo<Dst, Ax>,
fn try_permute<Dst: Shape, Ax: Axes>( self ) -> Result<Self::WithShape<Dst>, Self::Err>where Self::Shape: PermuteShapeTo<Dst, Ax>,
Fallible version of PermuteTo::permute