pub struct MultiHeadAttention { /* private fields */ }Expand description
Multi-Head Attention (Vaswani et al., 2017).
Allows the model to jointly attend to information from different representation subspaces at different positions.
§Example
ⓘ
let mha = MultiHeadAttention::new(512, 8); // d_model=512, num_heads=8
let q = Tensor::randn(&[32, 10, 512]);
let k = Tensor::randn(&[32, 20, 512]);
let v = Tensor::randn(&[32, 20, 512]);
let (output, attn_weights) = mha.forward_qkv(&q, &k, &v, None);Implementations§
Source§impl MultiHeadAttention
impl MultiHeadAttention
Sourcepub fn with_dropout(self, dropout_p: f32) -> Self
pub fn with_dropout(self, dropout_p: f32) -> Self
Set dropout probability.
Sourcepub fn forward_qkv(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attn_mask: Option<&Tensor>,
) -> (Tensor, Tensor)
pub fn forward_qkv( &self, query: &Tensor, key: &Tensor, value: &Tensor, attn_mask: Option<&Tensor>, ) -> (Tensor, Tensor)
Forward pass with separate query, key, value inputs.
§Arguments
query- Query tensor [batch,target_len,embed_dim]key- Key tensor [batch,source_len,embed_dim]value- Value tensor [batch,source_len,embed_dim]attn_mask- Optional attention mask [batch,target_len,source_len]
§Returns
Tuple of (output, attention_weights)
Trait Implementations§
Source§impl Debug for MultiHeadAttention
impl Debug for MultiHeadAttention
Source§impl Module for MultiHeadAttention
impl Module for MultiHeadAttention
Source§fn parameters_mut(&mut self) -> Vec<&mut Tensor>
fn parameters_mut(&mut self) -> Vec<&mut Tensor>
Get mutable references to all learnable parameters. Read more
Source§fn refresh_caches(&mut self)
fn refresh_caches(&mut self)
Refresh any cached computations after parameters have been modified. Read more
Source§fn num_parameters(&self) -> usize
fn num_parameters(&self) -> usize
Get the number of learnable parameters.
Auto Trait Implementations§
impl Freeze for MultiHeadAttention
impl !RefUnwindSafe for MultiHeadAttention
impl Send for MultiHeadAttention
impl Sync for MultiHeadAttention
impl Unpin for MultiHeadAttention
impl UnsafeUnpin for MultiHeadAttention
impl !UnwindSafe for MultiHeadAttention
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more