use super::*;
fn kvb(rows: &[&[f32]]) -> Array {
let b = rows.len();
let s = rows[0].len();
let mut data = Vec::with_capacity(b * s);
for r in rows {
assert_eq!(r.len(), s, "ragged test rows");
data.extend_from_slice(r);
}
Array::from_slice::<f32>(&data, &(b, 1usize, s, 1usize)).unwrap()
}
fn iv(a: &Array) -> Vec<i32> {
let mut a = a.try_clone().unwrap();
a.to_vec::<i32>().unwrap()
}
#[test]
fn batch_head_dim_generic_name_rank_error_and_ok() {
let bad = Array::from_slice::<f32>(&[1.0, 2.0], &(1usize, 2usize)).unwrap();
let err = batch_head_dim("offset", &bad).unwrap_err();
assert!(
matches!(err, Error::RankMismatch(_)),
"non-4-D batch_head_dim must be RankMismatch, not panic"
);
assert!(matches!(
batch_head_dim("keys", &bad).unwrap_err(),
Error::RankMismatch(_)
));
assert!(matches!(
batch_head_dim("values", &bad).unwrap_err(),
Error::RankMismatch(_)
));
let ok = kvb(&[&[1.0, 2.0, 3.0]]); assert_eq!(batch_head_dim("keys", &ok).unwrap(), 1);
}
#[test]
fn dynamic_roll_rank_and_axis_guards() {
let shifts = Array::from_slice::<i32>(&[0], &(1usize, 1usize)).unwrap();
let bad_x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1usize, 3usize)).unwrap();
assert!(matches!(
dynamic_roll(&bad_x, &shifts, 2).unwrap_err(),
Error::RankMismatch(_)
));
let x = kvb(&[&[10.0, 20.0, 30.0]]);
assert!(matches!(
dynamic_roll(&x, &shifts, 1).unwrap_err(),
Error::OutOfRange(_)
));
}
#[test]
fn empty_ivec_builds_zero_length_i32() {
let mut a = empty_ivec();
assert_eq!(a.shape(), vec![0usize]);
assert_eq!(a.dtype().unwrap(), Dtype::I32);
assert!(a.to_vec::<i32>().unwrap().is_empty());
}
#[test]
fn state_kv_empty_errors_then_returns_pair() {
let c = BatchKvCache::new(&[0, 0]);
assert!(
matches!(c.state_kv().unwrap_err(), Error::InvariantViolation(_)),
"state_kv on an empty cache must be InvariantViolation, not panic"
);
let mut c = BatchKvCache::new(&[0, 0]);
let p = kvb(&[&[1.0, 2.0], &[3.0, 4.0]]);
c.update(&p, &p).unwrap();
let (mut k, mut v) = c.state_kv().unwrap();
assert_eq!(k.shape(), vec![2, 1, 2, 1]);
assert_eq!(k.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(v.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn nbytes_sums_key_and_value_buffers() {
let c = BatchKvCache::new(&[0, 0]);
assert_eq!(c.nbytes(), 0, "empty cache has 0 bytes");
let mut c = BatchKvCache::new(&[0, 0]);
let p = kvb(&[&[1.0, 2.0], &[3.0, 4.0]]);
c.update(&p, &p).unwrap();
let per_buffer_elems = 2 * 2; assert_eq!(
c.nbytes(),
2 * per_buffer_elems * std::mem::size_of::<f32>(),
"keys.nbytes + values.nbytes (each 16 bytes here)"
);
}
#[test]
fn materialize_evals_all_live_buffers() {
let mut c = BatchKvCache::new(&[1, 0]);
let p = kvb(&[&[5.0, 6.0], &[7.0, 8.0]]);
c.update(&p, &p).unwrap();
c.prepare_right_padding(&[1, 1]).unwrap();
c.materialize().unwrap();
assert_eq!(c.offset(), 2);
assert_eq!(iv(&c.batch_offset().unwrap()), vec![1, 2]);
let (mut k, _) = c.state_kv().unwrap();
assert_eq!(k.to_vec::<f32>().unwrap(), vec![5.0, 6.0, 7.0, 8.0]);
let mut empty = BatchKvCache::new(&[0]);
empty.materialize().unwrap();
assert!(empty.is_empty());
}
#[test]
fn update_idx_overflow_is_rejected_without_partial_mutation() {
let stored = kvb(&[&[1.0]]); let lp = ivec(&[0]).unwrap();
let off = ivec(&[5]).unwrap();
let mut c = BatchKvCache {
keys: Some(stored.try_clone().unwrap()),
values: Some(stored.try_clone().unwrap()),
left_padding: lp,
pad_lengths: vec![0],
offset: off,
idx: usize::MAX,
right_padding: None,
right_padding_host: None,
};
let upd = kvb(&[&[2.0]]); let err = c.update(&upd, &upd).unwrap_err();
assert!(
matches!(err, Error::ArithmeticOverflow(_)),
"_idx + S overflow must be a recoverable ArithmeticOverflow"
);
assert_eq!(c.offset(), usize::MAX, "_idx unchanged on the Err path");
assert_eq!(iv(&c.batch_offset().unwrap()), vec![5], "offset unchanged");
}
#[test]
fn trim_zero_is_noop_early_return() {
let mut c = BatchKvCache::new(&[0, 0]);
let p = kvb(&[&[1.0], &[2.0]]);
c.update(&p, &p).unwrap();
assert_eq!(c.trim(0).unwrap(), 0, "trim(0) returns 0 immediately");
assert_eq!(c.offset(), 1, "offset untouched by trim(0)");
let mut empty = BatchKvCache::new(&[0]);
assert_eq!(empty.trim(3).unwrap(), 0);
}
#[test]
fn trim_with_no_buffer_decrements_idx_and_offset() {
let lp = ivec(&[0, 0]).unwrap();
let off = ivec(&[5, 5]).unwrap();
let mut c = BatchKvCache {
keys: None,
values: None,
left_padding: lp,
pad_lengths: vec![0, 0],
offset: off,
idx: 5,
right_padding: None,
right_padding_host: None,
};
assert_eq!(c.trim(2).unwrap(), 2, "trimmed = min(2, _idx=5)");
assert_eq!(c.offset(), 3, "_idx 5 -> 3");
assert_eq!(iv(&c.batch_offset().unwrap()), vec![3, 3]);
assert!(c.is_empty(), "keys stayed None (the `_ => None` slice arm)");
}
#[test]
fn finalize_with_none_host_mirror_and_no_buffer() {
let lp = ivec(&[0, 0]).unwrap();
let off = ivec(&[4, 4]).unwrap();
let padding = ivec(&[1, 2]).unwrap();
let mut c = BatchKvCache {
keys: None,
values: None,
left_padding: lp,
pad_lengths: vec![0, 0],
offset: off,
idx: 0,
right_padding: Some(padding),
right_padding_host: None,
};
c.finalize().unwrap();
assert_eq!(
iv(&c.batch_offset().unwrap()),
vec![3, 2],
"offset -= padding"
);
assert_eq!(
iv(&c.left_padding_arr().unwrap()),
vec![1, 2],
"lp += padding"
);
assert_eq!(
c.pad_lengths(),
&[0, 0],
"None host mirror -> pad_lengths preserved (line 400 arm)"
);
c.finalize().unwrap();
assert_eq!(iv(&c.batch_offset().unwrap()), vec![3, 2]);
}
#[test]
fn finalize_no_buffer_refreshes_host_mirror() {
let mut c = BatchKvCache::new(&[0, 0]);
c.prepare_right_padding(&[1, 2]).unwrap();
c.finalize().unwrap();
assert_eq!(iv(&c.left_padding_arr().unwrap()), vec![1, 2]);
assert_eq!(
c.pad_lengths(),
&[1, 2],
"host mirror updated elementwise (B==B arm)"
);
}
#[test]
fn copy_clones_all_buffers_independently() {
let mut c = BatchKvCache::new(&[1, 0]);
let p = kvb(&[&[10.0, 20.0], &[30.0, 40.0]]);
c.update(&p, &p).unwrap();
c.prepare_right_padding(&[1, 1]).unwrap();
let mut copied = c.copy().unwrap();
assert_eq!(copied.offset(), 2);
assert!(!copied.is_empty());
assert_eq!(copied.nbytes(), c.nbytes());
assert_eq!(
iv(
&copied
.as_batch_positioned()
.unwrap()
.batch_offset()
.unwrap()
),
vec![1, 2]
);
let st = copied.state().unwrap();
assert_eq!(st.len(), 4, "[keys, values, offset, left_padding]");
let mut k = st[0].try_clone().unwrap();
assert_eq!(
k.to_vec::<f32>().unwrap(),
vec![10.0, 20.0, 30.0, 40.0],
"copied keys are an exact independent duplicate"
);
let before_off = iv(&c.batch_offset().unwrap());
let before_lp = iv(&c.left_padding_arr().unwrap());
copied.set_state(Vec::new()).unwrap();
assert!(copied.is_empty(), "copy reset independently");
assert_eq!(
iv(&c.batch_offset().unwrap()),
before_off,
"original offset untouched by mutating the copy"
);
assert_eq!(
iv(&c.left_padding_arr().unwrap()),
before_lp,
"original left_padding untouched by mutating the copy"
);
}
#[test]
fn copy_of_empty_cache_is_empty() {
let c = BatchKvCache::new(&[2, 0, 1]);
let copied = c.copy().unwrap();
assert!(copied.is_empty());
assert_eq!(copied.offset(), 0);
assert_eq!(copied.nbytes(), 0);
assert!(copied.state().unwrap().is_empty());
assert_eq!(
iv(
&copied
.as_batch_positioned()
.unwrap()
.batch_offset()
.unwrap()
),
vec![-2, 0, -1],
"copied empty cache preserves -left_padding"
);
}
#[test]
fn causal_mask_batched_offset_overflow_is_rejected() {
let err = create_causal_mask_batched(1, usize::MAX, None, None, None).unwrap_err();
assert!(
matches!(err, Error::ArithmeticOverflow(_)),
"offset + N overflow must be ArithmeticOverflow, not panic"
);
}
#[test]
fn causal_mask_batched_windowed_term() {
let mut m = create_causal_mask_batched(4, 0, Some(2), None, None).unwrap();
assert_eq!(m.shape(), vec![4, 4], "no batch term -> [N, total]");
let bits: Vec<u8> = m
.to_vec::<bool>()
.unwrap()
.into_iter()
.map(|b| b as u8)
.collect();
assert_eq!(
bits,
vec![
1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, ]
);
let mut full = create_causal_mask_batched(4, 0, Some(99), None, None).unwrap();
let full_bits: Vec<u8> = full
.to_vec::<bool>()
.unwrap()
.into_iter()
.map(|b| b as u8)
.collect();
assert_eq!(
full_bits,
vec![
1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, ],
"window_size >= total is a no-op (plain causal)"
);
}
#[test]
fn causal_mask_batched_right_padding_term() {
let rp = ivec(&[0, 1]).unwrap();
let mut m = create_causal_mask_batched(3, 0, None, Some(&rp), None).unwrap();
assert_eq!(
m.shape(),
vec![2, 1, 3, 3],
"right_padding -> [B,1,N,total]"
);
let bits: Vec<u8> = m
.to_vec::<bool>()
.unwrap()
.into_iter()
.map(|b| b as u8)
.collect();
assert_eq!(
bits,
vec![
1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, ]
);
}