use crate::{Backend, Result, Tensor, WithDType};
pub fn repeat_kv<T: WithDType, B: Backend>(xs: Tensor<T, B>, n_rep: usize) -> Result<Tensor<T, B>> {
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
}
}