Struct dfdx::nn::MultiHeadAttention
source · [−]pub struct MultiHeadAttention<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize = EMBED_DIM, const V_DIM: usize = EMBED_DIM> {
pub w_q: Linear<EMBED_DIM, K_DIM>,
pub w_k: Linear<EMBED_DIM, K_DIM>,
pub w_v: Linear<EMBED_DIM, V_DIM>,
pub w_o: Linear<V_DIM, EMBED_DIM>,
}
Expand description
Requires Nightly A multi-head attention layer.
Generics:
EMBED_DIM
: The size of query vectors.NUM_HEADS
The number of heads to split query/key/value into.- Optional
K_DIM
: The size of key vectors. Defaults toEMBED_DIM
- Optional
V_DIM
The size of value vectors. Defaults toEMBED_DIM
Pytorch equivalent: torch.nn.MultiheadAttention(EMBED_DIM, NUM_HEADS, batch_first=True)
Examples
MultiHeadAttention<8, 2>
is an attention layer with 2 heads and 8 token, key and value dims.MultiHeadAttention<8, 2, 6, 4>
is an attention layer with the key and value dimension different than the embed dimension TODO: Doctests fail for some reason
Fields
w_q: Linear<EMBED_DIM, K_DIM>
w_k: Linear<EMBED_DIM, K_DIM>
w_v: Linear<EMBED_DIM, V_DIM>
w_o: Linear<V_DIM, EMBED_DIM>
Trait Implementations
sourceimpl<const M: usize, const H: usize, const K: usize, const V: usize> CanUpdateWithGradients for MultiHeadAttention<M, H, K, V>
impl<const M: usize, const H: usize, const K: usize, const V: usize> CanUpdateWithGradients for MultiHeadAttention<M, H, K, V>
sourcefn update<G: GradientProvider>(
&mut self,
grads: &mut G,
unused: &mut UnusedTensors
)
fn update<G: GradientProvider>(
&mut self,
grads: &mut G,
unused: &mut UnusedTensors
)
Updates self given the GradientProvider. When any parameters that
are NOT present in
G
, then this function should
add the tensor’s UniqueId to UnusedTensors. Read moresourceimpl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> Clone for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
impl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> Clone for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
sourcefn clone(&self) -> MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
fn clone(&self) -> MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
Returns a copy of the value. Read more
1.0.0 · sourcefn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source
. Read moresourceimpl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> Debug for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
impl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> Debug for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
sourceimpl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> Default for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
impl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> Default for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
sourcefn default() -> MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
fn default() -> MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
Returns the “default value” for a type. Read more
sourceimpl<const M: usize, const H: usize, const K: usize, const V: usize> LoadFromNpz for MultiHeadAttention<M, H, K, V>
impl<const M: usize, const H: usize, const K: usize, const V: usize> LoadFromNpz for MultiHeadAttention<M, H, K, V>
sourceimpl<const M: usize, const H: usize, const K: usize, const V: usize, const S1: usize, const S2: usize, TAPE: 'static + Tape> Module<(Tensor2D<S1, M, TAPE>, Tensor2D<S2, M, NoneTape>, Tensor2D<S2, M, NoneTape>)> for MultiHeadAttention<M, H, K, V>where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
impl<const M: usize, const H: usize, const K: usize, const V: usize, const S1: usize, const S2: usize, TAPE: 'static + Tape> Module<(Tensor2D<S1, M, TAPE>, Tensor2D<S2, M, NoneTape>, Tensor2D<S2, M, NoneTape>)> for MultiHeadAttention<M, H, K, V>where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
sourceimpl<const M: usize, const H: usize, const K: usize, const V: usize, const B: usize, const S1: usize, const S2: usize, TAPE: 'static + Tape> Module<(Tensor3D<B, S1, M, TAPE>, Tensor3D<B, S2, M, NoneTape>, Tensor3D<B, S2, M, NoneTape>)> for MultiHeadAttention<M, H, K, V>where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
impl<const M: usize, const H: usize, const K: usize, const V: usize, const B: usize, const S1: usize, const S2: usize, TAPE: 'static + Tape> Module<(Tensor3D<B, S1, M, TAPE>, Tensor3D<B, S2, M, NoneTape>, Tensor3D<B, S2, M, NoneTape>)> for MultiHeadAttention<M, H, K, V>where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
sourceimpl<const M: usize, const H: usize, const K: usize, const V: usize, T> ModuleMut<T> for MultiHeadAttention<M, H, K, V>where
Self: Module<T>,
impl<const M: usize, const H: usize, const K: usize, const V: usize, T> ModuleMut<T> for MultiHeadAttention<M, H, K, V>where
Self: Module<T>,
type Output = <MultiHeadAttention<M, H, K, V> as Module<T>>::Output
type Output = <MultiHeadAttention<M, H, K, V> as Module<T>>::Output
The type that this unit produces given
Input
.sourcefn forward_mut(&mut self, t: T) -> Self::Output
fn forward_mut(&mut self, t: T) -> Self::Output
sourceimpl<const M: usize, const H: usize, const K: usize, const V: usize> ResetParams for MultiHeadAttention<M, H, K, V>
impl<const M: usize, const H: usize, const K: usize, const V: usize> ResetParams for MultiHeadAttention<M, H, K, V>
Auto Trait Implementations
impl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> RefUnwindSafe for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
impl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> Send for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
impl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> Sync for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
impl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> Unpin for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
impl<const EMBED_DIM: usize, const NUM_HEADS: usize, const K_DIM: usize, const V_DIM: usize> UnwindSafe for MultiHeadAttention<EMBED_DIM, NUM_HEADS, K_DIM, V_DIM>
Blanket Implementations
sourceimpl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more