rstsr_common/layout/
indexer.rs

1use crate::prelude_dev::*;
2
3#[non_exhaustive]
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum Indexer {
6    /// Slice the tensor by a range, denoted by slice instead of
7    /// std::ops::Range.
8    Slice(SliceI),
9    /// Marginalize one dimension out by index.
10    Select(isize),
11    /// Insert dimension at index, something like unsqueeze. Currently not
12    /// applied.
13    Insert,
14    /// Expand dimensions.
15    Ellipsis,
16}
17
18pub use Indexer::Ellipsis;
19pub use Indexer::Insert as NewAxis;
20
21/* #region into Indexer */
22
23impl<R> From<R> for Indexer
24where
25    R: Into<SliceI>,
26{
27    fn from(slice: R) -> Self {
28        Self::Slice(slice.into())
29    }
30}
31
32impl From<Option<usize>> for Indexer {
33    fn from(opt: Option<usize>) -> Self {
34        match opt {
35            Some(_) => panic!("Option<T> should not be used in Indexer."),
36            None => Self::Insert,
37        }
38    }
39}
40
41macro_rules! impl_from_int_into_indexer {
42    ($($t:ty),*) => {
43        $(
44            impl From<$t> for Indexer {
45                fn from(index: $t) -> Self {
46                    Self::Select(index as isize)
47                }
48            }
49        )*
50    };
51}
52
53impl_from_int_into_indexer!(usize, isize, u32, i32, u64, i64);
54
55/* #endregion */
56
57/* #region into AxesIndex<Indexer> */
58
59macro_rules! impl_into_axes_index {
60    ($($t:ty),*) => {
61        $(
62            impl From<$t> for AxesIndex<Indexer> {
63                fn from(index: $t) -> Self {
64                    AxesIndex::Val(index.into())
65                }
66            }
67
68            impl<const N: usize> From<[$t; N]> for AxesIndex<Indexer> {
69                fn from(index: [$t; N]) -> Self {
70                    let index = index.iter().map(|v| v.clone().into()).collect::<Vec<_>>();
71                    AxesIndex::Vec(index)
72                }
73            }
74
75            impl From<Vec<$t>> for AxesIndex<Indexer> {
76                fn from(index: Vec<$t>) -> Self {
77                    let index = index.iter().map(|v| v.clone().into()).collect::<Vec<_>>();
78                    AxesIndex::Vec(index)
79                }
80            }
81        )*
82    };
83}
84
85impl_into_axes_index!(usize, isize, u32, i32, u64, i64);
86impl_into_axes_index!(Option<usize>);
87impl_into_axes_index!(
88    Slice<isize>,
89    core::ops::Range<isize>,
90    core::ops::RangeFrom<isize>,
91    core::ops::RangeTo<isize>,
92    core::ops::Range<usize>,
93    core::ops::RangeFrom<usize>,
94    core::ops::RangeTo<usize>,
95    core::ops::Range<i32>,
96    core::ops::RangeFrom<i32>,
97    core::ops::RangeTo<i32>,
98    core::ops::RangeFull
99);
100
101impl_from_tuple_to_axes_index!(Indexer);
102
103/* #endregion */
104
105pub trait IndexerPreserveAPI: Sized {
106    /// Narrowing tensor by slicing at a specific axis.
107    fn dim_narrow(&self, axis: isize, slice: SliceI) -> Result<Self>;
108}
109
110impl<D> IndexerPreserveAPI for Layout<D>
111where
112    D: DimDevAPI,
113{
114    fn dim_narrow(&self, axis: isize, slice: SliceI) -> Result<Self> {
115        // dimension check
116        let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
117        rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
118        let axis = axis as usize;
119
120        // get essential information
121        let mut shape = self.shape().clone();
122        let mut stride = self.stride().clone();
123
124        // fast return if slice is empty
125        if slice == Slice::new(None, None, None) {
126            return Ok(self.clone());
127        }
128
129        // previous shape length
130        let len_prev = shape[axis] as isize;
131
132        // handle cases of step > 0 and step < 0
133        let step = slice.step().unwrap_or(1);
134        rstsr_assert!(step != 0, InvalidValue)?;
135
136        // quick return if previous shape is zero
137        if len_prev == 0 {
138            return Ok(self.clone());
139        }
140
141        if step > 0 {
142            // default start = 0 and stop = len_prev
143            let mut start = slice.start().unwrap_or(0);
144            let mut stop = slice.stop().unwrap_or(len_prev);
145
146            // handle negative slice
147            if start < 0 {
148                start = (len_prev + start).max(0);
149            }
150            if stop < 0 {
151                stop = (len_prev + stop).max(0);
152            }
153
154            if start > len_prev || start > stop {
155                // zero size slice caused by inproper start and stop
156                start = 0;
157                stop = 0;
158            } else if stop > len_prev {
159                // stop is out of bound, set it to len_prev
160                stop = len_prev;
161            }
162
163            let offset = (self.offset() as isize + stride[axis] * start) as usize;
164            shape[axis] = ((stop - start + step - 1) / step).max(0) as usize;
165            stride[axis] *= step;
166            return Self::new(shape, stride, offset);
167        } else {
168            // step < 0
169            // default start = len_prev - 1 and stop = -1
170            let mut start = slice.start().unwrap_or(len_prev - 1);
171            let mut stop = slice.stop().unwrap_or(-1);
172
173            // handle negative slice
174            if start < 0 {
175                start = (len_prev + start).max(0);
176            }
177            if stop < -1 {
178                stop = (len_prev + stop).max(-1);
179            }
180
181            if stop > len_prev - 1 || stop > start {
182                // zero size slice caused by inproper start and stop
183                start = 0;
184                stop = 0;
185            } else if start > len_prev - 1 {
186                // start is out of bound, set it to len_prev
187                start = len_prev - 1;
188            }
189
190            let offset = (self.offset() as isize + stride[axis] * start) as usize;
191            shape[axis] = ((stop - start + step + 1) / step).max(0) as usize;
192            stride[axis] *= step;
193            return Self::new(shape, stride, offset);
194        }
195    }
196}
197
198pub trait IndexerSmallerOneAPI {
199    type DOut: DimDevAPI;
200
201    /// Select dimension at index. Number of dimension will decrease by 1.
202    fn dim_select(&self, axis: isize, index: isize) -> Result<Layout<Self::DOut>>;
203
204    /// Eliminate dimension at index. Number of dimension will decrease by 1.
205    fn dim_eliminate(&self, axis: isize) -> Result<Layout<Self::DOut>>;
206}
207
208impl<D> IndexerSmallerOneAPI for Layout<D>
209where
210    D: DimDevAPI + DimSmallerOneAPI,
211    D::SmallerOne: DimDevAPI,
212{
213    type DOut = <D as DimSmallerOneAPI>::SmallerOne;
214
215    fn dim_select(&self, axis: isize, index: isize) -> Result<Layout<Self::DOut>> {
216        // dimension check
217        let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
218        rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
219        let axis = axis as usize;
220
221        // get essential information
222        let shape = self.shape();
223        let stride = self.stride();
224        let mut offset = self.offset() as isize;
225        let mut shape_new = vec![];
226        let mut stride_new = vec![];
227
228        // change everything
229        for (i, (&d, &s)) in shape.as_ref().iter().zip(stride.as_ref().iter()).enumerate() {
230            if i == axis {
231                // dimension to be selected
232                let idx = if index < 0 { d as isize + index } else { index };
233                rstsr_pattern!(idx, 0..d as isize, ValueOutOfRange)?;
234                offset += s * idx;
235            } else {
236                // other dimensions
237                shape_new.push(d);
238                stride_new.push(s);
239            }
240        }
241
242        let offset = offset as usize;
243        let layout = Layout::<IxD>::new(shape_new, stride_new, offset)?;
244        return layout.into_dim();
245    }
246
247    fn dim_eliminate(&self, axis: isize) -> Result<Layout<Self::DOut>> {
248        // dimension check
249        let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
250        rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
251        let axis = axis as usize;
252
253        // get essential information
254        let mut shape = self.shape().as_ref().to_vec();
255        let mut stride = self.stride().as_ref().to_vec();
256        let offset = self.offset();
257
258        if shape[axis] != 1 {
259            rstsr_raise!(InvalidValue, "Dimension to be eliminated is not 1.")?;
260        }
261
262        shape.remove(axis);
263        stride.remove(axis);
264
265        let layout = Layout::<IxD>::new(shape, stride, offset)?;
266        return layout.into_dim();
267    }
268}
269
270pub trait IndexerLargerOneAPI {
271    type DOut: DimDevAPI;
272
273    /// Insert dimension after, with shape 1. Number of dimension will increase
274    /// by 1.
275    fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>>;
276}
277
278impl<D> IndexerLargerOneAPI for Layout<D>
279where
280    D: DimDevAPI + DimLargerOneAPI,
281    D::LargerOne: DimDevAPI,
282{
283    type DOut = <D as DimLargerOneAPI>::LargerOne;
284
285    fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>> {
286        // dimension check
287        let axis = if axis < 0 { self.ndim() as isize + axis + 1 } else { axis };
288        rstsr_pattern!(axis, 0..(self.ndim() + 1) as isize, ValueOutOfRange)?;
289        let axis = axis as usize;
290
291        // get essential information
292        let is_f_prefer = self.f_prefer();
293        let mut shape = self.shape().as_ref().to_vec();
294        let mut stride = self.stride().as_ref().to_vec();
295        let offset = self.offset();
296
297        if is_f_prefer {
298            if axis == 0 {
299                shape.insert(0, 1);
300                stride.insert(0, 1);
301            } else {
302                shape.insert(axis, 1);
303                stride.insert(axis, stride[axis - 1]);
304            }
305        } else if axis == self.ndim() {
306            shape.push(1);
307            stride.push(1);
308        } else {
309            shape.insert(axis, 1);
310            stride.insert(axis, stride[axis]);
311        }
312
313        let layout = Layout::new(shape, stride, offset)?;
314        return layout.into_dim();
315    }
316}
317
318pub trait IndexerDynamicAPI: IndexerPreserveAPI {
319    /// Index tensor by a list of indexers.
320    fn dim_slice(&self, indexers: &[Indexer]) -> Result<Layout<IxD>>;
321
322    /// Split current layout into two layouts at axis, with offset unchanged.
323    fn dim_split_at(&self, axis: isize) -> Result<(Layout<IxD>, Layout<IxD>)>;
324
325    fn dim_split_axes(&self, axes: &[isize]) -> Result<(Layout<IxD>, Layout<IxD>)>;
326}
327
328impl<D> IndexerDynamicAPI for Layout<D>
329where
330    D: DimDevAPI,
331{
332    fn dim_slice(&self, indexers: &[Indexer]) -> Result<Layout<IxD>> {
333        // transform any layout to dynamic layout
334        let shape = self.shape().as_ref().to_vec();
335        let stride = self.stride().as_ref().to_vec();
336        let mut layout = Layout::new(shape, stride, self.offset)?;
337
338        // clone indexers to vec to make it changeable
339        let mut indexers = indexers.to_vec();
340
341        // counter for indexer
342        let mut counter_slice = 0;
343        let mut counter_select = 0;
344        let mut idx_ellipsis = None;
345        for (n, indexer) in indexers.iter().enumerate() {
346            match indexer {
347                Indexer::Slice(_) => counter_slice += 1,
348                Indexer::Select(_) => counter_select += 1,
349                Indexer::Ellipsis => match idx_ellipsis {
350                    Some(_) => rstsr_raise!(InvalidValue, "Only one ellipsis indexer allowed.")?,
351                    None => idx_ellipsis = Some(n),
352                },
353                _ => {},
354            }
355        }
356
357        // check if slice-type and select-type indexer exceed the number of dimensions
358        rstsr_pattern!(counter_slice + counter_select, 0..=self.ndim(), ValueOutOfRange)?;
359
360        // insert Ellipsis by slice(:) anyway, default append at last
361        let n_ellipsis = self.ndim() - counter_slice - counter_select;
362        if n_ellipsis == 0 {
363            if let Some(idx) = idx_ellipsis {
364                indexers.remove(idx);
365            }
366        } else if let Some(idx_ellipsis) = idx_ellipsis {
367            indexers[idx_ellipsis] = SliceI::new(None, None, None).into();
368            if n_ellipsis > 1 {
369                for _ in 1..n_ellipsis {
370                    indexers.insert(idx_ellipsis, SliceI::new(None, None, None).into());
371                }
372            }
373        } else {
374            for _ in 0..n_ellipsis {
375                indexers.push(SliceI::new(None, None, None).into());
376            }
377        }
378
379        // handle indexers from last
380        // it is possible to be zero-dim, minus after -= 1
381        let mut cur_dim = self.ndim() as isize;
382        for indexer in indexers.iter().rev() {
383            match indexer {
384                Indexer::Slice(slice) => {
385                    cur_dim -= 1;
386                    layout = layout.dim_narrow(cur_dim, *slice)?;
387                },
388                Indexer::Select(index) => {
389                    cur_dim -= 1;
390                    layout = layout.dim_select(cur_dim, *index)?;
391                },
392                Indexer::Insert => {
393                    layout = layout.dim_insert(cur_dim)?;
394                },
395                _ => rstsr_raise!(InvalidValue, "Invalid indexer found : {:?}", indexer)?,
396            }
397        }
398
399        // this program should be designed that cur_dim is zero at the end
400        rstsr_assert!(cur_dim == 0, Miscellaneous, "Internal program error in indexer.")?;
401
402        return Ok(layout);
403    }
404
405    fn dim_split_at(&self, axis: isize) -> Result<(Layout<IxD>, Layout<IxD>)> {
406        // dimension check
407        // this functions allows [-n, n], not previous functions [-n, n)
408        let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
409        rstsr_pattern!(axis, 0..=self.ndim() as isize, ValueOutOfRange)?;
410        let axis = axis as usize;
411
412        // split layouts
413        let shape = self.shape().as_ref().to_vec();
414        let stride = self.stride().as_ref().to_vec();
415        let offset = self.offset();
416
417        let (shape1, shape2) = shape.split_at(axis);
418        let (stride1, stride2) = stride.split_at(axis);
419
420        let layout1 = unsafe { Layout::new_unchecked(shape1.to_vec(), stride1.to_vec(), offset) };
421        let layout2 = unsafe { Layout::new_unchecked(shape2.to_vec(), stride2.to_vec(), offset) };
422        return Ok((layout1, layout2));
423    }
424
425    fn dim_split_axes(&self, axes: &[isize]) -> Result<(Layout<IxD>, Layout<IxD>)> {
426        // returned layouts will be
427        // (layout_axes, layout_rest)
428
429        // axes to usize
430        let mut axes_update: Vec<usize> = vec![];
431        for &axis in axes {
432            let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
433            rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
434            axes_update.push(axis as usize);
435        }
436
437        // check same values
438        let axes_check = axes_update.clone();
439        axes_update.sort();
440        axes_update.dedup();
441        rstsr_assert_eq!(
442            axes_update.len(),
443            axes_check.len(),
444            InvalidLayout,
445            "Same axis is not allowed for this function."
446        )?;
447
448        // rest of axes
449        // this is not the most efficient way, but low cost when dimension is small
450        let axes_rest =
451            (0..self.ndim()).filter(|&axis| !axes_update.contains(&axis)).collect::<Vec<_>>();
452
453        // split layouts for axes
454        let offset = self.offset();
455        let shape_axes = axes_update.iter().map(|&axis| self.shape()[axis]).collect::<Vec<_>>();
456        let strides_axes = axes_update.iter().map(|&axis| self.stride()[axis]).collect::<Vec<_>>();
457        let layout_axes = Layout::new(shape_axes, strides_axes, offset)?;
458
459        let shape_rest = axes_rest.iter().map(|&axis| self.shape()[axis]).collect::<Vec<_>>();
460        let strides_rest = axes_rest.iter().map(|&axis| self.stride()[axis]).collect::<Vec<_>>();
461        let layout_rest = Layout::new(shape_rest, strides_rest, offset)?;
462
463        return Ok((layout_axes, layout_rest));
464    }
465}
466
467/// Generate slice with into support and optional parameters.
468#[macro_export]
469macro_rules! slice {
470    ($stop:expr) => {{
471        use $crate::layout::slice::Slice;
472        Slice::<isize>::from(Slice::new(None, $stop, None))
473    }};
474    ($start:expr, $stop:expr) => {{
475        use $crate::layout::slice::Slice;
476        Slice::<isize>::from(Slice::new($start, $stop, None))
477    }};
478    ($start:expr, $stop:expr, $step:expr) => {{
479        use $crate::layout::slice::Slice;
480        Slice::<isize>::from(Slice::new($start, $stop, $step))
481    }};
482}
483
484#[macro_export]
485macro_rules! s {
486    // basic rule
487    [$($slc:expr),*] => {
488        [$(($slc).into()),*].as_ref()
489    };
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn test_slice() {
498        let t = 3_usize;
499        let s = slice!(1, 2, t);
500        assert_eq!(s.start(), Some(1));
501        assert_eq!(s.stop(), Some(2));
502        assert_eq!(s.step(), Some(3));
503    }
504
505    #[test]
506    fn test_slice_at_dim() {
507        let l = Layout::new([2, 3, 4], [1, 10, 100], 0).unwrap();
508        let s = slice!(10, 1, -1);
509        let l1 = l.dim_narrow(1, s).unwrap();
510        println!("{:?}", l1);
511        let l2 = l.dim_select(1, -2).unwrap();
512        println!("{:?}", l2);
513        let l3 = l.dim_insert(1).unwrap();
514        println!("{:?}", l3);
515
516        let l = Layout::new([2, 3, 4], [100, 10, 1], 0).unwrap();
517        let l3 = l.dim_insert(1).unwrap();
518        println!("{:?}", l3);
519
520        let l4 = l.dim_slice(s![Indexer::Ellipsis, 1..3, None, 2]).unwrap();
521        let l4 = l4.into_dim::<Ix3>().unwrap();
522        println!("{:?}", l4);
523        assert_eq!(l4.shape(), &[2, 2, 1]);
524        assert_eq!(l4.offset(), 12);
525
526        let l5 = l.dim_slice(s![None, 1, None, 1..3]).unwrap();
527        let l5 = l5.into_dim::<Ix4>().unwrap();
528        println!("{:?}", l5);
529        assert_eq!(l5.shape(), &[1, 1, 2, 4]);
530        assert_eq!(l5.offset(), 110);
531    }
532
533    #[test]
534    fn test_slice_with_stride() {
535        let l = Layout::new([24], [1], 0).unwrap();
536        let b = l.dim_narrow(0, slice!(5, 15, 2)).unwrap();
537        assert_eq!(b, Layout::new([5], [2], 5).unwrap());
538        let b = l.dim_narrow(0, slice!(5, 16, 2)).unwrap();
539        assert_eq!(b, Layout::new([6], [2], 5).unwrap());
540        let b = l.dim_narrow(0, slice!(15, 5, -2)).unwrap();
541        assert_eq!(b, Layout::new([5], [-2], 15).unwrap());
542        let b = l.dim_narrow(0, slice!(15, 4, -2)).unwrap();
543        assert_eq!(b, Layout::new([6], [-2], 15).unwrap());
544    }
545
546    #[test]
547    fn test_expand_dims() {
548        let l = Layout::<Ix3>::new([2, 3, 4], [1, 10, 100], 0).unwrap();
549        let l1 = l.dim_insert(0).unwrap();
550        println!("{:?}", l1);
551        let l2 = l.dim_insert(1).unwrap();
552        println!("{:?}", l2);
553        let l3 = l.dim_insert(3).unwrap();
554        println!("{:?}", l3);
555        let l4 = l.dim_insert(-1).unwrap();
556        println!("{:?}", l4);
557        let l5 = l.dim_insert(-4).unwrap();
558        println!("{:?}", l5);
559    }
560}