rstsr_core/tensor/manuplication/transpose.rs
1use crate::prelude_dev::*;
2
3/* #region permute_dims */
4
5/// Permutes the axes (dimensions) of an array.
6///
7/// See also [`transpose`].
8pub fn into_transpose_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, D>>
9where
10 D: DimAPI,
11 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
12{
13 let axes = axes.try_into().map_err(Into::into)?;
14 match axes {
15 AxesIndex::None => Ok(into_reverse_axes(tensor)),
16 _ => {
17 let (storage, layout) = tensor.into_raw_parts();
18 let layout = layout.transpose(axes.as_ref())?;
19 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
20 },
21 }
22}
23
24/// Permutes the axes (dimensions) of an array.
25///
26/// Returns an array with axes transposed.
27///
28/// - For a 1-D array, this returns an unchanged view of the original array.
29/// - For a 2-D array, this is the standard matrix transpose.
30/// - For an n-D array, if axes are given, their order indicates how the axes are permuted (see
31/// Examples).
32///
33/// # Parameters
34///
35/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
36///
37/// - The input tensor whose axes are to be permuted.
38///
39/// - `axes`: TryInto [`AxesIndex<isize>`]
40///
41/// - The permutation of axes. If `None`, reverses the order of all axes (equivalent to
42/// [`reverse_axes`]).
43/// - Otherwise, `axes[i]` specifies the new position of axis `i` in the output.
44/// - The length of `axes` must match the number of dimensions of the input tensor.
45/// - Each axis must appear exactly once in `axes`.
46/// - Negative values are supported and indicate counting dimensions from the back.
47///
48/// # Returns
49///
50/// - [`TensorView<'_, T, B, D>`](TensorView)
51///
52/// - A view of the input tensor with permuted axes.
53/// - No data is copied; only the shape and strides are modified.
54///
55/// # Examples
56///
57/// For a 2-D array, this is the standard matrix transpose:
58///
59/// ```rust
60/// # use rstsr::prelude::*;
61/// # let mut device = DeviceCpu::default();
62/// # device.set_default_order(RowMajor);
63/// let a = rt::tensor_from_nested!([[1, 2], [3, 4]], &device);
64/// let result = a.transpose(None);
65/// println!("{result}");
66/// // [[ 1 3]
67/// // [ 2 4]]
68/// ```
69///
70/// For a 1-D array, this returns an unchanged view:
71///
72/// ```rust
73/// # use rstsr::prelude::*;
74/// # let mut device = DeviceCpu::default();
75/// # device.set_default_order(RowMajor);
76/// let a = rt::tensor_from_nested!([1, 2, 3, 4], &device);
77/// let result = a.transpose(None);
78/// println!("{result}");
79/// // [ 1 2 3 4]
80/// ```
81///
82/// For an n-D array, you can specify a custom permutation, or None for reverse order:
83///
84/// ```rust
85/// # use rstsr::prelude::*;
86/// # let mut device = DeviceCpu::default();
87/// # device.set_default_order(RowMajor);
88/// // 3-D tensor
89/// let a: Tensor<i32, _> = rt::ones(([1, 2, 3], &device));
90/// let result = a.transpose(None);
91/// println!("{:?}", result.shape());
92/// // [3, 2, 1]
93/// let result = a.transpose([1, 0, 2]);
94/// println!("{:?}", result.shape());
95/// // [2, 1, 3]
96///
97/// // 4-D tensor
98/// let a: Tensor<i32, _> = rt::ones(([2, 3, 4, 5], &device));
99/// let result = a.transpose(None);
100/// println!("{:?}", result.shape());
101/// // [5, 4, 3, 2]
102/// ```
103///
104/// Negative indices are also supported:
105///
106/// ```rust
107/// # use rstsr::prelude::*;
108/// # let mut device = DeviceCpu::default();
109/// # device.set_default_order(RowMajor);
110/// let a: Tensor<i32, _> = rt::arange((3 * 4 * 5, &device)).into_shape([3, 4, 5]);
111/// let result = a.transpose([-1, 0, -2]);
112/// println!("{:?}", result.shape());
113/// // [5, 3, 4]
114/// ```
115///
116/// # Notes of API accordance
117///
118/// - Array-API: `permute_dims(x, /, axes)` ([`permute_dims`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.permute_dims.html))
119/// - NumPy: `transpose(a, axes=None)` ([`numpy.transpose`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html))
120/// - RSTSR: `tensor.transpose(axes)` or `rt::transpose(&tensor, axes)`
121///
122/// Note that `axes=None` in NumPy/RSTSR reverses the order of all axes, which is equivalent to
123/// calling [`reverse_axes`] or [`TensorAny::t`] for 2D arrays.
124///
125/// # Panics
126///
127/// Panics if
128///
129/// - The length of `axes` does not match the number of dimensions of the input tensor.
130/// - Any axis index in `axes` is out of bounds (i.e., not in `[-ndim, ndim-1]`).
131/// - The `axes` array contains duplicate values (each axis must appear exactly once).
132///
133/// For a fallible version, use [`transpose_f`].
134///
135/// # See also
136///
137/// ## Related functions in RSTSR
138///
139/// - [`permute_dims`] - Alias for this function
140/// - [`reverse_axes`] - Reverse all axes order
141/// - [`swapaxes`] - Swap two specific axes
142///
143/// ## Variants of this function
144///
145/// - [`transpose`] / [`transpose_f`]: Returning a view.
146/// - [`into_transpose`] / [`into_transpose_f`]: Consuming version.
147///
148/// - Associated methods on `TensorAny`:
149///
150/// - [`TensorAny::transpose`] / [`TensorAny::transpose_f`]
151/// - [`TensorAny::into_transpose`] / [`TensorAny::into_transpose_f`]
152/// - [`TensorAny::t`] as shorthand for [`reverse_axes`]
153pub fn transpose<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, D>
154where
155 D: DimAPI,
156 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
157 R: DataAPI<Data = B::Raw>,
158 B: DeviceAPI<T>,
159{
160 into_transpose_f(tensor.view(), axes).rstsr_unwrap()
161}
162
163/// Permutes the axes (dimensions) of an array.
164///
165/// See also [`transpose`].
166pub fn transpose_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, D>>
167where
168 D: DimAPI,
169 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
170 R: DataAPI<Data = B::Raw>,
171 B: DeviceAPI<T>,
172{
173 into_transpose_f(tensor.view(), axes)
174}
175
176/// Permutes the axes (dimensions) of an array.
177///
178/// See also [`transpose`].
179pub fn into_transpose<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, D>
180where
181 D: DimAPI,
182 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
183{
184 into_transpose_f(tensor, axes).rstsr_unwrap()
185}
186
187pub use into_transpose as into_permute_dims;
188pub use into_transpose_f as into_permute_dims_f;
189pub use transpose as permute_dims;
190pub use transpose_f as permute_dims_f;
191
192impl<R, T, B, D> TensorAny<R, T, B, D>
193where
194 R: DataAPI<Data = B::Raw>,
195 B: DeviceAPI<T>,
196 D: DimAPI,
197{
198 /// Permutes the axes (dimensions) of an array `x`.
199 ///
200 /// # See also
201 ///
202 /// [`transpose`]
203 pub fn transpose<I>(&self, axes: I) -> TensorView<'_, T, B, D>
204 where
205 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
206 {
207 transpose(self, axes)
208 }
209
210 /// Permutes the axes (dimensions) of an array.
211 ///
212 /// See also [`transpose`].
213 pub fn transpose_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
214 where
215 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
216 {
217 transpose_f(self, axes)
218 }
219
220 /// Permutes the axes (dimensions) of an array `x`.
221 ///
222 /// # See also
223 ///
224 /// [`transpose`]
225 pub fn into_transpose<I>(self, axes: I) -> TensorAny<R, T, B, D>
226 where
227 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
228 {
229 into_transpose(self, axes)
230 }
231
232 /// Permutes the axes (dimensions) of an array.
233 ///
234 /// See also [`transpose`].
235 pub fn into_transpose_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
236 where
237 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
238 {
239 into_transpose_f(self, axes)
240 }
241
242 /// Permutes the axes (dimensions) of an array `x`.
243 ///
244 /// # See also
245 ///
246 /// [`transpose`]
247 pub fn permute_dims<I>(&self, axes: I) -> TensorView<'_, T, B, D>
248 where
249 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
250 {
251 transpose(self, axes)
252 }
253
254 /// Permutes the axes (dimensions) of an array.
255 ///
256 /// See also [`transpose`].
257 pub fn permute_dims_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
258 where
259 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
260 {
261 transpose_f(self, axes)
262 }
263
264 /// Permutes the axes (dimensions) of an array `x`.
265 ///
266 /// # See also
267 ///
268 /// [`transpose`]
269 pub fn into_permute_dims<I>(self, axes: I) -> TensorAny<R, T, B, D>
270 where
271 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
272 {
273 into_transpose(self, axes)
274 }
275
276 /// Permutes the axes (dimensions) of an array.
277 ///
278 /// See also [`transpose`].
279 pub fn into_permute_dims_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
280 where
281 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
282 {
283 into_transpose_f(self, axes)
284 }
285}
286
287/* #endregion */
288
289/* #region reverse_axes */
290
291/// Reverse the order of the axes (dimensions) of an array.
292///
293/// See also [`reverse_axes`].
294pub fn into_reverse_axes<S, D>(tensor: TensorBase<S, D>) -> TensorBase<S, D>
295where
296 D: DimAPI,
297{
298 let (storage, layout) = tensor.into_raw_parts();
299 let layout = layout.reverse_axes();
300 unsafe { TensorBase::new_unchecked(storage, layout) }
301}
302
303/// Reverse the order of the axes (dimensions) of an array.
304///
305/// Returns an array with the order of axes reversed.
306///
307/// For a 2-D array, this is equivalent to a matrix transpose. For
308/// higher-dimensional arrays, this reverses the axis order (e.g., for 3D with
309/// axes [0, 1, 2], the result has axes [2, 1, 0]).
310///
311/// This is by definition equivalent to `transpose(None)` or `tensor.t()`.
312///
313/// # Parameters
314///
315/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
316///
317/// - The input tensor whose axes are to be reversed.
318///
319/// # Returns
320///
321/// - [`TensorView<'_, T, B, D>`](TensorView)
322///
323/// - A view of the input tensor with reversed axis order.
324/// - No data is copied; only the shape and strides are modified.
325///
326/// # Examples
327///
328/// For a 2-D array, this is equivalent to a matrix transpose:
329///
330/// ```rust
331/// # use rstsr::prelude::*;
332/// # let mut device = DeviceCpu::default();
333/// # device.set_default_order(RowMajor);
334/// let a = rt::tensor_from_nested!([[1, 2], [3, 4]], &device);
335/// let result = a.reverse_axes();
336/// println!("{result}");
337/// // [[ 1 3]
338/// // [ 2 4]]
339/// ```
340///
341/// For a 1-D array, this returns an unchanged view:
342///
343/// ```rust
344/// # use rstsr::prelude::*;
345/// # let mut device = DeviceCpu::default();
346/// # device.set_default_order(RowMajor);
347/// let a = rt::tensor_from_nested!([1, 2, 3, 4], &device);
348/// let result = a.reverse_axes();
349/// println!("{result}");
350/// // [ 1 2 3 4]
351/// ```
352///
353/// For higher-dimensional arrays, the axis order is reversed:
354///
355/// ```rust
356/// # use rstsr::prelude::*;
357/// # let mut device = DeviceCpu::default();
358/// # device.set_default_order(RowMajor);
359/// // 3-D array: reverse_axes reverses all axis order
360/// let a = rt::tensor_from_nested!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
361/// println!("{:?}", a.shape());
362/// // [2, 2, 2]
363/// let result = a.reverse_axes();
364/// println!("{:?}", result.shape());
365/// // [2, 2, 2]
366/// // For [2,2,2] shape, reverse doesn't change shape but changes axis order
367///
368/// // 4-D array: reverse_axes shows clear shape change
369/// let a: Tensor<i32, _> = rt::ones(([2, 3, 4, 5], &device));
370/// let result = a.reverse_axes();
371/// println!("{:?}", result.shape());
372/// // [5, 4, 3, 2]
373/// ```
374///
375/// # Notes of API accordance
376///
377/// - NumPy: `transpose(a)` or `a.T` ([`numpy.transpose`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html))
378/// - RSTSR: `tensor.reverse_axes()` or `tensor.t()`
379///
380/// Also note for multiple-dimensional arrays, `a.T` (NumPy) is equivalent to `a.reverse_axes()`
381/// (RSTSR) (reverse all axes); but the `a.mT` (NumPy) is actually equivalent to `a.swapaxes(-1,
382/// -2)` (RSTSR) (only swap the last two axes).
383///
384/// # See also
385///
386/// ## Related functions in RSTSR
387///
388/// - [`transpose`] - General axis permutation
389/// - [`swapaxes`] - Swap two specific axes
390/// - [`TensorAny::t()`] - Shorthand for reverse axes
391///
392/// ## Variants of this function
393///
394/// Note that this function is by definition infallible, so no fallible version is provided.
395///
396/// - [`reverse_axes`]: Returning a view.
397/// - [`into_reverse_axes`]: Consuming version.
398///
399/// - Associated methods on `TensorAny`:
400///
401/// - [`TensorAny::reverse_axes`]
402/// - [`TensorAny::into_reverse_axes`]
403/// - [`TensorAny::t`] as shorthand for reverse axes
404pub fn reverse_axes<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, D>
405where
406 D: DimAPI,
407 R: DataAPI<Data = B::Raw>,
408 B: DeviceAPI<T>,
409{
410 into_reverse_axes(tensor.view())
411}
412
413impl<R, T, B, D> TensorAny<R, T, B, D>
414where
415 R: DataAPI<Data = B::Raw>,
416 B: DeviceAPI<T>,
417 D: DimAPI,
418{
419 /// Reverse the order of the axes (dimensions) of an array.
420 ///
421 /// See also [`reverse_axes`].
422 pub fn reverse_axes(&self) -> TensorView<'_, T, B, D> {
423 into_reverse_axes(self.view())
424 }
425
426 /// Reverse the order of the axes (dimensions) of an array.
427 ///
428 /// See also [`reverse_axes`].
429 pub fn into_reverse_axes(self) -> TensorAny<R, T, B, D> {
430 into_reverse_axes(self)
431 }
432
433 /// Reverse the order of the axes (dimensions) of an array.
434 ///
435 /// See also [`reverse_axes`].
436 pub fn t(&self) -> TensorView<'_, T, B, D> {
437 into_reverse_axes(self.view())
438 }
439}
440
441/* #endregion */
442
443/* #region swapaxes */
444
445/// Interchange two axes of an array.
446///
447/// See also [`swapaxes`].
448pub fn into_swapaxes_f<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> Result<TensorBase<S, D>>
449where
450 D: DimAPI,
451 I: TryInto<isize>,
452{
453 let axis1 = axis1.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
454 let axis2 = axis2.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
455 let (storage, layout) = tensor.into_raw_parts();
456 let layout = layout.swapaxes(axis1, axis2)?;
457 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
458}
459
460/// Interchange two axes of an array.
461///
462/// Returns an array with two axes interchanged. No data is copied; only the
463/// shape and strides are modified.
464///
465/// # Parameters
466///
467/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
468///
469/// - The input tensor whose axes are to be swapped.
470///
471/// - `axis1`: `impl TryInto<isize>`
472///
473/// - First axis to be swapped.
474/// - Negative values are supported and indicate counting dimensions from the back.
475///
476/// - `axis2`: `impl TryInto<isize>`
477///
478/// - Second axis to be swapped.
479/// - Negative values are supported and indicate counting dimensions from the back.
480///
481/// # Returns
482///
483/// - [`TensorView<'_, T, B, D>`](TensorView)
484///
485/// - A view of the input tensor with the two axes interchanged.
486/// - No data is copied; only the shape and strides are modified.
487///
488/// # Examples
489///
490/// For a 2-D array, swapping axes 0 and 1 is equivalent to transpose:
491///
492/// ```rust
493/// # use rstsr::prelude::*;
494/// # let mut device = DeviceCpu::default();
495/// # device.set_default_order(RowMajor);
496/// let x = rt::tensor_from_nested!([[1, 2, 3]], &device);
497/// let result = x.swapaxes(0, 1);
498/// println!("{result}");
499/// // [[ 1]
500/// // [ 2]
501/// // [ 3]]
502/// ```
503///
504/// For a 3-D array, swapping axes 0 and 2:
505///
506/// ```rust
507/// # use rstsr::prelude::*;
508/// # let mut device = DeviceCpu::default();
509/// # device.set_default_order(RowMajor);
510/// let x = rt::tensor_from_nested!([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], &device);
511/// let result = x.swapaxes(0, 2);
512/// println!("{result}");
513/// // [[[ 0 4]
514/// // [ 2 6]]
515/// // [[ 1 5]
516/// // [ 3 7]]]
517/// ```
518///
519/// Using negative indices to swap axes:
520///
521/// ```rust
522/// # use rstsr::prelude::*;
523/// # let mut device = DeviceCpu::default();
524/// # device.set_default_order(RowMajor);
525/// let x = rt::tensor_from_nested!([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], &device);
526/// let result = x.swapaxes(-1, -3);
527/// println!("{:?}", result.shape());
528/// // [2, 2, 2]
529/// ```
530///
531/// # Notes of API accordance
532///
533/// - NumPy: `swapaxes(a, axis1, axis2)` ([`numpy.swapaxes`](https://numpy.org/doc/stable/reference/generated/numpy.swapaxes.html))
534/// - RSTSR: `tensor.swapaxes(axis1, axis2)` or `rt::swapaxes(&tensor, axis1, axis2)`
535///
536/// # Panics
537///
538/// Panics if either `axis1` or `axis2` is out of bounds (i.e., not in `[-ndim, ndim-1]`).
539///
540/// For a fallible version, use [`swapaxes_f`].
541///
542/// # See also
543///
544/// ## Related functions in RSTSR
545///
546/// - [`transpose`] - General axis permutation
547/// - [`reverse_axes`] - Reverse all axes order
548///
549/// ## Variants of this function
550///
551/// - [`swapaxes`] / [`swapaxes_f`]: Returning a view.
552/// - [`into_swapaxes`] / [`into_swapaxes_f`]: Consuming version.
553///
554/// - Associated methods on `TensorAny`:
555///
556/// - [`TensorAny::swapaxes`] / [`TensorAny::swapaxes_f`]
557/// - [`TensorAny::into_swapaxes`] / [`TensorAny::into_swapaxes_f`]
558pub fn swapaxes<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
559where
560 D: DimAPI,
561 I: TryInto<isize>,
562 R: DataAPI<Data = B::Raw>,
563 B: DeviceAPI<T>,
564{
565 into_swapaxes_f(tensor.view(), axis1, axis2).rstsr_unwrap()
566}
567
568/// Interchange two axes of an array.
569///
570/// See also [`swapaxes`].
571pub fn swapaxes_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
572where
573 D: DimAPI,
574 I: TryInto<isize>,
575 R: DataAPI<Data = B::Raw>,
576 B: DeviceAPI<T>,
577{
578 into_swapaxes_f(tensor.view(), axis1, axis2)
579}
580
581/// Interchange two axes of an array.
582///
583/// See also [`swapaxes`].
584pub fn into_swapaxes<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> TensorBase<S, D>
585where
586 D: DimAPI,
587 I: TryInto<isize>,
588{
589 into_swapaxes_f(tensor, axis1, axis2).rstsr_unwrap()
590}
591
592impl<R, T, B, D> TensorAny<R, T, B, D>
593where
594 R: DataAPI<Data = B::Raw>,
595 B: DeviceAPI<T>,
596 D: DimAPI,
597{
598 /// Interchange two axes of an array.
599 ///
600 /// See also [`swapaxes`].
601 pub fn swapaxes<I>(&self, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
602 where
603 I: TryInto<isize>,
604 {
605 swapaxes(self, axis1, axis2)
606 }
607
608 /// Interchange two axes of an array.
609 ///
610 /// See also [`swapaxes`].
611 pub fn swapaxes_f<I>(&self, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
612 where
613 I: TryInto<isize>,
614 {
615 swapaxes_f(self, axis1, axis2)
616 }
617
618 /// Interchange two axes of an array.
619 ///
620 /// See also [`swapaxes`].
621 pub fn into_swapaxes<I>(self, axis1: I, axis2: I) -> TensorAny<R, T, B, D>
622 where
623 I: TryInto<isize>,
624 {
625 into_swapaxes(self, axis1, axis2)
626 }
627
628 /// Interchange two axes of an array.
629 ///
630 /// See also [`swapaxes`].
631 pub fn into_swapaxes_f<I>(self, axis1: I, axis2: I) -> Result<TensorAny<R, T, B, D>>
632 where
633 I: TryInto<isize>,
634 {
635 into_swapaxes_f(self, axis1, axis2)
636 }
637}
638
639/* #endregion */