rstsr_core/tensor/manuplication/
squeeze.rs

1use crate::prelude_dev::*;
2
3/* #region squeeze */
4
5pub fn into_squeeze_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, IxD>>
6where
7    D: DimAPI,
8    I: TryInto<AxesIndex<isize>, Error = Error>,
9{
10    // convert axis to positive indexes and (reversed) sort
11    let ndim: isize = TryInto::<isize>::try_into(tensor.ndim())?;
12    let (storage, layout) = tensor.into_raw_parts();
13    let mut layout = layout.into_dim::<IxD>()?;
14    let mut axes: Vec<isize> =
15        axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<_>();
16    axes.sort_by(|a, b| b.cmp(a));
17    if axes.first().is_some_and(|&v| v < 0) {
18        return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
19    }
20    // check no two axis are the same
21    for i in 0..axes.len() - 1 {
22        rstsr_assert!(axes[i] != axes[i + 1], InvalidValue, "Same axes is not allowed here.")?;
23    }
24    // perform squeeze
25    for &axis in axes.iter() {
26        layout = layout.dim_eliminate(axis)?;
27    }
28    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
29}
30
31/// Removes singleton dimensions (axes) from `x`.
32///
33/// # See also
34///
35/// [Python array API standard: `squeeze`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.squeeze.html)
36pub fn squeeze<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, IxD>
37where
38    D: DimAPI,
39    I: TryInto<AxesIndex<isize>, Error = Error>,
40    R: DataAPI<Data = B::Raw>,
41    B: DeviceAPI<T>,
42{
43    into_squeeze_f(tensor.view(), axes).rstsr_unwrap()
44}
45
46pub fn squeeze_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, IxD>>
47where
48    D: DimAPI,
49    I: TryInto<AxesIndex<isize>, Error = Error>,
50    R: DataAPI<Data = B::Raw>,
51    B: DeviceAPI<T>,
52{
53    into_squeeze_f(tensor.view(), axes)
54}
55
56pub fn into_squeeze<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, IxD>
57where
58    D: DimAPI,
59    I: TryInto<AxesIndex<isize>, Error = Error>,
60{
61    into_squeeze_f(tensor, axes).rstsr_unwrap()
62}
63
64impl<R, T, B, D> TensorAny<R, T, B, D>
65where
66    R: DataAPI<Data = B::Raw>,
67    B: DeviceAPI<T>,
68    D: DimAPI,
69{
70    /// Removes singleton dimensions (axes) from `x`.
71    ///
72    /// # See also
73    ///
74    /// [`squeeze`]
75    pub fn squeeze<I>(&self, axis: I) -> TensorView<'_, T, B, IxD>
76    where
77        I: TryInto<AxesIndex<isize>, Error = Error>,
78    {
79        squeeze(self, axis)
80    }
81
82    pub fn squeeze_f<I>(&self, axis: I) -> Result<TensorView<'_, T, B, IxD>>
83    where
84        I: TryInto<AxesIndex<isize>, Error = Error>,
85    {
86        squeeze_f(self, axis)
87    }
88
89    /// Removes singleton dimensions (axes) from `x`.
90    ///
91    /// # See also
92    ///
93    /// [`squeeze`]
94    pub fn into_squeeze<I>(self, axis: I) -> TensorAny<R, T, B, IxD>
95    where
96        I: TryInto<AxesIndex<isize>, Error = Error>,
97    {
98        into_squeeze(self, axis)
99    }
100
101    pub fn into_squeeze_f<I>(self, axis: I) -> Result<TensorAny<R, T, B, IxD>>
102    where
103        I: TryInto<AxesIndex<isize>, Error = Error>,
104    {
105        into_squeeze_f(self, axis)
106    }
107}
108
109/* #endregion */