use crate::Tensor;
use ndarray::{Array2, array};
crate::codegen_tests! {
fn test_rms_norm_basic(config) {
let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0, 4.0]]);
let mut result = x.rms_norm(-1, 1e-5).unwrap();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 4]);
let rms_inv = 1.0 / (7.5f32 + 1e-5).sqrt();
for i in 0..4 {
let expected = (i + 1) as f32 * rms_inv;
assert!((view[[0, i]] - expected).abs() < 1e-4, "rms_norm[{i}]: got {}, expected {}", view[[0, i]], expected);
}
}
fn test_rms_norm_axis(config) {
let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let mut result = x.rms_norm(-1, 1e-5).unwrap();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[2, 3]);
let rms0 = 1.0 / (14.0f32 / 3.0 + 1e-5).sqrt();
assert!((view[[0, 0]] - 1.0 * rms0).abs() < 1e-4);
assert!((view[[0, 1]] - 2.0 * rms0).abs() < 1e-4);
let rms1 = 1.0 / (77.0f32 / 3.0 + 1e-5).sqrt();
assert!((view[[1, 0]] - 4.0 * rms1).abs() < 1e-4);
}
fn test_embedding_basic(config) {
let weight_data: Vec<f32> = (0..12).map(|v| v as f32).collect();
let weight = Tensor::from_ndarray(&Array2::from_shape_vec((3, 4), weight_data).unwrap());
let indices = Tensor::from_slice([2i32, 0]);
let mut result = weight.embedding(&indices).unwrap();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[2, 4]);
assert_eq!(view[[0, 0]], 8.0);
assert_eq!(view[[0, 3]], 11.0);
assert_eq!(view[[1, 0]], 0.0);
assert_eq!(view[[1, 3]], 3.0);
}
fn test_embedding_2d_indices(config) {
let weight = Tensor::from_ndarray(&array![[0.0f32, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]]);
let indices = Tensor::from_ndarray(&array![[0i32, 1, 2], [3, 2, 1]]);
let mut result = weight.embedding(&indices).unwrap();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[2, 3, 2]);
assert_eq!(view[[0, 0, 0]], 0.0);
assert_eq!(view[[0, 0, 1]], 1.0);
assert_eq!(view[[0, 2, 0]], 4.0);
assert_eq!(view[[1, 0, 0]], 6.0);
assert_eq!(view[[1, 0, 1]], 7.0);
}
fn test_sdpa_basic(config) {
let q = Tensor::from_ndarray(&array![[[[1.0f32, 0.0], [0.0, 1.0]]]]);
let k = q.clone();
let v = Tensor::from_ndarray(&array![[[[1.0f32, 2.0], [3.0, 4.0]]]]);
let mut result = q.scaled_dot_product_attention().key(&k).value(&v).call().unwrap();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 1, 2, 2]);
}
fn test_sdpa_causal(config) {
let q = Tensor::from_ndarray(&array![[[[1.0f32, 0.0], [0.0, 1.0], [1.0, 1.0]]]]);
let k = q.clone();
let v = Tensor::from_ndarray(&array![[[[1.0f32, 0.0], [0.0, 1.0], [0.0, 0.0]]]]);
let mut result = q.scaled_dot_product_attention().key(&k).value(&v).is_causal(true).call().unwrap();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 1, 3, 2]);
assert!((view[[0, 0, 0, 0]] - 1.0).abs() < 1e-4);
assert!((view[[0, 0, 0, 1]] - 0.0).abs() < 1e-4);
}
fn test_sdpa_softcap(config) {
let q = Tensor::from_ndarray(&array![[[[10.0f32, 0.0], [0.0, 10.0]]]]);
let k = q.clone();
let v = Tensor::from_ndarray(&array![[[[1.0f32, 0.0], [0.0, 1.0]]]]);
let mut result = q.scaled_dot_product_attention().key(&k).value(&v).softcap(1.0).call().unwrap();
result.realize_with(&config).unwrap();
for val in result.as_vec::<f32>().unwrap() {
assert!(val.is_finite(), "softcap produced non-finite value: {val}");
}
}
fn test_sdpa_bool_mask_true_masks_out(config) {
let q = Tensor::from_ndarray(&array![[[[1.0f32, 0.0]]]]);
let k = Tensor::from_ndarray(&array![[[[1.0f32, 0.0], [0.0, 1.0]]]]);
let v = Tensor::from_ndarray(&array![[[[10.0f32, 1.0], [1.0, 10.0]]]]);
let mask = Tensor::from_ndarray(&array![[[[true, false]]]]);
let mut result = q
.scaled_dot_product_attention()
.key(&k)
.value(&v)
.maybe_attn_mask(Some(&mask))
.call()
.unwrap();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 1, 1, 2]);
assert!((view[[0, 0, 0, 0]] - 1.0).abs() < 1e-4);
assert!((view[[0, 0, 0, 1]] - 10.0).abs() < 1e-4);
}
fn test_sdpa_bool_mask_all_masked_row_finite(config) {
let q = Tensor::from_ndarray(&array![[[[1.0f32, 0.0]]]]);
let k = Tensor::from_ndarray(&array![[[[1.0f32, 0.0], [0.0, 1.0]]]]);
let v = Tensor::from_ndarray(&array![[[[10.0f32, 1.0], [1.0, 10.0]]]]);
let mask = Tensor::from_ndarray(&array![[[[true, true]]]]);
let mut result = q
.scaled_dot_product_attention()
.key(&k)
.value(&v)
.maybe_attn_mask(Some(&mask))
.call()
.unwrap();
result.realize_with(&config).unwrap();
for v in result.as_vec::<f32>().unwrap() {
assert!(v.is_finite(), "expected finite attention output, got {v}");
}
}
fn test_sdpa_bool_mask_all_masked_with_causal_finite(config) {
let q = Tensor::from_ndarray(&array![[[[1.0f32, 0.0], [0.0, 1.0]]]]);
let k = q.clone();
let v = Tensor::from_ndarray(&array![[[[10.0f32, 1.0], [1.0, 10.0]]]]);
let mask = Tensor::from_ndarray(&array![[[[true, true], [true, true]]]]);
let mut result = q
.scaled_dot_product_attention()
.key(&k)
.value(&v)
.is_causal(true)
.maybe_attn_mask(Some(&mask))
.call()
.unwrap();
result.realize_with(&config).unwrap();
for v in result.as_vec::<f32>().unwrap() {
assert!(v.is_finite(), "expected finite attention output with causal+mask, got {v}");
}
}
fn test_sdpa_rejects_non_float_qkv(_config) {
let qf = Tensor::from_ndarray(&array![[[[1.0f32, 0.0]]]]);
let kf = Tensor::from_ndarray(&array![[[[1.0f32, 0.0], [0.0, 1.0]]]]);
let vf = Tensor::from_ndarray(&array![[[[10.0f32, 1.0], [1.0, 10.0]]]]);
let qi = Tensor::from_ndarray(&array![[[[1i32, 0]]]]);
let ki = Tensor::from_ndarray(&array![[[[1i32, 0], [0, 1]]]]);
let vi = Tensor::from_ndarray(&array![[[[10i32, 1], [1, 10]]]]);
let err_q = match qi.scaled_dot_product_attention().key(&kf).value(&vf).call() {
Ok(_) => panic!("expected query dtype error"),
Err(err) => err,
};
assert!(matches!(err_q, crate::Error::FloatDTypeRequired { arg: "query", .. }));
let err_k = match qf.scaled_dot_product_attention().key(&ki).value(&vf).call() {
Ok(_) => panic!("expected key dtype error"),
Err(err) => err,
};
assert!(matches!(err_k, crate::Error::FloatDTypeRequired { arg: "key", .. }));
let err_v = match qf.scaled_dot_product_attention().key(&kf).value(&vi).call() {
Ok(_) => panic!("expected value dtype error"),
Err(err) => err,
};
assert!(matches!(err_v, crate::Error::FloatDTypeRequired { arg: "value", .. }));
}
fn test_rotary_emb_split(config) {
let x = Tensor::from_ndarray(&array![[[1.0f32, 2.0, 3.0, 4.0]]]);
let cos = Tensor::from_ndarray(&array![[[1.0f32, 0.0]]]);
let sin = Tensor::from_ndarray(&array![[[0.0f32, 0.0]]]);
let mut result = x.apply_rotary_emb(&cos, &sin, false).unwrap();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 1, 4]);
assert!((view[[0, 0, 0]] - 1.0).abs() < 1e-5);
assert!((view[[0, 0, 1]] - 0.0).abs() < 1e-5);
assert!((view[[0, 0, 2]] - 3.0).abs() < 1e-5);
assert!((view[[0, 0, 3]] - 0.0).abs() < 1e-5);
}
fn test_rotary_emb_interleaved(config) {
let x = Tensor::from_ndarray(&array![[[1.0f32, 2.0, 3.0, 4.0]]]);
let cos = Tensor::from_ndarray(&array![[[1.0f32, 1.0]]]);
let sin = Tensor::from_ndarray(&array![[[0.0f32, 0.0]]]);
let mut result = x.apply_rotary_emb(&cos, &sin, true).unwrap();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert_eq!(view.shape(), &[1, 1, 4]);
assert!((view[[0, 0, 0]] - 1.0).abs() < 1e-5);
assert!((view[[0, 0, 1]] - 2.0).abs() < 1e-5);
assert!((view[[0, 0, 2]] - 3.0).abs() < 1e-5);
assert!((view[[0, 0, 3]] - 4.0).abs() < 1e-5);
}
fn test_rotary_emb_rotation(config) {
let x = Tensor::from_ndarray(&array![[[1.0f32, 0.0, 0.0, 1.0]]]);
let cos = Tensor::from_ndarray(&array![[[0.0f32, 0.0]]]);
let sin = Tensor::from_ndarray(&array![[[1.0f32, 1.0]]]);
let mut result = x.apply_rotary_emb(&cos, &sin, false).unwrap();
result.realize_with(&config).unwrap();
let view = result.array_view::<f32>().unwrap();
assert!((view[[0, 0, 0]] - 0.0).abs() < 1e-5);
assert!((view[[0, 0, 1]] - (-1.0)).abs() < 1e-5);
assert!((view[[0, 0, 2]] - 1.0).abs() < 1e-5);
assert!((view[[0, 0, 3]] - 0.0).abs() < 1e-5);
}
}