use super::*;
fn arr_f32(data: &[f32], shape: &[usize]) -> Array {
Array::from_slice::<f32>(data, &shape).expect("from_slice")
}
fn arr_u32(data: &[u32], shape: &[usize]) -> Array {
Array::from_slice::<u32>(data, &shape).expect("from_slice")
}
#[test]
fn quantization_parses_minimal_block() {
let cfg_json = r#"{ "quantization": { "group_size": 64, "bits": 4 } }"#;
let plq = parse_quantization(cfg_json).unwrap().unwrap();
let q = plq.quantization.expect("global quant present");
assert_eq!(q.group_size, 64);
assert_eq!(q.bits, 4);
assert_eq!(q.mode, QuantMode::Affine);
assert!(plq.per_layer.is_empty());
}
#[test]
fn quantization_parses_mode_explicit() {
let cfg_json = r#"{ "quantization": { "group_size": 32, "bits": 4, "mode": "mxfp4" } }"#;
let q = parse_quantization(cfg_json)
.unwrap()
.unwrap()
.quantization
.unwrap();
assert_eq!(q.mode, QuantMode::Mxfp4);
}
#[test]
fn quantization_parses_per_layer_overrides() {
let cfg_json = r#"{
"quantization": {
"group_size": 64,
"bits": 4,
"model.embed_tokens": { "group_size": 32, "bits": 4 },
"model.layers.0.self_attn.q_norm": false
}
}"#;
let plq = parse_quantization(cfg_json).unwrap().unwrap();
let q = plq.quantization.unwrap();
assert_eq!(q.group_size, 64);
assert_eq!(q.bits, 4);
assert_eq!(plq.per_layer.len(), 2);
match plq.per_layer.get("model.embed_tokens") {
Some(QuantizationOption::Quantize(q2)) => {
assert_eq!(q2.group_size, 32);
assert_eq!(q2.bits, 4);
}
other => panic!("expected Quantize override, got {other:?}"),
}
assert_eq!(
plq
.per_layer
.get("model.layers.0.self_attn.q_norm")
.copied(),
Some(QuantizationOption::Skip)
);
assert_eq!(
plq.quantization_for("model.embed_tokens"),
Some(Quantization {
group_size: 32,
bits: 4,
mode: QuantMode::Affine,
})
);
assert_eq!(
plq.quantization_for("model.layers.0.self_attn.q_norm"),
None
);
assert_eq!(
plq.quantization_for("model.layers.5.mlp.gate_proj"),
Some(q)
);
}
#[test]
fn quantization_ignores_legacy_hf_keys() {
let cfg_json = r#"{
"quantization": {
"group_size": 64,
"bits": 4,
"quant_method": "awq",
"linear_class": "QuantizedLinear",
"quantization_mode": "affine"
}
}"#;
let plq = parse_quantization(cfg_json).unwrap().unwrap();
assert!(plq.per_layer.is_empty());
assert_eq!(plq.quantization.unwrap().group_size, 64);
}
#[test]
fn quantization_absent_returns_none() {
let cfg_json = r#"{ "model_type": "qwen3", "hidden_size": 1024 }"#;
let plq = parse_quantization(cfg_json).unwrap();
assert!(plq.is_none());
}
#[test]
fn quantization_invalid_json_errors() {
let plq = parse_quantization("{ not json");
assert!(plq.is_err());
}
#[test]
fn quantize_weights_applies_to_eligible_and_skips_rest() {
let group_size = 64_usize;
let n_rows = 3_usize;
let w1 = arr_f32(&vec![0.5_f32; n_rows * group_size], &[n_rows, group_size]);
let w2 = arr_f32(&vec![-0.25_f32; n_rows * group_size], &[n_rows, group_size]);
let already_w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let already_scales = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let already_biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let bias = arr_f32(&[1.0_f32, 2.0, 3.0], &[3]);
let odd_last = arr_f32(&vec![0.0_f32; 3 * 63], &[3, 63]);
let other = arr_f32(&[42.0_f32], &[1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.layers.0.q_proj.weight".to_string(), w1);
weights.insert("model.layers.0.k_proj.weight".to_string(), w2);
weights.insert("model.layers.1.v_proj.weight".to_string(), already_w);
weights.insert("model.layers.1.v_proj.scales".to_string(), already_scales);
weights.insert("model.layers.1.v_proj.biases".to_string(), already_biases);
weights.insert("model.layers.0.q_proj.bias".to_string(), bias);
weights.insert("model.layers.2.bad.weight".to_string(), odd_last);
weights.insert("model.norm.gamma".to_string(), other);
let cfg = PerLayerQuantization::from_global(Quantization::affine(group_size as i32, 4));
let out = quantize_weights(weights, &cfg, &default_eligible).expect("quantize");
for path in ["model.layers.0.q_proj", "model.layers.0.k_proj"] {
let w_q = out.get(&format!("{path}.weight")).expect(".weight");
let scales = out.get(&format!("{path}.scales")).expect(".scales");
let biases = out
.get(&format!("{path}.biases"))
.expect(".biases (affine)");
assert_eq!(w_q.shape(), vec![3, 8]);
assert_eq!(w_q.dtype().unwrap(), crate::dtype::Dtype::U32);
assert_eq!(scales.shape(), vec![3, 1]);
assert_eq!(scales.dtype().unwrap(), crate::dtype::Dtype::F32);
assert_eq!(biases.shape(), vec![3, 1]);
assert_eq!(biases.dtype().unwrap(), crate::dtype::Dtype::F32);
}
let pre_q_w = out.get("model.layers.1.v_proj.weight").expect("already-w");
assert_eq!(pre_q_w.shape(), vec![n_rows, 8]);
assert_eq!(pre_q_w.dtype().unwrap(), crate::dtype::Dtype::U32);
assert!(out.contains_key("model.layers.1.v_proj.scales"));
assert!(out.contains_key("model.layers.1.v_proj.biases"));
assert_eq!(
out.get("model.layers.0.q_proj.bias").unwrap().shape(),
vec![3]
);
assert_eq!(
out.get("model.layers.2.bad.weight").unwrap().shape(),
vec![3, 63]
);
assert_eq!(out.get("model.norm.gamma").unwrap().shape(), vec![1]);
assert!(!out.contains_key("model.layers.0.q_proj.scales.scales"));
assert!(!out.contains_key("model.layers.2.bad.scales"));
assert!(!out.contains_key("model.layers.2.bad.biases"));
}
#[test]
fn quantize_then_dequantize_roundtrips_within_tolerance() {
let group_size = 64_usize;
let n_rows = 4_usize;
let data: Vec<f32> = (0..n_rows * group_size)
.map(|i| (i as f32 / 128.0) - 1.0)
.collect();
let w = arr_f32(&data, &[n_rows, group_size]);
let mut weights: Weights = HashMap::new();
weights.insert("model.linear.weight".to_string(), w);
let cfg = PerLayerQuantization::from_global(Quantization::affine(group_size as i32, 4));
let quantized = quantize_weights(weights, &cfg, &default_eligible).unwrap();
let dequantized = dequantize_weights(quantized, &cfg).unwrap();
let mut deq = dequantized
.get("model.linear.weight")
.expect("round-tripped .weight")
.try_clone()
.unwrap();
assert_eq!(deq.shape(), vec![n_rows, group_size]);
let deq_vec: Vec<f32> = deq.to_vec().unwrap();
let max_abs_err = data
.iter()
.zip(deq_vec.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);
assert!(
max_abs_err < 0.05,
"round-trip max abs err = {max_abs_err}; expected < 0.05 for 4-bit affine"
);
}
#[test]
fn quantize_weights_per_layer_skip_passes_through() {
let group_size = 64_usize;
let n_rows = 2_usize;
let w = arr_f32(&vec![0.1_f32; n_rows * group_size], &[n_rows, group_size]);
let mut weights: Weights = HashMap::new();
weights.insert("model.embed_tokens.weight".to_string(), w);
let mut per_layer = HashMap::new();
per_layer.insert("model.embed_tokens".to_string(), QuantizationOption::Skip);
let cfg = PerLayerQuantization {
quantization: Some(Quantization::affine(group_size as i32, 4)),
per_layer,
};
let out = quantize_weights(weights, &cfg, &default_eligible).unwrap();
let pass = out.get("model.embed_tokens.weight").expect(".weight");
assert_eq!(pass.shape(), vec![n_rows, group_size]);
assert_eq!(pass.dtype().unwrap(), crate::dtype::Dtype::F32);
assert!(!out.contains_key("model.embed_tokens.scales"));
assert!(!out.contains_key("model.embed_tokens.biases"));
}
#[test]
fn quantize_weights_per_layer_override_uses_override_params() {
let n_rows = 2_usize;
let last = 32_usize;
let w = arr_f32(&vec![0.1_f32; n_rows * last], &[n_rows, last]);
let mut weights: Weights = HashMap::new();
weights.insert("model.embed_tokens.weight".to_string(), w);
let mut per_layer = HashMap::new();
per_layer.insert(
"model.embed_tokens".to_string(),
QuantizationOption::Quantize(Quantization::affine(32, 4)),
);
let cfg = PerLayerQuantization {
quantization: Some(Quantization::affine(64, 4)),
per_layer,
};
let out = quantize_weights(weights, &cfg, &default_eligible).unwrap();
let scales = out.get("model.embed_tokens.scales").expect(".scales");
assert_eq!(scales.shape(), vec![n_rows, 1]);
let w_q = out.get("model.embed_tokens.weight").expect(".weight");
assert_eq!(w_q.shape(), vec![n_rows, 4]);
}
#[test]
fn quantize_weights_predicate_rejected_passes_through() {
let group_size = 64_usize;
let n_rows = 2_usize;
let w = arr_f32(&vec![0.5_f32; n_rows * group_size], &[n_rows, group_size]);
let mut weights: Weights = HashMap::new();
weights.insert("model.some_future_module.weight".to_string(), w);
let cfg = PerLayerQuantization::from_global(Quantization::affine(group_size as i32, 4));
let reject_all: &Eligible<'_> = &|_path: &str, _arr: &Array| false;
let out = quantize_weights(weights, &cfg, reject_all).unwrap();
let pass = out.get("model.some_future_module.weight").expect(".weight");
assert_eq!(pass.shape(), vec![n_rows, group_size]);
assert_eq!(pass.dtype().unwrap(), crate::dtype::Dtype::F32);
assert!(!out.contains_key("model.some_future_module.scales"));
assert!(!out.contains_key("model.some_future_module.biases"));
}
#[test]
fn quantize_weights_predicate_approved_quantizes() {
let group_size = 64_usize;
let n_rows = 2_usize;
let w_yes = arr_f32(&vec![0.5_f32; n_rows * group_size], &[n_rows, group_size]);
let w_no = arr_f32(&vec![0.5_f32; n_rows * group_size], &[n_rows, group_size]);
let mut weights: Weights = HashMap::new();
weights.insert("model.linear_class.weight".to_string(), w_yes);
weights.insert("model.other_class.weight".to_string(), w_no);
let cfg = PerLayerQuantization::from_global(Quantization::affine(group_size as i32, 4));
let only_linear: &Eligible<'_> = &|path: &str, _arr: &Array| path == "model.linear_class";
let out = quantize_weights(weights, &cfg, only_linear).unwrap();
assert_eq!(
out
.get("model.linear_class.scales")
.expect("scales for approved layer")
.shape(),
vec![n_rows, 1]
);
assert_eq!(
out
.get("model.other_class.weight")
.expect("rejected layer .weight kept")
.shape(),
vec![n_rows, group_size]
);
assert!(!out.contains_key("model.other_class.scales"));
assert!(!out.contains_key("model.other_class.biases"));
}
#[test]
fn quantization_missing_bits_errors() {
let cfg_json = r#"{ "quantization": { "group_size": 64 } }"#;
let err = parse_quantization(cfg_json).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("bits"),
"error should mention the missing `bits` key, got: {msg}"
);
}
#[test]
fn quantization_missing_group_size_errors() {
let cfg_json = r#"{ "quantization": { "bits": 4 } }"#;
let err = parse_quantization(cfg_json).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("group_size"),
"error should mention the missing `group_size` key, got: {msg}"
);
}
#[test]
fn quantization_both_present_ok() {
let cfg_json = r#"{ "quantization": { "group_size": 32, "bits": 4 } }"#;
let plq = parse_quantization(cfg_json).unwrap().unwrap();
let q = plq.quantization.expect("global quant present");
assert_eq!(q.group_size, 32);
assert_eq!(q.bits, 4);
}
#[test]
fn quantize_weights_orphan_biases_collision_errors() {
let group_size = 64_usize;
let n_rows = 2_usize;
let w = arr_f32(&vec![0.5_f32; n_rows * group_size], &[n_rows, group_size]);
let stale_biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.foo.weight".to_string(), w);
weights.insert("model.foo.biases".to_string(), stale_biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(group_size as i32, 4));
let err = quantize_weights(weights, &cfg, &default_eligible).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.foo"),
"LayerKeyed must name the colliding layer, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::MissingKey(_)),
"inner must be MissingKey for stale `.biases` without `.scales`, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn quantize_weights_valid_existing_triple_still_skipped() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.already.weight".to_string(), w);
weights.insert("model.already.scales".to_string(), scales);
weights.insert("model.already.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let out = quantize_weights(weights, &cfg, &default_eligible).expect("valid triple passes");
let w_out = out.get("model.already.weight").unwrap();
assert_eq!(w_out.shape(), vec![n_rows, 8]);
assert_eq!(w_out.dtype().unwrap(), crate::dtype::Dtype::U32);
assert!(out.contains_key("model.already.scales"));
assert!(out.contains_key("model.already.biases"));
}
#[test]
fn quantize_weights_orphan_scales_with_dense_weight_errors() {
let group_size = 64_usize;
let n_rows = 2_usize;
let w = arr_f32(&vec![0.5_f32; n_rows * group_size], &[n_rows, group_size]);
let stale_scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let stale_biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.foo.weight".to_string(), w);
weights.insert("model.foo.scales".to_string(), stale_scales);
weights.insert("model.foo.biases".to_string(), stale_biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(group_size as i32, 4));
let err = quantize_weights(weights, &cfg, &default_eligible).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.foo"),
"LayerKeyed must name the colliding layer, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::UnsupportedDtype(_)),
"inner must be UnsupportedDtype for dense-weight orphan, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn quantize_weights_mismatched_scales_shape_errors() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let bad_scales = arr_f32(&[1.0_f32; 3], &[3, 1]);
let biases = arr_f32(&[0.0_f32; 3], &[3, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.foo.weight".to_string(), w);
weights.insert("model.foo.scales".to_string(), bad_scales);
weights.insert("model.foo.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let err = quantize_weights(weights, &cfg, &default_eligible).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.foo"),
"LayerKeyed must name the colliding layer, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::ShapePairMismatch(_)),
"inner must be ShapePairMismatch for leading-dim mismatch, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn quantize_weights_rank1_uint32_triple_errors() {
let w = arr_u32(&[0_u32, 0, 0, 0], &[4]);
let scales = arr_u32(&[1_u32], &[1]);
let biases = arr_u32(&[0_u32], &[1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.bad.weight".to_string(), w);
weights.insert("model.bad.scales".to_string(), scales);
weights.insert("model.bad.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let err = quantize_weights(weights, &cfg, &default_eligible).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.bad"),
"LayerKeyed must name the malformed layer, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::RankMismatch(_)),
"inner must be RankMismatch for rank-1 `.weight`, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn quantize_weights_pre_quantized_triple_passes_through_to_mlxc() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows * 2], &[n_rows, 2]);
let biases = arr_f32(&vec![0.0_f32; n_rows * 2], &[n_rows, 2]);
let mut weights: Weights = HashMap::new();
weights.insert("model.foo.weight".to_string(), w);
weights.insert("model.foo.scales".to_string(), scales);
weights.insert("model.foo.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let out = quantize_weights(weights, &cfg, &default_eligible)
.expect("triple now passes through; mlx-c validates per-mode params at call time");
let w_out = out.get("model.foo.weight").expect(".weight");
assert_eq!(w_out.shape(), vec![n_rows, 8]);
assert_eq!(w_out.dtype().unwrap(), crate::dtype::Dtype::U32);
let s_out = out.get("model.foo.scales").expect(".scales");
assert_eq!(s_out.shape(), vec![n_rows, 2]);
assert!(out.contains_key("model.foo.biases"));
}
#[test]
fn quantize_weights_pre_quantized_bits3_triple_passes_through() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.foo.weight".to_string(), w);
weights.insert("model.foo.scales".to_string(), scales);
weights.insert("model.foo.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 3));
let out = quantize_weights(weights, &cfg, &default_eligible)
.expect("bits=3 triple passes through; mlx supports bits ∈ {2,3,4,5,6,8}");
let w_out = out.get("model.foo.weight").expect(".weight");
assert_eq!(w_out.shape(), vec![n_rows, 8]);
assert_eq!(w_out.dtype().unwrap(), crate::dtype::Dtype::U32);
}
#[test]
fn quantize_weights_valid_triple_skipped() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.foo.weight".to_string(), w);
weights.insert("model.foo.scales".to_string(), scales);
weights.insert("model.foo.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let out =
quantize_weights(weights, &cfg, &default_eligible).expect("valid triple passes through");
let w_out = out.get("model.foo.weight").expect(".weight");
assert_eq!(w_out.shape(), vec![n_rows, 8]);
assert_eq!(w_out.dtype().unwrap(), crate::dtype::Dtype::U32);
let s_out = out.get("model.foo.scales").expect(".scales");
assert_eq!(s_out.shape(), vec![n_rows, 1]);
assert!(out.contains_key("model.foo.biases"));
}
#[test]
fn quantize_weights_triple_on_skip_path_errors() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.embed_tokens.weight".to_string(), w);
weights.insert("model.embed_tokens.scales".to_string(), scales);
let mut per_layer = HashMap::new();
per_layer.insert("model.embed_tokens".to_string(), QuantizationOption::Skip);
let cfg = PerLayerQuantization {
quantization: Some(Quantization::affine(64, 4)),
per_layer,
};
let err = quantize_weights(weights, &cfg, &default_eligible).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.embed_tokens"),
"LayerKeyed must name the Skip layer, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::KeyCollision(_)),
"inner must be KeyCollision for Skip-with-stale-scales, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn quantize_weights_affine_triple_missing_biases_errors() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.affine_missing.weight".to_string(), w);
weights.insert("model.affine_missing.scales".to_string(), scales);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let err = quantize_weights(weights, &cfg, &default_eligible).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.affine_missing"),
"LayerKeyed must name the incomplete layer, got layer={:?}",
payload.layer()
);
let Error::MissingKey(inner) = payload.inner() else {
panic!(
"inner must be MissingKey for affine missing biases, got: {:?}",
payload.inner()
);
};
assert!(
inner.key().contains(".biases"),
"MissingKey must name the missing `.biases` sibling, got key={:?}",
inner.key()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn quantize_weights_mxfp4_triple_with_stale_biases_errors() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 4], &[n_rows, 4]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let stale_biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.mxfp4_stale.weight".to_string(), w);
weights.insert("model.mxfp4_stale.scales".to_string(), scales);
weights.insert("model.mxfp4_stale.biases".to_string(), stale_biases);
let cfg = PerLayerQuantization::from_global(Quantization {
group_size: 32,
bits: 4,
mode: QuantMode::Mxfp4,
});
let err = quantize_weights(weights, &cfg, &default_eligible).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.mxfp4_stale"),
"LayerKeyed must name the offending layer, got layer={:?}",
payload.layer()
);
let Error::KeyCollision(inner) = payload.inner() else {
panic!(
"inner must be KeyCollision for mxfp4-with-stale-biases, got: {:?}",
payload.inner()
);
};
assert!(
inner.key().contains(".biases"),
"KeyCollision must name the stale `.biases` sibling, got key={:?}",
inner.key()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn quantize_weights_valid_mxfp4_scales_only_triple_passes() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 4], &[n_rows, 4]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.mxfp4_ok.weight".to_string(), w);
weights.insert("model.mxfp4_ok.scales".to_string(), scales);
let cfg = PerLayerQuantization::from_global(Quantization {
group_size: 32,
bits: 4,
mode: QuantMode::Mxfp4,
});
let out = quantize_weights(weights, &cfg, &default_eligible)
.expect("scale-only mxfp4 triple passes through");
let w_out = out.get("model.mxfp4_ok.weight").expect(".weight");
assert_eq!(w_out.shape(), vec![n_rows, 4]);
assert_eq!(w_out.dtype().unwrap(), crate::dtype::Dtype::U32);
let s_out = out.get("model.mxfp4_ok.scales").expect(".scales");
assert_eq!(s_out.shape(), vec![n_rows, 1]);
assert!(!out.contains_key("model.mxfp4_ok.biases"));
}
#[test]
fn dequantize_weights_affine_missing_biases_errors() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.affine_no_bias.weight".to_string(), w);
weights.insert("model.affine_no_bias.scales".to_string(), scales);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let err = dequantize_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.affine_no_bias"),
"LayerKeyed must name the layer, got layer={:?}",
payload.layer()
);
let Error::MissingKey(inner) = payload.inner() else {
panic!(
"inner must be MissingKey for affine triple missing biases, got: {:?}",
payload.inner()
);
};
assert!(
inner.key().contains(".biases"),
"MissingKey must name the missing `.biases` sibling, got key={:?}",
inner.key()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn dequantize_weights_mxfp4_with_stale_biases_errors() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 4], &[n_rows, 4]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let stale_biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.mxfp4_stale.weight".to_string(), w);
weights.insert("model.mxfp4_stale.scales".to_string(), scales);
weights.insert("model.mxfp4_stale.biases".to_string(), stale_biases);
let cfg = PerLayerQuantization::from_global(Quantization {
group_size: 32,
bits: 4,
mode: QuantMode::Mxfp4,
});
let err = dequantize_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.mxfp4_stale"),
"LayerKeyed must name the layer, got layer={:?}",
payload.layer()
);
let Error::KeyCollision(inner) = payload.inner() else {
panic!(
"inner must be KeyCollision for mxfp4-with-stale-biases on dequantize, got: {:?}",
payload.inner()
);
};
assert!(
inner.key().contains(".biases"),
"KeyCollision must name the stale `.biases` sibling, got key={:?}",
inner.key()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn dequantize_weights_orphan_biases_with_packed_weight_errors() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.orphan_bias.weight".to_string(), w);
weights.insert("model.orphan_bias.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let err = dequantize_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.orphan_bias"),
"LayerKeyed must name the layer, got layer={:?}",
payload.layer()
);
let Error::MissingKey(inner) = payload.inner() else {
panic!(
"inner must be MissingKey for orphan biases without scales, got: {:?}",
payload.inner()
);
};
assert!(
inner.key().contains(".scales"),
"MissingKey must name the missing `.scales` sibling, got key={:?}",
inner.key()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn dequantize_weights_dense_weight_with_biases_passes_through() {
let n_rows = 2_usize;
let n_cols = 8_usize;
let w = arr_f32(
&(0..n_rows * n_cols).map(|i| i as f32).collect::<Vec<_>>(),
&[n_rows, n_cols],
);
let biases = arr_f32(&vec![0.5_f32; n_cols], &[n_cols]);
let mut weights: Weights = HashMap::new();
weights.insert("model.dense.weight".to_string(), w);
weights.insert("model.dense.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let out = dequantize_weights(weights, &cfg)
.expect("dense `.weight` (F32) + `.biases` (F32) with no `.scales` must pass through");
let mut w_out = out
.get("model.dense.weight")
.expect("passed-through .weight")
.try_clone()
.unwrap();
let mut b_out = out
.get("model.dense.biases")
.expect("passed-through .biases")
.try_clone()
.unwrap();
assert_eq!(w_out.dtype().unwrap(), Dtype::F32);
assert_eq!(b_out.dtype().unwrap(), Dtype::F32);
assert_eq!(w_out.shape(), vec![n_rows, n_cols]);
assert_eq!(b_out.shape(), vec![n_cols]);
let w_vec: Vec<f32> = w_out.to_vec().unwrap();
let b_vec: Vec<f32> = b_out.to_vec().unwrap();
assert_eq!(
w_vec,
(0..n_rows * n_cols).map(|i| i as f32).collect::<Vec<_>>(),
"dense `.weight` data must be passed through verbatim"
);
assert_eq!(
b_vec,
vec![0.5_f32; n_cols],
"`.biases` data must be passed through verbatim"
);
}
#[test]
fn unpack_awq_weights_single_int32_gives_8_nibbles() {
let packed = Array::from_slice::<u32>(&[0xFFFF_u32], &(1usize, 1)).unwrap();
let mut unpacked = unpack_awq_weights(&packed).unwrap();
assert_eq!(unpacked.shape(), vec![1, 8]);
assert_eq!(unpacked.dtype().unwrap(), Dtype::U32);
assert_eq!(
unpacked.to_vec::<u32>().unwrap(),
vec![0xF, 0, 0xF, 0, 0xF, 0, 0xF, 0]
);
}
#[test]
fn unpack_awq_weights_reverses_awq_scramble() {
let packed_val: u32 =
(1_u32 << 16) | (2 << 4) | (3 << 20) | (4 << 8) | (5 << 24) | (6 << 12) | (7 << 28);
assert_eq!(packed_val, 0x7531_6420);
let packed = Array::from_slice::<u32>(&[packed_val], &(1usize, 1)).unwrap();
let mut unpacked = unpack_awq_weights(&packed).unwrap();
assert_eq!(unpacked.shape(), vec![1, 8]);
assert_eq!(
unpacked.to_vec::<u32>().unwrap(),
vec![0, 1, 2, 3, 4, 5, 6, 7]
);
}
#[test]
fn unpack_awq_weights_preserves_row_count_expands_cols_8x() {
let packed = Array::from_slice::<u32>(&[0u32; 6], &(3usize, 2)).unwrap();
let mut unpacked = unpack_awq_weights(&packed).unwrap();
assert_eq!(unpacked.shape(), vec![3, 16]);
assert_eq!(unpacked.to_vec::<u32>().unwrap(), vec![0u32; 48]);
}
#[test]
fn unpack_awq_weights_handles_zero_input() {
let packed = Array::from_slice::<u32>(&[0u32, 0, 0, 0], &(2usize, 2)).unwrap();
let mut unpacked = unpack_awq_weights(&packed).unwrap();
assert_eq!(unpacked.shape(), vec![2, 16]);
assert_eq!(unpacked.to_vec::<u32>().unwrap(), vec![0u32; 32]);
}
#[test]
fn unpack_awq_weights_rejects_non_2d() {
let r1 = Array::from_slice::<u32>(&[0u32; 4], &(4usize,)).unwrap();
let err = unpack_awq_weights(&r1).unwrap_err();
assert!(
matches!(err, Error::RankMismatch(_)),
"1-D should be RankMismatch, got {err:?}"
);
let r3 = Array::from_slice::<u32>(&[0u32; 8], &(2usize, 2, 2)).unwrap();
assert!(matches!(
unpack_awq_weights(&r3).unwrap_err(),
Error::RankMismatch(_)
));
}
#[test]
fn unpack_awq_weights_rejects_non_32bit_int_dtype() {
let r = Array::from_slice::<f32>(&[0.0_f32; 4], &(2usize, 2)).unwrap();
let err = unpack_awq_weights(&r).unwrap_err();
assert!(
matches!(err, Error::UnsupportedDtype(_)),
"f32 dtype should be UnsupportedDtype, got {err:?}"
);
}
#[test]
fn unpack_awq_weights_accepts_i32_input() {
let raw: u32 = 0xF0FF_FFFF;
let signed: i32 = raw as i32;
assert!(
signed < 0,
"fixture must be negative to exercise the sign bit"
);
let i32_packed = Array::from_slice::<i32>(&[signed], &(1usize, 1)).unwrap();
let u32_packed = Array::from_slice::<u32>(&[raw], &(1usize, 1)).unwrap();
let mut from_i32 = unpack_awq_weights(&i32_packed).expect("i32 input should be accepted");
let mut from_u32 = unpack_awq_weights(&u32_packed).expect("u32 input still accepted");
assert_eq!(from_i32.shape(), vec![1, 8]);
assert_eq!(from_u32.shape(), vec![1, 8]);
assert_eq!(from_i32.dtype().unwrap(), Dtype::U32);
let i32_nibbles = from_i32.to_vec::<u32>().unwrap();
let u32_nibbles = from_u32.to_vec::<u32>().unwrap();
assert_eq!(
i32_nibbles, u32_nibbles,
"i32 input must produce the SAME nibbles as the equivalent u32 input (bit-preserving)"
);
}
#[test]
fn unpack_awq_weights_accepts_u32_input() {
let raw: u32 = 0xF0FF_FFFF;
let packed = Array::from_slice::<u32>(&[raw], &(1usize, 1)).unwrap();
let out = unpack_awq_weights(&packed).expect("u32 input accepted");
assert_eq!(out.shape(), vec![1, 8]);
assert_eq!(out.dtype().unwrap(), Dtype::U32);
}
fn awq_pack_one_row(nibbles: [u32; 8]) -> u32 {
let mut packed = 0u32;
for (k, &n) in nibbles.iter().enumerate() {
packed |= (n & 0xF) << AWQ_SHIFTS[k];
}
packed
}
fn awq_dequant(nibble: u32, zero: u32, scale: f32) -> f32 {
(nibble as i32 - zero as i32) as f32 * scale
}
fn mlx_dequant(nibble: u32, scale: f32, bias: f32) -> f32 {
nibble as f32 * scale + bias
}
#[test]
fn transform_awq_weights_round_trips_known_fixture() {
let in_features = 8usize;
let out_features = 8usize;
let group_size = 4u32;
let n_groups = 2usize;
let awq_unpacked: Vec<Vec<u32>> = (0..in_features)
.map(|i| {
(0..out_features)
.map(|o| (((i + 1) * 3 + o) % 16) as u32)
.collect()
})
.collect();
let qweight_data: Vec<u32> = (0..in_features)
.map(|i| {
let row: [u32; 8] = awq_unpacked[i].to_vec().try_into().unwrap();
awq_pack_one_row(row)
})
.collect();
let qweight = Array::from_slice::<u32>(&qweight_data, &(in_features, 1)).unwrap();
let qzero_unpacked: Vec<Vec<u32>> = (0..n_groups)
.map(|g| (0..out_features).map(|o| ((g + o) % 16) as u32).collect())
.collect();
let qzeros_data: Vec<u32> = (0..n_groups)
.map(|g| {
let row: [u32; 8] = qzero_unpacked[g].to_vec().try_into().unwrap();
awq_pack_one_row(row)
})
.collect();
let qzeros = Array::from_slice::<u32>(&qzeros_data, &(n_groups, 1)).unwrap();
let scales_data: Vec<f32> = (0..n_groups * out_features)
.map(|i| 0.1_f32 * (i as f32 + 1.0))
.collect();
let scales = Array::from_slice::<f32>(&scales_data, &(n_groups, out_features)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qweight);
weights.insert("layer.qzeros".to_string(), qzeros);
weights.insert("layer.scales".to_string(), scales);
let config = AwqLoadConfig {
bits: 4,
group_size,
zero_point: true,
version: "gemm".into(),
};
let (out, plq) = transform_awq_weights(weights, &config).expect("transform");
let g = plq.quantization.expect("global quant");
assert_eq!(g.group_size, group_size as i32);
assert_eq!(g.bits, 4);
assert_eq!(g.mode, QuantMode::Affine);
let mut weight_arr = out
.get("layer.weight")
.expect("layer.weight")
.try_clone()
.unwrap();
let mut scales_arr = out
.get("layer.scales")
.expect("layer.scales")
.try_clone()
.unwrap();
let mut biases_arr = out
.get("layer.biases")
.expect("layer.biases")
.try_clone()
.unwrap();
assert!(
!out.contains_key("layer.qweight"),
"qweight key must be replaced by .weight"
);
assert!(
!out.contains_key("layer.qzeros"),
"qzeros key must be replaced by .biases"
);
assert_eq!(weight_arr.dtype().unwrap(), Dtype::U32);
assert_eq!(weight_arr.shape(), vec![out_features, in_features / 8]);
assert_eq!(scales_arr.shape(), vec![out_features, n_groups]);
assert_eq!(biases_arr.shape(), vec![out_features, n_groups]);
assert_eq!(scales_arr.dtype().unwrap(), Dtype::F32);
assert_eq!(biases_arr.dtype().unwrap(), Dtype::F32);
let weight_packed: Vec<u32> = weight_arr.to_vec().unwrap();
let mut mlx_nibbles = vec![vec![0u32; in_features]; out_features];
for o in 0..out_features {
for pi in 0..(in_features / 8) {
let word = weight_packed[o * (in_features / 8) + pi];
for k in 0..8 {
mlx_nibbles[o][pi * 8 + k] = (word >> (k as u32 * AWQ_BITS)) & AWQ_NIBBLE_MASK;
}
}
}
for o in 0..out_features {
for i in 0..in_features {
assert_eq!(
mlx_nibbles[o][i], awq_unpacked[i][o],
"MLX-format nibble at (o={o}, i={i}) must equal AWQ-format nibble at (i={i}, o={o})"
);
}
}
let scales_flat: Vec<f32> = scales_arr.to_vec().unwrap();
let biases_flat: Vec<f32> = biases_arr.to_vec().unwrap();
for o in 0..out_features {
for g in 0..n_groups {
let mlx_scale = scales_flat[o * n_groups + g];
let mlx_bias = biases_flat[o * n_groups + g];
let awq_scale = scales_data[g * out_features + o];
let awq_zero = qzero_unpacked[g][o];
for i_in in 0..(group_size as usize) {
let i = g * (group_size as usize) + i_in;
let nibble = awq_unpacked[i][o];
let awq_dq = awq_dequant(nibble, awq_zero, awq_scale);
let mlx_dq = mlx_dequant(nibble, mlx_scale, mlx_bias);
assert!(
(awq_dq - mlx_dq).abs() < 1e-4,
"AWQ dequant {awq_dq} != MLX dequant {mlx_dq} at (o={o}, g={g}, i={i}, nibble={nibble})"
);
}
}
}
}
#[test]
fn transform_awq_weights_handles_multiple_layers() {
let group_size = 4u32;
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let make_weights = |prefix: &str| -> Vec<(String, Array)> {
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let qz = Array::from_slice::<u32>(&vec![0u32; n_groups], &(n_groups, 1)).unwrap();
let scales_data: Vec<f32> = (0..n_groups * out_features)
.map(|i| 0.1_f32 * (i as f32 + 1.0))
.collect();
let sc = Array::from_slice::<f32>(&scales_data, &(n_groups, out_features)).unwrap();
vec![
(format!("{prefix}.qweight"), qw),
(format!("{prefix}.qzeros"), qz),
(format!("{prefix}.scales"), sc),
]
};
let mut weights: Weights = HashMap::new();
for (k, v) in make_weights("layer0.q") {
weights.insert(k, v);
}
for (k, v) in make_weights("layer1.q") {
weights.insert(k, v);
}
let passthrough = Array::from_slice::<f32>(&[1.0_f32; 16], &(2usize, 8)).unwrap();
weights.insert("embed_tokens.weight".to_string(), passthrough);
let config = AwqLoadConfig {
bits: 4,
group_size,
zero_point: true,
version: String::new(),
};
let (out, plq) = transform_awq_weights(weights, &config).expect("transform");
assert!(out.contains_key("layer0.q.weight"));
assert!(out.contains_key("layer0.q.scales"));
assert!(out.contains_key("layer0.q.biases"));
assert!(out.contains_key("layer1.q.weight"));
assert!(out.contains_key("layer1.q.scales"));
assert!(out.contains_key("layer1.q.biases"));
assert!(!out.contains_key("layer0.q.qweight"));
assert!(!out.contains_key("layer1.q.qzeros"));
let mut pt = out
.get("embed_tokens.weight")
.expect("pass-through")
.try_clone()
.unwrap();
assert_eq!(pt.shape(), vec![2, 8]);
assert_eq!(pt.to_vec::<f32>().unwrap(), vec![1.0_f32; 16]);
let g = plq.quantization.unwrap();
assert_eq!(g.group_size, group_size as i32);
assert_eq!(g.bits, 4);
assert!(plq.per_layer.is_empty());
}
#[test]
fn transform_awq_weights_rejects_missing_scales() {
let in_features = 8usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
let config = AwqLoadConfig::default();
let err = transform_awq_weights(weights, &config).unwrap_err();
let Error::MissingKey(p) = &err else {
panic!("expected Error::MissingKey, got {err:?}");
};
assert_eq!(p.key(), "layer.scales");
assert!(
p.context()
.contains("AWQ `.qweight` missing its `.scales` companion"),
"context names the rule: {}",
p.context()
);
}
#[test]
fn transform_awq_weights_rejects_mismatched_shapes() {
let qw = Array::from_slice::<u32>(&[0u32; 8], &(8usize, 1)).unwrap();
let sc = Array::from_slice::<f32>(&[0.1_f32; 32], &(4usize, 8)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
let config = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false,
version: String::new(),
};
let err = transform_awq_weights(weights, &config).unwrap_err();
assert!(
matches!(err, Error::LayerKeyed(ref p) if matches!(p.inner(), Error::ShapePairMismatch(_))),
"expected LayerKeyed(ShapePairMismatch), got {err:?}"
);
}
#[test]
fn transform_awq_weights_rejects_g_idx() {
let qw = Array::from_slice::<u32>(&[0u32; 8], &(8usize, 1)).unwrap();
let sc = Array::from_slice::<f32>(&[0.1_f32; 16], &(2usize, 8)).unwrap();
let gidx = Array::from_slice::<i32>(&[0i32, 1, 0, 1, 0, 1, 0, 1], &(8usize,)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
weights.insert("layer.g_idx".to_string(), gidx);
let config = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false,
version: String::new(),
};
let err = transform_awq_weights(weights, &config).unwrap_err();
let msg = format!("{err:?}");
assert!(
msg.contains("g_idx"),
"error must mention g_idx, got: {msg}"
);
}
#[test]
fn transform_awq_weights_rejects_non_4_bits() {
let weights: Weights = HashMap::new();
let config = AwqLoadConfig {
bits: 8,
..AwqLoadConfig::default()
};
let err = transform_awq_weights(weights, &config).unwrap_err();
let Error::OutOfRange(p) = &err else {
panic!("expected Error::OutOfRange, got {err:?}");
};
assert!(
p.context().contains("AWQ bits"),
"context names the AWQ bits rule: {}",
p.context()
);
assert!(
p.requirement().contains("must be 4"),
"requirement names the constraint: {}",
p.requirement()
);
assert_eq!(p.value(), "8");
}
#[test]
fn transform_awq_weights_symmetric_uses_implicit_zero() {
let in_features = 8usize;
let out_features = 8usize;
let group_size = 4u32;
let n_groups = 2usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let scales_data: Vec<f32> = vec![1.0_f32; n_groups * out_features];
let sc = Array::from_slice::<f32>(&scales_data, &(n_groups, out_features)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
let config = AwqLoadConfig {
bits: 4,
group_size,
zero_point: false,
version: String::new(),
};
let (out, _) = transform_awq_weights(weights, &config).expect("transform");
let mut biases_arr = out
.get("layer.biases")
.expect("layer.biases")
.try_clone()
.unwrap();
let biases: Vec<f32> = biases_arr.to_vec().unwrap();
for &b in &biases {
assert!(
(b + 8.0_f32).abs() < 1e-5,
"symmetric bias must be -2^(bits-1) * scale = -8.0, got {b}"
);
}
}
#[test]
fn transform_awq_weights_empty_input_is_noop() {
let pt = Array::from_slice::<f32>(&[1.0_f32, 2.0, 3.0], &(3usize,)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.weight".to_string(), pt);
let config = AwqLoadConfig::default();
let (out, plq) = transform_awq_weights(weights, &config).expect("transform");
let mut got = out
.get("layer.weight")
.expect("pass-through")
.try_clone()
.unwrap();
assert_eq!(got.to_vec::<f32>().unwrap(), vec![1.0_f32, 2.0, 3.0]);
let g = plq.quantization.unwrap();
assert_eq!(g.bits, 4);
assert_eq!(g.group_size, 128);
assert_eq!(g.mode, QuantMode::Affine);
}
fn awq_gemm_fixture_weights() -> Weights {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let qz = Array::from_slice::<u32>(&vec![0u32; n_groups], &(n_groups, 1)).unwrap();
let scales_data: Vec<f32> = vec![1.0_f32; n_groups * out_features];
let sc = Array::from_slice::<f32>(&scales_data, &(n_groups, out_features)).unwrap();
let mut w: Weights = HashMap::new();
w.insert("layer.qweight".to_string(), qw);
w.insert("layer.qzeros".to_string(), qz);
w.insert("layer.scales".to_string(), sc);
w
}
#[test]
fn transform_awq_weights_rejects_gemv_version() {
let config = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemv".into(),
};
let err = transform_awq_weights(awq_gemm_fixture_weights(), &config).unwrap_err();
match err {
Error::UnknownEnumValue(ref payload) => {
assert_eq!(
payload.value(),
"gemv",
"UnknownEnumValue must name the offending 'gemv' version"
);
}
other => panic!("expected Error::UnknownEnumValue, got: {other:?}"),
}
}
#[test]
fn transform_awq_weights_rejects_unknown_version() {
let config = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "unsupported".into(),
};
let err = transform_awq_weights(awq_gemm_fixture_weights(), &config).unwrap_err();
match err {
Error::UnknownEnumValue(ref payload) => {
assert_eq!(
payload.value(),
"unsupported",
"UnknownEnumValue must name the offending version"
);
}
other => panic!("expected Error::UnknownEnumValue, got: {other:?}"),
}
}
#[test]
fn transform_awq_weights_accepts_empty_version() {
let config = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: String::new(),
};
transform_awq_weights(awq_gemm_fixture_weights(), &config)
.expect("empty version (serde default) must be accepted");
}
#[test]
fn transform_awq_weights_accepts_gemm_version() {
let config = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemm".into(),
};
transform_awq_weights(awq_gemm_fixture_weights(), &config)
.expect("explicit 'gemm' version must be accepted");
}
#[test]
fn transform_awq_weights_accepts_i32_qweight_and_qzeros() {
let in_features = 8usize;
let out_features = 8usize;
let group_size = 4u32;
let n_groups = 2usize;
let qweight_data_u32: Vec<u32> = (0..in_features)
.map(|i| {
let nibbles = [
(i % 16) as u32,
((i + 1) % 16) as u32,
((i + 2) % 16) as u32,
((i + 3) % 16) as u32,
((i + 4) % 16) as u32,
((i + 5) % 16) as u32,
((i + 6) % 16) as u32,
0xF_u32, ];
awq_pack_one_row(nibbles)
})
.collect();
let qweight_data_i32: Vec<i32> = qweight_data_u32.iter().map(|&u| u as i32).collect();
assert!(
qweight_data_i32.iter().any(|&v| v < 0),
"fixture must contain a negative i32 to exercise the high-bit case"
);
let qzero_unpacked: Vec<Vec<u32>> = (0..n_groups)
.map(|g| (0..out_features).map(|o| ((g + o) % 16) as u32).collect())
.collect();
let qzeros_data_u32: Vec<u32> = (0..n_groups)
.map(|g| {
let row: [u32; 8] = qzero_unpacked[g].to_vec().try_into().unwrap();
awq_pack_one_row(row)
})
.collect();
let qzeros_data_i32: Vec<i32> = qzeros_data_u32.iter().map(|&u| u as i32).collect();
let scales_data: Vec<f32> = (0..n_groups * out_features)
.map(|i| 0.1_f32 * (i as f32 + 1.0))
.collect();
let qw_i32 = Array::from_slice::<i32>(&qweight_data_i32, &(in_features, 1)).unwrap();
let qz_i32 = Array::from_slice::<i32>(&qzeros_data_i32, &(n_groups, 1)).unwrap();
let sc = Array::from_slice::<f32>(&scales_data, &(n_groups, out_features)).unwrap();
let mut weights_i32: Weights = HashMap::new();
weights_i32.insert("layer.qweight".to_string(), qw_i32);
weights_i32.insert("layer.qzeros".to_string(), qz_i32);
weights_i32.insert("layer.scales".to_string(), sc);
let config = AwqLoadConfig {
bits: 4,
group_size,
zero_point: true,
version: "gemm".into(),
};
let (out, plq) =
transform_awq_weights(weights_i32, &config).expect("i32 qweight + qzeros accepted");
let weight_arr = out.get("layer.weight").expect("layer.weight");
assert_eq!(weight_arr.dtype().unwrap(), Dtype::U32);
let g = plq.quantization.expect("global quant");
assert_eq!(g.bits, 4);
assert_eq!(g.group_size, group_size as i32);
}
#[test]
fn transform_awq_weights_preserves_bit_pattern_on_i32_input() {
let in_features = 8usize;
let out_features = 8usize;
let group_size = 4u32;
let n_groups = 2usize;
let qweight_data_u32: Vec<u32> = (0..in_features)
.map(|i| {
let nibbles = [
(i % 16) as u32,
((i + 7) % 16) as u32,
((i + 3) % 16) as u32,
((i + 5) % 16) as u32,
((i + 2) % 16) as u32,
((i + 6) % 16) as u32,
((i + 1) % 16) as u32,
0xF_u32,
];
awq_pack_one_row(nibbles)
})
.collect();
let qweight_data_i32: Vec<i32> = qweight_data_u32.iter().map(|&u| u as i32).collect();
let qzeros_data_u32: Vec<u32> = vec![0_u32; n_groups];
let qzeros_data_i32: Vec<i32> = vec![0_i32; n_groups];
let scales_data: Vec<f32> = (0..n_groups * out_features)
.map(|i| 0.5_f32 + (i as f32) * 0.01)
.collect();
let build = |qw_dtype_i32: bool| -> Weights {
let mut w: Weights = HashMap::new();
if qw_dtype_i32 {
w.insert(
"layer.qweight".to_string(),
Array::from_slice::<i32>(&qweight_data_i32, &(in_features, 1)).unwrap(),
);
w.insert(
"layer.qzeros".to_string(),
Array::from_slice::<i32>(&qzeros_data_i32, &(n_groups, 1)).unwrap(),
);
} else {
w.insert(
"layer.qweight".to_string(),
Array::from_slice::<u32>(&qweight_data_u32, &(in_features, 1)).unwrap(),
);
w.insert(
"layer.qzeros".to_string(),
Array::from_slice::<u32>(&qzeros_data_u32, &(n_groups, 1)).unwrap(),
);
}
w.insert(
"layer.scales".to_string(),
Array::from_slice::<f32>(&scales_data, &(n_groups, out_features)).unwrap(),
);
w
};
let cfg = AwqLoadConfig {
bits: 4,
group_size,
zero_point: true,
version: "gemm".into(),
};
let (out_u32, _) = transform_awq_weights(build(false), &cfg).expect("u32 path");
let (out_i32, _) = transform_awq_weights(build(true), &cfg).expect("i32 path");
let mut w_u32 = out_u32.get("layer.weight").unwrap().try_clone().unwrap();
let mut w_i32 = out_i32.get("layer.weight").unwrap().try_clone().unwrap();
let u32_buf: Vec<u32> = w_u32.to_vec().unwrap();
let i32_buf: Vec<u32> = w_i32.to_vec().unwrap();
assert_eq!(
u32_buf, i32_buf,
"i32 qweight must produce the SAME .weight bit-pattern as the equivalent u32 input"
);
}
#[test]
fn transform_awq_weights_rejects_integer_scales_dtype() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let qz = Array::from_slice::<u32>(&vec![0u32; n_groups], &(n_groups, 1)).unwrap();
let sc_int = Array::from_slice::<i32>(
&vec![1_i32; n_groups * out_features],
&(n_groups, out_features),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("model.layer0.qweight".to_string(), qw);
weights.insert("model.layer0.qzeros".to_string(), qz);
weights.insert("model.layer0.scales".to_string(), sc_int);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.layer0.scales"),
"LayerKeyed must name the offending layer's `.scales` key, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::UnsupportedDtype(_)),
"inner must be UnsupportedDtype for non-floating scales, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn transform_awq_weights_rejects_uint_scales_dtype() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let qz = Array::from_slice::<u32>(&vec![0u32; n_groups], &(n_groups, 1)).unwrap();
let sc_u8 = Array::from_slice::<u8>(
&vec![1_u8; n_groups * out_features],
&(n_groups, out_features),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("model.layer0.qweight".to_string(), qw);
weights.insert("model.layer0.qzeros".to_string(), qz);
weights.insert("model.layer0.scales".to_string(), sc_u8);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.layer0.scales"),
"LayerKeyed must name the offending layer's `.scales`, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::UnsupportedDtype(_)),
"inner must be UnsupportedDtype for non-floating scales, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn resolve_awq_model_dtype_uses_highest_when_hierarchical() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let f16_scales_data: Vec<half::f16> = (0..n_groups * out_features)
.map(|i| half::f16::from_f32(0.1 * (i + 1) as f32))
.collect();
let f32_scales_data: Vec<f32> = (0..n_groups * out_features)
.map(|i| 0.5 + 0.01 * (i as f32))
.collect();
let qw_a = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_a = Array::from_slice::<half::f16>(&f16_scales_data, &(n_groups, out_features)).unwrap();
let qw_b = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_b = Array::from_slice::<f32>(&f32_scales_data, &(n_groups, out_features)).unwrap();
let weights: Weights = HashMap::from([
("layer_a.qweight".to_string(), qw_a),
("layer_a.scales".to_string(), sc_a),
("layer_b.qweight".to_string(), qw_b),
("layer_b.scales".to_string(), sc_b),
]);
let mut prefixes: Vec<String> = vec!["layer_a".to_string(), "layer_b".to_string()];
prefixes.sort();
validate_awq_scales_are_floating(&weights, &prefixes).expect("both floating, must pass");
let resolved = resolve_awq_model_dtype(&weights, &prefixes)
.unwrap()
.expect("some dtype");
assert_eq!(
resolved,
Dtype::F32,
"F32+F16 hierarchical must resolve to F32 (superset), got {resolved:?}"
);
let bf16_scales_data: Vec<half::bf16> = (0..n_groups * out_features)
.map(|i| half::bf16::from_f32(0.5 + 0.01 * (i as f32)))
.collect();
let f64_scales_data: Vec<f64> = (0..n_groups * out_features)
.map(|i| 0.5 + 0.001 * (i as f64))
.collect();
let qw_c = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_c = Array::from_slice::<half::bf16>(&bf16_scales_data, &(n_groups, out_features)).unwrap();
let qw_d = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_d = Array::from_slice::<f64>(&f64_scales_data, &(n_groups, out_features)).unwrap();
let weights2: Weights = HashMap::from([
("layer_c.qweight".to_string(), qw_c),
("layer_c.scales".to_string(), sc_c),
("layer_d.qweight".to_string(), qw_d),
("layer_d.scales".to_string(), sc_d),
]);
let mut prefixes2: Vec<String> = vec!["layer_c".to_string(), "layer_d".to_string()];
prefixes2.sort();
validate_awq_scales_are_floating(&weights2, &prefixes2).expect("both floating, must pass");
let resolved2 = resolve_awq_model_dtype(&weights2, &prefixes2)
.unwrap()
.expect("some dtype");
assert_eq!(
resolved2,
Dtype::F64,
"F64+BF16 hierarchical must resolve to F64 (superset), got {resolved2:?}"
);
}
#[test]
fn resolve_awq_model_dtype_escalates_f16_plus_bf16_to_f32() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let f16_scales_data: Vec<half::f16> = (0..n_groups * out_features)
.map(|i| half::f16::from_f32(0.1 * (i + 1) as f32))
.collect();
let bf16_scales_data: Vec<half::bf16> = (0..n_groups * out_features)
.map(|i| half::bf16::from_f32(0.5 + 0.01 * (i as f32)))
.collect();
let build = || {
let qw_a = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_a = Array::from_slice::<half::f16>(&f16_scales_data, &(n_groups, out_features)).unwrap();
let qw_b = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_b =
Array::from_slice::<half::bf16>(&bf16_scales_data, &(n_groups, out_features)).unwrap();
let weights: Weights = HashMap::from([
("layer_a.qweight".to_string(), qw_a),
("layer_a.scales".to_string(), sc_a),
("layer_b.qweight".to_string(), qw_b),
("layer_b.scales".to_string(), sc_b),
]);
weights
};
let weights = build();
let prefixes: Vec<String> = vec!["layer_a".to_string(), "layer_b".to_string()];
validate_awq_scales_are_floating(&weights, &prefixes).expect("both floating, must pass");
let resolved = resolve_awq_model_dtype(&weights, &prefixes)
.unwrap()
.expect("some dtype");
assert_eq!(
resolved,
Dtype::F32,
"F16+BF16 must escalate to F32 (no half is a superset), got {resolved:?}"
);
let weights_r = build();
let prefixes_r: Vec<String> = vec!["layer_b".to_string(), "layer_a".to_string()];
validate_awq_scales_are_floating(&weights_r, &prefixes_r).expect("both floating, must pass");
let resolved_r = resolve_awq_model_dtype(&weights_r, &prefixes_r)
.unwrap()
.expect("some dtype");
assert_eq!(
resolved_r,
Dtype::F32,
"F16+BF16 reversed order must still escalate to F32, got {resolved_r:?}"
);
}
#[test]
fn resolve_awq_model_dtype_escalates_f16_plus_bf16_plus_f32_stays_at_f32() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let f16_scales_data: Vec<half::f16> = (0..n_groups * out_features)
.map(|i| half::f16::from_f32(0.1 * (i + 1) as f32))
.collect();
let bf16_scales_data: Vec<half::bf16> = (0..n_groups * out_features)
.map(|i| half::bf16::from_f32(0.5 + 0.01 * (i as f32)))
.collect();
let f32_scales_data: Vec<f32> = (0..n_groups * out_features)
.map(|i| 0.25 + 0.001 * (i as f32))
.collect();
let qw_a = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_a = Array::from_slice::<half::f16>(&f16_scales_data, &(n_groups, out_features)).unwrap();
let qw_b = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_b = Array::from_slice::<half::bf16>(&bf16_scales_data, &(n_groups, out_features)).unwrap();
let qw_c = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_c = Array::from_slice::<f32>(&f32_scales_data, &(n_groups, out_features)).unwrap();
let weights: Weights = HashMap::from([
("layer_a.qweight".to_string(), qw_a),
("layer_a.scales".to_string(), sc_a),
("layer_b.qweight".to_string(), qw_b),
("layer_b.scales".to_string(), sc_b),
("layer_c.qweight".to_string(), qw_c),
("layer_c.scales".to_string(), sc_c),
]);
let prefixes: Vec<String> = vec![
"layer_a".to_string(),
"layer_b".to_string(),
"layer_c".to_string(),
];
validate_awq_scales_are_floating(&weights, &prefixes).expect("all floating, must pass");
let resolved = resolve_awq_model_dtype(&weights, &prefixes)
.unwrap()
.expect("some dtype");
assert_eq!(
resolved,
Dtype::F32,
"F16+BF16+F32 must stay at F32 (F32 already > BF16 rank, no escalation), got {resolved:?}"
);
}
#[test]
fn resolve_awq_model_dtype_escalates_f16_plus_bf16_with_f64_stays_at_f64() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let f16_scales_data: Vec<half::f16> = (0..n_groups * out_features)
.map(|i| half::f16::from_f32(0.1 * (i + 1) as f32))
.collect();
let bf16_scales_data: Vec<half::bf16> = (0..n_groups * out_features)
.map(|i| half::bf16::from_f32(0.5 + 0.01 * (i as f32)))
.collect();
let f64_scales_data: Vec<f64> = (0..n_groups * out_features)
.map(|i| 0.25 + 0.001 * (i as f64))
.collect();
let qw_a = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_a = Array::from_slice::<half::f16>(&f16_scales_data, &(n_groups, out_features)).unwrap();
let qw_b = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_b = Array::from_slice::<half::bf16>(&bf16_scales_data, &(n_groups, out_features)).unwrap();
let qw_c = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_c = Array::from_slice::<f64>(&f64_scales_data, &(n_groups, out_features)).unwrap();
let weights: Weights = HashMap::from([
("layer_a.qweight".to_string(), qw_a),
("layer_a.scales".to_string(), sc_a),
("layer_b.qweight".to_string(), qw_b),
("layer_b.scales".to_string(), sc_b),
("layer_c.qweight".to_string(), qw_c),
("layer_c.scales".to_string(), sc_c),
]);
let prefixes: Vec<String> = vec![
"layer_a".to_string(),
"layer_b".to_string(),
"layer_c".to_string(),
];
validate_awq_scales_are_floating(&weights, &prefixes).expect("all floating, must pass");
let resolved = resolve_awq_model_dtype(&weights, &prefixes)
.unwrap()
.expect("some dtype");
assert_eq!(
resolved,
Dtype::F64,
"F16+BF16+F64 must stay at F64 (F64 already > BF16 rank, no escalation), got {resolved:?}"
);
}
#[test]
fn transform_awq_weights_preserves_f16_precision_when_mixed_with_bf16() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let f16_value = half::f16::from_bits(0x3C01);
assert_eq!(
f16_value.to_f32(),
1.0 + (2.0_f32).powi(-10),
"F16 fixture value must be exactly 1 + 2^-10"
);
let bf_round = half::bf16::from_f32(f16_value.to_f32());
assert_eq!(
bf_round.to_f32(),
1.0,
"pre-condition: casting F16 1.0009765625 → BF16 must truncate to 1.0 \
(this is the lossy behavior the F16+BF16→F32 escalation prevents)"
);
let f16_scales_data: Vec<half::f16> = vec![f16_value; n_groups * out_features];
let bf16_scales_data: Vec<half::bf16> = (0..n_groups * out_features)
.map(|i| half::bf16::from_f32(0.5 + 0.01 * (i as f32)))
.collect();
let qw_a = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_a = Array::from_slice::<half::f16>(&f16_scales_data, &(n_groups, out_features)).unwrap();
let qw_b = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_b = Array::from_slice::<half::bf16>(&bf16_scales_data, &(n_groups, out_features)).unwrap();
let weights: Weights = HashMap::from([
("layer_a.qweight".to_string(), qw_a),
("layer_a.scales".to_string(), sc_a),
("layer_b.qweight".to_string(), qw_b),
("layer_b.scales".to_string(), sc_b),
]);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false, version: "gemm".into(),
};
let (out, _) = transform_awq_weights(weights, &cfg).expect("transform must succeed");
let mut sc_a_out = out
.get("layer_a.scales")
.expect("converted layer_a.scales present")
.try_clone()
.unwrap();
assert_eq!(
sc_a_out.dtype().unwrap(),
Dtype::F32,
"unified dtype must be F32 under the F16+BF16→F32 escalation"
);
let vals: Vec<f32> = sc_a_out.to_vec().expect("read back as F32");
for (i, &v) in vals.iter().enumerate() {
assert_eq!(
v,
1.0 + (2.0_f32).powi(-10),
"layer_a.scales[{i}] = {v} (bits 0x{:08X}) — F16 1.0009765625 must NOT have \
been truncated through BF16 (would land at 1.0 == 0x3F800000)",
v.to_bits()
);
}
let sc_b_out = out
.get("layer_b.scales")
.expect("converted layer_b.scales present");
assert_eq!(
sc_b_out.dtype().unwrap(),
Dtype::F32,
"layer_b.scales must also be unified to F32"
);
}
#[test]
fn transform_awq_weights_preserves_f16_precision_with_bf16_in_reversed_prefix_order() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let f16_value = half::f16::from_bits(0x3C01);
let f16_scales_data: Vec<half::f16> = vec![f16_value; n_groups * out_features];
let bf16_scales_data: Vec<half::bf16> = (0..n_groups * out_features)
.map(|i| half::bf16::from_f32(0.5 + 0.01 * (i as f32)))
.collect();
let qw_alpha = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_alpha =
Array::from_slice::<half::bf16>(&bf16_scales_data, &(n_groups, out_features)).unwrap();
let qw_zeta = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_zeta =
Array::from_slice::<half::f16>(&f16_scales_data, &(n_groups, out_features)).unwrap();
let weights: Weights = HashMap::from([
("alpha.qweight".to_string(), qw_alpha),
("alpha.scales".to_string(), sc_alpha),
("zeta.qweight".to_string(), qw_zeta),
("zeta.scales".to_string(), sc_zeta),
]);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false,
version: "gemm".into(),
};
let (out, _) = transform_awq_weights(weights, &cfg).expect("transform must succeed");
let mut sc_zeta_out = out
.get("zeta.scales")
.expect("converted zeta.scales present")
.try_clone()
.unwrap();
assert_eq!(
sc_zeta_out.dtype().unwrap(),
Dtype::F32,
"unified dtype must be F32 regardless of prefix order"
);
let vals: Vec<f32> = sc_zeta_out.to_vec().expect("read back as F32");
for (i, &v) in vals.iter().enumerate() {
assert_eq!(
v,
1.0 + (2.0_f32).powi(-10),
"zeta.scales[{i}] = {v} — F16 precision must be preserved in reversed-order layout"
);
}
}
#[test]
fn transform_awq_weights_rejects_collision_with_stale_weight() {
let mut weights = awq_gemm_fixture_weights();
let stale = Array::from_slice::<f32>(&[0.0_f32; 16], &(2usize, 8)).unwrap();
weights.insert("layer.weight".to_string(), stale);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &cfg).unwrap_err();
let Error::KeyCollision(p) = &err else {
panic!("expected Error::KeyCollision, got: {err:?}");
};
assert_eq!(p.key(), "layer.weight");
assert!(
p.context().contains(".qweight") && p.context().contains(".weight"),
"context must name both the qweight and weight, got: {}",
p.context()
);
}
#[test]
fn transform_awq_weights_rejects_collision_with_stale_biases() {
let mut weights = awq_gemm_fixture_weights();
let stale = Array::from_slice::<f32>(&[0.0_f32; 8], &(8usize,)).unwrap();
weights.insert("layer.biases".to_string(), stale);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &cfg).unwrap_err();
let Error::KeyCollision(p) = &err else {
panic!("expected Error::KeyCollision, got: {err:?}");
};
assert_eq!(p.key(), "layer.biases");
assert!(
p.context().contains(".qweight") && p.context().contains(".biases"),
"context must name both the qweight and biases, got: {}",
p.context()
);
}
#[test]
fn transform_awq_weights_accepts_unrelated_weight_keys() {
let mut weights = awq_gemm_fixture_weights();
let pt = Array::from_slice::<f32>(&[1.0_f32, 2.0, 3.0, 4.0], &(2usize, 2)).unwrap();
weights.insert("embed_tokens.weight".to_string(), pt);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemm".into(),
};
let (out, _) = transform_awq_weights(weights, &cfg).expect("unrelated .weight must pass");
assert!(
out.contains_key("layer.weight"),
"AWQ-converted .weight must be present"
);
assert!(
out.contains_key("embed_tokens.weight"),
"unrelated .weight must be preserved"
);
}
#[test]
fn transform_awq_weights_rejects_collision_with_both_stale_keys() {
let mut weights = awq_gemm_fixture_weights();
let stale_w = Array::from_slice::<f32>(&[0.0_f32; 16], &(2usize, 8)).unwrap();
let stale_b = Array::from_slice::<f32>(&[0.0_f32; 8], &(8usize,)).unwrap();
weights.insert("layer.weight".to_string(), stale_w);
weights.insert("layer.biases".to_string(), stale_b);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &cfg).unwrap_err();
assert!(
matches!(err, Error::KeyCollision(_)),
"must reject with KeyCollision, got: {err:?}"
);
let mut weights2 = awq_gemm_fixture_weights();
let stale_b2 = Array::from_slice::<f32>(&[0.0_f32; 8], &(8usize,)).unwrap();
weights2.insert("layer.biases".to_string(), stale_b2);
let err2 = transform_awq_weights(weights2, &cfg).unwrap_err();
match err2 {
Error::KeyCollision(ref p) => {
assert!(
p.key().contains("layer.biases"),
"must name the .biases collision when .weight is absent, got key={:?}",
p.key()
);
}
other => panic!("expected Error::KeyCollision, got: {other:?}"),
}
}
#[test]
fn transform_awq_weights_does_not_widen_passthrough_bf16_tensor() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let bf16_scales_data: Vec<half::bf16> = (0..n_groups * out_features)
.map(|i| half::bf16::from_f32(0.5 + 0.01 * (i as f32)))
.collect();
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc = Array::from_slice::<half::bf16>(&bf16_scales_data, &(n_groups, out_features)).unwrap();
let pt_shape = (100usize, 100usize);
let pt_data: Vec<half::bf16> = (0..pt_shape.0 * pt_shape.1)
.map(|i| half::bf16::from_f32(0.001 * (i as f32 % 1000.0)))
.collect();
let pt = Array::from_slice::<half::bf16>(&pt_data, &pt_shape).unwrap();
let weights: Weights = HashMap::from([
("layer.qweight".to_string(), qw),
("layer.scales".to_string(), sc),
("embed_tokens.weight".to_string(), pt),
]);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false,
version: "gemm".into(),
};
let (out, _) = transform_awq_weights(weights, &cfg).expect("transform");
let sc_out = out.get("layer.scales").expect("layer.scales generated");
assert_eq!(
sc_out.dtype().unwrap(),
Dtype::BF16,
"BF16-only AWQ scales must resolve to BF16, got {:?}",
sc_out.dtype().unwrap()
);
let mut pt_out = out
.get("embed_tokens.weight")
.expect("pass-through embed_tokens.weight preserved")
.try_clone()
.unwrap();
assert_eq!(
pt_out.dtype().unwrap(),
Dtype::BF16,
"pass-through BF16 tensor must NOT be widened by unification"
);
assert_eq!(
pt_out.shape(),
vec![pt_shape.0, pt_shape.1],
"pass-through shape preserved"
);
let pt_back: Vec<half::bf16> = pt_out.to_vec().expect("read pass-through as BF16");
assert_eq!(
pt_back.len(),
pt_data.len(),
"pass-through element count preserved"
);
for (i, (&got, &want)) in pt_back.iter().zip(pt_data.iter()).enumerate() {
assert_eq!(
got.to_bits(),
want.to_bits(),
"pass-through value at index {i} must be byte-identical (got 0x{:04X}, want 0x{:04X})",
got.to_bits(),
want.to_bits()
);
}
}
#[test]
fn transform_awq_weights_does_not_widen_passthrough_f16_tensor_when_mixed_with_bf16_awq_scales() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let f16_scales_data: Vec<half::f16> = (0..n_groups * out_features)
.map(|i| half::f16::from_f32(0.1 * (i + 1) as f32))
.collect();
let bf16_scales_data: Vec<half::bf16> = (0..n_groups * out_features)
.map(|i| half::bf16::from_f32(0.5 + 0.01 * (i as f32)))
.collect();
let qw_a = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_a = Array::from_slice::<half::f16>(&f16_scales_data, &(n_groups, out_features)).unwrap();
let qw_b = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_b = Array::from_slice::<half::bf16>(&bf16_scales_data, &(n_groups, out_features)).unwrap();
let lm_head_shape = (32usize, 16usize);
let lm_head_data: Vec<half::f16> = (0..lm_head_shape.0 * lm_head_shape.1)
.map(|i| half::f16::from_f32(0.01 * (i as f32 % 100.0)))
.collect();
let lm_head = Array::from_slice::<half::f16>(&lm_head_data, &lm_head_shape).unwrap();
let weights: Weights = HashMap::from([
("layer_a.qweight".to_string(), qw_a),
("layer_a.scales".to_string(), sc_a),
("layer_b.qweight".to_string(), qw_b),
("layer_b.scales".to_string(), sc_b),
("lm_head.weight".to_string(), lm_head),
]);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false,
version: "gemm".into(),
};
let (out, _) = transform_awq_weights(weights, &cfg).expect("transform");
let sc_a_out = out.get("layer_a.scales").expect("layer_a.scales");
assert_eq!(
sc_a_out.dtype().unwrap(),
Dtype::F32,
"AWQ-generated layer_a.scales must be cast to F32 under mixed-half escalation"
);
let sc_b_out = out.get("layer_b.scales").expect("layer_b.scales");
assert_eq!(
sc_b_out.dtype().unwrap(),
Dtype::F32,
"AWQ-generated layer_b.scales must be cast to F32 under mixed-half escalation"
);
let bi_a_out = out.get("layer_a.biases").expect("layer_a.biases");
assert_eq!(
bi_a_out.dtype().unwrap(),
Dtype::F32,
"AWQ-generated layer_a.biases must be cast to F32 under mixed-half escalation"
);
let bi_b_out = out.get("layer_b.biases").expect("layer_b.biases");
assert_eq!(
bi_b_out.dtype().unwrap(),
Dtype::F32,
"AWQ-generated layer_b.biases must be cast to F32 under mixed-half escalation"
);
let mut lm_head_out = out
.get("lm_head.weight")
.expect("pass-through lm_head.weight preserved")
.try_clone()
.unwrap();
assert_eq!(
lm_head_out.dtype().unwrap(),
Dtype::F16,
"pass-through F16 tensor must NOT be widened to F32 by the AWQ \
mixed-half escalation — only the AWQ-generated .scales/.biases get widened"
);
assert_eq!(
lm_head_out.shape(),
vec![lm_head_shape.0, lm_head_shape.1],
"pass-through lm_head shape preserved"
);
let lm_back: Vec<half::f16> = lm_head_out.to_vec().expect("read lm_head as F16");
for (i, (&got, &want)) in lm_back.iter().zip(lm_head_data.iter()).enumerate() {
assert_eq!(
got.to_bits(),
want.to_bits(),
"lm_head.weight[{i}] must be byte-identical (got 0x{:04X}, want 0x{:04X})",
got.to_bits(),
want.to_bits()
);
}
}
#[test]
fn transform_awq_weights_widens_only_generated_scales_and_biases() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let bf16_scales_data: Vec<half::bf16> = (0..n_groups * out_features)
.map(|i| half::bf16::from_f32(0.5 + 0.01 * (i as f32)))
.collect();
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc = Array::from_slice::<half::bf16>(&bf16_scales_data, &(n_groups, out_features)).unwrap();
let pt_shape = (16usize, 8usize);
let pt_data: Vec<half::f16> = (0..pt_shape.0 * pt_shape.1)
.map(|i| half::f16::from_f32(0.01 * (i as f32)))
.collect();
let pt = Array::from_slice::<half::f16>(&pt_data, &pt_shape).unwrap();
let weights: Weights = HashMap::from([
("layer.qweight".to_string(), qw),
("layer.scales".to_string(), sc),
("model.norm.weight".to_string(), pt),
]);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false,
version: "gemm".into(),
};
let (out, _) = transform_awq_weights(weights, &cfg).expect("transform");
let sc_out = out.get("layer.scales").expect("layer.scales");
assert_eq!(
sc_out.dtype().unwrap(),
Dtype::BF16,
"AWQ-generated .scales is at the resolved BF16 model_dtype"
);
let bi_out = out.get("layer.biases").expect("layer.biases");
assert_eq!(
bi_out.dtype().unwrap(),
Dtype::BF16,
"AWQ-generated .biases is at the resolved BF16 model_dtype"
);
let mut pt_out = out
.get("model.norm.weight")
.expect("pass-through model.norm.weight")
.try_clone()
.unwrap();
assert_eq!(
pt_out.dtype().unwrap(),
Dtype::F16,
"pass-through F16 tensor must NOT be cast to the resolved BF16 — \
unification is scoped to AWQ-generated outputs only"
);
let pt_back: Vec<half::f16> = pt_out.to_vec().expect("read pass-through as F16");
for (i, (&got, &want)) in pt_back.iter().zip(pt_data.iter()).enumerate() {
assert_eq!(
got.to_bits(),
want.to_bits(),
"pass-through value at index {i} must be byte-identical"
);
}
}
#[test]
fn transform_awq_weights_preserves_resident_size_for_passthrough() {
fn dtype_size(d: Dtype) -> usize {
match d {
Dtype::Bool | Dtype::U8 | Dtype::I8 => 1,
Dtype::U16 | Dtype::I16 | Dtype::F16 | Dtype::BF16 => 2,
Dtype::U32 | Dtype::I32 | Dtype::F32 => 4,
Dtype::U64 | Dtype::I64 | Dtype::F64 | Dtype::Complex64 => 8,
}
}
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let f16_scales_data: Vec<half::f16> = (0..n_groups * out_features)
.map(|i| half::f16::from_f32(0.1 * (i + 1) as f32))
.collect();
let bf16_scales_data: Vec<half::bf16> = (0..n_groups * out_features)
.map(|i| half::bf16::from_f32(0.5 + 0.01 * (i as f32)))
.collect();
let qw_a = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_a = Array::from_slice::<half::f16>(&f16_scales_data, &(n_groups, out_features)).unwrap();
let qw_b = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc_b = Array::from_slice::<half::bf16>(&bf16_scales_data, &(n_groups, out_features)).unwrap();
let pt_shape = (256usize, 256usize);
let pt_data: Vec<half::bf16> = (0..pt_shape.0 * pt_shape.1)
.map(|i| half::bf16::from_f32((i as f32) * 1e-4))
.collect();
let pt = Array::from_slice::<half::bf16>(&pt_data, &pt_shape).unwrap();
let pt_size_pre = pt.size() * dtype_size(pt.dtype().unwrap());
assert_eq!(
pt_size_pre,
pt_shape.0 * pt_shape.1 * 2,
"pre-transform BF16 pass-through resident size sanity"
);
let weights: Weights = HashMap::from([
("layer_a.qweight".to_string(), qw_a),
("layer_a.scales".to_string(), sc_a),
("layer_b.qweight".to_string(), qw_b),
("layer_b.scales".to_string(), sc_b),
("embed_tokens.weight".to_string(), pt),
]);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false,
version: "gemm".into(),
};
let (out, _) = transform_awq_weights(weights, &cfg).expect("transform");
let pt_out = out
.get("embed_tokens.weight")
.expect("pass-through preserved");
assert_eq!(
pt_out.dtype().unwrap(),
Dtype::BF16,
"pass-through must remain BF16 (not widened to F32 by the mixed-half escalation)"
);
assert_eq!(
pt_out.shape(),
vec![pt_shape.0, pt_shape.1],
"pass-through shape preserved"
);
let pt_size_post = pt_out.size() * dtype_size(pt_out.dtype().unwrap());
assert_eq!(
pt_size_post,
pt_size_pre,
"pass-through resident size must be IDENTICAL post-transform \
(an unscoped cast would double it from {pt_size_pre} to {} bytes)",
pt_size_pre * 2
);
}
#[test]
fn awq_load_config_parses_quantization_json() {
let json = r#"{
"bits": 4,
"group_size": 128,
"zero_point": true,
"version": "gemm"
}"#;
let cfg: AwqLoadConfig = serde_json::from_str(json).expect("parse");
assert_eq!(cfg.bits, 4);
assert_eq!(cfg.group_size, 128);
assert!(cfg.zero_point);
assert_eq!(cfg.version, "gemm");
}
#[test]
fn awq_load_config_defaults_when_keys_absent() {
let cfg: AwqLoadConfig = serde_json::from_str("{}").expect("parse");
assert_eq!(cfg.bits, 4);
assert_eq!(cfg.group_size, 128);
assert!(cfg.zero_point);
assert_eq!(cfg.version, "");
}
#[test]
fn awq_load_config_default_matches_serde_default() {
let from_default = AwqLoadConfig::default();
let from_serde: AwqLoadConfig = serde_json::from_str("{}").unwrap();
assert_eq!(from_default, from_serde);
}
#[test]
fn quant_mode_as_str_covers_all_variants() {
assert_eq!(QuantMode::Affine.as_str(), "affine");
assert_eq!(QuantMode::Mxfp4.as_str(), "mxfp4");
assert_eq!(QuantMode::Mxfp8.as_str(), "mxfp8");
assert_eq!(QuantMode::Nvfp4.as_str(), "nvfp4");
assert_eq!(QuantMode::default(), QuantMode::Affine);
for (tag, mode) in [
("affine", QuantMode::Affine),
("mxfp4", QuantMode::Mxfp4),
("mxfp8", QuantMode::Mxfp8),
("nvfp4", QuantMode::Nvfp4),
] {
let json = format!("{{ \"group_size\": 32, \"bits\": 4, \"mode\": {tag:?} }}");
let q = parse_quantization(&format!("{{ \"quantization\": {json} }}"))
.unwrap()
.unwrap()
.quantization
.unwrap();
assert_eq!(q.mode, mode, "tag {tag:?} must deserialize to {mode:?}");
assert_eq!(format!("{mode}"), tag);
}
}
#[test]
fn per_layer_ref_exposes_override_map() {
let flat = PerLayerQuantization::from_global(Quantization::affine(64, 4));
assert!(flat.per_layer_ref().is_empty());
let mut overrides = HashMap::new();
overrides.insert("model.embed_tokens".to_string(), QuantizationOption::Skip);
overrides.insert(
"model.layers.0.q_proj".to_string(),
QuantizationOption::Quantize(Quantization::affine(32, 4)),
);
let plq = PerLayerQuantization::new(Some(Quantization::affine(64, 4)), overrides);
let map = plq.per_layer_ref();
assert_eq!(map.len(), 2);
assert_eq!(
map.get("model.embed_tokens").copied(),
Some(QuantizationOption::Skip)
);
assert!(matches!(
map.get("model.layers.0.q_proj"),
Some(QuantizationOption::Quantize(_))
));
}
#[test]
fn quantization_block_non_object_errors() {
let cfg_json = r#"{ "quantization": [1, 2, 3] }"#;
let err = parse_quantization(cfg_json).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("must be a JSON object"),
"error should explain the block must be a JSON object, got: {msg}"
);
}
#[test]
fn quantization_per_layer_true_is_ignored() {
let cfg_json = r#"{
"quantization": {
"group_size": 64,
"bits": 4,
"model.layers.0.kept": true,
"model.layers.1.skipped": false
}
}"#;
let plq = parse_quantization(cfg_json).unwrap().unwrap();
assert_eq!(plq.per_layer_ref().len(), 1);
assert!(!plq.per_layer_ref().contains_key("model.layers.0.kept"));
assert_eq!(
plq.per_layer_ref().get("model.layers.1.skipped").copied(),
Some(QuantizationOption::Skip)
);
assert_eq!(
plq.quantization_for("model.layers.0.kept"),
Some(Quantization::affine(64, 4))
);
}
#[test]
fn quantization_per_layer_scalar_value_errors() {
let cfg_json = r#"{
"quantization": {
"group_size": 64,
"bits": 4,
"model.layers.0.bad": 7
}
}"#;
let err = parse_quantization(cfg_json).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("must be `false` or a quantization object"),
"error should explain the per-layer value contract, got: {msg}"
);
assert!(
msg.contains("model.layers.0.bad"),
"error should name the offending key, got: {msg}"
);
}
#[test]
fn quantize_weights_valid_triple_resolves_via_per_layer_override() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 4], &[n_rows, 4]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.special.weight".to_string(), w);
weights.insert("model.special.scales".to_string(), scales);
let mut per_layer = HashMap::new();
per_layer.insert(
"model.special".to_string(),
QuantizationOption::Quantize(Quantization {
group_size: 32,
bits: 4,
mode: QuantMode::Mxfp4,
}),
);
let cfg = PerLayerQuantization::new(Some(Quantization::affine(64, 4)), per_layer);
let out = quantize_weights(weights, &cfg, &default_eligible)
.expect("triple resolves via per-layer mxfp4 override, passes through");
let w_out = out.get("model.special.weight").expect(".weight");
assert_eq!(w_out.dtype().unwrap(), Dtype::U32);
assert!(out.contains_key("model.special.scales"));
assert!(!out.contains_key("model.special.biases"));
}
#[test]
fn quantize_weights_triple_with_no_resolvable_params_errors() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.foo.weight".to_string(), w);
weights.insert("model.foo.scales".to_string(), scales);
weights.insert("model.foo.biases".to_string(), biases);
let cfg = PerLayerQuantization::new(None, HashMap::new());
let err = quantize_weights(weights, &cfg, &default_eligible).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.foo"),
"LayerKeyed must name the layer, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::InvariantViolation(_)),
"inner must be InvariantViolation when params are unresolvable, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn quantize_weights_scales_rank_mismatch_errors() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1, 1]);
let biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.foo.weight".to_string(), w);
weights.insert("model.foo.scales".to_string(), scales);
weights.insert("model.foo.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let err = quantize_weights(weights, &cfg, &default_eligible).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.foo"),
"LayerKeyed must name the layer, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::LengthMismatch(_)),
"inner must be LengthMismatch for `.scales` rank != `.weight` rank, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn quantize_weights_rank1_weight_passes_through() {
let w = arr_f32(&[1.0_f32, 2.0, 3.0, 4.0], &[4]);
let mut weights: Weights = HashMap::new();
weights.insert("model.scalar_like.weight".to_string(), w);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let out = quantize_weights(weights, &cfg, &default_eligible).expect("rank-1 weight passes");
let pass = out.get("model.scalar_like.weight").expect(".weight");
assert_eq!(pass.shape(), vec![4]);
assert_eq!(pass.dtype().unwrap(), Dtype::F32);
assert!(!out.contains_key("model.scalar_like.scales"));
}
#[test]
fn quantize_weights_negative_group_size_errors() {
let n_rows = 2_usize;
let last = 64_usize;
let w = arr_f32(&vec![0.5_f32; n_rows * last], &[n_rows, last]);
let mut weights: Weights = HashMap::new();
weights.insert("model.bad_gs.weight".to_string(), w);
let mut per_layer = HashMap::new();
per_layer.insert(
"model.bad_gs".to_string(),
QuantizationOption::Quantize(Quantization {
group_size: -1,
bits: 4,
mode: QuantMode::Affine,
}),
);
let cfg = PerLayerQuantization::new(Some(Quantization::affine(64, 4)), per_layer);
let err = quantize_weights(weights, &cfg, &default_eligible).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.bad_gs"),
"LayerKeyed must name the layer, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::OutOfRange(_)),
"inner must be OutOfRange for negative group_size, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn quantize_weights_zero_group_size_passes_through() {
let n_rows = 2_usize;
let last = 64_usize;
let w = arr_f32(&vec![0.5_f32; n_rows * last], &[n_rows, last]);
let mut weights: Weights = HashMap::new();
weights.insert("model.zero_gs.weight".to_string(), w);
let mut per_layer = HashMap::new();
per_layer.insert(
"model.zero_gs".to_string(),
QuantizationOption::Quantize(Quantization {
group_size: 0,
bits: 4,
mode: QuantMode::Affine,
}),
);
let cfg = PerLayerQuantization::new(Some(Quantization::affine(64, 4)), per_layer);
let out =
quantize_weights(weights, &cfg, &default_eligible).expect("group_size 0 skips, passes through");
let pass = out.get("model.zero_gs.weight").expect(".weight");
assert_eq!(pass.shape(), vec![n_rows, last]);
assert_eq!(pass.dtype().unwrap(), Dtype::F32);
assert!(!out.contains_key("model.zero_gs.scales"));
}
#[test]
fn dequantize_weights_orphan_biases_without_weight_passes_through() {
let biases = arr_f32(&[0.5_f32, 1.5], &[2, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.lonely.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let out = dequantize_weights(weights, &cfg).expect("orphan biases pass through");
let pass = out.get("model.lonely.biases").expect(".biases preserved");
assert_eq!(pass.shape(), vec![2, 1]);
assert!(!out.contains_key("model.lonely.weight"));
}
#[test]
fn dequantize_weights_rank1_uint32_weight_with_biases_passes_through() {
let w = arr_u32(&[1_u32, 2, 3, 4], &[4]);
let biases = arr_f32(&[0.5_f32; 4], &[4]);
let mut weights: Weights = HashMap::new();
weights.insert("model.r1.weight".to_string(), w);
weights.insert("model.r1.biases".to_string(), biases);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let out =
dequantize_weights(weights, &cfg).expect("rank-1 uint32 weight + biases passes through");
let w_out = out.get("model.r1.weight").expect(".weight preserved");
assert_eq!(w_out.dtype().unwrap(), Dtype::U32);
assert_eq!(w_out.shape(), vec![4]);
assert!(out.contains_key("model.r1.biases"));
}
#[test]
fn dequantize_weights_non_suffix_key_passes_through() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let gamma = arr_f32(&[3.0_f32, 4.0], &[2]);
let mut weights: Weights = HashMap::new();
weights.insert("model.q.weight".to_string(), w);
weights.insert("model.q.scales".to_string(), scales);
weights.insert("model.q.biases".to_string(), biases);
weights.insert("model.norm.gamma".to_string(), gamma);
let cfg = PerLayerQuantization::from_global(Quantization::affine(64, 4));
let out = dequantize_weights(weights, &cfg).expect("dequantize");
let mut g_out = out
.get("model.norm.gamma")
.expect("pass-through gamma")
.try_clone()
.unwrap();
assert_eq!(g_out.shape(), vec![2]);
assert_eq!(g_out.to_vec::<f32>().unwrap(), vec![3.0_f32, 4.0]);
let deq = out.get("model.q.weight").expect("dequantized .weight");
assert_eq!(deq.shape(), vec![n_rows, 64]);
assert!(!out.contains_key("model.q.scales"));
}
#[test]
fn dequantize_weights_unresolvable_params_errors() {
let n_rows = 2_usize;
let w = arr_u32(&vec![0_u32; n_rows * 8], &[n_rows, 8]);
let scales = arr_f32(&vec![1.0_f32; n_rows], &[n_rows, 1]);
let biases = arr_f32(&vec![0.0_f32; n_rows], &[n_rows, 1]);
let mut weights: Weights = HashMap::new();
weights.insert("model.q.weight".to_string(), w);
weights.insert("model.q.scales".to_string(), scales);
weights.insert("model.q.biases".to_string(), biases);
let cfg = PerLayerQuantization::new(None, HashMap::new());
let err = dequantize_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("model.q"),
"LayerKeyed must name the layer, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::InvariantViolation(_)),
"inner must be InvariantViolation for unresolvable params, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn floating_dtype_precision_rank_orders_and_sentinels() {
assert!(floating_dtype_precision_rank(Dtype::F64) > floating_dtype_precision_rank(Dtype::F32));
assert!(floating_dtype_precision_rank(Dtype::F32) > floating_dtype_precision_rank(Dtype::BF16));
assert!(floating_dtype_precision_rank(Dtype::BF16) > floating_dtype_precision_rank(Dtype::F16));
assert_eq!(floating_dtype_precision_rank(Dtype::U32), 0);
assert_eq!(floating_dtype_precision_rank(Dtype::I32), 0);
assert!(floating_dtype_precision_rank(Dtype::F16) > 0);
}
#[test]
fn transform_awq_weights_group_size_overflows_i32_errors() {
let weights: Weights = HashMap::new();
let config = AwqLoadConfig {
bits: 4,
group_size: (i32::MAX as u32) + 1, zero_point: true,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &config).unwrap_err();
let Error::OutOfRange(p) = &err else {
panic!("expected Error::OutOfRange, got {err:?}");
};
assert!(
p.context().contains("group_size"),
"context must name the group_size rule, got: {}",
p.context()
);
assert!(
p.requirement().contains("i32"),
"requirement must mention the i32 bound, got: {}",
p.requirement()
);
}
#[test]
fn transform_awq_weights_rejects_non_int_qweight_dtype() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let qw = Array::from_slice::<f32>(&vec![0.0_f32; in_features], &(in_features, 1)).unwrap();
let sc = Array::from_slice::<f32>(
&vec![1.0_f32; n_groups * out_features],
&(n_groups, out_features),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("layer.qweight"),
"LayerKeyed must name the qweight key, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::UnsupportedDtype(_)),
"inner must be UnsupportedDtype for non-int qweight, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn transform_awq_weights_rejects_non_2d_qweight() {
let n_groups = 2usize;
let out_features = 8usize;
let qw = Array::from_slice::<u32>(&[0u32; 8], &(8usize,)).unwrap();
let sc = Array::from_slice::<f32>(
&vec![1.0_f32; n_groups * out_features],
&(n_groups, out_features),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("layer.qweight"),
"LayerKeyed must name the qweight key, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::RankMismatch(_)),
"inner must be RankMismatch for non-2D qweight, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn transform_awq_weights_rejects_non_2d_scales() {
let in_features = 8usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc = Array::from_slice::<f32>(&[1.0_f32; 8], &(8usize,)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: false,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("layer.scales"),
"LayerKeyed must name the scales key, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::RankMismatch(_)),
"inner must be RankMismatch for non-2D scales, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn transform_awq_weights_rejects_zero_group_size() {
let in_features = 8usize;
let out_features = 8usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc = Array::from_slice::<f32>(&vec![1.0_f32; out_features], &(1usize, out_features)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
let config = AwqLoadConfig {
bits: 4,
group_size: 0,
zero_point: false,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &config).unwrap_err();
let Error::OutOfRange(p) = &err else {
panic!("expected Error::OutOfRange, got {err:?}");
};
assert!(
p.context().contains("group_size"),
"context names group_size, got: {}",
p.context()
);
assert!(
p.requirement().contains("> 0"),
"requirement names the > 0 bound, got: {}",
p.requirement()
);
}
#[test]
fn transform_awq_weights_rejects_indivisible_in_features() {
let in_features = 8usize;
let out_features = 8usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc = Array::from_slice::<f32>(&vec![1.0_f32; out_features], &(1usize, out_features)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
let config = AwqLoadConfig {
bits: 4,
group_size: 3,
zero_point: false,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &config).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("layer.qweight"),
"LayerKeyed must name the qweight key, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::DivisibilityConstraint(_)),
"inner must be DivisibilityConstraint, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn transform_awq_weights_rejects_non_int_qzeros_dtype() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc = Array::from_slice::<f32>(
&vec![1.0_f32; n_groups * out_features],
&(n_groups, out_features),
)
.unwrap();
let qz = Array::from_slice::<f32>(&vec![0.0_f32; n_groups], &(n_groups, 1)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
weights.insert("layer.qzeros".to_string(), qz);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("layer.qzeros"),
"LayerKeyed must name the qzeros key, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::UnsupportedDtype(_)),
"inner must be UnsupportedDtype for non-int qzeros, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn transform_awq_weights_rejects_mismatched_qzeros_shape() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let sc = Array::from_slice::<f32>(
&vec![1.0_f32; n_groups * out_features],
&(n_groups, out_features),
)
.unwrap();
let qz = Array::from_slice::<u32>(&[0u32, 0, 0], &(3usize, 1)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
weights.insert("layer.qzeros".to_string(), qz);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemm".into(),
};
let err = transform_awq_weights(weights, &cfg).unwrap_err();
match err {
Error::LayerKeyed(ref payload) => {
assert!(
payload.layer().contains("layer.qzeros"),
"LayerKeyed must name the qzeros key, got layer={:?}",
payload.layer()
);
assert!(
matches!(payload.inner(), Error::ShapePairMismatch(_)),
"inner must be ShapePairMismatch for wrong qzeros shape, got: {:?}",
payload.inner()
);
}
other => panic!("expected Error::LayerKeyed, got: {other:?}"),
}
}
#[test]
fn transform_awq_weights_orphan_scales_and_qzeros_pass_through() {
let mut weights = awq_gemm_fixture_weights();
let orphan_scales = Array::from_slice::<f32>(&[0.1_f32, 0.2], &(2usize,)).unwrap();
weights.insert("orphan.scales".to_string(), orphan_scales);
let orphan_qzeros = Array::from_slice::<u32>(&[0u32, 0], &(2usize,)).unwrap();
weights.insert("orphan.qzeros".to_string(), orphan_qzeros);
let cfg = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true,
version: "gemm".into(),
};
let (out, _) = transform_awq_weights(weights, &cfg).expect("orphans pass through");
assert!(out.contains_key("layer.weight"));
assert!(
out.contains_key("orphan.scales"),
"orphan .scales must pass through"
);
assert!(
out.contains_key("orphan.qzeros"),
"orphan .qzeros must pass through"
);
}
#[test]
fn transform_awq_weights_zero_point_true_without_qzeros_uses_symmetric() {
let in_features = 8usize;
let out_features = 8usize;
let n_groups = 2usize;
let qw = Array::from_slice::<u32>(&vec![0u32; in_features], &(in_features, 1)).unwrap();
let scales_data: Vec<f32> = vec![1.0_f32; n_groups * out_features];
let sc = Array::from_slice::<f32>(&scales_data, &(n_groups, out_features)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert("layer.qweight".to_string(), qw);
weights.insert("layer.scales".to_string(), sc);
let config = AwqLoadConfig {
bits: 4,
group_size: 4,
zero_point: true, version: "gemm".into(),
};
let (out, _) =
transform_awq_weights(weights, &config).expect("zero_point=true w/o qzeros falls to symmetric");
let mut biases_arr = out
.get("layer.biases")
.expect("layer.biases")
.try_clone()
.unwrap();
let biases: Vec<f32> = biases_arr.to_vec().unwrap();
for &b in &biases {
assert!(
(b + 8.0_f32).abs() < 1e-5,
"implicit-zero symmetric bias must be -8.0, got {b}"
);
}
}
#[test]
fn quantize_weights_scale_only_mxfp4_emits_weight_and_scales_no_biases() {
let group_size = 32_usize;
let n_rows = 4_usize;
let data: Vec<f32> = (0..n_rows * group_size)
.map(|i| (i as f32 / 64.0) - 1.0)
.collect();
let w = arr_f32(&data, &[n_rows, group_size]);
let mut weights: Weights = HashMap::new();
weights.insert("model.proj.weight".to_string(), w);
let cfg = PerLayerQuantization::from_global(Quantization {
group_size: group_size as i32,
bits: 4,
mode: QuantMode::Mxfp4,
});
let out = quantize_weights(weights, &cfg, &default_eligible).expect("mxfp4 quantize");
let w_q = out.get("model.proj.weight").expect(".weight");
assert_eq!(
w_q.dtype().unwrap(),
Dtype::U32,
"mxfp4-packed `.weight` must be uint32"
);
let scales = out.get("model.proj.scales").expect(".scales");
assert_eq!(scales.shape(), vec![n_rows, 1]);
assert!(
!out.contains_key("model.proj.biases"),
"mxfp4 is scale-only: `quantize_weights` must NOT emit a `.biases` entry"
);
}
#[test]
fn quantize_then_dequantize_mxfp4_scale_only_roundtrips() {
let group_size = 32_usize;
let n_rows = 4_usize;
let data: Vec<f32> = (0..n_rows * group_size)
.map(|i| (i as f32 / 64.0) - 1.0)
.collect();
let w = arr_f32(&data, &[n_rows, group_size]);
let mut weights: Weights = HashMap::new();
weights.insert("model.proj.weight".to_string(), w);
let cfg = PerLayerQuantization::from_global(Quantization {
group_size: group_size as i32,
bits: 4,
mode: QuantMode::Mxfp4,
});
let quantized = quantize_weights(weights, &cfg, &default_eligible).expect("mxfp4 quantize");
assert!(!quantized.contains_key("model.proj.biases"));
let dequantized = dequantize_weights(quantized, &cfg).expect("mxfp4 dequantize");
let deq = dequantized
.get("model.proj.weight")
.expect("round-tripped .weight")
.try_clone()
.unwrap();
assert_eq!(deq.shape(), vec![n_rows, group_size]);
assert_eq!(deq.dtype().unwrap(), Dtype::BF16);
let mut deq_f32 = deq.astype(Dtype::F32).unwrap();
let deq_vec: Vec<f32> = deq_f32.to_vec().unwrap();
let max_abs = data.iter().fold(0.0_f32, |m, &v| m.max(v.abs()));
for (g, e) in deq_vec.iter().zip(data.iter()) {
assert!(
(g - e).abs() <= 0.25 * max_abs + 1e-3,
"mxfp4 round-trip drift too large: got={g} want={e}"
);
}
}
#[test]
fn quantize_then_dequantize_affine_on_grid_recovers_values() {
let group_size = 64_usize;
let n_rows = 2_usize;
let data: Vec<f32> = (0..n_rows * group_size).map(|i| (i % 16) as f32).collect();
let w = arr_f32(&data, &[n_rows, group_size]);
let mut weights: Weights = HashMap::new();
weights.insert("model.grid.weight".to_string(), w);
let cfg = PerLayerQuantization::from_global(Quantization::affine(group_size as i32, 4));
let quantized = quantize_weights(weights, &cfg, &default_eligible).expect("affine quantize");
assert_eq!(
quantized.get("model.grid.scales").expect(".scales").shape(),
vec![n_rows, 1]
);
assert_eq!(
quantized
.get("model.grid.biases")
.expect(".biases (affine)")
.shape(),
vec![n_rows, 1]
);
let dequantized = dequantize_weights(quantized, &cfg).expect("affine dequantize");
let mut deq = dequantized
.get("model.grid.weight")
.expect("round-tripped .weight")
.try_clone()
.unwrap();
assert_eq!(deq.shape(), vec![n_rows, group_size]);
let deq_vec: Vec<f32> = deq.to_vec().unwrap();
for (g, e) in deq_vec.iter().zip(data.iter()) {
assert!(
(g - e).abs() < 1e-3,
"on-grid affine value must round-trip losslessly: got={g} want={e}"
);
}
}