#![cfg(feature = "lm")]
use mlxrs::{
Array,
lm::cache::{
ArraysCache, BatchKvCache, BatchRotatingKvCache, CacheList, ChunkedKvCache, KvCache, MaskMode,
QuantizedKvCache, RotatingKvCache, StandardKvCache, StandardQuantizedKvCache,
},
};
fn kv(ids: &[f32]) -> Array {
Array::from_slice::<f32>(ids, &(1usize, 1, ids.len(), 1)).unwrap()
}
fn assert_arrays_eq(got: &mut Array, want: &mut Array, ctx: &str) {
assert_eq!(got.shape(), want.shape(), "{ctx}: shape mismatch");
let g = got.to_vec::<f32>().unwrap();
let w = want.to_vec::<f32>().unwrap();
assert_eq!(g, w, "{ctx}: contents mismatch");
}
fn assert_arrays_eq_i32(got: &mut Array, want: &mut Array, ctx: &str) {
assert_eq!(got.shape(), want.shape(), "{ctx}: shape mismatch");
let g = got.to_vec::<i32>().unwrap();
let w = want.to_vec::<i32>().unwrap();
assert_eq!(g, w, "{ctx}: contents mismatch");
}
fn assert_state_eq(got: &mut [Array], want: &mut [Array], ctx: &str) {
assert_eq!(got.len(), want.len(), "{ctx}: state.len() mismatch");
for (i, (g, w)) in got.iter_mut().zip(want.iter_mut()).enumerate() {
assert_arrays_eq(g, w, &format!("{ctx}[{i}]"));
}
}
fn assert_state_eq_mixed_kv_then_i32(
got: &mut [Array],
want: &mut [Array],
i32_indices: &[usize],
ctx: &str,
) {
assert_eq!(got.len(), want.len(), "{ctx}: state.len() mismatch");
for (i, (g, w)) in got.iter_mut().zip(want.iter_mut()).enumerate() {
if i32_indices.contains(&i) {
assert_arrays_eq_i32(g, w, &format!("{ctx}[{i} as i32]"));
} else {
assert_arrays_eq(g, w, &format!("{ctx}[{i} as f32]"));
}
}
}
#[derive(Default)]
struct DefaultProbeCache {
calls: Vec<&'static str>,
}
impl KvCache for DefaultProbeCache {
fn offset(&self) -> usize {
0
}
fn update(&mut self, _k: &Array, _v: &Array) -> mlxrs::Result<(Array, Array)> {
unreachable!("DefaultProbeCache::update is not exercised by the from_serialized default test")
}
fn state(&self) -> mlxrs::Result<Vec<Array>> {
Ok(Vec::new())
}
fn materialize(&mut self) -> mlxrs::Result<()> {
unreachable!(
"DefaultProbeCache::materialize is not exercised by the from_serialized default test"
)
}
fn set_state(&mut self, _state: Vec<Array>) -> mlxrs::Result<()> {
self.calls.push("set_state");
Ok(())
}
fn set_meta_state(&mut self, _m: &[String]) -> mlxrs::Result<()> {
self.calls.push("set_meta_state");
Ok(())
}
fn make_mask(&self, _n: usize, _w: Option<usize>, _r: bool) -> mlxrs::Result<MaskMode> {
unreachable!("DefaultProbeCache::make_mask is not exercised")
}
fn nbytes(&self) -> usize {
0
}
fn is_empty(&self) -> bool {
true
}
fn copy(&self) -> mlxrs::Result<Box<dyn KvCache>> {
unreachable!("DefaultProbeCache::copy is not exercised")
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn reference_class_name(&self) -> &'static str {
"DefaultProbeCache"
}
}
#[test]
fn trait_default_from_serialized_calls_set_state_then_meta() {
let mut probe = DefaultProbeCache::default();
probe.from_serialized(vec![kv(&[0.0])], &[]).unwrap();
assert_eq!(
probe.calls,
vec!["set_state", "set_meta_state"],
"trait-default from_serialized must call set_state then set_meta_state, in order"
);
}
#[test]
fn standard_kvcache_from_serialized_round_trip() {
let mut original = StandardKvCache::new();
let (_, _) = original
.update(&kv(&[0.0, 1.0, 2.0, 3.0]), &kv(&[0.0, 1.0, 2.0, 3.0]))
.unwrap();
let saved_state = original.state().unwrap();
let saved_meta = original.meta_state();
let mut restored = StandardKvCache::new();
restored.from_serialized(saved_state, &saved_meta).unwrap();
assert_eq!(restored.offset(), 4);
assert!(!restored.is_empty());
let mut s = restored.state().unwrap();
assert_eq!(s.len(), 2);
assert_eq!(s[0].shape(), vec![1, 1, 4, 1]);
assert_eq!(s[0].to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
assert_eq!(s[1].to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
assert!(restored.meta_state().is_empty());
}
#[test]
fn rotating_from_serialized_round_trip() {
let mut original = RotatingKvCache::new(8, 2);
original
.update(&kv(&[0.0, 1.0, 2.0, 3.0]), &kv(&[0.0, 1.0, 2.0, 3.0]))
.unwrap();
original.update(&kv(&[4.0]), &kv(&[4.0])).unwrap();
let saved_state = original.state().unwrap();
let saved_meta = original.meta_state();
let saved_offset = original.offset();
let mut restored = RotatingKvCache::new(0, 0);
restored.from_serialized(saved_state, &saved_meta).unwrap();
assert_eq!(restored.offset(), saved_offset);
assert_eq!(restored.meta_state(), saved_meta);
let mut restored_state = restored.state().unwrap();
let mut original_state_again = original.state().unwrap();
assert_state_eq(&mut restored_state, &mut original_state_again, "rotating");
}
#[test]
fn rotating_from_serialized_invalid_meta_leaves_self_unchanged() {
let mut cache = RotatingKvCache::new(8, 2);
cache
.update(&kv(&[0.0, 1.0, 2.0, 3.0]), &kv(&[0.0, 1.0, 2.0, 3.0]))
.unwrap();
let mut original_state = cache.state().unwrap();
let original_meta = cache.meta_state();
let original_offset = cache.offset();
let bad_state = cache.state().unwrap();
let bad_meta: Vec<String> = vec![
"2".to_string(),
"8".to_string(),
"not_a_number".to_string(),
"0".to_string(),
];
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err(), "expected Err on non-numeric offset");
assert_eq!(cache.offset(), original_offset);
assert_eq!(cache.meta_state(), original_meta);
let mut after_state = cache.state().unwrap();
assert_state_eq(&mut after_state, &mut original_state, "rotating-unchanged");
}
#[test]
fn rotating_from_serialized_wrong_arity_meta_leaves_self_unchanged() {
let mut cache = RotatingKvCache::new(8, 2);
cache.update(&kv(&[0.0, 1.0]), &kv(&[0.0, 1.0])).unwrap();
let mut original_state = cache.state().unwrap();
let original_meta = cache.meta_state();
let bad_state = cache.state().unwrap();
let bad_meta: Vec<String> = vec!["2".into(), "8".into(), "0".into()]; let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err());
assert_eq!(cache.meta_state(), original_meta);
let mut after_state = cache.state().unwrap();
assert_state_eq(
&mut after_state,
&mut original_state,
"rotating-wrong-arity",
);
}
#[test]
fn chunked_from_serialized_round_trip() {
let mut original = ChunkedKvCache::new(Some(8));
original
.update(&kv(&[0.0, 1.0, 2.0]), &kv(&[0.0, 1.0, 2.0]))
.unwrap();
let saved_state = original.state().unwrap();
let saved_meta = original.meta_state();
let saved_offset = original.offset();
let mut restored = ChunkedKvCache::new(None);
restored.from_serialized(saved_state, &saved_meta).unwrap();
assert_eq!(restored.offset(), saved_offset);
assert_eq!(restored.meta_state(), saved_meta);
let mut restored_state = restored.state().unwrap();
let mut original_state_again = original.state().unwrap();
assert_state_eq(&mut restored_state, &mut original_state_again, "chunked");
}
#[test]
fn chunked_from_serialized_invalid_meta_leaves_self_unchanged() {
let mut cache = ChunkedKvCache::new(Some(8));
cache
.update(&kv(&[0.0, 1.0, 2.0]), &kv(&[0.0, 1.0, 2.0]))
.unwrap();
let mut original_state = cache.state().unwrap();
let original_meta = cache.meta_state();
let original_offset = cache.offset();
let bad_state = cache.state().unwrap();
let bad_meta: Vec<String> = vec!["8".into(), "garbage".into()];
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err());
assert_eq!(cache.offset(), original_offset);
assert_eq!(cache.meta_state(), original_meta);
let mut after_state = cache.state().unwrap();
assert_state_eq(&mut after_state, &mut original_state, "chunked-unchanged");
}
fn kv_quant(n_steps: usize) -> Array {
let mut data = Vec::with_capacity(n_steps * 64);
for _ in 0..n_steps {
for j in 0..64 {
data.push(j as f32);
}
}
Array::from_slice::<f32>(&data, &(1usize, 1, n_steps, 64usize)).unwrap()
}
#[test]
fn quantized_from_serialized_round_trip() {
let mut original = StandardQuantizedKvCache::new(64, 8).unwrap();
original
.update_quantized(&kv_quant(3), &kv_quant(3))
.unwrap();
let saved_state = original.state().unwrap();
let saved_meta = original.meta_state();
let saved_offset = original.offset();
let mut restored = StandardQuantizedKvCache::new_unchecked(0, 0);
restored.from_serialized(saved_state, &saved_meta).unwrap();
assert_eq!(restored.offset(), saved_offset);
assert_eq!(restored.meta_state(), saved_meta);
assert_eq!(restored.group_size(), 64);
assert_eq!(restored.bits(), 8);
let restored_state = restored.state().unwrap();
let original_state_again = original.state().unwrap();
assert_eq!(restored_state.len(), original_state_again.len());
for (i, (a, b)) in restored_state
.iter()
.zip(original_state_again.iter())
.enumerate()
{
assert_eq!(a.shape(), b.shape(), "quantized-state[{i}].shape");
}
}
#[test]
fn quantized_from_serialized_wrong_arity_meta_leaves_self_unchanged() {
let mut cache = StandardQuantizedKvCache::new(64, 8).unwrap();
cache.update_quantized(&kv_quant(2), &kv_quant(2)).unwrap();
let original_offset = cache.offset();
let original_meta = cache.meta_state();
let bad_state = cache.state().unwrap();
let bad_meta: Vec<String> = vec![
"2".into(),
"64".into(),
"8".into(),
"extra".into(),
"extra2".into(),
];
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err());
assert_eq!(cache.offset(), original_offset);
assert_eq!(cache.meta_state(), original_meta);
assert_eq!(cache.group_size(), 64);
assert_eq!(cache.bits(), 8);
assert!(!cache.is_empty());
}
#[test]
fn quantized_from_serialized_invalid_meta_value_leaves_self_unchanged() {
let mut cache = StandardQuantizedKvCache::new(64, 8).unwrap();
cache.update_quantized(&kv_quant(2), &kv_quant(2)).unwrap();
let original_offset = cache.offset();
let original_meta = cache.meta_state();
let original_gs = cache.group_size();
let original_bits = cache.bits();
let bad_state = cache.state().unwrap();
let bad_meta: Vec<String> = vec!["2".into(), "64".into(), "not_bits".into()];
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err());
assert_eq!(cache.offset(), original_offset);
assert_eq!(cache.meta_state(), original_meta);
assert_eq!(cache.group_size(), original_gs);
assert_eq!(cache.bits(), original_bits);
}
fn build_heterogeneous_cache_list() -> CacheList {
let mut std_cache = StandardKvCache::new();
std_cache
.update(&kv(&[0.0, 1.0]), &kv(&[0.0, 1.0]))
.unwrap();
let mut rot_cache = RotatingKvCache::new(4, 1);
rot_cache
.update(&kv(&[10.0, 11.0]), &kv(&[10.0, 11.0]))
.unwrap();
let mut q_cache = StandardQuantizedKvCache::new(64, 8).unwrap();
q_cache
.update_quantized(&kv_quant(2), &kv_quant(2))
.unwrap();
CacheList::new(vec![
Box::new(std_cache),
Box::new(rot_cache),
Box::new(q_cache),
])
}
#[test]
fn cache_list_from_serialized_round_trip() {
let original = build_heterogeneous_cache_list();
let saved_state = original.state().unwrap();
let saved_meta = original.meta_state();
let mut restored = CacheList::new(Vec::new());
restored.from_serialized(saved_state, &saved_meta).unwrap();
assert_eq!(restored.len(), 3);
assert_eq!(restored.meta_state(), saved_meta);
assert_eq!(restored.get(0).unwrap().offset(), 2);
assert_eq!(restored.get(1).unwrap().offset(), 2);
assert_eq!(restored.get(2).unwrap().offset(), 2);
assert_eq!(restored.get(0).unwrap().reference_class_name(), "KVCache");
assert_eq!(
restored.get(1).unwrap().reference_class_name(),
"RotatingKVCache"
);
assert_eq!(
restored.get(2).unwrap().reference_class_name(),
"QuantizedKVCache"
);
}
#[test]
fn cache_list_from_serialized_unknown_class_name_leaves_self_unchanged() {
let mut cache = build_heterogeneous_cache_list();
let original_meta = cache.meta_state();
let original_len = cache.len();
let original_child0_class = cache.get(0).unwrap().reference_class_name();
let original_child1_class = cache.get(1).unwrap().reference_class_name();
let bad_meta: Vec<String> = vec![
"1".into(), "ThisClassDoesNotExist".into(), "0".into(), "0".into(), ];
let bad_state: Vec<Array> = Vec::new();
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err(), "expected Err on unknown class name");
assert_eq!(cache.len(), original_len);
assert_eq!(cache.meta_state(), original_meta);
assert_eq!(
cache.get(0).unwrap().reference_class_name(),
original_child0_class
);
assert_eq!(
cache.get(1).unwrap().reference_class_name(),
original_child1_class
);
}
#[test]
fn cache_list_from_serialized_truncated_meta_leaves_self_unchanged() {
let mut cache = build_heterogeneous_cache_list();
let original_meta = cache.meta_state();
let original_len = cache.len();
let bad_meta: Vec<String> = vec!["1000".into(), "KVCache".into()];
let bad_state: Vec<Array> = Vec::new();
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err());
assert_eq!(cache.len(), original_len);
assert_eq!(cache.meta_state(), original_meta);
}
fn kv_batch(seqs: &[&[f32]]) -> Array {
let b = seqs.len();
let s = seqs[0].len();
let mut data: Vec<f32> = Vec::with_capacity(b * s);
for row in seqs {
assert_eq!(row.len(), s, "kv_batch: ragged input");
data.extend_from_slice(row);
}
Array::from_slice::<f32>(&data, &(b, 1usize, s, 1usize)).unwrap()
}
#[test]
fn batch_from_serialized_round_trip() {
let mut original = BatchKvCache::new(&[1, 0]);
original
.update(
&kv_batch(&[&[0.0, 1.0], &[10.0, 11.0]]),
&kv_batch(&[&[0.0, 1.0], &[10.0, 11.0]]),
)
.unwrap();
let saved_state = original.state().unwrap();
let saved_meta = original.meta_state();
let saved_idx = original.offset();
let mut restored = BatchKvCache::new(&[]);
restored.from_serialized(saved_state, &saved_meta).unwrap();
assert_eq!(restored.offset(), saved_idx);
let mut restored_state = restored.state().unwrap();
let mut original_state_again = original.state().unwrap();
assert_state_eq_mixed_kv_then_i32(
&mut restored_state,
&mut original_state_again,
&[2, 3],
"batch",
);
}
#[test]
fn batch_from_serialized_wrong_state_arity_leaves_self_unchanged() {
let mut cache = BatchKvCache::new(&[1, 0]);
cache
.update(
&kv_batch(&[&[0.0, 1.0], &[10.0, 11.0]]),
&kv_batch(&[&[0.0, 1.0], &[10.0, 11.0]]),
)
.unwrap();
let original_offset = cache.offset();
let mut original_state = cache.state().unwrap();
let bad_state = vec![
kv_batch(&[&[0.0], &[1.0]]),
kv_batch(&[&[0.0], &[1.0]]),
Array::from_slice::<i32>(&[0, 0], &(2usize,)).unwrap(),
];
let result = cache.from_serialized(bad_state, &[]);
assert!(result.is_err());
assert_eq!(cache.offset(), original_offset);
let mut after_state = cache.state().unwrap();
assert_state_eq_mixed_kv_then_i32(
&mut after_state,
&mut original_state,
&[2, 3],
"batch-unchanged",
);
}
#[test]
fn batch_rotating_from_serialized_round_trip() {
let mut original = BatchRotatingKvCache::new(4, &[1, 0]);
original
.update(
&kv_batch(&[&[0.0, 1.0], &[10.0, 11.0]]),
&kv_batch(&[&[0.0, 1.0], &[10.0, 11.0]]),
)
.unwrap();
let saved_state = original.state().unwrap();
let saved_meta = original.meta_state();
let saved_off = original.offset();
let mut restored = BatchRotatingKvCache::new(0, &[]);
restored.from_serialized(saved_state, &saved_meta).unwrap();
assert_eq!(restored.offset(), saved_off);
assert_eq!(restored.meta_state(), saved_meta);
assert_eq!(restored.max_size(), Some(4));
let mut restored_state = restored.state().unwrap();
let mut original_state_again = original.state().unwrap();
assert_state_eq_mixed_kv_then_i32(
&mut restored_state,
&mut original_state_again,
&[2, 3],
"batch-rotating",
);
}
#[test]
fn batch_rotating_from_serialized_invalid_meta_leaves_self_unchanged() {
let mut cache = BatchRotatingKvCache::new(4, &[0, 0]);
cache
.update(
&kv_batch(&[&[0.0, 1.0], &[10.0, 11.0]]),
&kv_batch(&[&[0.0, 1.0], &[10.0, 11.0]]),
)
.unwrap();
let original_offset = cache.offset();
let original_meta = cache.meta_state();
let original_max_size = cache.max_size();
let bad_state = cache.state().unwrap();
let bad_meta: Vec<String> = vec!["4".into(), "2".into(), "2".into(), "neither".into()];
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err());
assert_eq!(cache.offset(), original_offset);
assert_eq!(cache.meta_state(), original_meta);
assert_eq!(cache.max_size(), original_max_size);
}
#[test]
fn arrays_from_serialized_round_trip() {
let mut original = ArraysCache::new(4);
let slot_arr = Array::from_slice::<f32>(&[42.0, 43.0], &(1usize, 2)).unwrap();
original.set(2, slot_arr).unwrap();
let saved_state = original.state().unwrap();
let saved_meta = original.meta_state();
let mut restored = ArraysCache::new(0);
restored.from_serialized(saved_state, &saved_meta).unwrap();
assert_eq!(restored.meta_state(), saved_meta);
assert!(restored.get(0).is_none());
assert!(restored.get(2).is_some());
let mut restored_state = restored.state().unwrap();
assert_eq!(restored_state.len(), 1);
assert_eq!(restored_state[0].to_vec::<f32>().unwrap(), vec![42.0, 43.0]);
}
#[test]
fn arrays_from_serialized_invalid_meta_leaves_self_unchanged() {
let mut cache = ArraysCache::new(4);
let slot_arr = Array::from_slice::<f32>(&[7.0, 8.0], &(1usize, 2)).unwrap();
cache.set(1, slot_arr).unwrap();
let original_meta = cache.meta_state();
let mut before_state = cache.state().unwrap();
assert_eq!(before_state.len(), 1);
let before_slot_contents = before_state[0].to_vec::<f32>().unwrap();
let bad_state: Vec<Array> = vec![Array::from_slice::<f32>(&[99.0], &(1usize, 1)).unwrap()];
let bad_meta: Vec<String> = vec!["not_a_count".into(), "0".into()];
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err());
assert_eq!(cache.meta_state(), original_meta);
assert!(cache.get(1).is_some());
let mut after_state = cache.state().unwrap();
assert_eq!(after_state.len(), 1);
let after_slot_contents = after_state[0].to_vec::<f32>().unwrap();
assert_eq!(after_slot_contents, before_slot_contents);
}
#[test]
fn standard_from_serialized_nonempty_meta_leaves_self_unchanged() {
let mut cache = StandardKvCache::new();
cache
.update(&kv(&[10.0, 11.0]), &kv(&[10.0, 11.0]))
.unwrap();
let original_offset = cache.offset();
let mut original_state = cache.state().unwrap();
let bad_state = vec![kv(&[99.0, 88.0]), kv(&[99.0, 88.0])];
let bad_meta: Vec<String> = vec!["bogus".into()];
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(
result.is_err(),
"must reject non-empty meta on StandardKvCache"
);
assert_eq!(cache.offset(), original_offset);
let mut after_state = cache.state().unwrap();
assert_state_eq(&mut after_state, &mut original_state, "standard");
}
#[test]
fn rotating_from_serialized_empty_state_nonzero_meta_rejected() {
let mut cache = RotatingKvCache::new(8, 2);
cache
.update(&kv(&[0.0, 1.0, 2.0, 3.0]), &kv(&[0.0, 1.0, 2.0, 3.0]))
.unwrap();
let original_meta = cache.meta_state();
let original_offset = cache.offset();
let mut original_state = cache.state().unwrap();
let bad_state: Vec<Array> = Vec::new();
let bad_meta: Vec<String> = vec!["4".into(), "8".into(), "5".into(), "5".into()];
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err(), "must reject empty state + non-zero meta");
assert_eq!(cache.offset(), original_offset);
assert_eq!(cache.meta_state(), original_meta);
let mut after_state = cache.state().unwrap();
assert_state_eq(&mut after_state, &mut original_state, "rotating");
}
#[test]
fn quantized_from_serialized_empty_state_nonzero_offset_rejected() {
let mut cache = StandardQuantizedKvCache::new(64, 8).unwrap();
cache.update_quantized(&kv_quant(2), &kv_quant(2)).unwrap();
let original_meta = cache.meta_state();
let original_offset = cache.offset();
let original_gs = cache.group_size();
let original_bits = cache.bits();
let original_state_count = cache.state().unwrap().len();
let bad_state: Vec<Array> = Vec::new();
let bad_meta: Vec<String> = vec!["5".into(), "64".into(), "8".into()];
let result = cache.from_serialized(bad_state, &bad_meta);
assert!(result.is_err(), "must reject empty state + non-zero offset");
assert_eq!(cache.offset(), original_offset);
assert_eq!(cache.meta_state(), original_meta);
assert_eq!(cache.group_size(), original_gs);
assert_eq!(cache.bits(), original_bits);
assert_eq!(cache.state().unwrap().len(), original_state_count);
}
#[test]
fn batch_rotating_from_serialized_structural_inconsistency_rejected() {
let mut cache = BatchRotatingKvCache::new(4, &[0, 0]);
cache
.update(
&kv_batch(&[&[0.0, 1.0], &[10.0, 11.0]]),
&kv_batch(&[&[0.0, 1.0], &[10.0, 11.0]]),
)
.unwrap();
let original_meta = cache.meta_state();
let original_offset = cache.offset();
let original_max_size = cache.max_size();
let saved_state = cache.state().unwrap();
let bad_meta: Vec<String> = vec!["4".into(), "2".into(), "99".into(), "false".into()];
let result = cache.from_serialized(saved_state, &bad_meta);
assert!(
result.is_err(),
"must reject _idx beyond physical buffer length"
);
assert_eq!(cache.offset(), original_offset);
assert_eq!(cache.meta_state(), original_meta);
assert_eq!(cache.max_size(), original_max_size);
}