Struct dfdx::nn::MultiHeadAttention
source · [−]pub struct MultiHeadAttention<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> {
pub w_q: Linear<M, K>,
pub w_k: Linear<N, K>,
pub w_v: Linear<N, V>,
pub w_o: Linear<V, M>,
}
Expand description
Requires Nightly A multi-head attention layer.
Generics
M
The embedding size of token vectors from decoder.N
The embedding size of token vectors from encoder.K
The size of the keys in self attention.V
The size of the values.H
The number of attention heads.
Examples
MultiHeadAttention<8, 10, 10, 10, 2>
is an attention layer with 2 heads and 10 token, key and value dims.
TODO: Doctests fail for some reason
Fields
w_q: Linear<M, K>
w_k: Linear<N, K>
w_v: Linear<N, V>
w_o: Linear<V, M>
Trait Implementations
sourceimpl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> CanUpdateWithGradients for MultiHeadAttention<M, N, K, V, H>
impl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> CanUpdateWithGradients for MultiHeadAttention<M, N, K, V, H>
sourcefn update<G: GradientProvider>(
&mut self,
grads: &mut G,
unused: &mut UnusedTensors
)
fn update<G: GradientProvider>(
&mut self,
grads: &mut G,
unused: &mut UnusedTensors
)
Updates self given the GradientProvider. When any parameters that
are NOT present in G
, then this function should
add the tensor’s UniqueId to UnusedTensors. Read more
sourceimpl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> Clone for MultiHeadAttention<M, N, K, V, H>
impl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> Clone for MultiHeadAttention<M, N, K, V, H>
sourcefn clone(&self) -> MultiHeadAttention<M, N, K, V, H>
fn clone(&self) -> MultiHeadAttention<M, N, K, V, H>
Returns a copy of the value. Read more
1.0.0 · sourcefn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from source
. Read more
sourceimpl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> Debug for MultiHeadAttention<M, N, K, V, H>
impl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> Debug for MultiHeadAttention<M, N, K, V, H>
sourceimpl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> Default for MultiHeadAttention<M, N, K, V, H>
impl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> Default for MultiHeadAttention<M, N, K, V, H>
sourcefn default() -> MultiHeadAttention<M, N, K, V, H>
fn default() -> MultiHeadAttention<M, N, K, V, H>
Returns the “default value” for a type. Read more
sourceimpl<const M: usize, const N: usize, const K: usize, const V: usize, const S1: usize, const S2: usize, const H: usize, T: 'static + Tape> Module<(Tensor2D<S1, M, T>, Tensor2D<S2, N, NoneTape>)> for MultiHeadAttention<M, N, K, V, H> where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
impl<const M: usize, const N: usize, const K: usize, const V: usize, const S1: usize, const S2: usize, const H: usize, T: 'static + Tape> Module<(Tensor2D<S1, M, T>, Tensor2D<S2, N, NoneTape>)> for MultiHeadAttention<M, N, K, V, H> where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
sourcefn forward(
&self,
(input, from_enc): (Tensor2D<S1, M, T>, Tensor2D<S2, N>)
) -> Self::Output
fn forward(
&self,
(input, from_enc): (Tensor2D<S1, M, T>, Tensor2D<S2, N>)
) -> Self::Output
Encoder-Decoder style self attention where one set of tensors is used for values and keys, and another is used for queries
sourcefn forward_mut(&mut self, input: Input) -> Self::Output
fn forward_mut(&mut self, input: Input) -> Self::Output
Pass an Input
through the unit and produce Self::Output.
Can be implemented for multiple Input
types. Read more
sourceimpl<const B: usize, const M: usize, const N: usize, const K: usize, const V: usize, const S1: usize, const S2: usize, const H: usize, T: 'static + Tape> Module<(Tensor3D<B, S1, M, T>, Tensor3D<B, S2, N, NoneTape>)> for MultiHeadAttention<M, N, K, V, H> where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
impl<const B: usize, const M: usize, const N: usize, const K: usize, const V: usize, const S1: usize, const S2: usize, const H: usize, T: 'static + Tape> Module<(Tensor3D<B, S1, M, T>, Tensor3D<B, S2, N, NoneTape>)> for MultiHeadAttention<M, N, K, V, H> where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
sourcefn forward(
&self,
(input, from_enc): (Tensor3D<B, S1, M, T>, Tensor3D<B, S2, N>)
) -> Self::Output
fn forward(
&self,
(input, from_enc): (Tensor3D<B, S1, M, T>, Tensor3D<B, S2, N>)
) -> Self::Output
Batched Encoder-Decoder style self attention where one set of tensors is used for values and keys, and another is used for queries
sourcefn forward_mut(&mut self, input: Input) -> Self::Output
fn forward_mut(&mut self, input: Input) -> Self::Output
Pass an Input
through the unit and produce Self::Output.
Can be implemented for multiple Input
types. Read more
sourceimpl<const M: usize, const K: usize, const V: usize, const S: usize, const H: usize, T: 'static + Tape> Module<Tensor2D<S, M, T>> for MultiHeadAttention<M, M, K, V, H> where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
impl<const M: usize, const K: usize, const V: usize, const S: usize, const H: usize, T: 'static + Tape> Module<Tensor2D<S, M, T>> for MultiHeadAttention<M, M, K, V, H> where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
sourcefn forward(&self, input: Tensor2D<S, M, T>) -> Self::Output
fn forward(&self, input: Tensor2D<S, M, T>) -> Self::Output
Normal self attention (where same tensors are used for keys, queries and values)
sourcefn forward_mut(&mut self, input: Input) -> Self::Output
fn forward_mut(&mut self, input: Input) -> Self::Output
Pass an Input
through the unit and produce Self::Output.
Can be implemented for multiple Input
types. Read more
sourceimpl<const B: usize, const M: usize, const K: usize, const V: usize, const S: usize, const H: usize, T: 'static + Tape> Module<Tensor3D<B, S, M, T>> for MultiHeadAttention<M, M, K, V, H> where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
impl<const B: usize, const M: usize, const K: usize, const V: usize, const S: usize, const H: usize, T: 'static + Tape> Module<Tensor3D<B, S, M, T>> for MultiHeadAttention<M, M, K, V, H> where
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
Assert<{ _ }>: ConstTrue,
sourcefn forward(&self, input: Tensor3D<B, S, M, T>) -> Self::Output
fn forward(&self, input: Tensor3D<B, S, M, T>) -> Self::Output
Batched normal self attention (where same tensors are used for keys, queries and values)
sourcefn forward_mut(&mut self, input: Input) -> Self::Output
fn forward_mut(&mut self, input: Input) -> Self::Output
Pass an Input
through the unit and produce Self::Output.
Can be implemented for multiple Input
types. Read more
sourceimpl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> ResetParams for MultiHeadAttention<M, N, K, V, H>
impl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> ResetParams for MultiHeadAttention<M, N, K, V, H>
Auto Trait Implementations
impl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> RefUnwindSafe for MultiHeadAttention<M, N, K, V, H>
impl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> !Send for MultiHeadAttention<M, N, K, V, H>
impl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> !Sync for MultiHeadAttention<M, N, K, V, H>
impl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> Unpin for MultiHeadAttention<M, N, K, V, H>
impl<const M: usize, const N: usize, const K: usize, const V: usize, const H: usize> UnwindSafe for MultiHeadAttention<M, N, K, V, H>
Blanket Implementations
sourceimpl<T> BorrowMut<T> for T where
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more