pub struct MultiHeadAttention {
pub d_model: usize,
pub n_heads: usize,
pub d_head: usize,
pub w_q: Vec<f64>,
pub w_k: Vec<f64>,
pub w_v: Vec<f64>,
pub w_o: Vec<f64>,
pub b_o: Vec<f64>,
}Expand description
Multi-head attention module.
Projects Q, K, V with learned linear projections, runs n_heads
parallel attention heads, then concatenates and projects the output.
All weight matrices are stored flat row-major.
Fields§
§d_model: usizeModel dimensionality.
n_heads: usizeNumber of attention heads.
d_head: usizeDimensionality per head: d_model / n_heads.
w_q: Vec<f64>W_Q projection [d_model × d_model].
w_k: Vec<f64>W_K projection [d_model × d_model].
w_v: Vec<f64>W_V projection [d_model × d_model].
w_o: Vec<f64>W_O output projection [d_model × d_model].
b_o: Vec<f64>Output bias [d_model].
Implementations§
Source§impl MultiHeadAttention
impl MultiHeadAttention
Sourcepub fn new(d_model: usize, n_heads: usize) -> Self
pub fn new(d_model: usize, n_heads: usize) -> Self
Create a new MHA module with zero-initialised projections.
Sourcepub fn init_identity(&mut self)
pub fn init_identity(&mut self)
Initialise W_Q, W_K, W_V, W_O with identity-like weights for testing.
Sourcepub fn forward(&self, x: &[f64], seq_len: usize) -> Vec<f64>
pub fn forward(&self, x: &[f64], seq_len: usize) -> Vec<f64>
Forward pass.
x has shape [seq_len × d_model] (flat row-major).
Returns output of shape [seq_len × d_model].
Sourcepub fn num_params(&self) -> usize
pub fn num_params(&self) -> usize
Total number of trainable parameters.
Trait Implementations§
Source§impl Clone for MultiHeadAttention
impl Clone for MultiHeadAttention
Source§fn clone(&self) -> MultiHeadAttention
fn clone(&self) -> MultiHeadAttention
Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreAuto Trait Implementations§
impl Freeze for MultiHeadAttention
impl RefUnwindSafe for MultiHeadAttention
impl Send for MultiHeadAttention
impl Sync for MultiHeadAttention
impl Unpin for MultiHeadAttention
impl UnsafeUnpin for MultiHeadAttention
impl UnwindSafe for MultiHeadAttention
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more