use smallvec::SmallVec;
use std::fmt::Debug;
use std::ops::{Range, RangeFrom, RangeFull, RangeTo};
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum SliceItem {
Index(isize),
Range(SliceRange),
}
impl SliceItem {
#[inline]
pub fn full_range() -> Self {
(..).into()
}
#[inline]
pub fn range(start: isize, end: Option<isize>, step: isize) -> SliceItem {
SliceItem::Range(SliceRange::new(start, end, step))
}
pub(crate) fn index_range(&self, dim_size: usize) -> IndexRange {
let range = match *self {
SliceItem::Range(range) => range,
SliceItem::Index(idx) => SliceRange::new(idx, Some(idx + 1), 1),
};
range.index_range(dim_size)
}
}
impl From<i32> for SliceItem {
#[inline]
fn from(value: i32) -> Self {
SliceItem::Index(value as isize)
}
}
impl From<isize> for SliceItem {
#[inline]
fn from(value: isize) -> Self {
SliceItem::Index(value)
}
}
impl From<usize> for SliceItem {
#[inline]
fn from(value: usize) -> Self {
SliceItem::Index(value as isize)
}
}
impl<R> From<R> for SliceItem
where
R: Into<SliceRange>,
{
fn from(value: R) -> Self {
SliceItem::Range(value.into())
}
}
pub trait IntoSliceItems {
type Array: AsRef<[SliceItem]>;
fn into_slice_items(self) -> Self::Array;
}
impl<'a> IntoSliceItems for &'a [SliceItem] {
type Array = &'a [SliceItem];
fn into_slice_items(self) -> &'a [SliceItem] {
self
}
}
impl<const N: usize, T: Into<SliceItem>> IntoSliceItems for [T; N] {
type Array = [SliceItem; N];
fn into_slice_items(self) -> [SliceItem; N] {
self.map(|x| x.into())
}
}
impl<T: Into<SliceItem>> IntoSliceItems for T {
type Array = [SliceItem; 1];
fn into_slice_items(self) -> [SliceItem; 1] {
[self.into()]
}
}
impl<T1: Into<SliceItem>> IntoSliceItems for (T1,) {
type Array = [SliceItem; 1];
fn into_slice_items(self) -> [SliceItem; 1] {
[self.0.into()]
}
}
impl<T1: Into<SliceItem>, T2: Into<SliceItem>> IntoSliceItems for (T1, T2) {
type Array = [SliceItem; 2];
fn into_slice_items(self) -> [SliceItem; 2] {
[self.0.into(), self.1.into()]
}
}
impl<T1: Into<SliceItem>, T2: Into<SliceItem>, T3: Into<SliceItem>> IntoSliceItems
for (T1, T2, T3)
{
type Array = [SliceItem; 3];
fn into_slice_items(self) -> [SliceItem; 3] {
[self.0.into(), self.1.into(), self.2.into()]
}
}
impl<T1: Into<SliceItem>, T2: Into<SliceItem>, T3: Into<SliceItem>, T4: Into<SliceItem>>
IntoSliceItems for (T1, T2, T3, T4)
{
type Array = [SliceItem; 4];
fn into_slice_items(self) -> [SliceItem; 4] {
[self.0.into(), self.1.into(), self.2.into(), self.3.into()]
}
}
pub type DynSliceItems = SmallVec<[SliceItem; 5]>;
pub fn to_slice_items<T: Clone + Into<SliceItem>>(index: &[T]) -> DynSliceItems {
index.iter().map(|x| x.clone().into()).collect()
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct SliceRange {
pub start: isize,
pub end: Option<isize>,
step: isize,
}
impl SliceRange {
#[inline]
pub fn new(start: isize, end: Option<isize>, step: isize) -> SliceRange {
assert!(step != 0, "Slice step cannot be 0");
SliceRange { start, end, step }
}
pub fn steps(&self, dim_size: usize) -> usize {
let clamped = self.clamp(dim_size);
let start_idx = Self::offset_from_start(clamped.start, dim_size);
let end_idx = clamped
.end
.map(|index| Self::offset_from_start(index, dim_size))
.unwrap_or(if self.step > 0 { dim_size as isize } else { -1 });
if (clamped.step > 0 && end_idx <= start_idx) || (clamped.step < 0 && end_idx >= start_idx)
{
return 0;
}
let steps = if clamped.step > 0 {
1 + (end_idx - start_idx - 1) / clamped.step
} else {
1 + (start_idx - end_idx - 1) / -clamped.step
};
steps.max(0) as usize
}
pub fn clamp(&self, dim_size: usize) -> SliceRange {
let len = dim_size as isize;
let min_idx;
let max_idx;
if self.step > 0 {
min_idx = -len;
max_idx = len;
} else {
min_idx = -len - 1;
max_idx = len - 1;
}
SliceRange::new(
self.start.clamp(min_idx, max_idx),
self.end.map(|e| e.clamp(min_idx, max_idx)),
self.step,
)
}
pub fn step(&self) -> isize {
self.step
}
pub fn resolve_clamped(&self, dim_size: usize) -> Range<usize> {
self.clamp(dim_size).resolve(dim_size).unwrap()
}
#[inline]
pub fn resolve(&self, dim_size: usize) -> Option<Range<usize>> {
let (start, end) = if self.step > 0 {
let start = Self::offset_from_start(self.start, dim_size);
let end = self
.end
.map(|end| Self::offset_from_start(end, dim_size))
.unwrap_or(dim_size as isize);
(start, end)
} else {
let start = Self::offset_from_end(self.start, dim_size);
let end = self
.end
.map(|end| Self::offset_from_end(end, dim_size))
.unwrap_or(dim_size as isize);
(start, end)
};
if start >= 0 && start <= dim_size as isize && end >= 0 && end <= dim_size as isize {
let end = end.max(start);
Some(start as usize..end as usize)
} else {
None
}
}
pub(crate) fn index_range(&self, dim_size: usize) -> IndexRange {
let resolved = self.resolve_clamped(dim_size);
if self.step > 0 {
IndexRange::new(resolved.start, resolved.end as isize, self.step)
} else {
IndexRange::new(
dim_size - 1 - resolved.start,
dim_size as isize - 1 - resolved.end as isize,
self.step,
)
}
}
#[inline]
fn offset_from_start(index: isize, dim_size: usize) -> isize {
if index >= 0 {
index
} else {
dim_size as isize + index
}
}
#[inline]
fn offset_from_end(index: isize, dim_size: usize) -> isize {
if index >= 0 {
dim_size as isize - 1 - index
} else {
-index - 1
}
}
}
impl<T> From<Range<T>> for SliceRange
where
T: TryInto<isize>,
<T as TryInto<isize>>::Error: Debug,
{
fn from(r: Range<T>) -> SliceRange {
let start = r.start.try_into().unwrap();
let end = r.end.try_into().unwrap();
SliceRange::new(start, Some(end), 1)
}
}
impl<T> From<RangeTo<T>> for SliceRange
where
T: TryInto<isize>,
<T as TryInto<isize>>::Error: Debug,
{
fn from(r: RangeTo<T>) -> SliceRange {
let end = r.end.try_into().unwrap();
SliceRange::new(0, Some(end), 1)
}
}
impl<T> From<RangeFrom<T>> for SliceRange
where
T: TryInto<isize>,
<T as TryInto<isize>>::Error: Debug,
{
fn from(r: RangeFrom<T>) -> SliceRange {
let start = r.start.try_into().unwrap();
SliceRange::new(start, None, 1)
}
}
impl From<RangeFull> for SliceRange {
#[inline]
fn from(_: RangeFull) -> SliceRange {
SliceRange::new(0, None, 1)
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub(crate) struct IndexRange {
start: usize,
end: isize,
step: isize,
}
impl IndexRange {
fn new(start: usize, end: isize, step: isize) -> Self {
assert!(step != 0);
assert!(start <= isize::MAX as usize);
IndexRange {
start,
end: end.max(-1),
step,
}
}
#[allow(unused)]
pub fn start(&self) -> usize {
self.start
}
#[allow(unused)]
pub fn end(&self) -> isize {
self.end
}
#[allow(unused)]
pub fn step(&self) -> isize {
self.step
}
pub fn steps(&self) -> usize {
let len = if self.step > 0 {
(self.end - self.start as isize).max(0).unsigned_abs()
} else {
(self.end - self.start as isize).min(0).unsigned_abs()
};
len.div_ceil(self.step.unsigned_abs())
}
}
impl IntoIterator for IndexRange {
type Item = usize;
type IntoIter = IndexRangeIter;
#[inline]
fn into_iter(self) -> IndexRangeIter {
IndexRangeIter {
step: self.step,
index: self.start as isize,
remaining: self.steps(),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct IndexRangeIter {
index: isize,
remaining: usize,
step: isize,
}
impl Iterator for IndexRangeIter {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<usize> {
if self.remaining == 0 {
return None;
}
let idx = self.index;
self.index += self.step;
self.remaining -= 1;
Some(idx as usize)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl ExactSizeIterator for IndexRangeIter {}
impl std::iter::FusedIterator for IndexRangeIter {}
#[cfg(test)]
mod tests {
use rten_testing::TestCases;
use super::{IntoSliceItems, SliceItem, SliceRange};
#[test]
fn test_into_slice_items() {
let x = (42).into_slice_items();
assert_eq!(x, [SliceItem::Index(42)]);
let x = (2..5).into_slice_items();
assert_eq!(x, [SliceItem::Range((2..5).into())]);
let x = (..5).into_slice_items();
assert_eq!(x, [SliceItem::Range((0..5).into())]);
let x = (3..).into_slice_items();
assert_eq!(x, [SliceItem::Range((3..).into())]);
let x = [1].into_slice_items();
assert_eq!(x, [SliceItem::Index(1)]);
let x = [1, 2].into_slice_items();
assert_eq!(x, [SliceItem::Index(1), SliceItem::Index(2)]);
let x = (0, 1..2, ..).into_slice_items();
assert_eq!(
x,
[
SliceItem::Index(0),
SliceItem::Range((1..2).into()),
SliceItem::full_range()
]
);
}
#[test]
fn test_index_range() {
#[derive(Debug)]
struct Case {
range: SliceItem,
dim_size: usize,
indices: Vec<usize>,
}
let cases = [
Case {
range: SliceItem::range(0, Some(4), 1),
dim_size: 6,
indices: (0..4).collect(),
},
Case {
range: SliceItem::range(2, Some(4), 1),
dim_size: 6,
indices: vec![2, 3],
},
Case {
range: SliceItem::range(2, Some(128), 1),
dim_size: 5,
indices: vec![2, 3, 4],
},
Case {
range: SliceItem::range(0, Some(5), 2),
dim_size: 5,
indices: vec![0, 2, 4],
},
Case {
range: SliceItem::range(0, None, 1),
dim_size: 6,
indices: (0..6).collect(),
},
Case {
range: SliceItem::range(-1, Some(-6), 2),
dim_size: 5,
indices: vec![],
},
Case {
range: SliceItem::range(-1, Some(-128), -1),
dim_size: 5,
indices: vec![4, 3, 2, 1, 0],
},
Case {
range: SliceItem::range(-1, None, -1),
dim_size: 5,
indices: vec![4, 3, 2, 1, 0],
},
Case {
range: SliceItem::range(-1, Some(-6), -2),
dim_size: 5,
indices: vec![4, 2, 0],
},
Case {
range: SliceItem::range(1, Some(5), -2),
dim_size: 5,
indices: vec![],
},
Case {
range: SliceItem::range(0, Some(0), 1),
dim_size: 4,
indices: vec![],
},
Case {
range: SliceItem::range(0, Some(0), -1),
dim_size: 4,
indices: vec![],
},
Case {
range: SliceItem::Index(2),
dim_size: 4,
indices: vec![2],
},
Case {
range: SliceItem::Index(2),
dim_size: 0,
indices: vec![],
},
];
cases.test_each(|case| {
let Case {
range,
dim_size,
indices,
} = case;
let mut index_iter = range.index_range(*dim_size).into_iter();
let size_hint = index_iter.size_hint();
let index_vec: Vec<_> = index_iter.by_ref().collect();
assert_eq!(size_hint, (index_vec.len(), Some(index_vec.len())));
assert_eq!(index_vec, *indices);
assert_eq!(index_iter.size_hint(), (0, Some(0)));
})
}
#[test]
fn test_index_range_steps() {
#[derive(Debug)]
struct Case {
range: SliceRange,
dim_size: usize,
steps: usize,
}
let cases = [
Case {
range: SliceRange::new(0, None, 1),
dim_size: 4,
steps: 4,
},
Case {
range: SliceRange::new(0, None, 5),
dim_size: 4,
steps: 1,
},
Case {
range: SliceRange::new(-1, None, -1),
dim_size: 3,
steps: 3,
},
Case {
range: SliceRange::new(1, Some(0), -2),
dim_size: 2,
steps: 1,
},
];
cases.test_each(|case| {
assert_eq!(case.range.index_range(case.dim_size).steps(), case.steps);
})
}
#[test]
#[should_panic(expected = "Slice step cannot be 0")]
fn test_slice_range_zero_step() {
SliceRange::new(0, None, 0);
}
#[test]
fn test_slice_range_resolve() {
assert_eq!(SliceRange::new(0, Some(5), 1).resolve_clamped(10), 0..5);
assert_eq!(SliceRange::new(0, None, 1).resolve_clamped(10), 0..10);
assert_eq!(SliceRange::new(15, Some(20), 1).resolve_clamped(10), 10..10);
assert_eq!(SliceRange::new(15, Some(20), 1).resolve(10), None);
assert_eq!(SliceRange::new(4, None, 1).resolve(3), None);
assert_eq!(SliceRange::new(0, Some(10), 1).resolve(3), None);
assert_eq!(SliceRange::new(-5, Some(-1), 1).resolve_clamped(10), 5..9);
assert_eq!(SliceRange::new(-20, Some(-1), 1).resolve_clamped(10), 0..9);
assert_eq!(SliceRange::new(-20, Some(-1), 1).resolve(10), None);
assert_eq!(SliceRange::new(-5, None, 1).resolve_clamped(10), 5..10);
assert_eq!(SliceRange::new(5, Some(0), -1).resolve_clamped(10), 4..9);
assert_eq!(SliceRange::new(5, None, -1).resolve_clamped(10), 4..10);
assert_eq!(SliceRange::new(9, None, -1).resolve_clamped(10), 0..10);
assert_eq!(SliceRange::new(-1, Some(-4), -1).resolve_clamped(3), 0..3);
assert_eq!(SliceRange::new(-1, None, -1).resolve_clamped(2), 0..2);
}
}