use super::*;
const TOL: f32 = 1e-4;
fn vclose(got: &[f32], want: &[f32]) -> bool {
if got.len() != want.len() {
return false;
}
got
.iter()
.zip(want)
.all(|(g, w)| (g - w).abs() <= TOL && g.is_finite() && w.is_finite())
}
#[test]
fn rms_norm_hand_traced() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let w = Array::from_slice::<f32>(&[1.0, 1.0, 1.0], &(3,)).unwrap();
let rn = RMSNorm::new(w, 1e-6);
let mut y = rn.forward(&x).unwrap();
let rms = (14.0_f32 / 3.0).sqrt();
assert!(vclose(
&y.to_vec::<f32>().unwrap(),
&[1.0 / rms, 2.0 / rms, 3.0 / rms]
));
}
#[test]
fn rms_norm_zero_input_is_finite() {
let x = Array::from_slice::<f32>(&[0.0, 0.0, 0.0, 0.0], &(1, 4)).unwrap();
let w = Array::ones::<f32>(&(4,)).unwrap();
let rn = RMSNorm::new(w, 1e-5);
let mut y = rn.forward(&x).unwrap();
let v = y.to_vec::<f32>().unwrap();
assert!(
v.iter().all(|x| x.is_finite()),
"expected finite, got {v:?}"
);
}
#[test]
fn rms_norm_preserves_rank3_shape() {
let x =
Array::from_slice::<f32>(&(0..24).map(|i| i as f32).collect::<Vec<_>>(), &(2, 3, 4)).unwrap();
let w = Array::ones::<f32>(&(4,)).unwrap();
let rn = RMSNorm::new(w, 1e-5);
let y = rn.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 3, 4]);
}
#[test]
fn rms_norm_matches_manual_fallback() {
let x = Array::from_slice::<f32>(&[0.5, -1.5, 2.0, 3.0, 4.0, 5.0], &(1, 2, 3)).unwrap();
let w = Array::from_slice::<f32>(&[0.5, 1.0, 1.5], &(3,)).unwrap();
let eps = 1e-5_f32;
let mut via_kernel = RMSNorm::new(w.try_clone().unwrap(), eps)
.forward(&x)
.unwrap();
let xx = ops::arithmetic::square(&x).unwrap();
let m = ops::reduction::mean_axes(&xx, &[-1], true).unwrap();
let eps_arr = scalar_like(eps, &m).unwrap();
let denom = ops::arithmetic::rsqrt(&ops::arithmetic::add(&m, &eps_arr).unwrap()).unwrap();
let scaled = ops::arithmetic::multiply(&x, &denom).unwrap();
let mut via_manual = ops::arithmetic::multiply(&scaled, &w).unwrap();
assert!(vclose(
&via_kernel.to_vec::<f32>().unwrap(),
&via_manual.to_vec::<f32>().unwrap()
));
}
#[test]
fn layer_norm_hand_traced() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let ln = LayerNorm::new(None, None, 1e-5);
let mut y = ln.forward(&x).unwrap();
let denom = (1.25_f32 + 1e-5).sqrt();
let want = [
(1.0 - 2.5) / denom,
(2.0 - 2.5) / denom,
(3.0 - 2.5) / denom,
(4.0 - 2.5) / denom,
];
assert!(vclose(&y.to_vec::<f32>().unwrap(), &want));
}
#[test]
fn layer_norm_zero_input_is_finite() {
let x = Array::from_slice::<f32>(&[0.0; 6], &(1, 6)).unwrap();
let ln = LayerNorm::new(None, None, 1e-5);
let mut y = ln.forward(&x).unwrap();
let v = y.to_vec::<f32>().unwrap();
assert!(
v.iter().all(|x| x.is_finite()),
"expected finite, got {v:?}"
);
}
#[test]
fn layer_norm_preserves_rank3_shape() {
let x =
Array::from_slice::<f32>(&(0..24).map(|i| i as f32).collect::<Vec<_>>(), &(2, 3, 4)).unwrap();
let ln = LayerNorm::new(None, None, 1e-5);
let y = ln.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 3, 4]);
}
#[test]
fn layer_norm_affine_applies_weight_and_bias() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let w = Array::full::<f32>(&(4,), 2.0).unwrap();
let b = Array::ones::<f32>(&(4,)).unwrap();
let plain = LayerNorm::new(None, None, 1e-5);
let affine = LayerNorm::new(Some(w), Some(b), 1e-5);
let mut p = plain.forward(&x).unwrap();
let mut a = affine.forward(&x).unwrap();
let pv = p.to_vec::<f32>().unwrap();
let av = a.to_vec::<f32>().unwrap();
let want: Vec<f32> = pv.iter().map(|v| 2.0 * v + 1.0).collect();
assert!(vclose(&av, &want));
}
#[test]
fn layer_norm_matches_manual_fallback() {
let x = Array::from_slice::<f32>(&[0.5, -1.5, 2.0, 3.0, 4.0, 5.0], &(1, 2, 3)).unwrap();
let eps = 1e-5_f32;
let mut via_kernel = LayerNorm::new(None, None, eps).forward(&x).unwrap();
let m = ops::reduction::mean_axes(&x, &[-1], true).unwrap();
let v = ops::reduction::var_axes(&x, &[-1], true, 0).unwrap();
let eps_arr = scalar_like(eps, &v).unwrap();
let denom = ops::arithmetic::rsqrt(&ops::arithmetic::add(&v, &eps_arr).unwrap()).unwrap();
let centered = ops::arithmetic::subtract(&x, &m).unwrap();
let mut via_manual = ops::arithmetic::multiply(¢ered, &denom).unwrap();
assert!(vclose(
&via_kernel.to_vec::<f32>().unwrap(),
&via_manual.to_vec::<f32>().unwrap()
));
}
#[test]
fn group_norm_hand_traced_one_group_matches_layer_norm() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let gn = GroupNorm::new(1, 4, 1e-5, false, false).unwrap();
let mut y = gn.forward(&x).unwrap();
let denom = (1.25_f32 + 1e-5).sqrt();
let want = [
(1.0 - 2.5) / denom,
(2.0 - 2.5) / denom,
(3.0 - 2.5) / denom,
(4.0 - 2.5) / denom,
];
assert!(vclose(&y.to_vec::<f32>().unwrap(), &want));
}
#[test]
fn group_norm_zero_input_is_finite() {
let x = Array::from_slice::<f32>(&[0.0; 12], &(1, 3, 4)).unwrap();
let gn = GroupNorm::new(2, 4, 1e-5, false, false).unwrap();
let mut y = gn.forward(&x).unwrap();
let v = y.to_vec::<f32>().unwrap();
assert!(
v.iter().all(|x| x.is_finite()),
"expected finite, got {v:?}"
);
}
#[test]
fn group_norm_preserves_rank4_shape() {
let n = 2 * 3 * 3 * 4;
let x =
Array::from_slice::<f32>(&(0..n).map(|i| i as f32).collect::<Vec<_>>(), &(2, 3, 3, 4)).unwrap();
let gn = GroupNorm::new(2, 4, 1e-5, true, false).unwrap();
let y = gn.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 3, 3, 4]);
}
#[test]
fn group_norm_pytorch_compat_preserves_shape() {
let n = 2 * 3 * 3 * 4;
let x = Array::from_slice::<f32>(
&(0..n).map(|i| i as f32 + 1.0).collect::<Vec<_>>(),
&(2, 3, 3, 4),
)
.unwrap();
let gn = GroupNorm::new(2, 4, 1e-5, false, true).unwrap();
let mut y = gn.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 3, 3, 4]);
let v = y.to_vec::<f32>().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
#[test]
fn group_norm_affine_true_applies_scale_and_shift() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let plain = GroupNorm::new(2, 4, 1e-5, false, false).unwrap();
let affine = GroupNorm::new(2, 4, 1e-5, true, false).unwrap();
let mut normalized = plain.forward(&x).unwrap();
let mut a = affine.forward(&x).unwrap();
let normalized_v = normalized.to_vec::<f32>().unwrap();
let av = a.to_vec::<f32>().unwrap();
let (w, b) = affine.affine().expect("affine=true ⇒ Some");
let scaled = ops::arithmetic::multiply(w, &normalized).unwrap();
let mut want = ops::arithmetic::add(&scaled, b).unwrap();
assert!(vclose(&av, &want.to_vec::<f32>().unwrap()));
assert!(vclose(&av, &normalized_v));
}
#[test]
fn group_norm_affine_false_is_pure_normalization() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let gn = GroupNorm::new(2, 4, 1e-5, false, false).unwrap();
assert!(gn.affine().is_none());
let mut y = gn.forward(&x).unwrap();
let v = y.to_vec::<f32>().unwrap();
assert_eq!(v.len(), 4);
assert!(v.iter().all(|x| x.is_finite()));
}
#[test]
fn group_norm_affine_is_both_or_none_by_construction() {
let with_affine = GroupNorm::new(2, 4, 1e-5, true, false).unwrap();
assert!(with_affine.affine().is_some());
let no_affine = GroupNorm::new(2, 4, 1e-5, false, false).unwrap();
assert!(no_affine.affine().is_none());
}
#[test]
fn group_norm_default_constructor_no_affine() {
let gn = GroupNorm::new(2, 4, 1e-5, false, false).unwrap();
assert!(gn.affine().is_none());
assert!(!gn.pytorch_compatible);
assert_eq!(gn.num_groups(), 2);
}
#[test]
fn group_norm_default_constructor_affine_allocates() {
let gn = GroupNorm::new(2, 4, 1e-5, true, false).unwrap();
let (w, b) = gn.affine().expect("affine=true ⇒ Some");
assert_eq!(w.shape(), vec![4]);
assert_eq!(b.shape(), vec![4]);
let mut w = w.try_clone().unwrap();
let mut b = b.try_clone().unwrap();
assert_eq!(w.to_vec::<f32>().unwrap(), vec![1.0; 4]);
assert_eq!(b.to_vec::<f32>().unwrap(), vec![0.0; 4]);
}
#[test]
fn group_norm_with_affine_accepts_checkpoint_tensors() {
let w = Array::from_slice::<f32>(&[2.0, 2.0, 2.0, 2.0], &(4,)).unwrap();
let b = Array::from_slice::<f32>(&[1.0, 1.0, 1.0, 1.0], &(4,)).unwrap();
let gn = GroupNorm::with_affine(2, 4, 1e-5, Some((w, b)), false).unwrap();
let (gw, gb) = gn.affine().expect("with_affine(Some(_)) ⇒ Some");
let mut gw = gw.try_clone().unwrap();
let mut gb = gb.try_clone().unwrap();
assert_eq!(gw.to_vec::<f32>().unwrap(), vec![2.0; 4]);
assert_eq!(gb.to_vec::<f32>().unwrap(), vec![1.0; 4]);
}
#[test]
fn group_norm_with_affine_non_identity_forward_applies_scale_shift() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let w = Array::from_slice::<f32>(&[2.0, 2.0, 2.0, 2.0], &(4,)).unwrap();
let b = Array::from_slice::<f32>(&[1.0, 1.0, 1.0, 1.0], &(4,)).unwrap();
let plain = GroupNorm::with_affine(2, 4, 1e-5, None, false).unwrap();
let mut normalized = plain.forward(&x).unwrap();
let normalized_v = normalized.to_vec::<f32>().unwrap();
let affine = GroupNorm::with_affine(
2,
4,
1e-5,
Some((w.try_clone().unwrap(), b.try_clone().unwrap())),
false,
)
.unwrap();
let mut got = affine.forward(&x).unwrap();
let scaled = ops::arithmetic::multiply(&w, &normalized).unwrap();
let mut want = ops::arithmetic::add(&scaled, &b).unwrap();
assert!(vclose(
&got.to_vec::<f32>().unwrap(),
&want.to_vec::<f32>().unwrap()
));
assert!(
!vclose(&got.to_vec::<f32>().unwrap(), &normalized_v),
"non-identity affine must change the output"
);
}
#[test]
fn group_norm_with_affine_rejects_wrong_shape_weight() {
let bias = Array::zeros::<f32>(&(4,)).unwrap();
let long_w = Array::ones::<f32>(&(5,)).unwrap();
let err = GroupNorm::with_affine(2, 4, 1e-5, Some((long_w, bias.try_clone().unwrap())), false)
.unwrap_err();
match err {
crate::error::Error::LengthMismatch(payload) => {
assert!(
payload.context().contains("weight"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.expected(), 4, "expected length 4 (dims)");
assert_eq!(payload.actual(), 5, "actual length 5");
}
other => panic!("expected LengthMismatch, got {other:?}"),
}
let rank2_w = Array::ones::<f32>(&(1, 4)).unwrap();
let err = GroupNorm::with_affine(2, 4, 1e-5, Some((rank2_w, bias)), false).unwrap_err();
match err {
crate::error::Error::RankMismatch(payload) => {
assert!(
payload.context().contains("weight"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.actual(), 2, "expected observed rank 2");
assert_eq!(
payload.actual_shape(),
&[1usize, 4],
"expected actual shape [1, 4]"
);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
}
#[test]
fn group_norm_with_affine_rejects_wrong_shape_bias() {
let weight = Array::ones::<f32>(&(4,)).unwrap();
let long_b = Array::zeros::<f32>(&(5,)).unwrap();
let err = GroupNorm::with_affine(
2,
4,
1e-5,
Some((weight.try_clone().unwrap(), long_b)),
false,
)
.unwrap_err();
match err {
crate::error::Error::LengthMismatch(payload) => {
assert!(
payload.context().contains("bias"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.expected(), 4, "expected length 4 (dims)");
assert_eq!(payload.actual(), 5, "actual length 5");
}
other => panic!("expected LengthMismatch, got {other:?}"),
}
let rank2_b = Array::zeros::<f32>(&(1, 4)).unwrap();
let err = GroupNorm::with_affine(2, 4, 1e-5, Some((weight, rank2_b)), false).unwrap_err();
match err {
crate::error::Error::RankMismatch(payload) => {
assert!(
payload.context().contains("bias"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.actual(), 2, "expected observed rank 2");
assert_eq!(
payload.actual_shape(),
&[1usize, 4],
"expected actual shape [1, 4]"
);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
}
#[test]
fn group_norm_with_affine_none_is_pure_normalization() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let gn = GroupNorm::with_affine(2, 4, 1e-5, None, false).unwrap();
assert!(gn.affine().is_none());
let mut got = gn.forward(&x).unwrap();
let mut want = gn.group_norm(&x).unwrap();
assert!(vclose(
&got.to_vec::<f32>().unwrap(),
&want.to_vec::<f32>().unwrap()
));
}
#[test]
fn group_norm_rank1_input_errors() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(4,)).unwrap();
let gn = GroupNorm::new(1, 4, 1e-5, false, false).unwrap();
let err = gn.forward(&x).unwrap_err();
match err {
crate::error::Error::RankMismatch(payload) => {
assert!(
payload.context().contains("rank"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.actual(), 1);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
}
#[test]
fn group_norm_feature_dim_mismatch_errors() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let gn = GroupNorm::new(2, 4, 1e-5, false, false).unwrap();
let err = gn.forward(&x).unwrap_err();
match err {
crate::error::Error::LengthMismatch(payload) => {
assert!(
payload.context().contains("last-axis"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.expected(), 4);
assert_eq!(payload.actual(), 3);
}
other => panic!("expected LengthMismatch, got {other:?}"),
}
}
#[test]
fn group_norm_pytorch_compat_rank1_input_errors() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(4,)).unwrap();
let gn = GroupNorm::new(1, 4, 1e-5, false, true).unwrap();
let err = gn.forward(&x).unwrap_err();
match err {
crate::error::Error::RankMismatch(payload) => {
assert!(
payload.context().contains("rank"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.actual(), 1);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
}
#[test]
fn group_norm_pytorch_compat_feature_dim_mismatch_errors() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1, 3)).unwrap();
let gn = GroupNorm::new(2, 4, 1e-5, false, true).unwrap();
let err = gn.forward(&x).unwrap_err();
match err {
crate::error::Error::LengthMismatch(payload) => {
assert!(
payload.context().contains("last-axis"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.expected(), 4);
assert_eq!(payload.actual(), 3);
}
other => panic!("expected LengthMismatch, got {other:?}"),
}
}
#[test]
fn group_norm_valid_rank2_still_works() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let gn = GroupNorm::new(2, 4, 1e-5, false, false).unwrap();
let mut y = gn.forward(&x).unwrap();
let v = y.to_vec::<f32>().unwrap();
assert_eq!(v.len(), 4);
assert!(v.iter().all(|x| x.is_finite()));
}
#[test]
fn group_norm_constructor_rejects_negative_dims() {
let err = GroupNorm::new(2, -1, 1e-5, false, false).unwrap_err();
match err {
crate::error::Error::OutOfRange(payload) => {
assert!(
payload.context().contains("dims"),
"unexpected context: {:?}",
payload.context()
);
assert!(
payload.requirement().contains("positive"),
"unexpected requirement: {:?}",
payload.requirement()
);
}
other => panic!("expected OutOfRange, got {other:?}"),
}
}
#[test]
fn group_norm_constructor_rejects_non_divisible_dims() {
let err = GroupNorm::new(2, 3, 1e-5, false, false).unwrap_err();
match err {
crate::error::Error::DivisibilityConstraint(payload) => {
assert_eq!(payload.name_dividend(), "dims");
assert_eq!(payload.name_divisor(), "num_groups");
}
other => panic!("expected DivisibilityConstraint, got {other:?}"),
}
}
#[test]
fn group_norm_constructor_rejects_zero_dims() {
let err = GroupNorm::new(2, 0, 1e-5, false, false).unwrap_err();
match err {
crate::error::Error::OutOfRange(payload) => {
assert!(
payload.context().contains("dims"),
"unexpected context: {:?}",
payload.context()
);
assert!(
payload.requirement().contains("positive"),
"unexpected requirement: {:?}",
payload.requirement()
);
}
other => panic!("expected OutOfRange, got {other:?}"),
}
}
#[test]
fn group_norm_constructor_accepts_valid_non_affine() {
let gn = GroupNorm::new(2, 4, 1e-5, false, false).unwrap();
assert_eq!(gn.dims(), 4);
assert_eq!(gn.num_groups(), 2);
assert!(gn.affine().is_none());
}
#[test]
fn group_norm_forward_rejects_dim_mismatch() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &(1, 8)).unwrap();
let gn = GroupNorm::new(2, 4, 1e-5, false, false).unwrap();
let err = gn.forward(&x).unwrap_err();
match err {
crate::error::Error::LengthMismatch(payload) => {
assert!(
payload.context().contains("last-axis"),
"expected context to name last-axis: {:?}",
payload.context()
);
assert_eq!(payload.expected(), 4);
assert_eq!(payload.actual(), 8);
}
other => panic!("expected LengthMismatch, got {other:?}"),
}
}
#[test]
fn group_norm_new_affine_true_invalid_config_rejects_before_alloc() {
let err = GroupNorm::new(0, 4, 1e-5, true, false).unwrap_err();
assert!(
matches!(err, crate::error::Error::OutOfRange(_)),
"expected OutOfRange for num_groups=0, got {err:?}"
);
let err = GroupNorm::new(3, 8, 1e-5, true, false).unwrap_err();
assert!(
matches!(err, crate::error::Error::DivisibilityConstraint(_)),
"expected DivisibilityConstraint for non-divisible dims, got {err:?}"
);
}
#[test]
fn group_norm_num_groups_dims_are_read_only_via_accessors() {
let gn = GroupNorm::new(4, 16, 1e-5, false, false).unwrap();
assert_eq!(gn.num_groups(), 4);
assert_eq!(gn.dims(), 16);
}
#[test]
fn inferred_dim_overflow_errors() {
let shape: [usize; 2] = [usize::MAX, 2];
let err = inferred_dim(&shape, &[1, 1]).unwrap_err();
match err {
crate::error::Error::ArithmeticOverflow(payload) => {
assert!(
payload.context().contains("overflow"),
"unexpected context: {:?}",
payload.context()
);
}
other => panic!("expected ArithmeticOverflow, got {other:?}"),
}
}
#[test]
fn rms_norm_weight_ref_returns_installed_weight() {
let w = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3,)).unwrap();
let rn = RMSNorm::new(w, 1e-5);
let got = rn.weight_ref();
assert_eq!(got.shape(), vec![3]);
let mut got = got.try_clone().unwrap();
assert_eq!(got.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn layer_norm_weight_ref_some_and_none() {
let w = Array::from_slice::<f32>(&[2.0, 2.0, 2.0, 2.0], &(4,)).unwrap();
let ln = LayerNorm::new(Some(w), None, 1e-5);
let got = ln.weight_ref().expect("weight installed ⇒ Some");
assert_eq!(got.shape(), vec![4]);
let mut got = got.try_clone().unwrap();
assert_eq!(got.to_vec::<f32>().unwrap(), vec![2.0; 4]);
let plain = LayerNorm::new(None, None, 1e-5);
assert!(plain.weight_ref().is_none());
}
#[test]
fn layer_norm_bias_ref_some_and_none() {
let b = Array::from_slice::<f32>(&[1.0, 1.0, 1.0, 1.0], &(4,)).unwrap();
let ln = LayerNorm::new(None, Some(b), 1e-5);
let got = ln.bias_ref().expect("bias installed ⇒ Some");
assert_eq!(got.shape(), vec![4]);
let mut got = got.try_clone().unwrap();
assert_eq!(got.to_vec::<f32>().unwrap(), vec![1.0; 4]);
let plain = LayerNorm::new(None, None, 1e-5);
assert!(plain.bias_ref().is_none());
}
#[test]
fn validate_input_shape_feature_dim_overflow_errors() {
let gn = GroupNorm::new(2, 4, 1e-5, false, false).unwrap();
let big = (i32::MAX as usize) + 1;
let err = gn.validate_input_shape(&[2, big]).unwrap_err();
match err {
crate::error::Error::ArithmeticOverflow(payload) => {
assert!(
payload.context().contains("feature dim"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.op_type(), "i32");
assert_eq!(payload.operands(), &[("dim", big as u64)]);
}
other => panic!("expected ArithmeticOverflow, got {other:?}"),
}
}
#[test]
fn validate_input_shape_divisibility_belt_and_suspenders() {
let gn = GroupNorm {
num_groups: 3,
dims: 4,
affine: None,
eps: 1e-5,
pytorch_compatible: false,
};
let err = gn.validate_input_shape(&[1, 4]).unwrap_err();
match err {
crate::error::Error::DivisibilityConstraint(payload) => {
assert_eq!(payload.name_dividend(), "feature_dim");
assert_eq!(payload.dividend(), 4);
assert_eq!(payload.name_divisor(), "num_groups");
assert_eq!(payload.divisor(), 3);
}
other => panic!("expected DivisibilityConstraint, got {other:?}"),
}
}
#[test]
fn shape_to_i32_overflow_errors() {
let big = (i32::MAX as usize) + 1;
let err = shape_to_i32(&[2, big]).unwrap_err();
match err {
crate::error::Error::ArithmeticOverflow(payload) => {
assert!(
payload.context().contains("exceeds i32::MAX"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.op_type(), "i32");
assert_eq!(payload.operands(), &[("dim", big as u64)]);
}
other => panic!("expected ArithmeticOverflow, got {other:?}"),
}
}
#[test]
fn shape_to_i32_ok_roundtrip() {
assert_eq!(shape_to_i32(&[2, 3, 4]).unwrap(), vec![2i32, 3, 4]);
}
#[test]
fn batch_dim_rank0_errors() {
let err = batch_dim(&[]).unwrap_err();
match err {
crate::error::Error::RankMismatch(payload) => {
assert!(
payload.context().contains("rank >= 1"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.actual(), 0);
assert_eq!(payload.actual_shape(), &[] as &[usize]);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
}
#[test]
fn batch_dim_overflow_errors() {
let big = (i32::MAX as usize) + 1;
let err = batch_dim(&[big, 3]).unwrap_err();
match err {
crate::error::Error::ArithmeticOverflow(payload) => {
assert!(
payload.context().contains("batch dim exceeds i32::MAX"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.op_type(), "i32");
assert_eq!(payload.operands(), &[("batch_dim", big as u64)]);
}
other => panic!("expected ArithmeticOverflow, got {other:?}"),
}
}
#[test]
fn batch_dim_ok() {
assert_eq!(batch_dim(&[7, 2, 5]).unwrap(), 7i32);
}
#[test]
fn inferred_dim_negative_known_dim_errors() {
let err = inferred_dim(&[4], &[-1]).unwrap_err();
match err {
crate::error::Error::OutOfRange(payload) => {
assert!(
payload.context().contains("known reshape dim"),
"unexpected context: {:?}",
payload.context()
);
assert!(
payload.requirement().contains("non-negative"),
"unexpected requirement: {:?}",
payload.requirement()
);
assert_eq!(payload.value(), "-1");
}
other => panic!("expected OutOfRange, got {other:?}"),
}
}
#[test]
fn inferred_dim_divisor_product_overflow_errors() {
let err = inferred_dim(&[1], &[i32::MAX, i32::MAX, i32::MAX]).unwrap_err();
match err {
crate::error::Error::ArithmeticOverflow(payload) => {
assert!(
payload.context().contains("divisor product"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.op_type(), "usize");
}
other => panic!("expected ArithmeticOverflow, got {other:?}"),
}
}
#[test]
fn inferred_dim_zero_divisor_errors() {
let err = inferred_dim(&[4], &[0]).unwrap_err();
match err {
crate::error::Error::InvariantViolation(payload) => {
assert!(
payload.context().contains("reshape divisor"),
"unexpected context: {:?}",
payload.context()
);
assert!(
payload.requirement().contains("non-zero"),
"unexpected requirement: {:?}",
payload.requirement()
);
}
other => panic!("expected InvariantViolation, got {other:?}"),
}
}
#[test]
fn inferred_dim_not_multiple_errors() {
let err = inferred_dim(&[5], &[2]).unwrap_err();
match err {
crate::error::Error::DivisibilityConstraint(payload) => {
assert_eq!(payload.name_dividend(), "total_elements");
assert_eq!(payload.dividend(), 5);
assert_eq!(payload.name_divisor(), "divisor_per_slot");
assert_eq!(payload.divisor(), 2);
}
other => panic!("expected DivisibilityConstraint, got {other:?}"),
}
}
#[test]
fn inferred_dim_result_overflow_errors() {
let big = (i32::MAX as usize) + 1;
let err = inferred_dim(&[big], &[1]).unwrap_err();
match err {
crate::error::Error::ArithmeticOverflow(payload) => {
assert!(
payload.context().contains("inferred dim exceeds i32::MAX"),
"unexpected context: {:?}",
payload.context()
);
assert_eq!(payload.op_type(), "i32");
assert_eq!(payload.operands(), &[("inferred_dim", big as u64)]);
}
other => panic!("expected ArithmeticOverflow, got {other:?}"),
}
}
#[test]
fn inferred_dim_happy_path() {
assert_eq!(inferred_dim(&[2, 12], &[2, 3]).unwrap(), 4i32);
}