ndarray/
indexes.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8use super::Dimension;
9use crate::dimension::IntoDimension;
10use crate::split_at::SplitAt;
11use crate::zip::Offset;
12use crate::Axis;
13use crate::Layout;
14use crate::NdProducer;
15use crate::{ArrayBase, Data};
16
17/// An iterator over the indexes of an array shape.
18///
19/// Iterator element type is `D`.
20#[derive(Clone)]
21pub struct IndicesIter<D>
22{
23    dim: D,
24    index: Option<D>,
25}
26
27/// Create an iterable of the array shape `shape`.
28///
29/// *Note:* prefer higher order methods, arithmetic operations and
30/// non-indexed iteration before using indices.
31pub fn indices<E>(shape: E) -> Indices<E::Dim>
32where E: IntoDimension
33{
34    let dim = shape.into_dimension();
35    Indices {
36        start: E::Dim::zeros(dim.ndim()),
37        dim,
38    }
39}
40
41/// Return an iterable of the indices of the passed-in array.
42///
43/// *Note:* prefer higher order methods, arithmetic operations and
44/// non-indexed iteration before using indices.
45pub fn indices_of<S, D>(array: &ArrayBase<S, D>) -> Indices<D>
46where
47    S: Data,
48    D: Dimension,
49{
50    indices(array.dim())
51}
52
53impl<D> Iterator for IndicesIter<D>
54where D: Dimension
55{
56    type Item = D::Pattern;
57    #[inline]
58    fn next(&mut self) -> Option<Self::Item>
59    {
60        let index = match self.index {
61            None => return None,
62            Some(ref ix) => ix.clone(),
63        };
64        self.index = self.dim.next_for(index.clone());
65        Some(index.into_pattern())
66    }
67
68    fn size_hint(&self) -> (usize, Option<usize>)
69    {
70        let l = match self.index {
71            None => 0,
72            Some(ref ix) => {
73                let gone = self
74                    .dim
75                    .default_strides()
76                    .slice()
77                    .iter()
78                    .zip(ix.slice().iter())
79                    .fold(0, |s, (&a, &b)| s + a * b);
80                self.dim.size() - gone
81            }
82        };
83        (l, Some(l))
84    }
85
86    fn fold<B, F>(self, init: B, mut f: F) -> B
87    where F: FnMut(B, D::Pattern) -> B
88    {
89        let IndicesIter { mut index, dim } = self;
90        let ndim = dim.ndim();
91        if ndim == 0 {
92            return match index {
93                Some(ix) => f(init, ix.into_pattern()),
94                None => init,
95            };
96        }
97        let inner_axis = ndim - 1;
98        let inner_len = dim[inner_axis];
99        let mut acc = init;
100        while let Some(mut ix) = index {
101            // unroll innermost axis
102            for i in ix[inner_axis]..inner_len {
103                ix[inner_axis] = i;
104                acc = f(acc, ix.clone().into_pattern());
105            }
106            index = dim.next_for(ix);
107        }
108        acc
109    }
110}
111
112impl<D> ExactSizeIterator for IndicesIter<D> where D: Dimension {}
113
114impl<D> IntoIterator for Indices<D>
115where D: Dimension
116{
117    type Item = D::Pattern;
118    type IntoIter = IndicesIter<D>;
119    fn into_iter(self) -> Self::IntoIter
120    {
121        let sz = self.dim.size();
122        let index = if sz != 0 { Some(self.start) } else { None };
123        IndicesIter { index, dim: self.dim }
124    }
125}
126
127/// Indices producer and iterable.
128///
129/// `Indices` is an `NdProducer` that produces the indices of an array shape.
130#[derive(Copy, Clone, Debug)]
131pub struct Indices<D>
132where D: Dimension
133{
134    start: D,
135    dim: D,
136}
137
138#[derive(Copy, Clone, Debug)]
139pub struct IndexPtr<D>
140{
141    index: D,
142}
143
144impl<D> Offset for IndexPtr<D>
145where D: Dimension + Copy
146{
147    // stride: The axis to increment
148    type Stride = usize;
149
150    unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self
151    {
152        self.index[stride] += index;
153        self
154    }
155    private_impl! {}
156}
157
158// How the NdProducer for Indices works.
159//
160// NdProducer allows for raw pointers (Ptr), strides (Stride) and the produced
161// item (Item).
162//
163// Instead of Ptr, there is `IndexPtr<D>` which is an index value, like [0, 0, 0]
164// for the three dimensional case.
165//
166// The stride is simply which axis is currently being incremented. The stride for axis 1, is 1.
167//
168// .stride_offset(stride, index) simply computes the new index along that axis, for example:
169// [0, 0, 0].stride_offset(1, 10) => [0, 10, 0]  axis 1 is incremented by 10.
170//
171// .as_ref() converts the Ptr value to an Item. For example [0, 10, 0] => (0, 10, 0)
172impl<D: Dimension + Copy> NdProducer for Indices<D>
173{
174    type Item = D::Pattern;
175    type Dim = D;
176    type Ptr = IndexPtr<D>;
177    type Stride = usize;
178
179    private_impl! {}
180
181    fn raw_dim(&self) -> Self::Dim
182    {
183        self.dim
184    }
185
186    fn equal_dim(&self, dim: &Self::Dim) -> bool
187    {
188        self.dim.equal(dim)
189    }
190
191    fn as_ptr(&self) -> Self::Ptr
192    {
193        IndexPtr { index: self.start }
194    }
195
196    fn layout(&self) -> Layout
197    {
198        if self.dim.ndim() <= 1 {
199            Layout::one_dimensional()
200        } else {
201            Layout::none()
202        }
203    }
204
205    unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item
206    {
207        ptr.index.into_pattern()
208    }
209
210    unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr
211    {
212        let mut index = *i;
213        index += &self.start;
214        IndexPtr { index }
215    }
216
217    fn stride_of(&self, axis: Axis) -> Self::Stride
218    {
219        axis.index()
220    }
221
222    #[inline(always)]
223    fn contiguous_stride(&self) -> Self::Stride
224    {
225        0
226    }
227
228    fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
229    {
230        let start_a = self.start;
231        let mut start_b = start_a;
232        let (a, b) = self.dim.split_at(axis, index);
233        start_b[axis.index()] += index;
234        (Indices { start: start_a, dim: a }, Indices { start: start_b, dim: b })
235    }
236}
237
238/// An iterator over the indexes of an array shape.
239///
240/// Iterator element type is `D`.
241#[derive(Clone)]
242pub struct IndicesIterF<D>
243{
244    dim: D,
245    index: D,
246    has_remaining: bool,
247}
248
249pub fn indices_iter_f<E>(shape: E) -> IndicesIterF<E::Dim>
250where E: IntoDimension
251{
252    let dim = shape.into_dimension();
253    let zero = E::Dim::zeros(dim.ndim());
254    IndicesIterF {
255        has_remaining: dim.size_checked() != Some(0),
256        index: zero,
257        dim,
258    }
259}
260
261impl<D> Iterator for IndicesIterF<D>
262where D: Dimension
263{
264    type Item = D::Pattern;
265    #[inline]
266    fn next(&mut self) -> Option<Self::Item>
267    {
268        if !self.has_remaining {
269            None
270        } else {
271            let elt = self.index.clone().into_pattern();
272            self.has_remaining = self.dim.next_for_f(&mut self.index);
273            Some(elt)
274        }
275    }
276
277    fn size_hint(&self) -> (usize, Option<usize>)
278    {
279        if !self.has_remaining {
280            return (0, Some(0));
281        }
282        let gone = self
283            .dim
284            .fortran_strides()
285            .slice()
286            .iter()
287            .zip(self.index.slice().iter())
288            .fold(0, |s, (&a, &b)| s + a * b);
289        let l = self.dim.size() - gone;
290        (l, Some(l))
291    }
292}
293
294impl<D> ExactSizeIterator for IndicesIterF<D> where D: Dimension {}
295
296#[cfg(test)]
297mod tests
298{
299    use super::indices;
300    use super::indices_iter_f;
301
302    #[test]
303    fn test_indices_iter_c_size_hint()
304    {
305        let dim = (3, 4);
306        let mut it = indices(dim).into_iter();
307        let mut len = dim.0 * dim.1;
308        assert_eq!(it.len(), len);
309        while let Some(_) = it.next() {
310            len -= 1;
311            assert_eq!(it.len(), len);
312        }
313        assert_eq!(len, 0);
314    }
315
316    #[test]
317    fn test_indices_iter_c_fold()
318    {
319        macro_rules! run_test {
320            ($dim:expr) => {
321                for num_consume in 0..3 {
322                    let mut it = indices($dim).into_iter();
323                    for _ in 0..num_consume {
324                        it.next();
325                    }
326                    let clone = it.clone();
327                    let len = it.len();
328                    let acc = clone.fold(0, |acc, ix| {
329                        assert_eq!(ix, it.next().unwrap());
330                        acc + 1
331                    });
332                    assert_eq!(acc, len);
333                    assert!(it.next().is_none());
334                }
335            };
336        }
337        run_test!(());
338        run_test!((2,));
339        run_test!((2, 3));
340        run_test!((2, 0, 3));
341        run_test!((2, 3, 4));
342        run_test!((2, 3, 4, 2));
343    }
344
345    #[test]
346    fn test_indices_iter_f_size_hint()
347    {
348        let dim = (3, 4);
349        let mut it = indices_iter_f(dim);
350        let mut len = dim.0 * dim.1;
351        assert_eq!(it.len(), len);
352        while let Some(_) = it.next() {
353            len -= 1;
354            assert_eq!(it.len(), len);
355        }
356        assert_eq!(len, 0);
357    }
358}