Skip to main content

rstsr_core/tensor/manuplication/
expand_dims.rs

1use crate::prelude_dev::*;
2
3/// Expands the shape of an array by inserting a new axis (dimension) of size one at the position
4/// specified by `axis`.
5///
6/// See also [`expand_dims`].
7pub fn into_expand_dims_f<S, D>(
8    tensor: TensorBase<S, D>,
9    axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
10) -> Result<TensorBase<S, IxD>>
11where
12    D: DimAPI,
13{
14    // convert axis to negative indexes and sort
15    let ndim = tensor.ndim();
16    let (storage, layout) = tensor.into_raw_parts();
17    let mut layout = layout.into_dim::<IxD>()?;
18    let axes = axes.try_into().map_err(Into::into)?;
19    let len_axes = axes.as_ref().len();
20    let axes = normalize_axes_index(axes, ndim + len_axes, false, true)?;
21    for axis in axes {
22        layout = layout.dim_insert(axis)?;
23    }
24    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
25}
26
27/// Expands the shape of an array by inserting a new axis (dimension) of size one at the position
28/// specified by `axis`.
29///
30/// # Parameters
31///
32/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
33///
34///   - The input tensor.
35///   - Note on variant [`into_expand_dims`]: This takes ownership [`Tensor<R, T, B, D>`] of input
36///     tensor, and will not perform change to underlying data, only layout changes.
37///
38/// - `axes`: TryInto [`AxesIndex<isize>`]
39///
40///   - Position in the expanded axes where the new axis (or axes) is placed.
41///   - Can be a single integer, or a list/tuple of integers.
42///   - Negative values are supported and indicate counting dimensions from the back.
43///
44/// # Returns
45///
46/// - [`TensorView<'_, T, B, IxD>`](TensorView)
47///
48///   - A view of the input tensor with the new axis (or axes) inserted.
49///   - If you want to convert the tensor itself (taking the ownership instead of returning view),
50///     use [`into_expand_dims`] instead.
51///
52/// # Panics
53///
54/// - If `axis` is greater than the number of axes in the original tensor.
55/// - If expanded axis has duplicated values.
56///
57/// For a fallible version, use [`expand_dims_f`].
58///
59/// # Examples
60///
61/// Expand dims at axis 0, which is equivalent to `x.i(None)`:
62///
63/// ```rust
64/// # use rstsr::prelude::*;
65/// # let mut device = DeviceCpu::default();
66/// # device.set_default_order(RowMajor);
67/// let x = rt::arange((2, &device));
68/// let y = x.expand_dims(0);
69/// println!("{y}");
70/// // [[ 0 1]]
71/// println!("y shape: {:?}", y.shape());
72/// // y shape: [1, 2]
73/// assert_eq!(y.shape(), &[1, 2]);
74/// assert_eq!(x.i(None).shape(), y.shape());
75/// ```
76///
77/// Expand dims at axis -1 (last axis), which is equivalent to `x.i((Ellipsis, None))`:
78///
79/// ```rust
80/// # use rstsr::prelude::*;
81/// # let mut device = DeviceCpu::default();
82/// # device.set_default_order(RowMajor);
83/// let x = rt::arange((2, &device));
84/// let y = x.expand_dims(-1);
85/// println!("{y}");
86/// // [[ 0]
87/// //  [ 1]]
88/// println!("y shape: {:?}", y.shape());
89/// // y shape: [2, 1]
90/// assert_eq!(y.shape(), &[2, 1]);
91/// assert_eq!(x.i((Ellipsis, None)).shape(), &[2, 1]);
92/// ```
93///
94/// Expand dims at axes 0 and 1, which is equivalent to `x.i((None, None))`:
95///
96/// ```rust
97/// # use rstsr::prelude::*;
98/// # let mut device = DeviceCpu::default();
99/// # device.set_default_order(RowMajor);
100/// let x = rt::arange((2, &device));
101/// let y = x.expand_dims([0, 1]);
102/// println!("{y}");
103/// // [[[ 0 1]]]
104/// println!("y shape: {:?}", y.shape());
105/// // y shape: [1, 1, 2]
106/// assert_eq!(y.shape(), &[1, 1, 2]);
107/// assert_eq!(x.i((None, None)).shape(), &[1, 1, 2]);
108/// ```
109///
110/// Expand dims at axes 0 and 2, which is equivalent to `x.i((None, Ellipsis, None))`:
111///
112/// ```rust
113/// # use rstsr::prelude::*;
114/// # let mut device = DeviceCpu::default();
115/// # device.set_default_order(RowMajor);
116/// let x = rt::arange((2, &device));
117/// let y = x.expand_dims([0, 2]);
118/// println!("{y}");
119/// // [[[ 0]]
120/// //  [[ 1]]]
121/// println!("y shape: {:?}", y.shape());
122/// // y shape: [1, 2, 1]
123/// assert_eq!(y.shape(), &[1, 2, 1]);
124/// assert_eq!(x.i((None, Ellipsis, None)).shape(), &[1, 2, 1]);
125/// ```
126///
127/// # Notes of API accordance
128///
129/// - Array-API: `expand_dims(x, /, axis)` ([`expand_dims` in Array-API](https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html))
130/// - NumPy: `expand_dims(a, axis)` ([`numpy.expand_dims`](https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html))
131/// - RSTSR: `rt::expand_dims(tensor, axes)`
132///
133/// # See also
134///
135/// ## Related functions in RSTSR
136///
137/// - [`i`](TensorAny::i) or [`slice`](slice()): Basic indexing and slicing of tensors, without
138///   modification of the underlying data.
139/// - [`squeeze`]: Removes singleton dimensions (axes) from `x`.
140///
141/// ## Variants of this function
142///
143/// - [expand_dims] / [`expand_dims_f`]: Returning a view.
144/// - [`into_expand_dims`] / [`into_expand_dims_f`]: Consuming version.
145/// - [`unsqueeze`] / [`unsqueeze_f`]: Alias of [`expand_dims`] / [`expand_dims_f`].
146/// - [`into_unsqueeze`] / [`into_unsqueeze_f`]: Alias of [`into_expand_dims`] /
147///   [`into_expand_dims_f`].
148///
149/// - Associated methods on [`TensorAny`]:
150///
151///   - [`TensorAny::expand_dims`] / [`TensorAny::expand_dims_f`]
152///   - [`TensorAny::into_expand_dims`] / [`TensorAny::into_expand_dims_f`]
153///   - [`TensorAny::unsqueeze`] / [`TensorAny::unsqueeze_f`]
154///   - [`TensorAny::into_unsqueeze`] / [`TensorAny::into_unsqueeze_f`]
155pub fn expand_dims<R, T, B, D>(
156    tensor: &TensorAny<R, T, B, D>,
157    axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
158) -> TensorView<'_, T, B, IxD>
159where
160    D: DimAPI,
161    R: DataAPI<Data = B::Raw>,
162    B: DeviceAPI<T>,
163{
164    into_expand_dims_f(tensor.view(), axes).rstsr_unwrap()
165}
166
167/// Expands the shape of an array by inserting a new axis (dimension) of size one at the position
168/// specified by `axis`.
169///
170/// See also [`expand_dims`].
171pub fn expand_dims_f<R, T, B, D>(
172    tensor: &TensorAny<R, T, B, D>,
173    axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
174) -> Result<TensorView<'_, T, B, IxD>>
175where
176    D: DimAPI,
177    R: DataAPI<Data = B::Raw>,
178    B: DeviceAPI<T>,
179{
180    into_expand_dims_f(tensor.view(), axes)
181}
182
183/// Expands the shape of an array by inserting a new axis (dimension) of size one at the position
184/// specified by `axis`.
185///
186/// See also [`expand_dims`].
187pub fn into_expand_dims<S, D>(
188    tensor: TensorBase<S, D>,
189    axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
190) -> TensorBase<S, IxD>
191where
192    D: DimAPI,
193{
194    into_expand_dims_f(tensor, axes).rstsr_unwrap()
195}
196
197pub use expand_dims as unsqueeze;
198pub use expand_dims_f as unsqueeze_f;
199pub use into_expand_dims as into_unsqueeze;
200pub use into_expand_dims_f as into_unsqueeze_f;
201
202impl<R, T, B, D> TensorAny<R, T, B, D>
203where
204    R: DataAPI<Data = B::Raw>,
205    B: DeviceAPI<T>,
206    D: DimAPI,
207{
208    /// Expands the shape of an array by inserting a new axis (dimension) of size one at the
209    /// position specified by `axis`.
210    ///
211    /// See also [`expand_dims`].
212    pub fn expand_dims(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> TensorView<'_, T, B, IxD> {
213        into_expand_dims(self.view(), axes)
214    }
215
216    /// Expands the shape of an array by inserting a new axis (dimension) of size one at the
217    /// position specified by `axis`.
218    ///
219    /// See also [`expand_dims`].
220    pub fn expand_dims_f(
221        &self,
222        axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
223    ) -> Result<TensorView<'_, T, B, IxD>> {
224        into_expand_dims_f(self.view(), axes)
225    }
226
227    /// Expands the shape of an array by inserting a new axis (dimension) of size one at the
228    /// position specified by `axis`.
229    ///
230    /// See also [`expand_dims`].
231    pub fn into_expand_dims(self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> TensorAny<R, T, B, IxD> {
232        into_expand_dims(self, axes)
233    }
234
235    /// Expands the shape of an array by inserting a new axis (dimension) of size one at the
236    /// position specified by `axis`.
237    ///
238    /// See also [`expand_dims`].
239    pub fn into_expand_dims_f(
240        self,
241        axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
242    ) -> Result<TensorAny<R, T, B, IxD>> {
243        into_expand_dims_f(self, axes)
244    }
245
246    /// Expands the shape of an array by inserting a new axis (dimension) of size one at the
247    /// position specified by `axis`.
248    ///
249    /// See also [`expand_dims`].
250    pub fn unsqueeze(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> TensorView<'_, T, B, IxD> {
251        self.expand_dims(axes)
252    }
253
254    /// Expands the shape of an array by inserting a new axis (dimension) of size one at the
255    /// position specified by `axis`.
256    ///
257    /// See also [`expand_dims`].
258    pub fn unsqueeze_f(
259        &self,
260        axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
261    ) -> Result<TensorView<'_, T, B, IxD>> {
262        self.expand_dims_f(axes)
263    }
264
265    /// Expands the shape of an array by inserting a new axis (dimension) of size one at the
266    /// position specified by `axis`.
267    ///
268    /// See also [`expand_dims`].
269    pub fn into_unsqueeze(self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> TensorAny<R, T, B, IxD> {
270        self.into_expand_dims(axes)
271    }
272
273    /// Expands the shape of an array by inserting a new axis (dimension) of size one at the
274    /// position specified by `axis`.
275    ///
276    /// See also [`expand_dims`].
277    pub fn into_unsqueeze_f(
278        self,
279        axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
280    ) -> Result<TensorAny<R, T, B, IxD>> {
281        self.into_expand_dims_f(axes)
282    }
283}