1use crate::prelude_dev::*;
4
5#[derive(Clone, Debug)]
23pub struct IterLayoutColMajor<D>
24where
25 D: DimDevAPI,
26{
27 pub(crate) layout: Layout<D>,
28
29 pub(crate) index_start: D, pub(crate) iter_start: usize,
31 pub(crate) offset_start: isize,
32
33 pub(crate) index_end: D, pub(crate) iter_end: usize,
35 pub(crate) offset_end: isize,
36}
37
38impl<D> IterLayoutColMajor<D>
39where
40 D: DimDevAPI,
41{
42 pub fn index_start(&self) -> &D {
43 &self.index_start
44 }
45
46 pub fn index_end(&self) -> &D {
47 &self.index_end
48 }
49
50 pub fn iter_start(&self) -> usize {
51 self.iter_start
52 }
53
54 pub fn iter_end(&self) -> usize {
55 self.iter_end
56 }
57
58 pub fn offset_start(&self) -> isize {
59 self.offset_start
60 }
61
62 pub fn offset_end(&self) -> isize {
63 self.offset_end
64 }
65}
66
67impl<D> IterLayoutColMajor<D>
68where
69 D: DimDevAPI,
70{
71 pub fn new(layout: &Layout<D>) -> Result<Self> {
74 let layout = layout.clone();
75 let shape = layout.shape();
76 let iter_start = 0;
77 let iter_end = layout.size();
78 let index_start = layout.new_shape();
79 let index_end = unsafe { shape.unravel_index_f(iter_end) };
80 let offset_start = layout.offset() as isize;
81 let offset_end = unsafe { layout.index_uncheck(index_end.as_ref()) };
82
83 return Ok(Self {
84 layout,
85 index_start,
86 iter_start,
87 offset_start,
88 index_end,
89 iter_end,
90 offset_end,
91 });
92 }
93
94 pub fn split_at(&self, index: usize) -> Result<(Self, Self)> {
95 let Self { layout, index_start, iter_start, offset_start, index_end, iter_end, offset_end } =
96 self.clone();
97 let shape = layout.shape();
98 let iter_ins = iter_start + index;
99 let index_ins = unsafe { shape.unravel_index_f(iter_ins) };
100 let offset_ins = unsafe { layout.index_uncheck(index_ins.as_ref()) };
101 let split_lhs = Self {
102 layout: layout.clone(),
103 index_start,
104 iter_start,
105 offset_start,
106 index_end: index_ins.clone(),
107 iter_end: iter_ins,
108 offset_end: offset_ins,
109 };
110 let split_rhs = Self {
111 layout: layout.clone(),
112 index_start: index_ins,
113 iter_start: iter_ins,
114 offset_start: offset_ins,
115 index_end,
116 iter_end,
117 offset_end,
118 };
119 return Ok((split_lhs, split_rhs));
120 }
121}
122
123impl<D> IterLayoutColMajor<D>
124where
125 D: DimDevAPI,
126{
127 #[inline]
128 fn next_iter_index(&mut self) {
129 let layout = &self.layout;
130 let index = self.index_start.as_mut();
131 let mut offset = self.offset_start;
132 let shape = layout.shape().as_ref();
133 let stride = layout.stride().as_ref();
134 match layout.ndim() {
135 0 => (),
136 1 => {
137 index[0] += 1;
138 offset += stride[0];
139 },
140 2 => {
141 index[0] += 1;
142 offset += stride[0];
143 if index[0] == shape[0] {
144 index[0] = 0;
145 offset -= shape[0] as isize * stride[0];
146 index[1] += 1;
147 offset += stride[1];
148 }
149 },
150 3 => {
151 index[0] += 1;
152 offset += stride[0];
153 if index[0] == shape[0] {
154 index[0] = 0;
155 offset -= shape[0] as isize * stride[0];
156 index[1] += 1;
157 offset += stride[1];
158 if index[1] == shape[1] {
159 index[1] = 0;
160 offset -= shape[1] as isize * stride[1];
161 index[2] += 1;
162 offset += stride[2];
163 }
164 }
165 },
166 4 => {
167 index[0] += 1;
168 offset += stride[0];
169 if index[0] == shape[0] {
170 index[0] = 0;
171 offset -= shape[0] as isize * stride[0];
172 index[1] += 1;
173 offset += stride[1];
174 if index[1] == shape[1] {
175 index[1] = 0;
176 offset -= shape[1] as isize * stride[1];
177 index[2] += 1;
178 offset += stride[2];
179 if index[2] == shape[2] {
180 index[2] = 0;
181 offset -= shape[2] as isize * stride[2];
182 index[3] += 1;
183 offset += stride[3];
184 }
185 }
186 }
187 },
188 _ => {
189 for (d, t, idx) in izip!(shape, stride, index.as_mut()) {
190 *idx += 1;
191 offset += t;
192 if idx == d {
193 *idx = 0;
194 offset -= *d as isize * t;
195 } else {
196 break;
197 }
198 }
199 },
200 }
201 self.offset_start = offset;
202 self.iter_start += 1;
203 }
204
205 #[inline]
206 fn back_iter_index(&mut self) {
207 let layout = &self.layout;
208 let index = self.index_end.as_mut();
209 let mut offset = self.offset_end;
210 let shape = layout.shape().as_ref();
211 let stride = layout.stride().as_ref();
212 match layout.ndim() {
213 0 => (),
214 1 => {
215 index[0] -= 1;
216 offset -= stride[0];
217 },
218 2 => {
219 if index[0] == 0 {
220 index[0] = shape[0] - 1;
221 offset += (shape[0] - 1) as isize * stride[0];
222 index[1] -= 1;
223 offset -= stride[1];
224 } else {
225 index[0] -= 1;
226 offset -= stride[0];
227 }
228 },
229 3 => {
230 if index[0] == 0 {
231 index[0] = shape[0] - 1;
232 offset += (shape[0] - 1) as isize * stride[0];
233 if index[1] == 0 {
234 index[1] = shape[1] - 1;
235 offset += (shape[1] - 1) as isize * stride[1];
236 index[2] -= 1;
237 offset -= stride[2];
238 } else {
239 index[1] -= 1;
240 offset -= stride[1];
241 }
242 } else {
243 index[0] -= 1;
244 offset -= stride[0];
245 }
246 },
247 4 => {
248 if index[0] == 0 {
249 index[0] = shape[0] - 1;
250 offset += (shape[0] - 1) as isize * stride[0];
251 if index[1] == 0 {
252 index[1] = shape[1] - 1;
253 offset += (shape[1] - 1) as isize * stride[1];
254 if index[2] == 0 {
255 index[2] = shape[2] - 1;
256 offset += (shape[2] - 1) as isize * stride[2];
257 index[3] -= 1;
258 offset -= stride[3];
259 } else {
260 index[2] -= 1;
261 offset -= stride[2];
262 }
263 } else {
264 index[1] -= 1;
265 offset -= stride[1];
266 }
267 } else {
268 index[0] -= 1;
269 offset -= stride[0];
270 }
271 },
272 _ => {
273 for (d, t, idx) in izip!(shape, stride, index.as_mut()) {
274 if *idx == 0 {
275 *idx = *d - 1;
276 offset += (*d - 1) as isize * t;
277 } else {
278 *idx -= 1;
279 offset -= t;
280 break;
281 }
282 }
283 },
284 }
285 self.offset_end = offset;
286 self.iter_end -= 1;
287 }
288}
289
290impl<D> Iterator for IterLayoutColMajor<D>
291where
292 D: DimDevAPI,
293{
294 type Item = usize;
295
296 fn next(&mut self) -> Option<Self::Item> {
297 if self.iter_start >= self.iter_end {
298 return None;
299 }
300 let offset = self.offset_start;
301 self.next_iter_index();
302 return Some(offset.try_into().unwrap());
303 }
304}
305
306impl<D> DoubleEndedIterator for IterLayoutColMajor<D>
307where
308 D: DimDevAPI,
309{
310 fn next_back(&mut self) -> Option<Self::Item> {
311 if self.iter_start >= self.iter_end {
312 return None;
313 }
314 self.back_iter_index();
315 let offset = self.offset_end;
316 return Some(offset.try_into().unwrap());
317 }
318}
319
320impl<D> ExactSizeIterator for IterLayoutColMajor<D>
321where
322 D: DimDevAPI,
323{
324 fn len(&self) -> usize {
325 self.iter_end - self.iter_start
326 }
327}
328
329impl<D> IterSplitAtAPI for IterLayoutColMajor<D>
330where
331 D: DimDevAPI,
332{
333 fn split_at(self, index: usize) -> (Self, Self) {
334 Self::split_at(&self, index).unwrap()
335 }
336}
337
338#[derive(Debug, Clone)]
352pub struct IterLayoutRowMajor<D>
353where
354 D: DimDevAPI,
355{
356 pub(crate) layout: Layout<D>,
357
358 pub(crate) index_start: D, pub(crate) iter_start: usize,
360 pub(crate) offset_start: isize,
361
362 pub(crate) index_end: D, pub(crate) iter_end: usize,
364 pub(crate) offset_end: isize,
365}
366
367impl<D> IterLayoutRowMajor<D>
368where
369 D: DimDevAPI,
370{
371 pub fn index_start(&self) -> &D {
372 &self.index_start
373 }
374
375 pub fn index_end(&self) -> &D {
376 &self.index_end
377 }
378
379 pub fn iter_start(&self) -> usize {
380 self.iter_start
381 }
382
383 pub fn iter_end(&self) -> usize {
384 self.iter_end
385 }
386
387 pub fn offset_start(&self) -> isize {
388 self.offset_start
389 }
390
391 pub fn offset_end(&self) -> isize {
392 self.offset_end
393 }
394}
395
396impl<D> IterLayoutRowMajor<D>
397where
398 D: DimDevAPI,
399{
400 pub fn new(layout: &Layout<D>) -> Result<Self> {
403 let layout = layout.clone();
404 let shape = layout.shape();
405 let iter_start = 0;
406 let iter_end = layout.size();
407 let index_start = layout.new_shape();
408 let index_end = unsafe { shape.unravel_index_c(iter_end) };
409 let offset_start = layout.offset() as isize;
410 let offset_end = unsafe { layout.index_uncheck(index_end.as_ref()) };
411
412 return Ok(Self {
413 layout,
414 index_start,
415 iter_start,
416 offset_start,
417 index_end,
418 iter_end,
419 offset_end,
420 });
421 }
422
423 pub fn split_at(&self, index: usize) -> Result<(Self, Self)> {
424 let Self { layout, index_start, iter_start, offset_start, index_end, iter_end, offset_end } =
425 self.clone();
426 let shape = layout.shape();
427 let iter_ins = iter_start + index;
428 let index_ins = unsafe { shape.unravel_index_c(iter_ins) };
429 let offset_ins = unsafe { layout.index_uncheck(index_ins.as_ref()) };
430 let split_lhs = Self {
431 layout: layout.clone(),
432 index_start,
433 iter_start,
434 offset_start,
435 index_end: index_ins.clone(),
436 iter_end: iter_ins,
437 offset_end: offset_ins,
438 };
439 let split_rhs = Self {
440 layout: layout.clone(),
441 index_start: index_ins,
442 iter_start: iter_ins,
443 offset_start: offset_ins,
444 index_end,
445 iter_end,
446 offset_end,
447 };
448 return Ok((split_lhs, split_rhs));
449 }
450}
451
452impl<D> IterLayoutRowMajor<D>
453where
454 D: DimDevAPI,
455{
456 #[inline]
457 fn next_iter_index(&mut self) {
458 let layout = &self.layout;
459 let index = self.index_start.as_mut();
460 let mut offset = self.offset_start;
461 let shape = layout.shape().as_ref();
462 let stride = layout.stride().as_ref();
463 match layout.ndim() {
464 0 => (),
465 1 => {
466 index[0] += 1;
467 offset += stride[0];
468 },
469 2 => {
470 index[1] += 1;
471 offset += stride[1];
472 if index[1] == shape[1] {
473 index[1] = 0;
474 offset -= shape[1] as isize * stride[1];
475 index[0] += 1;
476 offset += stride[0];
477 }
478 },
479 3 => {
480 index[2] += 1;
481 offset += stride[2];
482 if index[2] == shape[2] {
483 index[2] = 0;
484 offset -= shape[2] as isize * stride[2];
485 index[1] += 1;
486 offset += stride[1];
487 if index[1] == shape[1] {
488 index[1] = 0;
489 offset -= shape[1] as isize * stride[1];
490 index[0] += 1;
491 offset += stride[0];
492 }
493 }
494 },
495 4 => {
496 index[3] += 1;
497 offset += stride[3];
498 if index[3] == shape[3] {
499 index[3] = 0;
500 offset -= shape[3] as isize * stride[3];
501 index[2] += 1;
502 offset += stride[2];
503 if index[2] == shape[2] {
504 index[2] = 0;
505 offset -= shape[2] as isize * stride[2];
506 index[1] += 1;
507 offset += stride[1];
508 if index[1] == shape[1] {
509 index[1] = 0;
510 offset -= shape[1] as isize * stride[1];
511 index[0] += 1;
512 offset += stride[0];
513 }
514 }
515 }
516 },
517 _ => {
518 for (d, t, idx) in izip!(shape, stride, index.as_mut()).rev() {
519 *idx += 1;
520 offset += t;
521 if idx == d {
522 *idx = 0;
523 offset -= *d as isize * t;
524 } else {
525 break;
526 }
527 }
528 },
529 }
530 self.offset_start = offset;
531 self.iter_start += 1;
532 }
533
534 #[inline]
535 fn back_iter_index(&mut self) {
536 let layout = &self.layout;
537 let index = self.index_end.as_mut();
538 let mut offset = self.offset_end;
539 let shape = layout.shape().as_ref();
540 let stride = layout.stride().as_ref();
541 match layout.ndim() {
542 0 => (),
543 1 => {
544 index[0] -= 1;
545 offset -= stride[0];
546 },
547 2 => {
548 if index[1] == 0 {
549 index[1] = shape[1] - 1;
550 offset += (shape[1] - 1) as isize * stride[1];
551 index[0] -= 1;
552 offset -= stride[0];
553 } else {
554 index[1] -= 1;
555 offset -= stride[1];
556 }
557 },
558 3 => {
559 if index[2] == 0 {
560 index[2] = shape[2] - 1;
561 offset += (shape[2] - 1) as isize * stride[2];
562 if index[1] == 0 {
563 index[1] = shape[1] - 1;
564 offset += (shape[1] - 1) as isize * stride[1];
565 index[0] -= 1;
566 offset -= stride[0];
567 } else {
568 index[1] -= 1;
569 offset -= stride[1];
570 }
571 } else {
572 index[2] -= 1;
573 offset -= stride[2];
574 }
575 },
576 4 => {
577 if index[3] == 0 {
578 index[3] = shape[3] - 1;
579 offset += (shape[3] - 1) as isize * stride[3];
580 if index[2] == 0 {
581 index[2] = shape[2] - 1;
582 offset += (shape[2] - 1) as isize * stride[2];
583 if index[1] == 0 {
584 index[1] = shape[1] - 1;
585 offset += (shape[1] - 1) as isize * stride[1];
586 index[0] -= 1;
587 offset -= stride[0];
588 } else {
589 index[1] -= 1;
590 offset -= stride[1];
591 }
592 } else {
593 index[2] -= 1;
594 offset -= stride[2];
595 }
596 } else {
597 index[3] -= 1;
598 offset -= stride[3];
599 }
600 },
601 _ => {
602 for (d, t, idx) in izip!(shape, stride, index.as_mut()).rev() {
603 if *idx == 0 {
604 *idx = *d - 1;
605 offset += (*d - 1) as isize * t;
606 } else {
607 *idx -= 1;
608 offset -= t;
609 break;
610 }
611 }
612 },
613 }
614 self.offset_end = offset;
615 self.iter_end -= 1;
616 }
617}
618
619impl<D> Iterator for IterLayoutRowMajor<D>
620where
621 D: DimDevAPI,
622{
623 type Item = usize;
624
625 fn next(&mut self) -> Option<Self::Item> {
626 if self.iter_start >= self.iter_end {
627 return None;
628 }
629 let offset = self.offset_start;
630 self.next_iter_index();
631 return Some(offset.try_into().unwrap());
632 }
633}
634
635impl<D> DoubleEndedIterator for IterLayoutRowMajor<D>
636where
637 D: DimDevAPI,
638{
639 fn next_back(&mut self) -> Option<Self::Item> {
640 if self.iter_start >= self.iter_end {
641 return None;
642 }
643 self.back_iter_index();
644 let offset = self.offset_end;
645 return Some(offset.try_into().unwrap());
646 }
647}
648
649impl<D> ExactSizeIterator for IterLayoutRowMajor<D>
650where
651 D: DimDevAPI,
652{
653 fn len(&self) -> usize {
654 self.iter_end - self.iter_start
655 }
656}
657
658impl<D> IterSplitAtAPI for IterLayoutRowMajor<D>
659where
660 D: DimDevAPI,
661{
662 fn split_at(self, index: usize) -> (Self, Self) {
663 Self::split_at(&self, index).unwrap()
664 }
665}
666
667#[derive(Clone, Debug)]
672pub enum IterLayout<D>
673where
674 D: DimDevAPI,
675{
676 RowMajor(IterLayoutRowMajor<D>),
677 ColMajor(IterLayoutColMajor<D>),
678}
679
680impl<D> IterLayout<D>
681where
682 D: DimDevAPI,
683{
684 pub fn new(layout: &Layout<D>, order: TensorIterOrder) -> Result<Self> {
685 use TensorIterOrder::*;
686 match order {
687 C => {
688 let iter = IterLayoutRowMajor::new(layout)?;
689 return Ok(Self::RowMajor(iter));
690 },
691 F => {
692 let iter = IterLayoutColMajor::new(layout)?;
693 return Ok(Self::ColMajor(iter));
694 },
695 A => match FlagOrder::default() {
696 RowMajor => {
697 let iter = IterLayoutRowMajor::new(layout)?;
698 return Ok(Self::RowMajor(iter));
699 },
700 ColMajor => {
701 let iter = IterLayoutColMajor::new(layout)?;
702 return Ok(Self::ColMajor(iter));
703 },
704 },
705 K | G => {
706 let layout = translate_to_col_major_unary(layout, TensorIterOrder::K)?;
707 let iter = IterLayoutColMajor::new(&layout)?;
708 return Ok(Self::ColMajor(iter));
709 },
710 _ => rstsr_raise!(InvalidValue),
711 }
712 }
713}
714
715impl<D> Iterator for IterLayout<D>
716where
717 D: DimDevAPI,
718{
719 type Item = usize;
720
721 fn next(&mut self) -> Option<Self::Item> {
722 match self {
723 Self::RowMajor(iter) => iter.next(),
724 Self::ColMajor(iter) => iter.next(),
725 }
726 }
727}
728
729impl<D> DoubleEndedIterator for IterLayout<D>
730where
731 D: DimDevAPI,
732{
733 fn next_back(&mut self) -> Option<Self::Item> {
734 match self {
735 Self::RowMajor(iter) => iter.next_back(),
736 Self::ColMajor(iter) => iter.next_back(),
737 }
738 }
739}
740
741impl<D> ExactSizeIterator for IterLayout<D>
742where
743 D: DimDevAPI,
744{
745 fn len(&self) -> usize {
746 match self {
747 Self::RowMajor(iter) => iter.len(),
748 Self::ColMajor(iter) => iter.len(),
749 }
750 }
751}
752
753impl<D> IterSplitAtAPI for IterLayout<D>
754where
755 D: DimDevAPI,
756{
757 fn split_at(self, index: usize) -> (Self, Self) {
758 match self {
759 Self::RowMajor(iter) => {
760 let (lhs, rhs) = iter.split_at(index);
761 (Self::RowMajor(lhs), Self::RowMajor(rhs))
762 },
763 Self::ColMajor(iter) => {
764 let (lhs, rhs) = iter.split_at(index);
765 (Self::ColMajor(lhs), Self::ColMajor(rhs))
766 },
767 }
768 }
769}
770
771#[derive(Clone, Debug)]
776pub struct IndexedIterLayout<D>
777where
778 D: DimDevAPI,
779{
780 pub(crate) layout_iter: IterLayout<D>,
781}
782
783impl<D> IndexedIterLayout<D>
784where
785 D: DimDevAPI,
786{
787 pub fn new(layout: &Layout<D>, order: FlagOrder) -> Result<Self> {
788 let order = match order {
789 RowMajor => TensorIterOrder::C,
790 ColMajor => TensorIterOrder::F,
791 };
792 Ok(Self { layout_iter: IterLayout::new(layout, order)? })
793 }
794}
795
796impl<D> Iterator for IndexedIterLayout<D>
797where
798 D: DimDevAPI,
799{
800 type Item = (D, usize);
801
802 fn next(&mut self) -> Option<Self::Item> {
803 let index = match &self.layout_iter {
804 IterLayout::ColMajor(iter_inner) => iter_inner.index_start.clone(),
805 IterLayout::RowMajor(iter_inner) => iter_inner.index_start.clone(),
806 };
807 self.layout_iter.next().map(|offset| (index, offset))
808 }
809}
810
811impl<D> DoubleEndedIterator for IndexedIterLayout<D>
812where
813 D: DimDevAPI,
814{
815 fn next_back(&mut self) -> Option<Self::Item> {
816 let index = match &self.layout_iter {
817 IterLayout::ColMajor(iter_inner) => iter_inner.index_start.clone(),
818 IterLayout::RowMajor(iter_inner) => iter_inner.index_start.clone(),
819 };
820 self.layout_iter.next_back().map(|offset| (index, offset))
821 }
822}
823
824impl<D> ExactSizeIterator for IndexedIterLayout<D>
825where
826 D: DimDevAPI,
827{
828 fn len(&self) -> usize {
829 self.layout_iter.len()
830 }
831}
832
833impl<D> IterSplitAtAPI for IndexedIterLayout<D>
834where
835 D: DimDevAPI,
836{
837 fn split_at(self, mid: usize) -> (Self, Self) {
838 let (lhs, rhs) = self.layout_iter.split_at(mid);
839 let lhs = IndexedIterLayout { layout_iter: lhs };
840 let rhs = IndexedIterLayout { layout_iter: rhs };
841 (lhs, rhs)
842 }
843}
844
845#[allow(unused_mut)]
850pub fn layout_col_major_dim_dispatch_1<D, F>(la: &Layout<D>, mut f: F) -> Result<()>
851where
852 D: DimAPI,
853 F: FnMut(usize),
854{
855 #[cfg(feature = "dispatch_dim_layout_iter")]
856 {
857 macro_rules! dispatch {
858 ($dim: ident) => {{
859 let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dim>()?)?;
860 iter_a.for_each(f);
861 }};
862 }
863 match la.ndim() {
864 0 => f(la.offset()),
865 1 => dispatch!(Ix1),
866 2 => dispatch!(Ix2),
867 3 => dispatch!(Ix3),
868 4 => dispatch!(Ix4),
869 5 => dispatch!(Ix5),
870 6 => dispatch!(Ix6),
871 _ => {
872 let iter_a = IterLayoutColMajor::new(la)?;
873 iter_a.for_each(f);
874 },
875 }
876 }
877
878 #[cfg(not(feature = "dispatch_dim_layout_iter"))]
879 {
880 let iter_a = IterLayoutColMajor::new(la)?;
881 iter_a.for_each(f);
882 }
883 Ok(())
884}
885
886#[allow(unused_mut)]
887pub fn layout_col_major_dim_dispatch_2<D, F>(la: &Layout<D>, lb: &Layout<D>, mut f: F) -> Result<()>
888where
889 D: DimAPI,
890 F: FnMut((usize, usize)),
891{
892 rstsr_assert_eq!(la.ndim(), lb.ndim(), RuntimeError)?;
893
894 #[cfg(feature = "dispatch_dim_layout_iter")]
895 {
896 macro_rules! dispatch {
897 ($dim: ident) => {{
898 let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dim>()?)?;
899 let iter_b = IterLayoutColMajor::new(&lb.to_dim::<$dim>()?)?;
900 izip!(iter_a, iter_b).for_each(f);
901 }};
902 }
903 match la.ndim() {
904 0 => f((la.offset(), lb.offset())),
905 1 => dispatch!(Ix1),
906 2 => dispatch!(Ix2),
907 3 => dispatch!(Ix3),
908 4 => dispatch!(Ix4),
909 5 => dispatch!(Ix5),
910 6 => dispatch!(Ix6),
911 _ => {
912 let iter_a = IterLayoutColMajor::new(la)?;
913 let iter_b = IterLayoutColMajor::new(lb)?;
914 izip!(iter_a, iter_b).for_each(f);
915 },
916 }
917 }
918
919 #[cfg(not(feature = "dispatch_dim_layout_iter"))]
920 {
921 let iter_a = IterLayoutColMajor::new(la)?;
922 let iter_b = IterLayoutColMajor::new(lb)?;
923 izip!(iter_a, iter_b).for_each(f);
924 }
925 Ok(())
926}
927
928#[allow(unused_mut)]
929pub fn layout_col_major_dim_dispatch_3<D, F>(
930 la: &Layout<D>,
931 lb: &Layout<D>,
932 lc: &Layout<D>,
933 mut f: F,
934) -> Result<()>
935where
936 D: DimAPI,
937 F: FnMut((usize, usize, usize)),
938{
939 rstsr_assert_eq!(la.ndim(), lb.ndim(), RuntimeError)?;
940 rstsr_assert_eq!(la.ndim(), lc.ndim(), RuntimeError)?;
941
942 #[cfg(feature = "dispatch_dim_layout_iter")]
943 {
944 macro_rules! dispatch {
945 ($dim: ident) => {{
946 let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dim>()?)?;
947 let iter_b = IterLayoutColMajor::new(&lb.to_dim::<$dim>()?)?;
948 let iter_c = IterLayoutColMajor::new(&lc.to_dim::<$dim>()?)?;
949 izip!(iter_a, iter_b, iter_c).for_each(f);
950 }};
951 }
952 match la.ndim() {
953 0 => f((la.offset(), lb.offset(), lc.offset())),
954 1 => dispatch!(Ix1),
955 2 => dispatch!(Ix2),
956 3 => dispatch!(Ix3),
957 4 => dispatch!(Ix4),
958 5 => dispatch!(Ix5),
959 6 => dispatch!(Ix6),
960 _ => {
961 let iter_a = IterLayoutColMajor::new(la)?;
962 let iter_b = IterLayoutColMajor::new(lb)?;
963 let iter_c = IterLayoutColMajor::new(lc)?;
964 izip!(iter_a, iter_b, iter_c).for_each(f);
965 },
966 }
967 }
968
969 #[cfg(not(feature = "dispatch_dim_layout_iter"))]
970 {
971 let iter_a = IterLayoutColMajor::new(la)?;
972 let iter_b = IterLayoutColMajor::new(lb)?;
973 let iter_c = IterLayoutColMajor::new(lc)?;
974 izip!(iter_a, iter_b, iter_c).for_each(f);
975 }
976 Ok(())
977}
978
979#[allow(unused_mut)]
980pub fn layout_col_major_dim_dispatch_2diff<DA, DB, F>(
981 la: &Layout<DA>,
982 lb: &Layout<DB>,
983 mut f: F,
984) -> Result<()>
985where
986 DA: DimAPI,
987 DB: DimAPI,
988 F: FnMut((usize, usize)),
989{
990 #[cfg(feature = "dispatch_dim_layout_iter")]
991 {
992 macro_rules! dispatch {
993 ($dima: ident, $dimb: ident) => {{
994 let iter_a = IterLayoutColMajor::new(&la.to_dim::<$dima>()?)?;
995 let iter_b = IterLayoutColMajor::new(&lb.to_dim::<$dimb>()?)?;
996 izip!(iter_a, iter_b).for_each(f);
997 }};
998 }
999 match (la.ndim(), lb.ndim()) {
1000 (0, 0) => f((la.offset(), lb.offset())),
1001 (1, 1) => dispatch!(Ix1, Ix1),
1002 (1, 2) => dispatch!(Ix1, Ix2),
1003 (1, 3) => dispatch!(Ix1, Ix3),
1004 (1, 4) => dispatch!(Ix1, Ix4),
1005 (1, 5) => dispatch!(Ix1, Ix5),
1006 (1, 6) => dispatch!(Ix1, Ix6),
1007 (2, 1) => dispatch!(Ix2, Ix1),
1008 (2, 2) => dispatch!(Ix2, Ix2),
1009 (2, 3) => dispatch!(Ix2, Ix3),
1010 (2, 4) => dispatch!(Ix2, Ix4),
1011 (2, 5) => dispatch!(Ix2, Ix5),
1012 (2, 6) => dispatch!(Ix2, Ix6),
1013 (3, 1) => dispatch!(Ix3, Ix1),
1014 (3, 2) => dispatch!(Ix3, Ix2),
1015 (3, 3) => dispatch!(Ix3, Ix3),
1016 (3, 4) => dispatch!(Ix3, Ix4),
1017 (3, 5) => dispatch!(Ix3, Ix5),
1018 (3, 6) => dispatch!(Ix3, Ix6),
1019 (4, 1) => dispatch!(Ix4, Ix1),
1020 (4, 2) => dispatch!(Ix4, Ix2),
1021 (4, 3) => dispatch!(Ix4, Ix3),
1022 (4, 4) => dispatch!(Ix4, Ix4),
1023 (4, 5) => dispatch!(Ix4, Ix5),
1024 (4, 6) => dispatch!(Ix4, Ix6),
1025 (5, 1) => dispatch!(Ix5, Ix1),
1026 (5, 2) => dispatch!(Ix5, Ix2),
1027 (5, 3) => dispatch!(Ix5, Ix3),
1028 (5, 4) => dispatch!(Ix5, Ix4),
1029 (5, 5) => dispatch!(Ix5, Ix5),
1030 (5, 6) => dispatch!(Ix5, Ix6),
1031 (6, 1) => dispatch!(Ix6, Ix1),
1032 (6, 2) => dispatch!(Ix6, Ix2),
1033 (6, 3) => dispatch!(Ix6, Ix3),
1034 (6, 4) => dispatch!(Ix6, Ix4),
1035 (6, 5) => dispatch!(Ix6, Ix5),
1036 (6, 6) => dispatch!(Ix6, Ix6),
1037 _ => {
1038 let iter_a = IterLayoutColMajor::new(la)?;
1039 let iter_b = IterLayoutColMajor::new(lb)?;
1040 izip!(iter_a, iter_b).for_each(f);
1041 },
1042 }
1043 }
1044
1045 #[cfg(not(feature = "dispatch_dim_layout_iter"))]
1046 {
1047 let iter_a = IterLayoutColMajor::new(la)?;
1048 let iter_b = IterLayoutColMajor::new(lb)?;
1049 izip!(iter_a, iter_b).for_each(f);
1050 }
1051 Ok(())
1052}
1053
1054#[cfg(test)]
1057mod test_col_major {
1058 use super::*;
1059
1060 type Order = TensorIterOrder;
1062
1063 #[test]
1064 fn test_iter_next() {
1065 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
1070 let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1072 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1073 let vec = iter.collect::<Vec<_>>();
1074 assert_eq!(vec, [
1075 782, 797, 812, 827, 842, 857, 602, 617, 632, 647, 662, 677, 785, 800, 815, 830, 845,
1076 860, 605, 620, 635, 650, 665, 680, 788, 803, 818, 833, 848, 863, 608, 623, 638, 653,
1077 668, 683
1078 ]);
1079 let layout_trans = translate_to_col_major_unary(&layout, Order::F).unwrap();
1081 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1082 let vec = iter.collect::<Vec<_>>();
1083 assert_eq!(vec, [
1084 782, 785, 788, 602, 605, 608, 797, 800, 803, 617, 620, 623, 812, 815, 818, 632, 635,
1085 638, 827, 830, 833, 647, 650, 653, 842, 845, 848, 662, 665, 668, 857, 860, 863, 677,
1086 680, 683
1087 ]);
1088 let layout_trans = translate_to_col_major_unary(&layout, Order::K).unwrap();
1090 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1091 let vec = iter.collect::<Vec<_>>();
1092 assert_eq!(vec, [
1093 602, 605, 608, 617, 620, 623, 632, 635, 638, 647, 650, 653, 662, 665, 668, 677, 680,
1094 683, 782, 785, 788, 797, 800, 803, 812, 815, 818, 827, 830, 833, 842, 845, 848, 857,
1095 860, 863
1096 ]);
1097 let layout_trans = translate_to_col_major_unary(&layout, Order::K).unwrap();
1100 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1101 let vec = iter.collect::<Vec<_>>();
1102 assert_eq!(vec, [
1103 602, 605, 608, 617, 620, 623, 632, 635, 638, 647, 650, 653, 662, 665, 668, 677, 680,
1104 683, 782, 785, 788, 797, 800, 803, 812, 815, 818, 827, 830, 833, 842, 845, 848, 857,
1105 860, 863
1106 ]);
1107 assert!(translate_to_col_major_unary(&layout, Order::B).is_err());
1109 }
1110
1111 #[test]
1112 fn test_iter_back() {
1113 let layout = Layout::new([10, 10, 10], [10, 1, 100], 0).unwrap();
1114 let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1116 println!("{:?}", unsafe { layout.shape().unravel_index_f(100) });
1117 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1118 let vec_next = iter.collect::<Vec<_>>();
1119 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1120 let vec_back = iter.rev().collect::<Vec<_>>();
1121 assert_eq!(vec_next, vec_back.iter().rev().copied().collect::<Vec<_>>());
1122 }
1123
1124 #[test]
1125 fn test_iter_back_empty() {
1126 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
1127 let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1129 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1130 let vec_next = iter.clone().collect::<Vec<_>>();
1131 let vec_back = iter.clone().rev().collect::<Vec<_>>();
1132 assert_eq!(vec_next, vec_back.iter().rev().copied().collect::<Vec<_>>());
1133
1134 let layout = Layout::new([10], [10], 10).unwrap();
1135 let layout_trans = translate_to_col_major_unary(&layout, Order::C).unwrap();
1137 let iter = IterLayoutColMajor::new(&layout_trans).unwrap();
1138 let vec_next = iter.clone().collect::<Vec<_>>();
1139 let vec_back = iter.clone().rev().collect::<Vec<_>>();
1140 assert_eq!(vec_next, vec_back.iter().rev().copied().collect::<Vec<_>>());
1141 }
1142}
1143
1144#[cfg(test)]
1145mod test_row_major {
1146 use super::*;
1147
1148 #[test]
1149 fn test_iter_next() {
1150 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
1151 let iter = IterLayoutRowMajor::new(&layout).unwrap();
1153 let vec = iter.collect::<Vec<_>>();
1154 assert_eq!(vec, [
1155 782, 797, 812, 827, 842, 857, 602, 617, 632, 647, 662, 677, 785, 800, 815, 830, 845,
1156 860, 605, 620, 635, 650, 665, 680, 788, 803, 818, 833, 848, 863, 608, 623, 638, 653,
1157 668, 683
1158 ]);
1159 let iter = IterLayoutRowMajor::new(&layout).unwrap();
1160 let vec = iter.rev().collect::<Vec<_>>();
1161 assert_eq!(vec, [
1162 683, 668, 653, 638, 623, 608, 863, 848, 833, 818, 803, 788, 680, 665, 650, 635, 620,
1163 605, 860, 845, 830, 815, 800, 785, 677, 662, 647, 632, 617, 602, 857, 842, 827, 812,
1164 797, 782
1165 ]);
1166 }
1167}