rstsr_core/tensor/manuplication/
flip.rs

1use crate::prelude_dev::*;
2
3/* #region flip */
4
5/// Reverses the order of elements in an array along the given axis.
6///
7/// # See also
8///
9/// Refer to [`flip`] for more detailed documentation.
10pub fn into_flip_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, D>>
11where
12    D: DimAPI,
13    I: TryInto<AxesIndex<isize>, Error = Error>,
14{
15    let (storage, mut layout) = tensor.into_raw_parts();
16    let mut axes = normalize_axes_index(axes.try_into()?, layout.ndim(), false)?;
17    if axes.is_empty() {
18        axes = (0..layout.ndim() as isize).collect();
19    }
20    for axis in axes {
21        layout = layout.dim_narrow(axis, slice!(None, None, -1))?;
22    }
23    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
24}
25
26/// Reverses the order of elements in an array along the given axis.
27///
28/// The shape of the array will be preserved after flipping.
29///
30/// # Parameters
31///
32/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
33///
34///   - The input tensor to be flipped.
35///
36/// - `axes`: TryInto [`AxesIndex<isize>`]
37///
38///   - Axis or axes along which to flip over.
39///   - If `axes` is a single integer, flipping is performed along that axis.
40///   - If `axes` is a tuple/list of integers, flipping is performed on all specified axes.
41///   - If `axes` is empty, the function will flip over all axes.
42///   - Negative values are supported and indicate counting dimensions from the back.
43///
44/// # Returns
45///
46/// - [`TensorView<'_, T, B, D>`](TensorView)
47///
48///   - A view of the input tensor with the entries along the specified axes reversed.
49///   - The shape of the array is preserved, but the elements are reordered.
50///   - The underlying data is not copied; only the layout of the view is modified.
51///   - If you want to convert the tensor itself (taking the ownership instead of returning view),
52///     use [`into_flip`] instead.
53///
54/// # Examples
55///
56/// ## Flipping along a single axis
57///
58/// Flipping the first (0) axis:
59///
60/// ```rust
61/// use rstsr::prelude::*;
62/// let mut device = DeviceCpu::default();
63/// device.set_default_order(RowMajor);
64///
65/// let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
66/// let b = a.flip(0);
67/// let b_expected = rt::tensor_from_nested!([[[4, 5], [6, 7]], [[0, 1], [2, 3]]], &device);
68/// assert!(rt::allclose(&b, &b_expected, None));
69/// # let b_sliced = a.i(slice!(None, None, -1));
70/// # assert!(rt::allclose(&b_sliced, &b_expected, None));
71/// ```
72///
73/// The flipping is equivalent to slicing with a step of -1 along the specified axis:
74///
75///
76/// ```rust
77/// # use rstsr::prelude::*;
78/// # let mut device = DeviceCpu::default();
79/// # device.set_default_order(RowMajor);
80/// #
81/// # let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
82/// # let b = a.flip(0);
83/// # let b_expected = rt::tensor_from_nested!([[[4, 5], [6, 7]], [[0, 1], [2, 3]]], &device);
84/// # assert!(rt::allclose(&b, &b_expected, None));
85/// let b_sliced = a.i(slice!(None, None, -1));
86/// assert!(rt::allclose(&b_sliced, &b_expected, None));
87/// ```
88///
89/// Flipping the second (1) axis:
90///
91/// ```rust
92/// # use rstsr::prelude::*;
93/// # let mut device = DeviceCpu::default();
94/// # device.set_default_order(RowMajor);
95/// #
96/// # let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
97/// let b = a.flip(1);
98/// let b_expected = rt::tensor_from_nested!([[[2, 3], [0, 1]], [[6, 7], [4, 5]]], &device);
99/// assert!(rt::allclose(&b, &b_expected, None));
100/// ```
101///
102/// ## Flipping along multiple axes
103///
104/// Flipping the first (0) and last (-1 or in this specific case, 2) axes:
105///
106/// ```rust
107/// # use rstsr::prelude::*;
108/// # let mut device = DeviceCpu::default();
109/// # device.set_default_order(RowMajor);
110/// #
111/// # let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
112/// let b = a.flip([0, -1]);
113/// let b_expected = rt::tensor_from_nested!([[[5, 4], [7, 6]], [[1, 0], [3, 2]]], &device);
114/// assert!(rt::allclose(&b, &b_expected, None));
115/// ```
116///
117/// ## Flipping all axes
118///
119/// You can specify `None` or empty tuple `()` to flip all axes:
120///
121/// ```rust
122/// # use rstsr::prelude::*;
123/// # let mut device = DeviceCpu::default();
124/// # device.set_default_order(RowMajor);
125/// #
126/// # let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
127/// let b = a.flip(None);
128/// let b_expected = rt::tensor_from_nested!([[[7, 6], [5, 4]], [[3, 2], [1, 0]]], &device);
129/// assert!(rt::allclose(&b, &b_expected, None));
130/// ```
131///
132/// # Panics
133///
134/// - If some index in `axes` is greater than the number of axes in the original tensor.
135/// - If `axes` has duplicated values.
136///
137/// # See also
138///
139/// ## Similar function from other crates/libraries
140///
141/// - Python Array API standard: [`flip`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.flip.html)
142/// - NumPy: [`numpy.flip`](https://numpy.org/doc/stable/reference/generated/numpy.flip.html)
143///
144/// ## Related functions in RSTSR
145///
146/// - [`i`](TensorAny::i) or [`slice`](slice()): Basic indexing and slicing of tensors, without
147///   modification of the underlying data.
148///
149/// ## Variants of this function
150///
151/// - [`flip`]: Borrowing version.
152/// - [`flip_f`]: Fallible version.
153/// - [`into_flip`]: Consuming version.
154/// - [`into_flip_f`]: Consuming and fallible version, actual implementation.
155/// - Associated methods on [`TensorAny`]:
156///
157///   - [`TensorAny::flip`]
158///   - [`TensorAny::flip_f`]
159///   - [`TensorAny::into_flip`]
160///   - [`TensorAny::into_flip_f`]
161pub fn flip<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, D>
162where
163    D: DimAPI,
164    I: TryInto<AxesIndex<isize>, Error = Error>,
165    R: DataAPI<Data = B::Raw>,
166    B: DeviceAPI<T>,
167{
168    into_flip_f(tensor.view(), axes).rstsr_unwrap()
169}
170
171/// Reverses the order of elements in an array along the given axis.
172///
173/// # See also
174///
175/// Refer to [`flip`] for more detailed documentation.
176pub fn flip_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, D>>
177where
178    D: DimAPI,
179    I: TryInto<AxesIndex<isize>, Error = Error>,
180    R: DataAPI<Data = B::Raw>,
181    B: DeviceAPI<T>,
182{
183    into_flip_f(tensor.view(), axes)
184}
185
186/// Reverses the order of elements in an array along the given axis.
187///
188/// # See also
189///
190/// Refer to [`flip`] for more detailed documentation.
191pub fn into_flip<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, D>
192where
193    D: DimAPI,
194    I: TryInto<AxesIndex<isize>, Error = Error>,
195{
196    into_flip_f(tensor, axes).rstsr_unwrap()
197}
198
199impl<R, T, B, D> TensorAny<R, T, B, D>
200where
201    R: DataAPI<Data = B::Raw>,
202    B: DeviceAPI<T>,
203    D: DimAPI,
204{
205    /// Reverses the order of elements in an array along the given axis.
206    ///
207    /// # See also
208    ///
209    /// Refer to [`flip`] for more detailed documentation.
210    pub fn flip<I>(&self, axis: I) -> TensorView<'_, T, B, D>
211    where
212        I: TryInto<AxesIndex<isize>, Error = Error>,
213    {
214        flip(self, axis)
215    }
216
217    pub fn flip_f<I>(&self, axis: I) -> Result<TensorView<'_, T, B, D>>
218    where
219        I: TryInto<AxesIndex<isize>, Error = Error>,
220    {
221        flip_f(self, axis)
222    }
223
224    /// Reverses the order of elements in an array along the given axis.
225    ///
226    /// # See also
227    ///
228    /// Refer to [`flip`] for more detailed documentation.
229    pub fn into_flip<I>(self, axis: I) -> TensorAny<R, T, B, D>
230    where
231        I: TryInto<AxesIndex<isize>, Error = Error>,
232    {
233        into_flip(self, axis)
234    }
235
236    /// Reverses the order of elements in an array along the given axis.
237    ///
238    /// # See also
239    ///
240    /// Refer to [`flip`] for more detailed documentation.
241    pub fn into_flip_f<I>(self, axis: I) -> Result<TensorAny<R, T, B, D>>
242    where
243        I: TryInto<AxesIndex<isize>, Error = Error>,
244    {
245        into_flip_f(self, axis)
246    }
247}
248
249/* #endregion */
250
251#[cfg(test)]
252mod tests {
253    #[test]
254    fn doc_flip() {
255        use rstsr::prelude::*;
256
257        let mut device = DeviceCpu::default();
258        device.set_default_order(RowMajor);
259
260        let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
261        let a_expected = rt::tensor_from_nested!([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], &device);
262        assert!(rt::allclose(&a, &a_expected, None));
263
264        let b = a.flip(0);
265        let b_expected = rt::tensor_from_nested!([[[4, 5], [6, 7]], [[0, 1], [2, 3]]], &device);
266        assert!(rt::allclose(&b, &b_expected, None));
267        let b_sliced = a.i(slice!(None, None, -1));
268        assert!(rt::allclose(&b_sliced, &b_expected, None));
269
270        let b = a.flip(1);
271        let b_expected = rt::tensor_from_nested!([[[2, 3], [0, 1]], [[6, 7], [4, 5]]], &device);
272        assert!(rt::allclose(&b, &b_expected, None));
273
274        let b = a.flip([0, -1]);
275        let b_expected = rt::tensor_from_nested!([[[5, 4], [7, 6]], [[1, 0], [3, 2]]], &device);
276        assert!(rt::allclose(&b, &b_expected, None));
277
278        let b = a.flip(None);
279        let b_expected = rt::tensor_from_nested!([[[7, 6], [5, 4]], [[3, 2], [1, 0]]], &device);
280        assert!(rt::allclose(&b, &b_expected, None));
281    }
282}