1use crate::prelude_dev::*;
4
5#[derive(Clone, Debug)]
23pub struct IterLayoutColMajor<D>
24where
25 D: DimDevAPI,
26{
27 pub(crate) layout: Layout<D>,
28
29 pub(crate) index_start: D, pub(crate) iter_start: usize,
31 pub(crate) offset_start: isize,
32
33 pub(crate) index_end: D, pub(crate) iter_end: usize,
35 pub(crate) offset_end: isize,
36}
37
38impl<D> IterLayoutColMajor<D>
39where
40 D: DimDevAPI,
41{
42 pub fn index_start(&self) -> &D {
43 &self.index_start
44 }
45
46 pub fn index_end(&self) -> &D {
47 &self.index_end
48 }
49
50 pub fn iter_start(&self) -> usize {
51 self.iter_start
52 }
53
54 pub fn iter_end(&self) -> usize {
55 self.iter_end
56 }
57
58 pub fn offset_start(&self) -> isize {
59 self.offset_start
60 }
61
62 pub fn offset_end(&self) -> isize {
63 self.offset_end
64 }
65}
66
67impl<D> IterLayoutColMajor<D>
68where
69 D: DimDevAPI,
70{
71 pub fn new(layout: &Layout<D>) -> Result<Self> {
74 let layout = layout.clone();
75 let shape = layout.shape();
76 let iter_start = 0;
77 let iter_end = layout.size();
78 let index_start = layout.new_shape();
79 let index_end = match iter_end {
80 0 => index_start.clone(),
81 _ => unsafe { shape.unravel_index_f(iter_end) },
82 };
83 let offset_start = layout.offset() as isize;
84 let offset_end = unsafe { layout.index_uncheck(index_end.as_ref()) };
85
86 return Ok(Self { layout, index_start, iter_start, offset_start, index_end, iter_end, offset_end });
87 }
88
89 pub fn split_at(&self, index: usize) -> Result<(Self, Self)> {
90 let Self { layout, index_start, iter_start, offset_start, index_end, iter_end, offset_end } = self.clone();
91 let shape = layout.shape();
92 let iter_ins = iter_start + index;
93 let index_ins = unsafe { shape.unravel_index_f(iter_ins) };
94 let offset_ins = unsafe { layout.index_uncheck(index_ins.as_ref()) };
95 let split_lhs = Self {
96 layout: layout.clone(),
97 index_start,
98 iter_start,
99 offset_start,
100 index_end: index_ins.clone(),
101 iter_end: iter_ins,
102 offset_end: offset_ins,
103 };
104 let split_rhs = Self {
105 layout: layout.clone(),
106 index_start: index_ins,
107 iter_start: iter_ins,
108 offset_start: offset_ins,
109 index_end,
110 iter_end,
111 offset_end,
112 };
113 return Ok((split_lhs, split_rhs));
114 }
115}
116
117impl<D> IterLayoutColMajor<D>
118where
119 D: DimDevAPI,
120{
121 #[inline]
122 fn next_iter_index(&mut self) {
123 let layout = &self.layout;
124 let index = self.index_start.as_mut();
125 let mut offset = self.offset_start;
126 let shape = layout.shape().as_ref();
127 let stride = layout.stride().as_ref();
128 match layout.ndim() {
129 0 => (),
130 1 => {
131 index[0] += 1;
132 offset += stride[0];
133 },
134 2 => {
135 index[0] += 1;
136 offset += stride[0];
137 if index[0] == shape[0] {
138 index[0] = 0;
139 offset -= shape[0] as isize * stride[0];
140 index[1] += 1;
141 offset += stride[1];
142 }
143 },
144 3 => {
145 index[0] += 1;
146 offset += stride[0];
147 if index[0] == shape[0] {
148 index[0] = 0;
149 offset -= shape[0] as isize * stride[0];
150 index[1] += 1;
151 offset += stride[1];
152 if index[1] == shape[1] {
153 index[1] = 0;
154 offset -= shape[1] as isize * stride[1];
155 index[2] += 1;
156 offset += stride[2];
157 }
158 }
159 },
160 4 => {
161 index[0] += 1;
162 offset += stride[0];
163 if index[0] == shape[0] {
164 index[0] = 0;
165 offset -= shape[0] as isize * stride[0];
166 index[1] += 1;
167 offset += stride[1];
168 if index[1] == shape[1] {
169 index[1] = 0;
170 offset -= shape[1] as isize * stride[1];
171 index[2] += 1;
172 offset += stride[2];
173 if index[2] == shape[2] {
174 index[2] = 0;
175 offset -= shape[2] as isize * stride[2];
176 index[3] += 1;
177 offset += stride[3];
178 }
179 }
180 }
181 },
182 _ => {
183 for (d, t, idx) in izip!(shape, stride, index.as_mut()) {
184 *idx += 1;
185 offset += t;
186 if idx == d {
187 *idx = 0;
188 offset -= *d as isize * t;
189 } else {
190 break;
191 }
192 }
193 },
194 }
195 self.offset_start = offset;
196 self.iter_start += 1;
197 }
198
199 #[inline]
200 fn back_iter_index(&mut self) {
201 let layout = &self.layout;
202 let index = self.index_end.as_mut();
203 let mut offset = self.offset_end;
204 let shape = layout.shape().as_ref();
205 let stride = layout.stride().as_ref();
206 match layout.ndim() {
207 0 => (),
208 1 => {
209 index[0] -= 1;
210 offset -= stride[0];
211 },
212 2 => {
213 if index[0] == 0 {
214 index[0] = shape[0] - 1;
215 offset += (shape[0] - 1) as isize * stride[0];
216 index[1] -= 1;
217 offset -= stride[1];
218 } else {
219 index[0] -= 1;
220 offset -= stride[0];
221 }
222 },
223 3 => {
224 if index[0] == 0 {
225 index[0] = shape[0] - 1;
226 offset += (shape[0] - 1) as isize * stride[0];
227 if index[1] == 0 {
228 index[1] = shape[1] - 1;
229 offset += (shape[1] - 1) as isize * stride[1];
230 index[2] -= 1;
231 offset -= stride[2];
232 } else {
233 index[1] -= 1;
234 offset -= stride[1];
235 }
236 } else {
237 index[0] -= 1;
238 offset -= stride[0];
239 }
240 },
241 4 => {
242 if index[0] == 0 {
243 index[0] = shape[0] - 1;
244 offset += (shape[0] - 1) as isize * stride[0];
245 if index[1] == 0 {
246 index[1] = shape[1] - 1;
247 offset += (shape[1] - 1) as isize * stride[1];
248 if index[2] == 0 {
249 index[2] = shape[2] - 1;
250 offset += (shape[2] - 1) as isize * stride[2];
251 index[3] -= 1;
252 offset -= stride[3];
253 } else {
254 index[2] -= 1;
255 offset -= stride[2];
256 }
257 } else {
258 index[1] -= 1;
259 offset -= stride[1];
260 }
261 } else {
262 index[0] -= 1;
263 offset -= stride[0];
264 }
265 },
266 _ => {
267 for (d, t, idx) in izip!(shape, stride, index.as_mut()) {
268 if *idx == 0 {
269 *idx = *d - 1;
270 offset += (*d - 1) as isize * t;
271 } else {
272 *idx -= 1;
273 offset -= t;
274 break;
275 }
276 }
277 },
278 }
279 self.offset_end = offset;
280 self.iter_end -= 1;
281 }
282}
283
284impl<D> Iterator for IterLayoutColMajor<D>
285where
286 D: DimDevAPI,
287{
288 type Item = usize;
289
290 fn next(&mut self) -> Option<Self::Item> {
291 if self.iter_start >= self.iter_end {
292 return None;
293 }
294 let offset = self.offset_start;
295 self.next_iter_index();
296 return Some(offset.try_into().unwrap());
297 }
298}
299
300impl<D> DoubleEndedIterator for IterLayoutColMajor<D>
301where
302 D: DimDevAPI,
303{
304 fn next_back(&mut self) -> Option<Self::Item> {
305 if self.iter_start >= self.iter_end {
306 return None;
307 }
308 self.back_iter_index();
309 let offset = self.offset_end;
310 return Some(offset.try_into().unwrap());
311 }
312}
313
314impl<D> ExactSizeIterator for IterLayoutColMajor<D>
315where
316 D: DimDevAPI,
317{
318 fn len(&self) -> usize {
319 self.iter_end - self.iter_start
320 }
321}
322
323impl<D> IterSplitAtAPI for IterLayoutColMajor<D>
324where
325 D: DimDevAPI,
326{
327 fn split_at(self, index: usize) -> (Self, Self) {
328 Self::split_at(&self, index).unwrap()
329 }
330}
331
332#[derive(Debug, Clone)]
346pub struct IterLayoutRowMajor<D>
347where
348 D: DimDevAPI,
349{
350 pub(crate) layout: Layout<D>,
351
352 pub(crate) index_start: D, pub(crate) iter_start: usize,
354 pub(crate) offset_start: isize,
355
356 pub(crate) index_end: D, pub(crate) iter_end: usize,
358 pub(crate) offset_end: isize,
359}
360
361impl<D> IterLayoutRowMajor<D>
362where
363 D: DimDevAPI,
364{
365 pub fn index_start(&self) -> &D {
366 &self.index_start
367 }
368
369 pub fn index_end(&self) -> &D {
370 &self.index_end
371 }
372
373 pub fn iter_start(&self) -> usize {
374 self.iter_start
375 }
376
377 pub fn iter_end(&self) -> usize {
378 self.iter_end
379 }
380
381 pub fn offset_start(&self) -> isize {
382 self.offset_start
383 }
384
385 pub fn offset_end(&self) -> isize {
386 self.offset_end
387 }
388}
389
390impl<D> IterLayoutRowMajor<D>
391where
392 D: DimDevAPI,
393{
394 pub fn new(layout: &Layout<D>) -> Result<Self> {
397 let layout = layout.clone();
398 let shape = layout.shape();
399 let iter_start = 0;
400 let iter_end = layout.size();
401 let index_start = layout.new_shape();
402 let index_end = match iter_end {
403 0 => index_start.clone(),
404 _ => unsafe { shape.unravel_index_c(iter_end) },
405 };
406 let offset_start = layout.offset() as isize;
407 let offset_end = unsafe { layout.index_uncheck(index_end.as_ref()) };
408
409 return Ok(Self { layout, index_start, iter_start, offset_start, index_end, iter_end, offset_end });
410 }
411
412 pub fn split_at(&self, index: usize) -> Result<(Self, Self)> {
413 let Self { layout, index_start, iter_start, offset_start, index_end, iter_end, offset_end } = self.clone();
414 let shape = layout.shape();
415 let iter_ins = iter_start + index;
416 let index_ins = unsafe { shape.unravel_index_c(iter_ins) };
417 let offset_ins = unsafe { layout.index_uncheck(index_ins.as_ref()) };
418 let split_lhs = Self {
419 layout: layout.clone(),
420 index_start,
421 iter_start,
422 offset_start,
423 index_end: index_ins.clone(),
424 iter_end: iter_ins,
425 offset_end: offset_ins,
426 };
427 let split_rhs = Self {
428 layout: layout.clone(),
429 index_start: index_ins,
430 iter_start: iter_ins,
431 offset_start: offset_ins,
432 index_end,
433 iter_end,
434 offset_end,
435 };
436 return Ok((split_lhs, split_rhs));
437 }
438}
439
440impl<D> IterLayoutRowMajor<D>
441where
442 D: DimDevAPI,
443{
444 #[inline]
445 fn next_iter_index(&mut self) {
446 let layout = &self.layout;
447 let index = self.index_start.as_mut();
448 let mut offset = self.offset_start;
449 let shape = layout.shape().as_ref();
450 let stride = layout.stride().as_ref();
451 match layout.ndim() {
452 0 => (),
453 1 => {
454 index[0] += 1;
455 offset += stride[0];
456 },
457 2 => {
458 index[1] += 1;
459 offset += stride[1];
460 if index[1] == shape[1] {
461 index[1] = 0;
462 offset -= shape[1] as isize * stride[1];
463 index[0] += 1;
464 offset += stride[0];
465 }
466 },
467 3 => {
468 index[2] += 1;
469 offset += stride[2];
470 if index[2] == shape[2] {
471 index[2] = 0;
472 offset -= shape[2] as isize * stride[2];
473 index[1] += 1;
474 offset += stride[1];
475 if index[1] == shape[1] {
476 index[1] = 0;
477 offset -= shape[1] as isize * stride[1];
478 index[0] += 1;
479 offset += stride[0];
480 }
481 }
482 },
483 4 => {
484 index[3] += 1;
485 offset += stride[3];
486 if index[3] == shape[3] {
487 index[3] = 0;
488 offset -= shape[3] as isize * stride[3];
489 index[2] += 1;
490 offset += stride[2];
491 if index[2] == shape[2] {
492 index[2] = 0;
493 offset -= shape[2] as isize * stride[2];
494 index[1] += 1;
495 offset += stride[1];
496 if index[1] == shape[1] {
497 index[1] = 0;
498 offset -= shape[1] as isize * stride[1];
499 index[0] += 1;
500 offset += stride[0];
501 }
502 }
503 }
504 },
505 _ => {
506 for (d, t, idx) in izip!(shape, stride, index.as_mut()).rev() {
507 *idx += 1;
508 offset += t;
509 if idx == d {
510 *idx = 0;
511 offset -= *d as isize * t;
512 } else {
513 break;
514 }
515 }
516 },
517 }
518 self.offset_start = offset;
519 self.iter_start += 1;
520 }
521
522 #[inline]
523 fn back_iter_index(&mut self) {
524 let layout = &self.layout;
525 let index = self.index_end.as_mut();
526 let mut offset = self.offset_end;
527 let shape = layout.shape().as_ref();
528 let stride = layout.stride().as_ref();
529 match layout.ndim() {
530 0 => (),
531 1 => {
532 index[0] -= 1;
533 offset -= stride[0];
534 },
535 2 => {
536 if index[1] == 0 {
537 index[1] = shape[1] - 1;
538 offset += (shape[1] - 1) as isize * stride[1];
539 index[0] -= 1;
540 offset -= stride[0];
541 } else {
542 index[1] -= 1;
543 offset -= stride[1];
544 }
545 },
546 3 => {
547 if index[2] == 0 {
548 index[2] = shape[2] - 1;
549 offset += (shape[2] - 1) as isize * stride[2];
550 if index[1] == 0 {
551 index[1] = shape[1] - 1;
552 offset += (shape[1] - 1) as isize * stride[1];
553 index[0] -= 1;
554 offset -= stride[0];
555 } else {
556 index[1] -= 1;
557 offset -= stride[1];
558 }
559 } else {
560 index[2] -= 1;
561 offset -= stride[2];
562 }
563 },
564 4 => {
565 if index[3] == 0 {
566 index[3] = shape[3] - 1;
567 offset += (shape[3] - 1) as isize * stride[3];
568 if index[2] == 0 {
569 index[2] = shape[2] - 1;
570 offset += (shape[2] - 1) as isize * stride[2];
571 if index[1] == 0 {
572 index[1] = shape[1] - 1;
573 offset += (shape[1] - 1) as isize * stride[1];
574 index[0] -= 1;
575 offset -= stride[0];
576 } else {
577 index[1] -= 1;
578 offset -= stride[1];
579 }
580 } else {
581 index[2] -= 1;
582 offset -= stride[2];
583 }
584 } else {
585 index[3] -= 1;
586 offset -= stride[3];
587 }
588 },
589 _ => {
590 for (d, t, idx) in izip!(shape, stride, index.as_mut()).rev() {
591 if *idx == 0 {
592 *idx = *d - 1;
593 offset += (*d - 1) as isize * t;
594 } else {
595 *idx -= 1;
596 offset -= t;
597 break;
598 }
599 }
600 },
601 }
602 self.offset_end = offset;
603 self.iter_end -= 1;
604 }
605}
606
607impl<D> Iterator for IterLayoutRowMajor<D>
608where
609 D: DimDevAPI,
610{
611 type Item = usize;
612
613 fn next(&mut self) -> Option<Self::Item> {
614 if self.iter_start >= self.iter_end {
615 return None;
616 }
617 let offset = self.offset_start;
618 self.next_iter_index();
619 return Some(offset.try_into().unwrap());
620 }
621}
622
623impl<D> DoubleEndedIterator for IterLayoutRowMajor<D>
624where
625 D: DimDevAPI,
626{
627 fn next_back(&mut self) -> Option<Self::Item> {
628 if self.iter_start >= self.iter_end {
629 return None;
630 }
631 self.back_iter_index();
632 let offset = self.offset_end;
633 return Some(offset.try_into().unwrap());
634 }
635}
636
637impl<D> ExactSizeIterator for IterLayoutRowMajor<D>
638where
639 D: DimDevAPI,
640{
641 fn len(&self) -> usize {
642 self.iter_end - self.iter_start
643 }
644}
645
646impl<D> IterSplitAtAPI for IterLayoutRowMajor<D>
647where
648 D: DimDevAPI,
649{
650 fn split_at(self, index: usize) -> (Self, Self) {
651 Self::split_at(&self, index).unwrap()
652 }
653}
654
655#[derive(Clone, Debug)]
660pub enum IterLayout<D>
661where
662 D: DimDevAPI,
663{
664 RowMajor(IterLayoutRowMajor<D>),
665 ColMajor(IterLayoutColMajor<D>),
666}
667
668impl<D> IterLayout<D>
669where
670 D: DimDevAPI,
671{
672 pub fn new(layout: &Layout<D>, order: TensorIterOrder) -> Result<Self> {
673 use TensorIterOrder::*;
674 match order {
675 C => {
676 let iter = IterLayoutRowMajor::new(layout)?;
677 return Ok(Self::RowMajor(iter));
678 },
679 F => {
680 let iter = IterLayoutColMajor::new(layout)?;
681 return Ok(Self::ColMajor(iter));
682 },
683 A => match FlagOrder::default() {
684 RowMajor => {
685 let iter = IterLayoutRowMajor::new(layout)?;
686 return Ok(Self::RowMajor(iter));
687 },
688 ColMajor => {
689 let iter = IterLayoutColMajor::new(layout)?;
690 return Ok(Self::ColMajor(iter));
691 },
692 },
693 K | G => {
694 let layout = translate_to_col_major_unary(layout, TensorIterOrder::K)?;
695 let iter = IterLayoutColMajor::new(&layout)?;
696 return Ok(Self::ColMajor(iter));
697 },
698 _ => rstsr_raise!(InvalidValue),
699 }
700 }
701}
702
703impl<D> Iterator for IterLayout<D>
704where
705 D: DimDevAPI,
706{
707 type Item = usize;
708
709 fn next(&mut self) -> Option<Self::Item> {
710 match self {
711 Self::RowMajor(iter) => iter.next(),
712 Self::ColMajor(iter) => iter.next(),
713 }
714 }
715}
716
717impl<D> DoubleEndedIterator for IterLayout<D>
718where
719 D: DimDevAPI,
720{
721 fn next_back(&mut self) -> Option<Self::Item> {
722 match self {
723 Self::RowMajor(iter) => iter.next_back(),
724 Self::ColMajor(iter) => iter.next_back(),
725 }
726 }
727}
728
729impl<D> ExactSizeIterator for IterLayout<D>
730where
731 D: DimDevAPI,
732{
733 fn len(&self) -> usize {
734 match self {
735 Self::RowMajor(iter) => iter.len(),
736 Self::ColMajor(iter) => iter.len(),
737 }
738 }
739}
740
741impl<D> IterSplitAtAPI for IterLayout<D>
742where
743 D: DimDevAPI,
744{
745 fn split_at(self, index: usize) -> (Self, Self) {
746 match self {
747 Self::RowMajor(iter) => {
748 let (lhs, rhs) = iter.split_at(index);
749 (Self::RowMajor(lhs), Self::RowMajor(rhs))
750 },
751 Self::ColMajor(iter) => {
752 let (lhs, rhs) = iter.split_at(index);
753 (Self::ColMajor(lhs), Self::ColMajor(rhs))
754 },
755 }
756 }
757}
758
759#[derive(Clone, Debug)]
764pub struct IndexedIterLayout<D>
765where
766 D: DimDevAPI,
767{
768 pub(crate) layout_iter: IterLayout<D>,
769}
770
771impl<D> IndexedIterLayout<D>
772where
773 D: DimDevAPI,
774{
775 pub fn new(layout: &Layout<D>, order: FlagOrder) -> Result<Self> {
776 let order = match order {
777 RowMajor => TensorIterOrder::C,
778 ColMajor => TensorIterOrder::F,
779 };
780 Ok(Self { layout_iter: IterLayout::new(layout, order)? })
781 }
782}
783
784impl<D> Iterator for IndexedIterLayout<D>
785where
786 D: DimDevAPI,
787{
788 type Item = (D, usize);
789
790 fn next(&mut self) -> Option<Self::Item> {
791 let index = match &self.layout_iter {
792 IterLayout::ColMajor(iter_inner) => iter_inner.index_start.clone(),
793 IterLayout::RowMajor(iter_inner) => iter_inner.index_start.clone(),
794 };
795 self.layout_iter.next().map(|offset| (index, offset))
796 }
797}
798
799impl<D> DoubleEndedIterator for IndexedIterLayout<D>
800where
801 D: DimDevAPI,
802{
803 fn next_back(&mut self) -> Option<Self::Item> {
804 let index = match &self.layout_iter {
805 IterLayout::ColMajor(iter_inner) => iter_inner.index_start.clone(),
806 IterLayout::RowMajor(iter_inner) => iter_inner.index_start.clone(),
807 };
808 self.layout_iter.next_back().map(|offset| (index, offset))
809 }
810}
811
812impl<D> ExactSizeIterator for IndexedIterLayout<D>
813where
814 D: DimDevAPI,
815{
816 fn len(&self) -> usize {
817 self.layout_iter.len()
818 }
819}
820
821impl<D> IterSplitAtAPI for IndexedIterLayout<D>
822where
823 D: DimDevAPI,
824{
825 fn split_at(self, mid: usize) -> (Self, Self) {
826 let (lhs, rhs) = self.layout_iter.split_at(mid);
827 let lhs = IndexedIterLayout { layout_iter: lhs };
828 let rhs = IndexedIterLayout { layout_iter: rhs };
829 (lhs, rhs)
830 }
831}
832
833#[allow(unused_mut)]
838pub fn layout_col_major_dim_dispatch_1<D, F>(la: &Layout<D>, mut f: F) -> Result<()>
839where
840 D: DimAPI,
841 F: FnMut(usize),
842{
843 #[cfg(feature = "dispatch_dim_layout_iter")]
844 {
845 macro_rules! dispatch {
846 ($dim: ident) => {{
847 let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dim>()?)?;
848 iter_a.for_each(f);
849 }};
850 }
851 match la.ndim() {
852 0 => f(la.offset()),
853 1 => dispatch!(Ix1),
854 2 => dispatch!(Ix2),
855 3 => dispatch!(Ix3),
856 4 => dispatch!(Ix4),
857 5 => dispatch!(Ix5),
858 6 => dispatch!(Ix6),
859 _ => {
860 let iter_a = IterLayoutColMajor::new(la)?;
861 iter_a.for_each(f);
862 },
863 }
864 }
865
866 #[cfg(not(feature = "dispatch_dim_layout_iter"))]
867 {
868 let iter_a = IterLayoutColMajor::new(la)?;
869 iter_a.for_each(f);
870 }
871 Ok(())
872}
873
874#[allow(unused_mut)]
875pub fn layout_col_major_dim_dispatch_2<D, F>(la: &Layout<D>, lb: &Layout<D>, mut f: F) -> Result<()>
876where
877 D: DimAPI,
878 F: FnMut((usize, usize)),
879{
880 rstsr_assert_eq!(la.ndim(), lb.ndim(), RuntimeError)?;
881
882 #[cfg(feature = "dispatch_dim_layout_iter")]
883 {
884 macro_rules! dispatch {
885 ($dim: ident) => {{
886 let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dim>()?)?;
887 let iter_b = IterLayoutColMajor::new(&lb.to_dim::<$dim>()?)?;
888 izip!(iter_a, iter_b).for_each(f);
889 }};
890 }
891 match la.ndim() {
892 0 => f((la.offset(), lb.offset())),
893 1 => dispatch!(Ix1),
894 2 => dispatch!(Ix2),
895 3 => dispatch!(Ix3),
896 4 => dispatch!(Ix4),
897 5 => dispatch!(Ix5),
898 6 => dispatch!(Ix6),
899 _ => {
900 let iter_a = IterLayoutColMajor::new(la)?;
901 let iter_b = IterLayoutColMajor::new(lb)?;
902 izip!(iter_a, iter_b).for_each(f);
903 },
904 }
905 }
906
907 #[cfg(not(feature = "dispatch_dim_layout_iter"))]
908 {
909 let iter_a = IterLayoutColMajor::new(la)?;
910 let iter_b = IterLayoutColMajor::new(lb)?;
911 izip!(iter_a, iter_b).for_each(f);
912 }
913 Ok(())
914}
915
916#[allow(unused_mut)]
917pub fn layout_col_major_dim_dispatch_3<D, F>(la: &Layout<D>, lb: &Layout<D>, lc: &Layout<D>, mut f: F) -> Result<()>
918where
919 D: DimAPI,
920 F: FnMut((usize, usize, usize)),
921{
922 rstsr_assert_eq!(la.ndim(), lb.ndim(), RuntimeError)?;
923 rstsr_assert_eq!(la.ndim(), lc.ndim(), RuntimeError)?;
924
925 #[cfg(feature = "dispatch_dim_layout_iter")]
926 {
927 macro_rules! dispatch {
928 ($dim: ident) => {{
929 let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dim>()?)?;
930 let iter_b = IterLayoutColMajor::new(&lb.to_dim::<$dim>()?)?;
931 let iter_c = IterLayoutColMajor::new(&lc.to_dim::<$dim>()?)?;
932 izip!(iter_a, iter_b, iter_c).for_each(f);
933 }};
934 }
935 match la.ndim() {
936 0 => f((la.offset(), lb.offset(), lc.offset())),
937 1 => dispatch!(Ix1),
938 2 => dispatch!(Ix2),
939 3 => dispatch!(Ix3),
940 4 => dispatch!(Ix4),
941 5 => dispatch!(Ix5),
942 6 => dispatch!(Ix6),
943 _ => {
944 let iter_a = IterLayoutColMajor::new(la)?;
945 let iter_b = IterLayoutColMajor::new(lb)?;
946 let iter_c = IterLayoutColMajor::new(lc)?;
947 izip!(iter_a, iter_b, iter_c).for_each(f);
948 },
949 }
950 }
951
952 #[cfg(not(feature = "dispatch_dim_layout_iter"))]
953 {
954 let iter_a = IterLayoutColMajor::new(la)?;
955 let iter_b = IterLayoutColMajor::new(lb)?;
956 let iter_c = IterLayoutColMajor::new(lc)?;
957 izip!(iter_a, iter_b, iter_c).for_each(f);
958 }
959 Ok(())
960}
961
962#[allow(unused_mut)]
963pub fn layout_col_major_dim_dispatch_2diff<DA, DB, F>(la: &Layout<DA>, lb: &Layout<DB>, mut f: F) -> Result<()>
964where
965 DA: DimAPI,
966 DB: DimAPI,
967 F: FnMut((usize, usize)),
968{
969 #[cfg(feature = "dispatch_dim_layout_iter")]
970 {
971 macro_rules! dispatch {
972 ($dima: ident, $dimb: ident) => {{
973 let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dima>()?)?;
974 let iter_b = IterLayoutColMajor::new(&lb.to_dim::<$dimb>()?)?;
975 izip!(iter_a, iter_b).for_each(f);
976 }};
977 }
978 match (la.ndim(), lb.ndim()) {
979 (0, 0) => f((la.offset(), lb.offset())),
980 (1, 1) => dispatch!(Ix1, Ix1),
981 (1, 2) => dispatch!(Ix1, Ix2),
982 (1, 3) => dispatch!(Ix1, Ix3),
983 (1, 4) => dispatch!(Ix1, Ix4),
984 (1, 5) => dispatch!(Ix1, Ix5),
985 (1, 6) => dispatch!(Ix1, Ix6),
986 (2, 1) => dispatch!(Ix2, Ix1),
987 (2, 2) => dispatch!(Ix2, Ix2),
988 (2, 3) => dispatch!(Ix2, Ix3),
989 (2, 4) => dispatch!(Ix2, Ix4),
990 (2, 5) => dispatch!(Ix2, Ix5),
991 (2, 6) => dispatch!(Ix2, Ix6),
992 (3, 1) => dispatch!(Ix3, Ix1),
993 (3, 2) => dispatch!(Ix3, Ix2),
994 (3, 3) => dispatch!(Ix3, Ix3),
995 (3, 4) => dispatch!(Ix3, Ix4),
996 (3, 5) => dispatch!(Ix3, Ix5),
997 (3, 6) => dispatch!(Ix3, Ix6),
998 (4, 1) => dispatch!(Ix4, Ix1),
999 (4, 2) => dispatch!(Ix4, Ix2),
1000 (4, 3) => dispatch!(Ix4, Ix3),
1001 (4, 4) => dispatch!(Ix4, Ix4),
1002 (4, 5) => dispatch!(Ix4, Ix5),
1003 (4, 6) => dispatch!(Ix4, Ix6),
1004 (5, 1) => dispatch!(Ix5, Ix1),
1005 (5, 2) => dispatch!(Ix5, Ix2),
1006 (5, 3) => dispatch!(Ix5, Ix3),
1007 (5, 4) => dispatch!(Ix5, Ix4),
1008 (5, 5) => dispatch!(Ix5, Ix5),
1009 (5, 6) => dispatch!(Ix5, Ix6),
1010 (6, 1) => dispatch!(Ix6, Ix1),
1011 (6, 2) => dispatch!(Ix6, Ix2),
1012 (6, 3) => dispatch!(Ix6, Ix3),
1013 (6, 4) => dispatch!(Ix6, Ix4),
1014 (6, 5) => dispatch!(Ix6, Ix5),
1015 (6, 6) => dispatch!(Ix6, Ix6),
1016 _ => {
1017 let iter_a = IterLayoutColMajor::new(la)?;
1018 let iter_b = IterLayoutColMajor::new(lb)?;
1019 izip!(iter_a, iter_b).for_each(f);
1020 },
1021 }
1022 }
1023
1024 #[cfg(not(feature = "dispatch_dim_layout_iter"))]
1025 {
1026 let iter_a = IterLayoutColMajor::new(la)?;
1027 let iter_b = IterLayoutColMajor::new(lb)?;
1028 izip!(iter_a, iter_b).for_each(f);
1029 }
1030 Ok(())
1031}
1032
1033#[cfg(test)]
1036mod test_col_major {
1037 use super::*;
1038
1039 type Order = TensorIterOrder;
1041
1042 #[test]
1043 fn test_iter_next() {
1044 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
1049 let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1051 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1052 let vec = iter.collect::<Vec<_>>();
1053 assert_eq!(vec, [
1054 782, 797, 812, 827, 842, 857, 602, 617, 632, 647, 662, 677, 785, 800, 815, 830, 845, 860, 605, 620, 635,
1055 650, 665, 680, 788, 803, 818, 833, 848, 863, 608, 623, 638, 653, 668, 683
1056 ]);
1057 let layout_trans = translate_to_col_major_unary(&layout, Order::F).unwrap();
1059 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1060 let vec = iter.collect::<Vec<_>>();
1061 assert_eq!(vec, [
1062 782, 785, 788, 602, 605, 608, 797, 800, 803, 617, 620, 623, 812, 815, 818, 632, 635, 638, 827, 830, 833,
1063 647, 650, 653, 842, 845, 848, 662, 665, 668, 857, 860, 863, 677, 680, 683
1064 ]);
1065 let layout_trans = translate_to_col_major_unary(&layout, Order::K).unwrap();
1067 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1068 let vec = iter.collect::<Vec<_>>();
1069 assert_eq!(vec, [
1070 602, 605, 608, 617, 620, 623, 632, 635, 638, 647, 650, 653, 662, 665, 668, 677, 680, 683, 782, 785, 788,
1071 797, 800, 803, 812, 815, 818, 827, 830, 833, 842, 845, 848, 857, 860, 863
1072 ]);
1073 let layout_trans = translate_to_col_major_unary(&layout, Order::K).unwrap();
1076 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1077 let vec = iter.collect::<Vec<_>>();
1078 assert_eq!(vec, [
1079 602, 605, 608, 617, 620, 623, 632, 635, 638, 647, 650, 653, 662, 665, 668, 677, 680, 683, 782, 785, 788,
1080 797, 800, 803, 812, 815, 818, 827, 830, 833, 842, 845, 848, 857, 860, 863
1081 ]);
1082 assert!(translate_to_col_major_unary(&layout, Order::B).is_err());
1084 }
1085
1086 #[test]
1087 fn test_iter_back() {
1088 let layout = Layout::new([10, 10, 10], [10, 1, 100], 0).unwrap();
1089 let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1091 println!("{:?}", unsafe { layout.shape().unravel_index_f(100) });
1092 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1093 let vec_next = iter.collect::<Vec<_>>();
1094 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1095 let vec_back = iter.rev().collect::<Vec<_>>();
1096 assert_eq!(vec_next, vec_back.iter().rev().copied().collect::<Vec<_>>());
1097 }
1098
1099 #[test]
1100 fn test_iter_back_empty() {
1101 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
1102 let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1104 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1105 let vec_next = iter.clone().collect::<Vec<_>>();
1106 let vec_back = iter.clone().rev().collect::<Vec<_>>();
1107 assert_eq!(vec_next, vec_back.iter().rev().copied().collect::<Vec<_>>());
1108
1109 let layout = Layout::new([10], [10], 10).unwrap();
1110 let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1112 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1113 let vec_next = iter.clone().collect::<Vec<_>>();
1114 let vec_back = iter.clone().rev().collect::<Vec<_>>();
1115 assert_eq!(vec_next, vec_back.iter().rev().copied().collect::<Vec<_>>());
1116 }
1117
1118 #[test]
1119 fn test_zero_size_iterator() {
1120 let layout = [3, 0, 5].f();
1121 let iter = IterLayoutColMajor::new(&layout).unwrap();
1122 let vec = iter.collect_vec();
1123 assert_eq!(vec, vec![]);
1124 }
1125}
1126
1127#[cfg(test)]
1128mod test_row_major {
1129 use super::*;
1130
1131 #[test]
1132 fn test_iter_next() {
1133 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
1134 let iter = IterLayoutRowMajor::new(&layout).unwrap();
1136 let vec = iter.collect::<Vec<_>>();
1137 assert_eq!(vec, [
1138 782, 797, 812, 827, 842, 857, 602, 617, 632, 647, 662, 677, 785, 800, 815, 830, 845, 860, 605, 620, 635,
1139 650, 665, 680, 788, 803, 818, 833, 848, 863, 608, 623, 638, 653, 668, 683
1140 ]);
1141 let iter = IterLayoutRowMajor::new(&layout).unwrap();
1142 let vec = iter.rev().collect::<Vec<_>>();
1143 assert_eq!(vec, [
1144 683, 668, 653, 638, 623, 608, 863, 848, 833, 818, 803, 788, 680, 665, 650, 635, 620, 605, 860, 845, 830,
1145 815, 800, 785, 677, 662, 647, 632, 617, 602, 857, 842, 827, 812, 797, 782
1146 ]);
1147 }
1148}