1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
use crate::{tensor::ephemeral_view::EphemeralView, traits::ToIterator};

pub struct TensorIter<'a, Dtype> {
    i: i64,
    num_elements: i64,
    tensor_view: EphemeralView<'a, Dtype>,
}

impl<'a, Dtype> TensorIter<'a, Dtype> {
    fn new(tensor: EphemeralView<'a, Dtype>) -> TensorIter<'a, Dtype> {
        TensorIter {
            i: 0,
            num_elements: tensor.shape.num_elements() as i64,
            tensor_view: tensor,
        }
    }
}

impl<'a, Dtype> Iterator for TensorIter<'a, Dtype>
where
    Dtype: Copy,
{
    type Item = Dtype;

    fn next(&mut self) -> Option<Self::Item> {
        if self.i >= self.num_elements {
            None
        } else {
            let mut vec_index = 0;
            let mut index = self.i;
            for (len, stride) in
                self.tensor_view.shape.dims_strides.iter().rev()
            {
                vec_index += (index % len) * stride;
                index /= len;
            }
            self.i += 1;
            Some(self.tensor_view.data[vec_index as usize])
        }
    }
}

impl<'a, Dtype> ToIterator<'a, TensorIter<'a, Dtype>, Dtype>
    for EphemeralView<'a, Dtype>
where
    Dtype: Copy,
{
    fn to_iter(&'a self) -> TensorIter<'a, Dtype> {
        TensorIter::new(self.clone())
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        tensor::{
            ephemeral_view::{EphemeralView, ToEphemeralView},
            matrix::Matrix,
            matrix_transpose::MatrixTranspose,
            tensor_storage::IntoTensorStorage,
        },
        traits::ToIterator,
    };

    #[test]
    fn test_tensor_iter() {
        let matrix =
            Matrix::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 3, 4);
        let view = EphemeralView::from(&matrix);

        for (i, val) in view.to_iter().enumerate() {
            assert_eq!(val, i + 1);
        }

        {
            let transposed = view.t();
            for (val, expected) in transposed
                .to_iter()
                .zip(vec![1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12].into_iter())
            {
                assert_eq!(val, expected);
            }
        }

        let t2 = (0..24)
            .into_iter()
            .collect::<Vec<i32>>()
            .into_tensor_storage();
        let t2_view = t2.as_shape([4, 3, 2]);

        // t2_102 has shape [3, 4, 2]
        let t2_102 = t2_view.transpose(vec![1, 0, 2]);
        for (val, expected) in t2_102.to_iter().zip(
            vec![
                0, 1, 6, 7, 12, 13, 18, 19, 2, 3, 8, 9, 14, 15, 20, 21, 4, 5,
                10, 11, 16, 17, 22, 23,
            ]
            .into_iter(),
        ) {
            assert_eq!(val, expected);
        }
    }
}