math/tensor/
matrix_transpose.rs1use crate::tensor::{
2 borrow_tensor::BorrowTensor, has_tensor_shape_data::HasTensorShapeData,
3 AxisIndex,
4};
5
6pub trait MatrixTranspose<'a, Dtype: 'a>:
7 HasTensorShapeData<Dtype> + BorrowTensor<'a, Dtype> {
8 fn t(&'a self) -> <Self as BorrowTensor<'a, Dtype>>::Output {
10 let transposed_axes: Vec<AxisIndex> =
11 (0..self.shape().ndim()).into_iter().rev().collect();
12 let shape_transpose = self.shape().to_transposed(transposed_axes);
13 Self::create_borrowed_tensor(shape_transpose, &self.data())
14 }
15
16 fn transpose(
22 &'a self,
23 axes: Vec<AxisIndex>,
24 ) -> <Self as BorrowTensor<'a, Dtype>>::Output {
25 Self::create_borrowed_tensor(
26 self.shape().to_transposed(axes),
27 &self.data(),
28 )
29 }
30}
31
32impl<'a, Dtype: 'a, T> MatrixTranspose<'a, Dtype> for T where
33 T: HasTensorShapeData<Dtype> + BorrowTensor<'a, Dtype>
34{
35}
36
37#[cfg(test)]
38mod tests {
39 use crate::tensor::{
40 ephemeral_view::EphemeralView, matrix_transpose::MatrixTranspose,
41 tensor_shape::HasTensorShape, tensor_storage::IntoTensorStorage,
42 };
43
44 #[test]
45 fn test_transpose() {
46 {
47 let storage = vec![1, 2, 3, 4, 5, 6].into_tensor_storage();
48 let arr = EphemeralView::new(&storage, [2, 3]);
49 let arr_t = arr.t();
50 assert_eq!(arr_t.shape().dims(), vec![3, 2]);
51 assert_eq!(arr_t.shape().strides(), vec![1, 3]);
52 assert_eq!(arr_t.shape().ndim(), 2);
53 }
54 {
55 let storage = (0..24)
57 .into_iter()
58 .collect::<Vec<i32>>()
59 .into_tensor_storage();
60 let arr = EphemeralView::new(&storage, [2, 4, 3]);
61 let arr_t = arr.t();
62 assert_eq!(arr_t.shape().dims(), vec![3, 4, 2]);
63 assert_eq!(arr_t.shape().strides(), vec![1, 3, 12]);
64 assert_eq!(arr_t.shape().ndim(), 3);
65
66 let arr_t01 = arr.transpose(vec![1, 0, 2]);
67 assert_eq!(arr_t01.shape().dims(), vec![4, 2, 3]);
68 assert_eq!(arr_t01.shape().strides(), vec![3, 12, 1]);
69 assert_eq!(arr_t01.shape().ndim(), 3);
70 }
71 {
72 let storage = (0..120)
74 .into_iter()
75 .collect::<Vec<i32>>()
76 .into_tensor_storage();
77
78 let arr = EphemeralView::new(&storage, [2, 4, 3, 5]);
79 let arr_t = arr.transpose(vec![1, 3, 0, 2]);
80 assert_eq!(arr_t.shape().dims(), vec![4, 5, 2, 3]);
81 assert_eq!(arr_t.shape().strides(), vec![15, 1, 60, 5]);
82 assert_eq!(arr_t.shape().ndim(), 4);
83 }
84 }
85}