Skip to main content

rstsr_common/layout/
indexer.rs

1use crate::prelude_dev::*;
2
3#[non_exhaustive]
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum Indexer {
6    /// Slice the tensor by a range, denoted by slice instead of
7    /// std::ops::Range.
8    Slice(SliceI),
9    /// Marginalize one dimension out by index.
10    Select(isize),
11    /// Insert dimension at index, something like unsqueeze. Currently not
12    /// applied.
13    Insert,
14    /// Expand dimensions.
15    Ellipsis,
16}
17
18pub use Indexer::Ellipsis;
19pub use Indexer::Insert as NewAxis;
20
21/* #region into Indexer */
22
23impl<R> From<R> for Indexer
24where
25    R: Into<SliceI>,
26{
27    fn from(slice: R) -> Self {
28        Self::Slice(slice.into())
29    }
30}
31
32impl From<Option<usize>> for Indexer {
33    fn from(opt: Option<usize>) -> Self {
34        match opt {
35            Some(_) => panic!("Option<T> should not be used in Indexer."),
36            None => Self::Insert,
37        }
38    }
39}
40
41macro_rules! impl_from_int_into_indexer {
42    ($($t:ty),*) => {
43        $(
44            impl From<$t> for Indexer {
45                fn from(index: $t) -> Self {
46                    Self::Select(index as isize)
47                }
48            }
49        )*
50    };
51}
52
53impl_from_int_into_indexer!(usize, isize, u32, i32, u64, i64);
54
55/* #endregion */
56
57/* #region into AxesIndex<Indexer> */
58
59macro_rules! impl_into_axes_index {
60    ($($t:ty),*) => {
61        $(
62            impl TryFrom<$t> for AxesIndex<Indexer> {
63                type Error = Error;
64
65                fn try_from(index: $t) -> Result<Self> {
66                    Ok(AxesIndex::Val(index.try_into()?))
67                }
68            }
69
70            impl<const N: usize> TryFrom<[$t; N]> for AxesIndex<Indexer> {
71                type Error = Error;
72
73                fn try_from(index: [$t; N]) -> Result<Self> {
74                    let index = index.iter().map(|v| v.clone().into()).collect::<Vec<_>>();
75                    Ok(AxesIndex::Vec(index))
76                }
77            }
78
79            impl TryFrom<Vec<$t>> for AxesIndex<Indexer> {
80                type Error = Error;
81
82                fn try_from(index: Vec<$t>) -> Result<Self> {
83                    let index = index.iter().map(|v| v.clone().into()).collect::<Vec<_>>();
84                    Ok(AxesIndex::Vec(index))
85                }
86            }
87        )*
88    };
89}
90
91impl_into_axes_index!(usize, isize, u32, i32, u64, i64);
92impl_into_axes_index!(Option<usize>);
93impl_into_axes_index!(
94    Slice<isize>,
95    core::ops::Range<isize>,
96    core::ops::RangeFrom<isize>,
97    core::ops::RangeTo<isize>,
98    core::ops::Range<usize>,
99    core::ops::RangeFrom<usize>,
100    core::ops::RangeTo<usize>,
101    core::ops::Range<i32>,
102    core::ops::RangeFrom<i32>,
103    core::ops::RangeTo<i32>,
104    core::ops::RangeFull
105);
106
107impl_from_tuple_to_axes_index!(Indexer);
108
109/* #endregion */
110
111pub trait IndexerPreserveAPI: Sized {
112    /// Narrowing tensor by slicing at a specific axis.
113    fn dim_narrow(&self, axis: isize, slice: SliceI) -> Result<Self>;
114}
115
116impl<D> IndexerPreserveAPI for Layout<D>
117where
118    D: DimDevAPI,
119{
120    fn dim_narrow(&self, axis: isize, slice: SliceI) -> Result<Self> {
121        // dimension check
122        let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
123        rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
124        let axis = axis as usize;
125
126        // get essential information
127        let mut shape = self.shape().clone();
128        let mut stride = self.stride().clone();
129
130        // fast return if slice is empty
131        if slice == Slice::new(None, None, None) {
132            return Ok(self.clone());
133        }
134
135        // previous shape length
136        let len_prev = shape[axis] as isize;
137
138        // handle cases of step > 0 and step < 0
139        let step = slice.step().unwrap_or(1);
140        rstsr_assert!(step != 0, InvalidValue)?;
141
142        // quick return if previous shape is zero
143        if len_prev == 0 {
144            return Ok(self.clone());
145        }
146
147        if step > 0 {
148            // default start = 0 and stop = len_prev
149            let mut start = slice.start().unwrap_or(0);
150            let mut stop = slice.stop().unwrap_or(len_prev);
151
152            // handle negative slice
153            if start < 0 {
154                start = (len_prev + start).max(0);
155            }
156            if stop < 0 {
157                stop = (len_prev + stop).max(0);
158            }
159
160            if start > len_prev || start > stop {
161                // zero size slice caused by inproper start and stop
162                start = 0;
163                stop = 0;
164            } else if stop > len_prev {
165                // stop is out of bound, set it to len_prev
166                stop = len_prev;
167            }
168
169            let offset = (self.offset() as isize + stride[axis] * start) as usize;
170            shape[axis] = ((stop - start + step - 1) / step).max(0) as usize;
171            stride[axis] *= step;
172            return Self::new(shape, stride, offset);
173        } else {
174            // step < 0
175            // default start = len_prev - 1 and stop = -1
176            let mut start = slice.start().unwrap_or(len_prev - 1);
177            let mut stop = slice.stop().unwrap_or(-1);
178
179            // handle negative slice
180            if start < 0 {
181                start = (len_prev + start).max(0);
182            }
183            if stop < -1 {
184                stop = (len_prev + stop).max(-1);
185            }
186
187            if stop > len_prev - 1 || stop > start {
188                // zero size slice caused by inproper start and stop
189                start = 0;
190                stop = 0;
191            } else if start > len_prev - 1 {
192                // start is out of bound, set it to len_prev
193                start = len_prev - 1;
194            }
195
196            let offset = (self.offset() as isize + stride[axis] * start) as usize;
197            shape[axis] = ((stop - start + step + 1) / step).max(0) as usize;
198            stride[axis] *= step;
199            return Self::new(shape, stride, offset);
200        }
201    }
202}
203
204pub trait IndexerSmallerOneAPI {
205    type DOut: DimDevAPI;
206
207    /// Select dimension at index. Number of dimension will decrease by 1.
208    fn dim_select(&self, axis: isize, index: isize) -> Result<Layout<Self::DOut>>;
209
210    /// Eliminate dimension at index. Number of dimension will decrease by 1.
211    ///
212    /// Dimension to be eliminated should have shape 1, otherwise it will raise error. This is
213    /// useful for squeezeing.
214    fn dim_eliminate(&self, axis: isize) -> Result<Layout<Self::DOut>>;
215
216    /// Eliminate dimension at index (without addition checks). This may be useful to handle
217    /// zero-size axes, which is not eliminatable from dim_select(axis, 0) or dim_eliminate.
218    fn dim_chop(&self, axis: isize) -> Result<Layout<Self::DOut>>;
219}
220
221impl<D> IndexerSmallerOneAPI for Layout<D>
222where
223    D: DimDevAPI + DimSmallerOneAPI,
224    D::SmallerOne: DimDevAPI,
225{
226    type DOut = <D as DimSmallerOneAPI>::SmallerOne;
227
228    fn dim_select(&self, axis: isize, index: isize) -> Result<Layout<Self::DOut>> {
229        // dimension check
230        let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
231        rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
232        let axis = axis as usize;
233
234        // get essential information
235        let shape = self.shape();
236        let stride = self.stride();
237        let mut offset = self.offset() as isize;
238        let mut shape_new = vec![];
239        let mut stride_new = vec![];
240
241        // change everything
242        for (i, (&d, &s)) in shape.as_ref().iter().zip(stride.as_ref().iter()).enumerate() {
243            if i == axis {
244                // dimension to be selected
245                let idx = if index < 0 { d as isize + index } else { index };
246                rstsr_pattern!(idx, 0..d as isize, ValueOutOfRange)?;
247                offset += s * idx;
248            } else {
249                // other dimensions
250                shape_new.push(d);
251                stride_new.push(s);
252            }
253        }
254
255        let offset = offset as usize;
256        let layout = Layout::<IxD>::new(shape_new, stride_new, offset)?;
257        return layout.into_dim();
258    }
259
260    fn dim_eliminate(&self, axis: isize) -> Result<Layout<Self::DOut>> {
261        // dimension check
262        let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
263        rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
264        let axis = axis as usize;
265
266        // get essential information
267        let mut shape = self.shape().as_ref().to_vec();
268        let mut stride = self.stride().as_ref().to_vec();
269        let offset = self.offset();
270
271        if shape[axis] != 1 {
272            rstsr_raise!(InvalidValue, "Dimension to be eliminated is not 1.")?;
273        }
274
275        shape.remove(axis);
276        stride.remove(axis);
277
278        let layout = Layout::<IxD>::new(shape, stride, offset)?;
279        return layout.into_dim();
280    }
281
282    fn dim_chop(&self, axis: isize) -> Result<Layout<Self::DOut>> {
283        // dimension check
284        let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
285        rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
286        let axis = axis as usize;
287
288        // get essential information
289        let mut shape = self.shape().as_ref().to_vec();
290        let mut stride = self.stride().as_ref().to_vec();
291        let offset = self.offset();
292
293        shape.remove(axis);
294        stride.remove(axis);
295
296        let layout = Layout::<IxD>::new(shape, stride, offset)?;
297        return layout.into_dim();
298    }
299}
300
301pub trait IndexerLargerOneAPI {
302    type DOut: DimDevAPI;
303
304    /// Insert dimension after, with shape 1. Number of dimension will increase
305    /// by 1.
306    fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>>;
307}
308
309impl<D> IndexerLargerOneAPI for Layout<D>
310where
311    D: DimDevAPI + DimLargerOneAPI,
312    D::LargerOne: DimDevAPI,
313{
314    type DOut = <D as DimLargerOneAPI>::LargerOne;
315
316    fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>> {
317        // dimension check
318        let axis = if axis < 0 { self.ndim() as isize + axis + 1 } else { axis };
319        rstsr_pattern!(axis, 0..(self.ndim() + 1) as isize, ValueOutOfRange)?;
320        let axis = axis as usize;
321
322        // get essential information
323        let is_f_prefer = self.f_prefer();
324        let mut shape = self.shape().as_ref().to_vec();
325        let mut stride = self.stride().as_ref().to_vec();
326        let offset = self.offset();
327
328        if is_f_prefer {
329            if axis == 0 {
330                shape.insert(0, 1);
331                stride.insert(0, 1);
332            } else {
333                shape.insert(axis, 1);
334                stride.insert(axis, stride[axis - 1]);
335            }
336        } else if axis == self.ndim() {
337            shape.push(1);
338            stride.push(1);
339        } else {
340            shape.insert(axis, 1);
341            stride.insert(axis, stride[axis]);
342        }
343
344        let layout = Layout::new(shape, stride, offset)?;
345        return layout.into_dim();
346    }
347}
348
349pub trait IndexerDynamicAPI: IndexerPreserveAPI {
350    /// Index tensor by a list of indexers.
351    fn dim_slice(&self, indexers: &[Indexer]) -> Result<Layout<IxD>>;
352
353    /// Split current layout into two layouts at axis, with offset unchanged.
354    fn dim_split_at(&self, axis: isize) -> Result<(Layout<IxD>, Layout<IxD>)>;
355
356    /// Split current layout into two layouts by axes, with offset unchanged. Returned layouts will
357    /// be (layout_axes, layout_rest).
358    ///
359    /// This function is designed for reduction, to split the layout into axes to be reduced and the
360    /// rest.
361    fn dim_split_axes(&self, axes: &[isize]) -> Result<(Layout<IxD>, Layout<IxD>)>;
362}
363
364impl<D> IndexerDynamicAPI for Layout<D>
365where
366    D: DimDevAPI,
367{
368    fn dim_slice(&self, indexers: &[Indexer]) -> Result<Layout<IxD>> {
369        // transform any layout to dynamic layout
370        let shape = self.shape().as_ref().to_vec();
371        let stride = self.stride().as_ref().to_vec();
372        let mut layout = Layout::new(shape, stride, self.offset)?;
373
374        // clone indexers to vec to make it changeable
375        let mut indexers = indexers.to_vec();
376
377        // counter for indexer
378        let mut counter_slice = 0;
379        let mut counter_select = 0;
380        let mut idx_ellipsis = None;
381        for (n, indexer) in indexers.iter().enumerate() {
382            match indexer {
383                Indexer::Slice(_) => counter_slice += 1,
384                Indexer::Select(_) => counter_select += 1,
385                Indexer::Ellipsis => match idx_ellipsis {
386                    Some(_) => rstsr_raise!(InvalidValue, "Only one ellipsis indexer allowed.")?,
387                    None => idx_ellipsis = Some(n),
388                },
389                _ => {},
390            }
391        }
392
393        // check if slice-type and select-type indexer exceed the number of dimensions
394        rstsr_pattern!(counter_slice + counter_select, 0..=self.ndim(), ValueOutOfRange)?;
395
396        // insert Ellipsis by slice(:) anyway, default append at last
397        let n_ellipsis = self.ndim() - counter_slice - counter_select;
398        if n_ellipsis == 0 {
399            if let Some(idx) = idx_ellipsis {
400                indexers.remove(idx);
401            }
402        } else if let Some(idx_ellipsis) = idx_ellipsis {
403            indexers[idx_ellipsis] = SliceI::new(None, None, None).into();
404            if n_ellipsis > 1 {
405                for _ in 1..n_ellipsis {
406                    indexers.insert(idx_ellipsis, SliceI::new(None, None, None).into());
407                }
408            }
409        } else {
410            for _ in 0..n_ellipsis {
411                indexers.push(SliceI::new(None, None, None).into());
412            }
413        }
414
415        // handle indexers from last
416        // it is possible to be zero-dim, minus after -= 1
417        let mut cur_dim = self.ndim() as isize;
418        for indexer in indexers.iter().rev() {
419            match indexer {
420                Indexer::Slice(slice) => {
421                    cur_dim -= 1;
422                    layout = layout.dim_narrow(cur_dim, *slice)?;
423                },
424                Indexer::Select(index) => {
425                    cur_dim -= 1;
426                    layout = layout.dim_select(cur_dim, *index)?;
427                },
428                Indexer::Insert => {
429                    layout = layout.dim_insert(cur_dim)?;
430                },
431                _ => rstsr_raise!(InvalidValue, "Invalid indexer found : {:?}", indexer)?,
432            }
433        }
434
435        // this program should be designed that cur_dim is zero at the end
436        rstsr_assert!(cur_dim == 0, Miscellaneous, "Internal program error in indexer.")?;
437
438        return Ok(layout);
439    }
440
441    fn dim_split_at(&self, axis: isize) -> Result<(Layout<IxD>, Layout<IxD>)> {
442        // dimension check
443        // this functions allows [-n, n], not previous functions [-n, n)
444        let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
445        rstsr_pattern!(axis, 0..=self.ndim() as isize, ValueOutOfRange)?;
446        let axis = axis as usize;
447
448        // split layouts
449        let shape = self.shape().as_ref().to_vec();
450        let stride = self.stride().as_ref().to_vec();
451        let offset = self.offset();
452
453        let (shape1, shape2) = shape.split_at(axis);
454        let (stride1, stride2) = stride.split_at(axis);
455
456        let layout1 = unsafe { Layout::new_unchecked(shape1.to_vec(), stride1.to_vec(), offset) };
457        let layout2 = unsafe { Layout::new_unchecked(shape2.to_vec(), stride2.to_vec(), offset) };
458        return Ok((layout1, layout2));
459    }
460
461    fn dim_split_axes(&self, axes: &[isize]) -> Result<(Layout<IxD>, Layout<IxD>)> {
462        // returned layouts will be
463        // (layout_axes, layout_rest)
464
465        let axes_update = normalize_axes_index(axes.into(), self.ndim(), false, false)?
466            .into_iter()
467            .map(|axis| axis as usize)
468            .collect::<Vec<usize>>();
469
470        // rest of axes
471        // this is not the most efficient way, but low cost when dimension is small
472        let axes_rest = (0..self.ndim()).filter(|&axis| !axes_update.contains(&axis)).collect::<Vec<_>>();
473
474        // split layouts for axes
475        let offset = self.offset();
476        let shape_axes = axes_update.iter().map(|&axis| self.shape()[axis]).collect::<Vec<_>>();
477        let strides_axes = axes_update.iter().map(|&axis| self.stride()[axis]).collect::<Vec<_>>();
478        let layout_axes = Layout::new(shape_axes, strides_axes, offset)?;
479
480        let shape_rest = axes_rest.iter().map(|&axis| self.shape()[axis]).collect::<Vec<_>>();
481        let strides_rest = axes_rest.iter().map(|&axis| self.stride()[axis]).collect::<Vec<_>>();
482        let layout_rest = Layout::new(shape_rest, strides_rest, offset)?;
483
484        return Ok((layout_axes, layout_rest));
485    }
486}
487
488/// Generate slice with into support and optional parameters.
489#[macro_export]
490macro_rules! slice {
491    ($stop:expr) => {{
492        use $crate::layout::slice::Slice;
493        Slice::<isize>::from(Slice::new(None, $stop, None))
494    }};
495    ($start:expr, $stop:expr) => {{
496        use $crate::layout::slice::Slice;
497        Slice::<isize>::from(Slice::new($start, $stop, None))
498    }};
499    ($start:expr, $stop:expr, $step:expr) => {{
500        use $crate::layout::slice::Slice;
501        Slice::<isize>::from(Slice::new($start, $stop, $step))
502    }};
503}
504
505#[macro_export]
506macro_rules! s {
507    // basic rule
508    [$($slc:expr),*] => {
509        [$(($slc).into()),*].as_ref()
510    };
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_slice() {
519        let t = 3_usize;
520        let s = slice!(1, 2, t);
521        assert_eq!(s.start(), Some(1));
522        assert_eq!(s.stop(), Some(2));
523        assert_eq!(s.step(), Some(3));
524    }
525
526    #[test]
527    fn test_slice_at_dim() {
528        let l = Layout::new([2, 3, 4], [1, 10, 100], 0).unwrap();
529        let s = slice!(10, 1, -1);
530        let l1 = l.dim_narrow(1, s).unwrap();
531        println!("{l1:?}");
532        let l2 = l.dim_select(1, -2).unwrap();
533        println!("{l2:?}");
534        let l3 = l.dim_insert(1).unwrap();
535        println!("{l3:?}");
536
537        let l = Layout::new([2, 3, 4], [100, 10, 1], 0).unwrap();
538        let l3 = l.dim_insert(1).unwrap();
539        println!("{l3:?}");
540
541        let l4 = l.dim_slice(s![Indexer::Ellipsis, 1..3, None, 2]).unwrap();
542        let l4 = l4.into_dim::<Ix3>().unwrap();
543        println!("{l4:?}");
544        assert_eq!(l4.shape(), &[2, 2, 1]);
545        assert_eq!(l4.offset(), 12);
546
547        let l5 = l.dim_slice(s![None, 1, None, 1..3]).unwrap();
548        let l5 = l5.into_dim::<Ix4>().unwrap();
549        println!("{l5:?}");
550        assert_eq!(l5.shape(), &[1, 1, 2, 4]);
551        assert_eq!(l5.offset(), 110);
552    }
553
554    #[test]
555    fn test_slice_with_stride() {
556        let l = Layout::new([24], [1], 0).unwrap();
557        let b = l.dim_narrow(0, slice!(5, 15, 2)).unwrap();
558        assert_eq!(b, Layout::new([5], [2], 5).unwrap());
559        let b = l.dim_narrow(0, slice!(5, 16, 2)).unwrap();
560        assert_eq!(b, Layout::new([6], [2], 5).unwrap());
561        let b = l.dim_narrow(0, slice!(15, 5, -2)).unwrap();
562        assert_eq!(b, Layout::new([5], [-2], 15).unwrap());
563        let b = l.dim_narrow(0, slice!(15, 4, -2)).unwrap();
564        assert_eq!(b, Layout::new([6], [-2], 15).unwrap());
565    }
566
567    #[test]
568    fn test_expand_dims() {
569        let l = Layout::<Ix3>::new([2, 3, 4], [1, 10, 100], 0).unwrap();
570        let l1 = l.dim_insert(0).unwrap();
571        println!("{l1:?}");
572        let l2 = l.dim_insert(1).unwrap();
573        println!("{l2:?}");
574        let l3 = l.dim_insert(3).unwrap();
575        println!("{l3:?}");
576        let l4 = l.dim_insert(-1).unwrap();
577        println!("{l4:?}");
578        let l5 = l.dim_insert(-4).unwrap();
579        println!("{l5:?}");
580    }
581}