rstsr_core/tensor/
indexing.rs

1// Indexing of tensors
2
3use core::ops::{Index, IndexMut};
4
5use crate::prelude_dev::*;
6
7/* #region slice */
8
9pub fn into_slice_f<S, D, I>(tensor: TensorBase<S, D>, index: I) -> Result<TensorBase<S, IxD>>
10where
11    D: DimAPI,
12    I: TryInto<AxesIndex<Indexer>, Error = Error>,
13{
14    let (data, layout) = tensor.into_raw_parts();
15    let index = index.try_into()?;
16    let layout = layout.dim_slice(index.as_ref())?;
17    return unsafe { Ok(TensorBase::new_unchecked(data, layout)) };
18}
19
20pub fn into_slice<S, D, I>(tensor: TensorBase<S, D>, index: I) -> TensorBase<S, IxD>
21where
22    D: DimAPI,
23    I: TryInto<AxesIndex<Indexer>, Error = Error>,
24{
25    into_slice_f(tensor, index).unwrap()
26}
27
28pub fn slice_f<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, index: I) -> Result<TensorView<'_, T, B, IxD>>
29where
30    D: DimAPI,
31    I: TryInto<AxesIndex<Indexer>, Error = Error>,
32    R: DataAPI<Data = B::Raw>,
33    B: DeviceAPI<T>,
34{
35    into_slice_f(tensor.view(), index)
36}
37
38pub fn slice<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, index: I) -> TensorView<'_, T, B, IxD>
39where
40    D: DimAPI,
41    I: TryInto<AxesIndex<Indexer>, Error = Error>,
42    R: DataAPI<Data = B::Raw>,
43    B: DeviceAPI<T>,
44{
45    slice_f(tensor, index).unwrap()
46}
47
48impl<R, T, B, D> TensorAny<R, T, B, D>
49where
50    R: DataAPI<Data = B::Raw>,
51    B: DeviceAPI<T>,
52    D: DimAPI,
53{
54    pub fn into_slice_f<I>(self, index: I) -> Result<TensorAny<R, T, B, IxD>>
55    where
56        I: TryInto<AxesIndex<Indexer>, Error = Error>,
57    {
58        into_slice_f(self, index)
59    }
60
61    pub fn into_slice<I>(self, index: I) -> TensorAny<R, T, B, IxD>
62    where
63        I: TryInto<AxesIndex<Indexer>, Error = Error>,
64    {
65        into_slice(self, index)
66    }
67
68    pub fn slice_f<I>(&self, index: I) -> Result<TensorView<'_, T, B, IxD>>
69    where
70        I: TryInto<AxesIndex<Indexer>, Error = Error>,
71    {
72        slice_f(self, index)
73    }
74
75    pub fn slice<I>(&self, index: I) -> TensorView<'_, T, B, IxD>
76    where
77        I: TryInto<AxesIndex<Indexer>, Error = Error>,
78    {
79        slice(self, index)
80    }
81
82    pub fn i_f<I>(&self, index: I) -> Result<TensorView<'_, T, B, IxD>>
83    where
84        I: TryInto<AxesIndex<Indexer>, Error = Error>,
85    {
86        slice_f(self, index)
87    }
88
89    pub fn i<I>(&self, index: I) -> TensorView<'_, T, B, IxD>
90    where
91        I: TryInto<AxesIndex<Indexer>, Error = Error>,
92    {
93        slice(self, index)
94    }
95}
96
97/* #endregion */
98
99/* #region slice mut */
100
101pub fn slice_mut_f<R, T, B, D, I>(tensor: &mut TensorAny<R, T, B, D>, index: I) -> Result<TensorMut<'_, T, B, IxD>>
102where
103    D: DimAPI,
104    I: TryInto<AxesIndex<Indexer>, Error = Error>,
105    R: DataMutAPI<Data = B::Raw>,
106    B: DeviceAPI<T>,
107{
108    into_slice_f(tensor.view_mut(), index)
109}
110
111pub fn slice_mut<R, T, B, D, I>(tensor: &mut TensorAny<R, T, B, D>, index: I) -> TensorMut<'_, T, B, IxD>
112where
113    D: DimAPI,
114    I: TryInto<AxesIndex<Indexer>, Error = Error>,
115    R: DataMutAPI<Data = B::Raw>,
116    B: DeviceAPI<T>,
117{
118    slice_mut_f(tensor, index).unwrap()
119}
120
121impl<R, T, B, D> TensorAny<R, T, B, D>
122where
123    R: DataMutAPI<Data = B::Raw>,
124    B: DeviceAPI<T>,
125    D: DimAPI,
126{
127    pub fn slice_mut_f<I>(&mut self, index: I) -> Result<TensorMut<'_, T, B, IxD>>
128    where
129        I: TryInto<AxesIndex<Indexer>, Error = Error>,
130    {
131        slice_mut_f(self, index)
132    }
133
134    pub fn slice_mut<I>(&mut self, index: I) -> TensorMut<'_, T, B, IxD>
135    where
136        I: TryInto<AxesIndex<Indexer>, Error = Error>,
137    {
138        slice_mut(self, index)
139    }
140
141    pub fn i_mut_f<I>(&mut self, index: I) -> Result<TensorMut<'_, T, B, IxD>>
142    where
143        I: TryInto<AxesIndex<Indexer>, Error = Error>,
144    {
145        slice_mut_f(self, index)
146    }
147
148    pub fn i_mut<I>(&mut self, index: I) -> TensorMut<'_, T, B, IxD>
149    where
150        I: TryInto<AxesIndex<Indexer>, Error = Error>,
151    {
152        slice_mut(self, index)
153    }
154}
155
156/* #endregion */
157
158/* #region diagonal */
159
160pub struct DiagonalArgs {
161    pub offset: Option<isize>,
162    pub axis1: Option<isize>,
163    pub axis2: Option<isize>,
164}
165
166#[duplicate_item(
167    S0      S1      S2;
168   [isize] [isize] [isize];
169   [usize] [isize] [isize];
170   [usize] [usize] [usize];
171   [i32  ] [i32  ] [i32  ];
172   [i64  ] [i64  ] [i64  ];
173)]
174#[allow(clippy::unnecessary_cast)]
175impl From<(S0, S1, S2)> for DiagonalArgs {
176    fn from(args: (S0, S1, S2)) -> Self {
177        let (offset, axis1, axis2) = args;
178        Self { offset: Some(offset as isize), axis1: Some(axis1 as isize), axis2: Some(axis2 as isize) }
179    }
180}
181
182#[duplicate_item(S; [isize]; [usize]; [i32]; [i64];)]
183#[allow(clippy::unnecessary_cast)]
184impl From<S> for DiagonalArgs {
185    fn from(offset: S) -> Self {
186        Self { offset: Some(offset as isize), axis1: None, axis2: None }
187    }
188}
189
190impl From<()> for DiagonalArgs {
191    fn from(_: ()) -> Self {
192        Self { offset: None, axis1: None, axis2: None }
193    }
194}
195
196impl From<Option<isize>> for DiagonalArgs {
197    fn from(offset: Option<isize>) -> Self {
198        Self { offset, axis1: None, axis2: None }
199    }
200}
201
202pub fn into_diagonal_f<S, D>(
203    tensor: TensorBase<S, D>,
204    diagonal_args: impl Into<DiagonalArgs>,
205) -> Result<TensorBase<S, D::SmallerOne>>
206where
207    D: DimAPI + DimSmallerOneAPI,
208    D::SmallerOne: DimAPI,
209{
210    let (data, layout) = tensor.into_raw_parts();
211    let DiagonalArgs { offset, axis1, axis2 } = diagonal_args.into();
212    let layout = layout.diagonal(offset, axis1, axis2)?;
213    return unsafe { Ok(TensorBase::new_unchecked(data, layout)) };
214}
215
216pub fn into_diagonal<S, D>(
217    tensor: TensorBase<S, D>,
218    diagonal_args: impl Into<DiagonalArgs>,
219) -> TensorBase<S, D::SmallerOne>
220where
221    D: DimAPI + DimSmallerOneAPI,
222    D::SmallerOne: DimAPI,
223{
224    into_diagonal_f(tensor, diagonal_args).unwrap()
225}
226
227pub fn diagonal_f<R, T, B, D>(
228    tensor: &TensorAny<R, T, B, D>,
229    diagonal_args: impl Into<DiagonalArgs>,
230) -> Result<TensorView<'_, T, B, D::SmallerOne>>
231where
232    D: DimAPI + DimSmallerOneAPI,
233    D::SmallerOne: DimAPI,
234    R: DataAPI<Data = B::Raw>,
235    B: DeviceAPI<T>,
236{
237    into_diagonal_f(tensor.view(), diagonal_args)
238}
239
240pub fn diagonal<R, T, B, D>(
241    tensor: &TensorAny<R, T, B, D>,
242    diagonal_args: impl Into<DiagonalArgs>,
243) -> TensorView<'_, T, B, D::SmallerOne>
244where
245    D: DimAPI + DimSmallerOneAPI,
246    D::SmallerOne: DimAPI,
247    R: DataAPI<Data = B::Raw>,
248    B: DeviceAPI<T>,
249{
250    diagonal_f(tensor, diagonal_args).unwrap()
251}
252
253impl<R, T, B, D> TensorAny<R, T, B, D>
254where
255    R: DataAPI<Data = B::Raw>,
256    B: DeviceAPI<T>,
257    D: DimAPI + DimSmallerOneAPI,
258    D::SmallerOne: DimAPI,
259{
260    pub fn into_diagonal_f(self, diagonal_args: impl Into<DiagonalArgs>) -> Result<TensorAny<R, T, B, D::SmallerOne>> {
261        into_diagonal_f(self, diagonal_args)
262    }
263
264    pub fn into_diagonal(self, diagonal_args: impl Into<DiagonalArgs>) -> TensorAny<R, T, B, D::SmallerOne> {
265        into_diagonal(self, diagonal_args)
266    }
267
268    pub fn diagonal_f(&self, diagonal_args: impl Into<DiagonalArgs>) -> Result<TensorView<'_, T, B, D::SmallerOne>> {
269        diagonal_f(self, diagonal_args)
270    }
271
272    pub fn diagonal(&self, diagonal_args: impl Into<DiagonalArgs>) -> TensorView<'_, T, B, D::SmallerOne> {
273        diagonal(self, diagonal_args)
274    }
275}
276
277/* #endregion */
278
279/* #region diagonal_mut */
280
281pub fn into_diagonal_mut_f<S, D>(
282    tensor: TensorBase<S, D>,
283    diagonal_args: impl Into<DiagonalArgs>,
284) -> Result<TensorBase<S, D::SmallerOne>>
285where
286    D: DimAPI + DimSmallerOneAPI,
287    D::SmallerOne: DimAPI,
288{
289    let (data, layout) = tensor.into_raw_parts();
290    let DiagonalArgs { offset, axis1, axis2 } = diagonal_args.into();
291    let layout = layout.diagonal(offset, axis1, axis2)?;
292    return unsafe { Ok(TensorBase::new_unchecked(data, layout)) };
293}
294
295pub fn into_diagonal_mut<S, D>(
296    tensor: TensorBase<S, D>,
297    diagonal_args: impl Into<DiagonalArgs>,
298) -> TensorBase<S, D::SmallerOne>
299where
300    D: DimAPI + DimSmallerOneAPI,
301    D::SmallerOne: DimAPI,
302{
303    into_diagonal_mut_f(tensor, diagonal_args).unwrap()
304}
305
306pub fn diagonal_mut_f<R, T, B, D>(
307    tensor: &mut TensorAny<R, T, B, D>,
308    diagonal_args: impl Into<DiagonalArgs>,
309) -> Result<TensorMut<'_, T, B, D::SmallerOne>>
310where
311    D: DimAPI + DimSmallerOneAPI,
312    D::SmallerOne: DimAPI,
313    R: DataMutAPI<Data = B::Raw>,
314    B: DeviceAPI<T>,
315{
316    into_diagonal_mut_f(tensor.view_mut(), diagonal_args)
317}
318
319pub fn diagonal_mut<R, T, B, D>(
320    tensor: &mut TensorAny<R, T, B, D>,
321    diagonal_args: impl Into<DiagonalArgs>,
322) -> TensorMut<'_, T, B, D::SmallerOne>
323where
324    D: DimAPI + DimSmallerOneAPI,
325    D::SmallerOne: DimAPI,
326    R: DataMutAPI<Data = B::Raw>,
327    B: DeviceAPI<T>,
328{
329    diagonal_mut_f(tensor, diagonal_args).unwrap()
330}
331
332impl<R, T, B, D> TensorAny<R, T, B, D>
333where
334    R: DataMutAPI<Data = B::Raw>,
335    B: DeviceAPI<T>,
336    D: DimAPI + DimSmallerOneAPI,
337    D::SmallerOne: DimAPI,
338{
339    pub fn into_diagonal_mut_f(
340        self,
341        diagonal_args: impl Into<DiagonalArgs>,
342    ) -> Result<TensorAny<R, T, B, D::SmallerOne>> {
343        into_diagonal_mut_f(self, diagonal_args)
344    }
345
346    pub fn into_diagonal_mut(self, diagonal_args: impl Into<DiagonalArgs>) -> TensorAny<R, T, B, D::SmallerOne> {
347        into_diagonal_mut(self, diagonal_args)
348    }
349
350    pub fn diagonal_mut_f(
351        &mut self,
352        diagonal_args: impl Into<DiagonalArgs>,
353    ) -> Result<TensorMut<'_, T, B, D::SmallerOne>> {
354        diagonal_mut_f(self, diagonal_args)
355    }
356
357    pub fn diagonal_mut(&mut self, diagonal_args: impl Into<DiagonalArgs>) -> TensorMut<'_, T, B, D::SmallerOne> {
358        diagonal_mut(self, diagonal_args)
359    }
360}
361
362/* #endregion */
363
364/* #region indexing */
365
366// It seems that implementing Index for TensorBase is not possible because of
367// the lifetime issue. However, directly implementing each struct can avoid such
368// problem.
369
370#[duplicate_item(
371    TensorStruct;
372    [Tensor<T, B, D>];
373    [TensorView<'_, T, B, D>];
374    [TensorViewMut<'_, T, B, D>];
375    [TensorCow<'_, T, B, D>];
376)]
377impl<T, D, B, I> Index<I> for TensorStruct
378where
379    T: Clone,
380    D: DimAPI,
381    B: DeviceAPI<T, Raw = Vec<T>>,
382    I: AsRef<[usize]>,
383{
384    type Output = T;
385
386    #[inline]
387    fn index(&self, index: I) -> &Self::Output {
388        let index = index.as_ref().iter().map(|&v| v as isize).collect::<Vec<_>>();
389        let i = self.layout().index(index.as_ref());
390        let raw = self.raw();
391        raw.index(i)
392    }
393}
394
395#[duplicate_item(
396    TensorStruct;
397    [Tensor<T, B, D>];
398    [TensorViewMut<'_, T, B, D>];
399)]
400impl<T, D, B, I> IndexMut<I> for TensorStruct
401where
402    T: Clone,
403    D: DimAPI,
404    B: DeviceAPI<T, Raw = Vec<T>>,
405    I: AsRef<[usize]>,
406{
407    #[inline]
408    fn index_mut(&mut self, index: I) -> &mut Self::Output {
409        let index = index.as_ref().iter().map(|&v| v as isize).collect::<Vec<_>>();
410        let i = self.layout().index(index.as_ref());
411        let raw = self.raw_mut();
412        raw.index_mut(i)
413    }
414}
415
416/* #endregion */
417
418/* #region indexing (unchecked) */
419
420#[duplicate_item(
421    TensorStruct;
422    [Tensor<T, B, D>];
423    [TensorView<'_, T, B, D>];
424    [TensorViewMut<'_, T, B, D>];
425    [TensorCow<'_, T, B, D>];
426)]
427impl<T, B, D> TensorStruct
428where
429    T: Clone,
430    D: DimAPI,
431    B: DeviceAPI<T, Raw = Vec<T>>,
432{
433    /// # Safety
434    ///
435    /// This function is unsafe because it does not check the validity of the
436    /// index.
437    #[inline]
438    pub unsafe fn index_uncheck<I>(&self, index: I) -> &T
439    where
440        I: AsRef<[usize]>,
441    {
442        let index = index.as_ref();
443        let i = unsafe { self.layout().index_uncheck(index) } as usize;
444        let raw = self.raw();
445        raw.index(i)
446    }
447}
448
449#[duplicate_item(
450    TensorStruct;
451    [Tensor<T, B, D>];
452    [TensorViewMut<'_, T, B, D>];
453)]
454impl<T, B, D> TensorStruct
455where
456    T: Clone,
457    D: DimAPI,
458    B: DeviceAPI<T, Raw = Vec<T>>,
459{
460    /// # Safety
461    ///
462    /// This function is unsafe because it does not check the validity of the
463    /// index.
464    #[inline]
465    pub unsafe fn index_mut_uncheck<I>(&mut self, index: I) -> &mut T
466    where
467        I: AsRef<[usize]>,
468    {
469        let index = index.as_ref();
470        let i = unsafe { self.layout().index_uncheck(index) } as usize;
471        let raw = self.raw_mut();
472        raw.index_mut(i)
473    }
474}
475
476/* #endregion */
477
478#[cfg(test)]
479mod test {
480    use super::*;
481
482    #[test]
483    fn test_tensor_slice_1d() {
484        let tensor = asarray(vec![1, 2, 3, 4, 5]);
485        let tensor_slice = tensor.slice(s![1..4]);
486        println!("{tensor_slice:?}");
487        let tensor_slice = tensor.slice(s![1..4, None]);
488        println!("{tensor_slice:?}");
489        let tensor_slice = tensor.slice(1);
490        println!("{tensor_slice:?}");
491        let tensor_slice = tensor.slice(slice!(2, 7, 2));
492        println!("{tensor_slice:?}");
493
494        let mut tensor = asarray(vec![1, 2, 3, 4, 5]);
495        let mut tensor_slice = tensor.slice_mut(s![1..4]);
496        tensor_slice += 10;
497        println!("{tensor:?}");
498        *&mut tensor.slice_mut(s![1..4]) += 10;
499        println!("{tensor:?}");
500    }
501
502    #[test]
503    fn test_tensor_nd() {
504        let tensor = arange(24.0).into_shape([2, 3, 4]);
505        let tensor_slice = tensor.slice(s![1..2, 1..3, 1..4]);
506        println!("{tensor_slice:?}");
507        let tensor_slice = tensor.slice(s![1]);
508        println!("{tensor_slice:?}");
509    }
510
511    #[test]
512    fn test_tensor_index() {
513        let mut tensor = asarray(vec![1, 2, 3, 4, 5]);
514        let value = tensor[[1]];
515        println!("{value:?}");
516        let tensor_view = tensor.view();
517        let value = tensor_view[[2]];
518        {
519            let tensor_view = tensor.view();
520            let value = tensor_view[[3]];
521            println!("{value:?}");
522            let mut tensor_view_mut = tensor.view_mut();
523            tensor_view_mut[[4]] += 1;
524            *&mut tensor_view_mut.slice_mut(4) += 1;
525        }
526        println!("{value:?}");
527        println!("{tensor:?}");
528    }
529
530    #[test]
531    fn test_diagonal_compiles() {
532        let a = arange(24.0).into_shape([2, 3, 4]);
533        // this should reads input as i32
534        let b = a.diagonal(1);
535        println!("{b:?}");
536    }
537}