1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
/*
    Appellation: position <mod>
    Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::prelude::{Layout, Shape, Stride};

///
pub struct IndexIter<'a> {
    next: Option<usize>,
    position: Vec<usize>,
    shape: &'a Shape,
    stride: &'a Stride,
}

impl<'a> IndexIter<'a> {
    pub fn new(offset: usize, shape: &'a Shape, stride: &'a Stride) -> Self {
        let elem_count: usize = shape.iter().product();
        let next = if elem_count == 0 {
            None
        } else {
            // This applies to the scalar case.
            Some(offset)
        };
        Self {
            next,
            position: vec![0; *shape.rank()],
            shape,
            stride,
        }
    }

    pub(crate) fn index(&self, index: impl AsRef<[usize]>) -> usize {
        index
            .as_ref()
            .iter()
            .zip(self.stride.iter())
            .map(|(i, s)| i * s)
            .sum()
    }
}

impl<'a> DoubleEndedIterator for IndexIter<'a> {
    fn next_back(&mut self) -> Option<Self::Item> {
        let (pos, _idx) = if let Some(item) = self.next() {
            item
        } else {
            return None;
        };
        let position = self
            .shape
            .iter()
            .zip(pos.iter())
            .map(|(s, p)| s - p)
            .collect();
        let scope = self.index(&position);
        Some((position, scope))
    }
}

impl<'a> Iterator for IndexIter<'a> {
    type Item = (Vec<usize>, usize);

    fn next(&mut self) -> Option<Self::Item> {
        let scope = match self.next {
            None => return None,
            Some(storage_index) => storage_index,
        };
        let mut updated = false;
        let mut next = scope;
        for ((multi_i, max_i), stride_i) in self
            .position
            .iter_mut()
            .zip(self.shape.iter())
            .zip(self.stride.iter())
            .rev()
        {
            let next_i = *multi_i + 1;
            if next_i < *max_i {
                *multi_i = next_i;
                updated = true;
                next += stride_i;
                break;
            } else {
                next -= *multi_i * stride_i;
                *multi_i = 0
            }
        }
        self.next = if updated { Some(next) } else { None };
        Some((self.position.clone(), scope))
    }
}

impl<'a> From<&'a Layout> for IndexIter<'a> {
    fn from(layout: &'a Layout) -> Self {
        Self::new(layout.offset, &layout.shape, &layout.strides)
    }
}