Skip to main content

burn_named_tensor/
swap_dims.rs

1use crate::{Dim, NamedDims, NamedTensor};
2use burn_tensor::Tensor;
3use burn_tensor::backend::Backend;
4
5pub trait SwapDims<N, const D1: usize, const D2: usize> {
6    fn swap_dims(self) -> N;
7}
8
9impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
10where
11    ND: NamedDims<B, Tensor = Tensor<B, D>>,
12{
13    /// Swap two dimensions.
14    pub fn swap_dims<ND2, const D1: usize, const D2: usize>(self) -> NamedTensor<B, ND2>
15    where
16        ND2: NamedDims<B, Tensor = Tensor<B, D>>,
17        Self: SwapDims<NamedTensor<B, ND2>, D1, D2>,
18    {
19        SwapDims::swap_dims(self)
20    }
21}
22
23macro_rules! generate_permute {
24    (2 => $output:ty, ($dim1:expr, $dim2:expr)) => {
25        impl<B: Backend, D1: Dim, D2: Dim> SwapDims<NamedTensor<B, $output>, $dim1, $dim2>
26            for NamedTensor<B, (D1, D2)>
27        {
28            fn swap_dims(self) -> NamedTensor<B, $output> {
29                NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
30            }
31        }
32    };
33
34    (3 => $output:ty, ($dim1:expr, $dim2:expr)) => {
35        impl<B: Backend, D1: Dim, D2: Dim, D3: Dim> SwapDims<NamedTensor<B, $output>, $dim1, $dim2>
36            for NamedTensor<B, (D1, D2, D3)>
37        {
38            fn swap_dims(self) -> NamedTensor<B, $output> {
39                NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
40            }
41        }
42    };
43
44    (4 => $output:ty, ($dim1:expr, $dim2:expr)) => {
45        impl<B: Backend, D1: Dim, D2: Dim, D3: Dim, D4: Dim>
46            SwapDims<NamedTensor<B, $output>, $dim1, $dim2> for NamedTensor<B, (D1, D2, D3, D4)>
47        {
48            fn swap_dims(self) -> NamedTensor<B, $output> {
49                NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
50            }
51        }
52    };
53}
54
55generate_permute!(2 => (D2, D1), (0, 1));
56generate_permute!(3 => (D2, D1, D3), (0, 1));
57generate_permute!(3 => (D3, D2, D1), (0, 2));
58generate_permute!(3 => (D1, D3, D2), (1, 2));
59generate_permute!(4 => (D2, D1, D3, D4), (0, 1));
60generate_permute!(4 => (D3, D2, D1, D4), (0, 2));
61generate_permute!(4 => (D4, D2, D3, D1), (0, 3));
62generate_permute!(4 => (D1, D3, D2, D4), (1, 2));
63generate_permute!(4 => (D1, D4, D3, D2), (1, 3));