pub struct MultiHeadAttention { /* private fields */ }Expand description
Multi-head attention.
Projects input through Q, K, V linear layers, splits into n_heads heads,
runs per-head scaled dot-product attention (via the fused Attention op),
concatenates heads, and applies an output projection.
Implementations§
Source§impl MultiHeadAttention
impl MultiHeadAttention
Sourcepub fn new(
wq: Linear,
wk: Linear,
wv: Linear,
wo: Linear,
n_heads: usize,
) -> Self
pub fn new( wq: Linear, wk: Linear, wv: Linear, wo: Linear, n_heads: usize, ) -> Self
Create a new MultiHeadAttention from pre-built linear layers.
wq,wk,wv: projection layers with weight[n_heads * head_dim, model_dim]wo: output projection[model_dim, n_heads * head_dim]n_heads: number of attention heads
Sourcepub fn forward_causal(&self, x: &Tensor) -> Result<Tensor>
pub fn forward_causal(&self, x: &Tensor) -> Result<Tensor>
Forward pass with causal masking (self-attention, auto-regressive).
x has shape [seq_len, model_dim].
Returns [seq_len, model_dim].
Auto Trait Implementations§
impl Freeze for MultiHeadAttention
impl !RefUnwindSafe for MultiHeadAttention
impl Send for MultiHeadAttention
impl Sync for MultiHeadAttention
impl Unpin for MultiHeadAttention
impl UnsafeUnpin for MultiHeadAttention
impl !UnwindSafe for MultiHeadAttention
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more