use super::*;
fn kv4(n_steps: usize, dim: usize, base: f32) -> Array {
let total = n_steps * dim;
let data: Vec<f32> = (0..total).map(|i| base + i as f32).collect();
Array::from_slice::<f32>(&data, &(1usize, 1, n_steps, dim)).unwrap()
}
fn kv(vals: &[f32]) -> Array {
Array::from_slice::<f32>(vals, &(1usize, 1, vals.len(), 1)).unwrap()
}
#[allow(clippy::needless_pass_by_value)]
fn ranked(shape: Vec<usize>) -> Array {
let total: usize = shape.iter().product();
let data: Vec<f32> = (0..total).map(|i| i as f32).collect();
Array::from_slice::<f32>(&data, &shape).unwrap()
}
#[test]
fn default_is_empty_group64_bits8() {
let c = StandardQuantizedKvCache::default();
assert!(c.is_empty(), "fresh default cache holds no keys");
assert_eq!(c.offset(), 0);
assert_eq!(c.group_size(), 64, "mlx-lm default group_size");
assert_eq!(c.bits(), 8, "mlx-lm default bits");
assert_eq!(
c.meta_state(),
vec!["0".to_string(), "64".to_string(), "8".to_string()],
"meta_state serializes the default offset/group_size/bits"
);
assert!(c.quantized_state().unwrap().is_none());
}
#[test]
fn set_state_k_w_wrong_rank_uses_w_context_arm() {
let mut c = StandardQuantizedKvCache::default();
let bad_kw = ranked(vec![1, 4]); let ok = kv(&[1.0]); let st = vec![
bad_kw,
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
];
let err = c.set_state(st).unwrap_err();
match err {
Error::RankMismatch(p) => {
assert!(
p.context().contains("K w must be 4-D"),
"must select the K-side `w` context arm, got: {}",
p.context()
);
assert_eq!(p.actual(), 2, "observed rank is 2");
assert_eq!(p.actual_shape(), &[1, 4], "full observed shape carried");
}
other => panic!("expected RankMismatch, got {other:?}"),
}
assert!(
c.is_empty(),
"set_state must not mutate on the rank-gate Err"
);
assert_eq!(c.offset(), 0);
}
#[test]
fn set_state_k_scales_wrong_rank_uses_scales_context_arm() {
let mut c = StandardQuantizedKvCache::default();
let ok = kv(&[1.0]);
let bad_ks = ranked(vec![1, 1, 3]); let st = vec![
ok.try_clone().unwrap(),
bad_ks,
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
];
match c.set_state(st).unwrap_err() {
Error::RankMismatch(p) => {
assert!(
p.context().contains("K scales must be 4-D"),
"must select the K-side `scales` arm, got: {}",
p.context()
);
assert_eq!(p.actual(), 3);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
assert!(c.is_empty());
}
#[test]
fn set_state_k_biases_wrong_rank_uses_default_context_arm() {
let mut c = StandardQuantizedKvCache::default();
let ok = kv(&[1.0]);
let bad_kb = ranked(vec![5]); let st = vec![
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
bad_kb,
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
];
match c.set_state(st).unwrap_err() {
Error::RankMismatch(p) => {
assert!(
p.context().contains("K must be 4-D")
&& !p.context().contains("w must")
&& !p.context().contains("scales must"),
"biases must select the generic K-side `_` arm, got: {}",
p.context()
);
assert_eq!(p.actual(), 1);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
assert!(c.is_empty());
}
#[test]
fn set_state_v_w_wrong_rank_uses_v_w_context_arm() {
let mut c = StandardQuantizedKvCache::default();
let ok = kv(&[1.0]);
let bad_vw = ranked(vec![2, 2]); let st = vec![ok.try_clone().unwrap(), ok.try_clone().unwrap(), bad_vw, ok];
match c.set_state(st).unwrap_err() {
Error::RankMismatch(p) => {
assert!(
p.context().contains("V w must be 4-D"),
"must select the V-side `w` arm, got: {}",
p.context()
);
assert_eq!(p.actual(), 2);
assert_eq!(p.actual_shape(), &[2, 2]);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
assert!(c.is_empty());
}
#[test]
fn set_state_v_scales_wrong_rank_uses_v_scales_context_arm() {
let mut c = StandardQuantizedKvCache::default();
let ok = kv(&[1.0]);
let bad_vs = ranked(vec![1, 1, 1, 1, 1]); let st = vec![
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
bad_vs,
];
match c.set_state(st).unwrap_err() {
Error::RankMismatch(p) => {
assert!(
p.context().contains("V scales must be 4-D"),
"must select the V-side `scales` arm, got: {}",
p.context()
);
assert_eq!(p.actual(), 5);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
assert!(c.is_empty());
}
#[test]
fn set_state_v_biases_wrong_rank_uses_v_default_context_arm() {
let mut c = StandardQuantizedKvCache::default();
let ok = kv(&[1.0]);
let bad_vb = ranked(vec![7]); let st = vec![
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
ok.try_clone().unwrap(),
bad_vb,
];
match c.set_state(st).unwrap_err() {
Error::RankMismatch(p) => {
assert!(
p.context().contains("V must be 4-D")
&& !p.context().contains("w must")
&& !p.context().contains("scales must"),
"V biases must select the generic V-side `_` arm, got: {}",
p.context()
);
assert_eq!(p.actual(), 1);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
assert!(c.is_empty());
}
#[test]
fn update_quantized_offset_overflow_is_rejected() {
let mut c = StandardQuantizedKvCache {
keys: None,
values: None,
offset: usize::MAX,
group_size: 64,
bits: 8,
};
let t = kv(&[1.0, 2.0]); let err = c.update_quantized(&t, &t).unwrap_err();
match err {
Error::ArithmeticOverflow(p) => {
assert!(
p.context().contains("offset + num_steps"),
"context must name the offset + num_steps add, got: {}",
p.context()
);
assert!(
p.operands()
.iter()
.any(|(n, v)| *n == "offset" && *v == usize::MAX as u64),
"operands must carry offset=usize::MAX, got: {:?}",
p.operands()
);
assert!(
p.operands()
.iter()
.any(|(n, v)| *n == "num_steps" && *v == 2),
"operands must carry num_steps=2, got: {:?}",
p.operands()
);
}
other => panic!("expected ArithmeticOverflow, got {other:?}"),
}
assert_eq!(c.offset(), usize::MAX, "offset unchanged on the Err path");
assert!(c.is_empty(), "buffer not committed on the Err path");
let mut c2 = StandardQuantizedKvCache {
keys: None,
values: None,
offset: usize::MAX,
group_size: 64,
bits: 8,
};
assert!(matches!(
c2.update(&t, &t),
Err(Error::ArithmeticOverflow(_))
));
assert_eq!(c2.offset(), usize::MAX);
}
#[test]
fn set_meta_state_group_size_parse_error_leaves_cache_unmutated() {
let mut c = StandardQuantizedKvCache::new(64, 8).unwrap();
let err = c
.set_meta_state(&["3".to_string(), "not_a_number".to_string(), "8".to_string()])
.unwrap_err();
match err {
Error::Parse(p) => {
assert!(
p.context().contains("group_size"),
"context must name group_size, got: {}",
p.context()
);
assert_eq!(p.input_kind(), "i32");
}
other => panic!("expected Parse, got {other:?}"),
}
assert_eq!(
c.meta_state(),
vec!["0".to_string(), "64".to_string(), "8".to_string()],
"no field committed on the parse Err"
);
let mut c2 = StandardQuantizedKvCache::new(64, 8).unwrap();
match c2
.set_meta_state(&[
"256".to_string(),
"5".to_string(),
"bad".to_string(),
"4".to_string(),
])
.unwrap_err()
{
Error::Parse(p) => {
assert!(p.context().contains("group_size"), "got: {}", p.context());
assert_eq!(p.input_kind(), "i32");
}
other => panic!("expected Parse, got {other:?}"),
}
assert_eq!(c2.offset(), 0, "offset not committed on the Err path");
assert_eq!(
c2.group_size(),
64,
"group_size not committed on the Err path"
);
}
#[test]
fn trim_zero_token_is_noop_early_return() {
let kw = kv(&[10.0, 11.0, 12.0]);
let ks = kv(&[1.0, 1.0, 1.0]);
let vw = kv(&[100.0, 101.0, 102.0]);
let vs = kv(&[2.0, 2.0, 2.0]);
let mut c = StandardQuantizedKvCache {
keys: Some((kw, ks, None)),
values: Some((vw, vs, None)),
offset: 3,
group_size: 64,
bits: 8,
};
assert_eq!(c.trim(0).unwrap(), 0, "0-token trim returns 0");
assert_eq!(c.offset(), 3, "offset unchanged");
let st = c.state().unwrap();
assert_eq!(st[0].shape(), vec![1, 1, 3, 1], "K w untrimmed");
let mut empty = StandardQuantizedKvCache::new(64, 8).unwrap();
assert_eq!(empty.trim(5).unwrap(), 0, "empty cache trims nothing");
assert_eq!(empty.offset(), 0);
}
#[test]
fn trim_with_none_storage_takes_none_arms() {
let mut c = StandardQuantizedKvCache {
keys: None,
values: None,
offset: 4, group_size: 64,
bits: 8,
};
let trimmed = c.trim(2).unwrap();
assert_eq!(trimmed, 2, "min(2, 4) trimmed");
assert_eq!(c.offset(), 2, "offset decremented by the trimmed count");
assert!(c.is_empty(), "storage stays None (None arms taken)");
}
#[test]
fn copy_empty_cache_takes_none_arms() {
let c = StandardQuantizedKvCache::new(32, 4).unwrap();
let cp = c.copy().unwrap();
assert!(cp.is_empty(), "copied empty cache is still empty");
assert_eq!(cp.offset(), 0);
assert_eq!(cp.reference_class_name(), "QuantizedKVCache");
let q = cp.as_quantized().expect("copy is still a quantized cache");
assert_eq!(q.group_size(), 32, "group_size copied");
assert_eq!(q.bits(), 4, "bits copied");
assert!(q.quantized_state().unwrap().is_none(), "no triples to copy");
}
#[test]
fn triple_component_len_range_bias_less_uses_none_arm() {
let w = kv(&[0.0, 1.0, 2.0, 3.0, 4.0]); let s = kv(&[0.0, 1.0, 2.0]); let triple: (Array, Array, Option<Array>) = (w, s, None);
let (lo, hi) =
StandardQuantizedKvCache::triple_component_len_range("bias-less", &triple).unwrap();
assert_eq!(lo, 3, "min seq-len across (w=5, scales=3) is 3");
assert_eq!(hi, 5, "max seq-len across (w=5, scales=3) is 5");
let w2 = kv(&[0.0, 1.0]); let s2 = kv(&[0.0, 1.0]); let b2 = kv(&[0.0, 1.0, 2.0, 3.0]); let triple2: (Array, Array, Option<Array>) = (w2, s2, Some(b2));
let (lo2, hi2) =
StandardQuantizedKvCache::triple_component_len_range("biased", &triple2).unwrap();
assert_eq!((lo2, hi2), (2, 4), "bias seq-len 4 widens the max to 4");
}
#[test]
fn copy_bias_less_triples_takes_tree_map_none_arm() {
let kw = kv(&[10.0, 20.0]);
let ks = kv(&[1.0, 1.0]);
let vw = kv(&[100.0, 200.0]);
let vs = kv(&[2.0, 2.0]);
let c = StandardQuantizedKvCache {
keys: Some((kw, ks, None)),
values: Some((vw, vs, None)),
offset: 2,
group_size: 64,
bits: 8,
};
let cp = c.copy().unwrap();
assert_eq!(cp.offset(), 2, "copied offset matches");
assert!(!cp.is_empty());
let q = cp.as_quantized().unwrap();
let (qk, qv) = q
.quantized_state()
.unwrap()
.expect("copied cache still has triples");
assert!(qk.2.is_none(), "copied K triple is bias-less (None arm)");
assert!(qv.2.is_none(), "copied V triple is bias-less (None arm)");
let mut kw_copy = ops::shape::contiguous(&qk.0, false).unwrap();
assert_eq!(
kw_copy.to_vec::<f32>().unwrap(),
vec![10.0, 20.0],
"copied K w markers preserved (refcount clone)"
);
}
#[test]
fn concat_triple_mismatched_bias_is_invariant_violation() {
let pw = kv(&[10.0]);
let ps = kv(&[1.0]);
let pb = kv(&[0.5]);
let nw = kv(&[20.0]);
let ns = kv(&[2.0]);
let prev: (Array, Array, Option<Array>) = (pw, ps, Some(pb));
let new: (Array, Array, Option<Array>) = (nw, ns, None);
match StandardQuantizedKvCache::concat_triple(&prev, &new).unwrap_err() {
Error::InvariantViolation(p) => {
assert!(
p.context().contains("concatenating quantized triples"),
"context must name the triple concat, got: {}",
p.context()
);
assert!(
p.requirement().contains("biases must be present in both"),
"requirement must describe the bias-presence invariant, got: {}",
p.requirement()
);
}
other => panic!("expected InvariantViolation, got {other:?}"),
}
let prev2: (Array, Array, Option<Array>) = (kv(&[10.0]), kv(&[1.0]), None);
let new2: (Array, Array, Option<Array>) = (kv(&[20.0]), kv(&[2.0]), Some(kv(&[0.5])));
assert!(matches!(
StandardQuantizedKvCache::concat_triple(&prev2, &new2),
Err(Error::InvariantViolation(_))
));
}
#[test]
fn enforce_offset_invariant_keys_only_some_arm() {
let kw = kv(&[10.0, 11.0, 12.0]); let ks = kv(&[1.0, 1.0, 1.0]); let mut c = StandardQuantizedKvCache {
keys: Some((kw, ks, None)),
values: None,
offset: 5, group_size: 64,
bits: 8,
};
c.enforce_offset_len_invariant().unwrap();
assert_eq!(
c.offset(),
3,
"offset clamps down to the keys' stored seq-len"
);
assert!(
c.values.is_none(),
"values stay None through the (Some, None) arm"
);
let (kept_w, kept_s, kept_b) = c.keys.as_ref().expect("keys retained");
assert_eq!(
kept_w.shape(),
vec![1, 1, 3, 1],
"K w at the clamped offset"
);
assert_eq!(
kept_s.shape(),
vec![1, 1, 3, 1],
"K scales at the clamped offset"
);
assert!(kept_b.is_none(), "bias-less keys stay bias-less");
let vw = kv(&[100.0, 200.0]); let vs = kv(&[2.0, 2.0]);
let mut c2 = StandardQuantizedKvCache {
keys: None,
values: Some((vw, vs, None)),
offset: 2,
group_size: 64,
bits: 8,
};
c2.enforce_offset_len_invariant().unwrap();
assert_eq!(c2.offset(), 2, "consistent values offset is a no-op clamp");
assert!(
c2.keys.is_none(),
"keys stay None through the (None, Some) arm"
);
}
#[test]
fn materialize_evals_triples_and_empty_is_noop() {
let kw = kv4(2, 1, 10.0);
let ks = kv(&[1.0, 1.0]);
let kb = kv(&[0.5, 0.5]);
let vw = kv4(2, 1, 100.0);
let vs = kv(&[2.0, 2.0]);
let vb = kv(&[0.25, 0.25]);
let mut c = StandardQuantizedKvCache {
keys: Some((kw, ks, Some(kb))),
values: Some((vw, vs, Some(vb))),
offset: 2,
group_size: 64,
bits: 8,
};
c.materialize().unwrap();
assert_eq!(c.offset(), 2);
let (qk, qv) = c.quantized_state().unwrap().unwrap();
assert!(
qk.2.is_some() && qv.2.is_some(),
"biases survive materialize"
);
let mut kw_after = ops::shape::contiguous(&qk.0, false).unwrap();
assert_eq!(
kw_after.to_vec::<f32>().unwrap(),
vec![10.0, 11.0],
"K w markers unchanged by materialize"
);
let mut vw_after = ops::shape::contiguous(&qv.0, false).unwrap();
assert_eq!(
vw_after.to_vec::<f32>().unwrap(),
vec![100.0, 101.0],
"V w markers unchanged by materialize (own stream)"
);
let mut empty = StandardQuantizedKvCache::new(64, 8).unwrap();
empty.materialize().unwrap();
assert!(empty.is_empty());
}