#![cfg(feature = "lm")]
use mlxrs::{
Array,
lm::cache::{KvCache, MaskMode, QuantizedKvCache, StandardQuantizedKvCache, from_state},
ops,
};
const GROUP_SIZE: i32 = 64;
const BITS: i32 = 8;
const HEAD_DIM: usize = 64;
const MODE: &str = "affine";
fn kv(n_steps: usize) -> Array {
kv_base(n_steps, 0.0)
}
fn kv_base(n_steps: usize, base: f32) -> Array {
let total = n_steps * HEAD_DIM;
let data: Vec<f32> = (0..total)
.map(|i| (i as f32) * 0.013 - 0.4 + base)
.collect();
Array::from_slice::<f32>(&data, &(1usize, 1, n_steps, HEAD_DIM)).unwrap()
}
fn dequant(t: &(Array, Array, Option<Array>)) -> Array {
ops::quantized::dequantize(&t.0, &t.1, t.2.as_ref(), GROUP_SIZE, BITS, MODE, None, None).unwrap()
}
fn assert_close(got: &mut Array, want: &mut Array) {
let g = got.to_vec::<f32>().unwrap();
let w = want.to_vec::<f32>().unwrap();
assert_eq!(g.len(), w.len(), "length mismatch");
let max_abs = w.iter().fold(0.0f32, |m, v| m.max(v.abs()));
for (a, b) in g.iter().zip(w.iter()) {
assert!(
(a - b).abs() <= 0.05 * max_abs + 1e-3,
"dequant drift too large: got={a} want={b}"
);
}
}
#[test]
fn update_quantized_roundtrips_and_grows() {
let mut c = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
assert!(c.is_empty());
assert_eq!(c.offset(), 0);
assert_eq!(c.group_size(), GROUP_SIZE);
assert_eq!(c.bits(), BITS);
let mut k1 = kv(3);
let mut v1 = kv(3);
let (qk1, qv1) = c.update_quantized(&k1, &v1).unwrap();
assert!(!c.is_empty());
assert_eq!(c.offset(), 3);
assert_eq!(qk1.0.shape(), vec![1, 1, 3, HEAD_DIM / 4]);
assert_eq!(qk1.1.shape(), vec![1, 1, 3, HEAD_DIM / GROUP_SIZE as usize]);
assert!(qk1.2.is_some(), "affine quantize yields Some(biases)");
assert_eq!(qv1.0.shape(), vec![1, 1, 3, HEAD_DIM / 4]);
let mut dk1 = dequant(&qk1);
let mut dv1 = dequant(&qv1);
assert_eq!(dk1.shape(), vec![1, 1, 3, HEAD_DIM]);
assert_close(&mut dk1, &mut k1);
assert_close(&mut dv1, &mut v1);
let mut k2 = kv(2);
let mut v2 = kv(2);
let (qk2, qv2) = c.update_quantized(&k2, &v2).unwrap();
assert_eq!(c.offset(), 5);
assert_eq!(qk2.0.shape(), vec![1, 1, 5, HEAD_DIM / 4]);
let mut dk2 = dequant(&qk2);
assert_eq!(dk2.shape(), vec![1, 1, 5, HEAD_DIM]);
let want_k: Vec<f32> = {
let mut a = k1.to_vec::<f32>().unwrap();
a.extend_from_slice(&k2.to_vec::<f32>().unwrap());
a
};
let mut want_k_arr = Array::from_slice::<f32>(&want_k, &(1usize, 1, 5usize, HEAD_DIM)).unwrap();
assert_close(&mut dk2, &mut want_k_arr);
let mut dv2 = dequant(&qv2);
let want_v: Vec<f32> = {
let mut a = v1.to_vec::<f32>().unwrap();
a.extend_from_slice(&v2.to_vec::<f32>().unwrap());
a
};
let mut want_v_arr = Array::from_slice::<f32>(&want_v, &(1usize, 1, 5usize, HEAD_DIM)).unwrap();
assert_close(&mut dv2, &mut want_v_arr);
}
#[test]
fn base_update_returns_dequantized() {
let mut c = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
let mut k = kv(4);
let mut v = kv(4);
let (mut dk, mut dv) = c.update(&k, &v).unwrap();
assert_eq!(c.offset(), 4);
assert_eq!(dk.shape(), vec![1, 1, 4, HEAD_DIM]);
assert_eq!(dv.shape(), vec![1, 1, 4, HEAD_DIM]);
assert_close(&mut dk, &mut k);
assert_close(&mut dv, &mut v);
let mut k2 = kv(2);
let (mut dk2, _) = c.update(&k2, &kv(2)).unwrap();
assert_eq!(c.offset(), 6);
assert_eq!(dk2.shape(), vec![1, 1, 6, HEAD_DIM]);
let want: Vec<f32> = {
let mut a = k.to_vec::<f32>().unwrap();
a.extend_from_slice(&k2.to_vec::<f32>().unwrap());
a
};
let mut want_arr = Array::from_slice::<f32>(&want, &(1usize, 1, 6usize, HEAD_DIM)).unwrap();
assert_close(&mut dk2, &mut want_arr);
}
#[test]
fn quantized_state_none_then_some() {
let mut c = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
assert!(c.quantized_state().unwrap().is_none());
assert!(c.as_quantized().is_some());
let mut k = kv(3);
let mut v = kv(3);
c.update_quantized(&k, &v).unwrap();
let st = c.quantized_state().unwrap();
assert!(st.is_some());
let (qk, qv) = st.unwrap();
assert_eq!(qk.0.shape(), vec![1, 1, 3, HEAD_DIM / 4]);
let mut dk = dequant(&qk);
let mut dv = dequant(&qv);
assert_close(&mut dk, &mut k);
assert_close(&mut dv, &mut v);
assert_eq!(c.offset(), 3);
assert!(c.quantized_state().unwrap().is_some());
assert_eq!(c.offset(), 3);
}
#[test]
fn state_set_state_roundtrip() {
let mut c = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
assert!(c.state().unwrap().is_empty());
let mut k = kv(3);
let mut v = kv(3);
c.update_quantized(&k, &v).unwrap();
let st = c.state().unwrap();
assert_eq!(st.len(), 6);
let st_clone: Vec<Array> = st.iter().map(|a| a.try_clone().unwrap()).collect();
let meta = c.meta_state();
assert_eq!(
meta,
vec!["3".to_string(), "64".to_string(), "8".to_string()]
);
let mut c2 = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
c2.set_state(st_clone).unwrap();
c2.set_meta_state(&meta).unwrap();
assert_eq!(c2.offset(), 3);
assert!(!c2.is_empty());
let (qk, qv) = c2.quantized_state().unwrap().unwrap();
let mut dk = dequant(&qk);
let mut dv = dequant(&qv);
assert_close(&mut dk, &mut k);
assert_close(&mut dv, &mut v);
c2.set_state(Vec::new()).unwrap();
assert!(c2.is_empty());
assert_eq!(c2.offset(), 0);
assert!(c2.state().unwrap().is_empty());
}
#[test]
fn from_state_quantized_roundtrip() {
let mut c = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
let mut k = kv(2);
let mut v = kv(2);
c.update_quantized(&k, &v).unwrap();
let st: Vec<Array> = c
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let meta = c.meta_state();
let c2 = from_state("QuantizedKVCache", st, &meta).unwrap();
assert_eq!(c2.offset(), 2);
assert!(!c2.is_empty());
let q = c2.as_quantized().expect("reconstructed cache is quantized");
assert_eq!(q.group_size(), GROUP_SIZE);
assert_eq!(q.bits(), BITS);
let (qk, qv) = q.quantized_state().unwrap().unwrap();
let mut dk = dequant(&qk);
let mut dv = dequant(&qv);
assert_close(&mut dk, &mut k);
assert_close(&mut dv, &mut v);
assert!(from_state("NotACache", Vec::new(), &[]).is_err());
assert!(
from_state(
"QuantizedKVCache",
Vec::new(),
&["1".into(), "64".into(), "8".into()]
)
.is_err()
);
let empty = from_state(
"QuantizedKVCache",
Vec::new(),
&["0".into(), "64".into(), "8".into()],
)
.unwrap();
assert!(empty.is_empty());
assert_eq!(empty.offset(), 0);
}
#[test]
fn make_mask_forwards_to_create_attention_mask() {
let mut c = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
c.update_quantized(&kv(3), &kv(3)).unwrap();
assert!(matches!(
c.make_mask(1, None, false).unwrap(),
MaskMode::None
));
assert!(matches!(
c.make_mask(2, None, false).unwrap(),
MaskMode::Causal
));
match c.make_mask(2, None, true).unwrap() {
MaskMode::Array(m) => assert_eq!(m.shape(), vec![2, 3 + 2]),
_ => panic!("make_mask(2, None, true) must be a materialized Array (cache.py:122)"),
}
assert!(matches!(
c.make_mask(2, Some(1), false).unwrap(),
MaskMode::Array(_)
));
}
#[test]
fn trim_nbytes_copy() {
let mut c = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
assert_eq!(c.nbytes(), 0);
assert!(c.is_trimmable());
c.update_quantized(&kv(4), &kv(4)).unwrap();
assert_eq!(c.offset(), 4);
assert!(c.nbytes() > 0);
let mut cp = c.copy().unwrap();
cp.update(&kv(1), &kv(1)).unwrap();
assert_eq!(cp.offset(), 5);
assert_eq!(
c.offset(),
4,
"original must be unaffected by copy mutation"
);
assert_eq!(c.trim(3).unwrap(), 3);
assert_eq!(c.offset(), 1);
assert_eq!(c.trim(10).unwrap(), 1);
assert_eq!(c.offset(), 0);
}
#[test]
fn wrong_rank_errors() {
let mut c = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
let bad = Array::from_slice::<f32>(&[1.0, 2.0], &(1usize, 2)).unwrap();
assert!(c.update_quantized(&bad, &bad).is_err());
assert!(c.update(&bad, &bad).is_err());
assert!(c.set_meta_state(&["1".into(), "64".into()]).is_err());
assert!(
c.set_meta_state(&["x".into(), "64".into(), "8".into()])
.is_err()
);
assert!(c.set_state(vec![bad.try_clone().unwrap()]).is_err());
}
#[test]
fn update_after_trim_overwrites_not_appends() {
let mut c = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
let mut k4 = kv(4);
let mut v4 = kv(4);
c.update_quantized(&k4, &v4).unwrap();
assert_eq!(c.offset(), 4);
assert_eq!(c.trim(3).unwrap(), 3);
assert_eq!(c.offset(), 1);
let (qk_t, _) = c.quantized_state().unwrap().unwrap();
assert_eq!(qk_t.0.shape(), vec![1, 1, 1, HEAD_DIM / 4]);
let mut dk_t = dequant(&qk_t);
let mut tok0 = kv(1); assert_close(&mut dk_t, &mut tok0);
let mut new_tok = kv_base(1, 100.0);
let (qk, qv) = c.update_quantized(&new_tok, &new_tok).unwrap();
assert_eq!(c.offset(), 2);
assert_eq!(qk.0.shape(), vec![1, 1, 2, HEAD_DIM / 4]);
let mut dk = dequant(&qk);
assert_eq!(dk.shape(), vec![1, 1, 2, HEAD_DIM]);
let want: Vec<f32> = {
let mut a = kv(1).to_vec::<f32>().unwrap(); a.extend_from_slice(&new_tok.to_vec::<f32>().unwrap()); a
};
let mut want_arr = Array::from_slice::<f32>(&want, &(1usize, 1, 2usize, HEAD_DIM)).unwrap();
assert_close(&mut dk, &mut want_arr);
let mut dv = dequant(&qv);
assert_close(&mut dv, &mut want_arr);
let got = dk.to_vec::<f32>().unwrap();
let pos1: Vec<f32> = got[HEAD_DIM..2 * HEAD_DIM].to_vec();
let new_vals = new_tok.to_vec::<f32>().unwrap();
let stale_tok1 = k4.to_vec::<f32>().unwrap()[HEAD_DIM..2 * HEAD_DIM].to_vec();
let max_abs = new_vals.iter().fold(0.0f32, |m, x| m.max(x.abs()));
for (g, n) in pos1.iter().zip(new_vals.iter()) {
assert!(
(g - n).abs() <= 0.05 * max_abs + 1e-3,
"post-trim update position 1 must be the NEW token: got={g} want={n}"
);
}
let drift_from_stale: f32 = pos1
.iter()
.zip(stale_tok1.iter())
.map(|(g, s)| (g - s).abs())
.fold(0.0, f32::max);
assert!(
drift_from_stale > 10.0,
"position 1 must NOT be the stale trimmed token1 (drift {drift_from_stale} too small)"
);
let (qk2, _) = c.quantized_state().unwrap().unwrap();
assert_eq!(qk2.0.shape(), vec![1, 1, 2, HEAD_DIM / 4]);
let mut dk2 = dequant(&qk2);
assert_close(&mut dk2, &mut want_arr);
let _ = v4.to_vec::<f32>().unwrap();
}
#[test]
fn as_quantized_mut_reaches_update_quantized_through_dyn() {
let mut boxed: Box<dyn KvCache> =
Box::new(StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap());
{
let q = boxed
.as_quantized_mut()
.expect("StandardQuantizedKvCache must downcast via as_quantized_mut");
assert_eq!(q.group_size(), GROUP_SIZE);
assert_eq!(q.bits(), BITS);
let mut k = kv(3);
let mut v = kv(3);
let (qk, qv) = q
.update_quantized(&k, &v)
.expect("update_quantized through &mut dyn KvCache must succeed");
assert_eq!(qk.0.shape(), vec![1, 1, 3, HEAD_DIM / 4]);
let mut dk = dequant(&qk);
let mut dv = dequant(&qv);
assert_close(&mut dk, &mut k);
assert_close(&mut dv, &mut v);
}
assert_eq!(boxed.offset(), 3);
{
let q = boxed.as_quantized_mut().unwrap();
q.update_quantized(&kv(2), &kv(2)).unwrap();
}
assert_eq!(boxed.offset(), 5);
let mut std_boxed: Box<dyn KvCache> = Box::new(mlxrs::lm::cache::StandardKvCache::new());
assert!(std_boxed.as_quantized().is_none());
assert!(std_boxed.as_quantized_mut().is_none());
}
#[test]
fn from_state_slices_forged_overlong_triples_to_offset() {
let mut src = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src.update_quantized(&kv(5), &kv(5)).unwrap();
let st: Vec<Array> = src
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
assert_eq!(st.len(), 6, "affine → biased 6-array state");
assert_eq!(st[0].shape(), vec![1, 1, 5, HEAD_DIM / 4]);
let forged_meta = vec!["3".to_string(), "64".to_string(), "8".to_string()];
let mut c = from_state("QuantizedKVCache", st, &forged_meta).unwrap();
assert_eq!(c.offset(), 3);
{
let q = c.as_quantized().expect("reconstructed cache is quantized");
let (qk, qv) = q.quantized_state().unwrap().unwrap();
assert_eq!(
qk.0.shape(),
vec![1, 1, 3, HEAD_DIM / 4],
"restored triples must be sliced to offset 3, not the forged seq-len 5"
);
let mut dk = dequant(&qk);
let mut dv = dequant(&qv);
let mut want3 = kv(3); assert_close(&mut dk, &mut want3);
let mut want3v = kv(3);
assert_close(&mut dv, &mut want3v);
}
let mut new_tok = kv_base(1, 100.0);
let (qk, _) = {
let q = c
.as_quantized_mut()
.expect("reconstructed cache downcasts mutably");
q.update_quantized(&new_tok, &new_tok).unwrap()
};
assert_eq!(c.offset(), 4);
assert_eq!(qk.0.shape(), vec![1, 1, 4, HEAD_DIM / 4]);
let mut dk = dequant(&qk);
let want: Vec<f32> = {
let mut a = kv(3).to_vec::<f32>().unwrap(); a.extend_from_slice(&new_tok.to_vec::<f32>().unwrap()); a
};
let mut want_arr = Array::from_slice::<f32>(&want, &(1usize, 1, 4usize, HEAD_DIM)).unwrap();
assert_close(&mut dk, &mut want_arr);
let got = dk.to_vec::<f32>().unwrap();
let pos3: Vec<f32> = got[3 * HEAD_DIM..4 * HEAD_DIM].to_vec();
let new_vals = new_tok.to_vec::<f32>().unwrap();
let stale_t3 = kv(5).to_vec::<f32>().unwrap()[3 * HEAD_DIM..4 * HEAD_DIM].to_vec();
let max_abs = new_vals.iter().fold(0.0f32, |m, x| m.max(x.abs()));
for (g, n) in pos3.iter().zip(new_vals.iter()) {
assert!(
(g - n).abs() <= 0.05 * max_abs + 1e-3,
"post-restore update position 3 must be the NEW token: got={g} want={n}"
);
}
let drift_from_stale: f32 = pos3
.iter()
.zip(stale_t3.iter())
.map(|(g, s)| (g - s).abs())
.fold(0.0, f32::max);
assert!(
drift_from_stale > 10.0,
"position 3 must NOT be the stale forged t3 (drift {drift_from_stale} too small)"
);
let mut src2 = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src2.update_quantized(&kv(4), &kv(4)).unwrap();
let consistent_st: Vec<Array> = src2
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let consistent_meta = src2.meta_state(); assert_eq!(consistent_meta[0], "4");
let (sk, sv) = src2.quantized_state().unwrap().unwrap();
let mut want_k = dequant(&sk);
let mut want_v = dequant(&sv);
let rt = from_state("QuantizedKVCache", consistent_st, &consistent_meta).unwrap();
assert_eq!(rt.offset(), 4, "consistent offset preserved");
let q = rt.as_quantized().unwrap();
let (rk, rv) = q.quantized_state().unwrap().unwrap();
assert_eq!(rk.0.shape(), vec![1, 1, 4, HEAD_DIM / 4]);
let mut got_k = dequant(&rk);
let mut got_v = dequant(&rv);
let gk = got_k.to_vec::<f32>().unwrap();
let wk = want_k.to_vec::<f32>().unwrap();
assert_eq!(gk.len(), wk.len());
for (g, w) in gk.iter().zip(wk.iter()) {
assert_eq!(
g, w,
"consistent-state round-trip must be byte-identical (keys)"
);
}
let gv = got_v.to_vec::<f32>().unwrap();
let wv = want_v.to_vec::<f32>().unwrap();
for (g, w) in gv.iter().zip(wv.iter()) {
assert_eq!(
g, w,
"consistent-state round-trip must be byte-identical (values)"
);
}
}
#[test]
fn from_state_underlength_state_clamps_offset_down() {
let mut src = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src.update_quantized(&kv(3), &kv(3)).unwrap();
let st: Vec<Array> = src
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
assert_eq!(st.len(), 6, "affine → biased 6-array state");
assert_eq!(st[0].shape(), vec![1, 1, 3, HEAD_DIM / 4]);
let forged_meta = vec!["5".to_string(), "64".to_string(), "8".to_string()];
let mut c = from_state("QuantizedKVCache", st, &forged_meta).unwrap();
assert_eq!(
c.offset(),
3,
"underlength forge: offset must clamp down to stored seq-len 3, not the forged 5"
);
{
let q = c.as_quantized().expect("reconstructed cache is quantized");
let (qk, qv) = q.quantized_state().unwrap().unwrap();
assert_eq!(
qk.0.shape(),
vec![1, 1, 3, HEAD_DIM / 4],
"underlength forge: storage seq-len stays 3 (NumPy clamp on the trim)"
);
let mut dk = dequant(&qk);
let mut dv = dequant(&qv);
let mut want3k = kv(3);
let mut want3v = kv(3);
assert_close(&mut dk, &mut want3k);
assert_close(&mut dv, &mut want3v);
}
let mut new_tok = kv_base(1, 100.0);
let (qk, _) = {
let q = c
.as_quantized_mut()
.expect("reconstructed cache downcasts mutably");
q.update_quantized(&new_tok, &new_tok).unwrap()
};
assert_eq!(
c.offset(),
4,
"post-clamp `update_quantized` lands at offset 3 + 1 = 4, NOT 5 + 1 = 6"
);
assert_eq!(
qk.0.shape(),
vec![1, 1, 4, HEAD_DIM / 4],
"post-clamp `update_quantized` storage is length 4 (3 + 1), NOT length 6"
);
let mut dk = dequant(&qk);
let want: Vec<f32> = {
let mut a = kv(3).to_vec::<f32>().unwrap(); a.extend_from_slice(&new_tok.to_vec::<f32>().unwrap()); a
};
let mut want_arr = Array::from_slice::<f32>(&want, &(1usize, 1, 4usize, HEAD_DIM)).unwrap();
assert_close(&mut dk, &mut want_arr);
let got = dk.to_vec::<f32>().unwrap();
let pos3: Vec<f32> = got[3 * HEAD_DIM..4 * HEAD_DIM].to_vec();
let new_vals = new_tok.to_vec::<f32>().unwrap();
let max_abs = new_vals.iter().fold(0.0f32, |m, x| m.max(x.abs()));
for (g, n) in pos3.iter().zip(new_vals.iter()) {
assert!(
(g - n).abs() <= 0.05 * max_abs + 1e-3,
"post-clamp append position 3 must be the NEW token: got={g} want={n}"
);
}
{
let q = c.as_quantized().expect("post-clamp cache still quantized");
let (qk2, qv2) = q.quantized_state().unwrap().unwrap();
assert_eq!(qk2.0.shape()[2], c.offset());
assert_eq!(qv2.0.shape()[2], c.offset());
}
let mut src2 = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src2.update_quantized(&kv(4), &kv(4)).unwrap();
let consistent_st: Vec<Array> = src2
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let consistent_meta = src2.meta_state(); assert_eq!(
consistent_meta[0], "4",
"honest meta offset matches seq-len"
);
let (sk, sv) = src2.quantized_state().unwrap().unwrap();
let mut want_k = dequant(&sk);
let mut want_v = dequant(&sv);
let rt = from_state("QuantizedKVCache", consistent_st, &consistent_meta).unwrap();
assert_eq!(
rt.offset(),
4,
"consistent offset preserved (clamp is no-op)"
);
let q = rt.as_quantized().unwrap();
let (rk, rv) = q.quantized_state().unwrap().unwrap();
assert_eq!(rk.0.shape(), vec![1, 1, 4, HEAD_DIM / 4]);
let mut got_k = dequant(&rk);
let mut got_v = dequant(&rv);
let gk = got_k.to_vec::<f32>().unwrap();
let wk = want_k.to_vec::<f32>().unwrap();
assert_eq!(gk.len(), wk.len());
for (g, w) in gk.iter().zip(wk.iter()) {
assert_eq!(
g, w,
"consistent-state round-trip must be byte-identical under symmetric clamp (keys)"
);
}
let gv = got_v.to_vec::<f32>().unwrap();
let wv = want_v.to_vec::<f32>().unwrap();
for (g, w) in gv.iter().zip(wv.iter()) {
assert_eq!(
g, w,
"consistent-state round-trip must be byte-identical under symmetric clamp (values)"
);
}
}
#[test]
fn from_state_asymmetric_keys_shorter_is_rejected_at_set_state() {
let mut src_short = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src_short.update_quantized(&kv(3), &kv(3)).unwrap();
let s_short: Vec<Array> = src_short
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let mut src_long = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src_long.update_quantized(&kv(5), &kv(5)).unwrap();
let s_long: Vec<Array> = src_long
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let forged_state: Vec<Array> = s_short[0..3]
.iter()
.map(|a| a.try_clone().unwrap())
.chain(s_long[3..6].iter().map(|a| a.try_clone().unwrap()))
.collect();
let forged_meta = vec!["5".to_string(), "64".to_string(), "8".to_string()];
let result = from_state("QuantizedKVCache", forged_state, &forged_meta);
let err = match result {
Err(e) => e,
Ok(_) => panic!("asymmetric K/V forge must be REJECTED at set_state (not clamped)"),
};
let msg = err.to_string();
assert!(
msg.contains("set_state") && (msg.contains("K and V") || msg.contains("axis")),
"diagnostic must name the load boundary + the K/V mismatch; got {msg}"
);
}
#[test]
fn from_state_asymmetric_values_shorter_is_rejected_at_set_state() {
let mut src_short = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src_short.update_quantized(&kv(3), &kv(3)).unwrap();
let s_short: Vec<Array> = src_short
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let mut src_long = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src_long.update_quantized(&kv(5), &kv(5)).unwrap();
let s_long: Vec<Array> = src_long
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let forged_state: Vec<Array> = s_long[0..3]
.iter()
.map(|a| a.try_clone().unwrap())
.chain(s_short[3..6].iter().map(|a| a.try_clone().unwrap()))
.collect();
let forged_meta = vec!["5".to_string(), "64".to_string(), "8".to_string()];
let result = from_state("QuantizedKVCache", forged_state, &forged_meta);
let err = match result {
Err(e) => e,
Ok(_) => panic!("asymmetric K/V forge must be REJECTED at set_state (not clamped)"),
};
let msg = err.to_string();
assert!(
msg.contains("set_state") && (msg.contains("K and V") || msg.contains("axis")),
"diagnostic must name the load boundary + the K/V mismatch; got {msg}"
);
}
#[test]
fn from_state_underlength_state_within_triple_asymmetric_clamps_to_min() {
let mut src_short = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src_short.update_quantized(&kv(3), &kv(3)).unwrap();
let s_short: Vec<Array> = src_short
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let mut src_long = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src_long.update_quantized(&kv(5), &kv(5)).unwrap();
let s_long: Vec<Array> = src_long
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let forged_state: Vec<Array> = vec![
s_long[0].try_clone().unwrap(), s_short[1].try_clone().unwrap(), s_long[2].try_clone().unwrap(), s_long[3].try_clone().unwrap(), s_long[4].try_clone().unwrap(), s_long[5].try_clone().unwrap(), ];
assert_eq!(
forged_state[0].shape(),
vec![1, 1, 5, HEAD_DIM / 4],
"keys.weight forged seq=5"
);
assert_eq!(
forged_state[1].shape(),
vec![1, 1, 3, 1],
"keys.scales forged seq=3 -- WITHIN-triple asymmetry"
);
assert_eq!(
forged_state[2].shape(),
vec![1, 1, 5, 1],
"keys.biases forged seq=5"
);
let forged_meta = vec!["5".to_string(), "64".to_string(), "8".to_string()];
let result = from_state("QuantizedKVCache", forged_state, &forged_meta);
let err = match result {
Err(e) => e,
Ok(_) => panic!(
"within-triple asymmetry surfaces as a K/V scales shape mismatch \
at set_state — REJECTED, not silently converged"
),
};
let msg = err.to_string();
assert!(
msg.contains("set_state") && (msg.contains("scales") || msg.contains("K and V")),
"diagnostic must name the load boundary + the offending element; got {msg}"
);
}
#[test]
fn from_state_underlength_state_within_triple_asymmetric_bias_less_clamps_to_min() {
let mut src_short = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src_short.update_quantized(&kv(3), &kv(3)).unwrap();
let s_short_full: Vec<Array> = src_short
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let mut src_long = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
src_long.update_quantized(&kv(5), &kv(5)).unwrap();
let s_long_full: Vec<Array> = src_long
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let forged_state: Vec<Array> = vec![
s_long_full[0].try_clone().unwrap(), s_short_full[1].try_clone().unwrap(), s_long_full[3].try_clone().unwrap(), s_long_full[4].try_clone().unwrap(), ];
assert_eq!(forged_state.len(), 4, "bias-less state has 4 arrays");
assert_eq!(forged_state[0].shape(), vec![1, 1, 5, HEAD_DIM / 4]);
assert_eq!(forged_state[1].shape(), vec![1, 1, 3, 1]);
let forged_meta = vec!["5".to_string(), "64".to_string(), "8".to_string()];
let result = from_state("QuantizedKVCache", forged_state, &forged_meta);
let err = match result {
Err(e) => e,
Ok(_) => panic!("bias-less within-triple asymmetry is REJECTED at set_state"),
};
let msg = err.to_string();
assert!(
msg.contains("set_state") && (msg.contains("scales") || msg.contains("K and V")),
"diagnostic must name the load boundary + the offending element; got {msg}"
);
}
#[test]
fn set_meta_state_accepts_swift_4string_form() {
let mut c = StandardQuantizedKvCache::new_unchecked(0, 0);
c.set_meta_state(&[
"256".to_string(), "10".to_string(), "64".to_string(), "4".to_string(), ])
.unwrap();
assert_eq!(c.offset(), 10, "offset restored from index [1]");
let q = c.as_quantized().unwrap();
assert_eq!(q.group_size(), 64, "group_size restored from index [2]");
assert_eq!(q.bits(), 4, "bits restored from index [3]");
}
#[test]
fn set_meta_state_accepts_mlx_lm_3string_form() {
let mut c = StandardQuantizedKvCache::new_unchecked(0, 0);
c.set_meta_state(&[
"10".to_string(), "64".to_string(), "4".to_string(), ])
.unwrap();
assert_eq!(c.offset(), 10);
let q = c.as_quantized().unwrap();
assert_eq!(q.group_size(), 64);
assert_eq!(q.bits(), 4);
}
#[test]
fn set_meta_state_rejects_2_or_5_string_form() {
let mut c = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
let err2 = c
.set_meta_state(&["10".to_string(), "64".to_string()])
.unwrap_err()
.to_string();
assert!(
err2.contains("3 (mlx-lm form)") && err2.contains("4 (mlx-swift-lm form)"),
"2-string rejection message must list BOTH accepted forms; got: {err2}"
);
let err5 = c
.set_meta_state(&[
"10".to_string(),
"20".to_string(),
"64".to_string(),
"4".to_string(),
"1".to_string(),
])
.unwrap_err()
.to_string();
assert!(
err5.contains("3 (mlx-lm form)") && err5.contains("4 (mlx-swift-lm form)"),
"5-string rejection message must list BOTH accepted forms; got: {err5}"
);
assert_eq!(c.offset(), 0);
let q = c.as_quantized().unwrap();
assert_eq!(q.group_size(), GROUP_SIZE);
assert_eq!(q.bits(), BITS);
}
#[test]
fn from_state_round_trip_via_swift_form() {
let mut src = StandardQuantizedKvCache::new(GROUP_SIZE, BITS).unwrap();
let mut k = kv(3);
let mut v = kv(3);
src.update_quantized(&k, &v).unwrap();
let st: Vec<Array> = src
.state()
.unwrap()
.iter()
.map(|a| a.try_clone().unwrap())
.collect();
let st_a: Vec<Array> = st.iter().map(|a| a.try_clone().unwrap()).collect();
let st_b: Vec<Array> = st.iter().map(|a| a.try_clone().unwrap()).collect();
let lm_meta = src.meta_state();
assert_eq!(
lm_meta,
vec!["3".to_string(), "64".to_string(), "8".to_string()]
);
let swift_meta = vec![
"256".to_string(),
"3".to_string(),
"64".to_string(),
"8".to_string(),
];
let c_lm = from_state("QuantizedKVCache", st_a, &lm_meta).unwrap();
let c_sw = from_state("QuantizedKVCache", st_b, &swift_meta).unwrap();
assert_eq!(c_lm.offset(), c_sw.offset());
assert_eq!(c_lm.offset(), 3);
let q_lm = c_lm.as_quantized().unwrap();
let q_sw = c_sw.as_quantized().unwrap();
assert_eq!(q_lm.group_size(), q_sw.group_size());
assert_eq!(q_lm.bits(), q_sw.bits());
assert_eq!(q_lm.group_size(), GROUP_SIZE);
assert_eq!(q_lm.bits(), BITS);
let st_lm = c_lm.state().unwrap();
let st_sw = c_sw.state().unwrap();
assert_eq!(st_lm.len(), st_sw.len());
for (a, b) in st_lm.iter().zip(st_sw.iter()) {
assert_eq!(a.shape(), b.shape());
}
let (qk_lm, qv_lm) = q_lm.quantized_state().unwrap().unwrap();
let (qk_sw, qv_sw) = q_sw.quantized_state().unwrap().unwrap();
let mut dk_lm = dequant(&qk_lm);
let mut dv_lm = dequant(&qv_lm);
let mut dk_sw = dequant(&qk_sw);
let mut dv_sw = dequant(&qv_sw);
assert_close(&mut dk_lm, &mut k);
assert_close(&mut dv_lm, &mut v);
assert_close(&mut dk_sw, &mut k);
assert_close(&mut dv_sw, &mut v);
}