rten_tensor/
slice_range.rs

1use smallvec::SmallVec;
2
3use std::fmt::Debug;
4use std::ops::{Range, RangeFrom, RangeFull, RangeTo};
5
6/// Specifies a subset of a dimension to include when slicing a tensor or view.
7///
8/// Can be constructed from an index or range using `index_or_range.into()`.
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub enum SliceItem {
11    /// Extract a specific index from a dimension.
12    ///
13    /// The number of dimensions in the sliced view will be one minus the number
14    /// of dimensions sliced with an index. If the index is negative, it counts
15    /// back from the end of the dimension.
16    Index(isize),
17
18    /// Include a subset of the range of the dimension.
19    Range(SliceRange),
20}
21
22impl SliceItem {
23    /// Return a SliceItem that extracts the full range of a dimension.
24    #[inline]
25    pub fn full_range() -> Self {
26        (..).into()
27    }
28
29    /// Return a SliceItem that extracts part of an axis.
30    #[inline]
31    pub fn range(start: isize, end: Option<isize>, step: isize) -> SliceItem {
32        SliceItem::Range(SliceRange::new(start, end, step))
33    }
34
35    /// Return stepped index range selected by this item from an axis with a
36    /// given size.
37    pub(crate) fn index_range(&self, dim_size: usize) -> IndexRange {
38        let range = match *self {
39            SliceItem::Range(range) => range,
40            SliceItem::Index(idx) => SliceRange::new(idx, Some(idx + 1), 1),
41        };
42        range.index_range(dim_size)
43    }
44}
45
46// This conversion exists to avoid ambiguity when slicing a tensor with a
47// numeric literal of unspecified type (eg. `tensor.slice((0, 0))`). In this
48// case it is ambiguous which `SliceItem::from` should be used, but the i32
49// case is used if it exists.
50impl From<i32> for SliceItem {
51    #[inline]
52    fn from(value: i32) -> Self {
53        SliceItem::Index(value as isize)
54    }
55}
56
57impl From<isize> for SliceItem {
58    #[inline]
59    fn from(value: isize) -> Self {
60        SliceItem::Index(value)
61    }
62}
63
64impl From<usize> for SliceItem {
65    #[inline]
66    fn from(value: usize) -> Self {
67        SliceItem::Index(value as isize)
68    }
69}
70
71impl<R> From<R> for SliceItem
72where
73    R: Into<SliceRange>,
74{
75    fn from(value: R) -> Self {
76        SliceItem::Range(value.into())
77    }
78}
79
80/// Used to convert sequences of indices and/or ranges into a uniform
81/// `[SliceItem]` array that can be used to slice a tensor.
82///
83/// This trait is implemented for:
84///
85///  - Individual indices and ranges (types satisfying `Into<SliceItem>`)
86///  - Arrays of indices or ranges
87///  - Tuples of indices and/or ranges
88///  - `[SliceItem]` slices
89///
90/// Ranges can be specified using regular Rust ranges (eg. `start..end`,
91/// `start..`, `..end`, `..`) or a [`SliceRange`], which extends regular Rust
92/// ranges with support for steps and specifying endpoints using negative
93/// values, which behaves similarly to using negative values in NumPy.
94pub trait IntoSliceItems {
95    type Array: AsRef<[SliceItem]>;
96
97    fn into_slice_items(self) -> Self::Array;
98}
99
100impl<'a> IntoSliceItems for &'a [SliceItem] {
101    type Array = &'a [SliceItem];
102
103    fn into_slice_items(self) -> &'a [SliceItem] {
104        self
105    }
106}
107
108impl<const N: usize, T: Into<SliceItem>> IntoSliceItems for [T; N] {
109    type Array = [SliceItem; N];
110
111    fn into_slice_items(self) -> [SliceItem; N] {
112        self.map(|x| x.into())
113    }
114}
115
116impl<T: Into<SliceItem>> IntoSliceItems for T {
117    type Array = [SliceItem; 1];
118
119    fn into_slice_items(self) -> [SliceItem; 1] {
120        [self.into()]
121    }
122}
123
124impl<T1: Into<SliceItem>> IntoSliceItems for (T1,) {
125    type Array = [SliceItem; 1];
126
127    fn into_slice_items(self) -> [SliceItem; 1] {
128        [self.0.into()]
129    }
130}
131
132impl<T1: Into<SliceItem>, T2: Into<SliceItem>> IntoSliceItems for (T1, T2) {
133    type Array = [SliceItem; 2];
134
135    fn into_slice_items(self) -> [SliceItem; 2] {
136        [self.0.into(), self.1.into()]
137    }
138}
139
140impl<T1: Into<SliceItem>, T2: Into<SliceItem>, T3: Into<SliceItem>> IntoSliceItems
141    for (T1, T2, T3)
142{
143    type Array = [SliceItem; 3];
144
145    fn into_slice_items(self) -> [SliceItem; 3] {
146        [self.0.into(), self.1.into(), self.2.into()]
147    }
148}
149
150impl<T1: Into<SliceItem>, T2: Into<SliceItem>, T3: Into<SliceItem>, T4: Into<SliceItem>>
151    IntoSliceItems for (T1, T2, T3, T4)
152{
153    type Array = [SliceItem; 4];
154
155    fn into_slice_items(self) -> [SliceItem; 4] {
156        [self.0.into(), self.1.into(), self.2.into(), self.3.into()]
157    }
158}
159
160/// Dynamically sized array of [`SliceItem`]s, which avoids allocating in the
161/// common case where the length is small.
162pub type DynSliceItems = SmallVec<[SliceItem; 5]>;
163
164/// Convert a slice of indices into [`SliceItem`]s.
165///
166/// To convert indices of a statically known length to [`SliceItem`]s, use
167/// [`IntoSliceItems`] instead. This function is for the case when the length
168/// is not statically known, but is assumed to likely be small.
169pub fn to_slice_items<T: Clone + Into<SliceItem>>(index: &[T]) -> DynSliceItems {
170    index.iter().map(|x| x.clone().into()).collect()
171}
172
173/// A range for slicing a [`Tensor`](crate::Tensor) or [`NdTensor`](crate::NdTensor).
174///
175/// This has two main differences from [`Range`].
176///
177/// - A non-zero step between indices can be specified. The step can be negative,
178///   which means that the dimension should be traversed in reverse order.
179/// - The `start` and `end` indexes can also be negative, in which case they
180///   count backwards from the end of the array.
181///
182/// This system for specifying slicing and indexing follows NumPy, which in
183/// turn strongly influenced slicing in ONNX.
184#[derive(Clone, Copy, Debug, PartialEq)]
185pub struct SliceRange {
186    /// First index in range.
187    pub start: isize,
188
189    /// Last index (exclusive) in range, or None if the range extends to the
190    /// end of a dimension.
191    pub end: Option<isize>,
192
193    /// The steps between adjacent elements selected by this range. This
194    /// is private so this module can enforce the invariant that it is non-zero.
195    step: isize,
196}
197
198impl SliceRange {
199    /// Create a new range from `start` to `end`. The `start` index is inclusive
200    /// and the `end` value is exclusive. If `end` is None, the range spans
201    /// to the end of the dimension.
202    ///
203    /// Panics if the `step` size is 0.
204    #[inline]
205    pub fn new(start: isize, end: Option<isize>, step: isize) -> SliceRange {
206        assert!(step != 0, "Slice step cannot be 0");
207        SliceRange { start, end, step }
208    }
209
210    /// Return the number of elements that would be retained if using this range
211    /// to slice a dimension of size `dim_size`.
212    pub fn steps(&self, dim_size: usize) -> usize {
213        let clamped = self.clamp(dim_size);
214
215        let start_idx = Self::offset_from_start(clamped.start, dim_size);
216        let end_idx = clamped
217            .end
218            .map(|index| Self::offset_from_start(index, dim_size))
219            .unwrap_or(if self.step > 0 { dim_size as isize } else { -1 });
220
221        if (clamped.step > 0 && end_idx <= start_idx) || (clamped.step < 0 && end_idx >= start_idx)
222        {
223            return 0;
224        }
225
226        let steps = if clamped.step > 0 {
227            1 + (end_idx - start_idx - 1) / clamped.step
228        } else {
229            1 + (start_idx - end_idx - 1) / -clamped.step
230        };
231
232        steps.max(0) as usize
233    }
234
235    /// Return a copy of this range with indexes adjusted so that they are valid
236    /// for a tensor dimension of size `dim_size`.
237    ///
238    /// Valid indexes depend on direction that the dimension is traversed
239    /// (forwards if `self.step` is positive or backwards if negative). They
240    /// start at the first element going in that direction and end after the
241    /// last element.
242    pub fn clamp(&self, dim_size: usize) -> SliceRange {
243        let len = dim_size as isize;
244
245        let min_idx;
246        let max_idx;
247
248        if self.step > 0 {
249            // When traversing forwards, the range of valid +ve indexes is `[0,
250            // len]` and for -ve indexes `[-len, -1]`.
251            min_idx = -len;
252            max_idx = len;
253        } else {
254            // When traversing backwards, the range of valid +ve indexes are
255            // `[0, len-1]` and for -ve indexes `[-len-1, -1]`.
256            min_idx = -len - 1;
257            max_idx = len - 1;
258        }
259
260        SliceRange::new(
261            self.start.clamp(min_idx, max_idx),
262            self.end.map(|e| e.clamp(min_idx, max_idx)),
263            self.step,
264        )
265    }
266
267    pub fn step(&self) -> isize {
268        self.step
269    }
270
271    /// Clamp this range so that it is valid for a dimension of size `dim_size`
272    /// and resolve it to a positive range.
273    ///
274    /// This method is useful for implementing Python/NumPy-style slicing where
275    /// range endpoints can be out of bounds.
276    pub fn resolve_clamped(&self, dim_size: usize) -> Range<usize> {
277        self.clamp(dim_size).resolve(dim_size).unwrap()
278    }
279
280    /// Resolve the range endpoints to a positive range in `[0, dim_size)`.
281    ///
282    /// Returns the range if resolved or None if out of bounds.
283    ///
284    /// If `self.step` is positive, the returned range counts forwards from
285    /// the first index of the dimension, otherwise it counts backwards from
286    /// the last index.
287    #[inline]
288    pub fn resolve(&self, dim_size: usize) -> Option<Range<usize>> {
289        let (start, end) = if self.step > 0 {
290            let start = Self::offset_from_start(self.start, dim_size);
291            let end = self
292                .end
293                .map(|end| Self::offset_from_start(end, dim_size))
294                .unwrap_or(dim_size as isize);
295            (start, end)
296        } else {
297            let start = Self::offset_from_end(self.start, dim_size);
298            let end = self
299                .end
300                .map(|end| Self::offset_from_end(end, dim_size))
301                .unwrap_or(dim_size as isize);
302            (start, end)
303        };
304
305        if start >= 0 && start <= dim_size as isize && end >= 0 && end <= dim_size as isize {
306            // If `end < start` this means the range is empty. Set `end ==
307            // start` to have a canonical representation for this case.
308            let end = end.max(start);
309
310            Some(start as usize..end as usize)
311        } else {
312            None
313        }
314    }
315
316    /// Return stepped index range selected by this range from an axis with a
317    /// given size.
318    pub(crate) fn index_range(&self, dim_size: usize) -> IndexRange {
319        // Resolve range endpoints to `[0, N]`, counting forwards from the
320        // start if step > 0 or backwards from the end otherwise.
321        let resolved = self.resolve_clamped(dim_size);
322
323        if self.step > 0 {
324            IndexRange::new(resolved.start, resolved.end as isize, self.step)
325        } else {
326            IndexRange::new(
327                dim_size - 1 - resolved.start,
328                dim_size as isize - 1 - resolved.end as isize,
329                self.step,
330            )
331        }
332    }
333
334    /// Resolve an index to an offset from the first index of the dimension.
335    #[inline]
336    fn offset_from_start(index: isize, dim_size: usize) -> isize {
337        if index >= 0 {
338            index
339        } else {
340            dim_size as isize + index
341        }
342    }
343
344    /// Resolve an index to an offset from the last index of the dimension.
345    #[inline]
346    fn offset_from_end(index: isize, dim_size: usize) -> isize {
347        if index >= 0 {
348            dim_size as isize - 1 - index
349        } else {
350            -index - 1
351        }
352    }
353}
354
355impl<T> From<Range<T>> for SliceRange
356where
357    T: TryInto<isize>,
358    <T as TryInto<isize>>::Error: Debug,
359{
360    fn from(r: Range<T>) -> SliceRange {
361        let start = r.start.try_into().unwrap();
362        let end = r.end.try_into().unwrap();
363        SliceRange::new(start, Some(end), 1)
364    }
365}
366
367impl<T> From<RangeTo<T>> for SliceRange
368where
369    T: TryInto<isize>,
370    <T as TryInto<isize>>::Error: Debug,
371{
372    fn from(r: RangeTo<T>) -> SliceRange {
373        let end = r.end.try_into().unwrap();
374        SliceRange::new(0, Some(end), 1)
375    }
376}
377
378impl<T> From<RangeFrom<T>> for SliceRange
379where
380    T: TryInto<isize>,
381    <T as TryInto<isize>>::Error: Debug,
382{
383    fn from(r: RangeFrom<T>) -> SliceRange {
384        let start = r.start.try_into().unwrap();
385        SliceRange::new(start, None, 1)
386    }
387}
388
389impl From<RangeFull> for SliceRange {
390    #[inline]
391    fn from(_: RangeFull) -> SliceRange {
392        SliceRange::new(0, None, 1)
393    }
394}
395
396/// A range of indices with a step, which may be positive or negative.
397#[derive(Copy, Clone, Debug, PartialEq)]
398pub struct IndexRange {
399    /// Start index in [0, (dim_size - 1).max(0)]
400    start: usize,
401
402    /// End index in [-1, dim_size]
403    end: isize,
404    step: isize,
405}
406
407impl IndexRange {
408    /// Create a new range which steps from `start` (inclusive) to `end`
409    /// (exclusive) with a given step.
410    ///
411    /// The `step` value must not be zero.
412    ///
413    /// The `end` argument is signed to allow for a range which yields index 0
414    /// when `step` is negative. eg. `SteppedIndexRange::new(4, -1, -1)` will
415    /// yield indices `[4, 3, 2, 1, 0]`.
416    fn new(start: usize, end: isize, step: isize) -> Self {
417        assert!(step != 0);
418        assert!(start <= isize::MAX as usize);
419
420        IndexRange {
421            start,
422            end: end.max(-1),
423            step,
424        }
425    }
426
427    /// Return the start index.
428    #[allow(unused)]
429    pub fn start(&self) -> usize {
430        self.start
431    }
432
433    /// Return the index that is one past the end. This is signed since this
434    /// index can be -1 when `self.step() < 0`.
435    #[allow(unused)]
436    pub fn end(&self) -> isize {
437        self.end
438    }
439
440    /// Return the increment between indices.
441    #[allow(unused)]
442    pub fn step(&self) -> isize {
443        self.step
444    }
445
446    /// Return the number of steps along this dimension.
447    pub fn steps(&self) -> usize {
448        let len = if self.step > 0 {
449            (self.end - self.start as isize).max(0).unsigned_abs()
450        } else {
451            (self.end - self.start as isize).min(0).unsigned_abs()
452        };
453        len.div_ceil(self.step.unsigned_abs())
454    }
455}
456
457impl IntoIterator for IndexRange {
458    type Item = usize;
459    type IntoIter = IndexRangeIter;
460
461    #[inline]
462    fn into_iter(self) -> IndexRangeIter {
463        IndexRangeIter {
464            step: self.step,
465            index: self.start as isize,
466            remaining: self.steps(),
467        }
468    }
469}
470
471/// An iterator over the indices in an [`IndexRange`].
472#[derive(Clone, Debug, PartialEq)]
473pub struct IndexRangeIter {
474    /// Next index. This is in the range [-1, N] where `N` is the size of
475    /// the dimension. The values yielded by `next` are always in [0, N).
476    index: isize,
477
478    /// Remaining indices to yield.
479    remaining: usize,
480
481    step: isize,
482}
483
484impl Iterator for IndexRangeIter {
485    type Item = usize;
486
487    #[inline]
488    fn next(&mut self) -> Option<usize> {
489        if self.remaining == 0 {
490            return None;
491        }
492        let idx = self.index;
493        self.index += self.step;
494        self.remaining -= 1;
495        Some(idx as usize)
496    }
497
498    #[inline]
499    fn size_hint(&self) -> (usize, Option<usize>) {
500        (self.remaining, Some(self.remaining))
501    }
502}
503
504impl ExactSizeIterator for IndexRangeIter {}
505impl std::iter::FusedIterator for IndexRangeIter {}
506
507#[cfg(test)]
508mod tests {
509    use rten_testing::TestCases;
510
511    use super::{IntoSliceItems, SliceItem, SliceRange};
512
513    #[test]
514    fn test_into_slice_items() {
515        let x = (42).into_slice_items();
516        assert_eq!(x, [SliceItem::Index(42)]);
517
518        let x = (2..5).into_slice_items();
519        assert_eq!(x, [SliceItem::Range((2..5).into())]);
520
521        let x = (..5).into_slice_items();
522        assert_eq!(x, [SliceItem::Range((0..5).into())]);
523
524        let x = (3..).into_slice_items();
525        assert_eq!(x, [SliceItem::Range((3..).into())]);
526
527        let x = [1].into_slice_items();
528        assert_eq!(x, [SliceItem::Index(1)]);
529        let x = [1, 2].into_slice_items();
530        assert_eq!(x, [SliceItem::Index(1), SliceItem::Index(2)]);
531
532        let x = (0, 1..2, ..).into_slice_items();
533        assert_eq!(
534            x,
535            [
536                SliceItem::Index(0),
537                SliceItem::Range((1..2).into()),
538                SliceItem::full_range()
539            ]
540        );
541    }
542
543    #[test]
544    fn test_index_range() {
545        #[derive(Debug)]
546        struct Case {
547            range: SliceItem,
548            dim_size: usize,
549            indices: Vec<usize>,
550        }
551
552        let cases = [
553            // +ve step, +ve endpoints
554            Case {
555                range: SliceItem::range(0, Some(4), 1),
556                dim_size: 6,
557                indices: (0..4).collect(),
558            },
559            Case {
560                range: SliceItem::range(2, Some(4), 1),
561                dim_size: 6,
562                indices: vec![2, 3],
563            },
564            Case {
565                range: SliceItem::range(2, Some(128), 1),
566                dim_size: 5,
567                indices: vec![2, 3, 4],
568            },
569            // +ve step > 1, +ve endpoints
570            Case {
571                range: SliceItem::range(0, Some(5), 2),
572                dim_size: 5,
573                indices: vec![0, 2, 4],
574            },
575            // +ve step, no end
576            Case {
577                range: SliceItem::range(0, None, 1),
578                dim_size: 6,
579                indices: (0..6).collect(),
580            },
581            // +ve step, -ve endpoints
582            Case {
583                range: SliceItem::range(-1, Some(-6), 2),
584                dim_size: 5,
585                indices: vec![],
586            },
587            // -ve step, -ve endpoints
588            Case {
589                range: SliceItem::range(-1, Some(-128), -1),
590                dim_size: 5,
591                indices: vec![4, 3, 2, 1, 0],
592            },
593            // -ve step, no end
594            Case {
595                range: SliceItem::range(-1, None, -1),
596                dim_size: 5,
597                indices: vec![4, 3, 2, 1, 0],
598            },
599            // -ve step < -1, -ve endpoints
600            Case {
601                range: SliceItem::range(-1, Some(-6), -2),
602                dim_size: 5,
603                indices: vec![4, 2, 0],
604            },
605            // -ve step, +ve endpoints
606            Case {
607                range: SliceItem::range(1, Some(5), -2),
608                dim_size: 5,
609                indices: vec![],
610            },
611            // Empty range, +ve step
612            Case {
613                range: SliceItem::range(0, Some(0), 1),
614                dim_size: 4,
615                indices: vec![],
616            },
617            // Empty range, -ve step
618            Case {
619                range: SliceItem::range(0, Some(0), -1),
620                dim_size: 4,
621                indices: vec![],
622            },
623            // Single index
624            Case {
625                range: SliceItem::Index(2),
626                dim_size: 4,
627                indices: vec![2],
628            },
629            // Single index, out of range
630            Case {
631                range: SliceItem::Index(2),
632                dim_size: 0,
633                indices: vec![],
634            },
635        ];
636
637        cases.test_each(|case| {
638            let Case {
639                range,
640                dim_size,
641                indices,
642            } = case;
643
644            let mut index_iter = range.index_range(*dim_size).into_iter();
645            let size_hint = index_iter.size_hint();
646            let index_vec: Vec<_> = index_iter.by_ref().collect();
647
648            assert_eq!(size_hint, (index_vec.len(), Some(index_vec.len())));
649            assert_eq!(index_vec, *indices);
650            assert_eq!(index_iter.size_hint(), (0, Some(0)));
651        })
652    }
653
654    #[test]
655    fn test_index_range_steps() {
656        #[derive(Debug)]
657        struct Case {
658            range: SliceRange,
659            dim_size: usize,
660            steps: usize,
661        }
662
663        let cases = [
664            // Positive step, no end.
665            Case {
666                range: SliceRange::new(0, None, 1),
667                dim_size: 4,
668                steps: 4,
669            },
670            // Positive step size exceeds range length.
671            Case {
672                range: SliceRange::new(0, None, 5),
673                dim_size: 4,
674                steps: 1,
675            },
676            // Negative step, no end.
677            Case {
678                range: SliceRange::new(-1, None, -1),
679                dim_size: 3,
680                steps: 3,
681            },
682            // Negative step size exceeds range length.
683            Case {
684                range: SliceRange::new(1, Some(0), -2),
685                dim_size: 2,
686                steps: 1,
687            },
688        ];
689
690        cases.test_each(|case| {
691            assert_eq!(case.range.index_range(case.dim_size).steps(), case.steps);
692        })
693    }
694
695    #[test]
696    #[should_panic(expected = "Slice step cannot be 0")]
697    fn test_slice_range_zero_step() {
698        SliceRange::new(0, None, 0);
699    }
700
701    #[test]
702    fn test_slice_range_resolve() {
703        // +ve endpoints, +ve step
704        assert_eq!(SliceRange::new(0, Some(5), 1).resolve_clamped(10), 0..5);
705        assert_eq!(SliceRange::new(0, None, 1).resolve_clamped(10), 0..10);
706        assert_eq!(SliceRange::new(15, Some(20), 1).resolve_clamped(10), 10..10);
707        assert_eq!(SliceRange::new(15, Some(20), 1).resolve(10), None);
708        assert_eq!(SliceRange::new(4, None, 1).resolve(3), None);
709        assert_eq!(SliceRange::new(0, Some(10), 1).resolve(3), None);
710
711        // -ve endpoints, +ve step
712        assert_eq!(SliceRange::new(-5, Some(-1), 1).resolve_clamped(10), 5..9);
713        assert_eq!(SliceRange::new(-20, Some(-1), 1).resolve_clamped(10), 0..9);
714        assert_eq!(SliceRange::new(-20, Some(-1), 1).resolve(10), None);
715        assert_eq!(SliceRange::new(-5, None, 1).resolve_clamped(10), 5..10);
716
717        // +ve endpoints, -ve step.
718        //
719        // Note the returned ranges count backwards from the end of the
720        // dimension.
721        assert_eq!(SliceRange::new(5, Some(0), -1).resolve_clamped(10), 4..9);
722        assert_eq!(SliceRange::new(5, None, -1).resolve_clamped(10), 4..10);
723        assert_eq!(SliceRange::new(9, None, -1).resolve_clamped(10), 0..10);
724
725        // -ve endpoints, -ve step.
726        assert_eq!(SliceRange::new(-1, Some(-4), -1).resolve_clamped(3), 0..3);
727        assert_eq!(SliceRange::new(-1, None, -1).resolve_clamped(2), 0..2);
728    }
729}