use sapient_core::{DType, Shape, Tensor};
use sapient_models::forward::common::update_kv_cache;
#[test]
fn test_update_kv_cache() {
let b = 1;
let n_kv = 2;
let max_seq = 4;
let hd = 2;
let mut cache = Tensor::zeros(Shape::new([b, n_kv, max_seq, hd]), DType::F32).unwrap();
let k1_data = vec![
1.0, 1.0, 2.0, 2.0, ];
let k1 = Tensor::from_f32(&k1_data, Shape::new([b, n_kv, 1, hd])).unwrap();
let view1 = update_kv_cache(&mut cache, 0, &k1).unwrap();
assert_eq!(view1.shape().dims(), &[1, 2, 1, 2]);
drop(view1);
let k2_data = vec![
3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, ];
let k2 = Tensor::from_f32(&k2_data, Shape::new([b, n_kv, 2, hd])).unwrap();
let view2 = update_kv_cache(&mut cache, 1, &k2).unwrap();
assert_eq!(view2.shape().dims(), &[1, 2, 3, 2]);
drop(view2);
let cache_slice = cache.as_f32_slice();
let h0 = &cache_slice[0..8];
assert_eq!(h0, &[1.0, 1.0, 3.0, 3.0, 4.0, 4.0, 0.0, 0.0]);
let h1 = &cache_slice[8..16];
assert_eq!(h1, &[2.0, 2.0, 5.0, 5.0, 6.0, 6.0, 0.0, 0.0]);
}