Trait dfdx::tensor_ops::TryAttentionReshape
source · pub trait TryAttentionReshape<E: Dtype>: Storage<E> {
// Required method
fn try_attention_reshape<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>(
&self,
qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>
) -> Result<(Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>, Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>), Self::Err>;
// Provided method
fn attention_reshape<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>(
&self,
qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>
) -> (Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>, Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>) { ... }
}
Expand description
AttentionReshape qkv + past_key + past_value into (q, k, v) used in attention layer
Required Methods§
sourcefn try_attention_reshape<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>(
&self,
qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>
) -> Result<(Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>, Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>), Self::Err>
fn try_attention_reshape<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>( &self, qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>, past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self> ) -> Result<(Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>, Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>), Self::Err>
Fallible version of TryAttentionReshape::attention_reshape
Provided Methods§
sourcefn attention_reshape<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>(
&self,
qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>,
past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>,
past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>
) -> (Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>, Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>)
fn attention_reshape<const THREE_HIDDEN_DIM: usize, const NUM_HEADS: usize, const HEAD_DIM: usize>( &self, qkv: &Tensor<(usize, Const<THREE_HIDDEN_DIM>), E, Self>, past_key: &Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, past_value: &Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self> ) -> (Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>, Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), E, Self>, Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), E, Self>)
This is an inference only kernel:
Within transformers
architecture, a core component is the attention
layer, which can be written in many forms.
This particular version expects a qkv
tensor (gotten from one single
Linear layer, corresponding of stacked query
, key
, value
.
And past_key
and past_value
which are the cached values within attention
(This speeds up inference speed).
For the first pass, just send zero-width tensors when the cache isn’t present
already.
Having a single layer instead of many cat
, reshape
, permute
makes this
operation very efficient on GPU.