1use super::Dimension;
9use crate::dimension::IntoDimension;
10use crate::split_at::SplitAt;
11use crate::zip::Offset;
12use crate::Axis;
13use crate::Layout;
14use crate::NdProducer;
15use crate::{ArrayBase, Data};
16
17#[derive(Clone)]
21pub struct IndicesIter<D>
22{
23 dim: D,
24 index: Option<D>,
25}
26
27pub fn indices<E>(shape: E) -> Indices<E::Dim>
32where E: IntoDimension
33{
34 let dim = shape.into_dimension();
35 Indices {
36 start: E::Dim::zeros(dim.ndim()),
37 dim,
38 }
39}
40
41pub fn indices_of<S, D>(array: &ArrayBase<S, D>) -> Indices<D>
46where
47 S: Data,
48 D: Dimension,
49{
50 indices(array.dim())
51}
52
53impl<D> Iterator for IndicesIter<D>
54where D: Dimension
55{
56 type Item = D::Pattern;
57 #[inline]
58 fn next(&mut self) -> Option<Self::Item>
59 {
60 let index = match self.index {
61 None => return None,
62 Some(ref ix) => ix.clone(),
63 };
64 self.index = self.dim.next_for(index.clone());
65 Some(index.into_pattern())
66 }
67
68 fn size_hint(&self) -> (usize, Option<usize>)
69 {
70 let l = match self.index {
71 None => 0,
72 Some(ref ix) => {
73 let gone = self
74 .dim
75 .default_strides()
76 .slice()
77 .iter()
78 .zip(ix.slice().iter())
79 .fold(0, |s, (&a, &b)| s + a * b);
80 self.dim.size() - gone
81 }
82 };
83 (l, Some(l))
84 }
85
86 fn fold<B, F>(self, init: B, mut f: F) -> B
87 where F: FnMut(B, D::Pattern) -> B
88 {
89 let IndicesIter { mut index, dim } = self;
90 let ndim = dim.ndim();
91 if ndim == 0 {
92 return match index {
93 Some(ix) => f(init, ix.into_pattern()),
94 None => init,
95 };
96 }
97 let inner_axis = ndim - 1;
98 let inner_len = dim[inner_axis];
99 let mut acc = init;
100 while let Some(mut ix) = index {
101 for i in ix[inner_axis]..inner_len {
103 ix[inner_axis] = i;
104 acc = f(acc, ix.clone().into_pattern());
105 }
106 index = dim.next_for(ix);
107 }
108 acc
109 }
110}
111
112impl<D> ExactSizeIterator for IndicesIter<D> where D: Dimension {}
113
114impl<D> IntoIterator for Indices<D>
115where D: Dimension
116{
117 type Item = D::Pattern;
118 type IntoIter = IndicesIter<D>;
119 fn into_iter(self) -> Self::IntoIter
120 {
121 let sz = self.dim.size();
122 let index = if sz != 0 { Some(self.start) } else { None };
123 IndicesIter { index, dim: self.dim }
124 }
125}
126
127#[derive(Copy, Clone, Debug)]
131pub struct Indices<D>
132where D: Dimension
133{
134 start: D,
135 dim: D,
136}
137
138#[derive(Copy, Clone, Debug)]
139pub struct IndexPtr<D>
140{
141 index: D,
142}
143
144impl<D> Offset for IndexPtr<D>
145where D: Dimension + Copy
146{
147 type Stride = usize;
149
150 unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self
151 {
152 self.index[stride] += index;
153 self
154 }
155 private_impl! {}
156}
157
158impl<D: Dimension + Copy> NdProducer for Indices<D>
173{
174 type Item = D::Pattern;
175 type Dim = D;
176 type Ptr = IndexPtr<D>;
177 type Stride = usize;
178
179 private_impl! {}
180
181 fn raw_dim(&self) -> Self::Dim
182 {
183 self.dim
184 }
185
186 fn equal_dim(&self, dim: &Self::Dim) -> bool
187 {
188 self.dim.equal(dim)
189 }
190
191 fn as_ptr(&self) -> Self::Ptr
192 {
193 IndexPtr { index: self.start }
194 }
195
196 fn layout(&self) -> Layout
197 {
198 if self.dim.ndim() <= 1 {
199 Layout::one_dimensional()
200 } else {
201 Layout::none()
202 }
203 }
204
205 unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item
206 {
207 ptr.index.into_pattern()
208 }
209
210 unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr
211 {
212 let mut index = *i;
213 index += &self.start;
214 IndexPtr { index }
215 }
216
217 fn stride_of(&self, axis: Axis) -> Self::Stride
218 {
219 axis.index()
220 }
221
222 #[inline(always)]
223 fn contiguous_stride(&self) -> Self::Stride
224 {
225 0
226 }
227
228 fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
229 {
230 let start_a = self.start;
231 let mut start_b = start_a;
232 let (a, b) = self.dim.split_at(axis, index);
233 start_b[axis.index()] += index;
234 (Indices { start: start_a, dim: a }, Indices { start: start_b, dim: b })
235 }
236}
237
238#[derive(Clone)]
242pub struct IndicesIterF<D>
243{
244 dim: D,
245 index: D,
246 has_remaining: bool,
247}
248
249pub fn indices_iter_f<E>(shape: E) -> IndicesIterF<E::Dim>
250where E: IntoDimension
251{
252 let dim = shape.into_dimension();
253 let zero = E::Dim::zeros(dim.ndim());
254 IndicesIterF {
255 has_remaining: dim.size_checked() != Some(0),
256 index: zero,
257 dim,
258 }
259}
260
261impl<D> Iterator for IndicesIterF<D>
262where D: Dimension
263{
264 type Item = D::Pattern;
265 #[inline]
266 fn next(&mut self) -> Option<Self::Item>
267 {
268 if !self.has_remaining {
269 None
270 } else {
271 let elt = self.index.clone().into_pattern();
272 self.has_remaining = self.dim.next_for_f(&mut self.index);
273 Some(elt)
274 }
275 }
276
277 fn size_hint(&self) -> (usize, Option<usize>)
278 {
279 if !self.has_remaining {
280 return (0, Some(0));
281 }
282 let gone = self
283 .dim
284 .fortran_strides()
285 .slice()
286 .iter()
287 .zip(self.index.slice().iter())
288 .fold(0, |s, (&a, &b)| s + a * b);
289 let l = self.dim.size() - gone;
290 (l, Some(l))
291 }
292}
293
294impl<D> ExactSizeIterator for IndicesIterF<D> where D: Dimension {}
295
296#[cfg(test)]
297mod tests
298{
299 use super::indices;
300 use super::indices_iter_f;
301
302 #[test]
303 fn test_indices_iter_c_size_hint()
304 {
305 let dim = (3, 4);
306 let mut it = indices(dim).into_iter();
307 let mut len = dim.0 * dim.1;
308 assert_eq!(it.len(), len);
309 while let Some(_) = it.next() {
310 len -= 1;
311 assert_eq!(it.len(), len);
312 }
313 assert_eq!(len, 0);
314 }
315
316 #[test]
317 fn test_indices_iter_c_fold()
318 {
319 macro_rules! run_test {
320 ($dim:expr) => {
321 for num_consume in 0..3 {
322 let mut it = indices($dim).into_iter();
323 for _ in 0..num_consume {
324 it.next();
325 }
326 let clone = it.clone();
327 let len = it.len();
328 let acc = clone.fold(0, |acc, ix| {
329 assert_eq!(ix, it.next().unwrap());
330 acc + 1
331 });
332 assert_eq!(acc, len);
333 assert!(it.next().is_none());
334 }
335 };
336 }
337 run_test!(());
338 run_test!((2,));
339 run_test!((2, 3));
340 run_test!((2, 0, 3));
341 run_test!((2, 3, 4));
342 run_test!((2, 3, 4, 2));
343 }
344
345 #[test]
346 fn test_indices_iter_f_size_hint()
347 {
348 let dim = (3, 4);
349 let mut it = indices_iter_f(dim);
350 let mut len = dim.0 * dim.1;
351 assert_eq!(it.len(), len);
352 while let Some(_) = it.next() {
353 len -= 1;
354 assert_eq!(it.len(), len);
355 }
356 assert_eq!(len, 0);
357 }
358}