1use smallvec::SmallVec;
2
3use std::fmt::Debug;
4use std::ops::{Range, RangeFrom, RangeFull, RangeTo};
5
6#[derive(Clone, Copy, Debug, PartialEq)]
10pub enum SliceItem {
11 Index(isize),
17
18 Range(SliceRange),
20}
21
22impl SliceItem {
23 #[inline]
25 pub fn full_range() -> Self {
26 (..).into()
27 }
28
29 #[inline]
31 pub fn range(start: isize, end: Option<isize>, step: isize) -> SliceItem {
32 SliceItem::Range(SliceRange::new(start, end, step))
33 }
34
35 pub(crate) fn index_range(&self, dim_size: usize) -> IndexRange {
38 let range = match *self {
39 SliceItem::Range(range) => range,
40 SliceItem::Index(idx) => SliceRange::new(idx, Some(idx + 1), 1),
41 };
42 range.index_range(dim_size)
43 }
44}
45
46impl From<i32> for SliceItem {
51 #[inline]
52 fn from(value: i32) -> Self {
53 SliceItem::Index(value as isize)
54 }
55}
56
57impl From<isize> for SliceItem {
58 #[inline]
59 fn from(value: isize) -> Self {
60 SliceItem::Index(value)
61 }
62}
63
64impl From<usize> for SliceItem {
65 #[inline]
66 fn from(value: usize) -> Self {
67 SliceItem::Index(value as isize)
68 }
69}
70
71impl<R> From<R> for SliceItem
72where
73 R: Into<SliceRange>,
74{
75 fn from(value: R) -> Self {
76 SliceItem::Range(value.into())
77 }
78}
79
80pub trait IntoSliceItems {
95 type Array: AsRef<[SliceItem]>;
96
97 fn into_slice_items(self) -> Self::Array;
98}
99
100impl<'a> IntoSliceItems for &'a [SliceItem] {
101 type Array = &'a [SliceItem];
102
103 fn into_slice_items(self) -> &'a [SliceItem] {
104 self
105 }
106}
107
108impl<const N: usize, T: Into<SliceItem>> IntoSliceItems for [T; N] {
109 type Array = [SliceItem; N];
110
111 fn into_slice_items(self) -> [SliceItem; N] {
112 self.map(|x| x.into())
113 }
114}
115
116impl<T: Into<SliceItem>> IntoSliceItems for T {
117 type Array = [SliceItem; 1];
118
119 fn into_slice_items(self) -> [SliceItem; 1] {
120 [self.into()]
121 }
122}
123
124impl<T1: Into<SliceItem>> IntoSliceItems for (T1,) {
125 type Array = [SliceItem; 1];
126
127 fn into_slice_items(self) -> [SliceItem; 1] {
128 [self.0.into()]
129 }
130}
131
132impl<T1: Into<SliceItem>, T2: Into<SliceItem>> IntoSliceItems for (T1, T2) {
133 type Array = [SliceItem; 2];
134
135 fn into_slice_items(self) -> [SliceItem; 2] {
136 [self.0.into(), self.1.into()]
137 }
138}
139
140impl<T1: Into<SliceItem>, T2: Into<SliceItem>, T3: Into<SliceItem>> IntoSliceItems
141 for (T1, T2, T3)
142{
143 type Array = [SliceItem; 3];
144
145 fn into_slice_items(self) -> [SliceItem; 3] {
146 [self.0.into(), self.1.into(), self.2.into()]
147 }
148}
149
150impl<T1: Into<SliceItem>, T2: Into<SliceItem>, T3: Into<SliceItem>, T4: Into<SliceItem>>
151 IntoSliceItems for (T1, T2, T3, T4)
152{
153 type Array = [SliceItem; 4];
154
155 fn into_slice_items(self) -> [SliceItem; 4] {
156 [self.0.into(), self.1.into(), self.2.into(), self.3.into()]
157 }
158}
159
160pub type DynSliceItems = SmallVec<[SliceItem; 5]>;
163
164pub fn to_slice_items<T: Clone + Into<SliceItem>>(index: &[T]) -> DynSliceItems {
170 index.iter().map(|x| x.clone().into()).collect()
171}
172
173#[derive(Clone, Copy, Debug, PartialEq)]
185pub struct SliceRange {
186 pub start: isize,
188
189 pub end: Option<isize>,
192
193 step: isize,
196}
197
198impl SliceRange {
199 #[inline]
205 pub fn new(start: isize, end: Option<isize>, step: isize) -> SliceRange {
206 assert!(step != 0, "Slice step cannot be 0");
207 SliceRange { start, end, step }
208 }
209
210 pub fn steps(&self, dim_size: usize) -> usize {
213 let clamped = self.clamp(dim_size);
214
215 let start_idx = Self::offset_from_start(clamped.start, dim_size);
216 let end_idx = clamped
217 .end
218 .map(|index| Self::offset_from_start(index, dim_size))
219 .unwrap_or(if self.step > 0 { dim_size as isize } else { -1 });
220
221 if (clamped.step > 0 && end_idx <= start_idx) || (clamped.step < 0 && end_idx >= start_idx)
222 {
223 return 0;
224 }
225
226 let steps = if clamped.step > 0 {
227 1 + (end_idx - start_idx - 1) / clamped.step
228 } else {
229 1 + (start_idx - end_idx - 1) / -clamped.step
230 };
231
232 steps.max(0) as usize
233 }
234
235 pub fn clamp(&self, dim_size: usize) -> SliceRange {
243 let len = dim_size as isize;
244
245 let min_idx;
246 let max_idx;
247
248 if self.step > 0 {
249 min_idx = -len;
252 max_idx = len;
253 } else {
254 min_idx = -len - 1;
257 max_idx = len - 1;
258 }
259
260 SliceRange::new(
261 self.start.clamp(min_idx, max_idx),
262 self.end.map(|e| e.clamp(min_idx, max_idx)),
263 self.step,
264 )
265 }
266
267 pub fn step(&self) -> isize {
268 self.step
269 }
270
271 pub fn resolve_clamped(&self, dim_size: usize) -> Range<usize> {
277 self.clamp(dim_size).resolve(dim_size).unwrap()
278 }
279
280 #[inline]
288 pub fn resolve(&self, dim_size: usize) -> Option<Range<usize>> {
289 let (start, end) = if self.step > 0 {
290 let start = Self::offset_from_start(self.start, dim_size);
291 let end = self
292 .end
293 .map(|end| Self::offset_from_start(end, dim_size))
294 .unwrap_or(dim_size as isize);
295 (start, end)
296 } else {
297 let start = Self::offset_from_end(self.start, dim_size);
298 let end = self
299 .end
300 .map(|end| Self::offset_from_end(end, dim_size))
301 .unwrap_or(dim_size as isize);
302 (start, end)
303 };
304
305 if start >= 0 && start <= dim_size as isize && end >= 0 && end <= dim_size as isize {
306 let end = end.max(start);
309
310 Some(start as usize..end as usize)
311 } else {
312 None
313 }
314 }
315
316 pub(crate) fn index_range(&self, dim_size: usize) -> IndexRange {
319 let resolved = self.resolve_clamped(dim_size);
322
323 if self.step > 0 {
324 IndexRange::new(resolved.start, resolved.end as isize, self.step)
325 } else {
326 IndexRange::new(
327 dim_size - 1 - resolved.start,
328 dim_size as isize - 1 - resolved.end as isize,
329 self.step,
330 )
331 }
332 }
333
334 #[inline]
336 fn offset_from_start(index: isize, dim_size: usize) -> isize {
337 if index >= 0 {
338 index
339 } else {
340 dim_size as isize + index
341 }
342 }
343
344 #[inline]
346 fn offset_from_end(index: isize, dim_size: usize) -> isize {
347 if index >= 0 {
348 dim_size as isize - 1 - index
349 } else {
350 -index - 1
351 }
352 }
353}
354
355impl<T> From<Range<T>> for SliceRange
356where
357 T: TryInto<isize>,
358 <T as TryInto<isize>>::Error: Debug,
359{
360 fn from(r: Range<T>) -> SliceRange {
361 let start = r.start.try_into().unwrap();
362 let end = r.end.try_into().unwrap();
363 SliceRange::new(start, Some(end), 1)
364 }
365}
366
367impl<T> From<RangeTo<T>> for SliceRange
368where
369 T: TryInto<isize>,
370 <T as TryInto<isize>>::Error: Debug,
371{
372 fn from(r: RangeTo<T>) -> SliceRange {
373 let end = r.end.try_into().unwrap();
374 SliceRange::new(0, Some(end), 1)
375 }
376}
377
378impl<T> From<RangeFrom<T>> for SliceRange
379where
380 T: TryInto<isize>,
381 <T as TryInto<isize>>::Error: Debug,
382{
383 fn from(r: RangeFrom<T>) -> SliceRange {
384 let start = r.start.try_into().unwrap();
385 SliceRange::new(start, None, 1)
386 }
387}
388
389impl From<RangeFull> for SliceRange {
390 #[inline]
391 fn from(_: RangeFull) -> SliceRange {
392 SliceRange::new(0, None, 1)
393 }
394}
395
396#[derive(Copy, Clone, Debug, PartialEq)]
398pub struct IndexRange {
399 start: usize,
401
402 end: isize,
404 step: isize,
405}
406
407impl IndexRange {
408 fn new(start: usize, end: isize, step: isize) -> Self {
417 assert!(step != 0);
418 assert!(start <= isize::MAX as usize);
419
420 IndexRange {
421 start,
422 end: end.max(-1),
423 step,
424 }
425 }
426
427 #[allow(unused)]
429 pub fn start(&self) -> usize {
430 self.start
431 }
432
433 #[allow(unused)]
436 pub fn end(&self) -> isize {
437 self.end
438 }
439
440 #[allow(unused)]
442 pub fn step(&self) -> isize {
443 self.step
444 }
445
446 pub fn steps(&self) -> usize {
448 let len = if self.step > 0 {
449 (self.end - self.start as isize).max(0).unsigned_abs()
450 } else {
451 (self.end - self.start as isize).min(0).unsigned_abs()
452 };
453 len.div_ceil(self.step.unsigned_abs())
454 }
455}
456
457impl IntoIterator for IndexRange {
458 type Item = usize;
459 type IntoIter = IndexRangeIter;
460
461 #[inline]
462 fn into_iter(self) -> IndexRangeIter {
463 IndexRangeIter {
464 step: self.step,
465 index: self.start as isize,
466 remaining: self.steps(),
467 }
468 }
469}
470
471#[derive(Clone, Debug, PartialEq)]
473pub struct IndexRangeIter {
474 index: isize,
477
478 remaining: usize,
480
481 step: isize,
482}
483
484impl Iterator for IndexRangeIter {
485 type Item = usize;
486
487 #[inline]
488 fn next(&mut self) -> Option<usize> {
489 if self.remaining == 0 {
490 return None;
491 }
492 let idx = self.index;
493 self.index += self.step;
494 self.remaining -= 1;
495 Some(idx as usize)
496 }
497
498 #[inline]
499 fn size_hint(&self) -> (usize, Option<usize>) {
500 (self.remaining, Some(self.remaining))
501 }
502}
503
504impl ExactSizeIterator for IndexRangeIter {}
505impl std::iter::FusedIterator for IndexRangeIter {}
506
507#[cfg(test)]
508mod tests {
509 use rten_testing::TestCases;
510
511 use super::{IntoSliceItems, SliceItem, SliceRange};
512
513 #[test]
514 fn test_into_slice_items() {
515 let x = (42).into_slice_items();
516 assert_eq!(x, [SliceItem::Index(42)]);
517
518 let x = (2..5).into_slice_items();
519 assert_eq!(x, [SliceItem::Range((2..5).into())]);
520
521 let x = (..5).into_slice_items();
522 assert_eq!(x, [SliceItem::Range((0..5).into())]);
523
524 let x = (3..).into_slice_items();
525 assert_eq!(x, [SliceItem::Range((3..).into())]);
526
527 let x = [1].into_slice_items();
528 assert_eq!(x, [SliceItem::Index(1)]);
529 let x = [1, 2].into_slice_items();
530 assert_eq!(x, [SliceItem::Index(1), SliceItem::Index(2)]);
531
532 let x = (0, 1..2, ..).into_slice_items();
533 assert_eq!(
534 x,
535 [
536 SliceItem::Index(0),
537 SliceItem::Range((1..2).into()),
538 SliceItem::full_range()
539 ]
540 );
541 }
542
543 #[test]
544 fn test_index_range() {
545 #[derive(Debug)]
546 struct Case {
547 range: SliceItem,
548 dim_size: usize,
549 indices: Vec<usize>,
550 }
551
552 let cases = [
553 Case {
555 range: SliceItem::range(0, Some(4), 1),
556 dim_size: 6,
557 indices: (0..4).collect(),
558 },
559 Case {
560 range: SliceItem::range(2, Some(4), 1),
561 dim_size: 6,
562 indices: vec![2, 3],
563 },
564 Case {
565 range: SliceItem::range(2, Some(128), 1),
566 dim_size: 5,
567 indices: vec![2, 3, 4],
568 },
569 Case {
571 range: SliceItem::range(0, Some(5), 2),
572 dim_size: 5,
573 indices: vec![0, 2, 4],
574 },
575 Case {
577 range: SliceItem::range(0, None, 1),
578 dim_size: 6,
579 indices: (0..6).collect(),
580 },
581 Case {
583 range: SliceItem::range(-1, Some(-6), 2),
584 dim_size: 5,
585 indices: vec![],
586 },
587 Case {
589 range: SliceItem::range(-1, Some(-128), -1),
590 dim_size: 5,
591 indices: vec![4, 3, 2, 1, 0],
592 },
593 Case {
595 range: SliceItem::range(-1, None, -1),
596 dim_size: 5,
597 indices: vec![4, 3, 2, 1, 0],
598 },
599 Case {
601 range: SliceItem::range(-1, Some(-6), -2),
602 dim_size: 5,
603 indices: vec![4, 2, 0],
604 },
605 Case {
607 range: SliceItem::range(1, Some(5), -2),
608 dim_size: 5,
609 indices: vec![],
610 },
611 Case {
613 range: SliceItem::range(0, Some(0), 1),
614 dim_size: 4,
615 indices: vec![],
616 },
617 Case {
619 range: SliceItem::range(0, Some(0), -1),
620 dim_size: 4,
621 indices: vec![],
622 },
623 Case {
625 range: SliceItem::Index(2),
626 dim_size: 4,
627 indices: vec![2],
628 },
629 Case {
631 range: SliceItem::Index(2),
632 dim_size: 0,
633 indices: vec![],
634 },
635 ];
636
637 cases.test_each(|case| {
638 let Case {
639 range,
640 dim_size,
641 indices,
642 } = case;
643
644 let mut index_iter = range.index_range(*dim_size).into_iter();
645 let size_hint = index_iter.size_hint();
646 let index_vec: Vec<_> = index_iter.by_ref().collect();
647
648 assert_eq!(size_hint, (index_vec.len(), Some(index_vec.len())));
649 assert_eq!(index_vec, *indices);
650 assert_eq!(index_iter.size_hint(), (0, Some(0)));
651 })
652 }
653
654 #[test]
655 fn test_index_range_steps() {
656 #[derive(Debug)]
657 struct Case {
658 range: SliceRange,
659 dim_size: usize,
660 steps: usize,
661 }
662
663 let cases = [
664 Case {
666 range: SliceRange::new(0, None, 1),
667 dim_size: 4,
668 steps: 4,
669 },
670 Case {
672 range: SliceRange::new(0, None, 5),
673 dim_size: 4,
674 steps: 1,
675 },
676 Case {
678 range: SliceRange::new(-1, None, -1),
679 dim_size: 3,
680 steps: 3,
681 },
682 Case {
684 range: SliceRange::new(1, Some(0), -2),
685 dim_size: 2,
686 steps: 1,
687 },
688 ];
689
690 cases.test_each(|case| {
691 assert_eq!(case.range.index_range(case.dim_size).steps(), case.steps);
692 })
693 }
694
695 #[test]
696 #[should_panic(expected = "Slice step cannot be 0")]
697 fn test_slice_range_zero_step() {
698 SliceRange::new(0, None, 0);
699 }
700
701 #[test]
702 fn test_slice_range_resolve() {
703 assert_eq!(SliceRange::new(0, Some(5), 1).resolve_clamped(10), 0..5);
705 assert_eq!(SliceRange::new(0, None, 1).resolve_clamped(10), 0..10);
706 assert_eq!(SliceRange::new(15, Some(20), 1).resolve_clamped(10), 10..10);
707 assert_eq!(SliceRange::new(15, Some(20), 1).resolve(10), None);
708 assert_eq!(SliceRange::new(4, None, 1).resolve(3), None);
709 assert_eq!(SliceRange::new(0, Some(10), 1).resolve(3), None);
710
711 assert_eq!(SliceRange::new(-5, Some(-1), 1).resolve_clamped(10), 5..9);
713 assert_eq!(SliceRange::new(-20, Some(-1), 1).resolve_clamped(10), 0..9);
714 assert_eq!(SliceRange::new(-20, Some(-1), 1).resolve(10), None);
715 assert_eq!(SliceRange::new(-5, None, 1).resolve_clamped(10), 5..10);
716
717 assert_eq!(SliceRange::new(5, Some(0), -1).resolve_clamped(10), 4..9);
722 assert_eq!(SliceRange::new(5, None, -1).resolve_clamped(10), 4..10);
723 assert_eq!(SliceRange::new(9, None, -1).resolve_clamped(10), 0..10);
724
725 assert_eq!(SliceRange::new(-1, Some(-4), -1).resolve_clamped(3), 0..3);
727 assert_eq!(SliceRange::new(-1, None, -1).resolve_clamped(2), 0..2);
728 }
729}