1use crate::{Dim, NamedDims, NamedTensor};
2use burn_tensor::Tensor;
3use burn_tensor::backend::Backend;
4
5pub trait SwapDims<N, const D1: usize, const D2: usize> {
6 fn swap_dims(self) -> N;
7}
8
9impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
10where
11 ND: NamedDims<B, Tensor = Tensor<B, D>>,
12{
13 pub fn swap_dims<ND2, const D1: usize, const D2: usize>(self) -> NamedTensor<B, ND2>
15 where
16 ND2: NamedDims<B, Tensor = Tensor<B, D>>,
17 Self: SwapDims<NamedTensor<B, ND2>, D1, D2>,
18 {
19 SwapDims::swap_dims(self)
20 }
21}
22
23macro_rules! generate_permute {
24 (2 => $output:ty, ($dim1:expr, $dim2:expr)) => {
25 impl<B: Backend, D1: Dim, D2: Dim> SwapDims<NamedTensor<B, $output>, $dim1, $dim2>
26 for NamedTensor<B, (D1, D2)>
27 {
28 fn swap_dims(self) -> NamedTensor<B, $output> {
29 NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
30 }
31 }
32 };
33
34 (3 => $output:ty, ($dim1:expr, $dim2:expr)) => {
35 impl<B: Backend, D1: Dim, D2: Dim, D3: Dim> SwapDims<NamedTensor<B, $output>, $dim1, $dim2>
36 for NamedTensor<B, (D1, D2, D3)>
37 {
38 fn swap_dims(self) -> NamedTensor<B, $output> {
39 NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
40 }
41 }
42 };
43
44 (4 => $output:ty, ($dim1:expr, $dim2:expr)) => {
45 impl<B: Backend, D1: Dim, D2: Dim, D3: Dim, D4: Dim>
46 SwapDims<NamedTensor<B, $output>, $dim1, $dim2> for NamedTensor<B, (D1, D2, D3, D4)>
47 {
48 fn swap_dims(self) -> NamedTensor<B, $output> {
49 NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
50 }
51 }
52 };
53}
54
55generate_permute!(2 => (D2, D1), (0, 1));
56generate_permute!(3 => (D2, D1, D3), (0, 1));
57generate_permute!(3 => (D3, D2, D1), (0, 2));
58generate_permute!(3 => (D1, D3, D2), (1, 2));
59generate_permute!(4 => (D2, D1, D3, D4), (0, 1));
60generate_permute!(4 => (D3, D2, D1, D4), (0, 2));
61generate_permute!(4 => (D4, D2, D3, D1), (0, 3));
62generate_permute!(4 => (D1, D3, D2, D4), (1, 2));
63generate_permute!(4 => (D1, D4, D3, D2), (1, 3));