rstsr_core/tensor/manuplication/
expand_dims.rs

1use crate::prelude_dev::*;
2
3/* #region expand_dims */
4
5/// Expands the shape of an array by inserting a new axis (dimension) of size
6/// one at the position specified by `axis`.
7///
8/// # See also
9///
10/// Refer to [`expand_dims`] and [`into_expand_dims`] for more detailed documentation.
11pub fn into_expand_dims_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, IxD>>
12where
13    D: DimAPI,
14    I: TryInto<AxesIndex<isize>, Error = Error>,
15{
16    // convert axis to negative indexes and sort
17    let ndim = tensor.ndim();
18    let (storage, layout) = tensor.into_raw_parts();
19    let mut layout = layout.into_dim::<IxD>()?;
20    let axes = axes.try_into()?;
21    let len_axes = axes.as_ref().len();
22    let axes = normalize_axes_index(axes, ndim + len_axes, false)?;
23    for axis in axes {
24        layout = layout.dim_insert(axis)?;
25    }
26    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
27}
28
29/// Expands the shape of an array by inserting a new axis (dimension) of size
30/// one at the position specified by `axis`.
31///
32/// # Parameters
33///
34/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
35///
36///   - The input tensor.
37///
38/// - `axes`: TryInto [`AxesIndex<isize>`]
39///
40///   - Position in the expanded axes where the new axis (or axes) is placed.
41///   - Can be a single integer, or a list/tuple of integers.
42///   - Negative values are supported and indicate counting dimensions from the back.
43///
44/// # Returns
45///
46/// - [`TensorView<'_, T, B, IxD>`](TensorView)
47///
48///   - A view of the input tensor with the new axis (or axes) inserted.
49///   - If you want to convert the tensor itself (taking the ownership instead of returning view),
50///     use [`into_expand_dims`] instead.
51///
52/// # Panics
53///
54/// - If `axis` is greater than the number of axes in the original tensor.
55/// - If expaneded axis has duplicated values.
56///
57/// # Examples
58///
59/// We first initialize a vector of shape (2,):
60///
61/// ```rust
62/// use rstsr::prelude::*;
63/// let x = rt::asarray(vec![1, 2]);
64/// ```
65///
66/// Expand dims at axis 0, which is equilvalent to `x.i(None)`:
67///
68/// ```rust
69/// # use rstsr::prelude::*;
70/// # let x = rt::asarray(vec![1, 2]);
71/// // [1, 2] -> [[1, 2]]
72/// let y = x.expand_dims(0);
73/// let y_expected = rt::tensor_from_nested!([[1, 2]]);
74/// assert!(rt::allclose(&y, &y_expected, None));
75/// assert_eq!(y.shape(), &[1, 2]);
76/// assert_eq!(x.i(None).shape(), &[1, 2]);
77/// ```
78///
79/// Expand dims at axis -1 (last axis), which is equilvalent to `x.i((Ellipsis, None))`, or in this
80/// 1-dimension specific case, `x.i((.., None))`:
81///
82/// ```rust
83/// # use rstsr::prelude::*;
84/// # let x = rt::asarray(vec![1, 2]);
85/// // [1, 2] -> [[1], [2]]
86/// let y = x.expand_dims(-1);
87/// let y_expected = rt::tensor_from_nested!([[1], [2]]);
88/// assert!(rt::allclose(&y, &y_expected, None));
89/// assert_eq!(y.shape(), &[2, 1]);
90/// assert_eq!(x.i((Ellipsis, None)).shape(), &[2, 1]);
91/// ```
92///
93/// Expand dims at axes 0 and 1, which is equilvalent to `x.i((None, None))`:
94///
95/// ```rust
96/// # use rstsr::prelude::*;
97/// # let x = rt::asarray(vec![1, 2]);
98/// // Expand dims at axes 0 and 1
99/// // [1, 2] -> [[[1, 2]]]
100/// let y = x.expand_dims([0, 1]);
101/// let y_expected = rt::tensor_from_nested!([[[1, 2]]]);
102/// assert!(rt::allclose(&y, &y_expected, None));
103/// assert_eq!(y.shape(), &[1, 1, 2]);
104/// assert_eq!(x.i((None, None)).shape(), &[1, 1, 2]);
105/// ```
106///
107/// /// Expand dims at axes 0 and 2, which is equilvalent to `x.i((None, Ellipsis, None))`:
108///
109/// ```rust
110/// # use rstsr::prelude::*;
111/// # let x = rt::asarray(vec![1, 2]);
112/// // Expand dims at axes 0 and 2
113/// // [1, 2] -> [[[1], [2]]]
114/// let y = x.expand_dims([0, 2]);
115/// let y_expected = rt::tensor_from_nested!([[[1], [2]]]);
116/// assert!(rt::allclose(&y, &y_expected, None));
117/// assert_eq!(y.shape(), &[1, 2, 1]);
118/// assert_eq!(x.i((None, Ellipsis, None)).shape(), &[1, 2, 1]);
119/// ```
120///
121/// # See also
122///
123/// ## Similar functions from other crates/libraries
124///
125/// - Python Array API standard: [`expand_dims`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.expand_dims.html)
126/// - NumPy: [`numpy.expand_dims`](https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html)
127/// - PyTorch: [`torch.unsqueeze`](https://pytorch.org/docs/stable/generated/torch.unsqueeze.html)
128///
129/// ## Related functions in RSTSR
130///
131/// - [`i`](TensorAny::i) or [`slice`](slice()): Basic indexing and slicing of tensors, without
132///   modification of the underlying data.
133/// - [`squeeze`]: Removes singleton dimensions (axes) from `x`.
134///
135/// ## Variants of this function
136///
137/// - [`expand_dims_f`]: Failable version.
138/// - [`into_expand_dims`]: Consuming version.
139/// - [`into_expand_dims_f`]: Failable and consuming version, actual implementation.
140pub fn expand_dims<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, IxD>
141where
142    D: DimAPI,
143    I: TryInto<AxesIndex<isize>, Error = Error>,
144    R: DataAPI<Data = B::Raw>,
145    B: DeviceAPI<T>,
146{
147    into_expand_dims_f(tensor.view(), axes).rstsr_unwrap()
148}
149
150/// Expands the shape of an array by inserting a new axis (dimension) of size
151/// one at the position specified by `axis`.
152///
153/// # See also
154///
155/// Refer to [`expand_dims`] and [`into_expand_dims`] for more detailed documentation.
156pub fn expand_dims_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, IxD>>
157where
158    D: DimAPI,
159    I: TryInto<AxesIndex<isize>, Error = Error>,
160    R: DataAPI<Data = B::Raw>,
161    B: DeviceAPI<T>,
162{
163    into_expand_dims_f(tensor.view(), axes)
164}
165
166/// Expands the shape of an array by inserting a new axis (dimension) of size
167/// one at the position specified by `axis`.
168///
169/// # Parameters
170///
171/// - `tensor`: [`TensorBase<S, D>`]
172///
173///   - The input tensor.
174///   - Please note that this function takes ownership of the input tensor.
175///
176/// - `axes`: TryInto [`AxesIndex<isize>`]
177///
178///   - Position in the expanded axes where the new axis (or axes) is placed.
179///   - Can be a single integer, or a list/tuple of integers.
180///   - Negative values are supported and indicate counting dimensions from the back.
181///
182/// # Returns
183///
184/// - [`TensorBase<S, IxD>`]
185///
186///   - The tensor with the new axis (or axes) inserted.
187///   - Ownership of the returned tensor is transferred from the input tensor. Only the layout is
188///     modified; the underlying data remains unchanged.
189///
190/// # See also
191///
192/// Refer to [`expand_dims`] for more detailed documentation.
193pub fn into_expand_dims<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, IxD>
194where
195    D: DimAPI,
196    I: TryInto<AxesIndex<isize>, Error = Error>,
197{
198    into_expand_dims_f(tensor, axes).rstsr_unwrap()
199}
200
201impl<R, T, B, D> TensorAny<R, T, B, D>
202where
203    R: DataAPI<Data = B::Raw>,
204    B: DeviceAPI<T>,
205    D: DimAPI,
206{
207    /// Expands the shape of an array by inserting a new axis (dimension) of size
208    /// one at the position specified by `axis`.
209    ///
210    /// # See also
211    ///
212    /// Refer to [`expand_dims`] and [`into_expand_dims`] for more detailed documentation.
213    pub fn expand_dims<I>(&self, axes: I) -> TensorView<'_, T, B, IxD>
214    where
215        I: TryInto<AxesIndex<isize>, Error = Error>,
216    {
217        into_expand_dims(self.view(), axes)
218    }
219
220    /// Expands the shape of an array by inserting a new axis (dimension) of size
221    /// one at the position specified by `axis`.
222    ///
223    /// # See also
224    ///
225    /// Refer to [`expand_dims`] and [`into_expand_dims`] for more detailed documentation.
226    pub fn expand_dims_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, IxD>>
227    where
228        I: TryInto<AxesIndex<isize>, Error = Error>,
229    {
230        into_expand_dims_f(self.view(), axes)
231    }
232
233    /// Expands the shape of an array by inserting a new axis (dimension) of size
234    /// one at the position specified by `axis`.
235    ///
236    /// # See also
237    ///
238    /// Refer to [`expand_dims`] and [`into_expand_dims`] for more detailed documentation.
239    pub fn into_expand_dims<I>(self, axes: I) -> TensorAny<R, T, B, IxD>
240    where
241        I: TryInto<AxesIndex<isize>, Error = Error>,
242    {
243        into_expand_dims(self, axes)
244    }
245
246    /// Expands the shape of an array by inserting a new axis (dimension) of size
247    /// one at the position specified by `axis`.
248    ///
249    /// # See also
250    ///
251    /// Refer to [`expand_dims`] and [`into_expand_dims`] for more detailed documentation.
252    pub fn into_expand_dims_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, IxD>>
253    where
254        I: TryInto<AxesIndex<isize>, Error = Error>,
255    {
256        into_expand_dims_f(self, axes)
257    }
258}
259
260/* #endregion */
261
262#[cfg(test)]
263mod tests {
264    #[test]
265    fn doc_expand_dims() {
266        use rstsr::prelude::*;
267        let x = rt::asarray(vec![1, 2]);
268
269        // Expand dims at axis 0
270        // [1, 2] -> [[1, 2]]
271        let y = x.expand_dims(0);
272        let y_expected = rt::tensor_from_nested!([[1, 2]]);
273        assert!(rt::allclose(&y, &y_expected, None));
274        assert_eq!(y.shape(), &[1, 2]);
275        assert_eq!(x.i(None).shape(), &[1, 2]);
276
277        // Expand dims at axis -1 (last axis)
278        // [1, 2] -> [[1], [2]]
279        let y = x.expand_dims(-1);
280        let y_expected = rt::tensor_from_nested!([[1], [2]]);
281        assert!(rt::allclose(&y, &y_expected, None));
282        assert_eq!(y.shape(), &[2, 1]);
283        assert_eq!(x.i((Ellipsis, None)).shape(), &[2, 1]);
284
285        // Expand dims at axes 0 and 1
286        // [1, 2] -> [[[1, 2]]]
287        let y = x.expand_dims([0, 1]);
288        let y_expected = rt::tensor_from_nested!([[[1, 2]]]);
289        assert!(rt::allclose(&y, &y_expected, None));
290        assert_eq!(y.shape(), &[1, 1, 2]);
291        assert_eq!(x.i((None, None)).shape(), &[1, 1, 2]);
292
293        // Expand dims at axes 0 and 2
294        // [1, 2] -> [[[1], [2]]]
295        let y = x.expand_dims([0, 2]);
296        let y_expected = rt::tensor_from_nested!([[[1], [2]]]);
297        assert!(rt::allclose(&y, &y_expected, None));
298        assert_eq!(y.shape(), &[1, 2, 1]);
299        assert_eq!(x.i((None, Ellipsis, None)).shape(), &[1, 2, 1]);
300    }
301}