1use crate::prelude_dev::*;
7
8pub fn index_select_f<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, axis: isize, indices: I) -> Result<Tensor<T, B, D>>
11where
12 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
13 D: DimAPI + DimSmallerOneAPI,
14 D::SmallerOne: DimAPI,
15 B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
16 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
17{
18 let device = tensor.device().clone();
20 let tensor_layout = tensor.layout();
21 let ndim = tensor_layout.ndim();
22 let axis = if axis < 0 { ndim as isize + axis } else { axis };
24 rstsr_pattern!(axis, 0..ndim as isize, InvalidLayout, "Invalid axis that exceeds ndim.")?;
25 let axis = axis as usize;
26 let nshape: usize = tensor_layout.shape()[axis];
27 let indices = indices.try_into().map_err(Into::into)?;
28 let indices = indices
29 .as_ref()
30 .iter()
31 .map(|&i| -> Result<usize> {
32 let i = if i < 0 { nshape as isize + i } else { i };
33 rstsr_pattern!(
34 i,
35 0..nshape as isize,
36 InvalidLayout,
37 "Invalid index that exceeds shape length at axis {}.",
38 axis
39 )?;
40 Ok(i as usize)
41 })
42 .collect::<Result<Vec<usize>>>()?;
43 let mut out_shape = tensor_layout.shape().as_ref().to_vec();
44 out_shape[axis] = indices.len();
45 let out_layout = out_shape.new_contig(None, device.default_order()).into_dim()?;
46 let mut out_storage = device.uninit_impl(out_layout.size())?;
47 device.index_select(out_storage.raw_mut(), &out_layout, tensor.storage().raw(), tensor_layout, axis, &indices)?;
48 let out_storage = unsafe { B::assume_init_impl(out_storage)? };
49 TensorBase::new_f(out_storage, out_layout)
50}
51
52pub fn index_select<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, axis: isize, indices: I) -> Tensor<T, B, D>
59where
60 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
61 D: DimAPI + DimSmallerOneAPI,
62 D::SmallerOne: DimAPI,
63 B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
64 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
65{
66 index_select_f(tensor, axis, indices).rstsr_unwrap()
67}
68
69pub fn take_f<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, indices: I, axis: isize) -> Result<Tensor<T, B, D>>
70where
71 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
72 D: DimAPI + DimSmallerOneAPI,
73 D::SmallerOne: DimAPI,
74 B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
75 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
76{
77 index_select_f(tensor, axis, indices)
78}
79
80pub fn take<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, indices: I, axis: isize) -> Tensor<T, B, D>
86where
87 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
88 D: DimAPI + DimSmallerOneAPI,
89 D::SmallerOne: DimAPI,
90 B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
91 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
92{
93 index_select(tensor, axis, indices)
94}
95
96impl<R, T, B, D> TensorAny<R, T, B, D>
97where
98 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
99 D: DimAPI + DimSmallerOneAPI,
100 D::SmallerOne: DimAPI,
101 B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
102{
103 pub fn index_select_f<I>(&self, axis: isize, indices: I) -> Result<Tensor<T, B, D>>
104 where
105 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
106 {
107 index_select_f(self, axis, indices)
108 }
109
110 pub fn index_select<I>(&self, axis: isize, indices: I) -> Tensor<T, B, D>
117 where
118 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
119 {
120 index_select(self, axis, indices)
121 }
122
123 pub fn take_f<I>(&self, indices: I, axis: isize) -> Result<Tensor<T, B, D>>
124 where
125 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
126 {
127 take_f(self, indices, axis)
128 }
129
130 pub fn take<I>(&self, indices: I, axis: isize) -> Tensor<T, B, D>
136 where
137 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
138 {
139 take(self, indices, axis)
140 }
141}
142
143pub fn bool_select_f<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, axis: isize, mask: I) -> Result<Tensor<T, B, D>>
148where
149 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
150 D: DimAPI + DimSmallerOneAPI,
151 D::SmallerOne: DimAPI,
152 B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
153 I: TryInto<AxesIndex<bool>, Error: Into<Error>>,
154{
155 let indices = mask
157 .try_into()
158 .map_err(Into::into)?
159 .as_ref()
160 .iter()
161 .enumerate()
162 .filter_map(|(i, &m)| m.then_some(i))
163 .collect::<Vec<usize>>();
164 index_select_f(tensor, axis, indices)
165}
166
167pub fn bool_select<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, axis: isize, mask: I) -> Tensor<T, B, D>
170where
171 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
172 D: DimAPI + DimSmallerOneAPI,
173 D::SmallerOne: DimAPI,
174 B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
175 I: TryInto<AxesIndex<bool>, Error: Into<Error>>,
176{
177 bool_select_f(tensor, axis, mask).rstsr_unwrap()
178}
179
180impl<R, T, B, D> TensorAny<R, T, B, D>
181where
182 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
183 D: DimAPI + DimSmallerOneAPI,
184 D::SmallerOne: DimAPI,
185 B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
186{
187 pub fn bool_select_f<I>(&self, axis: isize, indices: I) -> Result<Tensor<T, B, D>>
188 where
189 I: TryInto<AxesIndex<bool>, Error: Into<Error>>,
190 {
191 bool_select_f(self, axis, indices)
192 }
193
194 pub fn bool_select<I>(&self, axis: isize, indices: I) -> Tensor<T, B, D>
197 where
198 I: TryInto<AxesIndex<bool>, Error: Into<Error>>,
199 {
200 bool_select(self, axis, indices)
201 }
202}
203
204#[cfg(test)]
207mod test {
208 use super::*;
209
210 #[test]
211 fn test_index_select() {
212 #[cfg(not(feature = "col_major"))]
213 {
214 let device = DeviceCpuSerial::default();
215 let a = linspace((1.0, 24.0, 24, &device)).into_shape((2, 3, 4));
216 let b = a.index_select(0, [0, 0, 1, -1]);
217 assert!(fingerprint(&b) - -31.94175930917264 < 1e-8);
218 let b = a.index_select(1, [0, 0, 1, -1]);
219 assert!(fingerprint(&b) - 3.5719025258942088 < 1e-8);
220 let b = a.index_select(2, [0, 0, 1, -1]);
221 assert!(fingerprint(&b) - -25.648600916145096 < 1e-8);
222 }
223 #[cfg(feature = "col_major")]
224 {
225 let device = DeviceCpuSerial::default();
226 let a = linspace((1.0, 24.0, 24, &device)).into_shape((4, 3, 2));
227 let b = a.index_select(2, [0, 0, 1, -1]);
228 assert!(fingerprint(&b) - -31.94175930917264 < 1e-8);
229 let b = a.index_select(1, [0, 0, 1, -1]);
230 assert!(fingerprint(&b) - 3.5719025258942088 < 1e-8);
231 let b = a.index_select(0, [0, 0, 1, -1]);
232 assert!(fingerprint(&b) - -25.648600916145096 < 1e-8);
233 }
234 }
235
236 #[test]
237 fn test_index_select_default_device() {
238 #[cfg(not(feature = "col_major"))]
239 {
240 let device = DeviceCpu::default();
241 let a = linspace((1.0, 2.0, 256 * 256 * 256, &device)).into_shape((256, 256, 256));
242 let sel = [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233];
243 let b = a.index_select(0, &sel);
244 assert!(fingerprint(&b) - 0.9357016252766746 < 1e-10);
245 let b = a.index_select(1, &sel);
246 assert!(fingerprint(&b) - 1.012193909979973 < 1e-10);
247 let b = a.index_select(2, &sel);
248 assert!(fingerprint(&b) - 1.010735112247236 < 1e-10);
249 }
250 #[cfg(feature = "col_major")]
251 {
252 let device = DeviceCpu::default();
253 let a = linspace((1.0, 2.0, 256 * 256 * 256, &device)).into_shape((256, 256, 256));
254 let sel = [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233];
255 let b = a.index_select(2, &sel);
256 assert!(fingerprint(&b) - 0.9357016252766746 < 1e-10);
257 let b = a.index_select(1, &sel);
258 assert!(fingerprint(&b) - 1.012193909979973 < 1e-10);
259 let b = a.index_select(0, &sel);
260 assert!(fingerprint(&b) - 1.010735112247236 < 1e-10);
261 }
262 }
263
264 #[test]
265 fn test_bool_select_workable() {
266 let a = arange(24).into_shape((2, 3, 4));
267 let b = a.bool_select(-2, [true, false, true]);
268 println!("{b:?}");
269 }
270}