use std::fmt;
use crate::shape::Shape;
use crate::slice::Slice;
pub type Coord = Vec<usize>;
pub struct ReshapedShape {
pub shape: Shape,
pub factors: Vec<(String, Vec<usize>)>,
}
#[allow(dead_code)]
const _: () = {
fn assert<T: Send + Sync + 'static>() {}
let _ = assert::<ReshapedShape>;
};
impl std::fmt::Debug for ReshapedShape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReshapedShape")
.field("labels", &self.shape.labels())
.field("sizes", &self.shape.slice().sizes())
.field("strides", &self.shape.slice().strides())
.field("offset", &self.shape.slice().offset())
.field("factors", &self.factors)
.finish()
}
}
impl std::fmt::Display for ReshapedShape {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ReshapedShape {{ [off={} sz={:?} st={:?} lab={:?} fac={:?}] }}",
self.shape.slice().offset(),
self.shape.slice().sizes(),
self.shape.slice().strides(),
self.shape.labels(),
self.factors
)
}
}
pub(crate) fn factor_dims(sizes: &[usize], limit: Limit) -> Vec<Vec<usize>> {
let limit = limit.get();
sizes
.iter()
.map(|&size| {
if size <= limit {
return vec![size];
}
let mut rem = size;
let mut factors = Vec::new();
for d in (2..=limit).rev() {
while rem % d == 0 {
factors.push(d);
rem /= d;
}
}
if rem > 1 {
factors.push(rem);
}
factors
})
.collect()
}
pub fn to_reshaped_coord<'a>(
original: &'a Slice,
reshaped: &'a Slice,
) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
let original = original.clone();
let reshaped = reshaped.clone();
move |coord: &[usize]| -> Coord {
let flat = original.location(coord).unwrap();
reshaped.coordinates(flat).unwrap()
}
}
pub fn to_original_coord<'a>(
reshaped: &'a Slice,
original: &'a Slice,
) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
let reshaped = reshaped.clone();
let original = original.clone();
move |coord: &[usize]| -> Coord {
let flat = reshaped.location(coord).unwrap();
original.coordinates(flat).unwrap()
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Limit(usize);
impl Limit {
pub fn new(n: usize) -> Self {
assert!(n >= 1, "Limit must be at least 1");
Self(n)
}
pub fn get(self) -> usize {
self.0
}
}
impl Default for Limit {
fn default() -> Self {
Self(32)
}
}
impl From<usize> for Limit {
fn from(n: usize) -> Self {
Self::new(n)
}
}
pub trait ReshapeSliceExt {
fn view_limit(&self, limit: Limit) -> Slice;
}
impl ReshapeSliceExt for Slice {
fn view_limit(&self, limit: Limit) -> Slice {
view_limit(self, limit)
}
}
pub trait ReshapeShapeExt {
fn reshape(&self, limit: Limit) -> ReshapedShape;
}
impl ReshapeShapeExt for Shape {
fn reshape(&self, limit: Limit) -> ReshapedShape {
reshape_shape(self, limit)
}
}
pub mod prelude {
pub use super::ReshapeShapeExt;
pub use super::ReshapeSliceExt;
}
pub fn view_limit(slice: &Slice, limit: Limit) -> Slice {
let orig_sizes = slice.sizes();
let orig_strides = slice.strides();
let factored_sizes = factor_dims(orig_sizes, limit);
let reshaped_sizes: Vec<usize> = factored_sizes.iter().flatten().cloned().collect();
let mut reshaped_strides = Vec::with_capacity(reshaped_sizes.len());
for (&orig_stride, factors) in orig_strides.iter().zip(&factored_sizes) {
let mut sub_strides = Vec::with_capacity(factors.len());
let mut stride = orig_stride;
for &f in factors.iter().rev() {
sub_strides.push(stride);
stride *= f;
}
sub_strides.reverse();
reshaped_strides.extend(sub_strides);
}
Slice::new(slice.offset(), reshaped_sizes, reshaped_strides).unwrap()
}
pub fn reshape_shape(shape: &Shape, limit: Limit) -> ReshapedShape {
let reshaped_slice = shape.slice().view_limit(limit);
let original_labels = shape.labels();
let original_sizes = shape.slice().sizes();
let factors = factor_dims(original_sizes, limit);
let factored_dims: Vec<(String, Vec<usize>)> =
original_labels.iter().cloned().zip(factors).collect();
let labels = expand_labels(&factored_dims);
let shape = Shape::new(labels, reshaped_slice).expect("invalid reshaped shape");
ReshapedShape {
shape,
factors: factored_dims,
}
}
pub fn expand_labels(factors: &[(String, Vec<usize>)]) -> Vec<String> {
let mut labels = Vec::new();
for (label, dims) in factors {
if dims.len() == 1 {
labels.push(label.clone());
} else {
for (i, _) in dims.iter().enumerate() {
labels.push(format!("{}/{}", label, i));
}
}
}
labels
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Slice;
use crate::shape;
#[test]
fn test_factor_dims_basic() {
assert_eq!(
factor_dims(&[6, 8], Limit::from(4)),
vec![vec![3, 2], vec![4, 2]]
);
assert_eq!(factor_dims(&[5], Limit::from(3)), vec![vec![5]]);
assert_eq!(factor_dims(&[30], Limit::from(5)), vec![vec![5, 3, 2]]);
}
#[macro_export]
macro_rules! assert_layout_preserved {
($original:expr_2021, $reshaped:expr_2021) => {{
for coord in $original.dim_iter($original.num_dim()) {
let forward = to_reshaped_coord($original, &$reshaped);
let inverse = to_original_coord(&$reshaped, $original);
let reshaped_coord = forward(&coord);
let roundtrip = inverse(&reshaped_coord);
assert_eq!(
roundtrip, coord,
"Inverse mismatch: reshaped {:?} → original {:?}, expected {:?}",
reshaped_coord, roundtrip, coord
);
let flat_orig = $original.location(&coord).unwrap();
let flat_reshaped = $reshaped.location(&reshaped_coord).unwrap();
assert_eq!(
flat_orig, flat_reshaped,
"Flat index mismatch: original {:?} → reshaped {:?}",
coord, reshaped_coord
);
let recovered = $reshaped.coordinates(flat_reshaped).unwrap();
assert_eq!(
reshaped_coord, recovered,
"Coordinate mismatch: flat index {} → expected {:?}, got {:?}",
flat_reshaped, reshaped_coord, recovered
);
}
}};
}
#[test]
fn test_reshape_split_1d_row_major() {
let s = Slice::new_row_major(vec![1024]);
let reshaped = s.view_limit(Limit::from(8));
assert_eq!(reshaped.offset(), 0);
assert_eq!(reshaped.sizes(), &vec![8, 8, 8, 2]);
assert_eq!(reshaped.strides(), &vec![128, 16, 2, 1]);
assert_eq!(
factor_dims(s.sizes(), Limit::from(8)),
vec![vec![8, 8, 8, 2]]
);
assert_layout_preserved!(&s, &reshaped);
}
#[test]
fn test_reshape_6_with_limit_2() {
let s = Slice::new_row_major(vec![6]);
let reshaped = view_limit(&s, Limit::from(2));
assert_eq!(factor_dims(s.sizes(), Limit::from(2)), vec![vec![2, 3]]);
assert_layout_preserved!(&s, &reshaped);
}
#[test]
fn test_reshape_identity_noop_2d() {
let original = Slice::new_row_major(vec![4, 8]);
let reshaped = original.view_limit(Limit::from(8));
assert_eq!(reshaped.sizes(), original.sizes());
assert_eq!(reshaped.strides(), original.strides());
assert_eq!(reshaped.offset(), original.offset());
assert_eq!(
vec![vec![4], vec![8]],
original
.sizes()
.iter()
.map(|&n| vec![n])
.collect::<Vec<_>>()
);
assert_layout_preserved!(&original, &reshaped);
}
#[test]
fn test_reshape_empty_slice() {
let original = Slice::new_row_major(vec![]);
let reshaped = view_limit(&original, Limit::from(8));
assert_eq!(reshaped.sizes(), original.sizes());
assert_eq!(reshaped.strides(), original.strides());
assert_eq!(reshaped.offset(), original.offset());
assert_layout_preserved!(&original, &reshaped);
}
#[test]
fn test_reshape_mixed_dims_3d() {
let original = Slice::new_row_major(vec![6, 8, 10]);
let reshaped = original.view_limit(Limit::from(4));
assert_eq!(
factor_dims(original.sizes(), Limit::from(4)),
vec![vec![3, 2], vec![4, 2], vec![2, 5]]
);
assert_eq!(reshaped.sizes(), &[3, 2, 4, 2, 2, 5]);
assert_layout_preserved!(&original, &reshaped);
}
#[test]
fn test_reshape_all_large_dims() {
let original = Slice::new_row_major(vec![12, 18, 20]);
let reshaped = original.view_limit(Limit::from(4));
assert_eq!(
factor_dims(original.sizes(), Limit::from(4)),
vec![vec![4, 3], vec![3, 3, 2], vec![4, 5]]
);
assert_eq!(reshaped.sizes(), &[4, 3, 3, 3, 2, 4, 5]);
assert_layout_preserved!(&original, &reshaped);
}
#[test]
fn test_reshape_split_1d_factors_3_3_2_2() {
let original = Slice::new_row_major(vec![36]);
let reshaped = view_limit(&original, Limit::from(3));
assert_eq!(
factor_dims(original.sizes(), Limit::from(3)),
vec![vec![3, 3, 2, 2]]
);
assert_eq!(reshaped.sizes(), &[3, 3, 2, 2]);
assert_layout_preserved!(&original, &reshaped);
}
#[test]
fn test_reshape_large_prime_dimension() {
let original = Slice::new_row_major(vec![7]);
let reshaped = view_limit(&original, Limit::from(4));
assert_eq!(factor_dims(original.sizes(), Limit::from(4)), vec![vec![7]]);
assert_eq!(reshaped.sizes(), &[7]);
assert_layout_preserved!(&original, &reshaped);
}
#[test]
fn test_reshape_split_1d_factors_5_3_2() {
let original = Slice::new_row_major(vec![30]);
let reshaped = view_limit(&original, Limit::from(5));
assert_eq!(
factor_dims(original.sizes(), Limit::from(5)),
vec![vec![5, 3, 2]]
);
assert_eq!(reshaped.sizes(), &[5, 3, 2]);
assert_eq!(reshaped.strides(), &[6, 2, 1]);
assert_layout_preserved!(&original, &reshaped);
}
#[test]
fn test_reshape_factors_2_6_2_8_8() {
let original = Slice::new_row_major(vec![2, 12, 64]);
let reshaped = original.view_limit(Limit::from(8));
assert_eq!(
factor_dims(original.sizes(), Limit::from(8)),
vec![vec![2], vec![6, 2], vec![8, 8]]
);
assert_eq!(reshaped.sizes(), &[2, 6, 2, 8, 8]);
assert_eq!(reshaped.strides(), &[768, 128, 64, 8, 1]);
assert_layout_preserved!(&original, &reshaped);
}
#[test]
fn test_reshape_all_dims_within_limit() {
let original = Slice::new_row_major(vec![2, 3, 4]);
let reshaped = original.view_limit(Limit::from(4));
assert_eq!(
factor_dims(original.sizes(), Limit::from(4)),
vec![vec![2], vec![3], vec![4]]
);
assert_eq!(reshaped.sizes(), &[2, 3, 4]);
assert_eq!(reshaped.strides(), original.strides());
assert_eq!(reshaped.offset(), original.offset());
assert_layout_preserved!(&original, &reshaped);
}
#[test]
fn test_reshape_degenerate_dimension() {
let original = Slice::new_row_major(vec![1, 12]);
let reshaped = original.view_limit(Limit::from(4));
assert_eq!(
factor_dims(original.sizes(), Limit::from(4)),
vec![vec![1], vec![4, 3]]
);
assert_eq!(reshaped.sizes(), &[1, 4, 3]);
assert_layout_preserved!(&original, &reshaped);
}
#[test]
fn test_select_then_reshape() {
let original = shape!(zone = 2, host = 3, gpu = 4);
let selected = original.select("zone", 1).unwrap();
assert_eq!(selected.slice().offset(), 12); assert_eq!(selected.slice().sizes(), &[1, 3, 4]);
let reshaped = selected.slice().view_limit(Limit::from(2));
assert_eq!(
factor_dims(selected.slice().sizes(), Limit::from(2)),
vec![vec![1], vec![3], vec![2, 2]]
);
assert_eq!(reshaped.sizes(), &[1, 3, 2, 2]);
assert_eq!(reshaped.strides(), &[12, 4, 2, 1]);
assert_eq!(reshaped.offset(), 12);
assert_layout_preserved!(selected.slice(), &reshaped);
}
#[test]
fn test_select_host_plane_then_reshape() {
let original = shape!(zone = 2, host = 3, gpu = 4);
let selected = original.select("host", 2).unwrap();
let reshaped = selected.slice().view_limit(Limit::from(2));
assert_layout_preserved!(selected.slice(), &reshaped);
}
#[test]
fn test_reshape_after_select_no_factoring_due_to_primes() {
let original = shape!(zone = 3, host = 4, gpu = 5);
let selected_zone = original.select("zone", 1).unwrap();
assert_eq!(selected_zone.slice().sizes(), &[1, 4, 5]);
let selected_host = selected_zone.select("host", 2).unwrap();
assert_eq!(selected_host.slice().sizes(), &[1, 1, 5]);
let reshaped = selected_host.slice().view_limit(Limit::from(2));
assert_eq!(
factor_dims(selected_host.slice().sizes(), Limit::from(2)),
vec![vec![1], vec![1], vec![5]]
);
assert_eq!(reshaped.sizes(), &[1, 1, 5]);
assert_layout_preserved!(selected_host.slice(), &reshaped);
}
#[test]
fn test_reshape_after_multiple_selects_triggers_factoring() {
let original = shape!(zone = 2, host = 4, gpu = 8);
let selected_zone = original.select("zone", 1).unwrap();
assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
let selected_host = selected_zone.select("host", 2).unwrap();
assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
let reshaped = selected_host.slice().view_limit(Limit::from(2));
assert_eq!(
factor_dims(selected_host.slice().sizes(), Limit::from(2)),
vec![vec![1], vec![1], vec![2, 2, 2]]
);
assert_eq!(reshaped.sizes(), &[1, 1, 2, 2, 2]);
assert_layout_preserved!(selected_host.slice(), &reshaped);
}
#[test]
fn test_expand_labels_singleton_dims() {
let factors = vec![("x".into(), vec![2]), ("y".into(), vec![4])];
let expected = vec!["x", "y"];
assert_eq!(expand_labels(&factors), expected);
}
#[test]
fn test_expand_labels_factored_dims() {
let factors = vec![("gpu".into(), vec![2, 2, 2])];
let expected = vec!["gpu/0", "gpu/1", "gpu/2"];
assert_eq!(expand_labels(&factors), expected);
}
#[test]
fn test_expand_labels_mixed_dims() {
let factors = vec![("zone".into(), vec![2]), ("gpu".into(), vec![2, 2])];
let expected = vec!["zone", "gpu/0", "gpu/1"];
assert_eq!(expand_labels(&factors), expected);
}
#[test]
fn test_expand_labels_empty() {
let factors: Vec<(String, Vec<usize>)> = vec![];
let expected: Vec<String> = vec![];
assert_eq!(expand_labels(&factors), expected);
}
#[test]
fn test_reshape_shape_noop() {
let shape = shape!(x = 4, y = 8);
let reshaped = reshape_shape(&shape, Limit::from(8));
assert_eq!(reshaped.shape.labels(), &["x", "y"]);
assert_eq!(reshaped.shape.slice(), shape.slice());
}
#[test]
fn test_reshape_shape_factored() {
let shape = shape!(gpu = 8);
let reshaped = reshape_shape(&shape, Limit::from(2));
assert_eq!(reshaped.shape.labels(), &["gpu/0", "gpu/1", "gpu/2"]);
assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2]);
let expected = shape.slice().view_limit(Limit::from(2));
assert_eq!(reshaped.shape.slice(), &expected);
}
#[test]
fn test_reshape_shape_singleton() {
let shape = shape!(x = 3);
let reshaped = reshape_shape(&shape, Limit::from(8));
assert_eq!(reshaped.shape.labels(), &["x"]);
assert_eq!(reshaped.shape.slice(), shape.slice());
}
#[test]
fn test_reshape_shape_prime_exceeds_limit() {
let shape = shape!(x = 11);
let reshaped = reshape_shape(&shape, Limit::from(5));
assert_eq!(reshaped.shape.labels(), &["x"]);
assert_eq!(reshaped.shape.slice(), shape.slice());
}
#[test]
fn test_reshape_shape_mixed_dims() {
let shape = shape!(zone = 2, gpu = 8);
let reshaped = reshape_shape(&shape, Limit::from(2));
assert_eq!(
reshaped.shape.labels(),
&["zone", "gpu/0", "gpu/1", "gpu/2"]
);
assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2, 2]);
let expected = shape.slice().view_limit(Limit::from(2));
assert_eq!(reshaped.shape.slice(), &expected);
}
#[test]
fn test_reshape_shape_after_selects() {
let original = shape!(zone = 2, host = 4, gpu = 8);
let selected_zone = original.select("zone", 1).unwrap();
assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
let selected_host = selected_zone.select("host", 2).unwrap();
assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
let reshaped = reshape_shape(&selected_host, Limit::from(2));
assert_eq!(
reshaped.shape.labels(),
&["zone", "host", "gpu/0", "gpu/1", "gpu/2"]
);
assert_eq!(reshaped.shape.slice().sizes(), &[1, 1, 2, 2, 2]);
let expected = selected_host.slice().view_limit(Limit::from(2));
assert_eq!(reshaped.shape.slice(), &expected);
}
}