rstsr_core/tensor/manuplication/squeeze.rs
1use crate::prelude_dev::*;
2
3/// Removes singleton dimensions (axes) from a tensor.
4///
5/// See also [`squeeze`].
6pub fn into_squeeze_f<S, D>(
7 tensor: TensorBase<S, D>,
8 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
9) -> Result<TensorBase<S, IxD>>
10where
11 D: DimAPI,
12{
13 // convert axis to positive indexes and (reversed) sort
14 let ndim: isize = TryInto::<isize>::try_into(tensor.ndim())?;
15 let (storage, layout) = tensor.into_raw_parts();
16 let mut layout = layout.into_dim::<IxD>()?;
17 let axes = axes.try_into().map_err(Into::into)?;
18 let axes = match axes {
19 AxesIndex::None => {
20 // find all axes with size 1
21 let mut axes: Vec<isize> = Vec::new();
22 for i in (0..ndim).rev() {
23 if layout.shape()[i as usize] == 1 {
24 axes.push(i);
25 }
26 }
27 axes
28 },
29 _ => {
30 let mut axes: Vec<isize> = axes.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect();
31 axes.sort_by(|a, b| b.cmp(a));
32 if axes.first().is_some_and(|&v| v < 0) {
33 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
34 }
35 // check no two axis are the same
36 for i in 0..axes.len().saturating_sub(1) {
37 rstsr_assert!(axes[i] != axes[i + 1], InvalidValue, "Same axes is not allowed here.")?;
38 }
39 axes
40 },
41 };
42 // perform squeeze
43 for &axis in axes.iter() {
44 layout = layout.dim_eliminate(axis)?;
45 }
46 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
47}
48
49/// Removes singleton dimensions (axes) from `x`.
50///
51/// See also [`squeeze`].
52pub fn into_squeeze<S, D>(
53 tensor: TensorBase<S, D>,
54 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
55) -> TensorBase<S, IxD>
56where
57 D: DimAPI,
58{
59 into_squeeze_f(tensor, axes).rstsr_unwrap()
60}
61
62/// Removes singleton dimensions (axes) from a tensor.
63///
64/// # Parameters
65///
66/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
67///
68/// - The input tensor.
69/// - Note on variant [`into_squeeze`]: This takes ownership [`Tensor<R, T, B, D>`] of input
70/// tensor, and will not perform change to underlying data, only layout changes.
71///
72/// - `axes`: TryInto [`AxesIndex<isize>`]
73///
74/// - The axis (or axes) to squeeze.
75/// - If `axes` is a single integer, squeezing is performed along that axis.
76/// - If `axes` is a tuple/list of integers, squeezing is performed on all specified axes.
77/// - If `axes` is `None`, the function will squeeze all axes with size 1.
78/// - If `axes` is an empty tuple `()`, no axes are squeezed.
79/// - Negative values are supported and indicate counting dimensions from the back.
80/// - Each axis in `axes` must have size 1; otherwise an error is raised.
81///
82/// # Returns
83///
84/// - [`TensorView<'_, T, B, IxD>`](TensorView)
85///
86/// - A view of the input tensor with the specified singleton dimensions removed.
87/// - The underlying data is not copied; only the layout of the view is modified.
88/// - If you want to convert the tensor itself (taking the ownership instead of returning view),
89/// use [`into_squeeze`] instead.
90///
91/// # Panics
92///
93/// - If an axis specified does not have size 1.
94/// - If an axis is out of bounds.
95/// - If `axes` has duplicated values.
96///
97/// # Examples
98///
99/// ## Squeezing a single axis
100///
101/// Squeeze a tensor along axis 0:
102///
103/// ```rust
104/// # use rstsr::prelude::*;
105/// # let mut device = DeviceCpu::default();
106/// # device.set_default_order(RowMajor);
107/// let a: Tensor<f64, _> = rt::zeros(([1, 3, 1, 4], &device));
108/// let b = a.squeeze(0);
109/// assert_eq!(b.shape(), &[3, 1, 4]);
110/// ```
111///
112/// Squeeze a tensor along the axis 2 (third axis with size 1):
113///
114/// ```rust
115/// # use rstsr::prelude::*;
116/// # let mut device = DeviceCpu::default();
117/// # device.set_default_order(RowMajor);
118/// # let a: Tensor<f64, _> = rt::zeros(([1, 3, 1, 4], &device));
119/// let b = a.squeeze(2);
120/// assert_eq!(b.shape(), &[1, 3, 4]);
121/// ```
122///
123/// Squeeze using negative index (-2 refers to the axis with size 1):
124///
125/// ```rust
126/// # use rstsr::prelude::*;
127/// # let mut device = DeviceCpu::default();
128/// # device.set_default_order(RowMajor);
129/// # let a: Tensor<f64, _> = rt::zeros(([1, 3, 1, 4], &device));
130/// let b = a.squeeze(-2);
131/// assert_eq!(b.shape(), &[1, 3, 4]);
132/// ```
133///
134/// ## Squeezing multiple axes
135///
136/// Squeeze multiple axes at once:
137///
138/// ```rust
139/// # use rstsr::prelude::*;
140/// # let mut device = DeviceCpu::default();
141/// # device.set_default_order(RowMajor);
142/// # let a: Tensor<f64, _> = rt::zeros(([1, 3, 1, 4, 1], &device));
143/// let b = a.squeeze([0, 2]);
144/// assert_eq!(b.shape(), &[3, 4, 1]);
145/// ```
146///
147/// Use negative indices to squeeze from the back:
148///
149/// ```rust
150/// # use rstsr::prelude::*;
151/// # let mut device = DeviceCpu::default();
152/// # device.set_default_order(RowMajor);
153/// # let a: Tensor<f64, _> = rt::zeros(([1, 3, 1, 4, 1], &device));
154/// let b = a.squeeze([0, -1]);
155/// assert_eq!(b.shape(), &[3, 1, 4]);
156/// ```
157///
158/// ## Squeezing all singleton axes
159///
160/// Use `None` to squeeze all axes with size 1:
161///
162/// ```rust
163/// # use rstsr::prelude::*;
164/// # let mut device = DeviceCpu::default();
165/// # device.set_default_order(RowMajor);
166/// let a: Tensor<f64, _> = rt::zeros(([1, 3, 1, 4, 1], &device));
167/// let b = a.squeeze(None);
168/// assert_eq!(b.shape(), &[3, 4]);
169/// ```
170///
171/// ## No squeezing (empty axes)
172///
173/// Use an empty tuple `()` to squeeze no axes (returns a view of the original tensor):
174///
175/// ```rust
176/// # use rstsr::prelude::*;
177/// # let mut device = DeviceCpu::default();
178/// # device.set_default_order(RowMajor);
179/// # let a: Tensor<f64, _> = rt::zeros(([1, 3, 1, 4, 1], &device));
180/// let b = a.squeeze(());
181/// assert_eq!(b.shape(), &[1, 3, 1, 4, 1]);
182/// ```
183///
184/// # Notes of API accordance
185///
186/// - Array-API: `squeeze(x, /, axis)` ([`squeeze`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.squeeze.html))
187/// - NumPy: `squeeze(a, axis=None)` ([`numpy.squeeze`](https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html))
188/// - RSTSR: `rt::squeeze(tensor, axes)`
189///
190/// RSTSR's behavior matches NumPy and Array-API:
191/// - `a.squeeze(None)` squeezes all axes with size 1
192/// - `a.squeeze(())` squeezes no axes (returns a view of the original tensor)
193///
194/// # See also
195///
196/// ## Related functions in RSTSR
197///
198/// - [`expand_dims`]: Adds singleton dimensions (axes) to a tensor.
199///
200/// ## Variants of this function
201///
202/// - [`squeeze`] / [`squeeze_f`]: Returning a view.
203/// - [`into_squeeze`] / [`into_squeeze_f`]: Consuming version.
204/// - Associated methods on [`TensorAny`]:
205///
206/// - [`TensorAny::squeeze`] / [`TensorAny::squeeze_f`]
207/// - [`TensorAny::into_squeeze`] / [`TensorAny::into_squeeze_f`]
208pub fn squeeze<R, T, B, D>(
209 tensor: &TensorAny<R, T, B, D>,
210 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
211) -> TensorView<'_, T, B, IxD>
212where
213 D: DimAPI,
214 R: DataAPI<Data = B::Raw>,
215 B: DeviceAPI<T>,
216{
217 into_squeeze_f(tensor.view(), axes).rstsr_unwrap()
218}
219
220/// Removes singleton dimensions (axes) from a tensor.
221///
222/// See also [`squeeze`].
223pub fn squeeze_f<R, T, B, D>(
224 tensor: &TensorAny<R, T, B, D>,
225 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
226) -> Result<TensorView<'_, T, B, IxD>>
227where
228 D: DimAPI,
229 R: DataAPI<Data = B::Raw>,
230 B: DeviceAPI<T>,
231{
232 into_squeeze_f(tensor.view(), axes)
233}
234
235impl<R, T, B, D> TensorAny<R, T, B, D>
236where
237 R: DataAPI<Data = B::Raw>,
238 B: DeviceAPI<T>,
239 D: DimAPI,
240{
241 /// Removes singleton dimensions (axes) from a tensor.
242 ///
243 /// See also [`squeeze`].
244 pub fn squeeze(&self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> TensorView<'_, T, B, IxD> {
245 squeeze(self, axes)
246 }
247
248 /// Removes singleton dimensions (axes) from a tensor.
249 ///
250 /// See also [`squeeze`].
251 pub fn squeeze_f(
252 &self,
253 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
254 ) -> Result<TensorView<'_, T, B, IxD>> {
255 squeeze_f(self, axes)
256 }
257
258 /// Removes singleton dimensions (axes) from a tensor.
259 ///
260 /// See also [`squeeze`].
261 pub fn into_squeeze(self, axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>) -> TensorAny<R, T, B, IxD> {
262 into_squeeze(self, axes)
263 }
264
265 /// Removes singleton dimensions (axes) from a tensor.
266 ///
267 /// See also [`squeeze`].
268 pub fn into_squeeze_f(
269 self,
270 axes: impl TryInto<AxesIndex<isize>, Error: Into<Error>>,
271 ) -> Result<TensorAny<R, T, B, IxD>> {
272 into_squeeze_f(self, axes)
273 }
274}