use super::*;
fn hand_traced_weight() -> Array {
Array::from_slice::<f32>(
&[
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, ],
&(2, 3, 4),
)
.unwrap()
}
fn hand_traced_input() -> Array {
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &(2, 1, 4)).unwrap()
}
#[test]
fn switch_linear_shape_no_bias() {
let weight = hand_traced_weight();
let layer = SwitchLinear::from_parts(weight, None).unwrap();
assert_eq!(layer.num_experts(), 2);
assert_eq!(layer.output_dims(), 3);
assert_eq!(layer.input_dims(), 4);
let x = hand_traced_input();
let indices = Array::from_slice::<u32>(&[0, 1], &(2usize,)).unwrap();
let out = layer.apply(&x, &indices, false).unwrap();
assert_eq!(out.shape(), vec![2, 1, 3]);
assert_eq!(out.dtype().unwrap(), Dtype::F32);
}
#[test]
fn switch_linear_hand_traced_no_bias() {
let layer = SwitchLinear::from_parts(hand_traced_weight(), None).unwrap();
let x = hand_traced_input();
let indices = Array::from_slice::<u32>(&[0, 1], &(2,)).unwrap();
let mut out = layer.apply(&x, &indices, false).unwrap();
let got = out.to_vec::<f32>().unwrap();
assert_eq!(got, vec![1.0, 2.0, 3.0, 6.0, 7.0, 8.0]);
}
#[test]
fn switch_linear_hand_traced_with_bias() {
let bias = Array::from_slice::<f32>(
&[
10.0, 20.0, 30.0, 40.0, 50.0, 60.0, ],
&(2, 3),
)
.unwrap();
let layer = SwitchLinear::from_parts(hand_traced_weight(), Some(bias)).unwrap();
let x = hand_traced_input();
let indices = Array::from_slice::<u32>(&[0, 1], &(2,)).unwrap();
let mut out = layer.apply(&x, &indices, false).unwrap();
let got = out.to_vec::<f32>().unwrap();
assert_eq!(got, vec![11.0, 22.0, 33.0, 46.0, 57.0, 68.0]);
}
#[test]
fn switch_linear_all_routed_to_one_expert_matches_plain_matmul() {
let weight = hand_traced_weight();
let layer = SwitchLinear::from_parts(weight, None).unwrap();
let x = hand_traced_input();
let indices = Array::from_slice::<u32>(&[0, 0], &(2,)).unwrap();
let mut out = layer.apply(&x, &indices, false).unwrap();
let got = out.to_vec::<f32>().unwrap();
assert_eq!(got, vec![1.0, 2.0, 3.0, 5.0, 6.0, 7.0]);
}
#[test]
fn switch_linear_sorted_indices_matches_unsorted() {
let layer = SwitchLinear::from_parts(hand_traced_weight(), None).unwrap();
let x = hand_traced_input();
let indices = Array::from_slice::<u32>(&[0, 1], &(2,)).unwrap(); let mut via_sorted = layer.apply(&x, &indices, true).unwrap();
let mut via_unsorted = layer.apply(&x, &indices, false).unwrap();
assert_eq!(
via_sorted.to_vec::<f32>().unwrap(),
via_unsorted.to_vec::<f32>().unwrap()
);
}
#[test]
fn switch_linear_from_parts_rejects_2d_weight() {
let bad = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let err = SwitchLinear::from_parts(bad, None).unwrap_err();
assert!(matches!(err, crate::Error::RankMismatch(_)));
}
#[test]
fn switch_linear_from_parts_rejects_mismatched_bias() {
let weight = hand_traced_weight(); let bad_bias =
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], &(3, 3)).unwrap();
let err = SwitchLinear::from_parts(weight, Some(bad_bias)).unwrap_err();
assert!(matches!(err, crate::Error::ShapePairMismatch(_)));
}
#[test]
fn switch_linear_from_parts_rejects_rank_mismatch_bias() {
let weight = hand_traced_weight(); let bad_bias_rank1 = Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap();
let err =
SwitchLinear::from_parts(weight.try_clone().unwrap(), Some(bad_bias_rank1)).unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 1, "rank-1 bias ⇒ actual rank 1");
assert_eq!(payload.actual_shape(), &[2usize]);
}
other => panic!("expected RankMismatch on rank-1 bias, got {other:?}"),
}
let bad_bias_rank3 =
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &(2usize, 3usize, 1usize)).unwrap();
let err = SwitchLinear::from_parts(weight, Some(bad_bias_rank3)).unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 3, "rank-3 bias ⇒ actual rank 3");
assert_eq!(payload.actual_shape(), &[2usize, 3, 1]);
}
other => panic!("expected RankMismatch on rank-3 bias, got {other:?}"),
}
}
#[test]
fn switch_linear_top_k_routing_shape() {
let layer = SwitchLinear::from_parts(hand_traced_weight(), None).unwrap();
let x = Array::from_slice::<f32>(
&[
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 5.0, 6.0, 7.0, 8.0, ],
&(2, 2, 1, 4),
)
.unwrap();
let indices = Array::from_slice::<u32>(&[0, 1, 1, 0], &(2, 2)).unwrap();
let mut out = layer.apply(&x, &indices, false).unwrap();
assert_eq!(out.shape(), vec![2, 2, 1, 3]);
let got = out.to_vec::<f32>().unwrap();
assert_eq!(
got,
vec![1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 5.0, 6.0, 7.0]
);
}
const QUANT_INPUT_DIMS: usize = 64;
fn quant_dense_weight() -> Array {
let e: usize = 2;
let o: usize = 4;
let i = QUANT_INPUT_DIMS;
let mut data = Vec::with_capacity(e * o * i);
for ei in 0..e {
for oi in 0..o {
for ii in 0..i {
data.push(((ei * 100 + oi * 10 + ii) as f32) * 0.001);
}
}
}
Array::from_slice::<f32>(&data, &(e, o, i)).unwrap()
}
fn quant_input() -> Array {
let n: usize = 2;
let i = QUANT_INPUT_DIMS;
let mut data = Vec::with_capacity(n * i);
for ni in 0..n {
for ii in 0..i {
data.push(((ni * 50 + ii) as f32) * 0.01);
}
}
Array::from_slice::<f32>(&data, &(n, 1usize, i)).unwrap()
}
#[test]
fn quantized_switch_linear_parity_within_quant_error() {
let dense_w = quant_dense_weight();
let dense_layer = SwitchLinear::from_parts(dense_w.try_clone().unwrap(), None).unwrap();
let (w_q, scales, q_biases) = quantized::quantize(&dense_w, 64, 4, "affine", None).unwrap();
assert!(
q_biases.is_some(),
"affine scheme produces per-group biases"
);
let q_layer =
QuantizedSwitchLinear::from_parts(w_q, scales, q_biases, None, 64, 4, "affine").unwrap();
let x = quant_input();
let indices = Array::from_slice::<u32>(&[0, 1], &(2,)).unwrap();
let mut dense_out = dense_layer.apply(&x, &indices, false).unwrap();
let mut quant_out = q_layer.apply(&x, &indices, false).unwrap();
assert_eq!(dense_out.shape(), quant_out.shape());
let dense = dense_out.to_vec::<f32>().unwrap();
let quant = quant_out.to_vec::<f32>().unwrap();
let max_abs = dense.iter().fold(0.0f32, |m, v| m.max(v.abs()));
for (d, q) in dense.iter().zip(quant.iter()) {
assert!(
(d - q).abs() <= 0.1 * max_abs + 1e-3,
"quantized SwitchLinear drift too large: dense={d} quant={q}"
);
}
}
#[test]
fn quantized_switch_linear_from_parts_rejects_mismatched_bias() {
let dense_w = quant_dense_weight();
let (w_q, scales, q_biases) = quantized::quantize(&dense_w, 64, 4, "affine", None).unwrap();
let bad_bias = Array::from_slice::<f32>(&[1.0, 2.0], &(2, 1)).unwrap();
let err =
QuantizedSwitchLinear::from_parts(w_q, scales, q_biases, Some(bad_bias), 64, 4, "affine")
.unwrap_err();
assert!(matches!(err, crate::Error::ShapePairMismatch(_)));
}
#[test]
fn quantized_switch_linear_from_parts_rejects_rank_mismatch_bias() {
let dense_w = quant_dense_weight();
let (w_q, scales, q_biases) = quantized::quantize(&dense_w, 64, 4, "affine", None).unwrap();
let bad_bias_rank1 = Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap();
let err =
QuantizedSwitchLinear::from_parts(w_q, scales, q_biases, Some(bad_bias_rank1), 64, 4, "affine")
.unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 1, "rank-1 bias ⇒ actual rank 1");
assert_eq!(payload.actual_shape(), &[2usize]);
}
other => panic!("expected RankMismatch on rank-1 bias, got {other:?}"),
}
}
#[test]
fn quantized_switch_linear_from_parts_rejects_quant_biases_rank_mismatch() {
let dense_w = quant_dense_weight(); let (w_q, scales, _q_biases) = quantized::quantize(&dense_w, 64, 4, "affine", None).unwrap();
let bad_qb =
Array::from_slice::<f32>(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &(2usize, 4usize)).unwrap();
let err = QuantizedSwitchLinear::from_parts(w_q, scales, Some(bad_qb), None, 64, 4, "affine")
.unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 2, "rank-2 quant_biases ⇒ actual rank 2");
assert_eq!(payload.actual_shape(), &[2usize, 4]);
}
other => panic!("expected RankMismatch on rank-2 quant_biases, got {other:?}"),
}
}
#[test]
fn quantized_switch_linear_with_bias_parity_within_quant_error() {
let dense_w = quant_dense_weight();
let bias = Array::from_slice::<f32>(
&[
10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, ],
&(2, 4),
)
.unwrap();
let dense_layer = SwitchLinear::from_parts(
dense_w.try_clone().unwrap(),
Some(bias.try_clone().unwrap()),
)
.unwrap();
let (w_q, scales, q_biases) = quantized::quantize(&dense_w, 64, 4, "affine", None).unwrap();
let q_layer =
QuantizedSwitchLinear::from_parts(w_q, scales, q_biases, Some(bias), 64, 4, "affine").unwrap();
let x = quant_input();
let indices = Array::from_slice::<u32>(&[0, 1], &(2,)).unwrap();
let mut dense_out = dense_layer.apply(&x, &indices, false).unwrap();
let mut quant_out = q_layer.apply(&x, &indices, false).unwrap();
assert_eq!(dense_out.shape(), quant_out.shape());
let dense = dense_out.to_vec::<f32>().unwrap();
let quant = quant_out.to_vec::<f32>().unwrap();
let max_abs = dense.iter().fold(0.0f32, |m, v| m.max(v.abs()));
for (d, q) in dense.iter().zip(quant.iter()) {
assert!(
(d - q).abs() <= 0.1 * max_abs + 1e-3,
"quantized SwitchLinear (with bias) drift too large: dense={d} quant={q}"
);
}
}
fn quant_mxfp4_triple() -> (Array, Array, Option<Array>) {
let dense_w = quant_dense_weight();
quantized::quantize(&dense_w, 32, 4, "mxfp4", None).unwrap()
}
#[test]
fn quantized_switch_linear_from_parts_rejects_mismatched_scales_leading_dims() {
let dense_w = quant_dense_weight(); let (w_q, _scales, _q_biases) = quantized::quantize(&dense_w, 64, 4, "affine", None).unwrap();
let bad_scales = Array::from_slice::<f32>(
&[
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ],
&(3usize, 4usize, 1usize),
)
.unwrap();
let err =
QuantizedSwitchLinear::from_parts(w_q, bad_scales, None, None, 64, 4, "mxfp4").unwrap_err();
assert!(
matches!(err, crate::Error::ShapePairMismatch(_)),
"expected ShapePairMismatch on scales leading dims, got {err:?}"
);
}
#[test]
fn quantized_switch_linear_from_parts_rejects_non_u32_weight() {
let dense_data = vec![0.5f32; 2 * 4 * 8];
let dense_weight = Array::from_slice::<f32>(&dense_data, &(2usize, 4usize, 8usize)).unwrap();
let scales_data = vec![1.0f32; 2 * 4];
let scales = Array::from_slice::<f32>(&scales_data, &(2usize, 4usize, 1usize)).unwrap();
let qb_data = vec![0.0f32; 2 * 4];
let quant_biases = Array::from_slice::<f32>(&qb_data, &(2usize, 4usize, 1usize)).unwrap();
let err = QuantizedSwitchLinear::from_parts(
dense_weight,
scales,
Some(quant_biases),
None,
64,
4,
"affine",
)
.unwrap_err();
match &err {
crate::Error::InvariantViolation(payload) => {
assert!(
payload.context().contains("weight dtype") || payload.requirement().contains("uint32"),
"InvariantViolation context/requirement should name the dtype invariant, got context={:?} requirement={:?}",
payload.context(),
payload.requirement()
);
}
other => panic!("expected InvariantViolation naming the dtype invariant, got {other:?}"),
}
}
#[test]
fn quantized_switch_linear_from_parts_rejects_quant_biases_shape_mismatch() {
let dense_w = quant_dense_weight(); let (w_q, scales, _q_biases) = quantized::quantize(&dense_w, 32, 4, "affine", None).unwrap();
let bad_qb = Array::from_slice::<f32>(
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&(2usize, 4usize, 1usize),
)
.unwrap();
let err = QuantizedSwitchLinear::from_parts(w_q, scales, Some(bad_qb), None, 32, 4, "affine")
.unwrap_err();
assert!(
matches!(err, crate::Error::ShapePairMismatch(_)),
"expected ShapePairMismatch on quant_biases shape, got {err:?}"
);
}
#[test]
fn quantized_switch_linear_from_parts_affine_requires_quant_biases() {
let dense_w = quant_dense_weight();
let (w_q, scales, _q_biases) = quantized::quantize(&dense_w, 64, 4, "affine", None).unwrap();
let err =
QuantizedSwitchLinear::from_parts(w_q, scales, None, None, 64, 4, "affine").unwrap_err();
assert!(
matches!(err, crate::Error::InvariantViolation(_)),
"expected InvariantViolation on affine-missing-quant_biases, got {err:?}"
);
}
#[test]
fn quantized_switch_linear_from_parts_mxfp4_forbids_quant_biases() {
let (w_q, scales, _none_qb) = quant_mxfp4_triple();
let s_shape = scales.shape();
let n_groups = s_shape[2];
let stale_qb_data = vec![0.0f32; 2 * 4 * n_groups];
let stale_qb = Array::from_slice::<f32>(&stale_qb_data, &(2usize, 4usize, n_groups)).unwrap();
let err = QuantizedSwitchLinear::from_parts(w_q, scales, Some(stale_qb), None, 32, 4, "mxfp4")
.unwrap_err();
assert!(
matches!(err, crate::Error::InvariantViolation(_)),
"expected InvariantViolation on mxfp4-with-stale-quant_biases, got {err:?}"
);
}
#[test]
fn quantized_switch_linear_from_parts_unknown_mode() {
let dense_w = quant_dense_weight();
let (w_q, scales, q_biases) = quantized::quantize(&dense_w, 64, 4, "affine", None).unwrap();
let err =
QuantizedSwitchLinear::from_parts(w_q, scales, q_biases, None, 64, 4, "unknown").unwrap_err();
assert!(
matches!(err, crate::Error::UnknownEnumValue(_)),
"expected UnknownEnumValue on unknown mode, got {err:?}"
);
}
#[test]
fn quantized_switch_linear_from_parts_zero_bits_or_group_size() {
let dense_w = quant_dense_weight();
let (w_q, scales, q_biases) = quantized::quantize(&dense_w, 64, 4, "affine", None).unwrap();
let err_bits = QuantizedSwitchLinear::from_parts(
w_q.try_clone().unwrap(),
scales.try_clone().unwrap(),
q_biases.as_ref().map(|q| q.try_clone().unwrap()),
None,
64,
0,
"affine",
)
.unwrap_err();
assert!(
matches!(err_bits, crate::Error::OutOfRange(_)),
"expected OutOfRange on bits=0, got {err_bits:?}"
);
let err_gs =
QuantizedSwitchLinear::from_parts(w_q, scales, q_biases, None, 0, 4, "affine").unwrap_err();
assert!(
matches!(err_gs, crate::Error::OutOfRange(_)),
"expected OutOfRange on group_size=0, got {err_gs:?}"
);
}
#[test]
fn quantized_switch_linear_from_parts_mxfp4_scales_only_ok() {
let (w_q, scales, none_qb) = quant_mxfp4_triple();
assert!(none_qb.is_none(), "mxfp4 quantize must yield None biases");
let layer = QuantizedSwitchLinear::from_parts(w_q, scales, None, None, 32, 4, "mxfp4").unwrap();
assert_eq!(layer.weight_ref().shape()[0], 2); assert_eq!(layer.weight_ref().shape()[1], 4); assert!(layer.quant_biases().is_none());
assert_eq!(layer.mode(), "mxfp4");
}
#[test]
fn switch_linear_fields_are_read_only_via_accessors() {
let layer = SwitchLinear::from_parts(hand_traced_weight(), None).unwrap();
assert_eq!(layer.weight_ref().shape(), vec![2, 3, 4]);
assert!(layer.bias().is_none());
let bias = Array::from_slice::<f32>(
&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &(2, 3),
)
.unwrap();
let layer_with_bias = SwitchLinear::from_parts(hand_traced_weight(), Some(bias)).unwrap();
assert_eq!(layer_with_bias.weight_ref().shape(), vec![2, 3, 4]);
assert_eq!(layer_with_bias.bias().unwrap().shape(), vec![2, 3]);
}
#[test]
fn quantized_switch_linear_fields_are_read_only_via_accessors() {
let dense_w = quant_dense_weight();
let (w_q, scales, q_biases) = quantized::quantize(&dense_w, 64, 4, "affine", None).unwrap();
let bias = Array::from_slice::<f32>(
&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0], &(2, 4),
)
.unwrap();
let q_layer =
QuantizedSwitchLinear::from_parts(w_q, scales, q_biases, Some(bias), 64, 4, "affine").unwrap();
assert_eq!(q_layer.weight_ref().shape()[0], 2); assert_eq!(q_layer.weight_ref().shape()[1], 4); assert_eq!(q_layer.scales_ref().shape()[0], 2); assert!(q_layer.quant_biases().is_some()); assert_eq!(q_layer.bias().unwrap().shape(), vec![2, 4]);
assert_eq!(q_layer.group_size(), 64);
assert_eq!(q_layer.bits(), 4);
assert_eq!(q_layer.mode(), "affine");
}
fn sigmoid_ref(v: f32) -> f32 {
1.0 / (1.0 + (-v).exp())
}
fn silu_ref(v: f32) -> f32 {
v * sigmoid_ref(v)
}
fn assert_close(got: &[f32], want: &[f32]) {
assert_eq!(
got.len(),
want.len(),
"length mismatch: {got:?} vs {want:?}"
);
for (g, w) in got.iter().zip(want.iter()) {
assert!(
(g - w).abs() <= 1e-5 + 1e-5 * w.abs(),
"block output mismatch: got {g}, want {w} (full got {got:?}, want {want:?})"
);
}
}
fn identity_then_swap_weight() -> Array {
Array::from_slice::<f32>(
&[
1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, ],
&(2, 2, 2),
)
.unwrap()
}
fn all_identity_weight() -> Array {
Array::from_slice::<f32>(
&[
1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, ],
&(2, 2, 2),
)
.unwrap()
}
#[test]
fn switch_glu_hand_traced_two_experts() {
let gate_proj = SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap();
let up_proj = SwitchLinear::from_parts(all_identity_weight(), None).unwrap();
let down_proj = SwitchLinear::from_parts(all_identity_weight(), None).unwrap();
let glu = SwitchGLU::new(
gate_proj,
up_proj,
down_proj,
SwitchGLU::default_activation(), )
.unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let indices = Array::from_slice::<u32>(&[0, 1], &(2, 1)).unwrap();
let mut out = glu.forward(&x, &indices).unwrap();
assert_eq!(out.shape(), vec![2, 1, 2]);
let got = out.to_vec::<f32>().unwrap();
let want = vec![
silu_ref(1.0) * 1.0,
silu_ref(2.0) * 2.0,
silu_ref(4.0) * 3.0,
silu_ref(3.0) * 4.0,
];
assert_close(&got, &want);
}
#[test]
fn switch_glu_routing_selects_the_indexed_expert() {
let glu = SwitchGLU::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchGLU::default_activation(),
)
.unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 5.0, 6.0], &(2, 2)).unwrap();
let indices = Array::from_slice::<u32>(&[1, 1], &(2, 1)).unwrap();
let mut out = glu.forward(&x, &indices).unwrap();
let got = out.to_vec::<f32>().unwrap();
let want = vec![
silu_ref(2.0) * 1.0,
silu_ref(1.0) * 2.0,
silu_ref(6.0) * 5.0,
silu_ref(5.0) * 6.0,
];
assert_close(&got, &want);
}
#[test]
fn switch_glu_sorted_path_matches_hand_trace() {
let glu = SwitchGLU::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchGLU::default_activation(),
)
.unwrap();
let n = 64usize;
let mut x_data = Vec::with_capacity(n * 2);
let mut idx_data = Vec::with_capacity(n);
for t in 0..n {
x_data.push(t as f32);
x_data.push(t as f32 + 1.0);
idx_data.push((t % 2) as u32);
}
let x = Array::from_slice::<f32>(&x_data, &(n, 2usize)).unwrap();
let indices = Array::from_slice::<u32>(&idx_data, &(n, 1usize)).unwrap();
assert!(indices.size() >= 64, "test must exercise the sorted path");
let mut out = glu.forward(&x, &indices).unwrap();
assert_eq!(out.shape(), vec![n, 1, 2]);
let got = out.to_vec::<f32>().unwrap();
let mut want = Vec::with_capacity(n * 2);
for t in 0..n {
let (x0, x1) = (t as f32, t as f32 + 1.0);
if t % 2 == 0 {
want.push(silu_ref(x0) * x0);
want.push(silu_ref(x1) * x1);
} else {
want.push(silu_ref(x1) * x0);
want.push(silu_ref(x0) * x1);
}
}
assert_close(&got, &want);
}
#[test]
fn switch_glu_new_rejects_mismatched_projection_shapes() {
let bad_down_weight =
Array::from_slice::<f32>(&[0.0f32; 2 * 3 * 2], &(2usize, 3usize, 2usize)).unwrap();
let err = SwitchGLU::new(
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(bad_down_weight, None).unwrap(),
SwitchGLU::default_activation(),
)
.unwrap_err();
assert!(
matches!(err, crate::Error::ShapePairMismatch(_)),
"expected ShapePairMismatch on mismatched down_proj, got {err:?}"
);
}
#[test]
fn switch_glu_new_rejects_mismatched_num_experts() {
let down_e3 = Array::from_slice::<f32>(&[1.0f32; 3 * 2 * 2], &(3usize, 2usize, 2usize)).unwrap();
let err = SwitchGLU::new(
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(down_e3, None).unwrap(),
SwitchGLU::default_activation(),
)
.unwrap_err();
assert!(
matches!(err, crate::Error::ShapePairMismatch(_)),
"expected ShapePairMismatch on mismatched num_experts, got {err:?}"
);
}
#[test]
fn switch_mlp_hand_traced_two_experts() {
let fc1 = SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap();
let fc2 = SwitchLinear::from_parts(all_identity_weight(), None).unwrap();
let square: Activation = Box::new(|a: &Array| a.multiply(a));
let mlp = SwitchMLP::new(fc1, fc2, square).unwrap();
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let indices = Array::from_slice::<u32>(&[0, 1], &(2, 1)).unwrap();
let mut out = mlp.forward(&x, &indices).unwrap();
assert_eq!(out.shape(), vec![2, 1, 2]);
let got = out.to_vec::<f32>().unwrap();
assert_eq!(got, vec![1.0, 4.0, 16.0, 9.0]);
}
#[test]
fn switch_mlp_default_activation_is_gelu_approx() {
let mlp = SwitchMLP::new(
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchMLP::default_activation(),
)
.unwrap();
let x = Array::from_slice::<f32>(&[-1.0, 0.5, 1.0, 2.0], &(2, 2)).unwrap();
let indices = Array::from_slice::<u32>(&[0, 1], &(2, 1)).unwrap();
let mut out = mlp.forward(&x, &indices).unwrap();
let got = out.to_vec::<f32>().unwrap();
let mut reference = super::super::activations::gelu_approx(&x).unwrap();
let want = reference.to_vec::<f32>().unwrap();
assert_close(&got, &want);
}
#[test]
fn switch_mlp_forward_preserves_f16_dtype() {
let w16 = all_identity_weight().astype(Dtype::F16).unwrap();
let mlp = SwitchMLP::new(
SwitchLinear::from_parts(w16.try_clone().unwrap(), None).unwrap(),
SwitchLinear::from_parts(w16, None).unwrap(),
SwitchMLP::default_activation(),
)
.unwrap();
let x = Array::from_slice::<f32>(&[-1.0, 0.5, 1.0, 2.0], &(2, 2))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let indices = Array::from_slice::<u32>(&[0, 1], &(2, 1)).unwrap();
let out = mlp.forward(&x, &indices).unwrap();
assert_eq!(
out.dtype().unwrap(),
Dtype::F16,
"SwitchMLP default forward must preserve the F16 input dtype"
);
}
#[test]
fn switch_mlp_sorted_path_matches_hand_trace() {
let square: Activation = Box::new(|a: &Array| a.multiply(a));
let mlp = SwitchMLP::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
square,
)
.unwrap();
let n = 64usize;
let mut x_data = Vec::with_capacity(n * 2);
let mut idx_data = Vec::with_capacity(n);
for t in 0..n {
x_data.push(t as f32);
x_data.push(t as f32 + 1.0);
idx_data.push((t % 2) as u32);
}
let x = Array::from_slice::<f32>(&x_data, &(n, 2usize)).unwrap();
let indices = Array::from_slice::<u32>(&idx_data, &(n, 1usize)).unwrap();
assert!(indices.size() >= 64, "test must exercise the sorted path");
let mut out = mlp.forward(&x, &indices).unwrap();
assert_eq!(out.shape(), vec![n, 1, 2]);
let got = out.to_vec::<f32>().unwrap();
let mut want = Vec::with_capacity(n * 2);
for t in 0..n {
let (x0, x1) = (t as f32, t as f32 + 1.0);
if t % 2 == 0 {
want.push(x0 * x0);
want.push(x1 * x1);
} else {
want.push(x1 * x1);
want.push(x0 * x0);
}
}
assert_close(&got, &want);
}
#[test]
fn switch_mlp_new_rejects_mismatched_projection_shapes() {
let bad_fc2 = Array::from_slice::<f32>(&[0.0f32; 2 * 3 * 2], &(2usize, 3usize, 2usize)).unwrap();
let err = SwitchMLP::new(
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(bad_fc2, None).unwrap(),
SwitchMLP::default_activation(),
)
.unwrap_err();
assert!(
matches!(err, crate::Error::ShapePairMismatch(_)),
"expected ShapePairMismatch on mismatched fc2, got {err:?}"
);
}
#[test]
fn gather_sort_then_scatter_unsort_round_trips() {
let indices = Array::from_slice::<u32>(&[2, 0, 1, 1, 0, 2], &(3, 2)).unwrap();
let x = Array::from_slice::<f32>(
&(0..6).map(|i| i as f32).collect::<Vec<_>>(),
&(3usize, 2usize, 1usize, 1usize),
)
.unwrap();
let x_expanded = shape::expand_dims_axes(&x, &[-1]).unwrap(); let (_x_sorted, mut idx_sorted, inv_order) = gather_sort(&x_expanded, &indices).unwrap();
let sorted_ids = idx_sorted.to_vec::<u32>().unwrap();
let mut expected_sorted = vec![2u32, 0, 1, 1, 0, 2];
expected_sorted.sort_unstable();
assert_eq!(sorted_ids, expected_sorted);
let idx_as_rows = shape::expand_dims_axes(&idx_sorted, &[-1]).unwrap(); let mut restored = scatter_unsort(&idx_as_rows, &inv_order, &[3, 2]).unwrap();
assert_eq!(restored.shape(), vec![3, 2, 1]);
let restored_flat = restored.to_vec::<u32>().unwrap();
assert_eq!(restored_flat, vec![2, 0, 1, 1, 0, 2]);
}
#[test]
fn switch_glu_sorted_path_rejects_ambiguous_flat_indices() {
let glu = SwitchGLU::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchGLU::default_activation(),
)
.unwrap();
let n = 64usize;
let mut x_data = Vec::with_capacity(n * 2);
let mut idx_data = Vec::with_capacity(n);
for t in 0..n {
x_data.push(t as f32);
x_data.push(t as f32 + 1.0);
idx_data.push((t % 2) as u32);
}
let x = Array::from_slice::<f32>(&x_data, &(n, 2usize)).unwrap();
let indices = Array::from_slice::<u32>(&idx_data, &(n,)).unwrap();
assert!(indices.size() >= 64, "test must exercise the sorted path");
let err = glu.forward(&x, &indices).unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 1, "rank-1 indices ⇒ actual rank 1");
assert_eq!(payload.actual_shape(), &[64usize]);
}
other => panic!("expected RankMismatch on ambiguous [N] indices, got {other:?}"),
}
}
#[test]
fn switch_glu_sorted_path_rejects_ambiguous_batch_indices() {
let glu = SwitchGLU::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchGLU::default_activation(),
)
.unwrap();
let (b, s) = (8usize, 8usize);
let mut x_data = Vec::with_capacity(b * s * 2);
let mut idx_data = Vec::with_capacity(b * s);
for t in 0..(b * s) {
x_data.push(t as f32);
x_data.push(t as f32 + 1.0);
idx_data.push((t % 2) as u32);
}
let x = Array::from_slice::<f32>(&x_data, &(b, s, 2usize)).unwrap();
let indices = Array::from_slice::<u32>(&idx_data, &(b, s)).unwrap();
assert!(indices.size() >= 64, "test must exercise the sorted path");
let err = glu.forward(&x, &indices).unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 2, "rank-2 indices ⇒ actual rank 2");
assert_eq!(payload.actual_shape(), &[8usize, 8]);
}
other => panic!("expected RankMismatch on ambiguous [B, S] indices, got {other:?}"),
}
}
#[test]
fn switch_glu_sorted_path_top1_explicit_k_routes_each_token_to_its_expert() {
let glu = SwitchGLU::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchGLU::default_activation(),
)
.unwrap();
let n = 64usize;
let mut x_data = Vec::with_capacity(n * 2);
let mut idx_data = Vec::with_capacity(n);
for t in 0..n {
x_data.push(t as f32 + 1.0);
x_data.push(t as f32 + 2.0);
idx_data.push(if t < n / 2 { 0u32 } else { 1u32 });
}
let x = Array::from_slice::<f32>(&x_data, &(n, 2usize)).unwrap();
let indices = Array::from_slice::<u32>(&idx_data, &(n, 1usize)).unwrap();
assert!(indices.size() >= 64, "test must exercise the sorted path");
let mut out = glu.forward(&x, &indices).unwrap();
assert_eq!(out.shape(), vec![n, 1, 2]);
let got = out.to_vec::<f32>().unwrap();
let mut want = Vec::with_capacity(n * 2);
for t in 0..n {
let (x0, x1) = (t as f32 + 1.0, t as f32 + 2.0);
if t < n / 2 {
want.push(silu_ref(x0) * x0);
want.push(silu_ref(x1) * x1);
} else {
want.push(silu_ref(x1) * x0);
want.push(silu_ref(x0) * x1);
}
}
assert_close(&got, &want);
}
#[test]
fn switch_glu_sorted_path_explicit_2d_batch_k_routes_each_token() {
let glu = SwitchGLU::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchGLU::default_activation(),
)
.unwrap();
let (b, s) = (8usize, 8usize);
let mut x_data = Vec::with_capacity(b * s * 2);
let mut idx_data = Vec::with_capacity(b * s);
for t in 0..(b * s) {
x_data.push(t as f32 + 1.0);
x_data.push(t as f32 + 2.0);
idx_data.push((t % 2) as u32);
}
let x = Array::from_slice::<f32>(&x_data, &(b, s, 2usize)).unwrap();
let indices = Array::from_slice::<u32>(&idx_data, &(b, s, 1usize)).unwrap();
assert!(indices.size() >= 64, "test must exercise the sorted path");
let mut out = glu.forward(&x, &indices).unwrap();
assert_eq!(out.shape(), vec![b, s, 1, 2]);
let got = out.to_vec::<f32>().unwrap();
let mut want = Vec::with_capacity(b * s * 2);
for t in 0..(b * s) {
let (x0, x1) = (t as f32 + 1.0, t as f32 + 2.0);
if t % 2 == 0 {
want.push(silu_ref(x0) * x0);
want.push(silu_ref(x1) * x1);
} else {
want.push(silu_ref(x1) * x0);
want.push(silu_ref(x0) * x1);
}
}
assert_close(&got, &want);
}
#[test]
fn switch_mlp_sorted_path_rejects_ambiguous_flat_indices() {
let square: Activation = Box::new(|a: &Array| a.multiply(a));
let mlp = SwitchMLP::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
square,
)
.unwrap();
let n = 64usize;
let mut x_data = Vec::with_capacity(n * 2);
let mut idx_data = Vec::with_capacity(n);
for t in 0..n {
x_data.push(t as f32);
x_data.push(t as f32 + 1.0);
idx_data.push((t % 2) as u32);
}
let x = Array::from_slice::<f32>(&x_data, &(n, 2usize)).unwrap();
let indices = Array::from_slice::<u32>(&idx_data, &(n,)).unwrap();
assert!(indices.size() >= 64, "test must exercise the sorted path");
let err = mlp.forward(&x, &indices).unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 1, "rank-1 indices ⇒ actual rank 1");
assert_eq!(payload.actual_shape(), &[64usize]);
}
other => panic!("expected RankMismatch on ambiguous [N] indices, got {other:?}"),
}
}
#[test]
fn switch_mlp_sorted_path_rejects_ambiguous_batch_indices() {
let square: Activation = Box::new(|a: &Array| a.multiply(a));
let mlp = SwitchMLP::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
square,
)
.unwrap();
let (b, s) = (8usize, 8usize);
let mut x_data = Vec::with_capacity(b * s * 2);
let mut idx_data = Vec::with_capacity(b * s);
for t in 0..(b * s) {
x_data.push(t as f32);
x_data.push(t as f32 + 1.0);
idx_data.push((t % 2) as u32);
}
let x = Array::from_slice::<f32>(&x_data, &(b, s, 2usize)).unwrap();
let indices = Array::from_slice::<u32>(&idx_data, &(b, s)).unwrap();
assert!(indices.size() >= 64, "test must exercise the sorted path");
let err = mlp.forward(&x, &indices).unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 2, "rank-2 indices ⇒ actual rank 2");
assert_eq!(payload.actual_shape(), &[8usize, 8]);
}
other => panic!("expected RankMismatch on ambiguous [B, S] indices, got {other:?}"),
}
}
#[test]
fn switch_mlp_sorted_path_top1_explicit_k_routes_each_token_to_its_expert() {
let square: Activation = Box::new(|a: &Array| a.multiply(a));
let mlp = SwitchMLP::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
square,
)
.unwrap();
let n = 64usize;
let mut x_data = Vec::with_capacity(n * 2);
let mut idx_data = Vec::with_capacity(n);
for t in 0..n {
x_data.push(t as f32 + 1.0);
x_data.push(t as f32 + 2.0);
idx_data.push(if t < n / 2 { 0u32 } else { 1u32 });
}
let x = Array::from_slice::<f32>(&x_data, &(n, 2usize)).unwrap();
let indices = Array::from_slice::<u32>(&idx_data, &(n, 1usize)).unwrap();
assert!(indices.size() >= 64, "test must exercise the sorted path");
let mut out = mlp.forward(&x, &indices).unwrap();
assert_eq!(out.shape(), vec![n, 1, 2]);
let got = out.to_vec::<f32>().unwrap();
let mut want = Vec::with_capacity(n * 2);
for t in 0..n {
let (x0, x1) = (t as f32 + 1.0, t as f32 + 2.0);
if t < n / 2 {
want.push(x0 * x0);
want.push(x1 * x1);
} else {
want.push(x1 * x1);
want.push(x0 * x0);
}
}
assert_close(&got, &want);
}
#[test]
fn quantized_switch_linear_from_parts_rejects_2d_weight() {
let bad_weight = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2usize, 2usize)).unwrap();
let scales = Array::from_slice::<f32>(&[1.0, 1.0], &(2usize, 1usize)).unwrap();
let err =
QuantizedSwitchLinear::from_parts(bad_weight, scales, None, None, 64, 4, "affine").unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 2, "rank-2 weight ⇒ actual rank 2");
assert_eq!(payload.actual_shape(), &[2usize, 2]);
}
other => panic!("expected RankMismatch on rank-2 quantized weight, got {other:?}"),
}
}
#[test]
fn quantized_switch_linear_from_parts_rejects_scales_rank_mismatch() {
let u32_weight = Array::from_slice::<u32>(&[0u32, 0, 0, 0], &(2usize, 2usize, 1usize)).unwrap();
let bad_scales = Array::from_slice::<f32>(&[1.0, 1.0, 1.0, 1.0], &(2usize, 2usize)).unwrap();
let err = QuantizedSwitchLinear::from_parts(u32_weight, bad_scales, None, None, 64, 4, "mxfp4")
.unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 2, "rank-2 scales ⇒ actual rank 2");
assert_eq!(payload.actual_shape(), &[2usize, 2]);
}
other => panic!("expected RankMismatch on rank-2 scales, got {other:?}"),
}
}
#[test]
fn check_routing_indices_rejects_rank0_x() {
let empty_shape: Vec<usize> = Vec::new();
let x = Array::from_slice::<f32>(&[5.0f32], &empty_shape).unwrap();
assert!(x.shape().is_empty(), "x must be rank-0 for this branch");
let indices = Array::from_slice::<u32>(&[0u32], &(1usize,)).unwrap();
let err = check_routing_indices(&x, &indices).unwrap_err();
match err {
crate::Error::RankMismatch(payload) => {
assert_eq!(payload.actual(), 0, "rank-0 x ⇒ actual rank 0");
assert!(
payload.actual_shape().is_empty(),
"rank-0 x ⇒ empty actual shape, got {:?}",
payload.actual_shape()
);
}
other => panic!("expected RankMismatch on rank-0 x, got {other:?}"),
}
}
#[test]
fn check_routing_indices_one_differing_batch_dim_is_length_mismatch() {
let x = Array::from_slice::<f32>(&[0.0f32; 2 * 3 * 4], &(2usize, 3usize, 4usize)).unwrap();
let indices = Array::from_slice::<u32>(&[0u32; 2 * 5], &(2usize, 5usize, 1usize)).unwrap();
let err = check_routing_indices(&x, &indices).unwrap_err();
match err {
crate::Error::LengthMismatch(payload) => {
assert_eq!(payload.expected(), 3, "x's differing batch dim is 3");
assert_eq!(payload.actual(), 5, "indices' differing leading dim is 5");
}
other => panic!("expected LengthMismatch on single differing batch dim, got {other:?}"),
}
}
#[test]
fn check_routing_indices_two_differing_batch_dims_is_shape_pair_mismatch() {
let x = Array::from_slice::<f32>(&[0.0f32; 2 * 3 * 4], &(2usize, 3usize, 4usize)).unwrap();
let indices = Array::from_slice::<u32>(&[0u32; 7 * 9], &(7usize, 9usize, 1usize)).unwrap();
let err = check_routing_indices(&x, &indices).unwrap_err();
match err {
crate::Error::ShapePairMismatch(payload) => {
assert_eq!(payload.expected(), &[2usize, 3], "x batch dims");
assert_eq!(payload.actual(), &[7usize, 9], "indices leading dims");
}
other => panic!("expected ShapePairMismatch on two differing batch dims, got {other:?}"),
}
}
#[test]
fn switch_glu_new_rejects_gate_up_hidden_mismatch() {
let gate = SwitchLinear::from_parts(all_identity_weight(), None).unwrap(); let up_w = Array::from_slice::<f32>(&[0.0f32; 2 * 3 * 2], &(2usize, 3usize, 2usize)).unwrap();
let up = SwitchLinear::from_parts(up_w, None).unwrap();
let down = SwitchLinear::from_parts(all_identity_weight(), None).unwrap();
let err = SwitchGLU::new(gate, up, down, SwitchGLU::default_activation()).unwrap_err();
match err {
crate::Error::ShapePairMismatch(payload) => {
assert_eq!(payload.expected(), &[2usize, 2], "gate [input, hidden]");
assert_eq!(payload.actual(), &[2usize, 3], "up [input, hidden]");
}
other => panic!("expected ShapePairMismatch on gate/up hidden mismatch, got {other:?}"),
}
}
#[test]
fn switch_mlp_new_rejects_num_experts_mismatch() {
let fc1 = SwitchLinear::from_parts(all_identity_weight(), None).unwrap(); let fc2_w = Array::from_slice::<f32>(&[1.0f32; 3 * 2 * 2], &(3usize, 2usize, 2usize)).unwrap();
let fc2 = SwitchLinear::from_parts(fc2_w, None).unwrap();
let err = SwitchMLP::new(fc1, fc2, SwitchMLP::default_activation()).unwrap_err();
match err {
crate::Error::LengthMismatch(payload) => {
assert_eq!(payload.expected(), 2, "fc1 num_experts");
assert_eq!(payload.actual(), 3, "fc2 num_experts");
}
other => panic!("expected LengthMismatch on fc1/fc2 num_experts, got {other:?}"),
}
}
#[test]
fn switch_glu_projection_accessors_return_constructed_layers() {
let glu = SwitchGLU::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchGLU::default_activation(),
)
.unwrap();
assert_eq!(glu.gate_proj().weight_ref().shape(), vec![2, 2, 2]);
assert_eq!(glu.up_proj().weight_ref().shape(), vec![2, 2, 2]);
assert_eq!(glu.down_proj().weight_ref().shape(), vec![2, 2, 2]);
assert_eq!(glu.gate_proj().num_experts(), 2);
assert_eq!(glu.gate_proj().input_dims(), 2);
assert_eq!(glu.gate_proj().output_dims(), 2);
}
#[test]
fn switch_glu_debug_elides_activation() {
let glu = SwitchGLU::new(
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchGLU::default_activation(),
)
.unwrap();
let rendered = format!("{glu:?}");
assert!(rendered.contains("SwitchGLU"), "got {rendered}");
assert!(rendered.contains("gate_proj"), "got {rendered}");
assert!(rendered.contains("up_proj"), "got {rendered}");
assert!(rendered.contains("down_proj"), "got {rendered}");
assert!(
rendered.contains("<fn>"),
"activation must be elided as <fn>, got {rendered}"
);
}
#[test]
fn switch_mlp_projection_accessors_return_constructed_layers() {
let square: Activation = Box::new(|a: &Array| a.multiply(a));
let mlp = SwitchMLP::new(
SwitchLinear::from_parts(identity_then_swap_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
square,
)
.unwrap();
assert_eq!(mlp.fc1().weight_ref().shape(), vec![2, 2, 2]);
assert_eq!(mlp.fc2().weight_ref().shape(), vec![2, 2, 2]);
assert_eq!(mlp.fc1().num_experts(), 2);
assert_eq!(mlp.fc2().output_dims(), 2);
}
#[test]
fn switch_mlp_debug_elides_activation() {
let square: Activation = Box::new(|a: &Array| a.multiply(a));
let mlp = SwitchMLP::new(
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
SwitchLinear::from_parts(all_identity_weight(), None).unwrap(),
square,
)
.unwrap();
let rendered = format!("{mlp:?}");
assert!(rendered.contains("SwitchMLP"), "got {rendered}");
assert!(rendered.contains("fc1"), "got {rendered}");
assert!(rendered.contains("fc2"), "got {rendered}");
assert!(
rendered.contains("<fn>"),
"activation must be elided as <fn>, got {rendered}"
);
}