use crate::layout::Layout;
use alloc::vec;
use alloc::vec::Vec;
pub struct StridedIter<'a> {
storage_index: isize,
multi_index: Vec<usize>,
layout: &'a Layout,
remaining: usize,
}
impl<'a> StridedIter<'a> {
pub fn new(layout: &'a Layout) -> Self {
let ndims = layout.num_dims();
Self {
storage_index: layout.start_offset() as isize,
multi_index: vec![0; ndims],
layout,
remaining: layout.num_elements(),
}
}
}
impl Iterator for StridedIter<'_> {
type Item = usize;
fn next(&mut self) -> Option<usize> {
if self.remaining == 0 {
return None;
}
debug_assert!(
self.storage_index >= 0,
"StridedIter: negative storage index"
);
let idx = self.storage_index as usize;
self.remaining -= 1;
let shape = self.layout.shape();
let strides = self.layout.strides();
for d in (0..shape.num_dims()).rev() {
self.multi_index[d] += 1;
if self.multi_index[d] < shape[d] {
self.storage_index += strides[d];
break;
}
self.multi_index[d] = 0;
self.storage_index -= (shape[d] as isize - 1) * strides[d];
}
Some(idx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl ExactSizeIterator for StridedIter<'_> {}
#[cfg(test)]
mod tests {
use super::*;
use burn_std::Shape;
#[test]
fn test_contiguous_iteration() {
let layout = Layout::contiguous(Shape::from(vec![2, 3]));
let indices: Vec<_> = StridedIter::new(&layout).collect();
assert_eq!(indices, vec![0, 1, 2, 3, 4, 5]);
}
#[test]
fn test_transposed_iteration() {
let layout = Layout::contiguous(Shape::from(vec![2, 3])).transpose(0, 1);
let indices: Vec<_> = StridedIter::new(&layout).collect();
assert_eq!(indices, vec![0, 3, 1, 4, 2, 5]);
}
#[test]
fn test_narrowed_iteration() {
let layout = Layout::contiguous(Shape::from(vec![2, 4])).narrow(1, 1, 2);
let indices: Vec<_> = StridedIter::new(&layout).collect();
assert_eq!(indices, vec![1, 2, 5, 6]);
}
}