math/tensor/
matrix_transpose.rs

1use crate::tensor::{
2    borrow_tensor::BorrowTensor, has_tensor_shape_data::HasTensorShapeData,
3    AxisIndex,
4};
5
6pub trait MatrixTranspose<'a, Dtype: 'a>:
7    HasTensorShapeData<Dtype> + BorrowTensor<'a, Dtype> {
8    /// Reverses the axes.
9    fn t(&'a self) -> <Self as BorrowTensor<'a, Dtype>>::Output {
10        let transposed_axes: Vec<AxisIndex> =
11            (0..self.shape().ndim()).into_iter().rev().collect();
12        let shape_transpose = self.shape().to_transposed(transposed_axes);
13        Self::create_borrowed_tensor(shape_transpose, &self.data())
14    }
15
16    /// # Arguments
17    /// * `axes` - Must be the same length as `self.shape().ndim()`. For each
18    ///   `i`, `axes[i] = j`
19    /// means that the original `j`-th axis will be at the `i`-th axis in the
20    /// new shape.
21    fn transpose(
22        &'a self,
23        axes: Vec<AxisIndex>,
24    ) -> <Self as BorrowTensor<'a, Dtype>>::Output {
25        Self::create_borrowed_tensor(
26            self.shape().to_transposed(axes),
27            &self.data(),
28        )
29    }
30}
31
32impl<'a, Dtype: 'a, T> MatrixTranspose<'a, Dtype> for T where
33    T: HasTensorShapeData<Dtype> + BorrowTensor<'a, Dtype>
34{
35}
36
37#[cfg(test)]
38mod tests {
39    use crate::tensor::{
40        ephemeral_view::EphemeralView, matrix_transpose::MatrixTranspose,
41        tensor_shape::HasTensorShape, tensor_storage::IntoTensorStorage,
42    };
43
44    #[test]
45    fn test_transpose() {
46        {
47            let storage = vec![1, 2, 3, 4, 5, 6].into_tensor_storage();
48            let arr = EphemeralView::new(&storage, [2, 3]);
49            let arr_t = arr.t();
50            assert_eq!(arr_t.shape().dims(), vec![3, 2]);
51            assert_eq!(arr_t.shape().strides(), vec![1, 3]);
52            assert_eq!(arr_t.shape().ndim(), 2);
53        }
54        {
55            // the original stride is (12, 3, 1)
56            let storage = (0..24)
57                .into_iter()
58                .collect::<Vec<i32>>()
59                .into_tensor_storage();
60            let arr = EphemeralView::new(&storage, [2, 4, 3]);
61            let arr_t = arr.t();
62            assert_eq!(arr_t.shape().dims(), vec![3, 4, 2]);
63            assert_eq!(arr_t.shape().strides(), vec![1, 3, 12]);
64            assert_eq!(arr_t.shape().ndim(), 3);
65
66            let arr_t01 = arr.transpose(vec![1, 0, 2]);
67            assert_eq!(arr_t01.shape().dims(), vec![4, 2, 3]);
68            assert_eq!(arr_t01.shape().strides(), vec![3, 12, 1]);
69            assert_eq!(arr_t01.shape().ndim(), 3);
70        }
71        {
72            // the original stride is (60, 15, 5, 1)
73            let storage = (0..120)
74                .into_iter()
75                .collect::<Vec<i32>>()
76                .into_tensor_storage();
77
78            let arr = EphemeralView::new(&storage, [2, 4, 3, 5]);
79            let arr_t = arr.transpose(vec![1, 3, 0, 2]);
80            assert_eq!(arr_t.shape().dims(), vec![4, 5, 2, 3]);
81            assert_eq!(arr_t.shape().strides(), vec![15, 1, 60, 5]);
82            assert_eq!(arr_t.shape().ndim(), 4);
83        }
84    }
85}