acme_tensor/actions/iter/
layout.rs1use crate::shape::Layout;
6
7#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
8#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize,))]
9#[repr(C)]
10pub struct Position {
11 pub(crate) index: usize,
12 pub(crate) position: Vec<usize>,
13}
14
15pub struct LayoutIter {
19 layout: Layout,
20 next: Option<usize>,
21 pos: Vec<usize>,
22}
23
24impl LayoutIter {
25 pub(crate) fn new(layout: Layout) -> Self {
26 let next = if layout.size() == 0 {
27 None
28 } else {
29 Some(layout.offset())
31 };
32 let pos = vec![0; *layout.rank()];
33 Self { next, layout, pos }
34 }
35
36 pub(crate) fn index(&self, index: impl AsRef<[usize]>) -> usize {
37 self.layout.index(index)
38 }
39}
40
41impl DoubleEndedIterator for LayoutIter {
42 fn next_back(&mut self) -> Option<Self::Item> {
43 let Position { position, .. } = self.next()?;
44 let rev = self
45 .layout
46 .shape()
47 .get_final_position()
48 .iter()
49 .zip(position.iter())
50 .map(|(s, p)| (s - p))
51 .collect();
52 let pos = Position::new(self.index(&rev), rev);
53 Some(pos)
54 }
55}
56
57impl ExactSizeIterator for LayoutIter {
58 fn len(&self) -> usize {
59 self.layout.size()
60 }
61}
62
63impl Iterator for LayoutIter {
64 type Item = Position;
65
66 fn next(&mut self) -> Option<Self::Item> {
67 let index = match self.next {
68 None => return None,
69 Some(i) => i,
70 };
71 let cur = Position::new(index, self.pos.clone());
72 let mut updated = false;
73 let mut next = index;
74 for ((i, j), s) in self
75 .pos
76 .iter_mut()
77 .zip(self.layout.shape.iter())
78 .zip(self.layout.strides.iter())
79 .rev()
80 {
81 let next_i = *i + 1;
82 if next_i < *j {
83 *i = next_i;
84 updated = true;
85 next += s;
86 break;
87 } else {
88 next -= *i * s;
89 *i = 0;
90 }
91 }
92 self.next = if updated { Some(next) } else { None };
93 Some(cur)
94 }
95}
96
97mod impl_position {
98 use super::Position;
99 use crate::shape::Layout;
100
101 impl Position {
102 pub fn new(index: usize, position: Vec<usize>) -> Self {
103 Self { index, position }
104 }
105
106 pub fn first(rank: usize) -> Self {
107 Self::new(0, vec![0; rank])
108 }
109 pub fn index(&self) -> usize {
111 self.index
112 }
113 pub fn next(&self, layout: &Layout) -> Option<Self> {
115 let mut position = self.position().to_vec();
116 let mut updated = false;
117 let mut next = self.index();
118 for ((i, j), s) in position
119 .iter_mut()
120 .zip(layout.shape().iter())
121 .zip(layout.strides().iter())
122 .rev()
123 {
124 let next_i = *i + 1;
125 if next_i < *j {
126 *i = next_i;
127 updated = true;
128 next += s;
129 break;
130 } else {
131 next -= *i * s;
132 *i = 0;
133 }
134 }
135 if updated {
136 Some(Self::new(next, position))
137 } else {
138 None
139 }
140 }
141 pub fn position(&self) -> &[usize] {
143 &self.position
144 }
145 pub fn position_mut(&mut self) -> &mut [usize] {
147 &mut self.position
148 }
149 }
150
151 impl From<(usize, Vec<usize>)> for Position {
152 fn from((idx, pos): (usize, Vec<usize>)) -> Self {
153 Self::new(idx, pos)
154 }
155 }
156
157 impl From<(Vec<usize>, usize)> for Position {
158 fn from((pos, idx): (Vec<usize>, usize)) -> Self {
159 Self::new(idx, pos)
160 }
161 }
162
163 impl From<Position> for (usize, Vec<usize>) {
164 fn from(pos: Position) -> (usize, Vec<usize>) {
165 (pos.index, pos.position)
166 }
167 }
168}