#![cfg(feature = "lm")]
use mlxrs::{
Array,
lm::cache::{ChunkedKvCache, KvCache, MaskMode, RopeOffset, from_state},
};
fn kv(ids: &[f32]) -> Array {
Array::from_slice::<f32>(ids, &(1usize, 1, ids.len(), 1)).unwrap()
}
#[test]
fn chunked_single_token_growth() {
let mut c = ChunkedKvCache::new(Some(4));
assert!(c.is_empty());
assert_eq!(c.offset(), 0);
assert_eq!(c.meta_state(), vec!["4", "0"]);
for id in 0..4u32 {
let t = kv(&[id as f32]);
let (mut k, mut v) = c.update(&t, &t).unwrap();
assert_eq!(c.offset() as u32, id + 1);
assert_eq!(k.shape(), vec![1, 1, (id + 1) as usize, 1]);
let want: Vec<f32> = (0..=id).map(|x| x as f32).collect();
assert_eq!(k.to_vec::<f32>().unwrap(), want);
assert_eq!(v.to_vec::<f32>().unwrap(), want);
}
assert!(!c.is_empty());
assert_eq!(c.meta_state(), vec!["4", "0"]);
match c.rope_offset().unwrap() {
RopeOffset::Scalar(o) => assert_eq!(o, 4),
RopeOffset::Batch(_) => panic!("chunked cache must use a scalar RoPE offset"),
}
}
#[test]
fn chunked_multi_token_spans_chunk_boundary() {
let mut c = ChunkedKvCache::new(Some(4));
let t = kv(&[0.0, 1.0, 2.0, 3.0, 4.0]);
let (mut k, mut v) = c.update(&t, &t).unwrap();
assert_eq!(c.offset(), 5);
assert_eq!(k.shape(), vec![1, 1, 5, 1]);
assert_eq!(k.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
assert_eq!(v.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
assert_eq!(c.meta_state(), vec!["4", "0"]);
}
#[test]
fn chunked_maybe_trim_front_then_update() {
let mut c = ChunkedKvCache::new(Some(4));
let buf = kv(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
c.set_state(vec![buf.try_clone().unwrap(), buf.try_clone().unwrap()])
.unwrap();
assert_eq!(c.offset(), 6);
assert_eq!(c.meta_state(), vec!["4", "0"]);
c.maybe_trim_front().unwrap();
assert_eq!(c.offset(), 6);
assert_eq!(c.meta_state(), vec!["4", "2"]);
let st = c.state().unwrap();
assert_eq!(st.len(), 2);
let mut sk = st[0].try_clone().unwrap();
assert_eq!(sk.to_vec::<f32>().unwrap(), vec![2.0, 3.0, 4.0, 5.0]);
let t = kv(&[6.0]);
let (mut k, mut v) = c.update(&t, &t).unwrap();
assert_eq!(c.offset(), 7);
assert_eq!(k.shape(), vec![1, 1, 5, 1]);
assert_eq!(k.to_vec::<f32>().unwrap(), vec![2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(v.to_vec::<f32>().unwrap(), vec![2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(c.meta_state(), vec!["4", "2"]);
}
#[test]
fn chunked_maybe_trim_front_raw_step_buffer_semantics() {
let mut c = ChunkedKvCache::new(Some(4));
let t = kv(&[0.0]);
c.update(&t, &t).unwrap(); assert_eq!(c.offset(), 1);
c.maybe_trim_front().unwrap();
assert_eq!(c.offset(), 1);
assert_eq!(c.meta_state(), vec!["4", "252"]);
let st = c.state().unwrap();
let mut sk = st[0].try_clone().unwrap();
assert_eq!(sk.shape(), vec![1, 1, 1, 1]);
assert_eq!(sk.to_vec::<f32>().unwrap(), vec![0.0]);
}
#[test]
fn chunked_maybe_trim_front_noop_paths() {
let mut c = ChunkedKvCache::new(Some(8));
let buf = kv(&[0.0, 1.0, 2.0, 3.0]);
c.set_state(vec![buf.try_clone().unwrap(), buf.try_clone().unwrap()])
.unwrap();
c.maybe_trim_front().unwrap();
assert_eq!(c.offset(), 4);
assert_eq!(c.meta_state(), vec!["8", "0"]);
let mut c2 = ChunkedKvCache::new(None);
let buf2 = kv(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
c2.set_state(vec![buf2.try_clone().unwrap(), buf2.try_clone().unwrap()])
.unwrap();
assert_eq!(c2.meta_state(), vec!["None", "0"]);
c2.maybe_trim_front().unwrap();
assert_eq!(c2.offset(), 6);
assert_eq!(c2.meta_state(), vec!["None", "0"]);
}
#[test]
fn chunked_state_roundtrip() {
let mut c = ChunkedKvCache::new(Some(4));
for id in 0..3u32 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
let st = c.state().unwrap();
assert_eq!(st.len(), 2);
let mut sk = st[0].try_clone().unwrap();
let mut sv = st[1].try_clone().unwrap();
assert_eq!(sk.shape(), vec![1, 1, 3, 1]);
assert_eq!(sk.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0]);
assert_eq!(sv.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0]);
let mut c2 = ChunkedKvCache::new(Some(4));
c2.set_state(st).unwrap();
assert_eq!(c2.offset(), 3);
let st2 = c2.state().unwrap();
let mut s2k = st2[0].try_clone().unwrap();
assert_eq!(s2k.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0]);
let mut c3 = ChunkedKvCache::new(Some(4));
assert!(c3.set_state(Vec::new()).is_err());
assert!(c3.set_state(vec![kv(&[0.0])]).is_err());
}
#[test]
fn chunked_meta_state_roundtrip() {
let mut c = ChunkedKvCache::new(Some(4));
let buf = kv(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
c.set_state(vec![buf.try_clone().unwrap(), buf.try_clone().unwrap()])
.unwrap();
c.maybe_trim_front().unwrap(); assert_eq!(c.meta_state(), vec!["4", "2"]);
let mut c2 = ChunkedKvCache::new(Some(99));
c2.set_meta_state(&["7".to_string(), "5".to_string()])
.unwrap();
assert_eq!(c2.meta_state(), vec!["7", "5"]);
let mut c3 = ChunkedKvCache::new(Some(4));
c3.set_meta_state(&["None".to_string(), "3".to_string()])
.unwrap();
assert_eq!(c3.meta_state(), vec!["None", "3"]);
assert!(c3.set_meta_state(&["4".to_string()]).is_err());
assert!(
c3.set_meta_state(&["4".to_string(), "0".to_string(), "x".to_string()])
.is_err()
);
assert!(
c3.set_meta_state(&["4".to_string(), "bad".to_string()])
.is_err()
);
assert_eq!(c3.meta_state(), vec!["None", "3"]);
}
#[test]
fn chunked_is_trimmable_and_trim() {
let mut c = ChunkedKvCache::new(Some(4));
let buf = kv(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
c.set_state(vec![buf.try_clone().unwrap(), buf.try_clone().unwrap()])
.unwrap();
c.maybe_trim_front().unwrap(); assert!(c.is_trimmable());
assert_eq!(c.trim(3).unwrap(), 3);
assert_eq!(c.offset(), 3);
assert_eq!(c.meta_state(), vec!["4", "2"]);
assert_eq!(c.trim(10).unwrap(), 1);
assert_eq!(c.offset(), 2);
}
#[test]
fn chunked_nbytes_and_copy() {
let mut c = ChunkedKvCache::new(Some(4));
assert_eq!(c.nbytes(), 0);
let buf = kv(&[0.0, 1.0, 2.0, 3.0]); c.set_state(vec![buf.try_clone().unwrap(), buf.try_clone().unwrap()])
.unwrap();
assert_eq!(c.nbytes(), 4 * 4 + 4 * 4);
let mut cp = c.copy().unwrap();
let t = kv(&[9.0]);
cp.update(&t, &t).unwrap();
assert_eq!(c.offset(), 4); assert_eq!(c.meta_state(), vec!["4", "0"]);
assert_eq!(cp.offset(), 5); }
#[test]
fn chunked_make_mask() {
let mut c = ChunkedKvCache::new(Some(4));
for id in 0..3u32 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(c.offset(), 3);
assert!(matches!(
c.make_mask(1, None, false).unwrap(),
MaskMode::None
));
assert!(matches!(
c.make_mask(3, None, false).unwrap(),
MaskMode::Causal
));
match c.make_mask(2, None, true).unwrap() {
MaskMode::Array(m) => assert_eq!(m.shape(), vec![2, 3 + 2]),
_ => panic!("return_array must materialize an array mask"),
}
assert!(matches!(
c.make_mask(3, Some(2), false).unwrap(),
MaskMode::Array(_)
));
}
#[test]
fn chunked_from_state_roundtrip() {
let mut c = ChunkedKvCache::new(Some(4));
let buf = kv(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
c.set_state(vec![buf.try_clone().unwrap(), buf.try_clone().unwrap()])
.unwrap();
c.maybe_trim_front().unwrap();
let state = c.state().unwrap();
let meta = c.meta_state();
assert_eq!(meta, vec!["4", "2"]);
let state_arg: Vec<Array> = state
.iter()
.map(|a| a.try_clone())
.collect::<mlxrs::Result<Vec<_>>>()
.unwrap();
let c2 = from_state("ChunkedKVCache", state_arg, &meta).unwrap();
assert_eq!(c2.offset(), 4);
assert_eq!(c2.meta_state(), vec!["4", "2"]);
let c2_state = c2.state().unwrap();
let mut c2k = c2_state[0].try_clone().unwrap();
assert_eq!(c2k.to_vec::<f32>().unwrap(), vec![2.0, 3.0, 4.0, 5.0]);
let c3 = from_state(
"ChunkedKvCache",
vec![buf.try_clone().unwrap(), buf.try_clone().unwrap()],
&["4".to_string(), "0".to_string()],
)
.unwrap();
assert_eq!(c3.offset(), 6);
assert_eq!(c3.meta_state(), vec!["4", "0"]);
assert!(
from_state(
"ChunkedKVCache",
Vec::new(),
&["4".to_string(), "0".to_string()]
)
.is_err()
);
}
#[test]
fn chunked_malformed_state_values_rank_is_err_not_panic() {
let keys = kv(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
let bad_values = Array::from_slice::<f32>(&[0.0, 1.0], &(1usize, 2)).unwrap();
assert!(
from_state(
"ChunkedKVCache",
vec![keys.try_clone().unwrap(), bad_values.try_clone().unwrap()],
&["4".to_string(), "0".to_string()],
)
.is_err()
);
let mut c = ChunkedKvCache::new(Some(4));
assert!(
c.set_state(vec![
keys.try_clone().unwrap(),
bad_values.try_clone().unwrap()
])
.is_err()
);
assert!(c.is_empty());
assert_eq!(c.offset(), 0);
assert_eq!(c.meta_state(), vec!["4", "0"]);
let mut ok = ChunkedKvCache::new(Some(4));
ok.set_state(vec![keys.try_clone().unwrap(), keys.try_clone().unwrap()])
.unwrap();
ok.maybe_trim_front().unwrap();
assert_eq!(ok.offset(), 6);
assert_eq!(ok.meta_state(), vec!["4", "2"]);
}
#[test]
fn chunked_update_err_after_keys_realloc_leaves_cache_unchanged() {
let mut c = ChunkedKvCache::new(None); let k0 = Array::from_slice::<f32>(&[7.0], &(1usize, 1, 1, 1)).unwrap();
let v0 = Array::from_slice::<f32>(&[7.0, 8.0], &(1usize, 1, 1, 2)).unwrap();
c.set_state(vec![k0.try_clone().unwrap(), v0.try_clone().unwrap()])
.unwrap();
assert_eq!(c.offset(), 1); assert!(!c.is_empty());
assert_eq!(c.meta_state(), vec!["None", "0"]);
let pre = c.state().unwrap();
let pre_k = pre[0].try_clone().unwrap().to_vec::<f32>().unwrap();
let pre_v = pre[1].try_clone().unwrap().to_vec::<f32>().unwrap();
let t = kv(&[9.0]);
let r = c.update(&t, &t);
assert!(
r.is_err(),
"values concat must fail on the head_dim mismatch"
);
assert_eq!(c.offset(), 1, "offset must not advance on a failed update");
assert!(!c.is_empty());
assert_eq!(c.meta_state(), vec!["None", "0"]);
let post = c.state().unwrap();
assert_eq!(post.len(), 2);
let mut pk = post[0].try_clone().unwrap();
let mut pv = post[1].try_clone().unwrap();
assert_eq!(
pk.shape(),
vec![1, 1, 1, 1],
"keys buffer must be unchanged"
);
assert_eq!(
pv.shape(),
vec![1, 1, 1, 2],
"values buffer must be unchanged"
);
assert_eq!(pk.to_vec::<f32>().unwrap(), pre_k);
assert_eq!(pv.to_vec::<f32>().unwrap(), pre_v);
let kk = Array::from_slice::<f32>(&[1.0], &(1usize, 1, 1, 1)).unwrap();
let vv = Array::from_slice::<f32>(&[1.0, 2.0], &(1usize, 1, 1, 2)).unwrap();
let (mut ok_k, mut ok_v) = c.update(&kk, &vv).unwrap();
assert_eq!(c.offset(), 2);
assert_eq!(ok_k.shape(), vec![1, 1, 2, 1]);
assert_eq!(ok_v.shape(), vec![1, 1, 2, 2]);
assert_eq!(ok_k.to_vec::<f32>().unwrap(), vec![7.0, 1.0]);
assert_eq!(ok_v.to_vec::<f32>().unwrap(), vec![7.0, 8.0, 1.0, 2.0]);
}
#[test]
fn chunked_maybe_trim_front_seq_mismatched_kv_trims_each_independently() {
let mut c = ChunkedKvCache::new(Some(4));
let keys = kv(&[10.0, 11.0, 12.0, 13.0, 14.0, 15.0]); let values = kv(&[20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0]); c.set_state(vec![keys.try_clone().unwrap(), values.try_clone().unwrap()])
.unwrap();
assert_eq!(c.offset(), 6);
assert_eq!(c.meta_state(), vec!["4", "0"]);
c.maybe_trim_front().unwrap();
assert_eq!(c.offset(), 6); assert_eq!(c.meta_state(), vec!["4", "2"]); let st = c.state().unwrap();
assert_eq!(st.len(), 2);
let mut sk = st[0].try_clone().unwrap();
let mut sv = st[1].try_clone().unwrap();
assert_eq!(sk.shape(), vec![1, 1, 4, 1]);
assert_eq!(sv.shape(), vec![1, 1, 4, 1]);
assert_eq!(sk.to_vec::<f32>().unwrap(), vec![12.0, 13.0, 14.0, 15.0]);
assert_eq!(
sv.to_vec::<f32>().unwrap(),
vec![26.0, 27.0, 28.0, 29.0],
"values must keep ITS OWN last chunk_size rows, not the keys-windowed rows"
);
let mut c2 = ChunkedKvCache::new(Some(4));
let keys2 = kv(&[10.0, 11.0, 12.0, 13.0, 14.0, 15.0]); let values2 = kv(&[40.0, 41.0, 42.0]); c2.set_state(vec![
keys2.try_clone().unwrap(),
values2.try_clone().unwrap(),
])
.unwrap();
c2.maybe_trim_front().unwrap();
assert_eq!(c2.meta_state(), vec!["4", "2"]); let st2 = c2.state().unwrap();
let mut sk2 = st2[0].try_clone().unwrap();
let mut sv2 = st2[1].try_clone().unwrap();
assert_eq!(sk2.to_vec::<f32>().unwrap(), vec![12.0, 13.0, 14.0, 15.0]);
assert_eq!(sv2.shape(), vec![1, 1, 3, 1]);
assert_eq!(sv2.to_vec::<f32>().unwrap(), vec![40.0, 41.0, 42.0]);
}
#[test]
fn chunked_maybe_trim_front_chunk_size_zero_is_python_neg_zero_noop() {
let mut c = ChunkedKvCache::new(Some(0));
let keys = kv(&[10.0, 11.0, 12.0, 13.0]);
let values = kv(&[20.0, 21.0, 22.0, 23.0]);
c.set_state(vec![keys.try_clone().unwrap(), values.try_clone().unwrap()])
.unwrap();
assert_eq!(c.offset(), 4); assert_eq!(c.meta_state(), vec!["0", "0"]);
c.maybe_trim_front().unwrap();
assert_eq!(c.offset(), 4);
assert_eq!(c.meta_state(), vec!["0", "4"]);
let st = c.state().unwrap();
assert_eq!(st.len(), 2);
let mut sk = st[0].try_clone().unwrap();
let mut sv = st[1].try_clone().unwrap();
assert_eq!(sk.shape(), vec![1, 1, 4, 1]);
assert_eq!(sv.shape(), vec![1, 1, 4, 1]);
assert_eq!(
sk.to_vec::<f32>().unwrap(),
vec![10.0, 11.0, 12.0, 13.0],
"chunk_size=0 trim must keep the WHOLE keys tensor (Python -0: slice)"
);
assert_eq!(
sv.to_vec::<f32>().unwrap(),
vec![20.0, 21.0, 22.0, 23.0],
"chunk_size=0 trim must keep the WHOLE values tensor (Python -0: slice)"
);
let mut c2 = ChunkedKvCache::new(Some(99));
let keys2 = kv(&[100.0, 101.0, 102.0]);
let values2 = kv(&[200.0, 201.0, 202.0]);
c2.set_state(vec![
keys2.try_clone().unwrap(),
values2.try_clone().unwrap(),
])
.unwrap();
c2.set_meta_state(&["0".to_string(), "0".to_string()])
.unwrap();
assert_eq!(c2.meta_state(), vec!["0", "0"]);
c2.maybe_trim_front().unwrap();
assert_eq!(c2.meta_state(), vec!["0", "3"]);
let st2 = c2.state().unwrap();
let mut sk2 = st2[0].try_clone().unwrap();
let mut sv2 = st2[1].try_clone().unwrap();
assert_eq!(sk2.to_vec::<f32>().unwrap(), vec![100.0, 101.0, 102.0]);
assert_eq!(sv2.to_vec::<f32>().unwrap(), vec![200.0, 201.0, 202.0]);
}
#[test]
fn chunked_update_out_of_bounds_values_write_is_err_not_silent_corrupt() {
let mut c = ChunkedKvCache::new(Some(4));
let keys = kv(&[10.0, 11.0, 12.0, 13.0]);
let values = kv(&[20.0]);
c.set_state(vec![keys.try_clone().unwrap(), values.try_clone().unwrap()])
.unwrap();
c.set_meta_state(&["4".to_string(), "2".to_string()])
.unwrap();
assert_eq!(c.offset(), 4);
assert_eq!(c.meta_state(), vec!["4", "2"]);
let pre = c.state().unwrap();
let pre_k = pre[0].try_clone().unwrap().to_vec::<f32>().unwrap();
let pre_v = pre[1].try_clone().unwrap().to_vec::<f32>().unwrap();
let t = kv(&[99.0]);
let r = c.update(&t, &t);
assert!(
r.is_err(),
"out-of-bounds values write must be Err, not silent truncation/append"
);
assert_eq!(c.offset(), 4);
assert_eq!(c.meta_state(), vec!["4", "2"]);
let post = c.state().unwrap();
let mut pk = post[0].try_clone().unwrap();
let mut pv = post[1].try_clone().unwrap();
assert_eq!(pk.shape(), vec![1, 1, 4, 1]);
assert_eq!(pv.shape(), vec![1, 1, 1, 1]);
assert_eq!(pk.to_vec::<f32>().unwrap(), pre_k);
assert_eq!(pv.to_vec::<f32>().unwrap(), pre_v);
}
#[test]
fn chunked_set_seq_full_window_rejects_mismatched_batch_dim() {
let bad_kv2 = Array::from_slice::<f32>(&[9.0, 9.5], &(2usize, 1, 1, 1)).unwrap();
let mut c = ChunkedKvCache::new(None);
let seed = kv(&[7.0]);
c.set_state(vec![seed.try_clone().unwrap(), seed.try_clone().unwrap()])
.unwrap();
assert_eq!(c.offset(), 1);
assert_eq!(c.trim(1).unwrap(), 1);
assert_eq!(c.offset(), 0);
assert_eq!(c.meta_state(), vec!["None", "0"]);
let pre = c.state().unwrap();
let r = c.update(&bad_kv2, &bad_kv2);
match &r {
Err(mlxrs::Error::ShapePairMismatch(p)) => {
assert!(
p.context().contains("broadcast_write_rhs"),
"expected broadcast_write_rhs context, got: {}",
p.context()
);
assert_eq!(p.expected(), &[1, 1, 1, 1]);
assert_eq!(p.actual(), &[2, 1, 1, 1]);
}
other => panic!(
"full-window set_seq must reject batch-axis mismatch on the public update API \
(closes #78), got {other:?}"
),
}
assert_eq!(c.offset(), 0, "offset must not advance on a failed update");
let post = c.state().unwrap();
assert_eq!(post.len(), pre.len());
let ok = kv(&[42.0]);
let (mut ok_k, _) = c.update(&ok, &ok).unwrap();
assert_eq!(c.offset(), 1);
assert_eq!(ok_k.shape(), vec![1, 1, 1, 1]);
assert_eq!(ok_k.to_vec::<f32>().unwrap(), vec![42.0]);
}
#[test]
fn chunked_set_seq_full_window_rejects_mismatched_heads_and_head_dim() {
let bad_kv_heads = Array::from_slice::<f32>(&[9.0, 9.5, 9.7], &(1usize, 3, 1, 1)).unwrap();
let bad_kv_hd = Array::from_slice::<f32>(&[9.0, 9.5], &(1usize, 1, 1, 2)).unwrap();
let mut c1 = ChunkedKvCache::new(None);
let seed = kv(&[7.0]);
c1.set_state(vec![seed.try_clone().unwrap(), seed.try_clone().unwrap()])
.unwrap();
c1.trim(1).unwrap();
let r1 = c1.update(&bad_kv_heads, &bad_kv_heads);
match &r1 {
Err(mlxrs::Error::ShapePairMismatch(p)) => {
assert!(
p.context().contains("broadcast_write_rhs"),
"expected broadcast_write_rhs context, got: {}",
p.context()
);
assert_eq!(p.expected(), &[1, 1, 1, 1]);
assert_eq!(p.actual(), &[1, 3, 1, 1]);
}
other => panic!("full-window set_seq must reject n_kv_heads (axis 1) mismatch, got {other:?}"),
}
let ok = kv(&[8.0]);
c1.update(&ok, &ok).unwrap();
assert_eq!(c1.offset(), 1);
let mut c2 = ChunkedKvCache::new(None);
c2.set_state(vec![seed.try_clone().unwrap(), seed.try_clone().unwrap()])
.unwrap();
c2.trim(1).unwrap();
let r2 = c2.update(&bad_kv_hd, &bad_kv_hd);
match &r2 {
Err(mlxrs::Error::ShapePairMismatch(p)) => {
assert!(
p.context().contains("broadcast_write_rhs"),
"expected broadcast_write_rhs context, got: {}",
p.context()
);
assert_eq!(p.expected(), &[1, 1, 1, 1]);
assert_eq!(p.actual(), &[1, 1, 1, 2]);
}
other => panic!("full-window set_seq must reject head_dim (axis 3) mismatch, got {other:?}"),
}
c2.update(&ok, &ok).unwrap();
assert_eq!(c2.offset(), 1);
}
#[test]
fn chunked_set_seq_full_window_broadcasts_size_one_rhs() {
let seed2 = Array::from_slice::<f32>(&[10.0, 20.0], &(2usize, 1, 1, 1)).unwrap();
let rhs1 = kv(&[99.0]);
let mut c = ChunkedKvCache::new(None);
c.set_state(vec![seed2.try_clone().unwrap(), seed2.try_clone().unwrap()])
.unwrap();
assert_eq!(c.offset(), 1); c.trim(1).unwrap();
assert_eq!(c.offset(), 0);
let (k, v) = c.update(&rhs1, &rhs1).unwrap();
assert_eq!(c.offset(), 1);
assert_eq!(
k.shape(),
vec![2, 1, 1, 1],
"size-1 batch RHS must broadcast to PRESERVE buffer batch dim 2 \
(NOT silently shrink to [1, 1, 1, 1])"
);
assert_eq!(v.shape(), vec![2, 1, 1, 1]);
let rhs2 = Array::from_slice::<f32>(&[7.0, 8.0], &(2usize, 1, 1, 1)).unwrap();
let (k2, _) = c.update(&rhs2, &rhs2).unwrap();
assert_eq!(c.offset(), 2);
assert_eq!(k2.shape(), vec![2, 1, 2, 1]);
}