#![cfg(feature = "lm")]
use mlxrs::{
Array, Error, Stream,
lm::cache::{
ArraysCache, BatchKvCache, BatchRotatingKvCache, CacheConfig, CacheList, ChunkedKvCache,
KvCache, MaskMode, RopeOffset, RotatingKvCache, StandardKvCache, StandardQuantizedKvCache,
create_attention_mask, create_causal_mask, from_state, make_prompt_cache,
},
};
fn kv(ids: &[f32]) -> Array {
Array::from_slice::<f32>(ids, &(1usize, 1, ids.len(), 1)).unwrap()
}
#[test]
fn standard_append_offset_trim() {
let mut c = StandardKvCache::new();
assert!(c.is_empty());
let (mut k1, mut v1) = c
.update(&kv(&[0.0, 1.0, 2.0, 3.0]), &kv(&[0.0, 1.0, 2.0, 3.0]))
.unwrap();
assert!(!c.is_empty());
assert_eq!(k1.shape(), vec![1, 1, 4, 1]);
assert_eq!(c.offset(), 4);
assert_eq!(k1.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
assert_eq!(v1.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
let (mut k2, _) = c.update(&kv(&[4.0, 5.0]), &kv(&[4.0, 5.0])).unwrap();
assert_eq!(k2.shape(), vec![1, 1, 6, 1]);
assert_eq!(c.offset(), 6);
assert_eq!(
k2.to_vec::<f32>().unwrap(),
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
);
let trimmed = c.trim(2).unwrap();
assert_eq!(trimmed, 2);
assert_eq!(c.offset(), 4);
let (mut k3, _) = c.update(&kv(&[9.0]), &kv(&[9.0])).unwrap();
assert_eq!(c.offset(), 5);
assert_eq!(k3.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 9.0]);
let mut c2 = StandardKvCache::new();
c2.update(&kv(&[0.0, 1.0]), &kv(&[0.0, 1.0])).unwrap();
assert_eq!(c2.trim(10).unwrap(), 2);
assert_eq!(c2.offset(), 0);
}
#[test]
fn standard_wrong_rank_errors() {
let mut c = StandardKvCache::new();
let bad = Array::from_slice::<f32>(&[1.0, 2.0], &(1usize, 2)).unwrap();
assert!(c.update(&bad, &bad).is_err());
}
#[test]
fn rotating_update_in_place_rank_invalid_values_errors_no_panic() {
let mut c = RotatingKvCache::new(8, 4);
let keys = kv(&[0.0]);
let bad_values = Array::from_slice::<f32>(&[0.0, 1.0], &(1usize, 2)).unwrap();
let r = c.update(&keys, &bad_values);
match &r {
Err(mlxrs::Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"seq_len: KV cache expects 4-D values [B, n_kv_heads, S, head_dim]"
);
assert_eq!(p.actual(), 2);
assert_eq!(p.actual_shape(), &[1usize, 2]);
}
_ => panic!(
"rank-invalid values on the S==1 path must be a recoverable \
RankMismatch, got {r:?}"
),
}
}
#[test]
fn rotating_update_concat_rank_invalid_values_errors_no_panic() {
let mut c = RotatingKvCache::new(6, 2);
let keys = kv(&[2.0, 3.0]);
let bad_values = Array::from_slice::<f32>(&[2.0, 3.0], &(1usize, 2)).unwrap();
let r = c.update(&keys, &bad_values);
match &r {
Err(mlxrs::Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"seq_len: KV cache expects 4-D values [B, n_kv_heads, S, head_dim]"
);
assert_eq!(p.actual(), 2);
assert_eq!(p.actual_shape(), &[1usize, 2]);
}
_ => panic!(
"rank-invalid values on the empty-cache S>1 path must be a DETERMINISTIC \
recoverable RankMismatch (per-tensor rank guard at update entry), got \
{r:?}"
),
}
c.update(&keys, &keys).unwrap();
}
#[test]
fn rotating_update_concat_single_part_fast_path_rank_invalid_no_corruption() {
let mut c = RotatingKvCache::new(1, 0);
let seed = kv(&[0.0]);
c.update(&seed, &seed).unwrap();
let keys = kv(&[1.0, 2.0]);
let bad_values = Array::from_slice::<f32>(&[1.0, 2.0], &(1usize, 2)).unwrap();
let r = c.update(&keys, &bad_values);
match &r {
Err(mlxrs::Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"seq_len: KV cache expects 4-D values [B, n_kv_heads, S, head_dim]"
);
assert_eq!(p.actual(), 2);
assert_eq!(p.actual_shape(), &[1usize, 2]);
}
_ => panic!(
"rank-invalid lone-surviving values must be a DETERMINISTIC recoverable \
RankMismatch (per-tensor rank guard at update entry rejects it before \
dispatch), got {r:?}"
),
}
let good = kv(&[3.0]);
let r2 = c.update(&good, &good);
assert!(
r2.is_ok(),
"a valid update after a rejected rank-invalid one must succeed (no \
cache corruption / no Result-path panic), got {r2:?}"
);
}
#[test]
fn standard_rank_invalid_values_errors_no_panic() {
let mut c = StandardKvCache::new();
let seed = kv(&[0.0, 1.0]);
c.update(&seed, &seed).unwrap();
let keys = kv(&[2.0]);
let bad_values = Array::from_slice::<f32>(&[2.0, 3.0], &(1usize, 2)).unwrap();
let r = c.update(&keys, &bad_values);
match &r {
Err(mlxrs::Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"seq_len: KV cache expects 4-D values [B, n_kv_heads, S, head_dim]"
);
assert_eq!(p.actual(), 2);
assert_eq!(p.actual_shape(), &[1usize, 2]);
}
_ => panic!(
"rank-invalid values must be a DETERMINISTIC recoverable RankMismatch \
(per-tensor rank guard at update entry), got {r:?}"
),
}
}
#[test]
fn rotating_keeps_prefix_and_window() {
let mut c = RotatingKvCache::new(8, 4);
assert!(c.is_empty());
assert!(c.is_trimmable());
let expected: [(&[f32], usize); 12] = [
(&[0.0], 1),
(&[0.0, 1.0], 2),
(&[0.0, 1.0, 2.0], 3),
(&[0.0, 1.0, 2.0, 3.0], 4),
(&[0.0, 1.0, 2.0, 3.0, 4.0], 5),
(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], 6),
(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 7),
(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], 8),
(&[0.0, 1.0, 2.0, 3.0, 8.0, 5.0, 6.0, 7.0], 9),
(&[0.0, 1.0, 2.0, 3.0, 8.0, 9.0, 6.0, 7.0], 10),
(&[0.0, 1.0, 2.0, 3.0, 8.0, 9.0, 10.0, 7.0], 11),
(&[0.0, 1.0, 2.0, 3.0, 8.0, 9.0, 10.0, 11.0], 12),
];
for (step, (want, off)) in expected.iter().enumerate() {
let t = kv(&[step as f32]);
let (mut k, mut v) = c.update(&t, &t).unwrap();
assert_eq!(&k.to_vec::<f32>().unwrap(), want, "keys, step {step}");
assert_eq!(&v.to_vec::<f32>().unwrap(), want, "values, step {step}");
assert_eq!(c.offset(), *off, "raw offset, step {step}");
if step == 11 {
assert_eq!(k.shape(), vec![1, 1, 8, 1], "full-window shape");
}
}
assert_eq!(c.offset(), 12);
assert!(!c.is_trimmable());
}
#[test]
fn rotating_multi_token_prefill_then_decode() {
let mut c = RotatingKvCache::new(6, 2);
let chunk = kv(&[0.0, 1.0, 2.0, 3.0, 4.0]); let (mut k, _) = c.update(&chunk, &chunk).unwrap();
assert_eq!(k.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
assert_eq!(c.offset(), 5);
let (mut k, _) = c.update(&kv(&[5.0]), &kv(&[5.0])).unwrap();
assert_eq!(
k.to_vec::<f32>().unwrap(),
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
);
assert_eq!(c.offset(), 6);
let (mut k, _) = c.update(&kv(&[6.0]), &kv(&[6.0])).unwrap();
assert_eq!(
k.to_vec::<f32>().unwrap(),
vec![0.0, 1.0, 6.0, 3.0, 4.0, 5.0]
);
assert_eq!(c.offset(), 7);
}
#[test]
fn rotating_keep_zero_is_pure_sliding_window() {
let mut c = RotatingKvCache::new(3, 0);
let expected: [&[f32]; 6] = [
&[0.0],
&[0.0, 1.0],
&[0.0, 1.0, 2.0],
&[3.0, 1.0, 2.0],
&[3.0, 4.0, 2.0],
&[3.0, 4.0, 5.0],
];
for (step, want) in expected.iter().enumerate() {
let t = kv(&[step as f32]);
let (mut k, _) = c.update(&t, &t).unwrap();
assert_eq!(&k.to_vec::<f32>().unwrap(), want, "step {step}");
assert_eq!(c.offset(), step + 1, "raw offset, step {step}");
}
}
#[test]
fn rotating_active_ring_then_concat() {
let mut c = RotatingKvCache::new(8, 4);
for id in 0..=8 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(c.offset(), 9, "after single-token fill+rotate");
let app = kv(&[9.0, 10.0]); let (mut k, mut v) = c.update(&app, &app).unwrap();
let want = vec![0.0, 1.0, 2.0, 3.0, 6.0, 7.0, 8.0, 9.0, 10.0];
assert_eq!(
k.to_vec::<f32>().unwrap(),
want,
"keys: temporal-order trim"
);
assert_eq!(
v.to_vec::<f32>().unwrap(),
want,
"values: temporal-order trim"
);
assert_eq!(k.shape(), vec![1, 1, 9, 1], "max_size + S - 1 = 9");
assert_eq!(c.offset(), 11, "raw offset += S");
}
#[test]
fn make_prompt_cache_one_per_layer_and_kind() {
let cfg = CacheConfig {
num_hidden_layers: 3,
sliding_window: None,
};
let cache = make_prompt_cache(&cfg);
assert_eq!(cache.len(), 3);
assert!(cache.iter().all(|c| c.max_size().is_none()));
assert!(cache.iter().all(|c| c.is_empty()));
let cfg = CacheConfig {
num_hidden_layers: 3,
sliding_window: Some(8),
};
let cache = make_prompt_cache(&cfg);
assert_eq!(cache.len(), 3);
assert!(cache.iter().all(|c| c.max_size() == Some(8)));
assert!(cache.iter().all(|c| c.is_empty()));
}
#[test]
fn kvcache_trait_dispatch() {
let mut c: Box<dyn KvCache> = Box::new(StandardKvCache::new());
assert!(c.is_empty());
let (k, _) = c
.update(&kv(&[0.0, 1.0, 2.0]), &kv(&[0.0, 1.0, 2.0]))
.unwrap();
assert_eq!(k.shape(), vec![1, 1, 3, 1]);
assert_eq!(c.offset(), 3);
assert!(!c.is_empty());
assert_eq!(c.trim(1).unwrap(), 1);
assert_eq!(c.offset(), 2);
let mut c: Box<dyn KvCache> = Box::new(RotatingKvCache::new(4, 2));
for i in 0..6 {
let t = kv(&[i as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(c.offset(), 6);
}
#[test]
fn causal_mask_no_window_offset_zero() {
let mut m = create_causal_mask(3, 0, None).unwrap();
assert_eq!(m.shape(), vec![3, 3]);
assert_eq!(
m.to_vec::<bool>().unwrap(),
vec![
true, false, false, true, true, false, true, true, true, ]
);
}
#[test]
fn causal_mask_with_offset() {
let mut m = create_causal_mask(2, 3, None).unwrap();
assert_eq!(m.shape(), vec![2, 5]);
assert_eq!(
m.to_vec::<bool>().unwrap(),
vec![
true, true, true, true, false, true, true, true, true, true, ]
);
}
#[test]
fn causal_mask_windowed() {
let mut m = create_causal_mask(4, 0, Some(2)).unwrap();
assert_eq!(m.shape(), vec![4, 4]);
assert_eq!(
m.to_vec::<bool>().unwrap(),
vec![
true, false, false, false, true, true, false, false, false, true, true, false, false, false, true, true, ]
);
}
#[test]
fn attention_mask_mode_decision_tree() {
assert!(matches!(
create_attention_mask(4, 0, false, Some(2)).unwrap(),
MaskMode::Array(_)
));
assert!(matches!(
create_attention_mask(1, 7, false, None).unwrap(),
MaskMode::None
));
assert!(matches!(
create_attention_mask(3, 0, true, None).unwrap(),
MaskMode::Array(_)
));
assert!(matches!(
create_attention_mask(3, 0, false, None).unwrap(),
MaskMode::Causal
));
}
#[test]
fn trait_defaults_rope_offset_and_meta_state() {
let mut s = StandardKvCache::new();
s.update(&kv(&[0.0, 1.0]), &kv(&[0.0, 1.0])).unwrap();
match s.rope_offset().unwrap() {
RopeOffset::Scalar(o) => assert_eq!(o, 2),
RopeOffset::Batch(_) => panic!("standard cache must use a scalar RoPE offset"),
}
assert!(s.meta_state().is_empty());
assert!(s.as_quantized().is_none());
assert!(s.as_batch_positioned().is_none());
let mut r = RotatingKvCache::new(8, 4);
for i in 0..3 {
let t = kv(&[i as f32]);
r.update(&t, &t).unwrap();
}
assert_eq!(r.meta_state(), vec!["4", "8", "3", "3"]);
match r.rope_offset().unwrap() {
RopeOffset::Scalar(o) => assert_eq!(o, 3),
RopeOffset::Batch(_) => panic!("rotating cache must use a scalar RoPE offset"),
}
}
#[test]
fn standard_state_nbytes_copy() {
let mut s = StandardKvCache::new();
assert!(s.state().unwrap().is_empty());
assert_eq!(s.nbytes(), 0);
s.update(&kv(&[0.0, 1.0, 2.0, 3.0]), &kv(&[0.0, 1.0, 2.0, 3.0]))
.unwrap();
let mut st = s.state().unwrap();
assert_eq!(st.len(), 2);
assert_eq!(st[0].to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
assert_eq!(s.nbytes(), 8 * 4);
let mut c = s.copy().unwrap();
c.update(&kv(&[9.0]), &kv(&[9.0])).unwrap();
assert_eq!(c.offset(), 5);
assert_eq!(s.offset(), 4);
}
#[test]
fn standard_make_mask_matches_create_attention_mask() {
let mut s = StandardKvCache::new();
s.update(&kv(&[0.0, 1.0, 2.0]), &kv(&[0.0, 1.0, 2.0]))
.unwrap();
assert_eq!(s.offset(), 3);
assert!(matches!(
s.make_mask(1, None, false).unwrap(),
MaskMode::None
));
assert!(matches!(
s.make_mask(3, None, false).unwrap(),
MaskMode::Causal
));
match s.make_mask(3, None, true).unwrap() {
MaskMode::Array(mut m) => {
assert_eq!(m.shape(), vec![3, 6]);
assert_eq!(
m.to_vec::<bool>().unwrap(),
vec![
true, true, true, true, false, false, true, true, true, true, true, false, true, true, true, true, true, true, ]
);
}
_ => panic!("make_mask(3,None,true) must be an Array (cache.py:122)"),
}
match s.make_mask(2, Some(4), false).unwrap() {
MaskMode::Array(mut m) => {
let mut want = create_causal_mask(2, s.offset(), Some(4)).unwrap();
assert_eq!(m.shape(), vec![2, 5]);
assert_eq!(m.to_vec::<bool>().unwrap(), want.to_vec::<bool>().unwrap());
assert_eq!(
m.to_vec::<bool>().unwrap(),
vec![
true, true, true, true, false, false, true, true, true, true, ]
);
}
_ => panic!("make_mask(2,Some(4),_) must be an Array (cache.py:118)"),
}
}
#[test]
fn rotating_make_mask_windowed_and_rolled() {
let mut c = RotatingKvCache::new(8, 4);
for id in 0..2 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(c.offset(), 2);
assert!(matches!(
c.make_mask(3, None, false).unwrap(),
MaskMode::Causal
));
assert!(matches!(
c.make_mask(1, None, false).unwrap(),
MaskMode::None
));
assert!(matches!(
c.make_mask(1, None, true).unwrap(),
MaskMode::None
));
let mut c = RotatingKvCache::new(8, 4);
for id in 0..10 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(c.offset(), 10);
match c.make_mask(3, None, false).unwrap() {
MaskMode::Array(m) => assert_eq!(m.shape(), vec![3, 10]),
_ => panic!("large-offset N>1 must be a windowed Array (cache.py:561)"),
}
let mut c = RotatingKvCache::new(4, 2);
for id in 0..6 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(c.offset(), 6);
match c.make_mask(1, Some(2), false).unwrap() {
MaskMode::Array(mut m) => {
assert_eq!(m.shape(), vec![4]);
assert_eq!(m.to_vec::<bool>().unwrap(), vec![true, false, false, true]);
}
_ => panic!("N==1 windowed rolled case must be an Array (cache.py:578)"),
}
assert!(matches!(
c.make_mask(1, Some(8), false).unwrap(),
MaskMode::None
));
}
#[test]
fn from_state_roundtrip_and_unknown_kind() {
let mut s = StandardKvCache::new();
s.update(&kv(&[0.0, 1.0, 2.0]), &kv(&[0.0, 1.0, 2.0]))
.unwrap();
let s_state = s.state().unwrap();
let s_meta = s.meta_state();
let s2 = from_state("KVCache", s_state, &s_meta).unwrap();
assert_eq!(s2.offset(), 3);
assert!(!s2.is_empty());
let s2_state = s2.state().unwrap();
assert_eq!(s2_state.len(), 2);
let mut s2k = s2_state[0].try_clone().unwrap();
assert_eq!(s2k.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0]);
let mut r = RotatingKvCache::new(8, 4);
for i in 0..5 {
let t = kv(&[i as f32]);
r.update(&t, &t).unwrap();
}
let r_state = r.state().unwrap();
let r_meta = r.meta_state();
let r_state_arg: Vec<Array> = r_state
.iter()
.map(|a| a.try_clone())
.collect::<mlxrs::Result<Vec<_>>>()
.unwrap();
let r2 = from_state("RotatingKVCache", r_state_arg, &r_meta).unwrap();
assert_eq!(r2.offset(), 5);
assert_eq!(r2.max_size(), Some(8));
let r2_state = r2.state().unwrap();
assert_eq!(r2_state.len(), 2);
let mut r2k = r2_state[0].try_clone().unwrap();
let mut r0k = r_state[0].try_clone().unwrap();
assert_eq!(r2k.to_vec::<f32>().unwrap(), r0k.to_vec::<f32>().unwrap());
let r3 = from_state("RotatingKvCache", r.state().unwrap(), &r.meta_state()).unwrap();
assert_eq!(r3.offset(), 5);
assert_eq!(r3.max_size(), Some(8));
assert!(from_state("ChunkedKvCache", Vec::new(), &[]).is_err());
}
#[test]
fn from_state_rotating_empty_with_nonzero_meta_is_invalid() {
let bad_meta = vec![
"4".to_string(), "8".to_string(), "5".to_string(), "5".to_string(), ];
let bad = from_state("RotatingKVCache", Vec::new(), &bad_meta);
assert!(
bad.is_err(),
"empty state + non-zero offset/idx must not yield a usable cache"
);
let bad_idx_only = vec![
"0".to_string(), "8".to_string(), "0".to_string(), "3".to_string(), ];
assert!(
from_state("RotatingKVCache", Vec::new(), &bad_idx_only).is_err(),
"empty state + non-zero idx must be rejected"
);
let zero_meta = vec![
"4".to_string(), "8".to_string(), "0".to_string(), "0".to_string(), ];
let ok = from_state("RotatingKVCache", Vec::new(), &zero_meta).unwrap();
assert!(
ok.is_empty(),
"empty state + zero meta is a valid empty cache"
);
assert_eq!(ok.offset(), 0);
assert_eq!(ok.max_size(), Some(8));
}
#[test]
fn rotating_offset_overflow_is_rejected_without_partial_mutation() {
let mut c = RotatingKvCache::new(8, 4);
let t1 = kv(&[0.0]);
c.update(&t1, &t1).unwrap();
c.set_meta_state(&[
"4".to_string(),
"8".to_string(),
usize::MAX.to_string(),
"1".to_string(),
])
.unwrap();
let before_meta = c.meta_state();
let before_state_len = c.state().unwrap().len();
let two = kv(&[1.0, 2.0]);
let r_concat = c.update(&two, &two);
match &r_concat {
Err(mlxrs::Error::ArithmeticOverflow(p)) => {
assert_eq!(p.context(), "RotatingKvCache::update_concat: offset + S");
assert_eq!(p.op_type(), "usize");
assert_eq!(p.operands(), &[("offset", usize::MAX as u64), ("S", 2u64)]);
}
_ => panic!("concat-path offset overflow must be Err(ArithmeticOverflow), got {r_concat:?}"),
}
assert_eq!(
c.meta_state(),
before_meta,
"overflow (concat) must not partially mutate ring state"
);
assert_eq!(c.state().unwrap().len(), before_state_len);
let one = kv(&[3.0]);
let r_inplace = c.update(&one, &one);
match &r_inplace {
Err(mlxrs::Error::ArithmeticOverflow(p)) => {
assert_eq!(
p.context(),
"RotatingKvCache::update_in_place: offset + S (S=1)"
);
assert_eq!(p.op_type(), "usize");
assert_eq!(p.operands(), &[("offset", usize::MAX as u64), ("S", 1u64)]);
}
_ => panic!("in-place-path offset overflow must be Err(ArithmeticOverflow), got {r_inplace:?}"),
}
assert_eq!(
c.meta_state(),
before_meta,
"overflow (in-place) must not partially mutate ring state"
);
assert_eq!(c.state().unwrap().len(), before_state_len);
}
#[test]
fn rotating_set_meta_state_is_atomic_on_malformed_input() {
let mut c = RotatingKvCache::new(8, 4);
for i in 0..5 {
let t = kv(&[i as f32]);
c.update(&t, &t).unwrap();
}
let before = c.meta_state();
assert_eq!(before.len(), 4, "(keep, max_size, offset, idx)");
let bad = c.set_meta_state(&[
"4".to_string(),
"8".to_string(),
"not-a-number".to_string(),
"2".to_string(),
]);
assert!(
bad.is_err(),
"malformed offset must make set_meta_state fail"
);
assert_eq!(
c.meta_state(),
before,
"a failed set_meta_state must leave keep/max_size/offset/idx unchanged"
);
c.set_meta_state(&[
"2".to_string(),
"16".to_string(),
"9".to_string(),
"3".to_string(),
])
.unwrap();
assert_eq!(
c.meta_state(),
vec![
"2".to_string(),
"16".to_string(),
"9".to_string(),
"3".to_string()
]
);
}
#[test]
fn iarange_mask_exact_in_range_and_rejects_past_f32_limit() {
const F32_EXACT_INT_MAX: usize = 1usize << 24;
let mut m = create_causal_mask(3, 0, None).unwrap();
assert_eq!(m.shape(), vec![3, 3]);
assert_eq!(
m.to_vec::<bool>().unwrap(),
vec![
true, false, false, true, true, false, true, true, true, ]
);
let mut mo = create_causal_mask(2, 3, None).unwrap();
assert_eq!(mo.shape(), vec![2, 5]);
assert_eq!(
mo.to_vec::<bool>().unwrap(),
vec![
true, true, true, true, false, true, true, true, true, true, ]
);
let over = create_causal_mask(2, F32_EXACT_INT_MAX - 1, None);
assert!(
over.is_err(),
"stop == 2^24+1 must be rejected (it rounds to 2^24 -> short/corrupt mask), not returned Ok"
);
let at_max = create_causal_mask(1, F32_EXACT_INT_MAX - 1, None);
assert!(
at_max.is_ok(),
"stop == 2^24 is exactly representable in f32 (cast + every index exact) -> must be Ok"
);
let mut rc = RotatingKvCache::new(F32_EXACT_INT_MAX + 8, 4);
rc.set_meta_state(&[
"4".to_string(), (F32_EXACT_INT_MAX + 8).to_string(), (F32_EXACT_INT_MAX + 4).to_string(), "5".to_string(), ])
.unwrap();
let rmask = rc.make_mask(1, Some(2), false);
assert!(
rmask.is_err(),
"rotating make_mask must propagate the iarange f32-limit Err, not build a corrupted ring mask"
);
}
#[test]
fn create_causal_mask_offset_plus_n_overflow_is_err_not_panic() {
let r = create_causal_mask(2, usize::MAX, None);
match &r {
Err(mlxrs::Error::ArithmeticOverflow(p)) => {
assert_eq!(p.context(), "create_causal_mask: offset + N");
assert_eq!(p.op_type(), "usize");
assert_eq!(p.operands(), &[("offset", usize::MAX as u64), ("N", 2u64)]);
}
_ => panic!(
"offset + N overflow must be Err::ArithmeticOverflow (no debug panic, no release wrap), got {r:?}"
),
}
let rw = create_causal_mask(3, usize::MAX, Some(4));
match &rw {
Err(mlxrs::Error::ArithmeticOverflow(p)) => {
assert_eq!(p.context(), "create_causal_mask: offset + N");
assert_eq!(p.op_type(), "usize");
assert_eq!(p.operands(), &[("offset", usize::MAX as u64), ("N", 3u64)]);
}
_ => panic!(
"windowed create_causal_mask must also reject offset + N overflow as \
ArithmeticOverflow before any range, got {rw:?}"
),
}
}
#[test]
fn create_causal_mask_huge_window_is_unwindowed_noop() {
let mut plain = create_causal_mask(4, 0, None).unwrap();
let plain_v = plain.to_vec::<bool>().unwrap();
assert_eq!(
plain_v,
vec![
true, false, false, false, true, true, false, false, true, true, true, false, true, true, true, true, ]
);
let mut w_eq = create_causal_mask(4, 0, Some(4)).unwrap();
assert_eq!(w_eq.shape(), vec![4, 4]);
assert_eq!(
w_eq.to_vec::<bool>().unwrap(),
plain_v,
"window_size == total must be the unwindowed causal mask (mlx-lm no-op)"
);
let mut w_max = create_causal_mask(4, 0, Some(usize::MAX)).unwrap();
assert_eq!(w_max.shape(), vec![4, 4]);
assert_eq!(
w_max.to_vec::<bool>().unwrap(),
plain_v,
"window_size = usize::MAX must be the unwindowed causal mask (no lossy i32 wrap)"
);
let mut po = create_causal_mask(2, 3, None).unwrap();
let mut pw = create_causal_mask(2, 3, Some(5)).unwrap();
assert_eq!(pw.shape(), vec![2, 5]);
assert_eq!(
pw.to_vec::<bool>().unwrap(),
po.to_vec::<bool>().unwrap(),
"offset!=0: window_size == offset+N must equal the unwindowed causal mask"
);
let mut w_small = create_causal_mask(4, 0, Some(2)).unwrap();
assert_eq!(
w_small.to_vec::<bool>().unwrap(),
vec![
true, false, false, false, true, true, false, false, false, true, true, false, false, false, true, true, ],
"window_size < total must still apply the sliding window"
);
}
#[test]
fn rotating_make_mask_n_gt_1_offset_plus_n_overflow_is_err_not_panic() {
let c = from_state(
"RotatingKVCache",
vec![kv(&[0.0]), kv(&[0.0])],
&[
"0".to_string(), usize::MAX.to_string(), usize::MAX.to_string(), "0".to_string(), ],
)
.unwrap();
let r = c.make_mask(2, None, false);
match &r {
Err(mlxrs::Error::ArithmeticOverflow(p)) => {
assert_eq!(p.context(), "RotatingKvCache::make_mask: offset + N");
assert_eq!(p.op_type(), "usize");
assert_eq!(
p.operands(),
&[("offset", (usize::MAX - 1) as u64), ("N", 2u64)]
);
}
Err(other) => panic!(
"RotatingKvCache::make_mask N>1 offset+N overflow must be Err::ArithmeticOverflow \
(no panic, no wrap-then-wrong-Causal-decision), got Err({other:?})"
),
Ok(_) => {
panic!("RotatingKvCache::make_mask N>1 offset+N overflow must be Err, got Ok(<MaskMode>)")
}
}
let r_arr = c.make_mask(2, None, true);
match &r_arr {
Err(mlxrs::Error::ArithmeticOverflow(p)) => {
assert_eq!(p.context(), "RotatingKvCache::make_mask: offset + N");
assert_eq!(p.op_type(), "usize");
assert_eq!(
p.operands(),
&[("offset", (usize::MAX - 1) as u64), ("N", 2u64)]
);
}
Err(other) => panic!(
"RotatingKvCache::make_mask N>1 offset+N overflow must be Err::ArithmeticOverflow \
even with return_array=true, got Err({other:?})"
),
Ok(_) => panic!(
"RotatingKvCache::make_mask N>1 offset+N overflow (return_array=true) must be Err, \
got Ok(<MaskMode>)"
),
}
let valid = RotatingKvCache::new(8, 4);
assert!(
matches!(valid.make_mask(3, None, false).unwrap(), MaskMode::Causal),
"valid (non-overflowing) N>1 decision must be unchanged (cache.py:560-563 -> \"causal\")"
);
}
#[test]
fn rotating_make_mask_n_gt_1_window_size_zero_is_max_size() {
let mut c = RotatingKvCache::new(8, 4);
for id in 0..2 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(c.offset(), 2);
assert!(
matches!(c.make_mask(3, None, false).unwrap(), MaskMode::Causal),
"N>1 small-offset, window_size None -> \"causal\" (cache.py:563)"
);
assert!(
matches!(c.make_mask(3, Some(0), false).unwrap(), MaskMode::Causal),
"N>1 small-offset, window_size Some(0) must match None (0 is falsy -> max_size, cache.py:558)"
);
assert!(
matches!(c.make_mask(3, Some(8), false).unwrap(), MaskMode::Causal),
"N>1 small-offset, window_size Some(8==max_size) -> \"causal\" (cache.py:563)"
);
let mut c = RotatingKvCache::new(8, 4);
for id in 0..10 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(c.offset(), 10);
let mask_vec = |m: MaskMode| -> (Vec<usize>, Vec<bool>) {
match m {
MaskMode::Array(mut a) => (a.shape(), a.to_vec::<bool>().unwrap()),
_ => panic!("large-offset N>1 must be a windowed Array (cache.py:561)"),
}
};
let (sh_none, v_none) = mask_vec(c.make_mask(3, None, false).unwrap());
let (sh_zero, v_zero) = mask_vec(c.make_mask(3, Some(0), false).unwrap());
let (sh_max, v_max) = mask_vec(c.make_mask(3, Some(8), false).unwrap());
assert_eq!(sh_none, vec![3, 10], "cache.py:561 shape [N, offset+N]");
assert_eq!(
(sh_zero, v_zero),
(sh_none.clone(), v_none.clone()),
"Some(0) must yield the SAME windowed mask as None (cache.py:558 `or max_size`)"
);
assert_eq!(
(sh_max, v_max),
(sh_none, v_none),
"Some(max_size) must yield the SAME windowed mask as None (cache.py:558)"
);
}
#[test]
fn rope_offset_default_is_scalar_for_non_batch_caches() {
let mut s = StandardKvCache::new();
assert!(
s.as_batch_positioned().is_none(),
"StandardKvCache must not be a batch-positioned refinement"
);
s.update(&kv(&[0.0, 1.0, 2.0]), &kv(&[0.0, 1.0, 2.0]))
.unwrap();
match s.rope_offset().unwrap() {
RopeOffset::Scalar(o) => assert_eq!(o, 3, "Standard rope_offset == offset (3)"),
RopeOffset::Batch(_) => {
panic!("non-batch Standard cache must still yield RopeOffset::Scalar")
}
}
s.update(&kv(&[3.0, 4.0]), &kv(&[3.0, 4.0])).unwrap();
match s.rope_offset().unwrap() {
RopeOffset::Scalar(o) => assert_eq!(o, 5, "Standard rope_offset tracks offset (5)"),
RopeOffset::Batch(_) => panic!("non-batch Standard cache must still yield RopeOffset::Scalar"),
}
let mut r = RotatingKvCache::new(8, 4);
assert!(
r.as_batch_positioned().is_none(),
"RotatingKvCache must not be a batch-positioned refinement"
);
for id in 0..5 {
let t = kv(&[id as f32]);
r.update(&t, &t).unwrap();
}
match r.rope_offset().unwrap() {
RopeOffset::Scalar(o) => assert_eq!(o, 5, "Rotating rope_offset == raw offset (5)"),
RopeOffset::Batch(_) => {
panic!("non-batch Rotating cache must still yield RopeOffset::Scalar")
}
}
}
#[test]
fn from_state_no_meta_cache_rejects_truthy_meta_state() {
fn fresh_state() -> Vec<Array> {
vec![kv(&[0.0, 1.0, 2.0]), kv(&[0.0, 1.0, 2.0])]
}
const NO_META_CTX: &str = "KvCache::set_meta_state: meta_state value count for a no-meta cache (mirrors mlx-lm `_BaseCache.meta_state` setter cache.py:142-145)";
for kind in &[
"KVCache",
"ConcatenateKVCache",
"KVCacheSimple",
"StandardKvCache",
] {
let bad = from_state(kind, fresh_state(), &["x".to_string()]);
match bad {
Err(Error::LengthMismatch(p)) => {
assert_eq!(p.context(), NO_META_CTX, "kind {kind}: context mismatch");
assert_eq!(p.expected(), 0, "kind {kind}: expected count must be 0");
assert_eq!(p.actual(), 1, "kind {kind}: actual count must be 1");
}
Err(other) => panic!(
"kind {kind}: truthy meta_state must be rejected as Error::LengthMismatch \
(mirroring mlx-lm _BaseCache.meta_state setter cache.py:142-145), got {other:?}"
),
Ok(_) => panic!(
"kind {kind}: truthy meta_state must NOT silently round-trip (issue #76 — \
a permissive Ok trait default would let it through)"
),
}
let bad_multi = from_state(
kind,
fresh_state(),
&["x".to_string(), "y".to_string(), "z".to_string()],
);
match bad_multi {
Err(Error::LengthMismatch(p)) => {
assert_eq!(p.context(), NO_META_CTX, "kind {kind}: context mismatch");
assert_eq!(p.expected(), 0, "kind {kind}: expected count must be 0");
assert_eq!(p.actual(), 3, "kind {kind}: actual count must be 3");
}
Err(other) => panic!(
"kind {kind}: multi-value truthy meta_state must also be rejected as \
Error::LengthMismatch, got Err({other:?})"
),
Ok(_) => panic!("kind {kind}: multi-value truthy meta_state must NOT silently round-trip"),
}
}
for kind in &[
"KVCache",
"ConcatenateKVCache",
"KVCacheSimple",
"StandardKvCache",
] {
let ok = from_state(kind, fresh_state(), &[])
.unwrap_or_else(|e| panic!("kind {kind}: empty meta_state must succeed, got {e:?}"));
assert_eq!(
ok.offset(),
3,
"kind {kind}: offset restored from state.shape[2]"
);
assert!(!ok.is_empty(), "kind {kind}: state was non-empty");
assert!(
ok.meta_state().is_empty(),
"kind {kind}: no-meta cache must report empty meta_state"
);
}
}
#[test]
fn rotating_set_seq_full_window_rejects_mismatched_batch_dim() {
let mut c = RotatingKvCache::new(1, 0); let seed = kv(&[7.0]);
c.update(&seed, &seed).unwrap();
assert_eq!(c.offset(), 1);
c.set_meta_state(&[
"0".to_string(),
"1".to_string(),
"0".to_string(),
"0".to_string(),
])
.unwrap();
assert_eq!(c.offset(), 0);
let bad_kv2 = Array::from_slice::<f32>(&[9.0, 9.5], &(2usize, 1, 1, 1)).unwrap();
let r = c.update(&bad_kv2, &bad_kv2);
match &r {
Err(mlxrs::Error::ShapePairMismatch(p)) => {
assert_eq!(
p.context(),
"broadcast_write_rhs: keys write RHS non-broadcastable (mlx-lm \
slice-assignment raises on non-broadcastable non-seq axes; seq-axis \
target is the slice window length)"
);
assert_eq!(p.expected(), &[1usize, 1, 1, 1]);
assert_eq!(p.actual(), &[2usize, 1, 1, 1]);
}
_ => panic!(
"rotating full-window set_seq must reject batch-axis mismatch on the public \
update API as ShapePairMismatch (closes #78), got {r:?}"
),
}
let ok = kv(&[42.0]);
c.update(&ok, &ok).unwrap();
assert_eq!(c.offset(), 1);
}
#[test]
fn rotating_set_seq_full_window_rejects_mismatched_heads_and_head_dim() {
let bad_kv_heads = Array::from_slice::<f32>(&[9.0, 9.5, 9.7], &(1usize, 3, 1, 1)).unwrap();
let bad_kv_hd = Array::from_slice::<f32>(&[9.0, 9.5], &(1usize, 1, 1, 2)).unwrap();
let seed = kv(&[7.0]);
let mut c1 = RotatingKvCache::new(1, 0);
c1.update(&seed, &seed).unwrap();
c1.set_meta_state(&[
"0".to_string(),
"1".to_string(),
"0".to_string(),
"0".to_string(),
])
.unwrap();
let r1 = c1.update(&bad_kv_heads, &bad_kv_heads);
match &r1 {
Err(mlxrs::Error::ShapePairMismatch(p)) => {
assert_eq!(
p.context(),
"broadcast_write_rhs: keys write RHS non-broadcastable (mlx-lm \
slice-assignment raises on non-broadcastable non-seq axes; seq-axis \
target is the slice window length)"
);
assert_eq!(p.expected(), &[1usize, 1, 1, 1]);
assert_eq!(p.actual(), &[1usize, 3, 1, 1]);
}
_ => panic!(
"rotating full-window set_seq must reject n_kv_heads mismatch as \
ShapePairMismatch, got {r1:?}"
),
}
c1.update(&seed, &seed).unwrap();
assert_eq!(c1.offset(), 1);
let mut c2 = RotatingKvCache::new(1, 0);
c2.update(&seed, &seed).unwrap();
c2.set_meta_state(&[
"0".to_string(),
"1".to_string(),
"0".to_string(),
"0".to_string(),
])
.unwrap();
let r2 = c2.update(&bad_kv_hd, &bad_kv_hd);
match &r2 {
Err(mlxrs::Error::ShapePairMismatch(p)) => {
assert_eq!(
p.context(),
"broadcast_write_rhs: keys write RHS non-broadcastable (mlx-lm \
slice-assignment raises on non-broadcastable non-seq axes; seq-axis \
target is the slice window length)"
);
assert_eq!(p.expected(), &[1usize, 1, 1, 1]);
assert_eq!(p.actual(), &[1usize, 1, 1, 2]);
}
_ => panic!(
"rotating full-window set_seq must reject head_dim mismatch as \
ShapePairMismatch, got {r2:?}"
),
}
c2.update(&seed, &seed).unwrap();
assert_eq!(c2.offset(), 1);
}
#[test]
fn set_seq_write_shape_compat_helper_semantics_via_chunked() {
let mut c = ChunkedKvCache::new(None);
let buf4 = kv(&[10.0, 11.0, 12.0, 13.0]); c.set_state(vec![buf4.try_clone().unwrap(), buf4.try_clone().unwrap()])
.unwrap();
let t = kv(&[99.0]);
c.update(&t, &t).unwrap();
assert_eq!(c.offset(), 5);
}
#[test]
fn rotating_set_seq_full_window_broadcasts_size_one_rhs() {
let seed2 = Array::from_slice::<f32>(&[10.0, 20.0], &(2usize, 1, 1, 1)).unwrap();
let rhs1 = kv(&[99.0]);
let mut c = RotatingKvCache::new(1, 0); c.update(&seed2, &seed2).unwrap();
assert_eq!(c.offset(), 1);
c.set_meta_state(&[
"0".to_string(),
"1".to_string(),
"0".to_string(),
"0".to_string(),
])
.unwrap();
assert_eq!(c.offset(), 0);
let (k, v) = c.update(&rhs1, &rhs1).unwrap();
assert_eq!(
k.shape(),
vec![2, 1, 1, 1],
"rotating size-1 batch RHS must broadcast to PRESERVE buffer batch dim 2 \
(NOT silently shrink to [1, 1, 1, 1])"
);
assert_eq!(v.shape(), vec![2, 1, 1, 1]);
}
#[test]
fn rotating_update_in_place_partial_mutation_on_set_seq_err_is_rejected() {
let mut c = RotatingKvCache::new(1, 0); let two = kv(&[10.0, 11.0]);
c.update(&two, &two).unwrap();
assert_eq!(c.offset(), 2);
let before_meta = c.meta_state();
assert_eq!(before_meta, vec!["0", "1", "2", "2"]); let before_state_len = c.state().unwrap().len();
let bad_kv2 = Array::from_slice::<f32>(&[9.0, 9.5], &(2usize, 1, 1, 1)).unwrap();
let r = c.update(&bad_kv2, &bad_kv2);
match &r {
Err(mlxrs::Error::ShapePairMismatch(p)) => {
assert_eq!(
p.context(),
"broadcast_write_rhs: keys write RHS non-broadcastable (mlx-lm \
slice-assignment raises on non-broadcastable non-seq axes; seq-axis \
target is the slice window length)"
);
assert_eq!(p.expected(), &[1usize, 1, 1, 1]);
assert_eq!(p.actual(), &[2usize, 1, 1, 1]);
}
_ => panic!(
"non-broadcastable full-window RHS must be Err::ShapePairMismatch on the \
public update API (follow-up to #78), got {r:?}"
),
}
assert_eq!(c.offset(), 2, "offset must NOT advance on a failed update");
assert_eq!(
c.meta_state(),
before_meta,
"no partial trim / cursor reset on a failed update"
);
assert_eq!(c.state().unwrap().len(), before_state_len);
let ok = kv(&[7.0]);
c.update(&ok, &ok).unwrap();
assert_eq!(c.offset(), 3);
}
#[test]
fn kvcache_as_any_mut_downcasts_to_standard() {
let mut boxed: Box<dyn KvCache> = Box::new(StandardKvCache::new());
let any = boxed.as_any_mut();
let std_cache: Option<&mut StandardKvCache> = any.downcast_mut::<StandardKvCache>();
assert!(
std_cache.is_some(),
"as_any_mut + downcast_mut must reach the concrete type"
);
let any = boxed.as_any_mut();
let wrong: Option<&mut RotatingKvCache> = any.downcast_mut::<RotatingKvCache>();
assert!(wrong.is_none(), "wrong-type downcast must return None");
}
#[test]
fn kvcache_as_any_mut_downcasts_to_rotating() {
let mut boxed: Box<dyn KvCache> = Box::new(RotatingKvCache::new(8, 1));
let any = boxed.as_any_mut();
assert!(any.downcast_mut::<RotatingKvCache>().is_some());
let any = boxed.as_any_mut();
assert!(any.downcast_mut::<StandardKvCache>().is_none());
}
#[test]
fn kvcache_as_any_mut_downcasts_to_chunked() {
let mut boxed: Box<dyn KvCache> = Box::new(ChunkedKvCache::new(Some(8)));
let any = boxed.as_any_mut();
assert!(any.downcast_mut::<ChunkedKvCache>().is_some());
}
#[test]
fn kvcache_as_any_mut_downcasts_for_all_8_in_tree_impls() {
{
let mut boxed: Box<dyn KvCache> = Box::new(StandardKvCache::new());
assert!(
boxed
.as_any_mut()
.downcast_mut::<StandardKvCache>()
.is_some(),
"StandardKvCache: positive downcast must succeed"
);
assert!(
boxed
.as_any_mut()
.downcast_mut::<RotatingKvCache>()
.is_none(),
"StandardKvCache: wrong-type downcast must return None"
);
}
{
let mut boxed: Box<dyn KvCache> = Box::new(RotatingKvCache::new(8, 1));
assert!(
boxed
.as_any_mut()
.downcast_mut::<RotatingKvCache>()
.is_some(),
"RotatingKvCache: positive downcast must succeed"
);
assert!(
boxed
.as_any_mut()
.downcast_mut::<StandardKvCache>()
.is_none(),
"RotatingKvCache: wrong-type downcast must return None"
);
}
{
let mut boxed: Box<dyn KvCache> = Box::new(ChunkedKvCache::new(Some(8)));
assert!(
boxed
.as_any_mut()
.downcast_mut::<ChunkedKvCache>()
.is_some(),
"ChunkedKvCache: positive downcast must succeed"
);
assert!(
boxed
.as_any_mut()
.downcast_mut::<StandardKvCache>()
.is_none(),
"ChunkedKvCache: wrong-type downcast must return None"
);
}
{
let mut boxed: Box<dyn KvCache> = Box::new(StandardQuantizedKvCache::new(64, 8).unwrap());
assert!(
boxed
.as_any_mut()
.downcast_mut::<StandardQuantizedKvCache>()
.is_some(),
"StandardQuantizedKvCache: positive downcast must succeed"
);
assert!(
boxed
.as_any_mut()
.downcast_mut::<StandardKvCache>()
.is_none(),
"StandardQuantizedKvCache: wrong-type downcast must return None"
);
}
{
let mut boxed: Box<dyn KvCache> = Box::new(BatchKvCache::new(&[0, 0]));
assert!(
boxed.as_any_mut().downcast_mut::<BatchKvCache>().is_some(),
"BatchKvCache: positive downcast must succeed"
);
assert!(
boxed
.as_any_mut()
.downcast_mut::<BatchRotatingKvCache>()
.is_none(),
"BatchKvCache: wrong-type downcast must return None"
);
}
{
let mut boxed: Box<dyn KvCache> = Box::new(BatchRotatingKvCache::new(4, &[0, 0]));
assert!(
boxed
.as_any_mut()
.downcast_mut::<BatchRotatingKvCache>()
.is_some(),
"BatchRotatingKvCache: positive downcast must succeed"
);
assert!(
boxed.as_any_mut().downcast_mut::<BatchKvCache>().is_none(),
"BatchRotatingKvCache: wrong-type downcast must return None"
);
}
{
let mut boxed: Box<dyn KvCache> = Box::new(ArraysCache::new(4));
assert!(
boxed.as_any_mut().downcast_mut::<ArraysCache>().is_some(),
"ArraysCache: positive downcast must succeed"
);
assert!(
boxed
.as_any_mut()
.downcast_mut::<StandardKvCache>()
.is_none(),
"ArraysCache: wrong-type downcast must return None"
);
}
{
let child: Box<dyn KvCache> = Box::new(StandardKvCache::new());
let mut boxed: Box<dyn KvCache> = Box::new(CacheList::new(vec![child]));
assert!(
boxed.as_any_mut().downcast_mut::<CacheList>().is_some(),
"CacheList: positive downcast must succeed"
);
assert!(
boxed
.as_any_mut()
.downcast_mut::<StandardKvCache>()
.is_none(),
"CacheList: wrong-type downcast must return None"
);
}
}
#[test]
fn standard_materialize_noop_empty_and_eval_populated() {
let mut s = StandardKvCache::new();
s.materialize().unwrap();
assert!(s.is_empty());
assert_eq!(s.offset(), 0);
assert_eq!(s.nbytes(), 0, "empty cache stores no buffer");
s.update(&kv(&[0.0, 1.0, 2.0, 3.0]), &kv(&[0.0, 1.0, 2.0, 3.0]))
.unwrap();
assert_eq!(
s.nbytes(),
32,
"StandardKvCache stores exactly offset-length buffers (no over-allocation)"
);
s.materialize().unwrap();
assert!(!s.is_empty());
assert_eq!(s.offset(), 4);
assert_eq!(
s.nbytes(),
32,
"materialize is value/shape-preserving: stored buffer stays 4 rows"
);
let st = s.state().unwrap();
assert_eq!(st.len(), 2);
let mut sk = st[0].try_clone().unwrap();
let mut sv = st[1].try_clone().unwrap();
assert_eq!(sk.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
assert_eq!(sv.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
let (mut k, _) = s.update(&kv(&[4.0]), &kv(&[4.0])).unwrap();
assert_eq!(s.offset(), 5);
assert_eq!(k.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn rotating_materialize_noop_empty_and_eval_populated() {
let mut c = RotatingKvCache::new(8, 4);
c.materialize().unwrap();
assert!(c.is_empty());
assert_eq!(c.offset(), 0);
assert_eq!(c.nbytes(), 0, "empty cache stores no buffer");
for id in 0..3 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(
c.nbytes(),
64,
"stored ring is the full 8-row buffer (over-allocated), not the 3-row state() slice"
);
assert_eq!(
c.state().unwrap()[0].shape(),
vec![1, 1, 3, 1],
"state() slices the full ring down to the 3-row offset view"
);
c.materialize().unwrap();
assert!(!c.is_empty());
assert_eq!(c.offset(), 3);
assert_eq!(c.meta_state(), vec!["4", "8", "3", "3"]);
assert_eq!(
c.nbytes(),
64,
"materialize is value/shape-preserving: stored ring stays 8 rows"
);
let st = c.state().unwrap();
let mut sk = st[0].try_clone().unwrap();
assert_eq!(sk.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0]);
let (mut k, _) = c.update(&kv(&[3.0]), &kv(&[3.0])).unwrap();
assert_eq!(c.offset(), 4);
assert_eq!(k.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
assert_eq!(c.meta_state(), vec!["4", "8", "4", "4"]);
assert_eq!(c.nbytes(), 64, "decode reuses the same 8-row ring");
}
#[test]
fn standard_materialize_is_not_a_noop_evals_buffers() {
let joined = std::thread::spawn(|| {
let mut s = StandardKvCache::new();
s.update(&kv(&[0.0, 1.0, 2.0, 3.0]), &kv(&[0.0, 1.0, 2.0, 3.0]))
.unwrap();
Stream::clear_current_thread_streams()
.expect("test setup: clear_current_thread_streams must succeed");
s.materialize().unwrap();
})
.join();
let payload = joined.expect_err("materialize must panic on a cleared-stream thread");
let msg = if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
};
assert!(
msg.contains("Stream::clear_current_thread_streams() was called on this thread"),
"unexpected panic payload: {msg}"
);
}
#[test]
fn rotating_materialize_is_not_a_noop_evals_buffers() {
let joined = std::thread::spawn(|| {
let mut c = RotatingKvCache::new(8, 4);
for id in 0..3 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
Stream::clear_current_thread_streams()
.expect("test setup: clear_current_thread_streams must succeed");
c.materialize().unwrap();
})
.join();
let payload = joined.expect_err("materialize must panic on a cleared-stream thread");
let msg = if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
};
assert!(
msg.contains("Stream::clear_current_thread_streams() was called on this thread"),
"unexpected panic payload: {msg}"
);
}
#[test]
fn standard_set_state_direct_arity_and_offset() {
let mut s = StandardKvCache::new();
let buf = kv(&[10.0, 11.0, 12.0, 13.0, 14.0]); s.set_state(vec![buf.try_clone().unwrap(), buf.try_clone().unwrap()])
.unwrap();
assert_eq!(s.offset(), 5, "offset = keys.shape[-2]");
assert!(!s.is_empty());
let st = s.state().unwrap();
let mut sk = st[0].try_clone().unwrap();
assert_eq!(
sk.to_vec::<f32>().unwrap(),
vec![10.0, 11.0, 12.0, 13.0, 14.0]
);
s.set_state(Vec::new()).unwrap();
assert!(s.is_empty());
assert_eq!(s.offset(), 0);
assert!(s.state().unwrap().is_empty());
let r = s.set_state(vec![kv(&[0.0])]);
match &r {
Err(Error::OutOfRange(p)) => {
assert_eq!(p.context(), "StandardKvCache::set_state: state array count");
assert_eq!(p.requirement(), "must be 0 or 2");
assert_eq!(p.value(), "1");
}
_ => panic!("1-array state must be Err(OutOfRange), got {r:?}"),
}
let r3 = s.set_state(vec![kv(&[0.0]), kv(&[0.0]), kv(&[0.0])]);
match &r3 {
Err(Error::OutOfRange(p)) => {
assert_eq!(p.value(), "3");
}
_ => panic!("3-array state must be Err(OutOfRange), got {r3:?}"),
}
assert!(s.is_empty());
assert_eq!(s.offset(), 0);
}
#[test]
fn rotating_set_state_direct_arity_and_buffer_only() {
let mut c = RotatingKvCache::new(8, 4);
for id in 0..3 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(c.meta_state(), vec!["4", "8", "3", "3"]);
let buf = kv(&[20.0, 21.0, 22.0, 23.0, 24.0]); c.set_state(vec![buf.try_clone().unwrap(), buf.try_clone().unwrap()])
.unwrap();
assert_eq!(
c.meta_state(),
vec!["4", "8", "3", "3"],
"set_state must not change offset/idx (they come from set_meta_state)"
);
let st = c.state().unwrap();
let mut sk = st[0].try_clone().unwrap();
assert_eq!(sk.to_vec::<f32>().unwrap(), vec![20.0, 21.0, 22.0]);
c.set_state(Vec::new()).unwrap();
assert!(c.is_empty());
assert_eq!(
c.meta_state(),
vec!["4", "8", "3", "3"],
"empty set_state nulls only the buffers; offset/idx/keep/max_size persist"
);
let r = c.set_state(vec![kv(&[0.0])]);
match &r {
Err(Error::OutOfRange(p)) => {
assert_eq!(p.context(), "RotatingKvCache::set_state: state array count");
assert_eq!(p.requirement(), "must be 0 or 2");
assert_eq!(p.value(), "1");
}
_ => panic!("1-array rotating state must be Err(OutOfRange), got {r:?}"),
}
}
#[test]
fn standard_trim_no_op_empty_and_buffer_sync() {
let mut empty = StandardKvCache::new();
assert_eq!(empty.trim(5).unwrap(), 0);
assert!(empty.is_empty());
assert_eq!(empty.offset(), 0);
let mut s = StandardKvCache::new();
s.update(
&kv(&[0.0, 1.0, 2.0, 3.0, 4.0]),
&kv(&[0.0, 1.0, 2.0, 3.0, 4.0]),
)
.unwrap();
assert_eq!(s.offset(), 5);
assert_eq!(s.trim(0).unwrap(), 0);
assert_eq!(s.offset(), 5);
let st = s.state().unwrap();
let mut sk0 = st[0].try_clone().unwrap();
assert_eq!(sk0.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
assert_eq!(s.trim(2).unwrap(), 2);
assert_eq!(s.offset(), 3);
let st = s.state().unwrap();
let mut sk = st[0].try_clone().unwrap();
let mut sv = st[1].try_clone().unwrap();
assert_eq!(sk.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0]);
assert_eq!(sv.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0]);
assert_eq!(sk.shape(), vec![1, 1, 3, 1]);
}
#[test]
fn rotating_trim_adjusts_offset_and_idx() {
let mut c = RotatingKvCache::new(8, 2);
for id in 0..5 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert!(c.is_trimmable());
assert_eq!(c.meta_state(), vec!["2", "8", "5", "5"]);
assert_eq!(c.trim(2).unwrap(), 2);
assert_eq!(c.offset(), 3);
assert_eq!(
c.meta_state(),
vec!["2", "8", "3", "3"],
"trim rewinds BOTH offset and _idx by the trimmed count"
);
assert_eq!(c.trim(100).unwrap(), 3);
assert_eq!(c.offset(), 0);
assert_eq!(c.meta_state(), vec!["2", "8", "0", "0"]);
let mut full = RotatingKvCache::new(4, 2);
for id in 0..4 {
let t = kv(&[id as f32]);
full.update(&t, &t).unwrap();
}
assert_eq!(full.offset(), 4);
assert!(
!full.is_trimmable(),
"offset == max_size -> not trimmable (cache.py is_trimmable)"
);
}
#[test]
fn rotating_make_mask_n1_window_unrolled_during_linear_fill() {
let mut c = RotatingKvCache::new(8, 0);
for id in 0..3 {
let t = kv(&[id as f32]);
c.update(&t, &t).unwrap();
}
assert_eq!(c.offset(), 3);
assert_eq!(c.meta_state(), vec!["0", "8", "3", "3"]);
match c.make_mask(1, Some(2), false).unwrap() {
MaskMode::Array(mut m) => {
assert_eq!(m.shape(), vec![4]);
assert_eq!(
m.to_vec::<bool>().unwrap(),
vec![false, false, true, true],
"shift == mask_size -> s == 0 -> roll_1d identity (unrolled mask)"
);
}
_ => panic!("N==1 windowed (offset >= ws, max_size > ws) must be an Array"),
}
}