Skip to main content

rstsr_core/tensor/linalg/
matrix_transpose.rs

1use crate::prelude_dev::*;
2
3/* #region matrix_transpose */
4
5/// Transposes a matrix (or a stack of matrices).
6///
7/// See also [`matrix_transpose`].
8pub fn into_matrix_transpose_f<S, D>(tensor: TensorBase<S, D>) -> Result<TensorBase<S, D>>
9where
10    D: DimAPI,
11{
12    into_swapaxes_f(tensor, -1, -2)
13}
14
15/// Transposes a matrix (or a stack of matrices).
16///
17/// Returns an array with the last two axes interchanged. This is equivalent
18/// to `swapaxes(-1, -2)`, but is provided as a convenience function for
19/// transposing matrices in multi-dimensional arrays.
20///
21/// For a 2-D array, this is equivalent to the standard matrix transpose.
22/// For higher-dimensional arrays, this transposes each matrix in a stack of
23/// matrices, leaving other axes unchanged.
24///
25/// # Examples
26///
27/// For a 2-D array, this is equivalent to the standard matrix transpose:
28///
29/// ```rust
30/// # use rstsr::prelude::*;
31/// # let mut device = DeviceCpu::default();
32/// # device.set_default_order(RowMajor);
33/// let x = rt::tensor_from_nested!([[1, 2], [3, 4]], &device);
34/// let result = x.matrix_transpose();
35/// println!("{result}");
36/// // [[ 1 3]
37/// //  [ 2 4]]
38/// ```
39///
40/// For a 3-D array (a stack of matrices), each matrix is transposed independently:
41///
42/// ```rust
43/// # use rstsr::prelude::*;
44/// # let mut device = DeviceCpu::default();
45/// # device.set_default_order(RowMajor);
46/// let x = rt::tensor_from_nested!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
47/// let result = x.matrix_transpose();
48/// println!("{result}");
49/// // [[[ 1 3]
50/// //   [ 2 4]]
51/// //  [[ 5 7]
52/// //   [ 6 8]]]
53/// ```
54///
55/// # Parameters
56///
57/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
58///
59///   - The input tensor to be transposed.
60///
61/// # Returns
62///
63/// - [`TensorView<'_, T, B, D>`](TensorView)
64///
65///   - A view of the input tensor with the last two axes interchanged.
66///   - No data is copied; only the shape and strides are modified.
67///
68/// # Notes of API accordance
69///
70/// - Array-API: `matrix_transpose(x)` ([`matrix_transpose`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.matrix_transpose.html))
71/// - NumPy: `numpy.linalg.matrix_transpose(x)` ([`numpy.matrix_transpose`](https://numpy.org/doc/stable/reference/generated/numpy.matrix_transpose.html))
72/// - RSTSR: `tensor.matrix_transpose()` or `rt::matrix_transpose(&tensor)`
73///
74/// Note that this is different from `T` (NumPy) / `t()` (RSTSR), which reverses
75/// all axes for n-dimensional arrays. This function only swaps the last two axes,
76/// which corresponds to `mT` in NumPy.
77///
78/// # See also
79///
80/// ## Related functions in RSTSR
81///
82/// - [`transpose`] - General axis permutation
83/// - [`swapaxes`] - Swap two specific axes
84/// - [`reverse_axes`] - Reverse all axes order
85///
86/// ## Variants of this function
87///
88/// - [`matrix_transpose`] / [`matrix_transpose_f`]: Returning a view.
89/// - [`into_matrix_transpose`] / [`into_matrix_transpose_f`]: Consuming version.
90///
91/// - Associated methods on `TensorAny`:
92///
93///   - [`TensorAny::matrix_transpose`] / [`TensorAny::matrix_transpose_f`]
94///   - [`TensorAny::into_matrix_transpose`] / [`TensorAny::into_matrix_transpose_f`]
95pub fn matrix_transpose<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, D>
96where
97    D: DimAPI,
98    R: DataAPI<Data = B::Raw>,
99    B: DeviceAPI<T>,
100{
101    into_swapaxes_f(tensor.view(), -1, -2).rstsr_unwrap()
102}
103
104/// Transposes a matrix (or a stack of matrices).
105///
106/// See also [`matrix_transpose`].
107pub fn matrix_transpose_f<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> Result<TensorView<'_, T, B, D>>
108where
109    D: DimAPI,
110    R: DataAPI<Data = B::Raw>,
111    B: DeviceAPI<T>,
112{
113    into_swapaxes_f(tensor.view(), -1, -2)
114}
115
116/// Transposes a matrix (or a stack of matrices).
117///
118/// See also [`matrix_transpose`].
119pub fn into_matrix_transpose<S, D>(tensor: TensorBase<S, D>) -> TensorBase<S, D>
120where
121    D: DimAPI,
122{
123    into_swapaxes_f(tensor, -1, -2).rstsr_unwrap()
124}
125
126impl<R, T, B, D> TensorAny<R, T, B, D>
127where
128    R: DataAPI<Data = B::Raw>,
129    B: DeviceAPI<T>,
130    D: DimAPI,
131{
132    /// Transposes a matrix (or a stack of matrices).
133    ///
134    /// See also [`matrix_transpose`].
135    pub fn matrix_transpose(&self) -> TensorView<'_, T, B, D> {
136        matrix_transpose(self)
137    }
138
139    /// Transposes a matrix (or a stack of matrices).
140    ///
141    /// See also [`matrix_transpose`].
142    pub fn matrix_transpose_f(&self) -> Result<TensorView<'_, T, B, D>> {
143        matrix_transpose_f(self)
144    }
145
146    /// Transposes a matrix (or a stack of matrices).
147    ///
148    /// See also [`matrix_transpose`].
149    pub fn into_matrix_transpose(self) -> TensorAny<R, T, B, D> {
150        into_matrix_transpose(self)
151    }
152
153    /// Transposes a matrix (or a stack of matrices).
154    ///
155    /// See also [`matrix_transpose`].
156    pub fn into_matrix_transpose_f(self) -> Result<TensorAny<R, T, B, D>> {
157        into_matrix_transpose_f(self)
158    }
159}
160
161/* #endregion */