Skip to main content

rstsr_core/tensor/
adv_indexing.rs

1//! Advanced indexing related tensor manuplications.
2//!
3//! Currently, full support of advanced indexing is not available. However, it
4//! is still possible to index one axis by list.
5
6use crate::prelude_dev::*;
7
8/* #region index_select */
9
10pub fn index_select_f<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, axis: isize, indices: I) -> Result<Tensor<T, B, D>>
11where
12    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
13    D: DimAPI + DimSmallerOneAPI,
14    D::SmallerOne: DimAPI,
15    B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
16    I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
17{
18    // TODO: output layout control (TensorIterOrder::K or default layout)
19    let device = tensor.device().clone();
20    let tensor_layout = tensor.layout();
21    let ndim = tensor_layout.ndim();
22    // check axis and index
23    let axis = if axis < 0 { ndim as isize + axis } else { axis };
24    rstsr_pattern!(axis, 0..ndim as isize, InvalidLayout, "Invalid axis that exceeds ndim.")?;
25    let axis = axis as usize;
26    let nshape: usize = tensor_layout.shape()[axis];
27    let indices = indices.try_into().map_err(Into::into)?;
28    let indices = indices
29        .as_ref()
30        .iter()
31        .map(|&i| -> Result<usize> {
32            let i = if i < 0 { nshape as isize + i } else { i };
33            rstsr_pattern!(
34                i,
35                0..nshape as isize,
36                InvalidLayout,
37                "Invalid index that exceeds shape length at axis {}.",
38                axis
39            )?;
40            Ok(i as usize)
41        })
42        .collect::<Result<Vec<usize>>>()?;
43    let mut out_shape = tensor_layout.shape().as_ref().to_vec();
44    out_shape[axis] = indices.len();
45    let out_layout = out_shape.new_contig(None, device.default_order()).into_dim()?;
46    let mut out_storage = device.uninit_impl(out_layout.size())?;
47    device.index_select(out_storage.raw_mut(), &out_layout, tensor.storage().raw(), tensor_layout, axis, &indices)?;
48    let out_storage = unsafe { B::assume_init_impl(out_storage)? };
49    TensorBase::new_f(out_storage, out_layout)
50}
51
52/// Returns a new tensor, which indexes the input tensor along dimension `axis`
53/// using the entries in `indices`.
54///
55/// # See also
56///
57/// This function should be similar to PyTorch's [`torch.index_select`](https://docs.pytorch.org/docs/stable/generated/torch.index_select.html).
58pub fn index_select<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, axis: isize, indices: I) -> Tensor<T, B, D>
59where
60    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
61    D: DimAPI + DimSmallerOneAPI,
62    D::SmallerOne: DimAPI,
63    B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
64    I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
65{
66    index_select_f(tensor, axis, indices).rstsr_unwrap()
67}
68
69pub fn take_f<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, indices: I, axis: isize) -> Result<Tensor<T, B, D>>
70where
71    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
72    D: DimAPI + DimSmallerOneAPI,
73    D::SmallerOne: DimAPI,
74    B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
75    I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
76{
77    index_select_f(tensor, axis, indices)
78}
79
80/// Take elements from an array along an axis.
81///
82/// # See also
83///
84/// [Python Array API standard: take](https://data-apis.org/array-api/latest/API_specification/generated/array_api.take.html#array_api.take)
85pub fn take<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, indices: I, axis: isize) -> Tensor<T, B, D>
86where
87    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
88    D: DimAPI + DimSmallerOneAPI,
89    D::SmallerOne: DimAPI,
90    B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
91    I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
92{
93    index_select(tensor, axis, indices)
94}
95
96impl<R, T, B, D> TensorAny<R, T, B, D>
97where
98    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
99    D: DimAPI + DimSmallerOneAPI,
100    D::SmallerOne: DimAPI,
101    B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
102{
103    pub fn index_select_f<I>(&self, axis: isize, indices: I) -> Result<Tensor<T, B, D>>
104    where
105        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
106    {
107        index_select_f(self, axis, indices)
108    }
109
110    /// Returns a new tensor, which indexes the input tensor along dimension
111    /// `axis` using the entries in `indices`.
112    ///
113    /// # See also
114    ///
115    /// This function should be similar to PyTorch's [`torch.index_select`](https://docs.pytorch.org/docs/stable/generated/torch.index_select.html).
116    pub fn index_select<I>(&self, axis: isize, indices: I) -> Tensor<T, B, D>
117    where
118        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
119    {
120        index_select(self, axis, indices)
121    }
122
123    pub fn take_f<I>(&self, indices: I, axis: isize) -> Result<Tensor<T, B, D>>
124    where
125        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
126    {
127        take_f(self, indices, axis)
128    }
129
130    /// Take elements from an array along an axis.
131    ///
132    /// # See also
133    ///
134    /// [Python Array API standard: take](https://data-apis.org/array-api/latest/API_specification/generated/array_api.take.html#array_api.take)
135    pub fn take<I>(&self, indices: I, axis: isize) -> Tensor<T, B, D>
136    where
137        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
138    {
139        take(self, indices, axis)
140    }
141}
142
143/* #endregion */
144
145/* #region bool_select */
146
147pub fn bool_select_f<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, axis: isize, mask: I) -> Result<Tensor<T, B, D>>
148where
149    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
150    D: DimAPI + DimSmallerOneAPI,
151    D::SmallerOne: DimAPI,
152    B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
153    I: TryInto<AxesIndex<bool>, Error: Into<Error>>,
154{
155    // transform bool to index
156    let indices = mask
157        .try_into()
158        .map_err(Into::into)?
159        .as_ref()
160        .iter()
161        .enumerate()
162        .filter_map(|(i, &m)| m.then_some(i))
163        .collect::<Vec<usize>>();
164    index_select_f(tensor, axis, indices)
165}
166
167/// Returns a new tensor, which indexes the input tensor along dimension `axis`
168/// using the boolean entries in `mask`.
169pub fn bool_select<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, axis: isize, mask: I) -> Tensor<T, B, D>
170where
171    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
172    D: DimAPI + DimSmallerOneAPI,
173    D::SmallerOne: DimAPI,
174    B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
175    I: TryInto<AxesIndex<bool>, Error: Into<Error>>,
176{
177    bool_select_f(tensor, axis, mask).rstsr_unwrap()
178}
179
180impl<R, T, B, D> TensorAny<R, T, B, D>
181where
182    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
183    D: DimAPI + DimSmallerOneAPI,
184    D::SmallerOne: DimAPI,
185    B: DeviceAPI<T> + DeviceIndexSelectAPI<T, D> + DeviceCreationAnyAPI<T>,
186{
187    pub fn bool_select_f<I>(&self, axis: isize, indices: I) -> Result<Tensor<T, B, D>>
188    where
189        I: TryInto<AxesIndex<bool>, Error: Into<Error>>,
190    {
191        bool_select_f(self, axis, indices)
192    }
193
194    /// Returns a new tensor, which indexes the input tensor along dimension
195    /// `axis` using the boolean entries in `mask`.
196    pub fn bool_select<I>(&self, axis: isize, indices: I) -> Tensor<T, B, D>
197    where
198        I: TryInto<AxesIndex<bool>, Error: Into<Error>>,
199    {
200        bool_select(self, axis, indices)
201    }
202}
203
204/* #endregion */
205
206#[cfg(test)]
207mod test {
208    use super::*;
209
210    #[test]
211    fn test_index_select() {
212        #[cfg(not(feature = "col_major"))]
213        {
214            let device = DeviceCpuSerial::default();
215            let a = linspace((1.0, 24.0, 24, &device)).into_shape((2, 3, 4));
216            let b = a.index_select(0, [0, 0, 1, -1]);
217            assert!(fingerprint(&b) - -31.94175930917264 < 1e-8);
218            let b = a.index_select(1, [0, 0, 1, -1]);
219            assert!(fingerprint(&b) - 3.5719025258942088 < 1e-8);
220            let b = a.index_select(2, [0, 0, 1, -1]);
221            assert!(fingerprint(&b) - -25.648600916145096 < 1e-8);
222        }
223        #[cfg(feature = "col_major")]
224        {
225            let device = DeviceCpuSerial::default();
226            let a = linspace((1.0, 24.0, 24, &device)).into_shape((4, 3, 2));
227            let b = a.index_select(2, [0, 0, 1, -1]);
228            assert!(fingerprint(&b) - -31.94175930917264 < 1e-8);
229            let b = a.index_select(1, [0, 0, 1, -1]);
230            assert!(fingerprint(&b) - 3.5719025258942088 < 1e-8);
231            let b = a.index_select(0, [0, 0, 1, -1]);
232            assert!(fingerprint(&b) - -25.648600916145096 < 1e-8);
233        }
234    }
235
236    #[test]
237    fn test_index_select_default_device() {
238        #[cfg(not(feature = "col_major"))]
239        {
240            let device = DeviceCpu::default();
241            let a = linspace((1.0, 2.0, 256 * 256 * 256, &device)).into_shape((256, 256, 256));
242            let sel = [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233];
243            let b = a.index_select(0, &sel);
244            assert!(fingerprint(&b) - 0.9357016252766746 < 1e-10);
245            let b = a.index_select(1, &sel);
246            assert!(fingerprint(&b) - 1.012193909979973 < 1e-10);
247            let b = a.index_select(2, &sel);
248            assert!(fingerprint(&b) - 1.010735112247236 < 1e-10);
249        }
250        #[cfg(feature = "col_major")]
251        {
252            let device = DeviceCpu::default();
253            let a = linspace((1.0, 2.0, 256 * 256 * 256, &device)).into_shape((256, 256, 256));
254            let sel = [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233];
255            let b = a.index_select(2, &sel);
256            assert!(fingerprint(&b) - 0.9357016252766746 < 1e-10);
257            let b = a.index_select(1, &sel);
258            assert!(fingerprint(&b) - 1.012193909979973 < 1e-10);
259            let b = a.index_select(0, &sel);
260            assert!(fingerprint(&b) - 1.010735112247236 < 1e-10);
261        }
262    }
263
264    #[test]
265    fn test_bool_select_workable() {
266        let a = arange(24).into_shape((2, 3, 4));
267        let b = a.bool_select(-2, [true, false, true]);
268        println!("{b:?}");
269    }
270}