rstsr_core/tensor/manuplication/flip.rs
1use crate::prelude_dev::*;
2
3/* #region flip */
4
5/// Reverses the order of elements in an array along the given axis.
6///
7/// # See also
8///
9/// Refer to [`flip`] for more detailed documentation.
10pub fn into_flip_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, D>>
11where
12 D: DimAPI,
13 I: TryInto<AxesIndex<isize>, Error = Error>,
14{
15 let (storage, mut layout) = tensor.into_raw_parts();
16 let mut axes = normalize_axes_index(axes.try_into()?, layout.ndim(), false)?;
17 if axes.is_empty() {
18 axes = (0..layout.ndim() as isize).collect();
19 }
20 for axis in axes {
21 layout = layout.dim_narrow(axis, slice!(None, None, -1))?;
22 }
23 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
24}
25
26/// Reverses the order of elements in an array along the given axis.
27///
28/// The shape of the array will be preserved after flipping.
29///
30/// # Parameters
31///
32/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
33///
34/// - The input tensor to be flipped.
35///
36/// - `axes`: TryInto [`AxesIndex<isize>`]
37///
38/// - Axis or axes along which to flip over.
39/// - If `axes` is a single integer, flipping is performed along that axis.
40/// - If `axes` is a tuple/list of integers, flipping is performed on all specified axes.
41/// - If `axes` is empty, the function will flip over all axes.
42/// - Negative values are supported and indicate counting dimensions from the back.
43///
44/// # Returns
45///
46/// - [`TensorView<'_, T, B, D>`](TensorView)
47///
48/// - A view of the input tensor with the entries along the specified axes reversed.
49/// - The shape of the array is preserved, but the elements are reordered.
50/// - The underlying data is not copied; only the layout of the view is modified.
51/// - If you want to convert the tensor itself (taking the ownership instead of returning view),
52/// use [`into_flip`] instead.
53///
54/// # Examples
55///
56/// ## Flipping along a single axis
57///
58/// Flipping the first (0) axis:
59///
60/// ```rust
61/// use rstsr::prelude::*;
62/// let mut device = DeviceCpu::default();
63/// device.set_default_order(RowMajor);
64///
65/// let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
66/// let b = a.flip(0);
67/// let b_expected = rt::tensor_from_nested!([[[4, 5], [6, 7]], [[0, 1], [2, 3]]], &device);
68/// assert!(rt::allclose(&b, &b_expected, None));
69/// # let b_sliced = a.i(slice!(None, None, -1));
70/// # assert!(rt::allclose(&b_sliced, &b_expected, None));
71/// ```
72///
73/// The flipping is equivalent to slicing with a step of -1 along the specified axis:
74///
75///
76/// ```rust
77/// # use rstsr::prelude::*;
78/// # let mut device = DeviceCpu::default();
79/// # device.set_default_order(RowMajor);
80/// #
81/// # let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
82/// # let b = a.flip(0);
83/// # let b_expected = rt::tensor_from_nested!([[[4, 5], [6, 7]], [[0, 1], [2, 3]]], &device);
84/// # assert!(rt::allclose(&b, &b_expected, None));
85/// let b_sliced = a.i(slice!(None, None, -1));
86/// assert!(rt::allclose(&b_sliced, &b_expected, None));
87/// ```
88///
89/// Flipping the second (1) axis:
90///
91/// ```rust
92/// # use rstsr::prelude::*;
93/// # let mut device = DeviceCpu::default();
94/// # device.set_default_order(RowMajor);
95/// #
96/// # let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
97/// let b = a.flip(1);
98/// let b_expected = rt::tensor_from_nested!([[[2, 3], [0, 1]], [[6, 7], [4, 5]]], &device);
99/// assert!(rt::allclose(&b, &b_expected, None));
100/// ```
101///
102/// ## Flipping along multiple axes
103///
104/// Flipping the first (0) and last (-1 or in this specific case, 2) axes:
105///
106/// ```rust
107/// # use rstsr::prelude::*;
108/// # let mut device = DeviceCpu::default();
109/// # device.set_default_order(RowMajor);
110/// #
111/// # let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
112/// let b = a.flip([0, -1]);
113/// let b_expected = rt::tensor_from_nested!([[[5, 4], [7, 6]], [[1, 0], [3, 2]]], &device);
114/// assert!(rt::allclose(&b, &b_expected, None));
115/// ```
116///
117/// ## Flipping all axes
118///
119/// You can specify `None` or empty tuple `()` to flip all axes:
120///
121/// ```rust
122/// # use rstsr::prelude::*;
123/// # let mut device = DeviceCpu::default();
124/// # device.set_default_order(RowMajor);
125/// #
126/// # let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
127/// let b = a.flip(None);
128/// let b_expected = rt::tensor_from_nested!([[[7, 6], [5, 4]], [[3, 2], [1, 0]]], &device);
129/// assert!(rt::allclose(&b, &b_expected, None));
130/// ```
131///
132/// # Panics
133///
134/// - If some index in `axes` is greater than the number of axes in the original tensor.
135/// - If `axes` has duplicated values.
136///
137/// # See also
138///
139/// ## Similar function from other crates/libraries
140///
141/// - Python Array API standard: [`flip`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.flip.html)
142/// - NumPy: [`numpy.flip`](https://numpy.org/doc/stable/reference/generated/numpy.flip.html)
143///
144/// ## Related functions in RSTSR
145///
146/// - [`i`](TensorAny::i) or [`slice`](slice()): Basic indexing and slicing of tensors, without
147/// modification of the underlying data.
148///
149/// ## Variants of this function
150///
151/// - [`flip`]: Borrowing version.
152/// - [`flip_f`]: Fallible version.
153/// - [`into_flip`]: Consuming version.
154/// - [`into_flip_f`]: Consuming and fallible version, actual implementation.
155/// - Associated methods on [`TensorAny`]:
156///
157/// - [`TensorAny::flip`]
158/// - [`TensorAny::flip_f`]
159/// - [`TensorAny::into_flip`]
160/// - [`TensorAny::into_flip_f`]
161pub fn flip<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, D>
162where
163 D: DimAPI,
164 I: TryInto<AxesIndex<isize>, Error = Error>,
165 R: DataAPI<Data = B::Raw>,
166 B: DeviceAPI<T>,
167{
168 into_flip_f(tensor.view(), axes).rstsr_unwrap()
169}
170
171/// Reverses the order of elements in an array along the given axis.
172///
173/// # See also
174///
175/// Refer to [`flip`] for more detailed documentation.
176pub fn flip_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, D>>
177where
178 D: DimAPI,
179 I: TryInto<AxesIndex<isize>, Error = Error>,
180 R: DataAPI<Data = B::Raw>,
181 B: DeviceAPI<T>,
182{
183 into_flip_f(tensor.view(), axes)
184}
185
186/// Reverses the order of elements in an array along the given axis.
187///
188/// # See also
189///
190/// Refer to [`flip`] for more detailed documentation.
191pub fn into_flip<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, D>
192where
193 D: DimAPI,
194 I: TryInto<AxesIndex<isize>, Error = Error>,
195{
196 into_flip_f(tensor, axes).rstsr_unwrap()
197}
198
199impl<R, T, B, D> TensorAny<R, T, B, D>
200where
201 R: DataAPI<Data = B::Raw>,
202 B: DeviceAPI<T>,
203 D: DimAPI,
204{
205 /// Reverses the order of elements in an array along the given axis.
206 ///
207 /// # See also
208 ///
209 /// Refer to [`flip`] for more detailed documentation.
210 pub fn flip<I>(&self, axis: I) -> TensorView<'_, T, B, D>
211 where
212 I: TryInto<AxesIndex<isize>, Error = Error>,
213 {
214 flip(self, axis)
215 }
216
217 pub fn flip_f<I>(&self, axis: I) -> Result<TensorView<'_, T, B, D>>
218 where
219 I: TryInto<AxesIndex<isize>, Error = Error>,
220 {
221 flip_f(self, axis)
222 }
223
224 /// Reverses the order of elements in an array along the given axis.
225 ///
226 /// # See also
227 ///
228 /// Refer to [`flip`] for more detailed documentation.
229 pub fn into_flip<I>(self, axis: I) -> TensorAny<R, T, B, D>
230 where
231 I: TryInto<AxesIndex<isize>, Error = Error>,
232 {
233 into_flip(self, axis)
234 }
235
236 /// Reverses the order of elements in an array along the given axis.
237 ///
238 /// # See also
239 ///
240 /// Refer to [`flip`] for more detailed documentation.
241 pub fn into_flip_f<I>(self, axis: I) -> Result<TensorAny<R, T, B, D>>
242 where
243 I: TryInto<AxesIndex<isize>, Error = Error>,
244 {
245 into_flip_f(self, axis)
246 }
247}
248
249/* #endregion */
250
251#[cfg(test)]
252mod tests {
253 #[test]
254 fn doc_flip() {
255 use rstsr::prelude::*;
256
257 let mut device = DeviceCpu::default();
258 device.set_default_order(RowMajor);
259
260 let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
261 let a_expected = rt::tensor_from_nested!([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], &device);
262 assert!(rt::allclose(&a, &a_expected, None));
263
264 let b = a.flip(0);
265 let b_expected = rt::tensor_from_nested!([[[4, 5], [6, 7]], [[0, 1], [2, 3]]], &device);
266 assert!(rt::allclose(&b, &b_expected, None));
267 let b_sliced = a.i(slice!(None, None, -1));
268 assert!(rt::allclose(&b_sliced, &b_expected, None));
269
270 let b = a.flip(1);
271 let b_expected = rt::tensor_from_nested!([[[2, 3], [0, 1]], [[6, 7], [4, 5]]], &device);
272 assert!(rt::allclose(&b, &b_expected, None));
273
274 let b = a.flip([0, -1]);
275 let b_expected = rt::tensor_from_nested!([[[5, 4], [7, 6]], [[1, 0], [3, 2]]], &device);
276 assert!(rt::allclose(&b, &b_expected, None));
277
278 let b = a.flip(None);
279 let b_expected = rt::tensor_from_nested!([[[7, 6], [5, 4]], [[3, 2], [1, 0]]], &device);
280 assert!(rt::allclose(&b, &b_expected, None));
281 }
282}