rstsr_core/tensor/manuplication/
transpose.rs

1use crate::prelude_dev::*;
2
3/* #region permute_dims */
4
5pub fn into_transpose_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, D>>
6where
7    D: DimAPI,
8    I: TryInto<AxesIndex<isize>, Error = Error>,
9{
10    let axes = axes.try_into()?;
11    if axes.as_ref().is_empty() {
12        return Ok(into_reverse_axes(tensor));
13    }
14    let (storage, layout) = tensor.into_raw_parts();
15    let layout = layout.transpose(axes.as_ref())?;
16    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
17}
18
19/// Permutes the axes (dimensions) of an array `x`.
20///
21/// # See also
22///
23/// - [Python array API standard: `permute_dims`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.permute_dims.html)
24pub fn transpose<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, D>
25where
26    D: DimAPI,
27    I: TryInto<AxesIndex<isize>, Error = Error>,
28    R: DataAPI<Data = B::Raw>,
29    B: DeviceAPI<T>,
30{
31    into_transpose_f(tensor.view(), axes).rstsr_unwrap()
32}
33
34pub fn transpose_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, D>>
35where
36    D: DimAPI,
37    I: TryInto<AxesIndex<isize>, Error = Error>,
38    R: DataAPI<Data = B::Raw>,
39    B: DeviceAPI<T>,
40{
41    into_transpose_f(tensor.view(), axes)
42}
43
44pub fn into_transpose<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, D>
45where
46    D: DimAPI,
47    I: TryInto<AxesIndex<isize>, Error = Error>,
48{
49    into_transpose_f(tensor, axes).rstsr_unwrap()
50}
51
52pub use into_transpose as into_permute_dims;
53pub use into_transpose_f as into_permute_dims_f;
54pub use transpose as permute_dims;
55pub use transpose_f as permute_dims_f;
56
57impl<R, T, B, D> TensorAny<R, T, B, D>
58where
59    R: DataAPI<Data = B::Raw>,
60    B: DeviceAPI<T>,
61    D: DimAPI,
62{
63    /// Permutes the axes (dimensions) of an array `x`.
64    ///
65    /// # See also
66    ///
67    /// [`transpose`]
68    pub fn transpose<I>(&self, axes: I) -> TensorView<'_, T, B, D>
69    where
70        I: TryInto<AxesIndex<isize>, Error = Error>,
71    {
72        transpose(self, axes)
73    }
74
75    pub fn transpose_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
76    where
77        I: TryInto<AxesIndex<isize>, Error = Error>,
78    {
79        transpose_f(self, axes)
80    }
81
82    /// Permutes the axes (dimensions) of an array `x`.
83    ///
84    /// # See also
85    ///
86    /// [`transpose`]
87    pub fn into_transpose<I>(self, axes: I) -> TensorAny<R, T, B, D>
88    where
89        I: TryInto<AxesIndex<isize>, Error = Error>,
90    {
91        into_transpose(self, axes)
92    }
93
94    pub fn into_transpose_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
95    where
96        I: TryInto<AxesIndex<isize>, Error = Error>,
97    {
98        into_transpose_f(self, axes)
99    }
100
101    /// Permutes the axes (dimensions) of an array `x`.
102    ///
103    /// # See also
104    ///
105    /// [`transpose`]
106    pub fn permute_dims<I>(&self, axes: I) -> TensorView<'_, T, B, D>
107    where
108        I: TryInto<AxesIndex<isize>, Error = Error>,
109    {
110        transpose(self, axes)
111    }
112
113    pub fn permute_dims_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
114    where
115        I: TryInto<AxesIndex<isize>, Error = Error>,
116    {
117        transpose_f(self, axes)
118    }
119
120    /// Permutes the axes (dimensions) of an array `x`.
121    ///
122    /// # See also
123    ///
124    /// [`transpose`]
125    pub fn into_permute_dims<I>(self, axes: I) -> TensorAny<R, T, B, D>
126    where
127        I: TryInto<AxesIndex<isize>, Error = Error>,
128    {
129        into_transpose(self, axes)
130    }
131
132    pub fn into_permute_dims_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
133    where
134        I: TryInto<AxesIndex<isize>, Error = Error>,
135    {
136        into_transpose_f(self, axes)
137    }
138}
139
140/* #endregion */
141
142/* #region reverse_axes */
143
144pub fn into_reverse_axes<S, D>(tensor: TensorBase<S, D>) -> TensorBase<S, D>
145where
146    D: DimAPI,
147{
148    let (storage, layout) = tensor.into_raw_parts();
149    let layout = layout.reverse_axes();
150    unsafe { TensorBase::new_unchecked(storage, layout) }
151}
152
153/// Reverse the order of elements in an array along the given axis.
154pub fn reverse_axes<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, D>
155where
156    D: DimAPI,
157    R: DataAPI<Data = B::Raw>,
158    B: DeviceAPI<T>,
159{
160    into_reverse_axes(tensor.view())
161}
162
163impl<R, T, B, D> TensorAny<R, T, B, D>
164where
165    R: DataAPI<Data = B::Raw>,
166    B: DeviceAPI<T>,
167    D: DimAPI,
168{
169    /// Reverse the order of elements in an array along the given axis.
170    ///
171    /// # See also
172    ///
173    /// [`reverse_axes`]
174    pub fn reverse_axes(&self) -> TensorView<'_, T, B, D> {
175        into_reverse_axes(self.view())
176    }
177
178    /// Reverse the order of elements in an array along the given axis.
179    ///
180    /// # See also
181    ///
182    /// [`reverse_axes`]
183    pub fn into_reverse_axes(self) -> TensorAny<R, T, B, D> {
184        into_reverse_axes(self)
185    }
186
187    /// Reverse the order of elements in an array along the given axis.
188    ///
189    /// # See also
190    ///
191    /// [`reverse_axes`]
192    pub fn t(&self) -> TensorView<'_, T, B, D> {
193        into_reverse_axes(self.view())
194    }
195}
196
197/* #endregion */
198
199/* #region swapaxes */
200
201pub fn into_swapaxes_f<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> Result<TensorBase<S, D>>
202where
203    D: DimAPI,
204    I: TryInto<isize>,
205{
206    let axis1 = axis1.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
207    let axis2 = axis2.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
208    let (storage, layout) = tensor.into_raw_parts();
209    let layout = layout.swapaxes(axis1, axis2)?;
210    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
211}
212
213/// Interchange two axes of an array.
214///
215/// # See also
216///
217/// - [numpy `swapaxes`](https://numpy.org/doc/stable/reference/generated/numpy.swapaxes.html)
218pub fn swapaxes<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
219where
220    D: DimAPI,
221    I: TryInto<isize>,
222    R: DataAPI<Data = B::Raw>,
223    B: DeviceAPI<T>,
224{
225    into_swapaxes_f(tensor.view(), axis1, axis2).rstsr_unwrap()
226}
227
228pub fn swapaxes_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
229where
230    D: DimAPI,
231    I: TryInto<isize>,
232    R: DataAPI<Data = B::Raw>,
233    B: DeviceAPI<T>,
234{
235    into_swapaxes_f(tensor.view(), axis1, axis2)
236}
237
238pub fn into_swapaxes<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> TensorBase<S, D>
239where
240    D: DimAPI,
241    I: TryInto<isize>,
242{
243    into_swapaxes_f(tensor, axis1, axis2).rstsr_unwrap()
244}
245
246impl<R, T, B, D> TensorAny<R, T, B, D>
247where
248    R: DataAPI<Data = B::Raw>,
249    B: DeviceAPI<T>,
250    D: DimAPI,
251{
252    /// Interchange two axes of an array.
253    ///
254    /// # See also
255    ///
256    /// [`swapaxes`]
257    pub fn swapaxes<I>(&self, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
258    where
259        I: TryInto<isize>,
260    {
261        swapaxes(self, axis1, axis2)
262    }
263
264    pub fn swapaxes_f<I>(&self, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
265    where
266        I: TryInto<isize>,
267    {
268        swapaxes_f(self, axis1, axis2)
269    }
270
271    /// Interchange two axes of an array.
272    ///
273    /// # See also
274    ///
275    /// [`swapaxes`]
276    pub fn into_swapaxes<I>(self, axis1: I, axis2: I) -> TensorAny<R, T, B, D>
277    where
278        I: TryInto<isize>,
279    {
280        into_swapaxes(self, axis1, axis2)
281    }
282
283    pub fn into_swapaxes_f<I>(self, axis1: I, axis2: I) -> Result<TensorAny<R, T, B, D>>
284    where
285        I: TryInto<isize>,
286    {
287        into_swapaxes_f(self, axis1, axis2)
288    }
289}
290
291/* #endregion */