#![cfg(feature = "lm")]
use mlxrs::{
Array,
lm::cache::{
BatchKvCache, BatchPositionedKvCache, BatchRotatingKvCache, KvCache, MaskMode, RopeOffset,
dynamic_roll, from_state,
},
};
fn kvb(rows: &[&[f32]]) -> Array {
let b = rows.len();
let s = rows[0].len();
let mut data = Vec::with_capacity(b * s);
for r in rows {
assert_eq!(r.len(), s, "ragged test rows");
data.extend_from_slice(r);
}
Array::from_slice::<f32>(&data, &(b, 1usize, s, 1usize)).unwrap()
}
fn iv(a: &Array) -> Vec<i32> {
let mut a = a.try_clone().unwrap();
a.to_vec::<i32>().unwrap()
}
#[test]
fn dynamic_roll_per_row_shift() {
let x = kvb(&[&[10.0, 20.0, 30.0], &[40.0, 50.0, 60.0]]);
let shifts = Array::from_slice::<i32>(&[0, 2], &(2usize, 1usize)).unwrap();
let mut rolled = dynamic_roll(&x, &shifts, 2).unwrap();
assert_eq!(rolled.shape(), vec![2, 1, 3, 1]);
assert_eq!(
rolled.to_vec::<f32>().unwrap(),
vec![10.0, 20.0, 30.0, 50.0, 60.0, 40.0]
);
let x1 = kvb(&[&[10.0, 20.0, 30.0]]);
let s1 = Array::from_slice::<i32>(&[1], &(1usize, 1usize)).unwrap();
let mut r1 = dynamic_roll(&x1, &s1, 2).unwrap();
assert_eq!(r1.to_vec::<f32>().unwrap(), vec![30.0, 10.0, 20.0]);
}
#[test]
fn batch_kv_init_offsets_and_rope_offset() {
let c = BatchKvCache::new(&[1, 3, 0]);
assert!(c.is_empty());
assert_eq!(iv(&c.batch_offset().unwrap()), vec![-1, -3, 0]);
assert!(c.as_batch_positioned().is_some());
match c.rope_offset().unwrap() {
RopeOffset::Batch(a) => assert_eq!(iv(&a), vec![-1, -3, 0]),
RopeOffset::Scalar(_) => panic!("batch cache must use a per-seq RoPE offset"),
}
assert_eq!(c.offset(), 0);
assert_eq!(c.nbytes(), 0);
assert!(c.state().unwrap().is_empty());
}
#[test]
fn batch_kv_left_padded_update_grows_and_concats() {
let mut c = BatchKvCache::new(&[1, 3, 0]);
let p = kvb(&[
&[0.0, 1.0, 3.0, 5.0],
&[0.0, 0.0, 0.0, 7.0],
&[2.0, 6.0, 8.0, 9.0],
]);
let (mut k, mut v) = c.update(&p, &p).unwrap();
assert_eq!(k.shape(), vec![3, 1, 4, 1]);
assert_eq!(
k.to_vec::<f32>().unwrap(),
vec![0.0, 1.0, 3.0, 5.0, 0.0, 0.0, 0.0, 7.0, 2.0, 6.0, 8.0, 9.0]
);
assert_eq!(v.to_vec::<f32>().unwrap(), k.to_vec::<f32>().unwrap());
assert_eq!(iv(&c.batch_offset().unwrap()), vec![3, 1, 4]);
assert_eq!(c.offset(), 4); assert!(!c.is_empty());
assert!(c.is_trimmable());
let d = kvb(&[&[10.0], &[11.0], &[12.0]]);
let (mut k2, _) = c.update(&d, &d).unwrap();
assert_eq!(k2.shape(), vec![3, 1, 5, 1]);
assert_eq!(
k2.to_vec::<f32>().unwrap(),
vec![
0.0, 1.0, 3.0, 5.0, 10.0, 0.0, 0.0, 0.0, 7.0, 11.0, 2.0, 6.0, 8.0, 9.0, 12.0, ]
);
assert_eq!(iv(&c.batch_offset().unwrap()), vec![4, 2, 5]);
assert_eq!(c.offset(), 5);
assert_eq!(c.trim(2).unwrap(), 2);
assert_eq!(c.offset(), 3);
assert_eq!(iv(&c.batch_offset().unwrap()), vec![2, 0, 3]);
assert_eq!(c.trim(99).unwrap(), 3);
assert_eq!(c.offset(), 0);
}
#[test]
fn batch_kv_right_padding_finalize_rolls() {
let mut c = BatchKvCache::new(&[0, 0]);
let p = kvb(&[&[1.0, 2.0, 3.0, 0.0], &[4.0, 5.0, 0.0, 0.0]]);
c.update(&p, &p).unwrap();
c.prepare_right_padding(&[1, 2]).unwrap();
let s_before = c.batch_offset().unwrap();
assert_eq!(iv(&s_before), vec![4, 4]);
c.finalize().unwrap();
let (mut k, _) = c.state_kv().unwrap();
assert_eq!(
k.to_vec::<f32>().unwrap(),
vec![0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 4.0, 5.0]
);
assert_eq!(iv(&c.batch_offset().unwrap()), vec![3, 2]);
assert_eq!(iv(&c.left_padding_arr().unwrap()), vec![1, 2]);
}
#[test]
fn batch_kv_make_mask_is_left_padded_causal() {
let mut c = BatchKvCache::new(&[1, 0]);
let p = kvb(&[&[0.0, 1.0], &[2.0, 3.0]]);
c.update(&p, &p).unwrap();
match c.make_mask(2, None, false).unwrap() {
MaskMode::Array(mut m) => {
assert_eq!(m.shape(), vec![2, 1, 2, 4]);
let bits: Vec<u8> = m
.to_vec::<bool>()
.unwrap()
.into_iter()
.map(|b| b as u8)
.collect();
assert_eq!(
bits,
vec![
0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, ]
);
}
_ => panic!("BatchKVCache.make_mask(N>1) must materialize a left-padded causal mask"),
}
}
#[test]
fn batch_kv_state_roundtrip_and_from_state() {
let mut c = BatchKvCache::new(&[1, 0]);
let p = kvb(&[&[0.0, 5.0, 6.0], &[1.0, 2.0, 3.0]]);
c.update(&p, &p).unwrap();
let st = c.state().unwrap();
assert_eq!(st.len(), 4);
let restored = from_state("BatchKVCache", st, &[]).unwrap();
assert_eq!(restored.offset(), 3);
let (mut k, _) = {
let s = restored.state().unwrap();
let mut it = s.into_iter();
(it.next().unwrap(), it.next().unwrap())
};
assert_eq!(
k.to_vec::<f32>().unwrap(),
vec![0.0, 5.0, 6.0, 1.0, 2.0, 3.0]
);
assert!(restored.as_batch_positioned().is_some());
assert_eq!(
iv(
&restored
.as_batch_positioned()
.unwrap()
.batch_offset()
.unwrap()
),
vec![2, 3]
);
}
#[test]
fn batch_kv_wrong_rank_errors_not_panic() {
let mut c = BatchKvCache::new(&[0, 0]);
let bad = Array::from_slice::<f32>(&[1.0, 2.0], &(1usize, 2usize)).unwrap();
assert!(c.update(&bad, &bad).is_err());
let good = kvb(&[&[1.0], &[2.0]]);
assert!(c.update(&good, &bad).is_err());
}
#[test]
fn batch_rotating_init_and_rope_offset() {
let c = BatchRotatingKvCache::new(4, &[2, 0]);
assert!(c.is_empty());
assert_eq!(c.max_size(), Some(4));
assert_eq!(iv(&c.batch_offset().unwrap()), vec![-2, 0]);
assert!(c.as_batch_positioned().is_some());
match c.rope_offset().unwrap() {
RopeOffset::Batch(a) => assert_eq!(iv(&a), vec![-2, 0]),
RopeOffset::Scalar(_) => panic!("batch-rotating cache must use a per-seq RoPE offset"),
}
assert!(c.is_trimmable()); }
#[test]
fn batch_rotating_active_ring_then_concat_mixed() {
let mut c = BatchRotatingKvCache::new(4, &[0, 0]);
let p = kvb(&[&[0.0, 1.0, 2.0], &[10.0, 11.0, 12.0]]);
let (mut k, _) = c.update(&p, &p).unwrap();
assert_eq!(k.shape(), vec![2, 1, 3, 1]);
assert_eq!(
k.to_vec::<f32>().unwrap(),
vec![0.0, 1.0, 2.0, 10.0, 11.0, 12.0]
);
assert_eq!(iv(&c.batch_offset().unwrap()), vec![3, 3]);
let d1 = kvb(&[&[3.0], &[13.0]]);
let (mut k1, _) = c.update(&d1, &d1).unwrap();
assert_eq!(k1.shape(), vec![2, 1, 4, 1]);
assert_eq!(
k1.to_vec::<f32>().unwrap(),
vec![0.0, 1.0, 2.0, 3.0, 10.0, 11.0, 12.0, 13.0]
);
assert_eq!(iv(&c.batch_offset().unwrap()), vec![4, 4]);
let d2 = kvb(&[&[4.0], &[14.0]]);
let (mut k2, _) = c.update(&d2, &d2).unwrap();
assert_eq!(
k2.to_vec::<f32>().unwrap(),
vec![4.0, 1.0, 2.0, 3.0, 14.0, 11.0, 12.0, 13.0],
"physical ring order after in-place rotated write"
);
assert_eq!(iv(&c.batch_offset().unwrap()), vec![5, 5]);
assert_eq!(iv(&c.left_padding_arr().unwrap()), vec![-1, -1]);
let p2 = kvb(&[&[5.0, 6.0], &[15.0, 16.0]]);
let (mut k3, mut v3) = c.update(&p2, &p2).unwrap();
assert_eq!(k3.shape(), vec![2, 1, 5, 1], "over-retain max_size+S-1=5");
assert_eq!(
k3.to_vec::<f32>().unwrap(),
vec![2.0, 3.0, 4.0, 5.0, 6.0, 12.0, 13.0, 14.0, 15.0, 16.0],
"temporal-order then trim-1 then append (mixed path)"
);
assert_eq!(v3.to_vec::<f32>().unwrap(), k3.to_vec::<f32>().unwrap());
assert_eq!(iv(&c.batch_offset().unwrap()), vec![7, 7]);
assert_eq!(iv(&c.left_padding_arr().unwrap()), vec![-2, -2]);
}
#[test]
fn batch_rotating_b1_parity_with_single_seq_rotating() {
use mlxrs::lm::cache::RotatingKvCache;
let mut br = BatchRotatingKvCache::new(4, &[0]);
let mut sr = RotatingKvCache::new(4, 0);
let p: Vec<f32> = vec![0.0, 1.0, 2.0];
let pk = Array::from_slice::<f32>(&p, &(1usize, 1, 3, 1)).unwrap();
let (mut bk, _) = br.update(&pk, &pk).unwrap();
let (mut rk, _) = sr.update(&pk, &pk).unwrap();
assert_eq!(bk.to_vec::<f32>().unwrap(), rk.to_vec::<f32>().unwrap());
for step in 3..8 {
let t = Array::from_slice::<f32>(&[step as f32], &(1usize, 1, 1, 1)).unwrap();
let (mut b2, _) = br.update(&t, &t).unwrap();
let (mut r2, _) = sr.update(&t, &t).unwrap();
assert_eq!(
b2.to_vec::<f32>().unwrap(),
r2.to_vec::<f32>().unwrap(),
"physical buffer parity at step {step}"
);
assert_eq!(br.offset(), sr.offset(), "offset parity step {step}");
}
}
#[test]
fn batch_rotating_make_mask_distinct_override() {
let mut c = BatchRotatingKvCache::new(4, &[0, 0]);
match c.make_mask(3, None, false).unwrap() {
MaskMode::Array(mut m) => {
assert_eq!(m.shape(), vec![2, 1, 3, 3]);
let bits: Vec<u8> = m
.to_vec::<bool>()
.unwrap()
.into_iter()
.map(|b| b as u8)
.collect();
assert_eq!(
bits,
vec![
1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, ]
);
}
_ => panic!("BatchRotatingKVCache.make_mask must materialize its own mask"),
}
let p = kvb(&[&[0.0, 1.0, 2.0, 3.0], &[0.0, 1.0, 2.0, 3.0]]);
c.update(&p, &p).unwrap(); let d = kvb(&[&[4.0], &[4.0]]);
c.update(&d, &d).unwrap(); match c.make_mask(1, Some(4), false).unwrap() {
MaskMode::Array(m) => {
assert_eq!(m.shape()[m.shape().len() - 1], 4);
}
_ => panic!("rotated N==1 must still return a (rolled) mask array"),
}
}
#[test]
fn batch_rotating_state_meta_roundtrip_and_from_state() {
let mut c = BatchRotatingKvCache::new(4, &[0, 0]);
let p = kvb(&[&[0.0, 1.0, 2.0], &[10.0, 11.0, 12.0]]);
c.update(&p, &p).unwrap();
let d = kvb(&[&[3.0], &[13.0]]);
c.update(&d, &d).unwrap(); let meta = c.meta_state();
assert_eq!(meta, vec!["4", "4", "4", "false"]);
let st = c.state().unwrap();
assert_eq!(st.len(), 4); let restored = from_state("BatchRotatingKVCache", st, &meta).unwrap();
assert_eq!(restored.offset(), 4);
assert_eq!(restored.max_size(), Some(4));
assert!(restored.as_batch_positioned().is_some());
assert_eq!(
iv(
&restored
.as_batch_positioned()
.unwrap()
.batch_offset()
.unwrap()
),
vec![4, 4]
);
}
#[test]
fn batch_rotating_trim_semantics() {
let mut c = BatchRotatingKvCache::new(8, &[0, 0]);
let p = kvb(&[&[0.0, 1.0, 2.0], &[10.0, 11.0, 12.0]]);
c.update(&p, &p).unwrap(); assert!(c.is_trimmable());
assert_eq!(c.trim(2).unwrap(), 2); assert_eq!(iv(&c.batch_offset().unwrap()), vec![1, 1]);
let mut c2 = BatchRotatingKvCache::new(2, &[0]);
let big = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &(1usize, 1, 4, 1)).unwrap();
c2.update(&big, &big).unwrap(); assert!(!c2.is_trimmable());
}
#[test]
fn batch_rotating_wrong_rank_errors_not_panic() {
let bad = Array::from_slice::<f32>(&[1.0, 2.0], &(1usize, 2usize)).unwrap();
let mut c = BatchRotatingKvCache::new(4, &[0]);
let good3 = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1usize, 1, 3, 1)).unwrap();
assert!(c.update(&good3, &bad).is_err());
let mut c2 = BatchRotatingKvCache::new(4, &[0]);
c2.update(&good3, &good3).unwrap();
let good1 = Array::from_slice::<f32>(&[4.0], &(1usize, 1, 1, 1)).unwrap();
assert!(c2.update(&good1, &bad).is_err());
assert!(c2.update(&bad, &good1).is_err());
}
#[test]
fn batch_from_state_empty_with_nonzero_offset_is_rejected() {
let mut c = BatchKvCache::new(&[0, 0]);
let p = kvb(&[&[1.0, 2.0], &[3.0, 4.0]]);
c.update(&p, &p).unwrap();
let st = c.state().unwrap();
assert!(from_state("BatchKVCache", st, &[]).is_ok());
assert!(from_state("BatchKVCache", Vec::new(), &[]).is_ok());
let bad_meta = vec![
"4".to_string(),
"3".to_string(),
"3".to_string(),
"false".to_string(),
];
assert!(from_state("BatchRotatingKVCache", Vec::new(), &bad_meta).is_err());
let rotated_meta = vec![
"4".to_string(),
"0".to_string(),
"0".to_string(),
"true".to_string(),
];
assert!(
from_state("BatchRotatingKVCache", Vec::new(), &rotated_meta).is_err(),
"empty state with rotated=true must be rejected"
);
let idx_meta = vec![
"4".to_string(),
"0".to_string(),
"2".to_string(),
"false".to_string(),
];
assert!(
from_state("BatchRotatingKVCache", Vec::new(), &idx_meta).is_err(),
"empty state with non-zero _idx must be rejected"
);
let zero_meta = vec![
"4".to_string(),
"0".to_string(),
"0".to_string(),
"false".to_string(),
];
assert!(from_state("BatchRotatingKVCache", Vec::new(), &zero_meta).is_ok());
}
#[test]
fn batch_rotating_offset_overflow_is_rejected_without_partial_mutation() {
let max = usize::MAX.to_string();
for &n in &[1usize, 2usize] {
let mut seed = BatchRotatingKvCache::new(8, &[0, 0]);
let p = kvb(&[&[1.0, 2.0], &[3.0, 4.0]]);
seed.update(&p, &p).unwrap();
let st = seed.state().unwrap();
let meta = vec![
"8".to_string(),
max.clone(), "2".to_string(),
"false".to_string(),
];
let mut c = from_state("BatchRotatingKVCache", st, &meta).unwrap();
let off_before = c.offset(); let meta_before = c.meta_state();
let st_before = c.state().unwrap();
let (mut k0, _) = (st_before[0].try_clone().unwrap(), &st_before[1]);
let k0v = k0.to_vec::<f32>().unwrap();
let row: Vec<f32> = (0..n).map(|i| 100.0 + i as f32).collect();
let upd = kvb(&[&row, &row]);
assert!(
c.update(&upd, &upd).is_err(),
"overflow must be a recoverable Err (n={n})"
);
assert_eq!(c.offset(), off_before, "offset unchanged on Err (n={n})");
assert_eq!(c.meta_state(), meta_before, "meta unchanged on Err (n={n})");
let st_after = c.state().unwrap();
let mut k1 = st_after[0].try_clone().unwrap();
assert_eq!(
k1.to_vec::<f32>().unwrap(),
k0v,
"keys buffer unchanged on Err (n={n})"
);
assert_eq!(
st_after.len(),
st_before.len(),
"state arity unchanged (n={n})"
);
}
}
#[test]
fn batch_from_state_rank_invalid_values_is_err_not_panic() {
let good_k = kvb(&[&[1.0, 2.0], &[3.0, 4.0]]);
let off = Array::from_slice::<i32>(&[2, 2], &(2usize,)).unwrap();
let lp = Array::from_slice::<i32>(&[0, 0], &(2usize,)).unwrap();
let bad_v = Array::from_slice::<f32>(&[1.0, 2.0], &(2usize, 1usize)).unwrap();
let st_bk = vec![
good_k.try_clone().unwrap(),
bad_v.try_clone().unwrap(),
off.try_clone().unwrap(),
lp.try_clone().unwrap(),
];
assert!(
from_state("BatchKVCache", st_bk, &[]).is_err(),
"BatchKVCache from_state with rank-invalid values must Err, not panic"
);
let meta = vec![
"8".to_string(),
"2".to_string(),
"2".to_string(),
"false".to_string(),
];
let st_br = vec![
good_k.try_clone().unwrap(),
bad_v.try_clone().unwrap(),
off.try_clone().unwrap(),
lp.try_clone().unwrap(),
];
assert!(
from_state("BatchRotatingKVCache", st_br, &meta).is_err(),
"BatchRotatingKVCache from_state with rank-invalid values must Err, not panic"
);
let bad_k = Array::from_slice::<f32>(&[1.0, 2.0], &(2usize, 1usize)).unwrap();
let good_v = kvb(&[&[1.0, 2.0], &[3.0, 4.0]]);
let st_bk2 = vec![
bad_k.try_clone().unwrap(),
good_v.try_clone().unwrap(),
off.try_clone().unwrap(),
lp.try_clone().unwrap(),
];
assert!(from_state("BatchKVCache", st_bk2, &[]).is_err());
}
#[test]
fn batch_finalize_batch_mismatch_err_leaves_state_unchanged() {
let k_b2 = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2usize, 1, 2, 1)).unwrap();
let v_b3 = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &(3usize, 1, 2, 1)).unwrap();
let off2 = Array::from_slice::<i32>(&[2, 2], &(2usize,)).unwrap();
let lp2 = Array::from_slice::<i32>(&[0, 0], &(2usize,)).unwrap();
let st = vec![
k_b2.try_clone().unwrap(),
v_b3.try_clone().unwrap(),
off2.try_clone().unwrap(),
lp2.try_clone().unwrap(),
];
let c = match from_state("BatchKVCache", st, &[]) {
Ok(c) => c,
Err(_) => return, };
let off_before = c.offset();
let st_before = c.state().unwrap();
let mut k_before = st_before[0].try_clone().unwrap();
let kb = k_before.to_vec::<f32>().unwrap();
let _ = c.make_mask(1, None, false);
assert_eq!(c.offset(), off_before);
let st_after = c.state().unwrap();
let mut k_after = st_after[0].try_clone().unwrap();
assert_eq!(k_after.to_vec::<f32>().unwrap(), kb);
let mut tc = BatchKvCache::new(&[0, 0]);
tc.set_state(vec![
k_b2.try_clone().unwrap(),
v_b3.try_clone().unwrap(),
off2.try_clone().unwrap(),
lp2.try_clone().unwrap(),
])
.unwrap();
tc.prepare_right_padding(&[1, 1]).unwrap();
let bo = iv(&tc.batch_offset().unwrap());
let lpb = iv(&tc.left_padding_arr().unwrap());
assert!(
tc.finalize().is_err(),
"B-mismatched finalize must be a recoverable Err"
);
assert_eq!(
iv(&tc.batch_offset().unwrap()),
bo,
"offset unchanged on Err"
);
assert_eq!(
iv(&tc.left_padding_arr().unwrap()),
lpb,
"left_padding unchanged on Err"
);
let mut rc = BatchRotatingKvCache::new(8, &[0, 0]);
rc.set_state(vec![
k_b2.try_clone().unwrap(),
v_b3.try_clone().unwrap(),
off2.try_clone().unwrap(),
lp2.try_clone().unwrap(),
])
.unwrap();
rc.prepare_right_padding(&[2, 2], &[1, 1]).unwrap();
let rbo = iv(&rc.batch_offset().unwrap());
let rlpb = iv(&rc.left_padding_arr().unwrap());
assert!(
rc.finalize().is_err(),
"B-mismatched batch-rotating finalize must be a recoverable Err"
);
assert_eq!(iv(&rc.batch_offset().unwrap()), rbo);
assert_eq!(iv(&rc.left_padding_arr().unwrap()), rlpb);
}
#[test]
fn batch_update_kv_shape_mismatch_is_err_not_desync() {
let mk = |s: usize| {
let row = vec![0.0f32; s];
kvb(&[row.as_slice(), row.as_slice()])
};
let v_b3 = |s: usize| {
let data = vec![0.0f32; 3 * s];
Array::from_slice::<f32>(&data, &(3usize, 1, s, 1)).unwrap()
};
let k_ok = kvb(&[&[1.0, 2.0], &[3.0, 4.0]]);
let v_hd_ok = Array::from_slice::<f32>(&[9.0f32; 16], &(2usize, 1, 2, 4)).unwrap();
let mut c = BatchKvCache::new(&[0, 0]);
assert!(
c.update(&mk(2), &v_b3(2)).is_err(),
"empty BatchKvCache update with B-mismatched values must Err"
);
assert!(c.is_empty(), "no partial mutation on the rejected update");
assert!(
c.update(&k_ok, &v_hd_ok).is_ok(),
"B/H/S match with differing head_dim must be accepted"
);
let mut c2 = BatchKvCache::new(&[0, 0]);
c2.update(&mk(2), &mk(2)).unwrap();
assert!(c2.update(&mk(1), &v_b3(1)).is_err());
let mut rc = BatchRotatingKvCache::new(8, &[0, 0]);
assert!(
rc.update(&mk(3), &v_b3(3)).is_err(),
"batch-rotating S>1 with B-mismatched values must Err"
);
assert!(rc.is_empty());
let mut rc2 = BatchRotatingKvCache::new(8, &[0, 0]);
rc2.update(&mk(3), &mk(3)).unwrap();
assert!(
rc2.update(&mk(1), &v_b3(1)).is_err(),
"batch-rotating S==1 with B-mismatched values must Err"
);
}
#[test]
fn batch_rotating_idx_overflow_is_rejected_not_panic() {
let p = kvb(&[&[0.0, 1.0, 2.0], &[10.0, 11.0, 12.0]]);
let mut s = BatchRotatingKvCache::new(8, &[0, 0]);
s.update(&p, &p).unwrap();
let good_state = s.state().unwrap();
let meta_idx_max = vec![
"8".to_string(),
"3".to_string(),
usize::MAX.to_string(),
"false".to_string(),
];
assert!(
from_state("BatchRotatingKVCache", good_state, &meta_idx_max).is_err(),
"non-empty restore with _idx=usize::MAX must be rejected at from_state"
);
let mut s2 = BatchRotatingKvCache::new(8, &[0, 0]);
s2.update(&p, &p).unwrap();
let st2 = s2.state().unwrap();
let meta_oob = vec![
"8".to_string(),
"3".to_string(),
"5".to_string(), "false".to_string(),
];
assert!(
from_state("BatchRotatingKVCache", st2, &meta_oob).is_err(),
"non-empty restore with _idx > keys.shape[-2] must be rejected at from_state"
);
let mut s3 = BatchRotatingKvCache::new(8, &[0, 0]);
s3.update(&p, &p).unwrap();
let st3 = s3.state().unwrap(); let meta_rot = vec![
"8".to_string(),
"3".to_string(),
"3".to_string(),
"true".to_string(),
];
assert!(
from_state("BatchRotatingKVCache", st3, &meta_rot).is_err(),
"non-empty restore with rotated=true but buffer!=max_size must be rejected"
);
let mut s5 = BatchRotatingKvCache::new(8, &[0, 0]);
s5.update(&p, &p).unwrap(); let st5 = s5.state().unwrap();
let meta_loff = vec![
"8".to_string(),
"1".to_string(), "0".to_string(),
"false".to_string(),
];
assert!(
from_state("BatchRotatingKVCache", st5, &meta_loff).is_err(),
"non-empty restore with keys length > _offset must be rejected at from_state"
);
let mut s4 = BatchRotatingKvCache::new(4, &[0, 0]);
s4.update(&p, &p).unwrap(); let d = kvb(&[&[3.0], &[13.0]]);
s4.update(&d, &d).unwrap(); let meta4 = s4.meta_state();
let st4 = s4.state().unwrap();
let mut ok = from_state("BatchRotatingKVCache", st4, &meta4).unwrap();
assert_eq!(ok.offset(), 4);
let d2 = kvb(&[&[4.0], &[14.0]]);
assert!(ok.update(&d2, &d2).is_ok());
}
#[test]
fn batch_kv_set_state_empty_resets_offset_and_padding() {
let lp = [1i32, 3, 0];
let mut c = BatchKvCache::new(&lp);
let a = kvb(&[&[1.0], &[2.0], &[3.0]]);
c.update(&a, &a).unwrap();
let b = kvb(&[&[4.0], &[5.0], &[6.0]]);
c.update(&b, &b).unwrap();
let before = c.batch_offset().unwrap().to_vec::<i32>().unwrap();
assert_ne!(before, vec![-1, -3, 0], "sanity: updates advanced offset");
c.set_state(Vec::new()).unwrap();
assert!(c.is_empty(), "keys/values cleared");
assert_eq!(c.offset(), 0, "_idx cleared");
let after = c.batch_offset().unwrap().to_vec::<i32>().unwrap();
assert_eq!(after, vec![-1, -3, 0], "offset reset to -left_padding");
}
#[test]
fn batch_rotating_set_state_empty_resets_all_metadata() {
let lp = [0i32, 0];
let max_size = 4;
let mut c = BatchRotatingKvCache::new(max_size, &lp);
for token_idx in 0..5 {
let t = kvb(&[&[token_idx as f32], &[(10 + token_idx) as f32]]);
c.update(&t, &t).unwrap();
}
assert!(c.offset() >= 5, "sanity: 5 updates advanced offset");
c.set_state(Vec::new()).unwrap();
assert!(c.is_empty(), "(a) keys/values cleared");
assert_eq!(c.offset(), 0, "(b) scalar _off cleared");
let after = c.batch_offset().unwrap().to_vec::<i32>().unwrap();
let lp_arr = c.left_padding_arr().unwrap().to_vec::<i32>().unwrap();
let expected: Vec<i32> = lp_arr.iter().map(|&l| -l).collect();
assert_eq!(
after, expected,
"(c) offset reset to -self.left_padding at reset time"
);
let next = kvb(&[&[99.0], &[199.0]]);
assert!(
c.update(&next, &next).is_ok(),
"(d) update after reset works"
);
}
#[test]
fn dynamic_roll_rejects_n_above_f32_exact_int_max() {
const LIMIT: usize = 1usize << 24;
let shifts_small = Array::from_slice::<i32>(&[0], &(1usize, 1)).unwrap();
let too_big = Array::zeros::<f32>(&(1usize, 1, LIMIT + 1, 1)).unwrap();
let r = dynamic_roll(&too_big, &shifts_small, 2);
assert!(
matches!(&r, Err(mlxrs::Error::OutOfRange(_))),
"dynamic_roll(n=2^24+1) must Err(OutOfRange), got {r:?}"
);
let small = Array::zeros::<f32>(&(1usize, 1, 3, 1)).unwrap();
let r = dynamic_roll(&small, &shifts_small, 2);
assert!(
r.is_ok(),
"dynamic_roll on small input must succeed, got {r:?}"
);
}
#[test]
fn dynamic_roll_rejects_rank_mismatch_shifts() {
let x = kvb(&[&[10.0, 20.0, 30.0], &[40.0, 50.0, 60.0]]);
let shifts_rank1 = Array::from_slice::<i32>(&[0, 1], &(2usize,)).unwrap();
let err = dynamic_roll(&x, &shifts_rank1, 2).expect_err("rank-1 shifts must Err");
match err {
mlxrs::Error::RankMismatch(p) => {
assert_eq!(p.actual(), 1, "rank-1 shifts: payload.actual must be 1");
assert_eq!(
p.actual_shape(),
&[2usize][..],
"rank-1 shifts: payload.actual_shape must be [B]"
);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
let shifts_rank3 = Array::from_slice::<i32>(&[0, 1], &(2usize, 1usize, 1usize)).unwrap();
let err = dynamic_roll(&x, &shifts_rank3, 2).expect_err("rank-3 shifts must Err");
match err {
mlxrs::Error::RankMismatch(p) => {
assert_eq!(p.actual(), 3, "rank-3 shifts: payload.actual must be 3");
assert_eq!(
p.actual_shape(),
&[2usize, 1, 1][..],
"rank-3 shifts: payload.actual_shape must be [B, 1, 1]"
);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
let shifts_wrong_b = Array::from_slice::<i32>(&[0, 0, 0], &(3usize, 1usize)).unwrap();
let err = dynamic_roll(&x, &shifts_wrong_b, 2).expect_err("[3,1] shifts on B=2 must Err");
match err {
mlxrs::Error::ShapePairMismatch(p) => {
assert_eq!(p.expected(), &[2usize, 1][..]);
assert_eq!(p.actual(), &[3usize, 1][..]);
}
other => panic!("expected ShapePairMismatch, got {other:?}"),
}
let shifts_wrong_k = Array::from_slice::<i32>(&[0, 0, 1, 1], &(2usize, 2usize)).unwrap();
let err = dynamic_roll(&x, &shifts_wrong_k, 2).expect_err("[B, 2] shifts must Err");
match err {
mlxrs::Error::ShapePairMismatch(p) => {
assert_eq!(p.expected(), &[2usize, 1][..]);
assert_eq!(p.actual(), &[2usize, 2][..]);
}
other => panic!("expected ShapePairMismatch, got {other:?}"),
}
}
#[test]
fn dynamic_roll_accepts_scalar_broadcast_shifts() {
let x = kvb(&[&[10.0, 20.0, 30.0], &[40.0, 50.0, 60.0]]);
let shifts_scalar = Array::from_slice::<i32>(&[1], &(1usize, 1usize)).unwrap();
let result = dynamic_roll(&x, &shifts_scalar, 2);
assert!(
result.is_ok(),
"scalar broadcast [1,1] shifts must be accepted; got {result:?}"
);
let mut rolled = result.unwrap();
assert_eq!(
rolled.shape(),
vec![2, 1, 3, 1],
"broadcast result shape must match input shape"
);
assert_eq!(
rolled.to_vec::<f32>().unwrap(),
vec![30.0, 10.0, 20.0, 60.0, 40.0, 50.0],
"scalar broadcast shift=1 must roll every row by 1"
);
let shifts_zero = Array::from_slice::<i32>(&[0], &(1usize, 1usize)).unwrap();
let mut identity = dynamic_roll(&x, &shifts_zero, 2).expect("scalar broadcast [1,1] shift=0");
assert_eq!(
identity.to_vec::<f32>().unwrap(),
vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
"scalar broadcast shift=0 must be identity"
);
}
#[test]
fn batch_rotating_from_state_error_message_names_violated_invariant() {
let kv =
|seq_len: usize| -> Array { Array::zeros::<f32>(&(2usize, 1usize, seq_len, 1usize)).unwrap() };
let bad_empty = vec![
"8".to_string(), "5".to_string(), "0".to_string(), "false".to_string(),
];
let r = from_state("BatchRotatingKVCache", vec![], &bad_empty);
match r {
Err(mlxrs::Error::OutOfRange(p)) => {
assert!(
p.context().contains("empty buffer"),
"context must name the empty-arm condition; got: {}",
p.context()
);
assert!(
p.value().contains("offset=5"),
"value must name the offending offset; got: {}",
p.value()
);
}
Err(other) => panic!("expected OutOfRange empty-arm, got {other:?}"),
Ok(_) => panic!("from_state empty + offset=5 must Err"),
}
let k = kv(3);
let v = kv(3);
let off = Array::from_slice::<i32>(&[0, 0], &(2usize,)).unwrap();
let lp = Array::from_slice::<i32>(&[0, 0], &(2usize,)).unwrap();
let st = vec![
k.try_clone().unwrap(),
v.try_clone().unwrap(),
off.try_clone().unwrap(),
lp.try_clone().unwrap(),
];
let bad_idx = vec![
"8".to_string(),
"5".to_string(),
"7".to_string(),
"false".to_string(),
];
let r = from_state("BatchRotatingKVCache", st, &bad_idx);
let err_msg = match r {
Err(e) => format!("{e}"),
Ok(_) => panic!("from_state with _idx > L must Err"),
};
assert!(
err_msg.contains("_idx") && err_msg.contains("7") && err_msg.contains("3"),
"error must name _idx and the offending values (idx=7, L=3); got: {err_msg}"
);
let k = kv(3);
let v = kv(3);
let st = vec![
k.try_clone().unwrap(),
v.try_clone().unwrap(),
off.try_clone().unwrap(),
lp.try_clone().unwrap(),
];
let bad_rot = vec![
"8".to_string(),
"3".to_string(),
"3".to_string(),
"true".to_string(),
];
let r = from_state("BatchRotatingKVCache", st, &bad_rot);
let err_msg = match r {
Err(e) => format!("{e}"),
Ok(_) => panic!("from_state with rotated=true && L != max_size must Err"),
};
assert!(
err_msg.contains("rotated") && err_msg.contains("max_size"),
"error must name `rotated` and `max_size`; got: {err_msg}"
);
let k = kv(5);
let v = kv(5);
let st = vec![
k.try_clone().unwrap(),
v.try_clone().unwrap(),
off.try_clone().unwrap(),
lp.try_clone().unwrap(),
];
let bad_off = vec![
"8".to_string(),
"3".to_string(),
"3".to_string(),
"false".to_string(),
];
let r = from_state("BatchRotatingKVCache", st, &bad_off);
let err_msg = match r {
Err(e) => format!("{e}"),
Ok(_) => panic!("from_state with L > _offset must Err"),
};
assert!(
err_msg.contains("L")
&& err_msg.contains("5")
&& err_msg.contains("_offset")
&& err_msg.contains("3"),
"error must name L, _offset, and the offending values (L=5, _offset=3); got: {err_msg}"
);
}
#[test]
fn dynamic_roll_n_zero_is_noop_clone() {
let empty = Array::zeros::<f32>(&(1usize, 1, 0, 1)).unwrap();
let shifts = Array::from_slice::<i32>(&[3], &(1usize, 1)).unwrap();
let r = dynamic_roll(&empty, &shifts, 2);
assert!(r.is_ok(), "dynamic_roll(n=0) must Ok-clone, got {r:?}");
let rolled = r.unwrap();
assert_eq!(rolled.shape(), &[1, 1, 0, 1]);
}
#[test]
fn batch_rotating_update_concat_clears_rotated_after_mixed_path() {
let lp = [0i32, 0];
let mut c = BatchRotatingKvCache::new(4, &lp);
let p3 = kvb(&[&[0.0, 1.0, 2.0], &[10.0, 11.0, 12.0]]);
c.update(&p3, &p3).unwrap();
for tok in 3..5 {
let d = kvb(&[&[tok as f32], &[(10 + tok) as f32]]);
c.update(&d, &d).unwrap();
}
let s2 = kvb(&[&[5.0, 6.0], &[15.0, 16.0]]);
c.update(&s2, &s2).unwrap();
let st = c.state().unwrap();
let meta = c.meta_state();
let restored = match from_state("BatchRotatingKVCache", st, &meta) {
Ok(c) => c,
Err(e) => panic!("save/load round-trip after mixed-path update must succeed, got Err({e})"),
};
assert_eq!(restored.offset(), c.offset());
}
#[test]
fn batch_kv_pad_lengths_constructor_is_borrowed_slice() {
let lp = [3i32, 1, 0];
let c = BatchKvCache::new(&lp);
assert_eq!(c.pad_lengths(), &lp);
let c0 = BatchKvCache::new(&[]);
assert!(c0.pad_lengths().is_empty());
}
#[test]
fn batch_kv_pad_lengths_updates_after_set_state() {
let initial = [0i32, 0];
let mut c = BatchKvCache::new(&initial);
let restored_lp_vals = [2i32, 4];
let lp_arr = Array::from_slice::<i32>(&restored_lp_vals, &(2usize,)).unwrap();
let off_arr = Array::from_slice::<i32>(&[-2, -4], &(2usize,)).unwrap();
let k = kvb(&[&[1.0], &[2.0]]);
let v = kvb(&[&[3.0], &[4.0]]);
c.set_state(vec![k, v, off_arr, lp_arr]).unwrap();
assert_eq!(
c.pad_lengths(),
&restored_lp_vals,
"set_state must materialize the host mirror once at restore time"
);
}
#[test]
fn batch_kv_pad_lengths_updates_after_finalize() {
let mut c = BatchKvCache::new(&[1i32, 3]);
assert_eq!(c.pad_lengths(), &[1, 3]);
let k = kvb(&[&[10.0], &[20.0]]);
c.update(&k, &k).unwrap();
c.prepare_right_padding(&[2, 0]).unwrap();
c.finalize().unwrap();
assert_eq!(
c.pad_lengths(),
&[3, 3],
"left_padding += right_padding mirrored in host pad_lengths"
);
}
#[test]
fn batch_rotating_pad_lengths_constructor_and_set_state() {
let lp = [2i32, 0];
let c = BatchRotatingKvCache::new(4, &lp);
assert_eq!(c.pad_lengths(), &lp);
let mut d = BatchRotatingKvCache::new(4, &[0, 0]);
let restored_lp = [1i32, 5];
let lp_arr = Array::from_slice::<i32>(&restored_lp, &(2usize,)).unwrap();
let off_arr = Array::from_slice::<i32>(&[-1, -5], &(2usize,)).unwrap();
let k = kvb(&[&[1.0], &[2.0]]);
let v = kvb(&[&[3.0], &[4.0]]);
d.set_state(vec![k, v, off_arr, lp_arr]).unwrap();
assert_eq!(d.pad_lengths(), &restored_lp);
}
#[test]
fn batch_rotating_rotated_flag_observable_through_meta_state() {
let lp = [0i32];
let mut c = BatchRotatingKvCache::new(3, &lp); let p = kvb(&[&[0.0, 1.0, 2.0]]);
c.update(&p, &p).unwrap();
let meta_before = c.meta_state();
assert_eq!(
meta_before[3], "false",
"pre-wrap meta_state.rotated == false"
);
let d = kvb(&[&[3.0]]);
c.update(&d, &d).unwrap();
let meta_after = c.meta_state();
assert_eq!(
meta_after[3], "true",
"post-wrap meta_state.rotated == true"
);
assert_eq!(c.offset(), 4);
let st = c.state().unwrap();
let restored = match from_state("BatchRotatingKVCache", st, &meta_after) {
Ok(c) => c,
Err(e) => {
panic!("post-rotation round-trip must succeed (proves rotated coherence), got Err({e})")
}
};
assert_eq!(restored.offset(), c.offset());
}
#[test]
fn batch_kv_finalize_with_scalar_right_padding_broadcasts_or_errs() {
let mut c = BatchKvCache::new(&[1i32, 3]);
assert_eq!(c.pad_lengths(), &[1, 3]);
let k = kvb(&[&[10.0], &[20.0]]);
c.update(&k, &k).unwrap();
c.prepare_right_padding(&[5]).unwrap();
c.finalize().unwrap();
assert_eq!(
c.pad_lengths(),
&[6, 8],
"length-1 right_padding MUST broadcast across pad_lengths (was: silently stale [1, 3])"
);
let mut d = BatchKvCache::new(&[1i32, 3]);
let k2 = kvb(&[&[10.0], &[20.0]]);
d.update(&k2, &k2).unwrap();
let lp_before = iv(&d.left_padding_arr().unwrap());
let pl_before: Vec<i32> = d.pad_lengths().to_vec();
d.prepare_right_padding(&[1, 1, 1]).unwrap();
assert!(
d.finalize().is_err(),
"right_padding length 3 vs pad_lengths length 2 MUST Err"
);
assert_eq!(
d.pad_lengths(),
pl_before.as_slice(),
"pad_lengths unchanged on Err"
);
assert_eq!(
iv(&d.left_padding_arr().unwrap()),
lp_before,
"left_padding unchanged on Err"
);
}
#[test]
fn batch_kv_set_state_propagates_to_vec_failure() {
let mut c = BatchKvCache::new(&[1i32, 2]);
let pl_before: Vec<i32> = c.pad_lengths().to_vec();
let lp_before = iv(&c.left_padding_arr().unwrap());
let lp_f32 = Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap();
let off = Array::from_slice::<i32>(&[-1, -2], &(2usize,)).unwrap();
let k = kvb(&[&[1.0], &[2.0]]);
let v = kvb(&[&[3.0], &[4.0]]);
assert!(
c.set_state(vec![k, v, off, lp_f32]).is_err(),
"non-I32 left_padding MUST Err"
);
assert_eq!(
c.pad_lengths(),
pl_before.as_slice(),
"pad_lengths unchanged on Err"
);
assert_eq!(
iv(&c.left_padding_arr().unwrap()),
lp_before,
"left_padding unchanged on Err"
);
let lp_2d = Array::from_slice::<i32>(&[1, 2, 3, 4], &(2usize, 2)).unwrap();
let off2 = Array::from_slice::<i32>(&[-1, -2], &(2usize,)).unwrap();
let k2 = kvb(&[&[1.0], &[2.0]]);
let v2 = kvb(&[&[3.0], &[4.0]]);
assert!(
c.set_state(vec![k2, v2, off2, lp_2d]).is_err(),
"2-D left_padding MUST Err (rank validation)"
);
assert_eq!(c.pad_lengths(), pl_before.as_slice());
assert_eq!(iv(&c.left_padding_arr().unwrap()), lp_before);
let lp_3 = Array::from_slice::<i32>(&[1, 2, 3], &(3usize,)).unwrap();
let off3 = Array::from_slice::<i32>(&[-1, -2], &(2usize,)).unwrap();
let k3 = kvb(&[&[1.0], &[2.0]]);
let v3 = kvb(&[&[3.0], &[4.0]]);
assert!(
c.set_state(vec![k3, v3, off3, lp_3]).is_err(),
"length-mismatched left_padding MUST Err"
);
assert_eq!(c.pad_lengths(), pl_before.as_slice());
assert_eq!(iv(&c.left_padding_arr().unwrap()), lp_before);
let good_lp = Array::from_slice::<i32>(&[3, 4], &(2usize,)).unwrap();
let good_off = Array::from_slice::<i32>(&[-3, -4], &(2usize,)).unwrap();
let good_k = kvb(&[&[1.0], &[2.0]]);
let good_v = kvb(&[&[3.0], &[4.0]]);
c.set_state(vec![good_k, good_v, good_off, good_lp])
.unwrap();
assert_eq!(
c.pad_lengths(),
&[3, 4],
"well-formed restore must update pad_lengths"
);
}
#[test]
fn batch_rotating_set_state_propagates_to_vec_failure() {
let mut c = BatchRotatingKvCache::new(4, &[1i32, 2]);
let pl_before: Vec<i32> = c.pad_lengths().to_vec();
let lp_before = iv(&c.left_padding_arr().unwrap());
let lp_f32 = Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap();
let off = Array::from_slice::<i32>(&[-1, -2], &(2usize,)).unwrap();
let k = kvb(&[&[1.0], &[2.0]]);
let v = kvb(&[&[3.0], &[4.0]]);
assert!(
c.set_state(vec![k, v, off, lp_f32]).is_err(),
"rotating: non-I32 left_padding MUST Err"
);
assert_eq!(c.pad_lengths(), pl_before.as_slice());
assert_eq!(iv(&c.left_padding_arr().unwrap()), lp_before);
let lp_2d = Array::from_slice::<i32>(&[1, 2, 3, 4], &(2usize, 2)).unwrap();
let off2 = Array::from_slice::<i32>(&[-1, -2], &(2usize,)).unwrap();
let k2 = kvb(&[&[1.0], &[2.0]]);
let v2 = kvb(&[&[3.0], &[4.0]]);
assert!(
c.set_state(vec![k2, v2, off2, lp_2d]).is_err(),
"rotating: 2-D left_padding MUST Err"
);
assert_eq!(c.pad_lengths(), pl_before.as_slice());
assert_eq!(iv(&c.left_padding_arr().unwrap()), lp_before);
let lp_3 = Array::from_slice::<i32>(&[1, 2, 3], &(3usize,)).unwrap();
let off3 = Array::from_slice::<i32>(&[-1, -2], &(2usize,)).unwrap();
let k3 = kvb(&[&[1.0], &[2.0]]);
let v3 = kvb(&[&[3.0], &[4.0]]);
assert!(
c.set_state(vec![k3, v3, off3, lp_3]).is_err(),
"rotating: length-mismatched left_padding MUST Err"
);
assert_eq!(c.pad_lengths(), pl_before.as_slice());
assert_eq!(iv(&c.left_padding_arr().unwrap()), lp_before);
let good_lp = Array::from_slice::<i32>(&[3, 4], &(2usize,)).unwrap();
let good_off = Array::from_slice::<i32>(&[-3, -4], &(2usize,)).unwrap();
let good_k = kvb(&[&[1.0], &[2.0]]);
let good_v = kvb(&[&[3.0], &[4.0]]);
c.set_state(vec![good_k, good_v, good_off, good_lp])
.unwrap();
assert_eq!(c.pad_lengths(), &[3, 4]);
}
#[test]
fn batch_rotating_finalize_propagates_to_vec_failure() {
let src = std::fs::read_to_string(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/lm/cache/batch_rotating.rs"
))
.expect("batch_rotating.rs must be readable for structural regression check");
assert!(
!src.contains("to_vec::<i32>()\n .unwrap_or_else"),
"BatchRotatingKvCache must not swallow `to_vec::<i32>` failures via \
`unwrap_or_else(|_| self.pad_lengths.clone())`: such a fallback commits a \
stale `pad_lengths` host mirror against a freshly-rolled `left_padding` \
Array. Use `?` propagation BEFORE the infallible commit tail."
);
assert!(
!src.contains("to_vec::<i32>().unwrap_or_else"),
"BatchRotatingKvCache must not use any chained `to_vec::<i32>().unwrap_or_else`"
);
}
#[test]
fn batch_rotating_update_concat_propagates_to_vec_failure() {
let src = std::fs::read_to_string(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/lm/cache/batch_rotating.rs"
))
.expect("batch_rotating.rs must be readable for structural regression check");
let strict_count = src.matches("to_vec::<i32>()?").count();
assert!(
strict_count >= 4,
"Expected ≥4 strict `to_vec::<i32>()?` sites in batch_rotating.rs \
(finalize + update_concat + update_in_place + set_state) — found {strict_count}. \
The dirty-left_padding paths must use `?` propagation."
);
}
#[test]
fn batch_rotating_update_in_place_propagates_to_vec_failure() {
let src = std::fs::read_to_string(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/lm/cache/batch_rotating.rs"
))
.expect("batch_rotating.rs must be readable for structural regression check");
assert!(
!src.contains(".unwrap_or_else(|_| self.pad_lengths.clone())"),
"BatchRotatingKvCache must not swallow extraction failures with \
`unwrap_or_else(|_| self.pad_lengths.clone())` in any of \
finalize / update_concat / update_in_place."
);
}