1use std::marker::PhantomData;
4use std::ops::Add;
5
6use itertools::Itertools;
7use typenum::{Const, NonZero, PInt, ToUInt, U, Unsigned};
8
9use crate::dense::array::multislice::Multislice;
10use crate::dense::array::reference::ArrayRef;
11use crate::dense::array::{Array, Shape};
12use crate::dense::layout::{col_major_stride_from_shape, convert_1d_nd_from_shape, convert_nd_raw};
13use crate::traits::accessors::{
14 UnsafeRandom1DAccessByValue, UnsafeRandom1DAccessMut, UnsafeRandomAccessByValue,
15 UnsafeRandomAccessMut,
16};
17use crate::traits::{base_operations::BaseItem, iterators::AsMultiIndex};
18use crate::{IsSmallerThan, NumberType, UnsafeRandom1DAccessByRef, UnsafeRandomAccessByRef};
19
20use super::reference::{self, ArrayRefMut};
21use super::slice::ArraySlice;
22
23pub struct ArrayDefaultIteratorByValue<'a, ArrayImpl, const NDIM: usize> {
27 arr: &'a Array<ArrayImpl, NDIM>,
28 pos: usize,
29 nelements: usize,
30}
31
32pub struct ArrayDefaultIteratorByRef<'a, ArrayImpl, const NDIM: usize> {
36 arr: &'a Array<ArrayImpl, NDIM>,
37 pos: usize,
38 nelements: usize,
39}
40
41pub struct ArrayDefaultIteratorMut<'a, ArrayImpl, const NDIM: usize> {
43 arr: &'a mut Array<ArrayImpl, NDIM>,
44 pos: usize,
45 nelements: usize,
46}
47
48pub struct MultiIndexIterator<I, const NDIM: usize> {
50 shape: [usize; NDIM],
51 iter: I,
52}
53
54impl<T, I: Iterator<Item = (usize, T)>, const NDIM: usize> Iterator
55 for MultiIndexIterator<I, NDIM>
56{
57 type Item = ([usize; NDIM], T);
58
59 fn next(&mut self) -> Option<Self::Item> {
60 if let Some((index, value)) = self.iter.next() {
61 Some((convert_1d_nd_from_shape(index, self.shape), value))
62 } else {
63 None
64 }
65 }
66}
67
68impl<T, I: Iterator<Item = (usize, T)>, const NDIM: usize> AsMultiIndex<T, I, NDIM> for I {
69 fn multi_index(self, shape: [usize; NDIM]) -> MultiIndexIterator<I, NDIM> {
70 MultiIndexIterator::<I, NDIM> { shape, iter: self }
71 }
72}
73
74impl<'a, ArrayImpl, const NDIM: usize> ArrayDefaultIteratorByValue<'a, ArrayImpl, NDIM>
75where
76 ArrayImpl: Shape<NDIM>,
77{
78 pub fn new(arr: &'a Array<ArrayImpl, NDIM>) -> Self {
80 Self {
81 arr,
82 pos: 0,
83 nelements: arr.len(),
84 }
85 }
86}
87
88impl<'a, ArrayImpl, const NDIM: usize> ArrayDefaultIteratorByRef<'a, ArrayImpl, NDIM>
89where
90 ArrayImpl: Shape<NDIM>,
91{
92 pub fn new(arr: &'a Array<ArrayImpl, NDIM>) -> Self {
94 Self {
95 arr,
96 pos: 0,
97 nelements: arr.len(),
98 }
99 }
100}
101
102impl<'a, ArrayImpl, const NDIM: usize> ArrayDefaultIteratorMut<'a, ArrayImpl, NDIM>
103where
104 ArrayImpl: Shape<NDIM>,
105{
106 pub fn new(arr: &'a mut Array<ArrayImpl, NDIM>) -> Self {
108 let nelements = arr.len();
109 Self {
110 arr,
111 pos: 0,
112 nelements,
113 }
114 }
115}
116
117impl<ArrayImpl: UnsafeRandom1DAccessByValue, const NDIM: usize> std::iter::Iterator
118 for ArrayDefaultIteratorByValue<'_, ArrayImpl, NDIM>
119{
120 type Item = ArrayImpl::Item;
121 fn next(&mut self) -> Option<Self::Item> {
122 if self.pos >= self.nelements {
123 None
124 } else {
125 let value = unsafe { Some(self.arr.imp().get_value_1d_unchecked(self.pos)) };
126 self.pos += 1;
127 value
128 }
129 }
130}
131
132impl<'a, ArrayImpl: UnsafeRandom1DAccessByRef, const NDIM: usize> std::iter::Iterator
133 for ArrayDefaultIteratorByRef<'a, ArrayImpl, NDIM>
134{
135 type Item = &'a ArrayImpl::Item;
136 fn next(&mut self) -> Option<Self::Item> {
137 if self.pos >= self.nelements {
138 None
139 } else {
140 let value = unsafe { Some(self.arr.imp().get_1d_unchecked(self.pos)) };
141 self.pos += 1;
142 value
143 }
144 }
145}
146
147impl<'a, ArrayImpl: UnsafeRandom1DAccessMut, const NDIM: usize> std::iter::Iterator
155 for ArrayDefaultIteratorMut<'a, ArrayImpl, NDIM>
156{
157 type Item = &'a mut ArrayImpl::Item;
158 fn next(&mut self) -> Option<Self::Item> {
159 if self.pos >= self.nelements {
160 None
161 } else {
162 let value = unsafe {
163 std::mem::transmute::<
164 &mut <ArrayImpl as BaseItem>::Item,
165 &'a mut <ArrayImpl as BaseItem>::Item,
166 >(self.arr.imp_mut().get_1d_unchecked_mut(self.pos))
167 };
168 self.pos += 1;
169 Some(value)
170 }
171 }
172}
173
174pub struct RowIterator<'a, ArrayImpl, const NDIM: usize> {
176 arr: &'a Array<ArrayImpl, NDIM>,
177 nrows: usize,
178 current_row: usize,
179}
180
181impl<'a, ArrayImpl> RowIterator<'a, ArrayImpl, 2>
182where
183 ArrayImpl: Shape<2>,
184{
185 pub fn new(arr: &'a Array<ArrayImpl, 2>) -> Self {
187 let nrows = arr.shape()[0];
188 RowIterator {
189 arr,
190 nrows,
191 current_row: 0,
192 }
193 }
194}
195
196pub struct RowIteratorMut<'a, ArrayImpl, const NDIM: usize> {
198 arr: &'a mut Array<ArrayImpl, NDIM>,
199 nrows: usize,
200 current_row: usize,
201}
202
203impl<'a, ArrayImpl> RowIteratorMut<'a, ArrayImpl, 2>
204where
205 ArrayImpl: Shape<2>,
206{
207 pub fn new(arr: &'a mut Array<ArrayImpl, 2>) -> Self {
209 let nrows = arr.shape()[0];
210 RowIteratorMut {
211 arr,
212 nrows,
213 current_row: 0,
214 }
215 }
216}
217
218impl<'a, ArrayImpl> std::iter::Iterator for RowIterator<'a, ArrayImpl, 2>
219where
220 ArrayImpl: Shape<2>,
221{
222 type Item = Array<ArraySlice<reference::ArrayRef<'a, ArrayImpl, 2>, 2, 1>, 1>;
223 fn next(&mut self) -> Option<Self::Item> {
224 if self.current_row >= self.nrows {
225 return None;
226 }
227 let slice = self.arr.r().slice(0, self.current_row);
228 self.current_row += 1;
229 Some(slice)
230 }
231}
232
233impl<'a, ArrayImpl> std::iter::Iterator for RowIteratorMut<'a, ArrayImpl, 2>
234where
235 ArrayImpl: Shape<2>,
236{
237 type Item = Array<ArraySlice<ArrayRefMut<'a, ArrayImpl, 2>, 2, 1>, 1>;
238 fn next(&mut self) -> Option<Self::Item> {
239 if self.current_row >= self.nrows {
240 return None;
241 }
242 let slice = self.arr.r_mut().slice(0, self.current_row);
243 self.current_row += 1;
244 unsafe {
245 Some(std::mem::transmute::<
246 Array<ArraySlice<ArrayRefMut<'_, ArrayImpl, 2>, 2, 1>, 1>,
247 Array<ArraySlice<ArrayRefMut<'a, ArrayImpl, 2>, 2, 1>, 1>,
248 >(slice))
249 }
250 }
251}
252
253pub struct ColIterator<'a, ArrayImpl, const NDIM: usize> {
255 arr: &'a Array<ArrayImpl, NDIM>,
256 ncols: usize,
257 current_col: usize,
258}
259
260impl<'a, ArrayImpl> ColIterator<'a, ArrayImpl, 2>
261where
262 ArrayImpl: Shape<2>,
263{
264 pub fn new(arr: &'a Array<ArrayImpl, 2>) -> Self {
266 let ncols = arr.shape()[1];
267 ColIterator {
268 arr,
269 ncols,
270 current_col: 0,
271 }
272 }
273}
274
275pub struct ColIteratorMut<'a, ArrayImpl, const NDIM: usize> {
277 arr: &'a mut Array<ArrayImpl, NDIM>,
278 ncols: usize,
279 current_col: usize,
280}
281
282impl<'a, ArrayImpl> ColIteratorMut<'a, ArrayImpl, 2>
283where
284 ArrayImpl: Shape<2>,
285{
286 pub fn new(arr: &'a mut Array<ArrayImpl, 2>) -> Self {
288 let ncols = arr.shape()[1];
289 ColIteratorMut {
290 arr,
291 ncols,
292 current_col: 0,
293 }
294 }
295}
296
297impl<'a, ArrayImpl: Shape<2>> std::iter::Iterator for ColIterator<'a, ArrayImpl, 2> {
298 type Item = Array<ArraySlice<reference::ArrayRef<'a, ArrayImpl, 2>, 2, 1>, 1>;
299 fn next(&mut self) -> Option<Self::Item> {
300 if self.current_col >= self.ncols {
301 return None;
302 }
303 let slice = self.arr.r().slice(1, self.current_col);
304 self.current_col += 1;
305 Some(slice)
306 }
307}
308
309impl<'a, ArrayImpl> std::iter::Iterator for ColIteratorMut<'a, ArrayImpl, 2>
310where
311 ArrayImpl: Shape<2>,
312{
313 type Item = Array<ArraySlice<ArrayRefMut<'a, ArrayImpl, 2>, 2, 1>, 1>;
314 fn next(&mut self) -> Option<Self::Item> {
315 if self.current_col >= self.ncols {
316 return None;
317 }
318 let slice = self.arr.r_mut().slice(1, self.current_col);
319 self.current_col += 1;
320 unsafe {
321 Some(std::mem::transmute::<
322 Array<ArraySlice<ArrayRefMut<'_, ArrayImpl, 2>, 2, 1>, 1>,
323 Array<ArraySlice<ArrayRefMut<'a, ArrayImpl, 2>, 2, 1>, 1>,
324 >(slice))
325 }
326 }
327}
328
329pub struct ArrayDiagIteratorByValue<'a, ArrayImpl, const NDIM: usize> {
331 arr: &'a Array<ArrayImpl, NDIM>,
332 pos: usize,
333 nelements: usize,
334}
335
336impl<'a, ArrayImpl: Shape<NDIM>, const NDIM: usize> ArrayDiagIteratorByValue<'a, ArrayImpl, NDIM> {
337 pub fn new(arr: &'a Array<ArrayImpl, NDIM>) -> Self {
339 let nelements = *arr.shape().iter().min().unwrap();
340 ArrayDiagIteratorByValue {
341 arr,
342 pos: 0,
343 nelements,
344 }
345 }
346}
347
348impl<'a, ArrayImpl: UnsafeRandomAccessByValue<NDIM>, const NDIM: usize> Iterator
349 for ArrayDiagIteratorByValue<'a, ArrayImpl, NDIM>
350{
351 type Item = ArrayImpl::Item;
352
353 fn next(&mut self) -> Option<Self::Item> {
354 if self.pos >= self.nelements {
355 None
356 } else {
357 let value = unsafe { self.arr.get_value_unchecked([self.pos; NDIM]) };
358 self.pos += 1;
359 Some(value)
360 }
361 }
362}
363
364pub struct ArrayDiagIteratorByRef<'a, ArrayImpl, const NDIM: usize> {
366 arr: &'a Array<ArrayImpl, NDIM>,
367 pos: usize,
368 nelements: usize,
369}
370
371impl<'a, ArrayImpl: Shape<NDIM>, const NDIM: usize> ArrayDiagIteratorByRef<'a, ArrayImpl, NDIM> {
372 pub fn new(arr: &'a Array<ArrayImpl, NDIM>) -> Self {
374 let nelements = *arr.shape().iter().min().unwrap();
375 ArrayDiagIteratorByRef {
376 arr,
377 pos: 0,
378 nelements,
379 }
380 }
381}
382
383impl<'a, ArrayImpl: UnsafeRandomAccessByRef<NDIM>, const NDIM: usize> Iterator
384 for ArrayDiagIteratorByRef<'a, ArrayImpl, NDIM>
385{
386 type Item = &'a ArrayImpl::Item;
387
388 fn next(&mut self) -> Option<Self::Item> {
389 if self.pos >= self.nelements {
390 None
391 } else {
392 let value = unsafe { self.arr.get_unchecked([self.pos; NDIM]) };
393 self.pos += 1;
394 Some(value)
395 }
396 }
397}
398
399pub struct ArrayDiagIteratorMut<'a, ArrayImpl, const NDIM: usize> {
401 arr: &'a mut Array<ArrayImpl, NDIM>,
402 pos: usize,
403 nelements: usize,
404}
405
406impl<'a, ArrayImpl: Shape<NDIM>, const NDIM: usize> ArrayDiagIteratorMut<'a, ArrayImpl, NDIM> {
407 pub fn new(arr: &'a mut Array<ArrayImpl, NDIM>) -> Self {
409 let nelements = *arr.shape().iter().min().unwrap();
410 ArrayDiagIteratorMut {
411 arr,
412 pos: 0,
413 nelements,
414 }
415 }
416}
417
418impl<'a, ArrayImpl: UnsafeRandomAccessMut<NDIM>, const NDIM: usize> Iterator
419 for ArrayDiagIteratorMut<'a, ArrayImpl, NDIM>
420{
421 type Item = &'a mut ArrayImpl::Item;
422
423 fn next(&mut self) -> Option<Self::Item> {
424 if self.pos >= self.nelements {
425 None
426 } else {
427 let value = unsafe {
428 std::mem::transmute::<&mut ArrayImpl::Item, Self::Item>(
429 self.arr.get_unchecked_mut([self.pos; NDIM]),
430 )
431 };
432 self.pos += 1;
433 Some(value)
434 }
435 }
436}
437
438pub struct MultisliceIterator<
440 'a,
441 Item,
442 ArrayImpl,
443 const NDIM: usize,
444 const SDIM: usize,
445 const OUTDIM: usize,
446> where
447 Const<NDIM>: ToUInt,
448 Const<SDIM>: ToUInt,
449 Const<OUTDIM>: ToUInt,
450 <Const<NDIM> as ToUInt>::Output: Unsigned + NonZero,
451 <Const<SDIM> as ToUInt>::Output: Unsigned + NonZero,
452 <Const<OUTDIM> as ToUInt>::Output: Unsigned + NonZero,
453 NumberType<SDIM>: IsSmallerThan<NDIM>,
454 PInt<U<OUTDIM>>: Add<PInt<U<SDIM>>, Output = PInt<U<NDIM>>>,
455 ArrayImpl: UnsafeRandom1DAccessByValue<Item = Item>
456 + UnsafeRandomAccessByValue<NDIM, Item = Item>
457 + Shape<NDIM>,
458{
459 arr: &'a Array<ArrayImpl, NDIM>,
460 axes: [usize; SDIM],
461 axes_shape: [usize; SDIM],
462 nindices: usize,
463 indices: [usize; SDIM],
464 is_finished: bool,
465 _marker: PhantomData<Item>,
466}
467
468impl<'a, Item, ArrayImpl, const NDIM: usize, const SDIM: usize, const OUTDIM: usize>
469 MultisliceIterator<'a, Item, ArrayImpl, NDIM, SDIM, OUTDIM>
470where
471 Const<NDIM>: ToUInt,
472 Const<SDIM>: ToUInt,
473 Const<OUTDIM>: ToUInt,
474 <Const<NDIM> as ToUInt>::Output: Unsigned + NonZero,
475 <Const<SDIM> as ToUInt>::Output: Unsigned + NonZero,
476 <Const<OUTDIM> as ToUInt>::Output: Unsigned + NonZero,
477 NumberType<SDIM>: IsSmallerThan<NDIM>,
478 PInt<U<OUTDIM>>: Add<PInt<U<SDIM>>, Output = PInt<U<NDIM>>>,
479 ArrayImpl: UnsafeRandom1DAccessByValue<Item = Item>
480 + UnsafeRandomAccessByValue<NDIM, Item = Item>
481 + Shape<NDIM>,
482{
483 pub fn new(arr: &'a Array<ArrayImpl, NDIM>, axes: [usize; SDIM]) -> Self {
487 if !axes
488 .iter()
489 .tuple_windows()
490 .all(|(elem1, elem2)| elem2 > elem1)
491 {
492 panic!("`axes` must be sorted in ascending order.");
493 }
494
495 let shape = arr.shape();
496 let axes_shape = {
497 let mut tmp = [0; SDIM];
498
499 for (index, elem) in tmp.iter_mut().enumerate() {
500 *elem = shape[axes[index]];
501 }
502 tmp
503 };
504
505 Self {
506 arr,
507 axes,
508 axes_shape,
509 nindices: axes_shape.iter().product::<usize>(),
510 indices: [0; SDIM],
511 is_finished: false,
512 _marker: PhantomData,
513 }
514 }
515}
516
517impl<'a, Item, ArrayImpl, const NDIM: usize, const SDIM: usize, const OUTDIM: usize> Iterator
518 for MultisliceIterator<'a, Item, ArrayImpl, NDIM, SDIM, OUTDIM>
519where
520 Const<NDIM>: ToUInt,
521 Const<SDIM>: ToUInt,
522 Const<OUTDIM>: ToUInt,
523 <Const<NDIM> as ToUInt>::Output: Unsigned + NonZero,
524 <Const<SDIM> as ToUInt>::Output: Unsigned + NonZero,
525 <Const<OUTDIM> as ToUInt>::Output: Unsigned + NonZero,
526 NumberType<SDIM>: IsSmallerThan<NDIM>,
527 PInt<U<OUTDIM>>: Add<PInt<U<SDIM>>, Output = PInt<U<NDIM>>>,
528 ArrayImpl: UnsafeRandom1DAccessByValue<Item = Item>
529 + UnsafeRandomAccessByValue<NDIM, Item = Item>
530 + Shape<NDIM>,
531{
532 type Item = (
533 [usize; SDIM],
534 Array<Multislice<ArrayRef<'a, ArrayImpl, NDIM>, NDIM, SDIM, OUTDIM>, OUTDIM>,
535 );
536
537 fn next(&mut self) -> Option<Self::Item> {
538 if self.is_finished {
539 None
540 } else {
541 let out = (
542 self.indices,
543 self.arr.r().multislice(self.axes, self.indices),
544 );
545
546 let index_1d =
547 convert_nd_raw(self.indices, col_major_stride_from_shape(self.axes_shape));
548
549 if 1 + index_1d < self.nindices {
550 self.indices = convert_1d_nd_from_shape(1 + index_1d, self.axes_shape);
551 } else {
552 self.is_finished = true;
553 }
554
555 Some(out)
556 }
557 }
558}
559
560pub struct MultisliceIteratorMut<
562 'a,
563 Item,
564 ArrayImpl,
565 const NDIM: usize,
566 const SDIM: usize,
567 const OUTDIM: usize,
568> where
569 Const<NDIM>: ToUInt,
570 Const<SDIM>: ToUInt,
571 Const<OUTDIM>: ToUInt,
572 <Const<NDIM> as ToUInt>::Output: Unsigned + NonZero,
573 <Const<SDIM> as ToUInt>::Output: Unsigned + NonZero,
574 <Const<OUTDIM> as ToUInt>::Output: Unsigned + NonZero,
575 NumberType<SDIM>: IsSmallerThan<NDIM>,
576 PInt<U<OUTDIM>>: Add<PInt<U<SDIM>>, Output = PInt<U<NDIM>>>,
577 ArrayImpl: UnsafeRandom1DAccessMut<Item = Item>
578 + UnsafeRandomAccessMut<NDIM, Item = Item>
579 + Shape<NDIM>,
580{
581 arr: &'a mut Array<ArrayImpl, NDIM>,
582 axes: [usize; SDIM],
583 axes_shape: [usize; SDIM],
584 nindices: usize,
585 indices: [usize; SDIM],
586 is_finished: bool,
587 _marker: PhantomData<Item>,
588}
589
590impl<'a, Item, ArrayImpl, const NDIM: usize, const SDIM: usize, const OUTDIM: usize>
591 MultisliceIteratorMut<'a, Item, ArrayImpl, NDIM, SDIM, OUTDIM>
592where
593 Const<NDIM>: ToUInt,
594 Const<SDIM>: ToUInt,
595 Const<OUTDIM>: ToUInt,
596 <Const<NDIM> as ToUInt>::Output: Unsigned + NonZero,
597 <Const<SDIM> as ToUInt>::Output: Unsigned + NonZero,
598 <Const<OUTDIM> as ToUInt>::Output: Unsigned + NonZero,
599 NumberType<SDIM>: IsSmallerThan<NDIM>,
600 PInt<U<OUTDIM>>: Add<PInt<U<SDIM>>, Output = PInt<U<NDIM>>>,
601 ArrayImpl: UnsafeRandom1DAccessMut<Item = Item>
602 + UnsafeRandomAccessMut<NDIM, Item = Item>
603 + Shape<NDIM>,
604{
605 pub fn new(arr: &'a mut Array<ArrayImpl, NDIM>, axes: [usize; SDIM]) -> Self {
609 if !axes
610 .iter()
611 .tuple_windows()
612 .all(|(elem1, elem2)| elem2 > elem1)
613 {
614 panic!("`axes` must be sorted in ascending order.");
615 }
616
617 let shape = arr.shape();
618 let axes_shape = {
619 let mut tmp = [0; SDIM];
620
621 for (index, elem) in tmp.iter_mut().enumerate() {
622 *elem = shape[axes[index]];
623 }
624 tmp
625 };
626
627 Self {
628 arr,
629 axes,
630 axes_shape,
631 nindices: axes_shape.iter().product::<usize>(),
632 indices: [0; SDIM],
633 is_finished: false,
634 _marker: PhantomData,
635 }
636 }
637}
638
639impl<'a, Item, ArrayImpl, const NDIM: usize, const SDIM: usize, const OUTDIM: usize> Iterator
640 for MultisliceIteratorMut<'a, Item, ArrayImpl, NDIM, SDIM, OUTDIM>
641where
642 Const<NDIM>: ToUInt,
643 Const<SDIM>: ToUInt,
644 Const<OUTDIM>: ToUInt,
645 <Const<NDIM> as ToUInt>::Output: Unsigned + NonZero,
646 <Const<SDIM> as ToUInt>::Output: Unsigned + NonZero,
647 <Const<OUTDIM> as ToUInt>::Output: Unsigned + NonZero,
648 NumberType<SDIM>: IsSmallerThan<NDIM>,
649 PInt<U<OUTDIM>>: Add<PInt<U<SDIM>>, Output = PInt<U<NDIM>>>,
650 ArrayImpl: UnsafeRandom1DAccessMut<Item = Item>
651 + UnsafeRandomAccessMut<NDIM, Item = Item>
652 + Shape<NDIM>,
653{
654 type Item = (
655 [usize; SDIM],
656 Array<Multislice<ArrayRefMut<'a, ArrayImpl, NDIM>, NDIM, SDIM, OUTDIM>, OUTDIM>,
657 );
658
659 fn next(&mut self) -> Option<Self::Item> {
660 if self.is_finished {
661 None
662 } else {
663 let out = (self.indices, unsafe {
664 std::mem::transmute::<
665 Array<Multislice<ArrayRefMut<'_, ArrayImpl, NDIM>, NDIM, SDIM, OUTDIM>, OUTDIM>,
666 Array<Multislice<ArrayRefMut<'a, ArrayImpl, NDIM>, NDIM, SDIM, OUTDIM>, OUTDIM>,
667 >(self.arr.r_mut().multislice(self.axes, self.indices))
668 });
669
670 let index_1d =
671 convert_nd_raw(self.indices, col_major_stride_from_shape(self.axes_shape));
672
673 if 1 + index_1d < self.nindices {
674 self.indices = convert_1d_nd_from_shape(1 + index_1d, self.axes_shape);
675 } else {
676 self.is_finished = true;
677 }
678
679 Some(out)
680 }
681 }
682}