ndslice/
reshape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Dimensional reshaping of slices and shapes.
10//!
11//! This module defines utilities for transforming a [`Slice`] or
12//! [`Shape`] by factoring large extents into smaller ones under a
13//! given limit. The result is a reshaped view with increased
14//! dimensionality and preserved memory layout.
15//!
16//! This is useful for hierarchical routing, structured fanout, and
17//! other multidimensional layout transformations.
18//!
19//! For [`Shape`]s, reshaping also expands dimension labels using a
20//! `label/N` naming convention, preserving the semantics of the
21//! original shape in the reshaped view_limit.
22//!
23//! See [`view_limit`] and [`reshape_shape`] for entry points.
24
25use std::fmt;
26
27use crate::shape::Shape;
28use crate::slice::Slice;
29
30/// Coordinate vector used throughout reshape logic. Semantically
31/// represents a point in multidimensional space.
32pub type Coord = Vec<usize>;
33
34/// A reshaped version of a `Shape`, with factored dimensions and
35/// updated labels.
36///
37///
38/// This type preserves coordinate bijections with the original shape
39/// and provides access to the transformed layout and label mappings.
40pub struct ReshapedShape {
41    /// The reshaped shape, with new labels and underlying factored
42    /// slice.
43    pub shape: Shape,
44
45    /// For each original dimension label, the list of sizes it was
46    /// split into.
47    pub factors: Vec<(String, Vec<usize>)>,
48}
49
50#[allow(dead_code)]
51const _: () = {
52    fn assert<T: Send + Sync + 'static>() {}
53    let _ = assert::<ReshapedShape>;
54};
55
56impl std::fmt::Debug for ReshapedShape {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("ReshapedShape")
59            .field("labels", &self.shape.labels())
60            .field("sizes", &self.shape.slice().sizes())
61            .field("strides", &self.shape.slice().strides())
62            .field("offset", &self.shape.slice().offset())
63            .field("factors", &self.factors)
64            .finish()
65    }
66}
67
68impl std::fmt::Display for ReshapedShape {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
70        write!(
71            f,
72            "ReshapedShape {{ [off={} sz={:?} st={:?} lab={:?} fac={:?}] }}",
73            self.shape.slice().offset(),
74            self.shape.slice().sizes(),
75            self.shape.slice().strides(),
76            self.shape.labels(),
77            self.factors
78        )
79    }
80}
81
82/// Returns, for each size, a list of factors that respect the given
83/// limit. If a size is ≤ limit, it is returned as a singleton.
84/// Otherwise, it is factored greedily using divisors ≤ limit, from
85/// largest to smallest.
86///
87/// For best results, dimensions should be chosen to allow factoring
88/// into small values under the selected limit (e.g., ≤ 32).
89/// Large prime numbers cannot be broken down and will remain as-is,
90/// limiting reshaping potential.
91///
92/// Prefer powers of 2 or other highly composite numbers
93/// (e.g., 8, 16, 32, 60, 120) over large primes (e.g., 17, 37, 113)
94/// when designing shapes intended for reshaping.
95pub(crate) fn factor_dims(sizes: &[usize], limit: Limit) -> Vec<Vec<usize>> {
96    let limit = limit.get();
97    sizes
98        .iter()
99        .map(|&size| {
100            if size <= limit {
101                return vec![size];
102            }
103            let mut rem = size;
104            let mut factors = Vec::new();
105            for d in (2..=limit).rev() {
106                while rem % d == 0 {
107                    factors.push(d);
108                    rem /= d;
109                }
110            }
111            if rem > 1 {
112                factors.push(rem);
113            }
114            factors
115        })
116        .collect()
117}
118
119/// Constructs a function that maps coordinates from the original
120/// slice to equivalent coordinates in the reshaped slice, preserving
121/// their flat (linear) position.
122pub fn to_reshaped_coord<'a>(
123    original: &'a Slice,
124    reshaped: &'a Slice,
125) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
126    let original = original.clone();
127    let reshaped = reshaped.clone();
128    move |coord: &[usize]| -> Coord {
129        let flat = original.location(coord).unwrap();
130        reshaped.coordinates(flat).unwrap()
131    }
132}
133
134/// Constructs a function that maps coordinates from the reshaped
135/// slice back to equivalent coordinates in the original slice,
136/// preserving their flat (linear) position.
137pub fn to_original_coord<'a>(
138    reshaped: &'a Slice,
139    original: &'a Slice,
140) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
141    let reshaped = reshaped.clone();
142    let original = original.clone();
143    move |coord: &[usize]| -> Coord {
144        let flat = reshaped.location(coord).unwrap();
145        original.coordinates(flat).unwrap()
146    }
147}
148
149/// A shaping constraint that bounds the maximum extent allowed in any
150/// reshaped dimension.
151///
152/// This limit controls how a given dimension is factored during
153/// reshaping. Values larger than `limit` are recursively decomposed
154/// into smaller factors (e.g., `view_limit([1024],
155/// Limit::new(32))` → `[32, 32]`).
156///
157/// The default limit is `32`, which balances fanout depth and layout
158/// regularity.
159///
160/// # Example
161/// ```
162/// use ndslice::reshape::Limit;
163/// let limit = Limit::new(64);
164/// assert_eq!(limit.get(), 64);
165/// ```
166#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
167pub struct Limit(usize);
168
169impl Limit {
170    /// Creates a new `Limit`. Panics if less than 1.
171    pub fn new(n: usize) -> Self {
172        assert!(n >= 1, "Limit must be at least 1");
173        Self(n)
174    }
175
176    /// Returns the inner value.
177    pub fn get(self) -> usize {
178        self.0
179    }
180}
181
182impl Default for Limit {
183    fn default() -> Self {
184        Self(32)
185    }
186}
187
188impl From<usize> for Limit {
189    fn from(n: usize) -> Self {
190        Self::new(n)
191    }
192}
193
194/// A trait for types that can be reshaped into a higher-dimensional
195/// view by factoring large extents into smaller ones.
196///
197/// This is implemented for [`Slice`], enabling ergonomic access to
198/// [`view_limit`] as a method.
199///
200/// # Example
201/// ```
202/// use ndslice::Slice;
203/// use ndslice::reshape::Limit;
204/// use ndslice::reshape::ReshapeSliceExt;
205///
206/// let slice = Slice::new_row_major(vec![1024]);
207/// let reshaped = slice.view_limit(Limit::new(32));
208/// assert_eq!(reshaped.sizes(), &[32, 32]);
209/// ```
210/// # Returns
211/// A reshaped [`Slice`] with increased dimensionality and preserved
212/// layout.
213pub trait ReshapeSliceExt {
214    /// Returns a reshaped version of this structure by factoring each
215    /// dimension into smaller extents no greater than `limit`,
216    /// preserving memory layout and flat index semantics. See
217    /// [`view_limit`] for full behavior and rationale.
218    ///
219    /// # Arguments
220    /// - `limit`: maximum size allowed in any reshaped dimension
221    ///
222    /// # Returns
223    /// A reshaped [`Slice`] with increased dimensionality and a
224    /// bijective mapping to the original.
225    fn view_limit(&self, limit: Limit) -> Slice;
226}
227
228impl ReshapeSliceExt for Slice {
229    fn view_limit(&self, limit: Limit) -> Slice {
230        view_limit(self, limit)
231    }
232}
233
234/// Extension trait for reshaping `Shape`s by factoring large dimensions.
235pub trait ReshapeShapeExt {
236    /// Produces a reshaped version of the shape with expanded
237    /// dimensions under the given size limit.
238    fn reshape(&self, limit: Limit) -> ReshapedShape;
239}
240
241impl ReshapeShapeExt for Shape {
242    fn reshape(&self, limit: Limit) -> ReshapedShape {
243        reshape_shape(self, limit)
244    }
245}
246
247/// For convenient `slice.view_limit()`, `shape.reshape()`
248/// syntax, `use reshape::prelude::*`.
249pub mod prelude {
250    pub use super::ReshapeShapeExt;
251    pub use super::ReshapeSliceExt;
252}
253
254/// Reshapes a slice by factoring each dimension into smaller extents
255/// under the given limit.
256///
257/// This transformation increases dimensionality by breaking large
258/// sizes into products of smaller factors (e.g., `[1024]` with limit
259/// 32 becomes `[32, 32]`). The result is a new [`Slice`] that
260/// preserves memory layout and flat index semantics.
261///
262/// Factoring is greedy, starting from the largest divisors ≤ `limit`.
263/// Dimensions that cannot be factored under the limit are left
264/// unchanged.
265///
266/// # Arguments
267/// - `slice`: the original multidimensional slice
268/// - `limit`: maximum extent allowed in any factored subdimension
269///
270/// # Returns
271/// A reshaped [`Slice`] with updated sizes and strides.
272///
273/// # Example
274/// ```
275/// use ndslice::Slice;
276/// use ndslice::reshape::Limit;
277/// use ndslice::reshape::view_limit;
278///
279/// let slice = Slice::new_row_major(vec![1024]);
280/// let reshaped = view_limit(&slice, Limit::new(32));
281/// assert_eq!(reshaped.sizes(), &[32, 32]);
282/// ```
283pub fn view_limit(slice: &Slice, limit: Limit) -> Slice {
284    let orig_sizes = slice.sizes();
285    let orig_strides = slice.strides();
286
287    // Step 1: Factor each size into subdimensions ≤ limit.
288    let factored_sizes = factor_dims(orig_sizes, limit);
289
290    // Step 2: Compute reshaped sizes and strides (row-major only).
291    let reshaped_sizes: Vec<usize> = factored_sizes.iter().flatten().cloned().collect();
292    let mut reshaped_strides = Vec::with_capacity(reshaped_sizes.len());
293
294    for (&orig_stride, factors) in orig_strides.iter().zip(&factored_sizes) {
295        let mut sub_strides = Vec::with_capacity(factors.len());
296        let mut stride = orig_stride;
297        for &f in factors.iter().rev() {
298            sub_strides.push(stride);
299            stride *= f;
300        }
301        sub_strides.reverse();
302        reshaped_strides.extend(sub_strides);
303    }
304
305    Slice::new(slice.offset(), reshaped_sizes, reshaped_strides).unwrap()
306}
307
308/// Reshapes a labeled [`Shape`] by factoring large extents into
309/// smaller ones, producing a new shape with expanded dimensionality
310/// and updated labels.
311///
312/// This uses [`view_limit`] on the underlying slice and [`expand_labels`]
313/// to generate labels for each factored dimension.
314///
315/// # Arguments
316/// - `shape`: the labeled shape to reshape
317/// - `limit`: maximum extent allowed per factored dimension
318///
319/// # Returns
320/// A new [`ReshapedShape`] with an updated [`Shape`] and dimension
321/// factoring metadata.
322///
323/// # Panics
324/// Panics if constructing the new `Shape` fails. This should not
325/// occur unless the reshaped slice and labels are inconsistent (a
326/// programming logic error).
327pub fn reshape_shape(shape: &Shape, limit: Limit) -> ReshapedShape {
328    let reshaped_slice = shape.slice().view_limit(limit);
329    let original_labels = shape.labels();
330    let original_sizes = shape.slice().sizes();
331
332    let factors = factor_dims(original_sizes, limit);
333    let factored_dims: Vec<(String, Vec<usize>)> =
334        original_labels.iter().cloned().zip(factors).collect();
335
336    let labels = expand_labels(&factored_dims);
337    let shape = Shape::new(labels, reshaped_slice).expect("invalid reshaped shape");
338
339    ReshapedShape {
340        shape,
341        factors: factored_dims,
342    }
343}
344
345/// Expands factored dimension labels into one label per subdimension.
346///
347/// Each input pair `(label, factors)` represents an original
348/// dimension and the extents it was factored into. If a dimension was
349/// not factored, it will have a single-element vector.
350///
351/// For example:
352/// - `[("zone", vec![2]), ("gpu", vec![2, 2, 2])]`
353///   becomes `["zone", "gpu/0", "gpu/1", "gpu/2"]`
354///
355/// This is used to generate new labels for reshaped shapes, where the
356/// dimensionality increases due to factoring.
357///
358/// # Arguments
359/// - `factors`: a list of factored dimension extents, paired with
360///   their labels
361///
362/// # Returns
363/// - A `Vec<String>` of expanded labels, one for each reshaped
364///   dimension.
365pub fn expand_labels(factors: &[(String, Vec<usize>)]) -> Vec<String> {
366    let mut labels = Vec::new();
367    for (label, dims) in factors {
368        if dims.len() == 1 {
369            labels.push(label.clone());
370        } else {
371            for (i, _) in dims.iter().enumerate() {
372                labels.push(format!("{}/{}", label, i));
373            }
374        }
375    }
376    labels
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::Slice;
383    use crate::shape;
384
385    #[test]
386    fn test_factor_dims_basic() {
387        assert_eq!(
388            factor_dims(&[6, 8], Limit::from(4)),
389            vec![vec![3, 2], vec![4, 2]]
390        );
391        assert_eq!(factor_dims(&[5], Limit::from(3)), vec![vec![5]]);
392        assert_eq!(factor_dims(&[30], Limit::from(5)), vec![vec![5, 3, 2]]);
393    }
394
395    // Verify that reshaping preserves memory layout by checking:
396    // 1. Coordinate round-tripping: original → reshaped → original
397    // 2. Flat index equality: original and reshaped coordinates map
398    //    to the same linear index
399    // 3. Index inversion: reshaped flat index maps back to the same
400    //    reshaped coordinate
401    //
402    // Together, these checks ensure that the reshaped view is
403    // layout-preserving and provides a bijective mapping between
404    // coordinate systems.
405    #[macro_export]
406    macro_rules! assert_layout_preserved {
407        ($original:expr_2021, $reshaped:expr_2021) => {{
408            // Iterate over all coordinates in the original slice.
409            for coord in $original.dim_iter($original.num_dim()) {
410                let forward = to_reshaped_coord($original, &$reshaped);
411                let inverse = to_original_coord(&$reshaped, $original);
412                // Apply the forward coordinate mapping from original
413                // to reshaped space.
414                let reshaped_coord = forward(&coord);
415                // Inverse mapping: reshaped coord → original coord.
416                let roundtrip = inverse(&reshaped_coord);
417                assert_eq!(
418                    roundtrip, coord,
419                    "Inverse mismatch: reshaped {:?} → original {:?}, expected {:?}",
420                    reshaped_coord, roundtrip, coord
421                );
422                // Compute flat index in the original slice.
423                let flat_orig = $original.location(&coord).unwrap();
424                // Compute flat index in the reshaped slice.
425                let flat_reshaped = $reshaped.location(&reshaped_coord).unwrap();
426                // Check that the flat index is preserved by the
427                // reshaping.
428                assert_eq!(
429                    flat_orig, flat_reshaped,
430                    "Flat index mismatch: original {:?} → reshaped {:?}",
431                    coord, reshaped_coord
432                );
433                // Invert the reshaped flat index back to coordinates.
434                let recovered = $reshaped.coordinates(flat_reshaped).unwrap();
435                // Ensure coordinate inversion is correct (round
436                // trip).
437                assert_eq!(
438                    reshaped_coord, recovered,
439                    "Coordinate mismatch: flat index {} → expected {:?}, got {:?}",
440                    flat_reshaped, reshaped_coord, recovered
441                );
442            }
443        }};
444    }
445
446    #[test]
447    fn test_reshape_split_1d_row_major() {
448        let s = Slice::new_row_major(vec![1024]);
449        let reshaped = s.view_limit(Limit::from(8));
450
451        assert_eq!(reshaped.offset(), 0);
452        assert_eq!(reshaped.sizes(), &vec![8, 8, 8, 2]);
453        assert_eq!(reshaped.strides(), &vec![128, 16, 2, 1]);
454        assert_eq!(
455            factor_dims(s.sizes(), Limit::from(8)),
456            vec![vec![8, 8, 8, 2]]
457        );
458
459        assert_layout_preserved!(&s, &reshaped);
460    }
461
462    #[test]
463    fn test_reshape_6_with_limit_2() {
464        let s = Slice::new_row_major(vec![6]);
465        let reshaped = view_limit(&s, Limit::from(2));
466        assert_eq!(factor_dims(s.sizes(), Limit::from(2)), vec![vec![2, 3]]);
467        assert_layout_preserved!(&s, &reshaped);
468    }
469
470    #[test]
471    fn test_reshape_identity_noop_2d() {
472        // All dimensions ≤ limit.
473        let original = Slice::new_row_major(vec![4, 8]);
474        let reshaped = original.view_limit(Limit::from(8));
475
476        assert_eq!(reshaped.sizes(), original.sizes());
477        assert_eq!(reshaped.strides(), original.strides());
478        assert_eq!(reshaped.offset(), original.offset());
479        assert_eq!(
480            vec![vec![4], vec![8]],
481            original
482                .sizes()
483                .iter()
484                .map(|&n| vec![n])
485                .collect::<Vec<_>>()
486        );
487        assert_layout_preserved!(&original, &reshaped);
488    }
489
490    #[test]
491    fn test_reshape_empty_slice() {
492        // 0-dimensional slice.
493        let original = Slice::new_row_major(vec![]);
494        let reshaped = view_limit(&original, Limit::from(8));
495
496        assert_eq!(reshaped.sizes(), original.sizes());
497        assert_eq!(reshaped.strides(), original.strides());
498        assert_eq!(reshaped.offset(), original.offset());
499
500        assert_layout_preserved!(&original, &reshaped);
501    }
502
503    #[test]
504    fn test_reshape_mixed_dims_3d() {
505        // 3D slice with one dimension exceeding the limit.
506        let original = Slice::new_row_major(vec![6, 8, 10]);
507        let reshaped = original.view_limit(Limit::from(4));
508
509        assert_eq!(
510            factor_dims(original.sizes(), Limit::from(4)),
511            vec![vec![3, 2], vec![4, 2], vec![2, 5]]
512        );
513        assert_eq!(reshaped.sizes(), &[3, 2, 4, 2, 2, 5]);
514
515        assert_layout_preserved!(&original, &reshaped);
516    }
517
518    #[test]
519    fn test_reshape_all_large_dims() {
520        // 3D slice with all dimensions exceeding the limit.
521        let original = Slice::new_row_major(vec![12, 18, 20]);
522        let reshaped = original.view_limit(Limit::from(4));
523
524        assert_eq!(
525            factor_dims(original.sizes(), Limit::from(4)),
526            vec![vec![4, 3], vec![3, 3, 2], vec![4, 5]]
527        );
528        assert_eq!(reshaped.sizes(), &[4, 3, 3, 3, 2, 4, 5]);
529
530        assert_layout_preserved!(&original, &reshaped);
531    }
532
533    #[test]
534    fn test_reshape_split_1d_factors_3_3_2_2() {
535        // 36 = 3 × 3 × 2 × 2.
536        let original = Slice::new_row_major(vec![36]);
537        let reshaped = view_limit(&original, Limit::from(3));
538
539        assert_eq!(
540            factor_dims(original.sizes(), Limit::from(3)),
541            vec![vec![3, 3, 2, 2]]
542        );
543        assert_eq!(reshaped.sizes(), &[3, 3, 2, 2]);
544        assert_layout_preserved!(&original, &reshaped);
545    }
546
547    #[test]
548    fn test_reshape_large_prime_dimension() {
549        // Prime larger than limit, cannot be factored.
550        let original = Slice::new_row_major(vec![7]);
551        let reshaped = view_limit(&original, Limit::from(4));
552
553        // Should remain as-is since 7 is prime > 4
554        assert_eq!(factor_dims(original.sizes(), Limit::from(4)), vec![vec![7]]);
555        assert_eq!(reshaped.sizes(), &[7]);
556
557        assert_layout_preserved!(&original, &reshaped);
558    }
559
560    #[test]
561    fn test_reshape_split_1d_factors_5_3_2() {
562        // 30 = 5 × 3 × 2, all ≤ limit.
563        let original = Slice::new_row_major(vec![30]);
564        let reshaped = view_limit(&original, Limit::from(5));
565
566        assert_eq!(
567            factor_dims(original.sizes(), Limit::from(5)),
568            vec![vec![5, 3, 2]]
569        );
570        assert_eq!(reshaped.sizes(), &[5, 3, 2]);
571        assert_eq!(reshaped.strides(), &[6, 2, 1]);
572
573        assert_layout_preserved!(&original, &reshaped);
574    }
575
576    #[test]
577    fn test_reshape_factors_2_6_2_8_8() {
578        // 12 = 6 × 2, 64 = 8 × 8 — all ≤ 8
579        let original = Slice::new_row_major(vec![2, 12, 64]);
580        let reshaped = original.view_limit(Limit::from(8));
581
582        assert_eq!(
583            factor_dims(original.sizes(), Limit::from(8)),
584            vec![vec![2], vec![6, 2], vec![8, 8]]
585        );
586        assert_eq!(reshaped.sizes(), &[2, 6, 2, 8, 8]);
587        assert_eq!(reshaped.strides(), &[768, 128, 64, 8, 1]);
588
589        assert_layout_preserved!(&original, &reshaped);
590    }
591
592    #[test]
593    fn test_reshape_all_dims_within_limit() {
594        // Original shape: [2, 3, 4] — all ≤ limit (4).
595        let original = Slice::new_row_major(vec![2, 3, 4]);
596        let reshaped = original.view_limit(Limit::from(4));
597
598        assert_eq!(
599            factor_dims(original.sizes(), Limit::from(4)),
600            vec![vec![2], vec![3], vec![4]]
601        );
602        assert_eq!(reshaped.sizes(), &[2, 3, 4]);
603        assert_eq!(reshaped.strides(), original.strides());
604        assert_eq!(reshaped.offset(), original.offset());
605
606        assert_layout_preserved!(&original, &reshaped);
607    }
608
609    #[test]
610    fn test_reshape_degenerate_dimension() {
611        // Degenerate dimension should remain unchanged.
612        let original = Slice::new_row_major(vec![1, 12]);
613        let reshaped = original.view_limit(Limit::from(4));
614
615        assert_eq!(
616            factor_dims(original.sizes(), Limit::from(4)),
617            vec![vec![1], vec![4, 3]]
618        );
619        assert_eq!(reshaped.sizes(), &[1, 4, 3]);
620
621        assert_layout_preserved!(&original, &reshaped);
622    }
623
624    #[test]
625    fn test_select_then_reshape() {
626        // Original shape: 2 zones, 3 hosts, 4 gpus
627        let original = shape!(zone = 2, host = 3, gpu = 4);
628
629        // Select the zone=1 plane: shape becomes [1, 3, 4]
630        let selected = original.select("zone", 1).unwrap();
631        assert_eq!(selected.slice().offset(), 12); // Nonzero offset.
632        assert_eq!(selected.slice().sizes(), &[1, 3, 4]);
633
634        // Reshape the selected slice using limit=2 in row-major
635        // layout.
636        let reshaped = selected.slice().view_limit(Limit::from(2));
637
638        assert_eq!(
639            factor_dims(selected.slice().sizes(), Limit::from(2)),
640            vec![vec![1], vec![3], vec![2, 2]]
641        );
642        assert_eq!(reshaped.sizes(), &[1, 3, 2, 2]);
643        assert_eq!(reshaped.strides(), &[12, 4, 2, 1]);
644        assert_eq!(reshaped.offset(), 12); // Offset verified preserved.
645
646        assert_layout_preserved!(selected.slice(), &reshaped);
647    }
648
649    #[test]
650    fn test_select_host_plane_then_reshape() {
651        // Original shape: 2 zones, 3 hosts, 4 gpus.
652        let original = shape!(zone = 2, host = 3, gpu = 4);
653        // Select the host=2 plane: shape becomes [2, 1, 4].
654        let selected = original.select("host", 2).unwrap();
655        // Reshape the selected slice using limit=2 in row-major
656        // layout.
657        let reshaped = selected.slice().view_limit(Limit::from(2));
658
659        assert_layout_preserved!(selected.slice(), &reshaped);
660    }
661
662    #[test]
663    fn test_reshape_after_select_no_factoring_due_to_primes() {
664        // Original shape: 3 zones, 4 hosts, 5 gpus
665        let original = shape!(zone = 3, host = 4, gpu = 5);
666        // First select: fix zone = 1 → shape: [1, 4, 5].
667        let selected_zone = original.select("zone", 1).unwrap();
668        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 5]);
669        // Second select: fix host = 2 → shape: [1, 1, 5].
670        let selected_host = selected_zone.select("host", 2).unwrap();
671        assert_eq!(selected_host.slice().sizes(), &[1, 1, 5]);
672        // Reshape with limit = 2.
673        let reshaped = selected_host.slice().view_limit(Limit::from(2));
674
675        assert_eq!(
676            factor_dims(selected_host.slice().sizes(), Limit::from(2)),
677            vec![vec![1], vec![1], vec![5]]
678        );
679        assert_eq!(reshaped.sizes(), &[1, 1, 5]);
680
681        assert_layout_preserved!(selected_host.slice(), &reshaped);
682    }
683
684    #[test]
685    fn test_reshape_after_multiple_selects_triggers_factoring() {
686        // Original shape: 2 zones, 4 hosts, 8 gpus
687        let original = shape!(zone = 2, host = 4, gpu = 8);
688        // Select zone=1 → shape: [1, 4, 8]
689        let selected_zone = original.select("zone", 1).unwrap();
690        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
691
692        // Select host=2 → shape: [1, 1, 8]
693        let selected_host = selected_zone.select("host", 2).unwrap();
694        assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
695
696        // Reshape with limit = 2 → gpu=8 should factor
697        let reshaped = selected_host.slice().view_limit(Limit::from(2));
698
699        assert_eq!(
700            factor_dims(selected_host.slice().sizes(), Limit::from(2)),
701            vec![vec![1], vec![1], vec![2, 2, 2]]
702        );
703        assert_eq!(reshaped.sizes(), &[1, 1, 2, 2, 2]);
704
705        assert_layout_preserved!(selected_host.slice(), &reshaped);
706    }
707
708    #[test]
709    fn test_expand_labels_singleton_dims() {
710        let factors = vec![("x".into(), vec![2]), ("y".into(), vec![4])];
711        let expected = vec!["x", "y"];
712        assert_eq!(expand_labels(&factors), expected);
713    }
714
715    #[test]
716    fn test_expand_labels_factored_dims() {
717        let factors = vec![("gpu".into(), vec![2, 2, 2])];
718        let expected = vec!["gpu/0", "gpu/1", "gpu/2"];
719        assert_eq!(expand_labels(&factors), expected);
720    }
721
722    #[test]
723    fn test_expand_labels_mixed_dims() {
724        let factors = vec![("zone".into(), vec![2]), ("gpu".into(), vec![2, 2])];
725        let expected = vec!["zone", "gpu/0", "gpu/1"];
726        assert_eq!(expand_labels(&factors), expected);
727    }
728
729    #[test]
730    fn test_expand_labels_empty() {
731        let factors: Vec<(String, Vec<usize>)> = vec![];
732        let expected: Vec<String> = vec![];
733        assert_eq!(expand_labels(&factors), expected);
734    }
735
736    #[test]
737    fn test_reshape_shape_noop() {
738        let shape = shape!(x = 4, y = 8);
739        let reshaped = reshape_shape(&shape, Limit::from(8));
740        assert_eq!(reshaped.shape.labels(), &["x", "y"]);
741        assert_eq!(reshaped.shape.slice(), shape.slice());
742    }
743
744    #[test]
745    fn test_reshape_shape_factored() {
746        let shape = shape!(gpu = 8);
747        let reshaped = reshape_shape(&shape, Limit::from(2));
748        assert_eq!(reshaped.shape.labels(), &["gpu/0", "gpu/1", "gpu/2"]);
749        assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2]);
750
751        let expected = shape.slice().view_limit(Limit::from(2));
752        assert_eq!(reshaped.shape.slice(), &expected);
753    }
754
755    #[test]
756    fn test_reshape_shape_singleton() {
757        let shape = shape!(x = 3);
758        let reshaped = reshape_shape(&shape, Limit::from(8));
759        assert_eq!(reshaped.shape.labels(), &["x"]);
760        assert_eq!(reshaped.shape.slice(), shape.slice());
761    }
762
763    #[test]
764    fn test_reshape_shape_prime_exceeds_limit() {
765        let shape = shape!(x = 11);
766        let reshaped = reshape_shape(&shape, Limit::from(5));
767        assert_eq!(reshaped.shape.labels(), &["x"]);
768        assert_eq!(reshaped.shape.slice(), shape.slice());
769    }
770
771    #[test]
772    fn test_reshape_shape_mixed_dims() {
773        let shape = shape!(zone = 2, gpu = 8);
774        let reshaped = reshape_shape(&shape, Limit::from(2));
775        assert_eq!(
776            reshaped.shape.labels(),
777            &["zone", "gpu/0", "gpu/1", "gpu/2"]
778        );
779        assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2, 2]);
780
781        let expected = shape.slice().view_limit(Limit::from(2));
782        assert_eq!(reshaped.shape.slice(), &expected);
783    }
784
785    #[test]
786    fn test_reshape_shape_after_selects() {
787        // Original shape: 2 zones, 4 hosts, 8 gpus
788        let original = shape!(zone = 2, host = 4, gpu = 8);
789
790        // Select zone=1 → shape: [1, 4, 8]
791        let selected_zone = original.select("zone", 1).unwrap();
792        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
793
794        // Select host=2 → shape: [1, 1, 8]
795        let selected_host = selected_zone.select("host", 2).unwrap();
796        assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
797
798        // Reshape shape through high-level API
799        let reshaped = reshape_shape(&selected_host, Limit::from(2));
800
801        // Labels should be: zone, host, gpu/0, gpu/1, gpu/2
802        assert_eq!(
803            reshaped.shape.labels(),
804            &["zone", "host", "gpu/0", "gpu/1", "gpu/2"]
805        );
806
807        // Sizes should reflect factored GPU dimension
808        assert_eq!(reshaped.shape.slice().sizes(), &[1, 1, 2, 2, 2]);
809
810        // Check against low-level equivalent reshaped slice
811        let expected = selected_host.slice().view_limit(Limit::from(2));
812        assert_eq!(reshaped.shape.slice(), &expected);
813    }
814}