rstsr_common/layout/
iterator.rs

1//! Layout (double-ended) iterator.
2
3use crate::prelude_dev::*;
4
5/* #region col-major */
6
7/// Layout iterator (column-major).
8///
9/// This iterator only handles column-major iterator.
10/// For other iteration orders, use function [`translate_to_col_major`] to
11/// generate the corresponding col-major (f-prefer) layout, then iterate as
12/// col-major.
13///
14/// # Note
15///
16/// This crate implements col-major iterator only; the layout iterator that
17/// actaully works is internal realization; though it's public struct, it is not
18/// intended to be exposed to user.
19/// Choosing col-major iterator is because it is possibly the most efficient
20/// way. It is not related to default order, which could be defined by crate
21/// feature `f_prefer`.
22#[derive(Clone, Debug)]
23pub struct IterLayoutColMajor<D>
24where
25    D: DimDevAPI,
26{
27    pub(crate) layout: Layout<D>,
28
29    pub(crate) index_start: D, // this is not used for buffer-order
30    pub(crate) iter_start: usize,
31    pub(crate) offset_start: isize,
32
33    pub(crate) index_end: D, // this is not used for buffer-order
34    pub(crate) iter_end: usize,
35    pub(crate) offset_end: isize,
36}
37
38impl<D> IterLayoutColMajor<D>
39where
40    D: DimDevAPI,
41{
42    pub fn index_start(&self) -> &D {
43        &self.index_start
44    }
45
46    pub fn index_end(&self) -> &D {
47        &self.index_end
48    }
49
50    pub fn iter_start(&self) -> usize {
51        self.iter_start
52    }
53
54    pub fn iter_end(&self) -> usize {
55        self.iter_end
56    }
57
58    pub fn offset_start(&self) -> isize {
59        self.offset_start
60    }
61
62    pub fn offset_end(&self) -> isize {
63        self.offset_end
64    }
65}
66
67impl<D> IterLayoutColMajor<D>
68where
69    D: DimDevAPI,
70{
71    /// This function generates col-major (f-prefer) layout, then give its
72    /// iterator object.
73    pub fn new(layout: &Layout<D>) -> Result<Self> {
74        let layout = layout.clone();
75        let shape = layout.shape();
76        let iter_start = 0;
77        let iter_end = layout.size();
78        let index_start = layout.new_shape();
79        let index_end = unsafe { shape.unravel_index_f(iter_end) };
80        let offset_start = layout.offset() as isize;
81        let offset_end = unsafe { layout.index_uncheck(index_end.as_ref()) };
82
83        return Ok(Self {
84            layout,
85            index_start,
86            iter_start,
87            offset_start,
88            index_end,
89            iter_end,
90            offset_end,
91        });
92    }
93
94    pub fn split_at(&self, index: usize) -> Result<(Self, Self)> {
95        let Self { layout, index_start, iter_start, offset_start, index_end, iter_end, offset_end } =
96            self.clone();
97        let shape = layout.shape();
98        let iter_ins = iter_start + index;
99        let index_ins = unsafe { shape.unravel_index_f(iter_ins) };
100        let offset_ins = unsafe { layout.index_uncheck(index_ins.as_ref()) };
101        let split_lhs = Self {
102            layout: layout.clone(),
103            index_start,
104            iter_start,
105            offset_start,
106            index_end: index_ins.clone(),
107            iter_end: iter_ins,
108            offset_end: offset_ins,
109        };
110        let split_rhs = Self {
111            layout: layout.clone(),
112            index_start: index_ins,
113            iter_start: iter_ins,
114            offset_start: offset_ins,
115            index_end,
116            iter_end,
117            offset_end,
118        };
119        return Ok((split_lhs, split_rhs));
120    }
121}
122
123impl<D> IterLayoutColMajor<D>
124where
125    D: DimDevAPI,
126{
127    #[inline]
128    fn next_iter_index(&mut self) {
129        let layout = &self.layout;
130        let index = self.index_start.as_mut();
131        let mut offset = self.offset_start;
132        let shape = layout.shape().as_ref();
133        let stride = layout.stride().as_ref();
134        match layout.ndim() {
135            0 => (),
136            1 => {
137                index[0] += 1;
138                offset += stride[0];
139            },
140            2 => {
141                index[0] += 1;
142                offset += stride[0];
143                if index[0] == shape[0] {
144                    index[0] = 0;
145                    offset -= shape[0] as isize * stride[0];
146                    index[1] += 1;
147                    offset += stride[1];
148                }
149            },
150            3 => {
151                index[0] += 1;
152                offset += stride[0];
153                if index[0] == shape[0] {
154                    index[0] = 0;
155                    offset -= shape[0] as isize * stride[0];
156                    index[1] += 1;
157                    offset += stride[1];
158                    if index[1] == shape[1] {
159                        index[1] = 0;
160                        offset -= shape[1] as isize * stride[1];
161                        index[2] += 1;
162                        offset += stride[2];
163                    }
164                }
165            },
166            4 => {
167                index[0] += 1;
168                offset += stride[0];
169                if index[0] == shape[0] {
170                    index[0] = 0;
171                    offset -= shape[0] as isize * stride[0];
172                    index[1] += 1;
173                    offset += stride[1];
174                    if index[1] == shape[1] {
175                        index[1] = 0;
176                        offset -= shape[1] as isize * stride[1];
177                        index[2] += 1;
178                        offset += stride[2];
179                        if index[2] == shape[2] {
180                            index[2] = 0;
181                            offset -= shape[2] as isize * stride[2];
182                            index[3] += 1;
183                            offset += stride[3];
184                        }
185                    }
186                }
187            },
188            _ => {
189                for (d, t, idx) in izip!(shape, stride, index.as_mut()) {
190                    *idx += 1;
191                    offset += t;
192                    if idx == d {
193                        *idx = 0;
194                        offset -= *d as isize * t;
195                    } else {
196                        break;
197                    }
198                }
199            },
200        }
201        self.offset_start = offset;
202        self.iter_start += 1;
203    }
204
205    #[inline]
206    fn back_iter_index(&mut self) {
207        let layout = &self.layout;
208        let index = self.index_end.as_mut();
209        let mut offset = self.offset_end;
210        let shape = layout.shape().as_ref();
211        let stride = layout.stride().as_ref();
212        match layout.ndim() {
213            0 => (),
214            1 => {
215                index[0] -= 1;
216                offset -= stride[0];
217            },
218            2 => {
219                if index[0] == 0 {
220                    index[0] = shape[0] - 1;
221                    offset += (shape[0] - 1) as isize * stride[0];
222                    index[1] -= 1;
223                    offset -= stride[1];
224                } else {
225                    index[0] -= 1;
226                    offset -= stride[0];
227                }
228            },
229            3 => {
230                if index[0] == 0 {
231                    index[0] = shape[0] - 1;
232                    offset += (shape[0] - 1) as isize * stride[0];
233                    if index[1] == 0 {
234                        index[1] = shape[1] - 1;
235                        offset += (shape[1] - 1) as isize * stride[1];
236                        index[2] -= 1;
237                        offset -= stride[2];
238                    } else {
239                        index[1] -= 1;
240                        offset -= stride[1];
241                    }
242                } else {
243                    index[0] -= 1;
244                    offset -= stride[0];
245                }
246            },
247            4 => {
248                if index[0] == 0 {
249                    index[0] = shape[0] - 1;
250                    offset += (shape[0] - 1) as isize * stride[0];
251                    if index[1] == 0 {
252                        index[1] = shape[1] - 1;
253                        offset += (shape[1] - 1) as isize * stride[1];
254                        if index[2] == 0 {
255                            index[2] = shape[2] - 1;
256                            offset += (shape[2] - 1) as isize * stride[2];
257                            index[3] -= 1;
258                            offset -= stride[3];
259                        } else {
260                            index[2] -= 1;
261                            offset -= stride[2];
262                        }
263                    } else {
264                        index[1] -= 1;
265                        offset -= stride[1];
266                    }
267                } else {
268                    index[0] -= 1;
269                    offset -= stride[0];
270                }
271            },
272            _ => {
273                for (d, t, idx) in izip!(shape, stride, index.as_mut()) {
274                    if *idx == 0 {
275                        *idx = *d - 1;
276                        offset += (*d - 1) as isize * t;
277                    } else {
278                        *idx -= 1;
279                        offset -= t;
280                        break;
281                    }
282                }
283            },
284        }
285        self.offset_end = offset;
286        self.iter_end -= 1;
287    }
288}
289
290impl<D> Iterator for IterLayoutColMajor<D>
291where
292    D: DimDevAPI,
293{
294    type Item = usize;
295
296    fn next(&mut self) -> Option<Self::Item> {
297        if self.iter_start >= self.iter_end {
298            return None;
299        }
300        let offset = self.offset_start;
301        self.next_iter_index();
302        return Some(offset.try_into().unwrap());
303    }
304}
305
306impl<D> DoubleEndedIterator for IterLayoutColMajor<D>
307where
308    D: DimDevAPI,
309{
310    fn next_back(&mut self) -> Option<Self::Item> {
311        if self.iter_start >= self.iter_end {
312            return None;
313        }
314        self.back_iter_index();
315        let offset = self.offset_end;
316        return Some(offset.try_into().unwrap());
317    }
318}
319
320impl<D> ExactSizeIterator for IterLayoutColMajor<D>
321where
322    D: DimDevAPI,
323{
324    fn len(&self) -> usize {
325        self.iter_end - self.iter_start
326    }
327}
328
329impl<D> IterSplitAtAPI for IterLayoutColMajor<D>
330where
331    D: DimDevAPI,
332{
333    fn split_at(self, index: usize) -> (Self, Self) {
334        Self::split_at(&self, index).unwrap()
335    }
336}
337
338/* #endregion */
339
340/* #region row-major */
341
342/// Layout iterator (row-major).
343///
344/// This iterator only handles row-major iterator.
345///
346/// # Note
347///
348/// This crate implements row-major iterator only; the layout iterator that
349/// actaully works is internal realization; though it's public struct, it is not
350/// intended to be exposed to user.
351#[derive(Debug, Clone)]
352pub struct IterLayoutRowMajor<D>
353where
354    D: DimDevAPI,
355{
356    pub(crate) layout: Layout<D>,
357
358    pub(crate) index_start: D, // this is not used for buffer-order
359    pub(crate) iter_start: usize,
360    pub(crate) offset_start: isize,
361
362    pub(crate) index_end: D, // this is not used for buffer-order
363    pub(crate) iter_end: usize,
364    pub(crate) offset_end: isize,
365}
366
367impl<D> IterLayoutRowMajor<D>
368where
369    D: DimDevAPI,
370{
371    pub fn index_start(&self) -> &D {
372        &self.index_start
373    }
374
375    pub fn index_end(&self) -> &D {
376        &self.index_end
377    }
378
379    pub fn iter_start(&self) -> usize {
380        self.iter_start
381    }
382
383    pub fn iter_end(&self) -> usize {
384        self.iter_end
385    }
386
387    pub fn offset_start(&self) -> isize {
388        self.offset_start
389    }
390
391    pub fn offset_end(&self) -> isize {
392        self.offset_end
393    }
394}
395
396impl<D> IterLayoutRowMajor<D>
397where
398    D: DimDevAPI,
399{
400    /// This function generates row-major (c-prefer) layout, then give its
401    /// iterator object.
402    pub fn new(layout: &Layout<D>) -> Result<Self> {
403        let layout = layout.clone();
404        let shape = layout.shape();
405        let iter_start = 0;
406        let iter_end = layout.size();
407        let index_start = layout.new_shape();
408        let index_end = unsafe { shape.unravel_index_c(iter_end) };
409        let offset_start = layout.offset() as isize;
410        let offset_end = unsafe { layout.index_uncheck(index_end.as_ref()) };
411
412        return Ok(Self {
413            layout,
414            index_start,
415            iter_start,
416            offset_start,
417            index_end,
418            iter_end,
419            offset_end,
420        });
421    }
422
423    pub fn split_at(&self, index: usize) -> Result<(Self, Self)> {
424        let Self { layout, index_start, iter_start, offset_start, index_end, iter_end, offset_end } =
425            self.clone();
426        let shape = layout.shape();
427        let iter_ins = iter_start + index;
428        let index_ins = unsafe { shape.unravel_index_c(iter_ins) };
429        let offset_ins = unsafe { layout.index_uncheck(index_ins.as_ref()) };
430        let split_lhs = Self {
431            layout: layout.clone(),
432            index_start,
433            iter_start,
434            offset_start,
435            index_end: index_ins.clone(),
436            iter_end: iter_ins,
437            offset_end: offset_ins,
438        };
439        let split_rhs = Self {
440            layout: layout.clone(),
441            index_start: index_ins,
442            iter_start: iter_ins,
443            offset_start: offset_ins,
444            index_end,
445            iter_end,
446            offset_end,
447        };
448        return Ok((split_lhs, split_rhs));
449    }
450}
451
452impl<D> IterLayoutRowMajor<D>
453where
454    D: DimDevAPI,
455{
456    #[inline]
457    fn next_iter_index(&mut self) {
458        let layout = &self.layout;
459        let index = self.index_start.as_mut();
460        let mut offset = self.offset_start;
461        let shape = layout.shape().as_ref();
462        let stride = layout.stride().as_ref();
463        match layout.ndim() {
464            0 => (),
465            1 => {
466                index[0] += 1;
467                offset += stride[0];
468            },
469            2 => {
470                index[1] += 1;
471                offset += stride[1];
472                if index[1] == shape[1] {
473                    index[1] = 0;
474                    offset -= shape[1] as isize * stride[1];
475                    index[0] += 1;
476                    offset += stride[0];
477                }
478            },
479            3 => {
480                index[2] += 1;
481                offset += stride[2];
482                if index[2] == shape[2] {
483                    index[2] = 0;
484                    offset -= shape[2] as isize * stride[2];
485                    index[1] += 1;
486                    offset += stride[1];
487                    if index[1] == shape[1] {
488                        index[1] = 0;
489                        offset -= shape[1] as isize * stride[1];
490                        index[0] += 1;
491                        offset += stride[0];
492                    }
493                }
494            },
495            4 => {
496                index[3] += 1;
497                offset += stride[3];
498                if index[3] == shape[3] {
499                    index[3] = 0;
500                    offset -= shape[3] as isize * stride[3];
501                    index[2] += 1;
502                    offset += stride[2];
503                    if index[2] == shape[2] {
504                        index[2] = 0;
505                        offset -= shape[2] as isize * stride[2];
506                        index[1] += 1;
507                        offset += stride[1];
508                        if index[1] == shape[1] {
509                            index[1] = 0;
510                            offset -= shape[1] as isize * stride[1];
511                            index[0] += 1;
512                            offset += stride[0];
513                        }
514                    }
515                }
516            },
517            _ => {
518                for (d, t, idx) in izip!(shape, stride, index.as_mut()).rev() {
519                    *idx += 1;
520                    offset += t;
521                    if idx == d {
522                        *idx = 0;
523                        offset -= *d as isize * t;
524                    } else {
525                        break;
526                    }
527                }
528            },
529        }
530        self.offset_start = offset;
531        self.iter_start += 1;
532    }
533
534    #[inline]
535    fn back_iter_index(&mut self) {
536        let layout = &self.layout;
537        let index = self.index_end.as_mut();
538        let mut offset = self.offset_end;
539        let shape = layout.shape().as_ref();
540        let stride = layout.stride().as_ref();
541        match layout.ndim() {
542            0 => (),
543            1 => {
544                index[0] -= 1;
545                offset -= stride[0];
546            },
547            2 => {
548                if index[1] == 0 {
549                    index[1] = shape[1] - 1;
550                    offset += (shape[1] - 1) as isize * stride[1];
551                    index[0] -= 1;
552                    offset -= stride[0];
553                } else {
554                    index[1] -= 1;
555                    offset -= stride[1];
556                }
557            },
558            3 => {
559                if index[2] == 0 {
560                    index[2] = shape[2] - 1;
561                    offset += (shape[2] - 1) as isize * stride[2];
562                    if index[1] == 0 {
563                        index[1] = shape[1] - 1;
564                        offset += (shape[1] - 1) as isize * stride[1];
565                        index[0] -= 1;
566                        offset -= stride[0];
567                    } else {
568                        index[1] -= 1;
569                        offset -= stride[1];
570                    }
571                } else {
572                    index[2] -= 1;
573                    offset -= stride[2];
574                }
575            },
576            4 => {
577                if index[3] == 0 {
578                    index[3] = shape[3] - 1;
579                    offset += (shape[3] - 1) as isize * stride[3];
580                    if index[2] == 0 {
581                        index[2] = shape[2] - 1;
582                        offset += (shape[2] - 1) as isize * stride[2];
583                        if index[1] == 0 {
584                            index[1] = shape[1] - 1;
585                            offset += (shape[1] - 1) as isize * stride[1];
586                            index[0] -= 1;
587                            offset -= stride[0];
588                        } else {
589                            index[1] -= 1;
590                            offset -= stride[1];
591                        }
592                    } else {
593                        index[2] -= 1;
594                        offset -= stride[2];
595                    }
596                } else {
597                    index[3] -= 1;
598                    offset -= stride[3];
599                }
600            },
601            _ => {
602                for (d, t, idx) in izip!(shape, stride, index.as_mut()).rev() {
603                    if *idx == 0 {
604                        *idx = *d - 1;
605                        offset += (*d - 1) as isize * t;
606                    } else {
607                        *idx -= 1;
608                        offset -= t;
609                        break;
610                    }
611                }
612            },
613        }
614        self.offset_end = offset;
615        self.iter_end -= 1;
616    }
617}
618
619impl<D> Iterator for IterLayoutRowMajor<D>
620where
621    D: DimDevAPI,
622{
623    type Item = usize;
624
625    fn next(&mut self) -> Option<Self::Item> {
626        if self.iter_start >= self.iter_end {
627            return None;
628        }
629        let offset = self.offset_start;
630        self.next_iter_index();
631        return Some(offset.try_into().unwrap());
632    }
633}
634
635impl<D> DoubleEndedIterator for IterLayoutRowMajor<D>
636where
637    D: DimDevAPI,
638{
639    fn next_back(&mut self) -> Option<Self::Item> {
640        if self.iter_start >= self.iter_end {
641            return None;
642        }
643        self.back_iter_index();
644        let offset = self.offset_end;
645        return Some(offset.try_into().unwrap());
646    }
647}
648
649impl<D> ExactSizeIterator for IterLayoutRowMajor<D>
650where
651    D: DimDevAPI,
652{
653    fn len(&self) -> usize {
654        self.iter_end - self.iter_start
655    }
656}
657
658impl<D> IterSplitAtAPI for IterLayoutRowMajor<D>
659where
660    D: DimDevAPI,
661{
662    fn split_at(self, index: usize) -> (Self, Self) {
663        Self::split_at(&self, index).unwrap()
664    }
665}
666
667/* #endregion */
668
669/* #region enum of layout iterator */
670
671#[derive(Clone, Debug)]
672pub enum IterLayout<D>
673where
674    D: DimDevAPI,
675{
676    RowMajor(IterLayoutRowMajor<D>),
677    ColMajor(IterLayoutColMajor<D>),
678}
679
680impl<D> IterLayout<D>
681where
682    D: DimDevAPI,
683{
684    pub fn new(layout: &Layout<D>, order: TensorIterOrder) -> Result<Self> {
685        use TensorIterOrder::*;
686        match order {
687            C => {
688                let iter = IterLayoutRowMajor::new(layout)?;
689                return Ok(Self::RowMajor(iter));
690            },
691            F => {
692                let iter = IterLayoutColMajor::new(layout)?;
693                return Ok(Self::ColMajor(iter));
694            },
695            A => match FlagOrder::default() {
696                RowMajor => {
697                    let iter = IterLayoutRowMajor::new(layout)?;
698                    return Ok(Self::RowMajor(iter));
699                },
700                ColMajor => {
701                    let iter = IterLayoutColMajor::new(layout)?;
702                    return Ok(Self::ColMajor(iter));
703                },
704            },
705            K | G => {
706                let layout = translate_to_col_major_unary(layout, TensorIterOrder::K)?;
707                let iter = IterLayoutColMajor::new(&layout)?;
708                return Ok(Self::ColMajor(iter));
709            },
710            _ => rstsr_raise!(InvalidValue),
711        }
712    }
713}
714
715impl<D> Iterator for IterLayout<D>
716where
717    D: DimDevAPI,
718{
719    type Item = usize;
720
721    fn next(&mut self) -> Option<Self::Item> {
722        match self {
723            Self::RowMajor(iter) => iter.next(),
724            Self::ColMajor(iter) => iter.next(),
725        }
726    }
727}
728
729impl<D> DoubleEndedIterator for IterLayout<D>
730where
731    D: DimDevAPI,
732{
733    fn next_back(&mut self) -> Option<Self::Item> {
734        match self {
735            Self::RowMajor(iter) => iter.next_back(),
736            Self::ColMajor(iter) => iter.next_back(),
737        }
738    }
739}
740
741impl<D> ExactSizeIterator for IterLayout<D>
742where
743    D: DimDevAPI,
744{
745    fn len(&self) -> usize {
746        match self {
747            Self::RowMajor(iter) => iter.len(),
748            Self::ColMajor(iter) => iter.len(),
749        }
750    }
751}
752
753impl<D> IterSplitAtAPI for IterLayout<D>
754where
755    D: DimDevAPI,
756{
757    fn split_at(self, index: usize) -> (Self, Self) {
758        match self {
759            Self::RowMajor(iter) => {
760                let (lhs, rhs) = iter.split_at(index);
761                (Self::RowMajor(lhs), Self::RowMajor(rhs))
762            },
763            Self::ColMajor(iter) => {
764                let (lhs, rhs) = iter.split_at(index);
765                (Self::ColMajor(lhs), Self::ColMajor(rhs))
766            },
767        }
768    }
769}
770
771/* #endregion */
772
773/* #region layout iterator with index */
774
775#[derive(Clone, Debug)]
776pub struct IndexedIterLayout<D>
777where
778    D: DimDevAPI,
779{
780    pub(crate) layout_iter: IterLayout<D>,
781}
782
783impl<D> IndexedIterLayout<D>
784where
785    D: DimDevAPI,
786{
787    pub fn new(layout: &Layout<D>, order: FlagOrder) -> Result<Self> {
788        let order = match order {
789            RowMajor => TensorIterOrder::C,
790            ColMajor => TensorIterOrder::F,
791        };
792        Ok(Self { layout_iter: IterLayout::new(layout, order)? })
793    }
794}
795
796impl<D> Iterator for IndexedIterLayout<D>
797where
798    D: DimDevAPI,
799{
800    type Item = (D, usize);
801
802    fn next(&mut self) -> Option<Self::Item> {
803        let index = match &self.layout_iter {
804            IterLayout::ColMajor(iter_inner) => iter_inner.index_start.clone(),
805            IterLayout::RowMajor(iter_inner) => iter_inner.index_start.clone(),
806        };
807        self.layout_iter.next().map(|offset| (index, offset))
808    }
809}
810
811impl<D> DoubleEndedIterator for IndexedIterLayout<D>
812where
813    D: DimDevAPI,
814{
815    fn next_back(&mut self) -> Option<Self::Item> {
816        let index = match &self.layout_iter {
817            IterLayout::ColMajor(iter_inner) => iter_inner.index_start.clone(),
818            IterLayout::RowMajor(iter_inner) => iter_inner.index_start.clone(),
819        };
820        self.layout_iter.next_back().map(|offset| (index, offset))
821    }
822}
823
824impl<D> ExactSizeIterator for IndexedIterLayout<D>
825where
826    D: DimDevAPI,
827{
828    fn len(&self) -> usize {
829        self.layout_iter.len()
830    }
831}
832
833impl<D> IterSplitAtAPI for IndexedIterLayout<D>
834where
835    D: DimDevAPI,
836{
837    fn split_at(self, mid: usize) -> (Self, Self) {
838        let (lhs, rhs) = self.layout_iter.split_at(mid);
839        let lhs = IndexedIterLayout { layout_iter: lhs };
840        let rhs = IndexedIterLayout { layout_iter: rhs };
841        (lhs, rhs)
842    }
843}
844
845/* #endregion */
846
847/* #region col-major layout dim dispatch */
848
849#[allow(unused_mut)]
850pub fn layout_col_major_dim_dispatch_1<D, F>(la: &Layout<D>, mut f: F) -> Result<()>
851where
852    D: DimAPI,
853    F: FnMut(usize),
854{
855    #[cfg(feature = "dispatch_dim_layout_iter")]
856    {
857        macro_rules! dispatch {
858            ($dim: ident) => {{
859                let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dim>()?)?;
860                iter_a.for_each(f);
861            }};
862        }
863        match la.ndim() {
864            0 => f(la.offset()),
865            1 => dispatch!(Ix1),
866            2 => dispatch!(Ix2),
867            3 => dispatch!(Ix3),
868            4 => dispatch!(Ix4),
869            5 => dispatch!(Ix5),
870            6 => dispatch!(Ix6),
871            _ => {
872                let iter_a = IterLayoutColMajor::new(la)?;
873                iter_a.for_each(f);
874            },
875        }
876    }
877
878    #[cfg(not(feature = "dispatch_dim_layout_iter"))]
879    {
880        let iter_a = IterLayoutColMajor::new(la)?;
881        iter_a.for_each(f);
882    }
883    Ok(())
884}
885
886#[allow(unused_mut)]
887pub fn layout_col_major_dim_dispatch_2<D, F>(la: &Layout<D>, lb: &Layout<D>, mut f: F) -> Result<()>
888where
889    D: DimAPI,
890    F: FnMut((usize, usize)),
891{
892    rstsr_assert_eq!(la.ndim(), lb.ndim(), RuntimeError)?;
893
894    #[cfg(feature = "dispatch_dim_layout_iter")]
895    {
896        macro_rules! dispatch {
897            ($dim: ident) => {{
898                let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dim>()?)?;
899                let iter_b = IterLayoutColMajor::new(&lb.to_dim::<$dim>()?)?;
900                izip!(iter_a, iter_b).for_each(f);
901            }};
902        }
903        match la.ndim() {
904            0 => f((la.offset(), lb.offset())),
905            1 => dispatch!(Ix1),
906            2 => dispatch!(Ix2),
907            3 => dispatch!(Ix3),
908            4 => dispatch!(Ix4),
909            5 => dispatch!(Ix5),
910            6 => dispatch!(Ix6),
911            _ => {
912                let iter_a = IterLayoutColMajor::new(la)?;
913                let iter_b = IterLayoutColMajor::new(lb)?;
914                izip!(iter_a, iter_b).for_each(f);
915            },
916        }
917    }
918
919    #[cfg(not(feature = "dispatch_dim_layout_iter"))]
920    {
921        let iter_a = IterLayoutColMajor::new(la)?;
922        let iter_b = IterLayoutColMajor::new(lb)?;
923        izip!(iter_a, iter_b).for_each(f);
924    }
925    Ok(())
926}
927
928#[allow(unused_mut)]
929pub fn layout_col_major_dim_dispatch_3<D, F>(
930    la: &Layout<D>,
931    lb: &Layout<D>,
932    lc: &Layout<D>,
933    mut f: F,
934) -> Result<()>
935where
936    D: DimAPI,
937    F: FnMut((usize, usize, usize)),
938{
939    rstsr_assert_eq!(la.ndim(), lb.ndim(), RuntimeError)?;
940    rstsr_assert_eq!(la.ndim(), lc.ndim(), RuntimeError)?;
941
942    #[cfg(feature = "dispatch_dim_layout_iter")]
943    {
944        macro_rules! dispatch {
945            ($dim: ident) => {{
946                let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dim>()?)?;
947                let iter_b = IterLayoutColMajor::new(&lb.to_dim::<$dim>()?)?;
948                let iter_c = IterLayoutColMajor::new(&lc.to_dim::<$dim>()?)?;
949                izip!(iter_a, iter_b, iter_c).for_each(f);
950            }};
951        }
952        match la.ndim() {
953            0 => f((la.offset(), lb.offset(), lc.offset())),
954            1 => dispatch!(Ix1),
955            2 => dispatch!(Ix2),
956            3 => dispatch!(Ix3),
957            4 => dispatch!(Ix4),
958            5 => dispatch!(Ix5),
959            6 => dispatch!(Ix6),
960            _ => {
961                let iter_a = IterLayoutColMajor::new(la)?;
962                let iter_b = IterLayoutColMajor::new(lb)?;
963                let iter_c = IterLayoutColMajor::new(lc)?;
964                izip!(iter_a, iter_b, iter_c).for_each(f);
965            },
966        }
967    }
968
969    #[cfg(not(feature = "dispatch_dim_layout_iter"))]
970    {
971        let iter_a = IterLayoutColMajor::new(la)?;
972        let iter_b = IterLayoutColMajor::new(lb)?;
973        let iter_c = IterLayoutColMajor::new(lc)?;
974        izip!(iter_a, iter_b, iter_c).for_each(f);
975    }
976    Ok(())
977}
978
979#[allow(unused_mut)]
980pub fn layout_col_major_dim_dispatch_2diff<DA, DB, F>(
981    la: &Layout<DA>,
982    lb: &Layout<DB>,
983    mut f: F,
984) -> Result<()>
985where
986    DA: DimAPI,
987    DB: DimAPI,
988    F: FnMut((usize, usize)),
989{
990    #[cfg(feature = "dispatch_dim_layout_iter")]
991    {
992        macro_rules! dispatch {
993            ($dima: ident, $dimb: ident) => {{
994                let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dima>()?)?;
995                let iter_b = IterLayoutColMajor::new(&lb.to_dim::<$dimb>()?)?;
996                izip!(iter_a, iter_b).for_each(f);
997            }};
998        }
999        match (la.ndim(), lb.ndim()) {
1000            (0, 0) => f((la.offset(), lb.offset())),
1001            (1, 1) => dispatch!(Ix1, Ix1),
1002            (1, 2) => dispatch!(Ix1, Ix2),
1003            (1, 3) => dispatch!(Ix1, Ix3),
1004            (1, 4) => dispatch!(Ix1, Ix4),
1005            (1, 5) => dispatch!(Ix1, Ix5),
1006            (1, 6) => dispatch!(Ix1, Ix6),
1007            (2, 1) => dispatch!(Ix2, Ix1),
1008            (2, 2) => dispatch!(Ix2, Ix2),
1009            (2, 3) => dispatch!(Ix2, Ix3),
1010            (2, 4) => dispatch!(Ix2, Ix4),
1011            (2, 5) => dispatch!(Ix2, Ix5),
1012            (2, 6) => dispatch!(Ix2, Ix6),
1013            (3, 1) => dispatch!(Ix3, Ix1),
1014            (3, 2) => dispatch!(Ix3, Ix2),
1015            (3, 3) => dispatch!(Ix3, Ix3),
1016            (3, 4) => dispatch!(Ix3, Ix4),
1017            (3, 5) => dispatch!(Ix3, Ix5),
1018            (3, 6) => dispatch!(Ix3, Ix6),
1019            (4, 1) => dispatch!(Ix4, Ix1),
1020            (4, 2) => dispatch!(Ix4, Ix2),
1021            (4, 3) => dispatch!(Ix4, Ix3),
1022            (4, 4) => dispatch!(Ix4, Ix4),
1023            (4, 5) => dispatch!(Ix4, Ix5),
1024            (4, 6) => dispatch!(Ix4, Ix6),
1025            (5, 1) => dispatch!(Ix5, Ix1),
1026            (5, 2) => dispatch!(Ix5, Ix2),
1027            (5, 3) => dispatch!(Ix5, Ix3),
1028            (5, 4) => dispatch!(Ix5, Ix4),
1029            (5, 5) => dispatch!(Ix5, Ix5),
1030            (5, 6) => dispatch!(Ix5, Ix6),
1031            (6, 1) => dispatch!(Ix6, Ix1),
1032            (6, 2) => dispatch!(Ix6, Ix2),
1033            (6, 3) => dispatch!(Ix6, Ix3),
1034            (6, 4) => dispatch!(Ix6, Ix4),
1035            (6, 5) => dispatch!(Ix6, Ix5),
1036            (6, 6) => dispatch!(Ix6, Ix6),
1037            _ => {
1038                let iter_a = IterLayoutColMajor::new(la)?;
1039                let iter_b = IterLayoutColMajor::new(lb)?;
1040                izip!(iter_a, iter_b).for_each(f);
1041            },
1042        }
1043    }
1044
1045    #[cfg(not(feature = "dispatch_dim_layout_iter"))]
1046    {
1047        let iter_a = IterLayoutColMajor::new(la)?;
1048        let iter_b = IterLayoutColMajor::new(lb)?;
1049        izip!(iter_a, iter_b).for_each(f);
1050    }
1051    Ok(())
1052}
1053
1054/* #endregion */
1055
1056#[cfg(test)]
1057mod test_col_major {
1058    use super::*;
1059
1060    // type alias for this file
1061    type Order = TensorIterOrder;
1062
1063    #[test]
1064    fn test_iter_next() {
1065        // a = np.arange(9 * 12 * 15)
1066        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
1067        //       .transpose(2, 0, 1)
1068        // a = np.asfortranarray(a)
1069        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
1070        // np.array(np.nditer(a, order="C"))
1071        let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1072        let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1073        let vec = iter.collect::<Vec<_>>();
1074        assert_eq!(vec, [
1075            782, 797, 812, 827, 842, 857, 602, 617, 632, 647, 662, 677, 785, 800, 815, 830, 845,
1076            860, 605, 620, 635, 650, 665, 680, 788, 803, 818, 833, 848, 863, 608, 623, 638, 653,
1077            668, 683
1078        ]);
1079        // np.array(np.nditer(a, order="F"))
1080        let layout_trans = translate_to_col_major_unary(&layout, Order::F).unwrap();
1081        let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1082        let vec = iter.collect::<Vec<_>>();
1083        assert_eq!(vec, [
1084            782, 785, 788, 602, 605, 608, 797, 800, 803, 617, 620, 623, 812, 815, 818, 632, 635,
1085            638, 827, 830, 833, 647, 650, 653, 842, 845, 848, 662, 665, 668, 857, 860, 863, 677,
1086            680, 683
1087        ]);
1088        // np.array(np.nditer(a, order="K"))
1089        let layout_trans = translate_to_col_major_unary(&layout, Order::K).unwrap();
1090        let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1091        let vec = iter.collect::<Vec<_>>();
1092        assert_eq!(vec, [
1093            602, 605, 608, 617, 620, 623, 632, 635, 638, 647, 650, 653, 662, 665, 668, 677, 680,
1094            683, 782, 785, 788, 797, 800, 803, 812, 815, 818, 827, 830, 833, 842, 845, 848, 857,
1095            860, 863
1096        ]);
1097        // np.array(np.nditer(a, order="G"))
1098        // for no broadcast case, greedy-order is same as k-order
1099        let layout_trans = translate_to_col_major_unary(&layout, Order::K).unwrap();
1100        let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1101        let vec = iter.collect::<Vec<_>>();
1102        assert_eq!(vec, [
1103            602, 605, 608, 617, 620, 623, 632, 635, 638, 647, 650, 653, 662, 665, 668, 677, 680,
1104            683, 782, 785, 788, 797, 800, 803, 812, 815, 818, 827, 830, 833, 842, 845, 848, 857,
1105            860, 863
1106        ]);
1107        // buffer should fail
1108        assert!(translate_to_col_major_unary(&layout, Order::B).is_err());
1109    }
1110
1111    #[test]
1112    fn test_iter_back() {
1113        let layout = Layout::new([10, 10, 10], [10, 1, 100], 0).unwrap();
1114        // np.array(np.nditer(a, order="C"))
1115        let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1116        println!("{:?}", unsafe { layout.shape().unravel_index_f(100) });
1117        let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1118        let vec_next = iter.collect::<Vec<_>>();
1119        let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1120        let vec_back = iter.rev().collect::<Vec<_>>();
1121        assert_eq!(vec_next, vec_back.iter().rev().copied().collect::<Vec<_>>());
1122    }
1123
1124    #[test]
1125    fn test_iter_back_empty() {
1126        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
1127        // np.array(np.nditer(a, order="C"))
1128        let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1129        let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1130        let vec_next = iter.clone().collect::<Vec<_>>();
1131        let vec_back = iter.clone().rev().collect::<Vec<_>>();
1132        assert_eq!(vec_next, vec_back.iter().rev().copied().collect::<Vec<_>>());
1133
1134        let layout = Layout::new([10], [10], 10).unwrap();
1135        // np.array(np.nditer(a, order="C"))
1136        let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1137        let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1138        let vec_next = iter.clone().collect::<Vec<_>>();
1139        let vec_back = iter.clone().rev().collect::<Vec<_>>();
1140        assert_eq!(vec_next, vec_back.iter().rev().copied().collect::<Vec<_>>());
1141    }
1142}
1143
1144#[cfg(test)]
1145mod test_row_major {
1146    use super::*;
1147
1148    #[test]
1149    fn test_iter_next() {
1150        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
1151        // np.array(np.nditer(a, order="C"))
1152        let iter = IterLayoutRowMajor::new(&layout).unwrap();
1153        let vec = iter.collect::<Vec<_>>();
1154        assert_eq!(vec, [
1155            782, 797, 812, 827, 842, 857, 602, 617, 632, 647, 662, 677, 785, 800, 815, 830, 845,
1156            860, 605, 620, 635, 650, 665, 680, 788, 803, 818, 833, 848, 863, 608, 623, 638, 653,
1157            668, 683
1158        ]);
1159        let iter = IterLayoutRowMajor::new(&layout).unwrap();
1160        let vec = iter.rev().collect::<Vec<_>>();
1161        assert_eq!(vec, [
1162            683, 668, 653, 638, 623, 608, 863, 848, 833, 818, 803, 788, 680, 665, 650, 635, 620,
1163            605, 860, 845, 830, 815, 800, 785, 677, 662, 647, 632, 617, 602, 857, 842, 827, 812,
1164            797, 782
1165        ]);
1166    }
1167}