use super::*;
use crate::{Error, array::Array};
fn kv1() -> Array {
Array::from_slice::<f32>(&[0.0], &(1usize, 1, 1, 1)).unwrap()
}
#[test]
fn slice_seq_rejects_end_above_i32_max() {
let a = kv1();
let bad_end = (i32::MAX as usize) + 1;
let r = slice_seq(&a, 0, bad_end);
match r {
Err(Error::ArithmeticOverflow(payload)) => {
assert!(
payload.context().contains("end") && payload.context().contains("i32::MAX"),
"expected context to name `end` and `i32::MAX`, got: {:?}",
payload.context()
);
let has_value = payload
.operands()
.iter()
.any(|(n, v)| *n == "end" && *v == bad_end as u64);
assert!(
has_value,
"expected operands to include `end` = {bad_end}, got: {:?}",
payload.operands()
);
}
other => panic!("expected Err(ArithmeticOverflow), got {other:?}"),
}
}
#[test]
fn slice_seq_rejects_start_above_i32_max() {
let a = kv1();
let bad_start = (i32::MAX as usize) + 1;
let r = slice_seq(&a, bad_start, bad_start);
match r {
Err(Error::ArithmeticOverflow(payload)) => {
assert!(
payload.context().contains("i32::MAX"),
"expected context to mention `i32::MAX`, got: {:?}",
payload.context()
);
assert!(
payload.context().contains("start") || payload.context().contains("end"),
"expected context to name `start` or `end` offset, got: {:?}",
payload.context()
);
}
other => panic!("expected Err(ArithmeticOverflow), got {other:?}"),
}
}
#[test]
fn slice_seq_accepts_zero_window_at_origin() {
let a = kv1();
let r = slice_seq(&a, 0, 0);
assert!(r.is_ok(), "valid zero-window slice must succeed, got {r:?}");
}
#[test]
fn slice_seq_rejects_rank_mismatch() {
let a1: Array = Array::from_slice::<f32>(&[0.0, 1.0], &(2usize,)).unwrap(); let r = slice_seq(&a1, 0, 0);
match r {
Err(Error::RankMismatch(payload)) => {
assert!(
payload.context().contains("4-D") || payload.context().contains("slice_seq"),
"error context must name expected rank or call site; got: {:?}",
payload.context()
);
assert_eq!(payload.actual(), 1, "expected actual rank 1");
assert_eq!(
payload.actual_shape(),
&[2usize],
"expected actual shape [2]"
);
}
other => panic!("rank-1 must Err(RankMismatch), got {other:?}"),
}
}
fn kv(vals: &[f32]) -> Array {
Array::from_slice::<f32>(vals, &(1usize, 1, vals.len(), 1)).unwrap()
}
fn kv4(b: usize, h: usize, s: usize, d: usize, vals: &[f32]) -> Array {
Array::from_slice::<f32>(vals, &(b, h, s, d)).unwrap()
}
fn rows(a: &Array) -> Vec<f32> {
ops::shape::contiguous(a, false)
.unwrap()
.to_vec::<f32>()
.unwrap()
}
#[test]
fn seq_len_rank_mismatch_default_name_context() {
let a = kv4(1, 1, 1, 1, &[0.0]); let a3: Array = Array::from_slice::<f32>(&[0.0, 1.0, 2.0], &(1usize, 3, 1)).unwrap();
assert_eq!(seq_len("anything", &a).unwrap(), 1);
match seq_len("anything", &a3) {
Err(Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"seq_len: KV cache expects 4-D [B, n_kv_heads, S, head_dim]",
"non-keys/values name must select the generic context arm"
);
assert_eq!(p.actual(), 3);
assert_eq!(p.actual_shape(), &[1usize, 3, 1]);
}
other => panic!("expected Err(RankMismatch), got {other:?}"),
}
}
#[test]
fn head_dim_returns_last_axis_for_valid_4d() {
let a = kv4(1, 1, 1, 4, &[0.0, 1.0, 2.0, 3.0]);
assert_eq!(head_dim("keys", &a).unwrap(), 4);
}
#[test]
fn head_dim_rank_mismatch_keys_name_context() {
let a3: Array = Array::from_slice::<f32>(&[0.0, 1.0], &(1usize, 1, 2)).unwrap(); match head_dim("keys", &a3) {
Err(Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"head_dim: KV cache expects 4-D keys [B, n_kv_heads, S, head_dim]"
);
assert_eq!(p.actual(), 3);
assert_eq!(p.actual_shape(), &[1usize, 1, 2]);
}
other => panic!("expected Err(RankMismatch) for keys, got {other:?}"),
}
}
#[test]
fn head_dim_rank_mismatch_values_name_context() {
let a5: Array = Array::from_slice::<f32>(&[0.0], &(1usize, 1, 1, 1, 1)).unwrap(); match head_dim("values", &a5) {
Err(Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"head_dim: KV cache expects 4-D values [B, n_kv_heads, S, head_dim]"
);
assert_eq!(p.actual(), 5);
}
other => panic!("expected Err(RankMismatch) for values, got {other:?}"),
}
}
#[test]
fn head_dim_rank_mismatch_default_name_context() {
let a2: Array = Array::from_slice::<f32>(&[0.0, 1.0], &(2usize,)).unwrap(); match head_dim("other", &a2) {
Err(Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"head_dim: KV cache expects 4-D [B, n_kv_heads, S, head_dim]"
);
assert_eq!(p.actual(), 1);
assert_eq!(p.actual_shape(), &[2usize]);
}
other => panic!("expected Err(RankMismatch) for default name, got {other:?}"),
}
}
#[test]
fn broadcast_write_rhs_buf_rank_mismatch_keys() {
let buf3: Array = Array::from_slice::<f32>(&[0.0], &(1usize, 1, 1)).unwrap();
let new = kv4(1, 1, 1, 1, &[5.0]);
match broadcast_write_rhs("keys", &buf3, 0, 1, &new) {
Err(Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"broadcast_write_rhs: KV cache expects 4-D keys [B, n_kv_heads, S, head_dim]"
);
assert_eq!(p.actual(), 3);
assert_eq!(p.actual_shape(), &[1usize, 1, 1]);
}
other => panic!("expected buf RankMismatch (keys), got {other:?}"),
}
}
#[test]
fn broadcast_write_rhs_buf_rank_mismatch_values_and_default() {
let buf3: Array = Array::from_slice::<f32>(&[0.0], &(1usize, 1, 1)).unwrap();
let new = kv4(1, 1, 1, 1, &[5.0]);
match broadcast_write_rhs("values", &buf3, 0, 1, &new) {
Err(Error::RankMismatch(p)) => assert_eq!(
p.context(),
"broadcast_write_rhs: KV cache expects 4-D values [B, n_kv_heads, S, head_dim]"
),
other => panic!("expected buf RankMismatch (values), got {other:?}"),
}
match broadcast_write_rhs("xyz", &buf3, 0, 1, &new) {
Err(Error::RankMismatch(p)) => assert_eq!(
p.context(),
"broadcast_write_rhs: KV cache expects 4-D [B, n_kv_heads, S, head_dim]"
),
other => panic!("expected buf RankMismatch (default), got {other:?}"),
}
}
#[test]
fn broadcast_write_rhs_new_rank_mismatch_keys() {
let buf = kv4(1, 1, 1, 1, &[0.0]);
let new3: Array = Array::from_slice::<f32>(&[5.0], &(1usize, 1, 1)).unwrap();
match broadcast_write_rhs("keys", &buf, 0, 1, &new3) {
Err(Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"broadcast_write_rhs: KV cache expects 4-D keys write RHS [B, n_kv_heads, S, head_dim]"
);
assert_eq!(p.actual(), 3);
assert_eq!(p.actual_shape(), &[1usize, 1, 1]);
}
other => panic!("expected RHS RankMismatch (keys), got {other:?}"),
}
}
#[test]
fn broadcast_write_rhs_new_rank_mismatch_values_and_default() {
let buf = kv4(1, 1, 1, 1, &[0.0]);
let new3: Array = Array::from_slice::<f32>(&[5.0], &(1usize, 1, 1)).unwrap();
match broadcast_write_rhs("values", &buf, 0, 1, &new3) {
Err(Error::RankMismatch(p)) => assert_eq!(
p.context(),
"broadcast_write_rhs: KV cache expects 4-D values write RHS [B, n_kv_heads, S, head_dim]"
),
other => panic!("expected RHS RankMismatch (values), got {other:?}"),
}
match broadcast_write_rhs("zzz", &buf, 0, 1, &new3) {
Err(Error::RankMismatch(p)) => assert_eq!(
p.context(),
"broadcast_write_rhs: KV cache expects 4-D write RHS [B, n_kv_heads, S, head_dim]"
),
other => panic!("expected RHS RankMismatch (default), got {other:?}"),
}
}
#[test]
fn broadcast_write_rhs_end_before_start_keys() {
let buf = kv4(1, 1, 8, 1, &[0.0; 8]);
let new = kv4(1, 1, 1, 1, &[5.0]);
match broadcast_write_rhs("keys", &buf, 5, 2, &new) {
Err(Error::InvariantViolation(p)) => {
assert_eq!(p.context(), "set_seq: keys write end < start");
assert_eq!(p.requirement(), "must satisfy end >= start");
}
other => panic!("expected InvariantViolation (keys), got {other:?}"),
}
}
#[test]
fn broadcast_write_rhs_end_before_start_values_and_default() {
let buf = kv4(1, 1, 8, 1, &[0.0; 8]);
let new = kv4(1, 1, 1, 1, &[5.0]);
match broadcast_write_rhs("values", &buf, 5, 2, &new) {
Err(Error::InvariantViolation(p)) => {
assert_eq!(p.context(), "set_seq: values write end < start")
}
other => panic!("expected InvariantViolation (values), got {other:?}"),
}
match broadcast_write_rhs("other", &buf, 5, 2, &new) {
Err(Error::InvariantViolation(p)) => {
assert_eq!(p.context(), "set_seq: write end < start")
}
other => panic!("expected InvariantViolation (default), got {other:?}"),
}
}
#[test]
fn broadcast_write_rhs_non_broadcastable_batch_axis_keys() {
let buf = kv4(2, 1, 4, 1, &[0.0; 8]);
let new = kv4(3, 1, 1, 1, &[1.0, 2.0, 3.0]);
match broadcast_write_rhs("keys", &buf, 0, 1, &new) {
Err(Error::ShapePairMismatch(p)) => {
assert!(
p.context().contains("keys write RHS non-broadcastable"),
"keys context arm (lines 180-181); got {:?}",
p.context()
);
assert_eq!(p.expected(), &[2usize, 1, 1, 1]);
assert_eq!(p.actual(), &[3usize, 1, 1, 1]);
}
other => panic!("expected ShapePairMismatch (keys), got {other:?}"),
}
}
#[test]
fn broadcast_write_rhs_non_broadcastable_values_and_default() {
let buf = kv4(2, 1, 4, 1, &[0.0; 8]);
let new = kv4(3, 1, 1, 1, &[1.0, 2.0, 3.0]);
match broadcast_write_rhs("values", &buf, 0, 1, &new) {
Err(Error::ShapePairMismatch(p)) => assert!(
p.context().contains("values write RHS non-broadcastable"),
"values context arm; got {:?}",
p.context()
),
other => panic!("expected ShapePairMismatch (values), got {other:?}"),
}
match broadcast_write_rhs("kkk", &buf, 0, 1, &new) {
Err(Error::ShapePairMismatch(p)) => assert!(
p.context().contains("write RHS non-broadcastable")
&& !p.context().contains("keys")
&& !p.context().contains("values"),
"default context arm (line 184); got {:?}",
p.context()
),
other => panic!("expected ShapePairMismatch (default), got {other:?}"),
}
}
#[test]
fn broadcast_write_rhs_identity_returns_window_shape() {
let buf = kv4(1, 1, 4, 1, &[0.0, 0.0, 0.0, 0.0]);
let new = kv4(1, 1, 2, 1, &[7.0, 8.0]);
let out = broadcast_write_rhs("keys", &buf, 1, 3, &new).unwrap();
assert_eq!(out.shape(), vec![1, 1, 2, 1], "identity broadcast shape");
assert_eq!(rows(&out), vec![7.0, 8.0], "identity broadcast data");
}
#[test]
fn broadcast_write_rhs_size1_axes_broadcast_up() {
let buf = kv4(2, 1, 4, 3, &[0.0; 24]);
let new = kv4(1, 1, 1, 1, &[9.0]);
let out = broadcast_write_rhs("values", &buf, 0, 1, &new).unwrap();
assert_eq!(
out.shape(),
vec![2, 1, 1, 3],
"size-1 axes broadcast to [B,1,win,D]"
);
assert_eq!(rows(&out), vec![9.0; 6], "all broadcast elements == marker");
}
#[test]
fn nbytes_dtype_size_groups() {
let b = Array::from_slice::<bool>(&[true, false, true, false], &(2usize, 2)).unwrap();
assert_eq!(nbytes(&b).unwrap(), 4);
let u16a = Array::from_slice::<u16>(&[1, 2, 3], &(3usize,)).unwrap();
assert_eq!(nbytes(&u16a).unwrap(), 6);
let f32a = kv4(1, 1, 6, 1, &[0.0; 6]);
assert_eq!(nbytes(&f32a).unwrap(), 24);
let i64a = Array::from_slice::<i64>(&[1, 2], &(2usize,)).unwrap();
assert_eq!(nbytes(&i64a).unwrap(), 16);
}
#[test]
fn concat_seq_appends_on_sequence_axis() {
let a = kv(&[10.0, 20.0]);
let b = kv(&[30.0, 40.0, 50.0]);
let out = concat_seq(&a, &b).unwrap();
assert_eq!(out.shape(), vec![1, 1, 5, 1]);
assert_eq!(rows(&out), vec![10.0, 20.0, 30.0, 40.0, 50.0]);
}
#[test]
fn seq_slice_clamps_overlong_end_to_length() {
let a = kv(&[1.0, 2.0, 3.0, 4.0]);
let out = seq_slice(&a, 1, 99).unwrap();
assert_eq!(out.shape(), vec![1, 1, 3, 1]);
assert_eq!(rows(&out), vec![2.0, 3.0, 4.0]);
}
#[test]
fn seq_slice_start_clamped_to_end_is_empty() {
let a = kv(&[1.0, 2.0, 3.0, 4.0]);
let out = seq_slice(&a, 5, 99).unwrap();
assert_eq!(out.shape(), vec![1, 1, 0, 1], "empty window after clamp");
}
#[test]
fn concat_parts_single_rank_invalid_part_is_rank_mismatch() {
let a3: Array = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(1usize, 3, 1)).unwrap();
match concat_parts(&[&a3]) {
Err(Error::RankMismatch(p)) => {
assert_eq!(
p.context(),
"concat_parts: KV cache concat expects 4-D [B, n_kv_heads, S, head_dim] parts"
);
assert_eq!(p.actual(), 3);
assert_eq!(p.actual_shape(), &[1usize, 3, 1]);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
}
#[test]
fn concat_parts_all_empty_returns_first_part() {
let e1 = kv(&[]); let e2 = kv(&[]); let out = concat_parts(&[&e1, &e2]).unwrap();
assert_eq!(
out.shape(),
vec![1, 1, 0, 1],
"returns the first empty part"
);
}
#[test]
fn concat_parts_empty_slice_is_empty_input_error() {
match concat_parts(&[]) {
Err(Error::EmptyInput(p)) => assert_eq!(p.context(), "concat_parts: parts"),
other => panic!("expected EmptyInput, got {other:?}"),
}
}
#[test]
fn concat_parts_drops_empty_keeps_nonempty_order() {
let e = kv(&[]);
let a = kv(&[1.0, 2.0]);
let b = kv(&[3.0]);
let out = concat_parts(&[&e, &a, &e, &b]).unwrap();
assert_eq!(out.shape(), vec![1, 1, 3, 1]);
assert_eq!(rows(&out), vec![1.0, 2.0, 3.0]);
}
#[test]
fn concat_parts_single_valid_part_is_identity() {
let a = kv(&[4.0, 5.0]);
let out = concat_parts(&[&a]).unwrap();
assert_eq!(out.shape(), vec![1, 1, 2, 1]);
assert_eq!(rows(&out), vec![4.0, 5.0]);
}