use std::iter::FusedIterator;
use crate::Layout;
pub const MAX_DIMS: usize = 8;
pub struct NdIter<const N: usize> {
pub inner_size: usize,
pub inner_strides: [usize; N],
outer_dims: [usize; MAX_DIMS],
outer_strides: [[usize; MAX_DIMS]; N],
outer_len: usize,
offsets: [usize; N],
coords: [usize; MAX_DIMS],
remaining: usize,
}
impl<const N: usize> NdIter<N> {
pub fn new(layouts: [&Layout; N]) -> NdIter<N> {
let dims = layouts[0].dims();
debug_assert!(
dims.len() <= MAX_DIMS,
"rank {} exceeds MAX_DIMS={}",
dims.len(),
MAX_DIMS
);
#[cfg(debug_assertions)]
for l in &layouts {
debug_assert_eq!(l.dims(), dims);
}
let rank = dims.len();
let mut out_dims = [0usize; MAX_DIMS];
let mut out_strides = [[0usize; MAX_DIMS]; N];
let mut out_len;
if rank == 0 {
out_dims[0] = 1;
out_len = 1;
} else {
out_dims[0] = dims[0];
for n in 0..N {
out_strides[n][0] = layouts[n].stride()[0];
}
out_len = 1;
for (i, d) in dims.iter().enumerate().take(rank).skip(1) {
let top = out_len - 1;
let last_d = out_dims[top];
let (can_merge, use_inner) = if last_d == 1 {
(true, true)
} else if *d == 1 {
(true, false)
} else {
let can_merge =
(0..N).all(|n| out_strides[n][top] == layouts[n].stride()[i] * d);
(can_merge, true)
};
if can_merge {
out_dims[top] = last_d * d;
if use_inner {
for n in 0..N {
out_strides[n][top] = layouts[n].stride()[i];
}
}
} else {
out_dims[out_len] = *d;
for n in 0..N {
out_strides[n][out_len] = layouts[n].stride()[i];
}
out_len += 1;
}
}
}
let inner_idx = out_len - 1;
let inner_size = out_dims[inner_idx];
let mut inner_strides = [0usize; N];
for n in 0..N {
inner_strides[n] = out_strides[n][inner_idx];
}
let outer_len = inner_idx;
let mut outer_dims = [0usize; MAX_DIMS];
outer_dims[..outer_len].copy_from_slice(&out_dims[..outer_len]);
let mut outer_strides = [[0usize; MAX_DIMS]; N];
for n in 0..N {
outer_strides[n][..outer_len].copy_from_slice(&out_strides[n][..outer_len]);
}
let mut offsets = [0usize; N];
for n in 0..N {
offsets[n] = layouts[n].start_offset();
}
let remaining = out_dims[..outer_len].iter().product();
NdIter {
inner_size,
inner_strides,
outer_dims,
outer_strides,
outer_len,
offsets,
coords: [0; MAX_DIMS],
remaining,
}
}
}
impl<const N: usize> Iterator for NdIter<N> {
type Item = [usize; N];
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
let item = self.offsets;
self.remaining -= 1;
for k in (0..self.outer_len).rev() {
self.coords[k] += 1;
for n in 0..N {
self.offsets[n] += self.outer_strides[n][k];
}
if self.coords[k] < self.outer_dims[k] {
break;
}
self.coords[k] = 0;
for n in 0..N {
self.offsets[n] -= self.outer_dims[k] * self.outer_strides[n][k];
}
}
Some(item)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl<const N: usize> ExactSizeIterator for NdIter<N> {}
impl<const N: usize> FusedIterator for NdIter<N> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::shape::Shape;
fn layout(dims: &[usize], strides: &[usize]) -> Layout {
Layout::new(Shape::from(dims.to_vec()), strides.to_vec(), 0)
}
#[test]
fn rank0_scalar() {
let l = Layout::contiguous(());
let mut it = NdIter::new([&l, &l]);
assert_eq!(it.inner_size, 1);
assert_eq!(it.inner_strides, [0, 0]);
assert_eq!(it.len(), 1);
assert_eq!(it.next(), Some([0, 0]));
assert_eq!(it.next(), None);
}
#[test]
fn rank1_contiguous_single_block() {
let l = Layout::contiguous(&[5]);
let mut it = NdIter::new([&l, &l]);
assert_eq!(it.inner_size, 5);
assert_eq!(it.inner_strides, [1, 1]);
assert_eq!(it.len(), 1);
assert_eq!(it.next(), Some([0, 0]));
assert_eq!(it.next(), None);
}
#[test]
fn rank2_contiguous_merges_to_one_block() {
let l = Layout::contiguous(&[3, 4]);
let it = NdIter::new([&l, &l]);
assert_eq!(it.inner_size, 12);
assert_eq!(it.inner_strides, [1, 1]);
assert_eq!(it.len(), 1);
}
#[test]
fn rank3_contiguous_fully_merged() {
let l = Layout::contiguous(&[2, 3, 4]);
let it = NdIter::new([&l]);
assert_eq!(it.inner_size, 24);
assert_eq!(it.inner_strides, [1]);
assert_eq!(it.len(), 1);
}
#[test]
fn rank3_outer_gap_partial_merge() {
let l = layout(&[2, 3, 4], &[24, 4, 1]);
let it = NdIter::new([&l]);
assert_eq!(it.inner_size, 12); assert_eq!(it.inner_strides, [1]);
assert_eq!(it.len(), 2); let offsets: Vec<_> = it.collect();
assert_eq!(offsets, vec![[0], [24]]);
}
#[test]
fn rank2_no_merge() {
let l = layout(&[3, 4], &[1, 3]);
let it = NdIter::new([&l, &l]);
assert_eq!(it.inner_size, 4);
assert_eq!(it.inner_strides, [3, 3]);
assert_eq!(it.len(), 3);
let offsets: Vec<_> = it.collect();
assert_eq!(offsets, vec![[0, 0], [1, 1], [2, 2]]);
}
#[test]
fn broadcast_zeros_merge() {
let l = layout(&[3, 4], &[0, 0]);
let it = NdIter::new([&l, &l]);
assert_eq!(it.inner_size, 12);
assert_eq!(it.inner_strides, [0, 0]);
assert_eq!(it.len(), 1);
}
#[test]
fn mixed_contiguous_and_broadcast_merge() {
let lhs = Layout::contiguous(&[3, 4]);
let rhs = layout(&[3, 4], &[0, 0]);
let it = NdIter::new([&lhs, &rhs]);
assert_eq!(it.inner_size, 12);
assert_eq!(it.inner_strides, [1, 0]);
assert_eq!(it.len(), 1);
}
#[test]
fn offsets_lhs_contiguous_rhs_strided() {
let lhs = Layout::contiguous(&[2, 3]);
let rhs = layout(&[2, 3], &[1, 2]);
let it = NdIter::new([&lhs, &rhs]);
assert_eq!(it.inner_size, 3);
assert_eq!(it.inner_strides, [1, 2]);
assert_eq!(it.len(), 2);
let offsets: Vec<_> = it.collect();
assert_eq!(offsets, vec![[0, 0], [3, 1]]);
}
#[test]
fn start_offset_reflected_in_first_iter() {
let l = Layout::contiguous_with_offset(4, 7);
let mut it = NdIter::new([&l]);
assert_eq!(it.next(), Some([7]));
assert_eq!(it.next(), None);
}
#[test]
fn start_offset_advances_with_outer_dims() {
let l = Layout::new(Shape::from(vec![2, 3]), vec![4, 1], 10);
let offsets: Vec<_> = NdIter::new([&l]).collect();
assert_eq!(offsets, vec![[10], [14]]);
}
}