#[cfg(test)]
mod tests {
use crate::specialization::*;
fn dh(divisor: i32) -> DivHint {
DivHint { divisor, max: 16 }
}
#[test]
fn max_pow2_divisor_powers_of_two() {
assert_eq!(max_pow2_divisor(1), 1);
assert_eq!(max_pow2_divisor(2), 2);
assert_eq!(max_pow2_divisor(4), 4);
assert_eq!(max_pow2_divisor(8), 8);
assert_eq!(max_pow2_divisor(16), 16);
assert_eq!(max_pow2_divisor(32), 16);
assert_eq!(max_pow2_divisor(1024), 16);
}
#[test]
fn max_pow2_divisor_non_powers() {
assert_eq!(max_pow2_divisor(3), 1); assert_eq!(max_pow2_divisor(6), 2); assert_eq!(max_pow2_divisor(12), 4); assert_eq!(max_pow2_divisor(24), 8); assert_eq!(max_pow2_divisor(48), 16); assert_eq!(max_pow2_divisor(7), 1);
assert_eq!(max_pow2_divisor(1023), 1);
}
#[test]
fn max_pow2_divisor_zero() {
assert_eq!(max_pow2_divisor(0), 16);
}
#[test]
fn divhint_from_value() {
assert_eq!(DivHint::from_value(1024), dh(16));
assert_eq!(DivHint::from_value(12), dh(4));
assert_eq!(DivHint::from_value(7), dh(1));
assert_eq!(DivHint::from_value(0), dh(16));
}
#[test]
fn divhint_with_max() {
let h: DivHint = DivHint::from_value(1024).with_max(8);
assert_eq!(h.divisor, 8);
assert_eq!(h.max, 8);
let h: DivHint = DivHint::from_value(3).with_max(8);
assert_eq!(h.divisor, 1); }
#[test]
fn divhint_default() {
let h: DivHint = DivHint::default();
assert_eq!(h.divisor, 1);
assert_eq!(h.max, 16);
}
#[test]
fn compute_spec_contiguous_aligned() {
let spec = compute_spec(0x1000, &[1024], &[1], 4);
assert_eq!(spec.shape_div, vec![dh(16)]); assert_eq!(spec.stride_div, vec![dh(1)]); assert_eq!(spec.stride_one, vec![true]);
assert_eq!(spec.base_ptr_div, dh(16)); assert!(spec.elements_disjoint);
}
#[test]
fn compute_spec_2d_row_major() {
let spec = compute_spec(0x1000, &[128, 256], &[256, 1], 2);
assert_eq!(spec.shape_div, vec![dh(16), dh(16)]); assert_eq!(spec.stride_div, vec![dh(16), dh(1)]); assert_eq!(spec.stride_one, vec![false, true]);
assert!(spec.elements_disjoint);
}
#[test]
fn compute_spec_odd_shape() {
let spec = compute_spec(0x1000, &[1023], &[1], 4);
assert_eq!(spec.shape_div, vec![dh(1)]); assert_eq!(spec.stride_div, vec![dh(1)]); assert_eq!(spec.stride_one, vec![true]);
}
#[test]
fn compute_spec_unaligned_ptr() {
let spec = compute_spec(0x1004, &[128], &[1], 4);
assert_eq!(spec.base_ptr_div, dh(4)); }
#[test]
fn stride_div_is_in_elements_not_bytes() {
let spec_f32 = compute_spec(0x1000, &[128], &[1], 4);
assert_eq!(
spec_f32.stride_div,
vec![dh(1)],
"stride=1 with f32: div should be 1 (elements), not 4 (bytes)"
);
let spec_f16 = compute_spec(0x1000, &[128], &[1], 2);
assert_eq!(
spec_f16.stride_div,
vec![dh(1)],
"stride=1 with f16: div should be 1 (elements), not 2 (bytes)"
);
let spec = compute_spec(0x1000, &[128, 256], &[256, 1], 2);
assert_eq!(
spec.stride_div,
vec![dh(16), dh(1)],
"stride=[256,1]: divs should be [16,1] in elements"
);
}
#[test]
fn base_ptr_div_is_in_bytes() {
let spec_f32 = compute_spec(0x1000, &[128], &[1], 4);
assert_eq!(
spec_f32.base_ptr_div,
dh(16),
"0x1000 is 16-byte aligned: base_ptr_div should be 16 (bytes)"
);
let spec_f16 = compute_spec(0x1000, &[128], &[1], 2);
assert_eq!(
spec_f16.base_ptr_div,
dh(16),
"same pointer, different dtype: base_ptr_div is still 16 (bytes, not elements)"
);
let spec = compute_spec(0x1004, &[128], &[1], 4);
assert_eq!(
spec.base_ptr_div,
dh(4),
"0x1004 is 4-byte aligned: base_ptr_div should be 4 (bytes)"
);
let spec = compute_spec(0x1002, &[128], &[1], 2);
assert_eq!(
spec.base_ptr_div,
dh(2),
"0x1002 is 2-byte aligned: base_ptr_div should be 2 (bytes)"
);
}
#[test]
fn units_elements_vs_bytes() {
let spec_f32 = compute_spec(0x1000, &[64, 128], &[128, 1], 4);
let spec_f16 = compute_spec(0x1000, &[64, 128], &[128, 1], 2);
assert_eq!(
spec_f32.shape_div, spec_f16.shape_div,
"shape_div is in elements: dtype does not affect it"
);
assert_eq!(
spec_f32.stride_div, spec_f16.stride_div,
"stride_div is in elements: dtype does not affect it"
);
assert_eq!(
spec_f32.base_ptr_div, spec_f16.base_ptr_div,
"base_ptr_div is in bytes: same pointer → same alignment"
);
}
#[test]
fn compute_spec_disjoint_detection() {
let spec = compute_spec(0x1000, &[4, 8], &[8, 1], 4);
assert!(spec.elements_disjoint);
let spec = compute_spec(0x1000, &[4, 8], &[4, 1], 4);
assert!(!spec.elements_disjoint);
}
#[test]
fn spec_equality_and_hash() {
use std::collections::HashSet;
let a = compute_spec(0x1000, &[128], &[1], 4);
let b = compute_spec(0x1000, &[128], &[1], 4);
let c = compute_spec(0x1000, &[127], &[1], 4); assert_eq!(a, b);
assert_ne!(a, c);
let mut set = HashSet::new();
set.insert(a.clone());
assert!(set.contains(&b));
assert!(!set.contains(&c));
}
}