1use std::array;
2use std::ops::{Index, IndexMut, Range};
3
4use crate::array_iter::CircularArrayIterator;
5use crate::span::BoundSpan;
6use crate::span_iter::{RawIndexAdaptor, SpanIterator};
7use crate::CircularArray;
8
9pub trait CircularArrayIndex<'a, const N: usize, T: 'a> {
11 fn iter(&'a self) -> impl ExactSizeIterator<Item = &'a T>;
13
14 fn iter_raw(&'a self) -> impl ExactSizeIterator<Item = &'a T>;
16
17 fn iter_index(&'a self, axis: usize, index: usize) -> impl ExactSizeIterator<Item = &'a T>;
19
20 fn iter_index_raw(&'a self, axis: usize, index: usize) -> impl ExactSizeIterator<Item = &'a T>;
22
23 fn iter_range(
26 &'a self,
27 axis: usize,
28 range: Range<usize>,
29 ) -> impl ExactSizeIterator<Item = &'a T>;
30
31 fn iter_range_raw(
33 &'a self,
34 axis: usize,
35 range: Range<usize>,
36 ) -> impl ExactSizeIterator<Item = &'a T>;
37
38 fn iter_slice(&'a self, slice: [Range<usize>; N]) -> impl ExactSizeIterator<Item = &'a T>;
40
41 fn get(&'a self, index: [usize; N]) -> &'a T;
43
44 fn get_raw(&'a self, index: [usize; N]) -> &'a T;
46}
47
48pub trait CircularArrayIndexMut<'a, const N: usize, T: 'a> {
50 fn get_mut(&mut self, index: [usize; N]) -> &mut T;
52
53 fn get_mut_raw(&mut self, index: [usize; N]) -> &mut T;
55}
56
57impl<const N: usize, A, T> CircularArray<N, A, T> {
58 pub(crate) fn spans(&self) -> [BoundSpan; N] {
60 array::from_fn(|i| BoundSpan::new(self.offset[i], self.shape[i], self.shape[i]))
61 }
62
63 #[allow(dead_code)]
65 pub(crate) fn spans_raw(&self) -> [BoundSpan; N] {
66 array::from_fn(|i| BoundSpan::new(0, self.shape[i], self.shape[i]))
67 }
68
69 pub(crate) fn spans_axis_bound(&self, axis: usize, span: BoundSpan) -> [BoundSpan; N] {
72 debug_assert!(span.len() <= self.shape[axis]);
73 array::from_fn(|i| {
74 if i == axis {
75 (span + self.offset[i]) % self.shape[i]
76 } else {
77 BoundSpan::new(self.offset[i], self.shape[i], self.shape[i])
78 }
79 })
80 }
81
82 pub(crate) fn spans_axis_bound_raw(&self, axis: usize, span: BoundSpan) -> [BoundSpan; N] {
84 array::from_fn(|i| {
85 if i == axis {
86 span
87 } else {
88 BoundSpan::new(0, self.shape[i], self.shape[i])
89 }
90 })
91 }
92}
93
94impl<'a, const N: usize, A: AsRef<[T]>, T: 'a> CircularArrayIndex<'a, N, T>
95 for CircularArray<N, A, T>
96{
97 fn iter(&'a self) -> impl ExactSizeIterator<Item = &'a T> {
98 let iter = SpanIterator::new(self.spans())
99 .into_ranges(&self.strides)
100 .flat_map(|range| &self.array.as_ref()[range]);
101
102 CircularArrayIterator::new(iter, self.len())
103 }
104
105 fn iter_raw(&'a self) -> impl ExactSizeIterator<Item = &'a T> {
106 let iter = self.array.as_ref().iter();
107
108 CircularArrayIterator::new(iter, self.len())
109 }
110
111 fn iter_index(&'a self, axis: usize, index: usize) -> impl ExactSizeIterator<Item = &'a T> {
112 assert_shape_index!(axis, N);
113 assert_slice_index!(self, axis, index);
114
115 let iter = SpanIterator::new(
116 self.spans_axis_bound(axis, BoundSpan::new(index, 1, self.shape[axis])),
117 )
118 .into_ranges(&self.strides)
119 .flat_map(|range| &self.array.as_ref()[range]);
120
121 CircularArrayIterator::new(iter, self.slice_len(axis))
122 }
123
124 fn iter_index_raw(&'a self, axis: usize, index: usize) -> impl ExactSizeIterator<Item = &'a T> {
125 assert_shape_index!(axis, N);
126 assert_slice_index!(self, axis, index);
127
128 let iter = SpanIterator::new(
129 self.spans_axis_bound_raw(axis, BoundSpan::new(index, 1, self.shape[axis])),
130 )
131 .into_ranges(&self.strides)
132 .flat_map(|range| &self.array.as_ref()[range]);
133
134 CircularArrayIterator::new(iter, self.slice_len(axis))
135 }
136
137 fn iter_range(
138 &'a self,
139 axis: usize,
140 range: Range<usize>,
141 ) -> impl ExactSizeIterator<Item = &'a T> {
142 assert_shape_index!(axis, N);
143 assert_slice_range!(self, axis, range);
144
145 let iter = SpanIterator::new(self.spans_axis_bound(
146 axis,
147 BoundSpan::new(range.start, range.len(), self.shape[axis]),
148 ))
149 .into_ranges(&self.strides)
150 .flat_map(|range| &self.array.as_ref()[range]);
151
152 CircularArrayIterator::new(iter, range.len() * self.slice_len(axis))
153 }
154
155 fn iter_range_raw(
156 &'a self,
157 axis: usize,
158 range: Range<usize>,
159 ) -> impl ExactSizeIterator<Item = &'a T> {
160 assert_shape_index!(axis, N);
161 assert_slice_range!(self, axis, range);
162
163 let iter = SpanIterator::new(self.spans_axis_bound_raw(
164 axis,
165 BoundSpan::new(range.start, range.len(), self.shape[axis]),
166 ))
167 .into_ranges(&self.strides)
168 .flat_map(|range| &self.array.as_ref()[range]);
169
170 CircularArrayIterator::new(iter, range.len() * self.slice_len(axis))
171 }
172
173 fn iter_slice(&'a self, slice: [Range<usize>; N]) -> impl ExactSizeIterator<Item = &'a T> {
174 let spans = array::from_fn(|i| {
175 let range = &slice[i];
176 assert_slice_range!(self, i, range);
177
178 BoundSpan::new(
179 (range.start + self.offset[i]) % self.shape[i],
180 range.len(),
181 self.shape[i],
182 ) % self.shape[i]
183 });
184
185 let iter = SpanIterator::new(spans)
186 .into_ranges(&self.strides)
187 .flat_map(|range| &self.array.as_ref()[range]);
188 let len = spans.iter().map(|spans| spans.len()).product();
189
190 CircularArrayIterator::new(iter, len)
191 }
192
193 fn get(&'a self, mut index: [usize; N]) -> &'a T {
194 index.iter_mut().enumerate().for_each(|(i, idx)| {
195 assert_slice_index!(self, i, *idx);
196 *idx = (*idx + self.offset[i]) % (self.shape[i]);
197 });
198
199 &self.array.as_ref()[self.strides.apply_to_index(index)]
200 }
201
202 fn get_raw(&'a self, index: [usize; N]) -> &'a T {
203 &self.array.as_ref()[self.strides.apply_to_index(index)]
204 }
205}
206
207impl<'a, const N: usize, A: AsMut<[T]>, T: 'a> CircularArrayIndexMut<'a, N, T>
208 for CircularArray<N, A, T>
209{
210 fn get_mut(&mut self, mut index: [usize; N]) -> &mut T {
211 index.iter_mut().enumerate().for_each(|(i, idx)| {
212 assert_slice_index!(self, i, *idx);
213 *idx = (*idx + self.offset[i]) % (self.shape[i]);
214 });
215
216 &mut self.array.as_mut()[self.strides.apply_to_index(index)]
217 }
218
219 fn get_mut_raw(&mut self, index: [usize; N]) -> &mut T {
220 &mut self.array.as_mut()[self.strides.apply_to_index(index)]
221 }
222}
223
224impl<'a, const N: usize, A: AsRef<[T]>, T: 'a> Index<[usize; N]> for CircularArray<N, A, T> {
225 type Output = T;
226
227 fn index(&self, index: [usize; N]) -> &Self::Output {
228 self.get(index)
229 }
230}
231
232impl<'a, const N: usize, A: AsRef<[T]> + AsMut<[T]>, T: 'a> IndexMut<[usize; N]>
233 for CircularArray<N, A, T>
234{
235 fn index_mut(&mut self, index: [usize; N]) -> &mut Self::Output {
236 self.get_mut(index)
237 }
238}
239
240#[cfg(test)]
241mod tests {
242
243 use super::*;
244 use crate::CircularArrayVec;
245
246 #[test]
247 fn iter() {
248 let shape = [3, 3, 3];
249 let mut m = CircularArrayVec::from_iter(shape, 0..shape.iter().product());
250 m.offset = [1, 1, 1];
251
252 #[rustfmt::skip]
253 assert_eq!(m.iter().cloned().collect::<Vec<_>>(), [
254 13, 14, 12,
255 16, 17, 15,
256 10, 11, 9,
257
258 22, 23, 21,
259 25, 26, 24,
260 19, 20, 18,
261
262 4, 5, 3,
263 7, 8, 6,
264 1, 2, 0
265 ]);
266 assert_eq!(m.iter().len(), 27);
267 }
268
269 #[test]
270 fn iter_raw() {
271 let shape = [3, 3, 3];
272 let m = CircularArrayVec::from_iter(shape, 0..shape.iter().product());
273
274 assert_eq!(
275 m.iter_raw().cloned().collect::<Vec<_>>(),
276 (0..3 * 3 * 3).collect::<Vec<_>>()
277 );
278 assert_eq!(m.iter().len(), 27);
279 }
280
281 #[test]
282 fn iter_index() {
283 let shape = [3, 3, 3];
284 let mut m = CircularArrayVec::from_iter_offset(shape, 0..shape.iter().product(), [1, 0, 0]);
285
286 #[rustfmt::skip]
287 assert_eq!(
288 m.iter_index(0, 1).cloned().collect::<Vec<_>>(),
289 [2, 5, 8, 11, 14, 17, 20, 23, 26]
290 );
291 assert_eq!(m.iter_index(0, 1).len(), 9);
292 m.offset = [0, 1, 0];
293 assert_eq!(
294 m.iter_index(1, 1).cloned().collect::<Vec<_>>(),
295 [6, 7, 8, 15, 16, 17, 24, 25, 26]
296 );
297 assert_eq!(m.iter_index(1, 1).len(), 9);
298 m.offset = [0, 0, 1];
299 assert_eq!(
300 m.iter_index(2, 1).cloned().collect::<Vec<_>>(),
301 [18, 19, 20, 21, 22, 23, 24, 25, 26]
302 );
303 assert_eq!(m.iter_index(2, 1).len(), 9);
304 m.offset = [1, 1, 1];
305 #[rustfmt::skip]
306 assert_eq!(
307 m.iter_index(0, 0).cloned().collect::<Vec<_>>(),
308 [13, 16, 10, 22, 25, 19, 4, 7, 1]
309 );
310 assert_eq!(m.iter_index(0, 0).len(), 9);
311 }
312
313 #[test]
314 fn iter_range() {
315 let shape = [3, 3, 3];
316 let mut m = CircularArrayVec::from_iter_offset(shape, 0..shape.iter().product(), [1, 0, 0]);
317
318 #[rustfmt::skip]
319 assert_eq!(
320 m.iter_range(0, 0..2).cloned().collect::<Vec<_>>(),
321 [1, 2, 4, 5, 7, 8, 10, 11, 13, 14, 16, 17, 19, 20, 22, 23, 25, 26]
322 );
323 assert_eq!(m.iter_range(0, 0..2).len(), 18);
324 m.offset = [0, 1, 0];
325 assert_eq!(
326 m.iter_range(1, 1..3).cloned().collect::<Vec<_>>(),
327 [6, 7, 8, 0, 1, 2, 15, 16, 17, 9, 10, 11, 24, 25, 26, 18, 19, 20]
328 );
329 assert_eq!(m.iter_range(1, 1..3).len(), 18);
330 m.offset = [0, 0, 1];
331 assert_eq!(
332 m.iter_range(2, 1..2).cloned().collect::<Vec<_>>(),
333 [18, 19, 20, 21, 22, 23, 24, 25, 26]
334 );
335 assert_eq!(m.iter_range(2, 1..2).len(), 9);
336 m.offset = [1, 1, 1];
337 #[rustfmt::skip]
338 assert_eq!(m.iter_range(0, 1..4).cloned().collect::<Vec<_>>(), [
339 14, 12, 13,
340 17, 15, 16,
341 11, 9, 10,
342
343 23, 21, 22,
344 26, 24, 25,
345 20, 18, 19,
346
347 5, 3, 4,
348 8, 6, 7,
349 2, 0, 1
350 ]);
351 assert_eq!(m.iter_range(0, 1..4).len(), 27);
352 }
353
354 #[test]
355 fn iter_range_raw() {
356 let shape = [3, 3, 3];
357 let mut m = CircularArrayVec::from_iter_offset(shape, 0..shape.iter().product(), [1, 0, 0]);
358
359 #[rustfmt::skip]
360 assert_eq!(
361 m.iter_range_raw(0, 0..2).cloned().collect::<Vec<_>>(),
362 [0, 1, 3, 4, 6, 7, 9, 10, 12, 13, 15, 16, 18, 19, 21, 22, 24, 25]
363 );
364 assert_eq!(m.iter_range_raw(0, 0..2).len(), 18);
365 m.offset = [0, 1, 0];
366 assert_eq!(
367 m.iter_range_raw(1, 1..3).cloned().collect::<Vec<_>>(),
368 [3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 21, 22, 23, 24, 25, 26]
369 );
370 assert_eq!(m.iter_range_raw(1, 1..3).len(), 18);
371 m.offset = [0, 0, 1];
372 assert_eq!(
373 m.iter_range_raw(2, 1..2).cloned().collect::<Vec<_>>(),
374 [9, 10, 11, 12, 13, 14, 15, 16, 17]
375 );
376 assert_eq!(m.iter_range_raw(2, 1..2).len(), 9);
377 m.offset = [1, 1, 1];
378 #[rustfmt::skip]
379 assert_eq!(m.iter_range_raw(0, 1..3).cloned().collect::<Vec<_>>(), [
380 1, 2,
381 4, 5,
382 7, 8,
383
384 10, 11,
385 13, 14,
386 16, 17,
387
388 19, 20,
389 22, 23,
390 25, 26
391 ]);
392 assert_eq!(m.iter_range_raw(0, 1..3).len(), 18);
393 }
394
395 #[test]
396 fn iter_slice() {
397 let shape = [3, 3, 3];
398 let mut m = CircularArrayVec::from_iter_offset(shape, 0..shape.iter().product(), [1, 1, 1]);
399
400 #[rustfmt::skip]
401 assert_eq!(m.iter_slice([0..1, 0..1, 0..1]).cloned().collect::<Vec<_>>(), &[13]);
402 assert_eq!(m.iter_slice([0..1, 0..1, 0..1]).len(), 1);
403 #[rustfmt::skip]
404 assert_eq!(m.iter_slice([0..3, 0..3, 1..2]).cloned().collect::<Vec<_>>(), &[
405 22, 23, 21,
406 25, 26, 24,
407 19, 20, 18
408 ]);
409 assert_eq!(m.iter_slice([0..3, 0..3, 1..2]).len(), 9);
410
411 m.offset = [2, 2, 2];
412
413 #[rustfmt::skip]
414 assert_eq!(m.iter_slice([0..1, 0..1, 0..1]).cloned().collect::<Vec<_>>(), &[26]);
415 assert_eq!(m.iter_slice([0..1, 0..1, 0..1]).len(), 1);
416 #[rustfmt::skip]
417 assert_eq!(m.iter_slice([0..3, 0..3, 1..2]).cloned().collect::<Vec<_>>(), &[
418 8, 6, 7,
419 2, 0, 1,
420 5, 3, 4
421 ]);
422 assert_eq!(m.iter_slice([0..3, 0..3, 1..2]).len(), 9);
423 }
424
425 #[test]
426 fn get() {
427 let shape = [3, 3, 3];
428 let m = CircularArrayVec::from_iter_offset(shape, 0..shape.iter().product(), [1, 1, 1]);
429
430 assert_eq!(m.get([0, 0, 0]), &13);
431 assert_eq!(m.get([1, 1, 1]), &26);
432 assert_eq!(m.get([2, 2, 2]), &0);
433 }
434
435 #[test]
436 fn get_raw() {
437 let m = CircularArray::new([3, 3, 3], (0..3 * 3 * 3).collect::<Vec<_>>());
438
439 assert_eq!(m.get_raw([0, 0, 0]), &0);
440 assert_eq!(m.get_raw([1, 1, 1]), &13);
441 assert_eq!(m.get_raw([2, 2, 2]), &26);
442 }
443}