pub struct MultiHeadAttention {
pub w_q: Tensor,
pub w_k: Tensor,
pub w_v: Tensor,
pub w_o: Tensor,
pub b_q: Option<Tensor>,
pub b_k: Option<Tensor>,
pub b_v: Option<Tensor>,
pub q_norm: Option<Tensor>,
pub k_norm: Option<Tensor>,
/* private fields */
}Expand description
Multi-head self-attention layer
Fields§
§w_q: TensorQuery projection weight (hidden_size x hidden_size)
w_k: TensorKey projection weight (hidden_size x kv_hidden_size)
w_v: TensorValue projection weight (hidden_size x kv_hidden_size)
w_o: TensorOutput projection weight (hidden_size x hidden_size)
b_q: Option<Tensor>Optional query bias (Qwen2 uses attention biases)
b_k: Option<Tensor>Optional key bias
b_v: Option<Tensor>Optional value bias
q_norm: Option<Tensor>Optional Q RMSNorm weight (Qwen3 uses QK-norm, shape=[head_dim])
k_norm: Option<Tensor>Optional K RMSNorm weight (Qwen3 uses QK-norm, shape=[head_dim])
Implementations§
Source§impl MultiHeadAttention
impl MultiHeadAttention
Sourcepub fn new(config: &TransformerConfig) -> Self
pub fn new(config: &TransformerConfig) -> Self
Create new attention layer with initialized weights.
When config.use_bias == true (Qwen2 family), allocates Q/K/V
projection biases as zero tensors. The forward pass already
honors Option<Tensor> biases (lines 388-395 add via
add_bias when Some); without allocating them here, biases
stay None and populate_trainer_from_init_tensors silently
drops the corresponding init tensors during fine-tune from a
Qwen APR checkpoint — see FALSIFY-APR-PRETRAIN-INIT-POPULATE-
COVERAGE-001/002 in transformer::config::tests for the
290-vs-218 named-parameters gap that surfaced this bug.
Zero-init for biases matches HuggingFace LLaMA / Qwen
convention (PyTorch nn.Linear(bias=True) initializes the
weight with kaiming_uniform_ but the bias as the all-zeros
tensor — see torch.nn.modules.linear.Linear.reset_parameters).
Sourcepub fn from_params(
config: &TransformerConfig,
params: &HashMap<String, Tensor>,
prefix: &str,
) -> Option<Self>
pub fn from_params( config: &TransformerConfig, params: &HashMap<String, Tensor>, prefix: &str, ) -> Option<Self>
Create attention layer from parameter map
Expected parameter names (following HuggingFace convention):
{prefix}.q_proj.weight{prefix}.k_proj.weight{prefix}.v_proj.weight{prefix}.o_proj.weight
§Contract (PMAT-331)
Validates Q/K/V/O projection shapes against config dimensions. Returns None if any key is missing or shape is wrong.
Sourcepub fn forward_with_lora(
&self,
x: &Tensor,
seq_len: usize,
lora_a_q: &Tensor,
lora_b_q: &Tensor,
lora_a_v: &Tensor,
lora_b_v: &Tensor,
lora_rank: usize,
lora_scale: f32,
) -> Tensor
pub fn forward_with_lora( &self, x: &Tensor, seq_len: usize, lora_a_q: &Tensor, lora_b_q: &Tensor, lora_a_v: &Tensor, lora_b_v: &Tensor, lora_rank: usize, lora_scale: f32, ) -> Tensor
Forward pass with LoRA adjusts on Q and V projections (KAIZEN-010).
Applies LoRA adapters to Q and V during the forward pass so that gradients flow through LoRA A/B matrices on non-CUDA paths.
§Arguments
x- Input tensor (seq_len * hidden_size)seq_len- Sequence lengthlora_a_q,lora_b_q- Q projection LoRA matrices (rank×d_in, d_out×rank)lora_a_v,lora_b_v- V projection LoRA matrices (rank×d_in, d_out×rank)lora_rank- LoRA ranklora_scale- LoRA scaling factor (alpha/rank)
Sourcepub fn parameters(&self) -> Vec<&Tensor>
pub fn parameters(&self) -> Vec<&Tensor>
Get all parameters as a vector
Sourcepub fn parameters_mut(&mut self) -> Vec<&mut Tensor>
pub fn parameters_mut(&mut self) -> Vec<&mut Tensor>
Get all parameters as mutable references for optimizer
Sourcepub fn has_biases(&self) -> bool
pub fn has_biases(&self) -> bool
Whether this attention layer has QKV biases
Sourcepub fn named_parameters(&self, prefix: &str) -> Vec<(String, &Tensor)>
pub fn named_parameters(&self, prefix: &str) -> Vec<(String, &Tensor)>
Get named parameters for checkpoint serialization
Sourcepub fn set_named_parameter(&mut self, suffix: &str, value: Tensor) -> bool
pub fn set_named_parameter(&mut self, suffix: &str, value: Tensor) -> bool
ENT-282: Set a named parameter by suffix (after “self_attn.”).
Bias suffixes route to b_q / b_k / b_v only when those
fields are already Some (i.e., MultiHeadAttention::new
allocated them because config.use_bias == true). If the
caller asks to set a bias on an attention that doesn’t have
one, return false — same semantic as setting an unrecognized
suffix. This keeps populate_trainer_from_init_tensors
honest: a Qwen-init APR’s biases populate iff the target
Transformer was built from a use_bias=true config.
Auto Trait Implementations§
impl !RefUnwindSafe for MultiHeadAttention
impl !Send for MultiHeadAttention
impl !Sync for MultiHeadAttention
impl !UnwindSafe for MultiHeadAttention
impl Freeze for MultiHeadAttention
impl Unpin for MultiHeadAttention
impl UnsafeUnpin for MultiHeadAttention
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> FmtForward for T
impl<T> FmtForward for T
Source§fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
self to use its Binary implementation when Debug-formatted.Source§fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
self to use its Display implementation when
Debug-formatted.Source§fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
self to use its LowerExp implementation when
Debug-formatted.Source§fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
self to use its LowerHex implementation when
Debug-formatted.Source§fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
self to use its Octal implementation when Debug-formatted.Source§fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
self to use its Pointer implementation when
Debug-formatted.Source§fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
self to use its UpperExp implementation when
Debug-formatted.Source§fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
self to use its UpperHex implementation when
Debug-formatted.Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pipe for Twhere
T: ?Sized,
impl<T> Pipe for Twhere
T: ?Sized,
Source§fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
Source§fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
Source§fn pipe_borrow_mut<'a, B, R>(
&'a mut self,
func: impl FnOnce(&'a mut B) -> R,
) -> R
fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
Source§fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
self, then passes self.as_ref() into the pipe function.Source§fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
self, then passes self.as_mut() into the pipe
function.Source§fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
self, then passes self.deref() into the pipe function.Source§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> PolicyExt for Twhere
T: ?Sized,
impl<T> PolicyExt for Twhere
T: ?Sized,
Source§impl<T> Tap for T
impl<T> Tap for T
Source§fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
Borrow<B> of a value. Read moreSource§fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
BorrowMut<B> of a value. Read moreSource§fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
AsRef<R> view of a value. Read moreSource§fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
AsMut<R> view of a value. Read moreSource§fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
.tap() only in debug builds, and is erased in release builds.Source§fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
.tap_mut() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
.tap_borrow() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
.tap_borrow_mut() only in debug builds, and is erased in release
builds.Source§fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
.tap_ref() only in debug builds, and is erased in release
builds.Source§fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
.tap_ref_mut() only in debug builds, and is erased in release
builds.Source§fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
.tap_deref() only in debug builds, and is erased in release
builds.