rstsr_core/tensor/manuplication/
squeeze.rs1use crate::prelude_dev::*;
2
3pub fn into_squeeze_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, IxD>>
6where
7 D: DimAPI,
8 I: TryInto<AxesIndex<isize>, Error = Error>,
9{
10 let ndim: isize = TryInto::<isize>::try_into(tensor.ndim())?;
12 let (storage, layout) = tensor.into_raw_parts();
13 let mut layout = layout.into_dim::<IxD>()?;
14 let mut axes: Vec<isize> =
15 axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<_>();
16 axes.sort_by(|a, b| b.cmp(a));
17 if axes.first().is_some_and(|&v| v < 0) {
18 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
19 }
20 for i in 0..axes.len() - 1 {
22 rstsr_assert!(axes[i] != axes[i + 1], InvalidValue, "Same axes is not allowed here.")?;
23 }
24 for &axis in axes.iter() {
26 layout = layout.dim_eliminate(axis)?;
27 }
28 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
29}
30
31pub fn squeeze<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, IxD>
37where
38 D: DimAPI,
39 I: TryInto<AxesIndex<isize>, Error = Error>,
40 R: DataAPI<Data = B::Raw>,
41 B: DeviceAPI<T>,
42{
43 into_squeeze_f(tensor.view(), axes).rstsr_unwrap()
44}
45
46pub fn squeeze_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, IxD>>
47where
48 D: DimAPI,
49 I: TryInto<AxesIndex<isize>, Error = Error>,
50 R: DataAPI<Data = B::Raw>,
51 B: DeviceAPI<T>,
52{
53 into_squeeze_f(tensor.view(), axes)
54}
55
56pub fn into_squeeze<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, IxD>
57where
58 D: DimAPI,
59 I: TryInto<AxesIndex<isize>, Error = Error>,
60{
61 into_squeeze_f(tensor, axes).rstsr_unwrap()
62}
63
64impl<R, T, B, D> TensorAny<R, T, B, D>
65where
66 R: DataAPI<Data = B::Raw>,
67 B: DeviceAPI<T>,
68 D: DimAPI,
69{
70 pub fn squeeze<I>(&self, axis: I) -> TensorView<'_, T, B, IxD>
76 where
77 I: TryInto<AxesIndex<isize>, Error = Error>,
78 {
79 squeeze(self, axis)
80 }
81
82 pub fn squeeze_f<I>(&self, axis: I) -> Result<TensorView<'_, T, B, IxD>>
83 where
84 I: TryInto<AxesIndex<isize>, Error = Error>,
85 {
86 squeeze_f(self, axis)
87 }
88
89 pub fn into_squeeze<I>(self, axis: I) -> TensorAny<R, T, B, IxD>
95 where
96 I: TryInto<AxesIndex<isize>, Error = Error>,
97 {
98 into_squeeze(self, axis)
99 }
100
101 pub fn into_squeeze_f<I>(self, axis: I) -> Result<TensorAny<R, T, B, IxD>>
102 where
103 I: TryInto<AxesIndex<isize>, Error = Error>,
104 {
105 into_squeeze_f(self, axis)
106 }
107}
108
109