#![cfg(feature = "lm")]
use mlxrs::{
Array,
lm::cache::{CacheList, ChunkedKvCache, KvCache, RotatingKvCache, StandardKvCache, from_state},
};
fn kv(ids: &[f32]) -> Array {
Array::from_slice::<f32>(ids, &(1usize, 1, ids.len(), 1)).unwrap()
}
fn populated_pair() -> CacheList {
let mut s = StandardKvCache::new();
s.update(&kv(&[0.0, 1.0, 2.0]), &kv(&[0.0, 1.0, 2.0]))
.unwrap();
let mut r = RotatingKvCache::new(8, 4);
for i in 0..5 {
let t = kv(&[i as f32]);
r.update(&t, &t).unwrap();
}
CacheList::new(vec![Box::new(s), Box::new(r)])
}
#[test]
fn get_indexes_children() {
let cl = populated_pair();
assert_eq!(cl.len(), 2);
assert_eq!(cl.get(0).unwrap().offset(), 3);
assert_eq!(cl.get(1).unwrap().offset(), 5);
assert_eq!(cl.get(1).unwrap().max_size(), Some(8));
assert!(cl.get(2).is_none(), "out-of-range index must be None");
}
#[test]
fn offset_is_max_child_offset() {
let cl = populated_pair();
assert_eq!(cl.offset(), 5);
let empty = CacheList::new(Vec::new());
assert_eq!(empty.offset(), 0);
}
#[test]
fn is_trimmable_is_all_children() {
let cl = populated_pair();
assert!(cl.is_trimmable());
let mut s = StandardKvCache::new();
s.update(&kv(&[0.0]), &kv(&[0.0])).unwrap();
let mut r = RotatingKvCache::new(4, 2);
for i in 0..6 {
let t = kv(&[i as f32]);
r.update(&t, &t).unwrap();
}
assert!(!r.is_trimmable(), "rotating must be full / not trimmable");
let cl2 = CacheList::new(vec![Box::new(s), Box::new(r)]);
assert!(
!cl2.is_trimmable(),
"one non-trimmable child => list not trimmable"
);
assert!(CacheList::new(Vec::new()).is_trimmable());
}
#[test]
fn trim_delegates_to_all_returns_last() {
let mut cl = populated_pair();
let returned = cl.trim(2).unwrap();
assert_eq!(returned, 2, "returns the LAST child's trimmed count");
assert_eq!(cl.get(0).unwrap().offset(), 1, "child 0 also trimmed");
assert_eq!(cl.get(1).unwrap().offset(), 3, "child 1 trimmed");
let returned2 = cl.trim(3).unwrap();
assert_eq!(
returned2, 3,
"last child (Rotating) trimmed 3 -> that is the returned value"
);
assert_eq!(cl.get(0).unwrap().offset(), 0);
assert_eq!(cl.get(1).unwrap().offset(), 0);
let mut empty = CacheList::new(Vec::new());
assert_eq!(empty.trim(5).unwrap(), 0);
}
#[test]
fn state_is_flattened_child_states() {
let cl = populated_pair();
let st = cl.state().unwrap();
assert_eq!(st.len(), 4, "2 (Standard k,v) + 2 (Rotating k,v)");
let mut k0 = st[0].try_clone().unwrap();
assert_eq!(k0.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0]);
assert!(CacheList::new(Vec::new()).state().unwrap().is_empty());
}
#[test]
fn meta_state_carries_reference_class_names() {
let cl = populated_pair();
let meta = cl.meta_state();
assert_eq!(meta[0], "2", "child count");
assert_eq!(meta[1], "KVCache", "child 0 reference class name");
assert_eq!(meta[2], "2", "child 0 state array count");
assert_eq!(meta[3], "0", "child 0 (Standard) has no meta_state");
assert_eq!(meta[4], "RotatingKVCache", "child 1 reference class name");
assert_eq!(meta[5], "2", "child 1 state array count");
assert_eq!(meta[6], "4", "child 1 (Rotating) meta_state has 4 values");
assert_eq!(&meta[7..11], &["4", "8", "5", "5"]);
assert_eq!(
CacheList::new(Vec::new()).meta_state(),
vec!["0".to_string()],
"empty list -> just child count 0"
);
}
#[test]
fn from_state_roundtrip_rebuilds_children() {
let cl = populated_pair();
let st = cl.state().unwrap();
let meta = cl.meta_state();
let rebuilt = from_state("CacheList", st, &meta).unwrap();
assert_eq!(rebuilt.offset(), 5);
assert!(!rebuilt.is_empty());
assert_eq!(rebuilt.meta_state(), cl.meta_state());
let rst = rebuilt.state().unwrap();
assert_eq!(rst.len(), 4);
let mut k0 = rst[0].try_clone().unwrap();
assert_eq!(k0.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0]);
let again = from_state("CacheList", rebuilt.state().unwrap(), &rebuilt.meta_state()).unwrap();
assert_eq!(again.offset(), 5);
}
#[test]
fn keep_one_rotating_child_is_not_misidentified_as_cache_list() {
let r = RotatingKvCache::new(8, 1);
assert_eq!(
r.meta_state(),
vec![
"1".to_string(),
"8".to_string(),
"0".to_string(),
"0".to_string()
],
"precondition: keep=1 rotating meta is the ambiguous numeric shape"
);
let cl = CacheList::new(vec![Box::new(r)]);
let meta = cl.meta_state();
assert_eq!(meta[0], "1", "one child");
assert_eq!(
meta[1], "RotatingKVCache",
"the keep=1 rotating child must be named RotatingKVCache, NOT CacheList"
);
assert_eq!(meta[2], "0", "fresh rotating child has 0 state arrays");
assert_eq!(meta[3], "4", "rotating meta_state has 4 values");
assert_eq!(&meta[4..8], &["1", "8", "0", "0"]);
let rebuilt = from_state("CacheList", cl.state().unwrap(), &cl.meta_state()).unwrap();
assert_eq!(
rebuilt.meta_state(),
cl.meta_state(),
"round-trip meta must be byte-identical (child rebuilt as RotatingKVCache)"
);
assert_eq!(rebuilt.offset(), 0);
assert!(
rebuilt.is_empty(),
"fresh rotating child -> composite empty"
);
let mut r2 = RotatingKvCache::new(6, 1);
for i in 0..3 {
let t = kv(&[i as f32]);
r2.update(&t, &t).unwrap();
}
let cl2 = CacheList::new(vec![Box::new(r2)]);
assert_eq!(
cl2.meta_state()[1],
"RotatingKVCache",
"populated keep=1 rotating child must also be named RotatingKVCache"
);
let rb2 = from_state("CacheList", cl2.state().unwrap(), &cl2.meta_state()).unwrap();
assert_eq!(rb2.meta_state(), cl2.meta_state());
assert_eq!(rb2.offset(), 3);
}
#[test]
fn set_state_splits_per_child() {
let src = populated_pair();
let st = src.state().unwrap();
let mut s = StandardKvCache::new();
s.update(&kv(&[9.0, 9.0, 9.0]), &kv(&[9.0, 9.0, 9.0]))
.unwrap();
let mut r = RotatingKvCache::new(8, 4);
for _ in 0..5 {
let t = kv(&[9.0]);
r.update(&t, &t).unwrap();
}
let mut tgt = CacheList::new(vec![Box::new(s), Box::new(r)]);
tgt.set_state(st).unwrap();
let back = tgt.state().unwrap();
assert_eq!(back.len(), 4);
let mut k0 = back[0].try_clone().unwrap();
assert_eq!(
k0.to_vec::<f32>().unwrap(),
vec![0.0, 1.0, 2.0],
"child 0 state replaced from the flat list"
);
}
#[test]
fn set_state_is_transactional_on_later_child_failure() {
let mut s0 = StandardKvCache::new();
s0.update(&kv(&[1.0, 2.0]), &kv(&[1.0, 2.0])).unwrap();
let mut s1 = StandardKvCache::new();
s1.update(&kv(&[3.0, 4.0]), &kv(&[3.0, 4.0])).unwrap();
let mut cl = CacheList::new(vec![Box::new(s0), Box::new(s1)]);
let before_off0 = cl.get(0).unwrap().offset();
let before_k0 = cl.state().unwrap()[0].to_vec::<f32>().unwrap();
assert_eq!(before_off0, 2);
assert_eq!(before_k0, vec![1.0, 2.0]);
let good_k = kv(&[7.0, 8.0, 9.0]);
let good_v = kv(&[7.0, 8.0, 9.0]);
let bad_k = Array::from_slice::<f32>(&[5.0, 6.0], &(1usize, 2)).unwrap();
let ok_v = kv(&[5.0, 6.0]);
let flat = vec![good_k, good_v, bad_k, ok_v];
let r = cl.set_state(flat);
assert!(
r.is_err(),
"a later child rejecting its chunk must make set_state Err"
);
assert_eq!(
cl.get(0).unwrap().offset(),
before_off0,
"child 0 offset must be unchanged after the failed restore"
);
let after_k0 = cl.state().unwrap()[0].to_vec::<f32>().unwrap();
assert_eq!(
after_k0, before_k0,
"child 0 state must be byte-identical after the failed restore \
(no partial mutation / half-applied old-new mix)"
);
}
#[test]
fn copy_is_independent() {
let mut cl = populated_pair();
let cp = cl.copy().unwrap();
assert_eq!(cp.offset(), 5);
cl.trim(5).unwrap();
assert_eq!(cl.get(0).unwrap().offset(), 0);
assert_eq!(
cp.offset(),
5,
"the deep copy is unaffected by trimming the original"
);
}
#[test]
fn nbytes_sum_and_is_empty_is_first_child() {
let cl = populated_pair();
let s_only_k = kv(&[0.0, 1.0, 2.0]); let total = cl.nbytes();
assert!(total >= 24, "at least the Standard child's k+v bytes");
assert!(
total > 2 * s_only_k.size() * 4,
"includes the Rotating child too"
);
assert_eq!(CacheList::new(Vec::new()).nbytes(), 0);
let empty_children = CacheList::new(vec![
Box::new(StandardKvCache::new()),
Box::new(RotatingKvCache::new(4, 2)),
]);
assert!(empty_children.is_empty(), "first child (fresh) is empty");
let mut s = StandardKvCache::new();
s.update(&kv(&[1.0]), &kv(&[1.0])).unwrap();
let non_empty = CacheList::new(vec![Box::new(s), Box::new(RotatingKvCache::new(4, 2))]);
assert!(!non_empty.is_empty());
assert!(CacheList::new(Vec::new()).is_empty());
}
#[test]
fn update_is_unsupported_error_not_panic() {
let mut cl = populated_pair();
let t = kv(&[7.0]);
assert!(
cl.update(&t, &t).is_err(),
"CacheList.update must be a recoverable Err, not a panic"
);
}
#[test]
fn make_mask_is_unsupported_error_not_panic() {
let cl = populated_pair();
assert!(
cl.make_mask(1, None, false).is_err(),
"CacheList.make_mask must be a recoverable Err (no _BaseCache mask)"
);
}
#[test]
fn from_state_rejects_malformed_meta() {
assert!(from_state("CacheList", Vec::new(), &[]).is_err());
assert!(from_state("CacheList", Vec::new(), &["x".to_string()]).is_err());
assert!(
from_state("CacheList", Vec::new(), &["1".to_string()]).is_err(),
"truncated per-child framing must error, not panic"
);
let bad = vec![
"1".to_string(),
"KVCache".to_string(),
"two".to_string(),
"0".to_string(),
];
assert!(from_state("CacheList", Vec::new(), &bad).is_err());
let claims_two_arrays = vec![
"1".to_string(),
"KVCache".to_string(),
"2".to_string(), "0".to_string(),
];
assert!(
from_state("CacheList", Vec::new(), &claims_two_arrays).is_err(),
"declared stateCount > provided arrays must error, not panic"
);
}
#[test]
fn from_state_huge_child_count_is_err_not_panic_or_oom() {
let huge = vec![usize::MAX.to_string()]; let r = from_state("CacheList", Vec::new(), &huge);
assert!(
r.is_err(),
"an absurd child_count must be a recoverable Err, never a capacity \
panic / OOM abort on the public from_state load path"
);
let big = vec![
"1000000000".to_string(),
"KVCache".to_string(),
"0".to_string(),
"0".to_string(),
];
assert!(
from_state("CacheList", Vec::new(), &big).is_err(),
"child_count far exceeding the frame budget must error pre-allocation"
);
let exactly_one = vec![
"1".to_string(),
"KVCache".to_string(),
"0".to_string(), "0".to_string(), ];
let ok = from_state("CacheList", Vec::new(), &exactly_one);
assert!(
ok.is_ok(),
"child_count == the frame-budget max must still construct, not be \
spuriously rejected by the pre-allocation bound"
);
assert_eq!(ok.unwrap().offset(), 0);
}
#[test]
fn from_state_deeply_nested_chain_is_err_not_stack_overflow() {
fn nest(depth: usize) -> Vec<String> {
let mut m = vec!["0".to_string()]; for _ in 0..depth {
let inner_len = m.len().to_string();
let mut next = vec![
"1".to_string(), "CacheList".to_string(), "0".to_string(), inner_len, ];
next.append(&mut m);
m = next;
}
m
}
let deep = nest(5000);
let r = from_state("CacheList", Vec::new(), &deep);
assert!(
r.is_err(),
"a pathologically deep nested-CacheList chain must be a recoverable \
Err, never a stack-overflow process abort on the from_state load path"
);
let shallow = nest(3);
let ok = from_state("CacheList", Vec::new(), &shallow);
assert!(
ok.is_ok(),
"a shallow (depth-3) nested CacheList must still reconstruct — the \
nesting bound must not reject realistic nesting"
);
let ok = ok.unwrap();
assert_eq!(ok.offset(), 0);
assert!(ok.is_empty());
assert_eq!(
ok.meta_state(),
shallow,
"shallow nested round-trip is exact"
);
}
#[test]
fn nested_cache_list_roundtrips() {
let inner = populated_pair();
let mut s = StandardKvCache::new();
s.update(&kv(&[8.0, 8.0]), &kv(&[8.0, 8.0])).unwrap();
let outer = CacheList::new(vec![Box::new(s), Box::new(inner)]);
assert_eq!(outer.len(), 2);
let meta = outer.meta_state();
assert_eq!(meta[0], "2");
assert_eq!(meta[1], "KVCache");
assert_eq!(meta[4], "CacheList");
let rebuilt = from_state("CacheList", outer.state().unwrap(), &outer.meta_state()).unwrap();
assert_eq!(rebuilt.meta_state(), outer.meta_state());
assert_eq!(rebuilt.offset(), 5);
}
#[test]
fn as_cache_list_downcast_through_dyn() {
let cl = CacheList::new(vec![
Box::new(StandardKvCache::new()),
Box::new(RotatingKvCache::new(8, 4)),
]);
let mut b: Box<dyn KvCache> = Box::new(cl);
assert!(
b.as_cache_list().is_some(),
"Box<dyn KvCache> wrapping a CacheList must downcast via as_cache_list"
);
let view = b.as_cache_list().unwrap();
assert_eq!(view.len(), 2);
assert_eq!(view.get(0).unwrap().offset(), 0);
{
let view_mut = b.as_cache_list_mut().unwrap();
let child0 = view_mut.get_mut(0).unwrap();
let k = Array::from_slice::<f32>(&[5.0], &(1usize, 1, 1, 1)).unwrap();
child0.update(&k, &k).unwrap();
}
assert_eq!(
b.as_cache_list().unwrap().get(0).unwrap().offset(),
1,
"the through-dyn `&mut` downcast actually mutated the child"
);
let plain: Box<dyn KvCache> = Box::new(StandardKvCache::new());
assert!(
plain.as_cache_list().is_none(),
"a non-CacheList cache must inherit the defaulted None downcast"
);
let plain_rot: Box<dyn KvCache> = Box::new(RotatingKvCache::new(8, 4));
assert!(plain_rot.as_cache_list().is_none());
}
#[test]
fn cache_list_chunked_child_class_name_and_roundtrip() {
let mut chunk = ChunkedKvCache::new(Some(64));
chunk.update(&kv(&[0.0, 1.0]), &kv(&[0.0, 1.0])).unwrap();
let chunk_meta_before = chunk.meta_state();
assert_eq!(
chunk_meta_before.len(),
2,
"precondition: Chunked meta_state has the 2-value shape"
);
let cl = CacheList::new(vec![Box::new(chunk)]);
let meta = cl.meta_state();
assert_eq!(meta[0], "1", "one child");
assert_eq!(
meta[1], "ChunkedKVCache",
"Chunked child must be named ChunkedKVCache, NOT KVCache (silent \
fallback would drop chunk_size/start_position on reload)"
);
assert_eq!(meta[2], "2", "populated Chunked child has 2 state arrays");
assert_eq!(meta[3], "2", "Chunked meta_state has 2 values");
assert_eq!(&meta[4..6], chunk_meta_before.as_slice());
let rebuilt = from_state("CacheList", cl.state().unwrap(), &cl.meta_state()).unwrap();
assert_eq!(
rebuilt.meta_state(),
cl.meta_state(),
"round-trip meta must be byte-identical (Chunked child rebuilt as the \
right concrete kind, preserving chunk_size/start_position)"
);
}
#[test]
fn cache_list_state_count_matches_state_len() {
let cl = populated_pair();
assert_eq!(
cl.state_count().unwrap(),
cl.state().unwrap().len(),
"CacheList::state_count must equal CacheList::state().len()"
);
let empty = CacheList::new(Vec::new());
assert_eq!(empty.state_count().unwrap(), empty.state().unwrap().len());
assert_eq!(empty.state_count().unwrap(), 0);
let b: Box<dyn KvCache> = Box::new(populated_pair());
assert_eq!(b.state_count().unwrap(), b.state().unwrap().len());
}
#[test]
fn cache_list_trim_transactional_short_circuits_on_non_trimmable_child() {
let mut populated_std = StandardKvCache::new();
populated_std
.update(&kv(&[0.0, 1.0, 2.0]), &kv(&[0.0, 1.0, 2.0]))
.unwrap();
let std_offset_before = populated_std.offset();
let mut filled_rot = RotatingKvCache::new(4, 2);
for i in 0..6 {
let t = kv(&[i as f32]);
filled_rot.update(&t, &t).unwrap();
}
assert!(
!filled_rot.is_trimmable(),
"sanity: filled Rotating is not trimmable"
);
let rot_offset_before = filled_rot.offset();
let mut cl = CacheList::new(vec![Box::new(populated_std), Box::new(filled_rot)]);
assert!(
!cl.is_trimmable(),
"sanity: filled-Rotating child makes the list non-trimmable"
);
let r = cl.trim(2).unwrap();
assert_eq!(r, 0, "trim must short-circuit Ok(0) on non-trimmable list");
assert_eq!(
cl.get(0).unwrap().offset(),
std_offset_before,
"TRANSACTIONAL: trimmable Standard child must NOT mutate when sibling is non-trimmable"
);
assert_eq!(
cl.get(1).unwrap().offset(),
rot_offset_before,
"filled Rotating child also unchanged"
);
let mut all_trim = CacheList::new(vec![
Box::new({
let mut s = StandardKvCache::new();
s.update(&kv(&[0.0, 1.0, 2.0, 3.0]), &kv(&[0.0, 1.0, 2.0, 3.0]))
.unwrap();
s
}),
Box::new({
let mut r = RotatingKvCache::new(8, 4);
for i in 0..5 {
let t = kv(&[i as f32]);
r.update(&t, &t).unwrap();
}
r
}),
]);
assert!(all_trim.is_trimmable(), "sanity: both children trimmable");
let r2 = all_trim.trim(2).unwrap();
assert!(
r2 > 0,
"all-trimmable list: trim must actually trim (returned {r2})"
);
}
#[test]
fn state_into_buffer_reuse_matches_state_for_cache_list() {
let cl = populated_pair();
let s = cl.state().unwrap();
let mut buf: Vec<Array> = Vec::new();
cl.state_into(&mut buf).unwrap();
assert_eq!(
s.len(),
buf.len(),
"state_into and state must append the same number of arrays"
);
cl.state_into(&mut buf).unwrap();
assert_eq!(
buf.len(),
s.len() * 2,
"state_into must APPEND, not clear (multi-cache callers depend on this)"
);
}
#[test]
fn meta_state_into_buffer_reuse_matches_meta_state_for_cache_list() {
let cl = populated_pair();
let m1 = cl.meta_state();
let mut buf: Vec<String> = Vec::new();
cl.meta_state_into(&mut buf);
assert_eq!(
m1, buf,
"meta_state and meta_state_into must produce byte-identical output"
);
cl.meta_state_into(&mut buf);
assert_eq!(
buf.len(),
m1.len() * 2,
"meta_state_into must APPEND, not clear"
);
}
#[test]
fn meta_state_into_default_delegates_to_meta_state_for_standard_cache() {
let mut s = StandardKvCache::new();
s.update(&kv(&[0.0, 1.0]), &kv(&[0.0, 1.0])).unwrap();
let m = s.meta_state();
let mut buf: Vec<String> = Vec::new();
s.meta_state_into(&mut buf);
assert_eq!(
m, buf,
"default meta_state_into must produce identical output to meta_state"
);
assert!(buf.is_empty(), "StandardKvCache has no meta_state");
}
#[test]
fn state_into_default_delegates_to_state_for_standard_cache() {
let mut s = StandardKvCache::new();
s.update(&kv(&[0.0, 1.0, 2.0]), &kv(&[0.0, 1.0, 2.0]))
.unwrap();
let st = s.state().unwrap();
let mut buf: Vec<Array> = Vec::new();
s.state_into(&mut buf).unwrap();
assert_eq!(
st.len(),
buf.len(),
"default state_into must produce identical count to state"
);
assert_eq!(buf.len(), 2, "populated StandardKvCache state has 2 arrays");
}