use std::iter::FusedIterator;
use std::ops::Range;
use smallvec::{smallvec, SmallVec};
pub trait IndexArray: AsMut<[usize]> + AsRef<[usize]> + Clone {}
impl<const N: usize> IndexArray for SmallVec<[usize; N]> {}
impl<const N: usize> IndexArray for [usize; N] {}
pub type DynIndex = SmallVec<[usize; 5]>;
pub struct Indices<Index: IndexArray>
where
Index: IndexArray,
{
start: Index,
end: Index,
next: Option<Index>,
steps: usize,
}
fn steps(from: &[usize], to: &[usize]) -> usize {
assert!(from.len() == to.len());
let mut product = 1;
for (&from, &to) in from.iter().zip(to.iter()).rev() {
let size = to.saturating_sub(from);
product *= size;
}
product
}
impl<Index: IndexArray> Indices<Index> {
fn from_start_and_end(start: Index, end: Index) -> Indices<Index> {
let steps = steps(start.as_ref(), end.as_ref());
Indices {
next: if steps > 0 || start.as_ref().is_empty() {
Some(start.clone())
} else {
None
},
start,
end,
steps,
}
}
}
impl<const N: usize> Indices<SmallVec<[usize; N]>> {
pub fn from_ranges(ranges: &[Range<usize>]) -> Indices<SmallVec<[usize; N]>> {
let start: SmallVec<[usize; N]> = ranges.iter().map(|r| r.start).collect();
let end = ranges.iter().map(|r| r.end).collect();
Self::from_start_and_end(start, end)
}
pub fn from_shape(shape: &[usize]) -> Indices<SmallVec<[usize; N]>> {
let start = smallvec![0; shape.len()];
let end = shape.iter().copied().collect();
Self::from_start_and_end(start, end)
}
}
impl<const N: usize> Indices<[usize; N]> {
pub fn from_ranges(ranges: [Range<usize>; N]) -> Indices<[usize; N]> {
let start = ranges.clone().map(|r| r.start);
let end = ranges.map(|r| r.end);
Self::from_start_and_end(start, end)
}
pub fn from_shape(shape: [usize; N]) -> Indices<[usize; N]> {
Self::from_ranges(shape.map(|size| 0..size))
}
}
impl<Index: IndexArray> Iterator for Indices<Index> {
type Item = Index;
fn next(&mut self) -> Option<Self::Item> {
let current = self.next.clone()?;
let mut next = current.clone();
let mut has_next = false;
for ((&dim_end, &dim_start), index) in self
.end
.as_ref()
.iter()
.zip(self.start.as_ref())
.zip(next.as_mut().iter_mut())
.rev()
{
*index += 1;
if *index == dim_end {
*index = dim_start;
} else {
has_next = true;
break;
}
}
self.next = has_next.then_some(next);
Some(current)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.steps, Some(self.steps))
}
}
impl<Index: IndexArray> ExactSizeIterator for Indices<Index> {}
impl<Index: IndexArray> FusedIterator for Indices<Index> {}
pub struct NdIndices<const N: usize> {
inner: Indices<[usize; N]>,
}
impl<const N: usize> NdIndices<N> {
pub fn from_ranges(ranges: [Range<usize>; N]) -> NdIndices<N> {
NdIndices {
inner: Indices::<[usize; N]>::from_ranges(ranges),
}
}
pub fn from_shape(shape: [usize; N]) -> NdIndices<N> {
NdIndices {
inner: Indices::<[usize; N]>::from_shape(shape),
}
}
}
impl<const N: usize> Iterator for NdIndices<N> {
type Item = [usize; N];
fn next(&mut self) -> Option<Self::Item> {
self.inner.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl<const N: usize> ExactSizeIterator for NdIndices<N> {}
impl<const N: usize> FusedIterator for NdIndices<N> {}
const DYN_SMALL_LEN: usize = 4;
enum DynIndicesInner {
Small {
iter: NdIndices<DYN_SMALL_LEN>,
pad: usize,
},
Large(Indices<DynIndex>),
}
pub struct DynIndices {
inner: DynIndicesInner,
}
fn left_pad_shape<const N: usize>(shape: &[usize]) -> (usize, [usize; N]) {
assert!(shape.len() <= N);
let mut padded_shape = [0; N];
let pad = N - shape.len();
for i in 0..pad {
padded_shape[i] = 1;
}
for i in pad..N {
padded_shape[i] = shape[i - pad];
}
(N - shape.len(), padded_shape)
}
fn left_pad_ranges<const N: usize>(ranges: &[Range<usize>]) -> (usize, [Range<usize>; N]) {
assert!(ranges.len() <= N);
let mut padded_ranges = SmallVec::<[Range<usize>; N]>::from_elem(0..1, N);
let pad = N - ranges.len();
for i in 0..pad {
padded_ranges[i] = 0..1;
}
for i in pad..N {
padded_ranges[i] = ranges[i - pad].clone();
}
(N - ranges.len(), padded_ranges.into_inner().unwrap())
}
impl DynIndices {
pub fn from_shape(shape: &[usize]) -> DynIndices {
let inner = if shape.len() <= DYN_SMALL_LEN {
let (pad, padded) = left_pad_shape(shape);
DynIndicesInner::Small {
iter: NdIndices::from_shape(padded),
pad,
}
} else {
DynIndicesInner::Large(Indices::<DynIndex>::from_shape(shape))
};
DynIndices { inner }
}
pub fn from_ranges(ranges: &[Range<usize>]) -> DynIndices {
let inner = if ranges.len() <= DYN_SMALL_LEN {
let (pad, padded) = left_pad_ranges(ranges);
DynIndicesInner::Small {
iter: NdIndices::from_ranges(padded),
pad,
}
} else {
DynIndicesInner::Large(Indices::<DynIndex>::from_ranges(ranges))
};
DynIndices { inner }
}
}
impl Iterator for DynIndices {
type Item = DynIndex;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
match self.inner {
DynIndicesInner::Small { ref mut iter, pad } => {
iter.next().map(|idx| SmallVec::from_slice(&idx[pad..]))
}
DynIndicesInner::Large(ref mut inner) => inner.next(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self.inner {
DynIndicesInner::Small { ref iter, .. } => iter.size_hint(),
DynIndicesInner::Large(ref inner) => inner.size_hint(),
}
}
}
impl ExactSizeIterator for DynIndices {}
impl FusedIterator for DynIndices {}
#[cfg(test)]
mod tests {
use super::{DynIndices, NdIndices};
#[test]
fn test_nd_indices() {
let mut iter = NdIndices::from_ranges([0..0]);
assert_eq!(iter.next(), None);
assert_eq!(iter.next(), None);
let mut iter = NdIndices::from_ranges([]);
assert_eq!(iter.next(), Some([]));
assert_eq!(iter.next(), None);
let iter = NdIndices::from_ranges([0..5]);
let visited: Vec<_> = iter.collect();
assert_eq!(visited, &[[0], [1], [2], [3], [4]]);
let iter = NdIndices::from_ranges([2..4, 2..4]);
let visited: Vec<_> = iter.collect();
assert_eq!(visited, &[[2, 2], [2, 3], [3, 2], [3, 3]]);
}
#[test]
fn test_dyn_indices() {
type Index = <DynIndices as Iterator>::Item;
let mut iter = DynIndices::from_ranges(&[0..0]);
assert_eq!(iter.next(), None);
assert_eq!(iter.next(), None);
let mut iter = DynIndices::from_ranges(&[]);
assert_eq!(iter.next(), Some(Index::new()));
assert_eq!(iter.next(), None);
let iter = DynIndices::from_ranges(&[0..5]);
let visited: Vec<Vec<usize>> = iter.map(|ix| ix.into_iter().collect()).collect();
assert_eq!(visited, vec![vec![0], vec![1], vec![2], vec![3], vec![4]]);
let iter = DynIndices::from_ranges(&[2..4, 2..4]);
let visited: Vec<Vec<usize>> = iter.map(|ix| ix.into_iter().collect()).collect();
assert_eq!(
visited,
vec![vec![2, 2], vec![2, 3], vec![3, 2], vec![3, 3],]
);
let iter = DynIndices::from_shape(&[2, 1, 1, 2, 2]);
let visited: Vec<Vec<usize>> = iter.map(|ix| ix.into_iter().collect()).collect();
assert_eq!(
visited,
vec![
vec![0, 0, 0, 0, 0],
vec![0, 0, 0, 0, 1],
vec![0, 0, 0, 1, 0],
vec![0, 0, 0, 1, 1],
vec![1, 0, 0, 0, 0],
vec![1, 0, 0, 0, 1],
vec![1, 0, 0, 1, 0],
vec![1, 0, 0, 1, 1],
]
);
}
#[test]
#[ignore]
fn bench_indices() {
use std::time::Instant;
let shape = std::hint::black_box([16, 128, 128]);
let start = Instant::now();
let mut count = 0;
for _ in 0..100 {
let indices = DynIndices::from_shape(&shape);
for _ in indices {
count += 1;
}
}
let elapsed = start.elapsed().as_millis();
println!("DynIndices stepped {} times in {} ms", count, elapsed);
let start = Instant::now();
let mut count = 0;
for _ in 0..100 {
let indices = NdIndices::from_shape(shape);
for _ in indices {
count += 1;
}
}
let elapsed = start.elapsed().as_millis();
println!("NdIndices stepped {} times in {} ms", count, elapsed);
}
}