Skip to main content

rstsr_core/tensor/manuplication/
transpose.rs

1use crate::prelude_dev::*;
2
3/* #region permute_dims */
4
5/// Permutes the axes (dimensions) of an array.
6///
7/// See also [`transpose`].
8pub fn into_transpose_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, D>>
9where
10    D: DimAPI,
11    I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
12{
13    let axes = axes.try_into().map_err(Into::into)?;
14    match axes {
15        AxesIndex::None => Ok(into_reverse_axes(tensor)),
16        _ => {
17            let (storage, layout) = tensor.into_raw_parts();
18            let layout = layout.transpose(axes.as_ref())?;
19            unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
20        },
21    }
22}
23
24/// Permutes the axes (dimensions) of an array.
25///
26/// Returns an array with axes transposed.
27///
28/// - For a 1-D array, this returns an unchanged view of the original array.
29/// - For a 2-D array, this is the standard matrix transpose.
30/// - For an n-D array, if axes are given, their order indicates how the axes are permuted (see
31///   Examples).
32///
33/// # Parameters
34///
35/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
36///
37///   - The input tensor whose axes are to be permuted.
38///
39/// - `axes`: TryInto [`AxesIndex<isize>`]
40///
41///   - The permutation of axes. If `None`, reverses the order of all axes (equivalent to
42///     [`reverse_axes`]).
43///   - Otherwise, `axes[i]` specifies the new position of axis `i` in the output.
44///   - The length of `axes` must match the number of dimensions of the input tensor.
45///   - Each axis must appear exactly once in `axes`.
46///   - Negative values are supported and indicate counting dimensions from the back.
47///
48/// # Returns
49///
50/// - [`TensorView<'_, T, B, D>`](TensorView)
51///
52///   - A view of the input tensor with permuted axes.
53///   - No data is copied; only the shape and strides are modified.
54///
55/// # Examples
56///
57/// For a 2-D array, this is the standard matrix transpose:
58///
59/// ```rust
60/// # use rstsr::prelude::*;
61/// # let mut device = DeviceCpu::default();
62/// # device.set_default_order(RowMajor);
63/// let a = rt::tensor_from_nested!([[1, 2], [3, 4]], &device);
64/// let result = a.transpose(None);
65/// println!("{result}");
66/// // [[ 1 3]
67/// //  [ 2 4]]
68/// ```
69///
70/// For a 1-D array, this returns an unchanged view:
71///
72/// ```rust
73/// # use rstsr::prelude::*;
74/// # let mut device = DeviceCpu::default();
75/// # device.set_default_order(RowMajor);
76/// let a = rt::tensor_from_nested!([1, 2, 3, 4], &device);
77/// let result = a.transpose(None);
78/// println!("{result}");
79/// // [ 1 2 3 4]
80/// ```
81///
82/// For an n-D array, you can specify a custom permutation, or None for reverse order:
83///
84/// ```rust
85/// # use rstsr::prelude::*;
86/// # let mut device = DeviceCpu::default();
87/// # device.set_default_order(RowMajor);
88/// // 3-D tensor
89/// let a: Tensor<i32, _> = rt::ones(([1, 2, 3], &device));
90/// let result = a.transpose(None);
91/// println!("{:?}", result.shape());
92/// // [3, 2, 1]
93/// let result = a.transpose([1, 0, 2]);
94/// println!("{:?}", result.shape());
95/// // [2, 1, 3]
96///
97/// // 4-D tensor
98/// let a: Tensor<i32, _> = rt::ones(([2, 3, 4, 5], &device));
99/// let result = a.transpose(None);
100/// println!("{:?}", result.shape());
101/// // [5, 4, 3, 2]
102/// ```
103///
104/// Negative indices are also supported:
105///
106/// ```rust
107/// # use rstsr::prelude::*;
108/// # let mut device = DeviceCpu::default();
109/// # device.set_default_order(RowMajor);
110/// let a: Tensor<i32, _> = rt::arange((3 * 4 * 5, &device)).into_shape([3, 4, 5]);
111/// let result = a.transpose([-1, 0, -2]);
112/// println!("{:?}", result.shape());
113/// // [5, 3, 4]
114/// ```
115///
116/// # Notes of API accordance
117///
118/// - Array-API: `permute_dims(x, /, axes)` ([`permute_dims`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.permute_dims.html))
119/// - NumPy: `transpose(a, axes=None)` ([`numpy.transpose`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html))
120/// - RSTSR: `tensor.transpose(axes)` or `rt::transpose(&tensor, axes)`
121///
122/// Note that `axes=None` in NumPy/RSTSR reverses the order of all axes, which is equivalent to
123/// calling [`reverse_axes`] or [`TensorAny::t`] for 2D arrays.
124///
125/// # Panics
126///
127/// Panics if
128///
129/// - The length of `axes` does not match the number of dimensions of the input tensor.
130/// - Any axis index in `axes` is out of bounds (i.e., not in `[-ndim, ndim-1]`).
131/// - The `axes` array contains duplicate values (each axis must appear exactly once).
132///
133/// For a fallible version, use [`transpose_f`].
134///
135/// # See also
136///
137/// ## Related functions in RSTSR
138///
139/// - [`permute_dims`] - Alias for this function
140/// - [`reverse_axes`] - Reverse all axes order
141/// - [`swapaxes`] - Swap two specific axes
142///
143/// ## Variants of this function
144///
145/// - [`transpose`] / [`transpose_f`]: Returning a view.
146/// - [`into_transpose`] / [`into_transpose_f`]: Consuming version.
147///
148/// - Associated methods on `TensorAny`:
149///
150///   - [`TensorAny::transpose`] / [`TensorAny::transpose_f`]
151///   - [`TensorAny::into_transpose`] / [`TensorAny::into_transpose_f`]
152///   - [`TensorAny::t`] as shorthand for [`reverse_axes`]
153pub fn transpose<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, D>
154where
155    D: DimAPI,
156    I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
157    R: DataAPI<Data = B::Raw>,
158    B: DeviceAPI<T>,
159{
160    into_transpose_f(tensor.view(), axes).rstsr_unwrap()
161}
162
163/// Permutes the axes (dimensions) of an array.
164///
165/// See also [`transpose`].
166pub fn transpose_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, D>>
167where
168    D: DimAPI,
169    I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
170    R: DataAPI<Data = B::Raw>,
171    B: DeviceAPI<T>,
172{
173    into_transpose_f(tensor.view(), axes)
174}
175
176/// Permutes the axes (dimensions) of an array.
177///
178/// See also [`transpose`].
179pub fn into_transpose<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, D>
180where
181    D: DimAPI,
182    I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
183{
184    into_transpose_f(tensor, axes).rstsr_unwrap()
185}
186
187pub use into_transpose as into_permute_dims;
188pub use into_transpose_f as into_permute_dims_f;
189pub use transpose as permute_dims;
190pub use transpose_f as permute_dims_f;
191
192impl<R, T, B, D> TensorAny<R, T, B, D>
193where
194    R: DataAPI<Data = B::Raw>,
195    B: DeviceAPI<T>,
196    D: DimAPI,
197{
198    /// Permutes the axes (dimensions) of an array `x`.
199    ///
200    /// # See also
201    ///
202    /// [`transpose`]
203    pub fn transpose<I>(&self, axes: I) -> TensorView<'_, T, B, D>
204    where
205        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
206    {
207        transpose(self, axes)
208    }
209
210    /// Permutes the axes (dimensions) of an array.
211    ///
212    /// See also [`transpose`].
213    pub fn transpose_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
214    where
215        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
216    {
217        transpose_f(self, axes)
218    }
219
220    /// Permutes the axes (dimensions) of an array `x`.
221    ///
222    /// # See also
223    ///
224    /// [`transpose`]
225    pub fn into_transpose<I>(self, axes: I) -> TensorAny<R, T, B, D>
226    where
227        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
228    {
229        into_transpose(self, axes)
230    }
231
232    /// Permutes the axes (dimensions) of an array.
233    ///
234    /// See also [`transpose`].
235    pub fn into_transpose_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
236    where
237        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
238    {
239        into_transpose_f(self, axes)
240    }
241
242    /// Permutes the axes (dimensions) of an array `x`.
243    ///
244    /// # See also
245    ///
246    /// [`transpose`]
247    pub fn permute_dims<I>(&self, axes: I) -> TensorView<'_, T, B, D>
248    where
249        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
250    {
251        transpose(self, axes)
252    }
253
254    /// Permutes the axes (dimensions) of an array.
255    ///
256    /// See also [`transpose`].
257    pub fn permute_dims_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
258    where
259        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
260    {
261        transpose_f(self, axes)
262    }
263
264    /// Permutes the axes (dimensions) of an array `x`.
265    ///
266    /// # See also
267    ///
268    /// [`transpose`]
269    pub fn into_permute_dims<I>(self, axes: I) -> TensorAny<R, T, B, D>
270    where
271        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
272    {
273        into_transpose(self, axes)
274    }
275
276    /// Permutes the axes (dimensions) of an array.
277    ///
278    /// See also [`transpose`].
279    pub fn into_permute_dims_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
280    where
281        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
282    {
283        into_transpose_f(self, axes)
284    }
285}
286
287/* #endregion */
288
289/* #region reverse_axes */
290
291/// Reverse the order of the axes (dimensions) of an array.
292///
293/// See also [`reverse_axes`].
294pub fn into_reverse_axes<S, D>(tensor: TensorBase<S, D>) -> TensorBase<S, D>
295where
296    D: DimAPI,
297{
298    let (storage, layout) = tensor.into_raw_parts();
299    let layout = layout.reverse_axes();
300    unsafe { TensorBase::new_unchecked(storage, layout) }
301}
302
303/// Reverse the order of the axes (dimensions) of an array.
304///
305/// Returns an array with the order of axes reversed.
306///
307/// For a 2-D array, this is equivalent to a matrix transpose. For
308/// higher-dimensional arrays, this reverses the axis order (e.g., for 3D with
309/// axes [0, 1, 2], the result has axes [2, 1, 0]).
310///
311/// This is by definition equivalent to `transpose(None)` or `tensor.t()`.
312///
313/// # Parameters
314///
315/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
316///
317///   - The input tensor whose axes are to be reversed.
318///
319/// # Returns
320///
321/// - [`TensorView<'_, T, B, D>`](TensorView)
322///
323///   - A view of the input tensor with reversed axis order.
324///   - No data is copied; only the shape and strides are modified.
325///
326/// # Examples
327///
328/// For a 2-D array, this is equivalent to a matrix transpose:
329///
330/// ```rust
331/// # use rstsr::prelude::*;
332/// # let mut device = DeviceCpu::default();
333/// # device.set_default_order(RowMajor);
334/// let a = rt::tensor_from_nested!([[1, 2], [3, 4]], &device);
335/// let result = a.reverse_axes();
336/// println!("{result}");
337/// // [[ 1 3]
338/// //  [ 2 4]]
339/// ```
340///
341/// For a 1-D array, this returns an unchanged view:
342///
343/// ```rust
344/// # use rstsr::prelude::*;
345/// # let mut device = DeviceCpu::default();
346/// # device.set_default_order(RowMajor);
347/// let a = rt::tensor_from_nested!([1, 2, 3, 4], &device);
348/// let result = a.reverse_axes();
349/// println!("{result}");
350/// // [ 1 2 3 4]
351/// ```
352///
353/// For higher-dimensional arrays, the axis order is reversed:
354///
355/// ```rust
356/// # use rstsr::prelude::*;
357/// # let mut device = DeviceCpu::default();
358/// # device.set_default_order(RowMajor);
359/// // 3-D array: reverse_axes reverses all axis order
360/// let a = rt::tensor_from_nested!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
361/// println!("{:?}", a.shape());
362/// // [2, 2, 2]
363/// let result = a.reverse_axes();
364/// println!("{:?}", result.shape());
365/// // [2, 2, 2]
366/// // For [2,2,2] shape, reverse doesn't change shape but changes axis order
367///
368/// // 4-D array: reverse_axes shows clear shape change
369/// let a: Tensor<i32, _> = rt::ones(([2, 3, 4, 5], &device));
370/// let result = a.reverse_axes();
371/// println!("{:?}", result.shape());
372/// // [5, 4, 3, 2]
373/// ```
374///
375/// # Notes of API accordance
376///
377/// - NumPy: `transpose(a)` or `a.T` ([`numpy.transpose`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html))
378/// - RSTSR: `tensor.reverse_axes()` or `tensor.t()`
379///
380/// Also note for multiple-dimensional arrays, `a.T` (NumPy) is equivalent to `a.reverse_axes()`
381/// (RSTSR) (reverse all axes); but the `a.mT` (NumPy) is actually equivalent to `a.swapaxes(-1,
382/// -2)` (RSTSR) (only swap the last two axes).
383///
384/// # See also
385///
386/// ## Related functions in RSTSR
387///
388/// - [`transpose`] - General axis permutation
389/// - [`swapaxes`] - Swap two specific axes
390/// - [`TensorAny::t()`] - Shorthand for reverse axes
391///
392/// ## Variants of this function
393///
394/// Note that this function is by definition infallible, so no fallible version is provided.
395///
396/// - [`reverse_axes`]: Returning a view.
397/// - [`into_reverse_axes`]: Consuming version.
398///
399/// - Associated methods on `TensorAny`:
400///
401///   - [`TensorAny::reverse_axes`]
402///   - [`TensorAny::into_reverse_axes`]
403///   - [`TensorAny::t`] as shorthand for reverse axes
404pub fn reverse_axes<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, D>
405where
406    D: DimAPI,
407    R: DataAPI<Data = B::Raw>,
408    B: DeviceAPI<T>,
409{
410    into_reverse_axes(tensor.view())
411}
412
413impl<R, T, B, D> TensorAny<R, T, B, D>
414where
415    R: DataAPI<Data = B::Raw>,
416    B: DeviceAPI<T>,
417    D: DimAPI,
418{
419    /// Reverse the order of the axes (dimensions) of an array.
420    ///
421    /// See also [`reverse_axes`].
422    pub fn reverse_axes(&self) -> TensorView<'_, T, B, D> {
423        into_reverse_axes(self.view())
424    }
425
426    /// Reverse the order of the axes (dimensions) of an array.
427    ///
428    /// See also [`reverse_axes`].
429    pub fn into_reverse_axes(self) -> TensorAny<R, T, B, D> {
430        into_reverse_axes(self)
431    }
432
433    /// Reverse the order of the axes (dimensions) of an array.
434    ///
435    /// See also [`reverse_axes`].
436    pub fn t(&self) -> TensorView<'_, T, B, D> {
437        into_reverse_axes(self.view())
438    }
439}
440
441/* #endregion */
442
443/* #region swapaxes */
444
445/// Interchange two axes of an array.
446///
447/// See also [`swapaxes`].
448pub fn into_swapaxes_f<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> Result<TensorBase<S, D>>
449where
450    D: DimAPI,
451    I: TryInto<isize>,
452{
453    let axis1 = axis1.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
454    let axis2 = axis2.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
455    let (storage, layout) = tensor.into_raw_parts();
456    let layout = layout.swapaxes(axis1, axis2)?;
457    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
458}
459
460/// Interchange two axes of an array.
461///
462/// Returns an array with two axes interchanged. No data is copied; only the
463/// shape and strides are modified.
464///
465/// # Parameters
466///
467/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
468///
469///   - The input tensor whose axes are to be swapped.
470///
471/// - `axis1`: `impl TryInto<isize>`
472///
473///   - First axis to be swapped.
474///   - Negative values are supported and indicate counting dimensions from the back.
475///
476/// - `axis2`: `impl TryInto<isize>`
477///
478///   - Second axis to be swapped.
479///   - Negative values are supported and indicate counting dimensions from the back.
480///
481/// # Returns
482///
483/// - [`TensorView<'_, T, B, D>`](TensorView)
484///
485///   - A view of the input tensor with the two axes interchanged.
486///   - No data is copied; only the shape and strides are modified.
487///
488/// # Examples
489///
490/// For a 2-D array, swapping axes 0 and 1 is equivalent to transpose:
491///
492/// ```rust
493/// # use rstsr::prelude::*;
494/// # let mut device = DeviceCpu::default();
495/// # device.set_default_order(RowMajor);
496/// let x = rt::tensor_from_nested!([[1, 2, 3]], &device);
497/// let result = x.swapaxes(0, 1);
498/// println!("{result}");
499/// // [[ 1]
500/// //  [ 2]
501/// //  [ 3]]
502/// ```
503///
504/// For a 3-D array, swapping axes 0 and 2:
505///
506/// ```rust
507/// # use rstsr::prelude::*;
508/// # let mut device = DeviceCpu::default();
509/// # device.set_default_order(RowMajor);
510/// let x = rt::tensor_from_nested!([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], &device);
511/// let result = x.swapaxes(0, 2);
512/// println!("{result}");
513/// // [[[ 0 4]
514/// //   [ 2 6]]
515/// //  [[ 1 5]
516/// //   [ 3 7]]]
517/// ```
518///
519/// Using negative indices to swap axes:
520///
521/// ```rust
522/// # use rstsr::prelude::*;
523/// # let mut device = DeviceCpu::default();
524/// # device.set_default_order(RowMajor);
525/// let x = rt::tensor_from_nested!([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], &device);
526/// let result = x.swapaxes(-1, -3);
527/// println!("{:?}", result.shape());
528/// // [2, 2, 2]
529/// ```
530///
531/// # Notes of API accordance
532///
533/// - NumPy: `swapaxes(a, axis1, axis2)` ([`numpy.swapaxes`](https://numpy.org/doc/stable/reference/generated/numpy.swapaxes.html))
534/// - RSTSR: `tensor.swapaxes(axis1, axis2)` or `rt::swapaxes(&tensor, axis1, axis2)`
535///
536/// # Panics
537///
538/// Panics if either `axis1` or `axis2` is out of bounds (i.e., not in `[-ndim, ndim-1]`).
539///
540/// For a fallible version, use [`swapaxes_f`].
541///
542/// # See also
543///
544/// ## Related functions in RSTSR
545///
546/// - [`transpose`] - General axis permutation
547/// - [`reverse_axes`] - Reverse all axes order
548///
549/// ## Variants of this function
550///
551/// - [`swapaxes`] / [`swapaxes_f`]: Returning a view.
552/// - [`into_swapaxes`] / [`into_swapaxes_f`]: Consuming version.
553///
554/// - Associated methods on `TensorAny`:
555///
556///   - [`TensorAny::swapaxes`] / [`TensorAny::swapaxes_f`]
557///   - [`TensorAny::into_swapaxes`] / [`TensorAny::into_swapaxes_f`]
558pub fn swapaxes<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
559where
560    D: DimAPI,
561    I: TryInto<isize>,
562    R: DataAPI<Data = B::Raw>,
563    B: DeviceAPI<T>,
564{
565    into_swapaxes_f(tensor.view(), axis1, axis2).rstsr_unwrap()
566}
567
568/// Interchange two axes of an array.
569///
570/// See also [`swapaxes`].
571pub fn swapaxes_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
572where
573    D: DimAPI,
574    I: TryInto<isize>,
575    R: DataAPI<Data = B::Raw>,
576    B: DeviceAPI<T>,
577{
578    into_swapaxes_f(tensor.view(), axis1, axis2)
579}
580
581/// Interchange two axes of an array.
582///
583/// See also [`swapaxes`].
584pub fn into_swapaxes<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> TensorBase<S, D>
585where
586    D: DimAPI,
587    I: TryInto<isize>,
588{
589    into_swapaxes_f(tensor, axis1, axis2).rstsr_unwrap()
590}
591
592impl<R, T, B, D> TensorAny<R, T, B, D>
593where
594    R: DataAPI<Data = B::Raw>,
595    B: DeviceAPI<T>,
596    D: DimAPI,
597{
598    /// Interchange two axes of an array.
599    ///
600    /// See also [`swapaxes`].
601    pub fn swapaxes<I>(&self, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
602    where
603        I: TryInto<isize>,
604    {
605        swapaxes(self, axis1, axis2)
606    }
607
608    /// Interchange two axes of an array.
609    ///
610    /// See also [`swapaxes`].
611    pub fn swapaxes_f<I>(&self, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
612    where
613        I: TryInto<isize>,
614    {
615        swapaxes_f(self, axis1, axis2)
616    }
617
618    /// Interchange two axes of an array.
619    ///
620    /// See also [`swapaxes`].
621    pub fn into_swapaxes<I>(self, axis1: I, axis2: I) -> TensorAny<R, T, B, D>
622    where
623        I: TryInto<isize>,
624    {
625        into_swapaxes(self, axis1, axis2)
626    }
627
628    /// Interchange two axes of an array.
629    ///
630    /// See also [`swapaxes`].
631    pub fn into_swapaxes_f<I>(self, axis1: I, axis2: I) -> Result<TensorAny<R, T, B, D>>
632    where
633        I: TryInto<isize>,
634    {
635        into_swapaxes_f(self, axis1, axis2)
636    }
637}
638
639/* #endregion */