xn 0.1.1

Another minimalist deep-learning framework optimized for inference.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
use crate::{Backend, Result, Tensor, WithDType};

/// Repeats a key or value tensor for grouped query attention
/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`,
pub fn repeat_kv<T: WithDType, B: Backend>(xs: Tensor<T, B>, n_rep: usize) -> Result<Tensor<T, B>> {
    if n_rep == 1 {
        Ok(xs)
    } else {
        let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
        // Using cat is faster than a broadcast as it avoids going through a potentially
        // strided copy.
        // https://github.com/huggingface/candle/pull/2043
        Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
    }
}