rstsr_core/tensor/linalg/matrix_transpose.rs
1use crate::prelude_dev::*;
2
3/* #region matrix_transpose */
4
5/// Transposes a matrix (or a stack of matrices).
6///
7/// See also [`matrix_transpose`].
8pub fn into_matrix_transpose_f<S, D>(tensor: TensorBase<S, D>) -> Result<TensorBase<S, D>>
9where
10 D: DimAPI,
11{
12 into_swapaxes_f(tensor, -1, -2)
13}
14
15/// Transposes a matrix (or a stack of matrices).
16///
17/// Returns an array with the last two axes interchanged. This is equivalent
18/// to `swapaxes(-1, -2)`, but is provided as a convenience function for
19/// transposing matrices in multi-dimensional arrays.
20///
21/// For a 2-D array, this is equivalent to the standard matrix transpose.
22/// For higher-dimensional arrays, this transposes each matrix in a stack of
23/// matrices, leaving other axes unchanged.
24///
25/// # Examples
26///
27/// For a 2-D array, this is equivalent to the standard matrix transpose:
28///
29/// ```rust
30/// # use rstsr::prelude::*;
31/// # let mut device = DeviceCpu::default();
32/// # device.set_default_order(RowMajor);
33/// let x = rt::tensor_from_nested!([[1, 2], [3, 4]], &device);
34/// let result = x.matrix_transpose();
35/// println!("{result}");
36/// // [[ 1 3]
37/// // [ 2 4]]
38/// ```
39///
40/// For a 3-D array (a stack of matrices), each matrix is transposed independently:
41///
42/// ```rust
43/// # use rstsr::prelude::*;
44/// # let mut device = DeviceCpu::default();
45/// # device.set_default_order(RowMajor);
46/// let x = rt::tensor_from_nested!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
47/// let result = x.matrix_transpose();
48/// println!("{result}");
49/// // [[[ 1 3]
50/// // [ 2 4]]
51/// // [[ 5 7]
52/// // [ 6 8]]]
53/// ```
54///
55/// # Parameters
56///
57/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
58///
59/// - The input tensor to be transposed.
60///
61/// # Returns
62///
63/// - [`TensorView<'_, T, B, D>`](TensorView)
64///
65/// - A view of the input tensor with the last two axes interchanged.
66/// - No data is copied; only the shape and strides are modified.
67///
68/// # Notes of API accordance
69///
70/// - Array-API: `matrix_transpose(x)` ([`matrix_transpose`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.matrix_transpose.html))
71/// - NumPy: `numpy.linalg.matrix_transpose(x)` ([`numpy.matrix_transpose`](https://numpy.org/doc/stable/reference/generated/numpy.matrix_transpose.html))
72/// - RSTSR: `tensor.matrix_transpose()` or `rt::matrix_transpose(&tensor)`
73///
74/// Note that this is different from `T` (NumPy) / `t()` (RSTSR), which reverses
75/// all axes for n-dimensional arrays. This function only swaps the last two axes,
76/// which corresponds to `mT` in NumPy.
77///
78/// # See also
79///
80/// ## Related functions in RSTSR
81///
82/// - [`transpose`] - General axis permutation
83/// - [`swapaxes`] - Swap two specific axes
84/// - [`reverse_axes`] - Reverse all axes order
85///
86/// ## Variants of this function
87///
88/// - [`matrix_transpose`] / [`matrix_transpose_f`]: Returning a view.
89/// - [`into_matrix_transpose`] / [`into_matrix_transpose_f`]: Consuming version.
90///
91/// - Associated methods on `TensorAny`:
92///
93/// - [`TensorAny::matrix_transpose`] / [`TensorAny::matrix_transpose_f`]
94/// - [`TensorAny::into_matrix_transpose`] / [`TensorAny::into_matrix_transpose_f`]
95pub fn matrix_transpose<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, D>
96where
97 D: DimAPI,
98 R: DataAPI<Data = B::Raw>,
99 B: DeviceAPI<T>,
100{
101 into_swapaxes_f(tensor.view(), -1, -2).rstsr_unwrap()
102}
103
104/// Transposes a matrix (or a stack of matrices).
105///
106/// See also [`matrix_transpose`].
107pub fn matrix_transpose_f<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> Result<TensorView<'_, T, B, D>>
108where
109 D: DimAPI,
110 R: DataAPI<Data = B::Raw>,
111 B: DeviceAPI<T>,
112{
113 into_swapaxes_f(tensor.view(), -1, -2)
114}
115
116/// Transposes a matrix (or a stack of matrices).
117///
118/// See also [`matrix_transpose`].
119pub fn into_matrix_transpose<S, D>(tensor: TensorBase<S, D>) -> TensorBase<S, D>
120where
121 D: DimAPI,
122{
123 into_swapaxes_f(tensor, -1, -2).rstsr_unwrap()
124}
125
126impl<R, T, B, D> TensorAny<R, T, B, D>
127where
128 R: DataAPI<Data = B::Raw>,
129 B: DeviceAPI<T>,
130 D: DimAPI,
131{
132 /// Transposes a matrix (or a stack of matrices).
133 ///
134 /// See also [`matrix_transpose`].
135 pub fn matrix_transpose(&self) -> TensorView<'_, T, B, D> {
136 matrix_transpose(self)
137 }
138
139 /// Transposes a matrix (or a stack of matrices).
140 ///
141 /// See also [`matrix_transpose`].
142 pub fn matrix_transpose_f(&self) -> Result<TensorView<'_, T, B, D>> {
143 matrix_transpose_f(self)
144 }
145
146 /// Transposes a matrix (or a stack of matrices).
147 ///
148 /// See also [`matrix_transpose`].
149 pub fn into_matrix_transpose(self) -> TensorAny<R, T, B, D> {
150 into_matrix_transpose(self)
151 }
152
153 /// Transposes a matrix (or a stack of matrices).
154 ///
155 /// See also [`matrix_transpose`].
156 pub fn into_matrix_transpose_f(self) -> Result<TensorAny<R, T, B, D>> {
157 into_matrix_transpose_f(self)
158 }
159}
160
161/* #endregion */