ella_common/shape/
iter.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
use super::Shape;

#[derive(Debug)]
pub struct ShapeIndexIter<S> {
    pub(crate) shape: S,
    pub(crate) index: Option<S>,
}

impl<S: Shape> ShapeIndexIter<S> {
    pub(crate) fn new(shape: S) -> Self {
        let index = Self::first_index(&shape);
        Self { shape, index }
    }

    pub(crate) fn first_index(shape: &S) -> Option<S> {
        if shape.size() == 0 {
            None
        } else {
            Some(S::zeros(shape.ndim()))
        }
    }

    pub(crate) fn shape_next(shape: &S, mut index: S) -> Option<S> {
        for (&d, i) in shape.slice().iter().zip(index.as_mut()).rev() {
            *i += 1;
            if *i == d {
                *i = 0;
            } else {
                return Some(index);
            }
        }
        None
    }
}

impl<S: Shape> Iterator for ShapeIndexIter<S> {
    type Item = S;

    #[inline]
    fn next(&mut self) -> Option<Self::Item> {
        let index = self.index.as_ref()?.clone();
        self.index = Self::shape_next(&self.shape, index.clone());
        Some(index)
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        let len = self.len();
        (len, Some(len))
    }
}

impl<S: Shape> ExactSizeIterator for ShapeIndexIter<S> {
    fn len(&self) -> usize {
        match self.index {
            Some(ref idx) => {
                let consumed = self
                    .shape
                    .default_strides()
                    .slice()
                    .iter()
                    .zip(idx.slice())
                    .fold(0, |s, (&a, &b)| s + a * b);
                self.shape.size() - consumed
            }
            None => 0,
        }
    }
}