n_circular_array/
array_index.rs

1use std::array;
2use std::ops::{Index, IndexMut, Range};
3
4use crate::array_iter::CircularArrayIterator;
5use crate::span::BoundSpan;
6use crate::span_iter::{RawIndexAdaptor, SpanIterator};
7use crate::CircularArray;
8
9/// Operations for retrieving elements from the array.
10pub trait CircularArrayIndex<'a, const N: usize, T: 'a> {
11    /// Iterate over all elements of the inner array, aligned to the offset.
12    fn iter(&'a self) -> impl ExactSizeIterator<Item = &'a T>;
13
14    /// Iterate over all elements of the inner array.
15    fn iter_raw(&'a self) -> impl ExactSizeIterator<Item = &'a T>;
16
17    /// Iterate over all elements of `index` for the given `axis` aligned to the offset.
18    fn iter_index(&'a self, axis: usize, index: usize) -> impl ExactSizeIterator<Item = &'a T>;
19
20    /// Iterate over all elements of `index` for the given `axis`.
21    fn iter_index_raw(&'a self, axis: usize, index: usize) -> impl ExactSizeIterator<Item = &'a T>;
22
23    /// Iterate over all elements of the given index `range` for the given `axis`
24    /// aligned to the offset.
25    fn iter_range(
26        &'a self,
27        axis: usize,
28        range: Range<usize>,
29    ) -> impl ExactSizeIterator<Item = &'a T>;
30
31    /// Iterate over all elements of the given index `range` for the given `axis`.
32    fn iter_range_raw(
33        &'a self,
34        axis: usize,
35        range: Range<usize>,
36    ) -> impl ExactSizeIterator<Item = &'a T>;
37
38    /// Iterate over all elements of the given index `slice`.
39    fn iter_slice(&'a self, slice: [Range<usize>; N]) -> impl ExactSizeIterator<Item = &'a T>;
40
41    /// Get a reference to the element at the given index, aligned to the offset.
42    fn get(&'a self, index: [usize; N]) -> &'a T;
43
44    /// Get a reference to the element at the given index.
45    fn get_raw(&'a self, index: [usize; N]) -> &'a T;
46}
47
48/// Methods for retrieving mutable references to elements of the array.
49pub trait CircularArrayIndexMut<'a, const N: usize, T: 'a> {
50    /// Get a mutable reference to the element at the given index, aligned to the offset.
51    fn get_mut(&mut self, index: [usize; N]) -> &mut T;
52
53    /// Get a mutable reference to the element at the given index.
54    fn get_mut_raw(&mut self, index: [usize; N]) -> &mut T;
55}
56
57impl<const N: usize, A, T> CircularArray<N, A, T> {
58    /// Get the exhaustive spans of the array, aligned to the offset.
59    pub(crate) fn spans(&self) -> [BoundSpan; N] {
60        array::from_fn(|i| BoundSpan::new(self.offset[i], self.shape[i], self.shape[i]))
61    }
62
63    /// Get the raw exhaustive spans of the array.
64    #[allow(dead_code)]
65    pub(crate) fn spans_raw(&self) -> [BoundSpan; N] {
66        array::from_fn(|i| BoundSpan::new(0, self.shape[i], self.shape[i]))
67    }
68
69    /// Get the spans of the array, bound by the given `span` on the given `axis`,
70    /// aligned to the offset.
71    pub(crate) fn spans_axis_bound(&self, axis: usize, span: BoundSpan) -> [BoundSpan; N] {
72        debug_assert!(span.len() <= self.shape[axis]);
73        array::from_fn(|i| {
74            if i == axis {
75                (span + self.offset[i]) % self.shape[i]
76            } else {
77                BoundSpan::new(self.offset[i], self.shape[i], self.shape[i])
78            }
79        })
80    }
81
82    /// Get the raw spans of the array, bound by the given `span` on the given `axis`.
83    pub(crate) fn spans_axis_bound_raw(&self, axis: usize, span: BoundSpan) -> [BoundSpan; N] {
84        array::from_fn(|i| {
85            if i == axis {
86                span
87            } else {
88                BoundSpan::new(0, self.shape[i], self.shape[i])
89            }
90        })
91    }
92}
93
94impl<'a, const N: usize, A: AsRef<[T]>, T: 'a> CircularArrayIndex<'a, N, T>
95    for CircularArray<N, A, T>
96{
97    fn iter(&'a self) -> impl ExactSizeIterator<Item = &'a T> {
98        let iter = SpanIterator::new(self.spans())
99            .into_ranges(&self.strides)
100            .flat_map(|range| &self.array.as_ref()[range]);
101
102        CircularArrayIterator::new(iter, self.len())
103    }
104
105    fn iter_raw(&'a self) -> impl ExactSizeIterator<Item = &'a T> {
106        let iter = self.array.as_ref().iter();
107
108        CircularArrayIterator::new(iter, self.len())
109    }
110
111    fn iter_index(&'a self, axis: usize, index: usize) -> impl ExactSizeIterator<Item = &'a T> {
112        assert_shape_index!(axis, N);
113        assert_slice_index!(self, axis, index);
114
115        let iter = SpanIterator::new(
116            self.spans_axis_bound(axis, BoundSpan::new(index, 1, self.shape[axis])),
117        )
118        .into_ranges(&self.strides)
119        .flat_map(|range| &self.array.as_ref()[range]);
120
121        CircularArrayIterator::new(iter, self.slice_len(axis))
122    }
123
124    fn iter_index_raw(&'a self, axis: usize, index: usize) -> impl ExactSizeIterator<Item = &'a T> {
125        assert_shape_index!(axis, N);
126        assert_slice_index!(self, axis, index);
127
128        let iter = SpanIterator::new(
129            self.spans_axis_bound_raw(axis, BoundSpan::new(index, 1, self.shape[axis])),
130        )
131        .into_ranges(&self.strides)
132        .flat_map(|range| &self.array.as_ref()[range]);
133
134        CircularArrayIterator::new(iter, self.slice_len(axis))
135    }
136
137    fn iter_range(
138        &'a self,
139        axis: usize,
140        range: Range<usize>,
141    ) -> impl ExactSizeIterator<Item = &'a T> {
142        assert_shape_index!(axis, N);
143        assert_slice_range!(self, axis, range);
144
145        let iter = SpanIterator::new(self.spans_axis_bound(
146            axis,
147            BoundSpan::new(range.start, range.len(), self.shape[axis]),
148        ))
149        .into_ranges(&self.strides)
150        .flat_map(|range| &self.array.as_ref()[range]);
151
152        CircularArrayIterator::new(iter, range.len() * self.slice_len(axis))
153    }
154
155    fn iter_range_raw(
156        &'a self,
157        axis: usize,
158        range: Range<usize>,
159    ) -> impl ExactSizeIterator<Item = &'a T> {
160        assert_shape_index!(axis, N);
161        assert_slice_range!(self, axis, range);
162
163        let iter = SpanIterator::new(self.spans_axis_bound_raw(
164            axis,
165            BoundSpan::new(range.start, range.len(), self.shape[axis]),
166        ))
167        .into_ranges(&self.strides)
168        .flat_map(|range| &self.array.as_ref()[range]);
169
170        CircularArrayIterator::new(iter, range.len() * self.slice_len(axis))
171    }
172
173    fn iter_slice(&'a self, slice: [Range<usize>; N]) -> impl ExactSizeIterator<Item = &'a T> {
174        let spans = array::from_fn(|i| {
175            let range = &slice[i];
176            assert_slice_range!(self, i, range);
177
178            BoundSpan::new(
179                (range.start + self.offset[i]) % self.shape[i],
180                range.len(),
181                self.shape[i],
182            ) % self.shape[i]
183        });
184
185        let iter = SpanIterator::new(spans)
186            .into_ranges(&self.strides)
187            .flat_map(|range| &self.array.as_ref()[range]);
188        let len = spans.iter().map(|spans| spans.len()).product();
189
190        CircularArrayIterator::new(iter, len)
191    }
192
193    fn get(&'a self, mut index: [usize; N]) -> &'a T {
194        index.iter_mut().enumerate().for_each(|(i, idx)| {
195            assert_slice_index!(self, i, *idx);
196            *idx = (*idx + self.offset[i]) % (self.shape[i]);
197        });
198
199        &self.array.as_ref()[self.strides.apply_to_index(index)]
200    }
201
202    fn get_raw(&'a self, index: [usize; N]) -> &'a T {
203        &self.array.as_ref()[self.strides.apply_to_index(index)]
204    }
205}
206
207impl<'a, const N: usize, A: AsMut<[T]>, T: 'a> CircularArrayIndexMut<'a, N, T>
208    for CircularArray<N, A, T>
209{
210    fn get_mut(&mut self, mut index: [usize; N]) -> &mut T {
211        index.iter_mut().enumerate().for_each(|(i, idx)| {
212            assert_slice_index!(self, i, *idx);
213            *idx = (*idx + self.offset[i]) % (self.shape[i]);
214        });
215
216        &mut self.array.as_mut()[self.strides.apply_to_index(index)]
217    }
218
219    fn get_mut_raw(&mut self, index: [usize; N]) -> &mut T {
220        &mut self.array.as_mut()[self.strides.apply_to_index(index)]
221    }
222}
223
224impl<'a, const N: usize, A: AsRef<[T]>, T: 'a> Index<[usize; N]> for CircularArray<N, A, T> {
225    type Output = T;
226
227    fn index(&self, index: [usize; N]) -> &Self::Output {
228        self.get(index)
229    }
230}
231
232impl<'a, const N: usize, A: AsRef<[T]> + AsMut<[T]>, T: 'a> IndexMut<[usize; N]>
233    for CircularArray<N, A, T>
234{
235    fn index_mut(&mut self, index: [usize; N]) -> &mut Self::Output {
236        self.get_mut(index)
237    }
238}
239
240#[cfg(test)]
241mod tests {
242
243    use super::*;
244    use crate::CircularArrayVec;
245
246    #[test]
247    fn iter() {
248        let shape = [3, 3, 3];
249        let mut m = CircularArrayVec::from_iter(shape, 0..shape.iter().product());
250        m.offset = [1, 1, 1];
251
252        #[rustfmt::skip]
253        assert_eq!(m.iter().cloned().collect::<Vec<_>>(), [
254            13, 14, 12,
255            16, 17, 15,
256            10, 11, 9,
257
258            22, 23, 21,
259            25, 26, 24,
260            19, 20, 18, 
261
262             4,  5,  3,
263             7,  8,  6, 
264             1,  2,  0
265        ]);
266        assert_eq!(m.iter().len(), 27);
267    }
268
269    #[test]
270    fn iter_raw() {
271        let shape = [3, 3, 3];
272        let m = CircularArrayVec::from_iter(shape, 0..shape.iter().product());
273
274        assert_eq!(
275            m.iter_raw().cloned().collect::<Vec<_>>(),
276            (0..3 * 3 * 3).collect::<Vec<_>>()
277        );
278        assert_eq!(m.iter().len(), 27);
279    }
280
281    #[test]
282    fn iter_index() {
283        let shape = [3, 3, 3];
284        let mut m = CircularArrayVec::from_iter_offset(shape, 0..shape.iter().product(), [1, 0, 0]);
285
286        #[rustfmt::skip]
287        assert_eq!(
288            m.iter_index(0, 1).cloned().collect::<Vec<_>>(),
289            [2, 5, 8, 11, 14, 17, 20, 23, 26]
290        );
291        assert_eq!(m.iter_index(0, 1).len(), 9);
292        m.offset = [0, 1, 0];
293        assert_eq!(
294            m.iter_index(1, 1).cloned().collect::<Vec<_>>(),
295            [6, 7, 8, 15, 16, 17, 24, 25, 26]
296        );
297        assert_eq!(m.iter_index(1, 1).len(), 9);
298        m.offset = [0, 0, 1];
299        assert_eq!(
300            m.iter_index(2, 1).cloned().collect::<Vec<_>>(),
301            [18, 19, 20, 21, 22, 23, 24, 25, 26]
302        );
303        assert_eq!(m.iter_index(2, 1).len(), 9);
304        m.offset = [1, 1, 1];
305        #[rustfmt::skip]
306        assert_eq!(
307            m.iter_index(0, 0).cloned().collect::<Vec<_>>(),
308            [13, 16, 10, 22, 25, 19, 4, 7, 1]
309        );
310        assert_eq!(m.iter_index(0, 0).len(), 9);
311    }
312
313    #[test]
314    fn iter_range() {
315        let shape = [3, 3, 3];
316        let mut m = CircularArrayVec::from_iter_offset(shape, 0..shape.iter().product(), [1, 0, 0]);
317
318        #[rustfmt::skip]
319        assert_eq!(
320            m.iter_range(0, 0..2).cloned().collect::<Vec<_>>(),
321            [1, 2, 4, 5, 7, 8, 10, 11, 13, 14, 16, 17, 19, 20, 22, 23, 25, 26]
322        );
323        assert_eq!(m.iter_range(0, 0..2).len(), 18);
324        m.offset = [0, 1, 0];
325        assert_eq!(
326            m.iter_range(1, 1..3).cloned().collect::<Vec<_>>(),
327            [6, 7, 8, 0, 1, 2, 15, 16, 17, 9, 10, 11, 24, 25, 26, 18, 19, 20]
328        );
329        assert_eq!(m.iter_range(1, 1..3).len(), 18);
330        m.offset = [0, 0, 1];
331        assert_eq!(
332            m.iter_range(2, 1..2).cloned().collect::<Vec<_>>(),
333            [18, 19, 20, 21, 22, 23, 24, 25, 26]
334        );
335        assert_eq!(m.iter_range(2, 1..2).len(), 9);
336        m.offset = [1, 1, 1];
337        #[rustfmt::skip]
338        assert_eq!(m.iter_range(0, 1..4).cloned().collect::<Vec<_>>(), [
339                14, 12, 13,
340                17, 15, 16,
341                11,  9, 10,
342
343                23, 21, 22,
344                26, 24, 25,
345                20, 18, 19,
346
347                 5,  3,  4,
348                 8,  6,  7,
349                 2,  0,  1
350            ]);
351        assert_eq!(m.iter_range(0, 1..4).len(), 27);
352    }
353
354    #[test]
355    fn iter_range_raw() {
356        let shape = [3, 3, 3];
357        let mut m = CircularArrayVec::from_iter_offset(shape, 0..shape.iter().product(), [1, 0, 0]);
358
359        #[rustfmt::skip]
360        assert_eq!(
361            m.iter_range_raw(0, 0..2).cloned().collect::<Vec<_>>(),
362            [0, 1, 3, 4, 6, 7, 9, 10, 12, 13, 15, 16, 18, 19, 21, 22, 24, 25]
363        );
364        assert_eq!(m.iter_range_raw(0, 0..2).len(), 18);
365        m.offset = [0, 1, 0];
366        assert_eq!(
367            m.iter_range_raw(1, 1..3).cloned().collect::<Vec<_>>(),
368            [3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 21, 22, 23, 24, 25, 26]
369        );
370        assert_eq!(m.iter_range_raw(1, 1..3).len(), 18);
371        m.offset = [0, 0, 1];
372        assert_eq!(
373            m.iter_range_raw(2, 1..2).cloned().collect::<Vec<_>>(),
374            [9, 10, 11, 12, 13, 14, 15, 16, 17]
375        );
376        assert_eq!(m.iter_range_raw(2, 1..2).len(), 9);
377        m.offset = [1, 1, 1];
378        #[rustfmt::skip]
379        assert_eq!(m.iter_range_raw(0, 1..3).cloned().collect::<Vec<_>>(), [
380             1,  2,
381             4,  5,
382             7,  8,
383            
384            10, 11,
385            13, 14,
386            16, 17,
387            
388            19, 20,
389            22, 23,
390            25, 26            
391            ]);
392        assert_eq!(m.iter_range_raw(0, 1..3).len(), 18);
393    }
394
395    #[test]
396    fn iter_slice() {
397        let shape = [3, 3, 3];
398        let mut m = CircularArrayVec::from_iter_offset(shape, 0..shape.iter().product(), [1, 1, 1]);
399
400        #[rustfmt::skip]
401        assert_eq!(m.iter_slice([0..1, 0..1, 0..1]).cloned().collect::<Vec<_>>(), &[13]);
402        assert_eq!(m.iter_slice([0..1, 0..1, 0..1]).len(), 1);
403        #[rustfmt::skip]
404        assert_eq!(m.iter_slice([0..3, 0..3, 1..2]).cloned().collect::<Vec<_>>(), &[
405            22, 23, 21,
406            25, 26, 24,
407            19, 20, 18
408        ]);
409        assert_eq!(m.iter_slice([0..3, 0..3, 1..2]).len(), 9);
410
411        m.offset = [2, 2, 2];
412
413        #[rustfmt::skip]
414        assert_eq!(m.iter_slice([0..1, 0..1, 0..1]).cloned().collect::<Vec<_>>(), &[26]);
415        assert_eq!(m.iter_slice([0..1, 0..1, 0..1]).len(), 1);
416        #[rustfmt::skip]
417        assert_eq!(m.iter_slice([0..3, 0..3, 1..2]).cloned().collect::<Vec<_>>(), &[
418            8, 6, 7,
419            2, 0, 1,
420            5, 3, 4
421        ]);
422        assert_eq!(m.iter_slice([0..3, 0..3, 1..2]).len(), 9);
423    }
424
425    #[test]
426    fn get() {
427        let shape = [3, 3, 3];
428        let m = CircularArrayVec::from_iter_offset(shape, 0..shape.iter().product(), [1, 1, 1]);
429
430        assert_eq!(m.get([0, 0, 0]), &13);
431        assert_eq!(m.get([1, 1, 1]), &26);
432        assert_eq!(m.get([2, 2, 2]), &0);
433    }
434
435    #[test]
436    fn get_raw() {
437        let m = CircularArray::new([3, 3, 3], (0..3 * 3 * 3).collect::<Vec<_>>());
438
439        assert_eq!(m.get_raw([0, 0, 0]), &0);
440        assert_eq!(m.get_raw([1, 1, 1]), &13);
441        assert_eq!(m.get_raw([2, 2, 2]), &26);
442    }
443}