acme_tensor/actions/iter/
layout.rs

1/*
2    Appellation: layout <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::shape::Layout;
6
7#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
8#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize,))]
9#[repr(C)]
10pub struct Position {
11    pub(crate) index: usize,
12    pub(crate) position: Vec<usize>,
13}
14
15/// An iterator over the positions of an n-dimensional tensor.
16/// Each step yields a [position](Position) containing the current position
17/// and corresponding index.
18pub struct LayoutIter {
19    layout: Layout,
20    next: Option<usize>,
21    pos: Vec<usize>,
22}
23
24impl LayoutIter {
25    pub(crate) fn new(layout: Layout) -> Self {
26        let next = if layout.size() == 0 {
27            None
28        } else {
29            // This applies to the scalar case.
30            Some(layout.offset())
31        };
32        let pos = vec![0; *layout.rank()];
33        Self { next, layout, pos }
34    }
35
36    pub(crate) fn index(&self, index: impl AsRef<[usize]>) -> usize {
37        self.layout.index(index)
38    }
39}
40
41impl DoubleEndedIterator for LayoutIter {
42    fn next_back(&mut self) -> Option<Self::Item> {
43        let Position { position, .. } = self.next()?;
44        let rev = self
45            .layout
46            .shape()
47            .get_final_position()
48            .iter()
49            .zip(position.iter())
50            .map(|(s, p)| (s - p))
51            .collect();
52        let pos = Position::new(self.index(&rev), rev);
53        Some(pos)
54    }
55}
56
57impl ExactSizeIterator for LayoutIter {
58    fn len(&self) -> usize {
59        self.layout.size()
60    }
61}
62
63impl Iterator for LayoutIter {
64    type Item = Position;
65
66    fn next(&mut self) -> Option<Self::Item> {
67        let index = match self.next {
68            None => return None,
69            Some(i) => i,
70        };
71        let cur = Position::new(index, self.pos.clone());
72        let mut updated = false;
73        let mut next = index;
74        for ((i, j), s) in self
75            .pos
76            .iter_mut()
77            .zip(self.layout.shape.iter())
78            .zip(self.layout.strides.iter())
79            .rev()
80        {
81            let next_i = *i + 1;
82            if next_i < *j {
83                *i = next_i;
84                updated = true;
85                next += s;
86                break;
87            } else {
88                next -= *i * s;
89                *i = 0;
90            }
91        }
92        self.next = if updated { Some(next) } else { None };
93        Some(cur)
94    }
95}
96
97mod impl_position {
98    use super::Position;
99    use crate::shape::Layout;
100
101    impl Position {
102        pub fn new(index: usize, position: Vec<usize>) -> Self {
103            Self { index, position }
104        }
105
106        pub fn first(rank: usize) -> Self {
107            Self::new(0, vec![0; rank])
108        }
109        /// Returns the index of the position.
110        pub fn index(&self) -> usize {
111            self.index
112        }
113        /// Given a particular layout, returns the next position.
114        pub fn next(&self, layout: &Layout) -> Option<Self> {
115            let mut position = self.position().to_vec();
116            let mut updated = false;
117            let mut next = self.index();
118            for ((i, j), s) in position
119                .iter_mut()
120                .zip(layout.shape().iter())
121                .zip(layout.strides().iter())
122                .rev()
123            {
124                let next_i = *i + 1;
125                if next_i < *j {
126                    *i = next_i;
127                    updated = true;
128                    next += s;
129                    break;
130                } else {
131                    next -= *i * s;
132                    *i = 0;
133                }
134            }
135            if updated {
136                Some(Self::new(next, position))
137            } else {
138                None
139            }
140        }
141        /// Returns a reference to the position.
142        pub fn position(&self) -> &[usize] {
143            &self.position
144        }
145        /// Returns a mutable reference to the position.
146        pub fn position_mut(&mut self) -> &mut [usize] {
147            &mut self.position
148        }
149    }
150
151    impl From<(usize, Vec<usize>)> for Position {
152        fn from((idx, pos): (usize, Vec<usize>)) -> Self {
153            Self::new(idx, pos)
154        }
155    }
156
157    impl From<(Vec<usize>, usize)> for Position {
158        fn from((pos, idx): (Vec<usize>, usize)) -> Self {
159            Self::new(idx, pos)
160        }
161    }
162
163    impl From<Position> for (usize, Vec<usize>) {
164        fn from(pos: Position) -> (usize, Vec<usize>) {
165            (pos.index, pos.position)
166        }
167    }
168}