rstsr_core/tensor/manuplication/moveaxis.rs
1use crate::prelude_dev::*;
2
3/// Moves array axes (dimensions) to new positions, while leaving other axes in their original
4/// positions.
5///
6/// See also [`moveaxis`].
7pub fn into_moveaxis_f<IS, ID, S, D>(tensor: TensorBase<S, D>, source: IS, destination: ID) -> Result<TensorBase<S, D>>
8where
9 D: DimAPI,
10 IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
11 ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
12{
13 let source = source.try_into().map_err(Into::into)?;
14 let destination = destination.try_into().map_err(Into::into)?;
15
16 let ndim = tensor.ndim();
17
18 // Normalize axes
19 let source = normalize_axes_index(source, tensor.ndim(), false, false)?;
20 let destination = normalize_axes_index(destination, tensor.ndim(), false, false)?;
21
22 // Check that source and destination have the same length
23 rstsr_assert_eq!(
24 source.len(),
25 destination.len(),
26 InvalidValue,
27 "`source` and `destination` arguments must have the same number of elements"
28 )?;
29
30 // Build the permutation order
31 // Start with all axes that are not in source
32 let mut order: Vec<isize> = (0..ndim as isize).filter(|&i| !source.contains(&i)).collect();
33
34 // Insert source axes at their destination positions
35 // Sort pairs by destination to insert in correct order
36 let mut pairs: Vec<(isize, isize)> = destination.iter().zip(source.iter()).map(|(&d, &s)| (d, s)).collect();
37 pairs.sort_by_key(|&(d, _)| d);
38
39 for (dest, src) in pairs {
40 order.insert(dest as usize, src);
41 }
42
43 // Apply the transpose
44 let (storage, layout) = tensor.into_raw_parts();
45 let layout = layout.transpose(&order)?;
46 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
47}
48
49/// Moves array axes (dimensions) to new positions, while leaving other axes in their original
50/// positions.
51///
52/// Returns an array with axes moved to new positions. Other axes remain in their
53/// original order. This is a view operation; no data is copied.
54///
55/// # Parameters
56///
57/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
58///
59/// - The input tensor whose axes are to be moved.
60///
61/// - `source`: TryInto [`AxesIndex<isize>`]
62///
63/// - Original positions of the axes to move. These must be unique.
64/// - Can be a single axis index or a sequence of axis indices.
65/// - Negative values are supported and indicate counting dimensions from the back.
66///
67/// - `destination`: TryInto [`AxesIndex<isize>`]
68///
69/// - Destination positions for each of the original axes. These must also be unique.
70/// - Can be a single axis index or a sequence of axis indices.
71/// - Negative values are supported and indicate counting dimensions from the back.
72/// - Must have the same number of elements as `source`.
73///
74/// # Returns
75///
76/// - [`TensorView<'_, T, B, D>`](TensorView)
77///
78/// - A view of the input tensor with moved axes.
79/// - No data is copied; only the shape and strides are modified.
80///
81/// # Examples
82///
83/// Move a single axis to a new position:
84///
85/// ```rust
86/// # use rstsr::prelude::*;
87/// # let mut device = DeviceCpu::default();
88/// # device.set_default_order(RowMajor);
89/// let x: Tensor<f64, _> = rt::zeros(([3, 4, 5], &device));
90/// let result = x.moveaxis(0, -1);
91/// println!("{:?}", result.shape());
92/// // [4, 5, 3]
93/// ```
94///
95/// Move multiple axes to new positions:
96///
97/// ```rust
98/// # use rstsr::prelude::*;
99/// # let mut device = DeviceCpu::default();
100/// # device.set_default_order(RowMajor);
101/// let x: Tensor<f64, _> = rt::zeros(([3, 4, 5], &device));
102/// let result = x.moveaxis([0, 1], [-1, -2]);
103/// println!("{:?}", result.shape());
104/// // [5, 4, 3]
105/// ```
106///
107/// Using negative indices:
108///
109/// ```rust
110/// # use rstsr::prelude::*;
111/// # let mut device = DeviceCpu::default();
112/// # device.set_default_order(RowMajor);
113/// let x: Tensor<f64, _> = rt::zeros(([3, 4, 5], &device));
114/// let result = x.moveaxis(-1, 0);
115/// println!("{:?}", result.shape());
116/// // [5, 3, 4]
117/// ```
118///
119/// # Notes of API accordance
120///
121/// - Array-API: `moveaxis(x, source, destination, /)` ([`moveaxis`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.moveaxis.html))
122/// - NumPy: `moveaxis(a, source, destination)` ([`numpy.moveaxis`](https://numpy.org/doc/stable/reference/generated/numpy.moveaxis.html))
123/// - RSTSR: `tensor.moveaxis(source, destination)` or `rt::moveaxis(&tensor, source, destination)`
124///
125/// # See also
126///
127/// ## Related functions in RSTSR
128///
129/// - [`transpose`] - General axis permutation
130/// - [`swapaxes`] - Swap two specific axes
131/// - [`reverse_axes`] - Reverse all axes order
132///
133/// ## Variants of this function
134///
135/// - [`moveaxis`] / [`moveaxis_f`]: Returning a view.
136/// - [`into_moveaxis`] / [`into_moveaxis_f`]: Consuming version.
137///
138/// - Associated methods on `TensorAny`:
139///
140/// - [`TensorAny::moveaxis`] / [`TensorAny::moveaxis_f`]
141/// - [`TensorAny::into_moveaxis`] / [`TensorAny::into_moveaxis_f`]
142pub fn moveaxis<IS, ID, R, T, B, D>(
143 tensor: &TensorAny<R, T, B, D>,
144 source: IS,
145 destination: ID,
146) -> TensorView<'_, T, B, D>
147where
148 D: DimAPI,
149 IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
150 ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
151 R: DataAPI<Data = B::Raw>,
152 B: DeviceAPI<T>,
153{
154 into_moveaxis_f(tensor.view(), source, destination).rstsr_unwrap()
155}
156
157/// Moves array axes (dimensions) to new positions, while leaving other axes in their original
158/// positions.
159///
160/// See also [`moveaxis`].
161pub fn moveaxis_f<IS, ID, R, T, B, D>(
162 tensor: &TensorAny<R, T, B, D>,
163 source: IS,
164 destination: ID,
165) -> Result<TensorView<'_, T, B, D>>
166where
167 D: DimAPI,
168 IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
169 ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
170 R: DataAPI<Data = B::Raw>,
171 B: DeviceAPI<T>,
172{
173 into_moveaxis_f(tensor.view(), source, destination)
174}
175
176/// Moves array axes (dimensions) to new positions, while leaving other axes in their original
177/// positions.
178///
179/// See also [`moveaxis`].
180pub fn into_moveaxis<IS, ID, S, D>(tensor: TensorBase<S, D>, source: IS, destination: ID) -> TensorBase<S, D>
181where
182 D: DimAPI,
183 IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
184 ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
185{
186 into_moveaxis_f(tensor, source, destination).rstsr_unwrap()
187}
188
189impl<R, T, B, D> TensorAny<R, T, B, D>
190where
191 R: DataAPI<Data = B::Raw>,
192 B: DeviceAPI<T>,
193 D: DimAPI,
194{
195 /// Moves array axes (dimensions) to new positions, while leaving other axes in their original
196 /// positions.
197 ///
198 /// See also [`moveaxis`].
199 pub fn moveaxis<IS, ID>(&self, source: IS, destination: ID) -> TensorView<'_, T, B, D>
200 where
201 IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
202 ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
203 {
204 moveaxis(self, source, destination)
205 }
206
207 /// Moves array axes (dimensions) to new positions, while leaving other axes in their original
208 /// positions.
209 ///
210 /// See also [`moveaxis`].
211 pub fn moveaxis_f<IS, ID>(&self, source: IS, destination: ID) -> Result<TensorView<'_, T, B, D>>
212 where
213 IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
214 ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
215 {
216 moveaxis_f(self, source, destination)
217 }
218
219 /// Moves array axes (dimensions) to new positions, while leaving other axes in their original
220 /// positions.
221 ///
222 /// See also [`moveaxis`].
223 pub fn into_moveaxis<IS, ID>(self, source: IS, destination: ID) -> TensorAny<R, T, B, D>
224 where
225 IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
226 ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
227 {
228 into_moveaxis(self, source, destination)
229 }
230
231 /// Moves array axes (dimensions) to new positions, while leaving other axes in their original
232 /// positions.
233 ///
234 /// See also [`moveaxis`].
235 pub fn into_moveaxis_f<IS, ID>(self, source: IS, destination: ID) -> Result<TensorAny<R, T, B, D>>
236 where
237 IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
238 ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
239 {
240 into_moveaxis_f(self, source, destination)
241 }
242}