pub trait TryAttentionReshape<E: Dtype>: Storage<E> {
    // Required method
    fn try_attention_reshape<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>(
        &self,
        qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
        past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
        past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>
    ) -> Result<(Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>, Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>), Self::Err>;

    // Provided method
    fn attention_reshape<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>(
        &self,
        qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
        past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
        past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>
    ) -> (Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>, Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>) { ... }
}
Expand description

AttentionReshape qkv + past_key + past_value into (q, k, v) used in attention layer

Required Methods§

source

fn try_attention_reshape<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>( &self, qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>, past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self> ) -> Result<(Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>, Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>), Self::Err>

Provided Methods§

source

fn attention_reshape<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>( &self, qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>, past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self> ) -> (Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>, Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>)

This is an inference only kernel: Within transformers architecture, a core component is the attention layer, which can be written in many forms.

This particular version expects a qkv tensor (gotten from one single Linear layer, corresponding of stacked query, key, value. And past_key and past_value which are the cached values within attention (This speeds up inference speed). For the first pass, just send zero-width tensors when the cache isn’t present already.

Having a single layer instead of many cat, reshape, permute makes this operation very efficient on GPU.

Implementors§

source§

impl<E: Dtype, D: AttentionReshapeKernel<E>> TryAttentionReshape<E> for D