Skip to main content

rstsr_core/tensor/manuplication/
moveaxis.rs

1use crate::prelude_dev::*;
2
3/// Moves array axes (dimensions) to new positions, while leaving other axes in their original
4/// positions.
5///
6/// See also [`moveaxis`].
7pub fn into_moveaxis_f<IS, ID, S, D>(tensor: TensorBase<S, D>, source: IS, destination: ID) -> Result<TensorBase<S, D>>
8where
9    D: DimAPI,
10    IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
11    ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
12{
13    let source = source.try_into().map_err(Into::into)?;
14    let destination = destination.try_into().map_err(Into::into)?;
15
16    let ndim = tensor.ndim();
17
18    // Normalize axes
19    let source = normalize_axes_index(source, tensor.ndim(), false, false)?;
20    let destination = normalize_axes_index(destination, tensor.ndim(), false, false)?;
21
22    // Check that source and destination have the same length
23    rstsr_assert_eq!(
24        source.len(),
25        destination.len(),
26        InvalidValue,
27        "`source` and `destination` arguments must have the same number of elements"
28    )?;
29
30    // Build the permutation order
31    // Start with all axes that are not in source
32    let mut order: Vec<isize> = (0..ndim as isize).filter(|&i| !source.contains(&i)).collect();
33
34    // Insert source axes at their destination positions
35    // Sort pairs by destination to insert in correct order
36    let mut pairs: Vec<(isize, isize)> = destination.iter().zip(source.iter()).map(|(&d, &s)| (d, s)).collect();
37    pairs.sort_by_key(|&(d, _)| d);
38
39    for (dest, src) in pairs {
40        order.insert(dest as usize, src);
41    }
42
43    // Apply the transpose
44    let (storage, layout) = tensor.into_raw_parts();
45    let layout = layout.transpose(&order)?;
46    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
47}
48
49/// Moves array axes (dimensions) to new positions, while leaving other axes in their original
50/// positions.
51///
52/// Returns an array with axes moved to new positions. Other axes remain in their
53/// original order. This is a view operation; no data is copied.
54///
55/// # Parameters
56///
57/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
58///
59///   - The input tensor whose axes are to be moved.
60///
61/// - `source`: TryInto [`AxesIndex<isize>`]
62///
63///   - Original positions of the axes to move. These must be unique.
64///   - Can be a single axis index or a sequence of axis indices.
65///   - Negative values are supported and indicate counting dimensions from the back.
66///
67/// - `destination`: TryInto [`AxesIndex<isize>`]
68///
69///   - Destination positions for each of the original axes. These must also be unique.
70///   - Can be a single axis index or a sequence of axis indices.
71///   - Negative values are supported and indicate counting dimensions from the back.
72///   - Must have the same number of elements as `source`.
73///
74/// # Returns
75///
76/// - [`TensorView<'_, T, B, D>`](TensorView)
77///
78///   - A view of the input tensor with moved axes.
79///   - No data is copied; only the shape and strides are modified.
80///
81/// # Examples
82///
83/// Move a single axis to a new position:
84///
85/// ```rust
86/// # use rstsr::prelude::*;
87/// # let mut device = DeviceCpu::default();
88/// # device.set_default_order(RowMajor);
89/// let x: Tensor<f64, _> = rt::zeros(([3, 4, 5], &device));
90/// let result = x.moveaxis(0, -1);
91/// println!("{:?}", result.shape());
92/// // [4, 5, 3]
93/// ```
94///
95/// Move multiple axes to new positions:
96///
97/// ```rust
98/// # use rstsr::prelude::*;
99/// # let mut device = DeviceCpu::default();
100/// # device.set_default_order(RowMajor);
101/// let x: Tensor<f64, _> = rt::zeros(([3, 4, 5], &device));
102/// let result = x.moveaxis([0, 1], [-1, -2]);
103/// println!("{:?}", result.shape());
104/// // [5, 4, 3]
105/// ```
106///
107/// Using negative indices:
108///
109/// ```rust
110/// # use rstsr::prelude::*;
111/// # let mut device = DeviceCpu::default();
112/// # device.set_default_order(RowMajor);
113/// let x: Tensor<f64, _> = rt::zeros(([3, 4, 5], &device));
114/// let result = x.moveaxis(-1, 0);
115/// println!("{:?}", result.shape());
116/// // [5, 3, 4]
117/// ```
118///
119/// # Notes of API accordance
120///
121/// - Array-API: `moveaxis(x, source, destination, /)` ([`moveaxis`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.moveaxis.html))
122/// - NumPy: `moveaxis(a, source, destination)` ([`numpy.moveaxis`](https://numpy.org/doc/stable/reference/generated/numpy.moveaxis.html))
123/// - RSTSR: `tensor.moveaxis(source, destination)` or `rt::moveaxis(&tensor, source, destination)`
124///
125/// # See also
126///
127/// ## Related functions in RSTSR
128///
129/// - [`transpose`] - General axis permutation
130/// - [`swapaxes`] - Swap two specific axes
131/// - [`reverse_axes`] - Reverse all axes order
132///
133/// ## Variants of this function
134///
135/// - [`moveaxis`] / [`moveaxis_f`]: Returning a view.
136/// - [`into_moveaxis`] / [`into_moveaxis_f`]: Consuming version.
137///
138/// - Associated methods on `TensorAny`:
139///
140///   - [`TensorAny::moveaxis`] / [`TensorAny::moveaxis_f`]
141///   - [`TensorAny::into_moveaxis`] / [`TensorAny::into_moveaxis_f`]
142pub fn moveaxis<IS, ID, R, T, B, D>(
143    tensor: &TensorAny<R, T, B, D>,
144    source: IS,
145    destination: ID,
146) -> TensorView<'_, T, B, D>
147where
148    D: DimAPI,
149    IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
150    ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
151    R: DataAPI<Data = B::Raw>,
152    B: DeviceAPI<T>,
153{
154    into_moveaxis_f(tensor.view(), source, destination).rstsr_unwrap()
155}
156
157/// Moves array axes (dimensions) to new positions, while leaving other axes in their original
158/// positions.
159///
160/// See also [`moveaxis`].
161pub fn moveaxis_f<IS, ID, R, T, B, D>(
162    tensor: &TensorAny<R, T, B, D>,
163    source: IS,
164    destination: ID,
165) -> Result<TensorView<'_, T, B, D>>
166where
167    D: DimAPI,
168    IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
169    ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
170    R: DataAPI<Data = B::Raw>,
171    B: DeviceAPI<T>,
172{
173    into_moveaxis_f(tensor.view(), source, destination)
174}
175
176/// Moves array axes (dimensions) to new positions, while leaving other axes in their original
177/// positions.
178///
179/// See also [`moveaxis`].
180pub fn into_moveaxis<IS, ID, S, D>(tensor: TensorBase<S, D>, source: IS, destination: ID) -> TensorBase<S, D>
181where
182    D: DimAPI,
183    IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
184    ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
185{
186    into_moveaxis_f(tensor, source, destination).rstsr_unwrap()
187}
188
189impl<R, T, B, D> TensorAny<R, T, B, D>
190where
191    R: DataAPI<Data = B::Raw>,
192    B: DeviceAPI<T>,
193    D: DimAPI,
194{
195    /// Moves array axes (dimensions) to new positions, while leaving other axes in their original
196    /// positions.
197    ///
198    /// See also [`moveaxis`].
199    pub fn moveaxis<IS, ID>(&self, source: IS, destination: ID) -> TensorView<'_, T, B, D>
200    where
201        IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
202        ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
203    {
204        moveaxis(self, source, destination)
205    }
206
207    /// Moves array axes (dimensions) to new positions, while leaving other axes in their original
208    /// positions.
209    ///
210    /// See also [`moveaxis`].
211    pub fn moveaxis_f<IS, ID>(&self, source: IS, destination: ID) -> Result<TensorView<'_, T, B, D>>
212    where
213        IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
214        ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
215    {
216        moveaxis_f(self, source, destination)
217    }
218
219    /// Moves array axes (dimensions) to new positions, while leaving other axes in their original
220    /// positions.
221    ///
222    /// See also [`moveaxis`].
223    pub fn into_moveaxis<IS, ID>(self, source: IS, destination: ID) -> TensorAny<R, T, B, D>
224    where
225        IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
226        ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
227    {
228        into_moveaxis(self, source, destination)
229    }
230
231    /// Moves array axes (dimensions) to new positions, while leaving other axes in their original
232    /// positions.
233    ///
234    /// See also [`moveaxis`].
235    pub fn into_moveaxis_f<IS, ID>(self, source: IS, destination: ID) -> Result<TensorAny<R, T, B, D>>
236    where
237        IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
238        ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
239    {
240        into_moveaxis_f(self, source, destination)
241    }
242}