1use std::fmt;
26
27use crate::shape::Shape;
28use crate::slice::Slice;
29
30pub type Coord = Vec<usize>;
33
34pub struct ReshapedShape {
41 pub shape: Shape,
44
45 pub factors: Vec<(String, Vec<usize>)>,
48}
49
50#[allow(dead_code)]
51const _: () = {
52 fn assert<T: Send + Sync + 'static>() {}
53 let _ = assert::<ReshapedShape>;
54};
55
56impl std::fmt::Debug for ReshapedShape {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("ReshapedShape")
59 .field("labels", &self.shape.labels())
60 .field("sizes", &self.shape.slice().sizes())
61 .field("strides", &self.shape.slice().strides())
62 .field("offset", &self.shape.slice().offset())
63 .field("factors", &self.factors)
64 .finish()
65 }
66}
67
68impl std::fmt::Display for ReshapedShape {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
70 write!(
71 f,
72 "ReshapedShape {{ [off={} sz={:?} st={:?} lab={:?} fac={:?}] }}",
73 self.shape.slice().offset(),
74 self.shape.slice().sizes(),
75 self.shape.slice().strides(),
76 self.shape.labels(),
77 self.factors
78 )
79 }
80}
81
82pub(crate) fn factor_dims(sizes: &[usize], limit: Limit) -> Vec<Vec<usize>> {
96 let limit = limit.get();
97 sizes
98 .iter()
99 .map(|&size| {
100 if size <= limit {
101 return vec![size];
102 }
103 let mut rem = size;
104 let mut factors = Vec::new();
105 for d in (2..=limit).rev() {
106 while rem % d == 0 {
107 factors.push(d);
108 rem /= d;
109 }
110 }
111 if rem > 1 {
112 factors.push(rem);
113 }
114 factors
115 })
116 .collect()
117}
118
119pub fn to_reshaped_coord<'a>(
123 original: &'a Slice,
124 reshaped: &'a Slice,
125) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
126 let original = original.clone();
127 let reshaped = reshaped.clone();
128 move |coord: &[usize]| -> Coord {
129 let flat = original.location(coord).unwrap();
130 reshaped.coordinates(flat).unwrap()
131 }
132}
133
134pub fn to_original_coord<'a>(
138 reshaped: &'a Slice,
139 original: &'a Slice,
140) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
141 let reshaped = reshaped.clone();
142 let original = original.clone();
143 move |coord: &[usize]| -> Coord {
144 let flat = reshaped.location(coord).unwrap();
145 original.coordinates(flat).unwrap()
146 }
147}
148
149#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
167pub struct Limit(usize);
168
169impl Limit {
170 pub fn new(n: usize) -> Self {
172 assert!(n >= 1, "Limit must be at least 1");
173 Self(n)
174 }
175
176 pub fn get(self) -> usize {
178 self.0
179 }
180}
181
182impl Default for Limit {
183 fn default() -> Self {
184 Self(32)
185 }
186}
187
188impl From<usize> for Limit {
189 fn from(n: usize) -> Self {
190 Self::new(n)
191 }
192}
193
194pub trait ReshapeSliceExt {
214 fn view_limit(&self, limit: Limit) -> Slice;
226}
227
228impl ReshapeSliceExt for Slice {
229 fn view_limit(&self, limit: Limit) -> Slice {
230 view_limit(self, limit)
231 }
232}
233
234pub trait ReshapeShapeExt {
236 fn reshape(&self, limit: Limit) -> ReshapedShape;
239}
240
241impl ReshapeShapeExt for Shape {
242 fn reshape(&self, limit: Limit) -> ReshapedShape {
243 reshape_shape(self, limit)
244 }
245}
246
247pub mod prelude {
250 pub use super::ReshapeShapeExt;
251 pub use super::ReshapeSliceExt;
252}
253
254pub fn view_limit(slice: &Slice, limit: Limit) -> Slice {
284 let orig_sizes = slice.sizes();
285 let orig_strides = slice.strides();
286
287 let factored_sizes = factor_dims(orig_sizes, limit);
289
290 let reshaped_sizes: Vec<usize> = factored_sizes.iter().flatten().cloned().collect();
292 let mut reshaped_strides = Vec::with_capacity(reshaped_sizes.len());
293
294 for (&orig_stride, factors) in orig_strides.iter().zip(&factored_sizes) {
295 let mut sub_strides = Vec::with_capacity(factors.len());
296 let mut stride = orig_stride;
297 for &f in factors.iter().rev() {
298 sub_strides.push(stride);
299 stride *= f;
300 }
301 sub_strides.reverse();
302 reshaped_strides.extend(sub_strides);
303 }
304
305 Slice::new(slice.offset(), reshaped_sizes, reshaped_strides).unwrap()
306}
307
308pub fn reshape_shape(shape: &Shape, limit: Limit) -> ReshapedShape {
328 let reshaped_slice = shape.slice().view_limit(limit);
329 let original_labels = shape.labels();
330 let original_sizes = shape.slice().sizes();
331
332 let factors = factor_dims(original_sizes, limit);
333 let factored_dims: Vec<(String, Vec<usize>)> =
334 original_labels.iter().cloned().zip(factors).collect();
335
336 let labels = expand_labels(&factored_dims);
337 let shape = Shape::new(labels, reshaped_slice).expect("invalid reshaped shape");
338
339 ReshapedShape {
340 shape,
341 factors: factored_dims,
342 }
343}
344
345pub fn expand_labels(factors: &[(String, Vec<usize>)]) -> Vec<String> {
366 let mut labels = Vec::new();
367 for (label, dims) in factors {
368 if dims.len() == 1 {
369 labels.push(label.clone());
370 } else {
371 for (i, _) in dims.iter().enumerate() {
372 labels.push(format!("{}/{}", label, i));
373 }
374 }
375 }
376 labels
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::Slice;
383 use crate::shape;
384
385 #[test]
386 fn test_factor_dims_basic() {
387 assert_eq!(
388 factor_dims(&[6, 8], Limit::from(4)),
389 vec![vec![3, 2], vec![4, 2]]
390 );
391 assert_eq!(factor_dims(&[5], Limit::from(3)), vec![vec![5]]);
392 assert_eq!(factor_dims(&[30], Limit::from(5)), vec![vec![5, 3, 2]]);
393 }
394
395 #[macro_export]
406 macro_rules! assert_layout_preserved {
407 ($original:expr_2021, $reshaped:expr_2021) => {{
408 for coord in $original.dim_iter($original.num_dim()) {
410 let forward = to_reshaped_coord($original, &$reshaped);
411 let inverse = to_original_coord(&$reshaped, $original);
412 let reshaped_coord = forward(&coord);
415 let roundtrip = inverse(&reshaped_coord);
417 assert_eq!(
418 roundtrip, coord,
419 "Inverse mismatch: reshaped {:?} → original {:?}, expected {:?}",
420 reshaped_coord, roundtrip, coord
421 );
422 let flat_orig = $original.location(&coord).unwrap();
424 let flat_reshaped = $reshaped.location(&reshaped_coord).unwrap();
426 assert_eq!(
429 flat_orig, flat_reshaped,
430 "Flat index mismatch: original {:?} → reshaped {:?}",
431 coord, reshaped_coord
432 );
433 let recovered = $reshaped.coordinates(flat_reshaped).unwrap();
435 assert_eq!(
438 reshaped_coord, recovered,
439 "Coordinate mismatch: flat index {} → expected {:?}, got {:?}",
440 flat_reshaped, reshaped_coord, recovered
441 );
442 }
443 }};
444 }
445
446 #[test]
447 fn test_reshape_split_1d_row_major() {
448 let s = Slice::new_row_major(vec![1024]);
449 let reshaped = s.view_limit(Limit::from(8));
450
451 assert_eq!(reshaped.offset(), 0);
452 assert_eq!(reshaped.sizes(), &vec![8, 8, 8, 2]);
453 assert_eq!(reshaped.strides(), &vec![128, 16, 2, 1]);
454 assert_eq!(
455 factor_dims(s.sizes(), Limit::from(8)),
456 vec![vec![8, 8, 8, 2]]
457 );
458
459 assert_layout_preserved!(&s, &reshaped);
460 }
461
462 #[test]
463 fn test_reshape_6_with_limit_2() {
464 let s = Slice::new_row_major(vec![6]);
465 let reshaped = view_limit(&s, Limit::from(2));
466 assert_eq!(factor_dims(s.sizes(), Limit::from(2)), vec![vec![2, 3]]);
467 assert_layout_preserved!(&s, &reshaped);
468 }
469
470 #[test]
471 fn test_reshape_identity_noop_2d() {
472 let original = Slice::new_row_major(vec![4, 8]);
474 let reshaped = original.view_limit(Limit::from(8));
475
476 assert_eq!(reshaped.sizes(), original.sizes());
477 assert_eq!(reshaped.strides(), original.strides());
478 assert_eq!(reshaped.offset(), original.offset());
479 assert_eq!(
480 vec![vec![4], vec![8]],
481 original
482 .sizes()
483 .iter()
484 .map(|&n| vec![n])
485 .collect::<Vec<_>>()
486 );
487 assert_layout_preserved!(&original, &reshaped);
488 }
489
490 #[test]
491 fn test_reshape_empty_slice() {
492 let original = Slice::new_row_major(vec![]);
494 let reshaped = view_limit(&original, Limit::from(8));
495
496 assert_eq!(reshaped.sizes(), original.sizes());
497 assert_eq!(reshaped.strides(), original.strides());
498 assert_eq!(reshaped.offset(), original.offset());
499
500 assert_layout_preserved!(&original, &reshaped);
501 }
502
503 #[test]
504 fn test_reshape_mixed_dims_3d() {
505 let original = Slice::new_row_major(vec![6, 8, 10]);
507 let reshaped = original.view_limit(Limit::from(4));
508
509 assert_eq!(
510 factor_dims(original.sizes(), Limit::from(4)),
511 vec![vec![3, 2], vec![4, 2], vec![2, 5]]
512 );
513 assert_eq!(reshaped.sizes(), &[3, 2, 4, 2, 2, 5]);
514
515 assert_layout_preserved!(&original, &reshaped);
516 }
517
518 #[test]
519 fn test_reshape_all_large_dims() {
520 let original = Slice::new_row_major(vec![12, 18, 20]);
522 let reshaped = original.view_limit(Limit::from(4));
523
524 assert_eq!(
525 factor_dims(original.sizes(), Limit::from(4)),
526 vec![vec![4, 3], vec![3, 3, 2], vec![4, 5]]
527 );
528 assert_eq!(reshaped.sizes(), &[4, 3, 3, 3, 2, 4, 5]);
529
530 assert_layout_preserved!(&original, &reshaped);
531 }
532
533 #[test]
534 fn test_reshape_split_1d_factors_3_3_2_2() {
535 let original = Slice::new_row_major(vec![36]);
537 let reshaped = view_limit(&original, Limit::from(3));
538
539 assert_eq!(
540 factor_dims(original.sizes(), Limit::from(3)),
541 vec![vec![3, 3, 2, 2]]
542 );
543 assert_eq!(reshaped.sizes(), &[3, 3, 2, 2]);
544 assert_layout_preserved!(&original, &reshaped);
545 }
546
547 #[test]
548 fn test_reshape_large_prime_dimension() {
549 let original = Slice::new_row_major(vec![7]);
551 let reshaped = view_limit(&original, Limit::from(4));
552
553 assert_eq!(factor_dims(original.sizes(), Limit::from(4)), vec![vec![7]]);
555 assert_eq!(reshaped.sizes(), &[7]);
556
557 assert_layout_preserved!(&original, &reshaped);
558 }
559
560 #[test]
561 fn test_reshape_split_1d_factors_5_3_2() {
562 let original = Slice::new_row_major(vec![30]);
564 let reshaped = view_limit(&original, Limit::from(5));
565
566 assert_eq!(
567 factor_dims(original.sizes(), Limit::from(5)),
568 vec![vec![5, 3, 2]]
569 );
570 assert_eq!(reshaped.sizes(), &[5, 3, 2]);
571 assert_eq!(reshaped.strides(), &[6, 2, 1]);
572
573 assert_layout_preserved!(&original, &reshaped);
574 }
575
576 #[test]
577 fn test_reshape_factors_2_6_2_8_8() {
578 let original = Slice::new_row_major(vec![2, 12, 64]);
580 let reshaped = original.view_limit(Limit::from(8));
581
582 assert_eq!(
583 factor_dims(original.sizes(), Limit::from(8)),
584 vec![vec![2], vec![6, 2], vec![8, 8]]
585 );
586 assert_eq!(reshaped.sizes(), &[2, 6, 2, 8, 8]);
587 assert_eq!(reshaped.strides(), &[768, 128, 64, 8, 1]);
588
589 assert_layout_preserved!(&original, &reshaped);
590 }
591
592 #[test]
593 fn test_reshape_all_dims_within_limit() {
594 let original = Slice::new_row_major(vec![2, 3, 4]);
596 let reshaped = original.view_limit(Limit::from(4));
597
598 assert_eq!(
599 factor_dims(original.sizes(), Limit::from(4)),
600 vec![vec![2], vec![3], vec![4]]
601 );
602 assert_eq!(reshaped.sizes(), &[2, 3, 4]);
603 assert_eq!(reshaped.strides(), original.strides());
604 assert_eq!(reshaped.offset(), original.offset());
605
606 assert_layout_preserved!(&original, &reshaped);
607 }
608
609 #[test]
610 fn test_reshape_degenerate_dimension() {
611 let original = Slice::new_row_major(vec![1, 12]);
613 let reshaped = original.view_limit(Limit::from(4));
614
615 assert_eq!(
616 factor_dims(original.sizes(), Limit::from(4)),
617 vec![vec![1], vec![4, 3]]
618 );
619 assert_eq!(reshaped.sizes(), &[1, 4, 3]);
620
621 assert_layout_preserved!(&original, &reshaped);
622 }
623
624 #[test]
625 fn test_select_then_reshape() {
626 let original = shape!(zone = 2, host = 3, gpu = 4);
628
629 let selected = original.select("zone", 1).unwrap();
631 assert_eq!(selected.slice().offset(), 12); assert_eq!(selected.slice().sizes(), &[1, 3, 4]);
633
634 let reshaped = selected.slice().view_limit(Limit::from(2));
637
638 assert_eq!(
639 factor_dims(selected.slice().sizes(), Limit::from(2)),
640 vec![vec![1], vec![3], vec![2, 2]]
641 );
642 assert_eq!(reshaped.sizes(), &[1, 3, 2, 2]);
643 assert_eq!(reshaped.strides(), &[12, 4, 2, 1]);
644 assert_eq!(reshaped.offset(), 12); assert_layout_preserved!(selected.slice(), &reshaped);
647 }
648
649 #[test]
650 fn test_select_host_plane_then_reshape() {
651 let original = shape!(zone = 2, host = 3, gpu = 4);
653 let selected = original.select("host", 2).unwrap();
655 let reshaped = selected.slice().view_limit(Limit::from(2));
658
659 assert_layout_preserved!(selected.slice(), &reshaped);
660 }
661
662 #[test]
663 fn test_reshape_after_select_no_factoring_due_to_primes() {
664 let original = shape!(zone = 3, host = 4, gpu = 5);
666 let selected_zone = original.select("zone", 1).unwrap();
668 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 5]);
669 let selected_host = selected_zone.select("host", 2).unwrap();
671 assert_eq!(selected_host.slice().sizes(), &[1, 1, 5]);
672 let reshaped = selected_host.slice().view_limit(Limit::from(2));
674
675 assert_eq!(
676 factor_dims(selected_host.slice().sizes(), Limit::from(2)),
677 vec![vec![1], vec![1], vec![5]]
678 );
679 assert_eq!(reshaped.sizes(), &[1, 1, 5]);
680
681 assert_layout_preserved!(selected_host.slice(), &reshaped);
682 }
683
684 #[test]
685 fn test_reshape_after_multiple_selects_triggers_factoring() {
686 let original = shape!(zone = 2, host = 4, gpu = 8);
688 let selected_zone = original.select("zone", 1).unwrap();
690 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
691
692 let selected_host = selected_zone.select("host", 2).unwrap();
694 assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
695
696 let reshaped = selected_host.slice().view_limit(Limit::from(2));
698
699 assert_eq!(
700 factor_dims(selected_host.slice().sizes(), Limit::from(2)),
701 vec![vec![1], vec![1], vec![2, 2, 2]]
702 );
703 assert_eq!(reshaped.sizes(), &[1, 1, 2, 2, 2]);
704
705 assert_layout_preserved!(selected_host.slice(), &reshaped);
706 }
707
708 #[test]
709 fn test_expand_labels_singleton_dims() {
710 let factors = vec![("x".into(), vec![2]), ("y".into(), vec![4])];
711 let expected = vec!["x", "y"];
712 assert_eq!(expand_labels(&factors), expected);
713 }
714
715 #[test]
716 fn test_expand_labels_factored_dims() {
717 let factors = vec![("gpu".into(), vec![2, 2, 2])];
718 let expected = vec!["gpu/0", "gpu/1", "gpu/2"];
719 assert_eq!(expand_labels(&factors), expected);
720 }
721
722 #[test]
723 fn test_expand_labels_mixed_dims() {
724 let factors = vec![("zone".into(), vec![2]), ("gpu".into(), vec![2, 2])];
725 let expected = vec!["zone", "gpu/0", "gpu/1"];
726 assert_eq!(expand_labels(&factors), expected);
727 }
728
729 #[test]
730 fn test_expand_labels_empty() {
731 let factors: Vec<(String, Vec<usize>)> = vec![];
732 let expected: Vec<String> = vec![];
733 assert_eq!(expand_labels(&factors), expected);
734 }
735
736 #[test]
737 fn test_reshape_shape_noop() {
738 let shape = shape!(x = 4, y = 8);
739 let reshaped = reshape_shape(&shape, Limit::from(8));
740 assert_eq!(reshaped.shape.labels(), &["x", "y"]);
741 assert_eq!(reshaped.shape.slice(), shape.slice());
742 }
743
744 #[test]
745 fn test_reshape_shape_factored() {
746 let shape = shape!(gpu = 8);
747 let reshaped = reshape_shape(&shape, Limit::from(2));
748 assert_eq!(reshaped.shape.labels(), &["gpu/0", "gpu/1", "gpu/2"]);
749 assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2]);
750
751 let expected = shape.slice().view_limit(Limit::from(2));
752 assert_eq!(reshaped.shape.slice(), &expected);
753 }
754
755 #[test]
756 fn test_reshape_shape_singleton() {
757 let shape = shape!(x = 3);
758 let reshaped = reshape_shape(&shape, Limit::from(8));
759 assert_eq!(reshaped.shape.labels(), &["x"]);
760 assert_eq!(reshaped.shape.slice(), shape.slice());
761 }
762
763 #[test]
764 fn test_reshape_shape_prime_exceeds_limit() {
765 let shape = shape!(x = 11);
766 let reshaped = reshape_shape(&shape, Limit::from(5));
767 assert_eq!(reshaped.shape.labels(), &["x"]);
768 assert_eq!(reshaped.shape.slice(), shape.slice());
769 }
770
771 #[test]
772 fn test_reshape_shape_mixed_dims() {
773 let shape = shape!(zone = 2, gpu = 8);
774 let reshaped = reshape_shape(&shape, Limit::from(2));
775 assert_eq!(
776 reshaped.shape.labels(),
777 &["zone", "gpu/0", "gpu/1", "gpu/2"]
778 );
779 assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2, 2]);
780
781 let expected = shape.slice().view_limit(Limit::from(2));
782 assert_eq!(reshaped.shape.slice(), &expected);
783 }
784
785 #[test]
786 fn test_reshape_shape_after_selects() {
787 let original = shape!(zone = 2, host = 4, gpu = 8);
789
790 let selected_zone = original.select("zone", 1).unwrap();
792 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
793
794 let selected_host = selected_zone.select("host", 2).unwrap();
796 assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
797
798 let reshaped = reshape_shape(&selected_host, Limit::from(2));
800
801 assert_eq!(
803 reshaped.shape.labels(),
804 &["zone", "host", "gpu/0", "gpu/1", "gpu/2"]
805 );
806
807 assert_eq!(reshaped.shape.slice().sizes(), &[1, 1, 2, 2, 2]);
809
810 let expected = selected_host.slice().view_limit(Limit::from(2));
812 assert_eq!(reshaped.shape.slice(), &expected);
813 }
814}