Skip to main content

rstsr_core/tensor/manuplication/
flip.rs

1use crate::prelude_dev::*;
2
3/// Reverses the order of elements in an array along the given axis.
4///
5/// See also [`flip`].
6pub fn into_flip_f<S, D>(
7    tensor: TensorBase<S, D>,
8    axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
9) -> Result<TensorBase<S, D>>
10where
11    D: DimAPI,
12{
13    let (storage, mut layout) = tensor.into_raw_parts();
14    let axes = axes.try_into().map_err(Into::into)?;
15    let axes = match axes {
16        AxesIndex::None => (0..layout.ndim() as isize).collect(),
17        _ => normalize_axes_index(axes, layout.ndim(), false, true)?,
18    };
19    for axis in axes {
20        layout = layout.dim_narrow(axis, slice!(None, None, -1))?;
21    }
22    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
23}
24
25/// Reverses the order of elements in an array along the given axis.
26///
27/// The shape of the array will be preserved after flipping.
28///
29/// # Parameters
30///
31/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
32///
33///   - The input tensor.
34///   - Note on variant [`into_flip`]: This takes ownership [`Tensor<R, T, B, D>`] of input tensor,
35///     and will not perform change to underlying data, only layout changes.
36///
37/// - `axes`: TryInto [`AxesIndex<isize>`]
38///
39///   - Axis or axes along which to flip over.
40///   - If `axes` is a single integer, flipping is performed along that axis.
41///   - If `axes` is a tuple/list of integers, flipping is performed on all specified axes.
42///   - If `axes` is `None`, the function will flip over all axes.
43///   - If `axes` is an empty tuple `()`, no axes are flipped (returns a view of the original
44///     tensor).
45///   - Negative values are supported and indicate counting dimensions from the back.
46///
47/// # Returns
48///
49/// - [`TensorView<'_, T, B, D>`](TensorView)
50///
51///   - A view of the input tensor with the entries along the specified axes reversed.
52///   - The shape of the array is preserved, but the elements are reordered.
53///   - The underlying data is not copied; only the layout of the view is modified.
54///   - If you want to convert the tensor itself (taking the ownership instead of returning view),
55///     use [`into_flip`] instead.
56///
57/// # Examples
58///
59/// ## Flipping along a single axis
60///
61/// ```rust
62/// # use rstsr::prelude::*;
63/// # let mut device = DeviceCpu::default();
64/// # device.set_default_order(RowMajor);
65/// let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
66/// println!("{a}");
67/// // [[[ 0 1]
68/// //   [ 2 3]]
69/// //
70/// //  [[ 4 5]
71/// //   [ 6 7]]]
72/// ```
73///
74/// Flipping the first (0) axis (which is equivalent to slicing with step -1 along that axis):
75///
76/// ```rust
77/// # use rstsr::prelude::*;
78/// # let mut device = DeviceCpu::default();
79/// # device.set_default_order(RowMajor);
80/// let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
81/// let b = a.flip(0);
82/// println!("{b}");
83/// // [[[ 4 5]
84/// //   [ 6 7]]
85/// //
86/// //  [[ 0 1]
87/// //   [ 2 3]]]
88/// assert!(rt::allclose(a.i(slice!(None, None, -1)), &b, None));
89/// ```
90///
91/// Flipping the second (1) axis:
92///
93/// ```rust
94/// # use rstsr::prelude::*;
95/// # let mut device = DeviceCpu::default();
96/// # device.set_default_order(RowMajor);
97/// let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
98/// let b = a.flip(1);
99/// println!("{b}");
100/// // [[[ 2 3]
101/// //   [ 0 1]]
102/// //
103/// //  [[ 6 7]
104/// //   [ 4 5]]]
105/// assert!(rt::allclose(a.i((.., slice!(None, None, -1))), &b, None));
106/// ```
107///
108/// ## Flipping along multiple axes
109///
110/// Flipping the first (0) and last (-1 or in this specific case, 2) axes:
111///
112/// ```rust
113/// # use rstsr::prelude::*;
114/// # let mut device = DeviceCpu::default();
115/// # device.set_default_order(RowMajor);
116/// let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
117/// let b = a.flip([0, -1]);
118/// println!("{b}");
119/// // [[[ 5 4]
120/// //   [ 7 6]]
121/// //
122/// //  [[ 1 0]
123/// //   [ 3 2]]]
124/// ```
125///
126/// ## Flipping all axes
127///
128/// You can specify `None` to flip all axes:
129///
130/// ```rust
131/// # use rstsr::prelude::*;
132/// # let mut device = DeviceCpu::default();
133/// # device.set_default_order(RowMajor);
134/// let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
135/// let b = a.flip(None);
136/// println!("{b}");
137/// // [[[ 7 6]
138/// //   [ 5 4]]
139/// //
140/// //  [[ 3 2]
141/// //   [ 1 0]]]
142/// ```
143///
144/// ## No flipping (empty axes)
145///
146/// You can specify an empty tuple `()` to flip no axes (returns a view of the original tensor):
147///
148/// ```rust
149/// # use rstsr::prelude::*;
150/// # let mut device = DeviceCpu::default();
151/// # device.set_default_order(RowMajor);
152/// let a = rt::arange((8, &device)).into_shape([2, 2, 2]);
153/// let b = a.flip(());
154/// println!("{b}");
155/// // [[[ 0 1]
156/// //   [ 2 3]]
157/// //
158/// //  [[ 4 5]
159/// //   [ 6 7]]]
160/// ```
161///
162/// # Panics
163///
164/// - If some index in `axes` is greater than the number of axes in the original tensor.
165/// - If `axes` has duplicated values.
166///
167/// For a fallible version, use [`flip_f`].
168///
169/// # Notes of API accordance
170///
171/// - Array-API: `flip(x, /, *, axis=None)` ([`flip`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.flip.html))
172/// - NumPy: `flip(m, axis=None)` ([`numpy.flip`](https://numpy.org/doc/stable/reference/generated/numpy.flip.html))
173/// - RSTSR: `rt::flip(tensor, axes)`
174///
175/// RSTSR's behavior matches NumPy and Array-API:
176/// - `a.flip(None)` flips all axes
177/// - `a.flip(())` flips no axes (returns a view of the original tensor)
178///
179/// # See also
180///
181/// ## Related functions in RSTSR
182///
183/// - [`i`](TensorAny::i) or [`slice`](slice()): Basic indexing and slicing of tensors, without
184///   modification of the underlying data.
185///
186/// ## Variants of this function
187///
188/// - [`flip`]: Borrowing version.
189/// - [`flip_f`]: Fallible version.
190/// - [`into_flip`]: Consuming version.
191/// - [`into_flip_f`]: Consuming and fallible version, actual implementation.
192/// - Associated methods on [`TensorAny`]:
193///
194///   - [`TensorAny::flip`]
195///   - [`TensorAny::flip_f`]
196///   - [`TensorAny::into_flip`]
197///   - [`TensorAny::into_flip_f`]
198pub fn flip<R, T, B, D>(
199    tensor: &TensorAny<R, T, B, D>,
200    axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
201) -> TensorView<'_, T, B, D>
202where
203    D: DimAPI,
204    R: DataAPI<Data = B::Raw>,
205    B: DeviceAPI<T>,
206{
207    into_flip_f(tensor.view(), axes).rstsr_unwrap()
208}
209
210/// Reverses the order of elements in an array along the given axis.
211///
212/// See also [`flip`].
213pub fn flip_f<R, T, B, D>(
214    tensor: &TensorAny<R, T, B, D>,
215    axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
216) -> Result<TensorView<'_, T, B, D>>
217where
218    D: DimAPI,
219    R: DataAPI<Data = B::Raw>,
220    B: DeviceAPI<T>,
221{
222    into_flip_f(tensor.view(), axes)
223}
224
225/// Reverses the order of elements in an array along the given axis.
226///
227/// See also [`flip`].
228pub fn into_flip<S, D>(
229    tensor: TensorBase<S, D>,
230    axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
231) -> TensorBase<S, D>
232where
233    D: DimAPI,
234{
235    into_flip_f(tensor, axes).rstsr_unwrap()
236}
237
238impl<R, T, B, D> TensorAny<R, T, B, D>
239where
240    R: DataAPI<Data = B::Raw>,
241    B: DeviceAPI<T>,
242    D: DimAPI,
243{
244    /// Reverses the order of elements in an array along the given axis.
245    ///
246    /// See also [`flip`].
247    pub fn flip(&self, axis: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> TensorView<'_, T, B, D> {
248        flip(self, axis)
249    }
250
251    /// Reverses the order of elements in an array along the given axis.
252    ///
253    /// See also [`flip`].
254    pub fn flip_f(&self, axis: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> Result<TensorView<'_, T, B, D>> {
255        flip_f(self, axis)
256    }
257
258    /// Reverses the order of elements in an array along the given axis.
259    ///
260    /// See also [`flip`].
261    pub fn into_flip(self, axis: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> TensorAny<R, T, B, D> {
262        into_flip(self, axis)
263    }
264
265    /// Reverses the order of elements in an array along the given axis.
266    ///
267    /// See also [`flip`].
268    pub fn into_flip_f(
269        self,
270        axis: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
271    ) -> Result<TensorAny<R, T, B, D>> {
272        into_flip_f(self, axis)
273    }
274}