use crate::autograd::{matmul, matmul_nt, BackwardOp};
use crate::Tensor;
use ndarray::Array1;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use super::config::TransformerConfig;
fn add_bias(x: &Tensor, bias: &Tensor, seq_len: usize) -> Tensor {
let xd = x.data();
let x_slice = xd.as_slice().expect("contiguous projection");
let bd = bias.data();
let b_slice = bd.as_slice().expect("contiguous bias");
let dim = b_slice.len();
let mut out = Vec::with_capacity(x_slice.len());
for s in 0..seq_len {
let base = s * dim;
for d in 0..dim {
out.push(x_slice[base + d] + b_slice[d]);
}
}
Tensor::from_vec(out, x.requires_grad())
}
fn apply_qk_norm(
x: &Tensor,
norm_weight: &Tensor,
seq_len: usize,
num_heads: usize,
head_dim: usize,
) -> Tensor {
let xd = x.data();
let x_slice = xd.as_slice().expect("contiguous qk");
let wd = norm_weight.data();
let w_slice = wd.as_slice().expect("contiguous norm weight");
let total_dim = num_heads * head_dim;
let eps = 1e-6_f32;
let mut out = vec![0.0f32; seq_len * total_dim];
for s in 0..seq_len {
for h in 0..num_heads {
let offset = s * total_dim + h * head_dim;
let mut sum_sq = 0.0f32;
for d in 0..head_dim {
let v = x_slice[offset + d];
sum_sq += v * v;
}
let rms = (sum_sq / head_dim as f32 + eps).sqrt();
let inv_rms = 1.0 / rms;
for d in 0..head_dim {
out[offset + d] = x_slice[offset + d] * inv_rms * w_slice[d];
}
}
}
Tensor::from_vec(out, x.requires_grad())
}
fn apply_rope(
x: &Tensor,
seq_len: usize,
num_heads: usize,
head_dim: usize,
rope_theta: f32,
) -> Tensor {
let xd = x.data();
let x_slice = xd.as_slice().expect("contiguous qk for rope");
let total_dim = num_heads * head_dim;
let half_dim = head_dim / 2;
let mut out = vec![0.0f32; seq_len * total_dim];
let inv_freq: Vec<f32> =
(0..half_dim).map(|i| 1.0 / rope_theta.powf(2.0 * i as f32 / head_dim as f32)).collect();
for pos in 0..seq_len {
for h in 0..num_heads {
let offset = pos * total_dim + h * head_dim;
for i in 0..half_dim {
let freq = pos as f32 * inv_freq[i];
let cos_f = freq.cos();
let sin_f = freq.sin();
let x_first = x_slice[offset + i];
let x_second = x_slice[offset + i + half_dim];
out[offset + i] = x_first * cos_f - x_second * sin_f;
out[offset + i + half_dim] = x_second * cos_f + x_first * sin_f;
}
}
}
Tensor::from_vec(out, x.requires_grad())
}
struct AttentionBlockBackward {
q: Tensor,
k: Tensor,
v: Tensor,
head_q_tensors: Vec<Tensor>,
head_k_tensors: Vec<Tensor>,
head_v_tensors: Vec<Tensor>,
head_outputs: Vec<Tensor>,
head_kv_indices: Vec<usize>,
seq_len: usize,
head_dim: usize,
q_dim: usize,
kv_hidden_size: usize,
result_grad: Rc<RefCell<Option<Array1<f32>>>>,
}
impl BackwardOp for AttentionBlockBackward {
fn backward(&self) {
let Some(grad_out) = self.result_grad.borrow().as_ref().cloned() else { return };
let go = grad_out.as_slice().expect("grad contiguous");
let h = self.head_dim;
split_and_backward_heads(go, &self.head_outputs, self.seq_len, h, self.q_dim);
scatter_head_grads_q(&self.q, &self.head_q_tensors, self.seq_len, h, self.q_dim);
scatter_head_grads_kv(
&self.k,
&self.head_k_tensors,
&self.head_kv_indices,
self.seq_len,
h,
self.kv_hidden_size,
);
scatter_head_grads_kv(
&self.v,
&self.head_v_tensors,
&self.head_kv_indices,
self.seq_len,
h,
self.kv_hidden_size,
);
for proj in [&self.q, &self.k, &self.v] {
if let Some(op) = proj.backward_op() {
op.backward();
}
}
}
}
fn split_and_backward_heads(
go: &[f32],
head_outputs: &[Tensor],
seq_len: usize,
head_dim: usize,
q_dim: usize,
) {
for (head_idx, head_out) in head_outputs.iter().enumerate() {
let mut grad_head = vec![0.0_f32; seq_len * head_dim];
for s in 0..seq_len {
let src_base = s * q_dim + head_idx * head_dim;
let dst_base = s * head_dim;
grad_head[dst_base..dst_base + head_dim]
.copy_from_slice(&go[src_base..src_base + head_dim]);
}
head_out.accumulate_grad(Array1::from(grad_head));
if let Some(op) = head_out.backward_op() {
op.backward();
}
}
}
fn scatter_head_grads_q(
q: &Tensor,
head_q_tensors: &[Tensor],
seq_len: usize,
head_dim: usize,
q_dim: usize,
) {
if !q.requires_grad() {
return;
}
let mut grad_q = vec![0.0_f32; seq_len * q_dim];
for (head_idx, head_q) in head_q_tensors.iter().enumerate() {
if let Some(hgrad) = head_q.grad() {
let hg = hgrad.as_slice().expect("contiguous");
for s in 0..seq_len {
let src_base = s * head_dim;
let dst_base = s * q_dim + head_idx * head_dim;
for d in 0..head_dim {
grad_q[dst_base + d] += hg[src_base + d];
}
}
}
}
q.accumulate_grad(Array1::from(grad_q));
}
fn scatter_head_grads_kv(
target: &Tensor,
head_tensors: &[Tensor],
kv_indices: &[usize],
seq_len: usize,
head_dim: usize,
kv_hidden_size: usize,
) {
if !target.requires_grad() {
return;
}
let mut grad = vec![0.0_f32; seq_len * kv_hidden_size];
for (head_idx, head_t) in head_tensors.iter().enumerate() {
let kv_h = kv_indices[head_idx];
if let Some(hgrad) = head_t.grad() {
let hg = hgrad.as_slice().expect("contiguous");
for s in 0..seq_len {
let src_base = s * head_dim;
let dst_base = s * kv_hidden_size + kv_h * head_dim;
for d in 0..head_dim {
grad[dst_base + d] += hg[src_base + d];
}
}
}
}
target.accumulate_grad(Array1::from(grad));
}
pub struct MultiHeadAttention {
config: TransformerConfig,
pub w_q: Tensor,
pub w_k: Tensor,
pub w_v: Tensor,
pub w_o: Tensor,
pub b_q: Option<Tensor>,
pub b_k: Option<Tensor>,
pub b_v: Option<Tensor>,
pub q_norm: Option<Tensor>,
pub k_norm: Option<Tensor>,
}
impl MultiHeadAttention {
pub fn new(config: &TransformerConfig) -> Self {
use super::init::{get_init_seed, rand_normal_seeded};
let hidden_size = config.hidden_size;
let q_dim = config.q_dim();
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let seed = get_init_seed();
Self {
config: config.clone(),
w_q: Tensor::from_vec(rand_normal_seeded(q_dim * hidden_size, seed, "w_q"), true),
w_k: Tensor::from_vec(
rand_normal_seeded(kv_hidden_size * hidden_size, seed, "w_k"),
true,
),
w_v: Tensor::from_vec(
rand_normal_seeded(kv_hidden_size * hidden_size, seed, "w_v"),
true,
),
w_o: Tensor::from_vec(rand_normal_seeded(hidden_size * q_dim, seed, "w_o"), true),
b_q: None,
b_k: None,
b_v: None,
q_norm: None,
k_norm: None,
}
}
pub fn from_params(
config: &TransformerConfig,
params: &HashMap<String, Tensor>,
prefix: &str,
) -> Option<Self> {
let w_q = params.get(&format!("{prefix}.q_proj.weight"))?.clone();
let w_k = params.get(&format!("{prefix}.k_proj.weight"))?.clone();
let w_v = params.get(&format!("{prefix}.v_proj.weight"))?.clone();
let w_o = params.get(&format!("{prefix}.o_proj.weight"))?.clone();
let hidden = config.hidden_size;
let q_dim = config.q_dim();
let kv_hidden = config.num_kv_heads * config.head_dim();
let checks: &[(&str, &Tensor, usize)] = &[
("q_proj", &w_q, q_dim * hidden),
("k_proj", &w_k, kv_hidden * hidden),
("v_proj", &w_v, kv_hidden * hidden),
("o_proj", &w_o, hidden * q_dim),
];
for &(name, tensor, expected) in checks {
if tensor.len() != expected {
eprintln!(
"[PMAT-331] {prefix}.{name}: shape mismatch — got {} elements, expected {expected}",
tensor.len()
);
return None;
}
}
let b_q = params.get(&format!("{prefix}.q_proj.bias")).cloned();
let b_k = params.get(&format!("{prefix}.k_proj.bias")).cloned();
let b_v = params.get(&format!("{prefix}.v_proj.bias")).cloned();
let q_norm = params.get(&format!("{prefix}.q_norm.weight")).cloned();
let k_norm = params.get(&format!("{prefix}.k_norm.weight")).cloned();
Some(Self { config: config.clone(), w_q, w_k, w_v, w_o, b_q, b_k, b_v, q_norm, k_norm })
}
pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
contract_pre_attention!(x.data());
let hidden_size = self.config.hidden_size;
let num_heads = self.config.num_attention_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = self.config.head_dim();
let q_dim = self.config.q_dim();
let kv_hidden_size = num_kv_heads * head_dim;
let mut q = matmul_nt(x, &self.w_q, seq_len, hidden_size, q_dim);
let mut k = matmul_nt(x, &self.w_k, seq_len, hidden_size, kv_hidden_size);
let mut v = matmul_nt(x, &self.w_v, seq_len, hidden_size, kv_hidden_size);
if let Some(ref b_q) = self.b_q {
q = add_bias(&q, b_q, seq_len);
}
if let Some(ref b_k) = self.b_k {
k = add_bias(&k, b_k, seq_len);
}
if let Some(ref b_v) = self.b_v {
v = add_bias(&v, b_v, seq_len);
}
if let Some(ref qn) = self.q_norm {
q = apply_qk_norm(&q, qn, seq_len, num_heads, head_dim);
}
if let Some(ref kn) = self.k_norm {
k = apply_qk_norm(&k, kn, seq_len, num_kv_heads, head_dim);
}
if self.config.rope_theta > 0.0 {
q = apply_rope(&q, seq_len, num_heads, head_dim, self.config.rope_theta);
k = apply_rope(&k, seq_len, num_kv_heads, head_dim, self.config.rope_theta);
}
let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
let heads_per_kv = num_heads / num_kv_heads;
let q_data = q.data();
let q_slice = q_data.as_slice().expect("contiguous Q");
let k_data = k.data();
let k_slice = k_data.as_slice().expect("contiguous K");
let v_data = v.data();
let v_slice = v_data.as_slice().expect("contiguous V");
let mut head_q_tensors = Vec::with_capacity(num_heads);
let mut head_k_tensors = Vec::with_capacity(num_heads);
let mut head_v_tensors = Vec::with_capacity(num_heads);
let mut head_outputs = Vec::with_capacity(num_heads);
let mut head_kv_indices = Vec::with_capacity(num_heads);
for h in 0..num_heads {
let kv_h = h / heads_per_kv;
head_kv_indices.push(kv_h);
let mut q_head = Vec::with_capacity(seq_len * head_dim);
for s in 0..seq_len {
let start = s * q_dim + h * head_dim;
q_head.extend_from_slice(&q_slice[start..start + head_dim]);
}
let mut k_head = Vec::with_capacity(seq_len * head_dim);
for s in 0..seq_len {
let start = s * kv_hidden_size + kv_h * head_dim;
k_head.extend_from_slice(&k_slice[start..start + head_dim]);
}
let mut v_head = Vec::with_capacity(seq_len * head_dim);
for s in 0..seq_len {
let start = s * kv_hidden_size + kv_h * head_dim;
v_head.extend_from_slice(&v_slice[start..start + head_dim]);
}
let q_tensor = Tensor::from_vec(q_head, requires_grad);
let k_tensor = Tensor::from_vec(k_head, requires_grad);
let v_tensor = Tensor::from_vec(v_head, requires_grad);
let attn_out = crate::autograd::attention(
&q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
);
head_q_tensors.push(q_tensor);
head_k_tensors.push(k_tensor);
head_v_tensors.push(v_tensor);
head_outputs.push(attn_out);
}
let mut concat_output = vec![0.0; seq_len * q_dim];
for (h, head_out) in head_outputs.iter().enumerate() {
let hd = head_out.data();
let hdata = hd.as_slice().expect("contiguous attention output");
for s in 0..seq_len {
let src_base = s * head_dim;
let dst_base = s * q_dim + h * head_dim;
concat_output[dst_base..dst_base + head_dim]
.copy_from_slice(&hdata[src_base..src_base + head_dim]);
}
}
let mut concat_tensor = Tensor::from_vec(concat_output, requires_grad);
if requires_grad {
let backward_op = Rc::new(AttentionBlockBackward {
q: q.clone(),
k: k.clone(),
v: v.clone(),
head_q_tensors,
head_k_tensors,
head_v_tensors,
head_outputs,
head_kv_indices,
seq_len,
head_dim,
q_dim,
kv_hidden_size,
result_grad: concat_tensor.grad_cell(),
});
concat_tensor.set_backward_op(backward_op);
}
matmul_nt(&concat_tensor, &self.w_o, seq_len, q_dim, hidden_size)
}
pub fn forward_with_lora(
&self,
x: &Tensor,
seq_len: usize,
lora_a_q: &Tensor,
lora_b_q: &Tensor,
lora_a_v: &Tensor,
lora_b_v: &Tensor,
lora_rank: usize,
lora_scale: f32,
) -> Tensor {
contract_pre_lora_forward!();
let hidden_size = self.config.hidden_size;
let num_heads = self.config.num_attention_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = self.config.head_dim();
let q_dim = self.config.q_dim();
let kv_hidden_size = num_kv_heads * head_dim;
let q_base = matmul_nt(x, &self.w_q, seq_len, hidden_size, q_dim);
let q_mid = crate::autograd::matmul_nt(x, lora_a_q, seq_len, hidden_size, lora_rank);
let q_lora = crate::autograd::matmul_nt(&q_mid, lora_b_q, seq_len, lora_rank, q_dim);
let q = crate::autograd::add_scaled(&q_base, &q_lora, lora_scale);
let k = matmul_nt(x, &self.w_k, seq_len, hidden_size, kv_hidden_size);
let v_base = matmul_nt(x, &self.w_v, seq_len, hidden_size, kv_hidden_size);
let v_mid = crate::autograd::matmul_nt(x, lora_a_v, seq_len, hidden_size, lora_rank);
let v_lora =
crate::autograd::matmul_nt(&v_mid, lora_b_v, seq_len, lora_rank, kv_hidden_size);
let v = crate::autograd::add_scaled(&v_base, &v_lora, lora_scale);
let q = if let Some(ref qn) = self.q_norm {
apply_qk_norm(&q, qn, seq_len, num_heads, head_dim)
} else {
q
};
let k = if let Some(ref kn) = self.k_norm {
apply_qk_norm(&k, kn, seq_len, num_kv_heads, head_dim)
} else {
k
};
let (q, k) = if self.config.rope_theta > 0.0 {
(
apply_rope(&q, seq_len, num_heads, head_dim, self.config.rope_theta),
apply_rope(&k, seq_len, num_kv_heads, head_dim, self.config.rope_theta),
)
} else {
(q, k)
};
let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
let heads_per_kv = num_heads / num_kv_heads;
let q_data = q.data();
let q_slice = q_data.as_slice().expect("contiguous Q");
let k_data = k.data();
let k_slice = k_data.as_slice().expect("contiguous K");
let v_data = v.data();
let v_slice = v_data.as_slice().expect("contiguous V");
let mut head_q_tensors = Vec::with_capacity(num_heads);
let mut head_k_tensors = Vec::with_capacity(num_heads);
let mut head_v_tensors = Vec::with_capacity(num_heads);
let mut head_outputs = Vec::with_capacity(num_heads);
let mut head_kv_indices = Vec::with_capacity(num_heads);
for h in 0..num_heads {
let kv_h = h / heads_per_kv;
head_kv_indices.push(kv_h);
let mut q_head = Vec::with_capacity(seq_len * head_dim);
for s in 0..seq_len {
let start = s * q_dim + h * head_dim;
q_head.extend_from_slice(&q_slice[start..start + head_dim]);
}
let mut k_head = Vec::with_capacity(seq_len * head_dim);
for s in 0..seq_len {
let start = s * kv_hidden_size + kv_h * head_dim;
k_head.extend_from_slice(&k_slice[start..start + head_dim]);
}
let mut v_head = Vec::with_capacity(seq_len * head_dim);
for s in 0..seq_len {
let start = s * kv_hidden_size + kv_h * head_dim;
v_head.extend_from_slice(&v_slice[start..start + head_dim]);
}
let q_tensor = Tensor::from_vec(q_head, requires_grad);
let k_tensor = Tensor::from_vec(k_head, requires_grad);
let v_tensor = Tensor::from_vec(v_head, requires_grad);
let attn_out = crate::autograd::attention(
&q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
);
head_q_tensors.push(q_tensor);
head_k_tensors.push(k_tensor);
head_v_tensors.push(v_tensor);
head_outputs.push(attn_out);
}
let mut concat_output = vec![0.0; seq_len * q_dim];
for (h, head_out) in head_outputs.iter().enumerate() {
let hd = head_out.data();
let hdata = hd.as_slice().expect("contiguous attention output");
for s in 0..seq_len {
let src_base = s * head_dim;
let dst_base = s * q_dim + h * head_dim;
concat_output[dst_base..dst_base + head_dim]
.copy_from_slice(&hdata[src_base..src_base + head_dim]);
}
}
let mut concat_tensor = Tensor::from_vec(concat_output, requires_grad);
if requires_grad {
let backward_op = Rc::new(AttentionBlockBackward {
q: q.clone(),
k: k.clone(),
v: v.clone(),
head_q_tensors,
head_k_tensors,
head_v_tensors,
head_outputs,
head_kv_indices,
seq_len,
head_dim,
q_dim,
kv_hidden_size,
result_grad: concat_tensor.grad_cell(),
});
concat_tensor.set_backward_op(backward_op);
}
matmul_nt(&concat_tensor, &self.w_o, seq_len, q_dim, hidden_size)
}
pub fn parameters(&self) -> Vec<&Tensor> {
let mut params = vec![&self.w_q, &self.w_k, &self.w_v, &self.w_o];
if let Some(ref b) = self.b_q {
params.push(b);
}
if let Some(ref b) = self.b_k {
params.push(b);
}
if let Some(ref b) = self.b_v {
params.push(b);
}
params
}
pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = vec![&mut self.w_q, &mut self.w_k, &mut self.w_v, &mut self.w_o];
if let Some(ref mut b) = self.b_q {
params.push(b);
}
if let Some(ref mut b) = self.b_k {
params.push(b);
}
if let Some(ref mut b) = self.b_v {
params.push(b);
}
params
}
pub fn has_biases(&self) -> bool {
self.b_q.is_some()
}
pub fn named_parameters(&self, prefix: &str) -> Vec<(String, &Tensor)> {
let mut params = vec![
(format!("{prefix}.q_proj.weight"), &self.w_q),
(format!("{prefix}.k_proj.weight"), &self.w_k),
(format!("{prefix}.v_proj.weight"), &self.w_v),
(format!("{prefix}.o_proj.weight"), &self.w_o),
];
if let Some(ref b) = self.b_q {
params.push((format!("{prefix}.q_proj.bias"), b));
}
if let Some(ref b) = self.b_k {
params.push((format!("{prefix}.k_proj.bias"), b));
}
if let Some(ref b) = self.b_v {
params.push((format!("{prefix}.v_proj.bias"), b));
}
params
}
pub fn set_named_parameter(&mut self, suffix: &str, value: Tensor) -> bool {
match suffix {
"self_attn.q_proj.weight" => {
self.w_q = value;
true
}
"self_attn.k_proj.weight" => {
self.w_k = value;
true
}
"self_attn.v_proj.weight" => {
self.w_v = value;
true
}
"self_attn.o_proj.weight" => {
self.w_o = value;
true
}
_ => false,
}
}
}
pub struct LoRAProjection {
pub base_weight: Tensor,
pub lora_a: Tensor,
pub lora_b: Tensor,
pub d_in: usize,
pub d_out: usize,
pub rank: usize,
pub scale: f32,
}
impl LoRAProjection {
pub fn new(base_weight: Tensor, d_in: usize, d_out: usize, rank: usize, alpha: f32) -> Self {
assert_eq!(base_weight.len(), d_in * d_out, "Base weight size mismatch");
let mut base_weight = base_weight;
base_weight.set_requires_grad(false);
let lora_a = Tensor::from_vec(
(0..d_in * rank).map(|i| ((i as f32 * 0.123).sin() * 0.01)).collect(),
true, );
let lora_b = Tensor::zeros(rank * d_out, true);
Self { base_weight, lora_a, lora_b, d_in, d_out, rank, scale: alpha / rank as f32 }
}
pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
let base_out = matmul(x, &self.base_weight, seq_len, self.d_in, self.d_out);
let lora_intermediate = matmul(x, &self.lora_a, seq_len, self.d_in, self.rank);
let lora_out = matmul(&lora_intermediate, &self.lora_b, seq_len, self.rank, self.d_out);
crate::autograd::add_scaled(&base_out, &lora_out, self.scale)
}
pub fn lora_params(&self) -> Vec<&Tensor> {
vec![&self.lora_a, &self.lora_b]
}
pub fn lora_params_mut(&mut self) -> Vec<&mut Tensor> {
vec![&mut self.lora_a, &mut self.lora_b]
}
}
pub struct MultiHeadAttentionWithLoRA {
pub config: TransformerConfig,
pub q_proj: LoRAProjection,
pub k_proj: LoRAProjection,
pub v_proj: LoRAProjection,
pub o_proj: LoRAProjection,
}
impl MultiHeadAttentionWithLoRA {
pub fn from_attention(attn: &MultiHeadAttention, rank: usize, alpha: f32) -> Self {
let hidden_size = attn.config.hidden_size;
let q_dim = attn.config.q_dim();
let kv_hidden_size = attn.config.num_kv_heads * attn.config.head_dim();
Self {
config: attn.config.clone(),
q_proj: LoRAProjection::new(attn.w_q.clone(), hidden_size, q_dim, rank, alpha),
k_proj: LoRAProjection::new(attn.w_k.clone(), hidden_size, kv_hidden_size, rank, alpha),
v_proj: LoRAProjection::new(attn.w_v.clone(), hidden_size, kv_hidden_size, rank, alpha),
o_proj: LoRAProjection::new(attn.w_o.clone(), q_dim, hidden_size, rank, alpha),
}
}
pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
let num_heads = self.config.num_attention_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = self.config.head_dim();
let q_dim = self.config.q_dim();
let kv_hidden_size = num_kv_heads * head_dim;
let q = self.q_proj.forward(x, seq_len);
let k = self.k_proj.forward(x, seq_len);
let v = self.v_proj.forward(x, seq_len);
let mut attn_outputs = Vec::with_capacity(num_heads * seq_len * head_dim);
let heads_per_kv = num_heads / num_kv_heads;
let q_data = q.data();
let q_slice = q_data.as_slice().expect("contiguous Q tensor");
let k_data = k.data();
let k_slice = k_data.as_slice().expect("contiguous K tensor");
let v_data = v.data();
let v_slice = v_data.as_slice().expect("contiguous V tensor");
for h in 0..num_heads {
let kv_h = h / heads_per_kv;
let mut q_head = Vec::with_capacity(seq_len * head_dim);
for s in 0..seq_len {
let start = s * q_dim + h * head_dim;
q_head.extend_from_slice(&q_slice[start..start + head_dim]);
}
let mut k_head = Vec::with_capacity(seq_len * head_dim);
for s in 0..seq_len {
let start = s * kv_hidden_size + kv_h * head_dim;
k_head.extend_from_slice(&k_slice[start..start + head_dim]);
}
let mut v_head = Vec::with_capacity(seq_len * head_dim);
for s in 0..seq_len {
let start = s * kv_hidden_size + kv_h * head_dim;
v_head.extend_from_slice(&v_slice[start..start + head_dim]);
}
let q_tensor = Tensor::from_vec(q_head, false);
let k_tensor = Tensor::from_vec(k_head, false);
let v_tensor = Tensor::from_vec(v_head, false);
let attn_out = crate::autograd::attention(
&q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
);
attn_outputs.extend_from_slice(
attn_out.data().as_slice().expect("contiguous attention output"),
);
}
let mut concat_output = vec![0.0; seq_len * q_dim];
for h in 0..num_heads {
for s in 0..seq_len {
let src_idx = h * seq_len * head_dim + s * head_dim;
let dst_idx = s * q_dim + h * head_dim;
concat_output[dst_idx..dst_idx + head_dim]
.copy_from_slice(&attn_outputs[src_idx..src_idx + head_dim]);
}
}
let concat_tensor = Tensor::from_vec(concat_output, true);
self.o_proj.forward(&concat_tensor, seq_len)
}
pub fn lora_params(&self) -> Vec<&Tensor> {
let mut params = Vec::new();
params.extend(self.q_proj.lora_params());
params.extend(self.k_proj.lora_params());
params.extend(self.v_proj.lora_params());
params.extend(self.o_proj.lora_params());
params
}
pub fn lora_params_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = Vec::new();
params.extend(self.q_proj.lora_params_mut());
params.extend(self.k_proj.lora_params_mut());
params.extend(self.v_proj.lora_params_mut());
params.extend(self.o_proj.lora_params_mut());
params
}
pub fn lora_param_count(&self) -> usize {
let hidden = self.config.hidden_size;
let kv_hidden = self.config.num_kv_heads * self.config.head_dim();
let rank = self.q_proj.rank;
(hidden * rank + rank * hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * hidden) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multi_head_attention_tiny() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
let output = attn.forward(&x, 2);
assert_eq!(output.len(), 2 * config.hidden_size);
}
#[test]
fn test_multi_head_attention_parameters() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let params = attn.parameters();
assert_eq!(params.len(), 4); }
#[test]
fn test_attention_longer_sequence() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let x = Tensor::from_vec(vec![0.1; 8 * config.hidden_size], true);
let output = attn.forward(&x, 8);
assert_eq!(output.len(), 8 * config.hidden_size);
}
#[test]
fn test_attention_weight_sizes() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let kv_hidden = config.num_kv_heads * config.head_dim();
assert_eq!(attn.w_q.len(), config.hidden_size * config.hidden_size);
assert_eq!(attn.w_k.len(), config.hidden_size * kv_hidden);
assert_eq!(attn.w_v.len(), config.hidden_size * kv_hidden);
assert_eq!(attn.w_o.len(), config.hidden_size * config.hidden_size);
}
#[test]
fn test_multi_head_attention_from_params_success() {
let config = TransformerConfig::tiny();
let hidden_size = config.hidden_size;
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let mut params = HashMap::new();
params.insert(
"attn.q_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
"attn.k_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
"attn.v_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
"attn.o_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
assert!(attn.is_some());
let attn = attn.expect("operation should succeed");
assert_eq!(attn.w_q.len(), hidden_size * hidden_size);
}
#[test]
fn test_multi_head_attention_from_params_missing_key() {
let config = TransformerConfig::tiny();
let hidden_size = config.hidden_size;
let mut params = HashMap::new();
params.insert(
"attn.q_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
assert!(attn.is_none());
}
#[test]
fn test_attention_projections_backward() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let hidden_size = config.hidden_size;
let seq_len = 2;
let x = Tensor::from_vec(vec![0.1; seq_len * hidden_size], true);
let mut q = crate::autograd::matmul(&x, &attn.w_q, seq_len, hidden_size, hidden_size);
let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
crate::autograd::backward(&mut q, Some(grad_out));
assert!(attn.w_q.grad().is_some());
let grad_q = attn.w_q.grad().expect("gradient should be available");
assert!(grad_q.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_output_projection_backward() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let hidden_size = config.hidden_size;
let seq_len = 2;
let concat_out = Tensor::from_vec(vec![0.1; seq_len * hidden_size], true);
let mut output =
crate::autograd::matmul(&concat_out, &attn.w_o, seq_len, hidden_size, hidden_size);
let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
crate::autograd::backward(&mut output, Some(grad_out));
assert!(attn.w_o.grad().is_some());
let grad_o = attn.w_o.grad().expect("gradient should be available");
assert!(grad_o.iter().all(|&v| v.is_finite()));
let sum: f32 = grad_o.iter().map(|v| v.abs()).sum();
assert!(sum > 0.0, "Output projection gradient should not be all zero");
}
#[test]
#[ignore = "apply_rope() severs autograd chain — needs backward op (ENT-272)"]
fn test_attention_full_forward_qkv_gradients() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let hidden_size = config.hidden_size;
let seq_len = 3;
let x_data: Vec<f32> =
(0..seq_len * hidden_size).map(|i| ((i as f32) * 0.17).sin() * 0.5).collect();
let x = Tensor::from_vec(x_data, true);
let mut output = attn.forward(&x, seq_len);
let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
crate::autograd::backward(&mut output, Some(grad_out));
for (name, param) in
[("w_q", &attn.w_q), ("w_k", &attn.w_k), ("w_v", &attn.w_v), ("w_o", &attn.w_o)]
{
assert!(
param.grad().is_some(),
"ALB-038: {name} must have gradient after full attention forward"
);
let grad = param.grad().expect("gradient available");
assert!(grad.iter().all(|&v| v.is_finite()), "ALB-038: {name} gradient must be finite");
assert!(
grad.iter().any(|&v| v.abs() > 1e-10),
"ALB-038: {name} gradient must be non-zero"
);
}
assert!(x.grad().is_some(), "ALB-038: input x must have gradient");
}
#[test]
fn test_lora_projection_new() {
let d_in = 32;
let d_out = 16;
let rank = 4;
let alpha = 8.0;
let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, alpha);
assert_eq!(lora.d_in, d_in);
assert_eq!(lora.d_out, d_out);
assert_eq!(lora.rank, rank);
assert!((lora.scale - 2.0).abs() < 1e-6); assert_eq!(lora.lora_a.len(), d_in * rank);
assert_eq!(lora.lora_b.len(), rank * d_out);
}
#[test]
fn test_lora_projection_forward() {
let d_in = 32;
let d_out = 16;
let rank = 4;
let alpha = 8.0;
let seq_len = 2;
let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, alpha);
let x = Tensor::from_vec(vec![0.1; seq_len * d_in], false);
let output = lora.forward(&x, seq_len);
assert_eq!(output.len(), seq_len * d_out);
assert!(output.data().iter().all(|&v| v.is_finite()));
}
#[test]
fn test_lora_projection_params() {
let d_in = 32;
let d_out = 16;
let rank = 4;
let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
let params = lora.lora_params();
assert_eq!(params.len(), 2); }
#[test]
fn test_lora_projection_params_mut() {
let d_in = 32;
let d_out = 16;
let rank = 4;
let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
let mut lora = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
let params = lora.lora_params_mut();
assert_eq!(params.len(), 2);
}
#[test]
#[should_panic(expected = "Base weight size mismatch")]
fn test_lora_projection_size_mismatch() {
let d_in = 32;
let d_out = 16;
let rank = 4;
let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out + 1], false);
let _ = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
}
#[test]
fn test_mha_with_lora_creation() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let rank = 4;
let alpha = 8.0;
let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, rank, alpha);
assert_eq!(lora_attn.q_proj.rank, rank);
assert_eq!(lora_attn.k_proj.rank, rank);
assert_eq!(lora_attn.v_proj.rank, rank);
assert_eq!(lora_attn.o_proj.rank, rank);
}
#[test]
fn test_mha_with_lora_forward() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
let seq_len = 2;
let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], false);
let output = lora_attn.forward(&x, seq_len);
assert_eq!(output.len(), seq_len * config.hidden_size);
assert!(output.data().iter().all(|&v| v.is_finite()));
}
#[test]
fn test_mha_with_lora_params() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
let params = lora_attn.lora_params();
assert_eq!(params.len(), 8);
}
#[test]
fn test_mha_with_lora_params_mut() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let mut lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
let params = lora_attn.lora_params_mut();
assert_eq!(params.len(), 8);
}
#[test]
fn test_mha_with_lora_param_count() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let rank = 4;
let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, rank, 8.0);
let param_count = lora_attn.lora_param_count();
let hidden = config.hidden_size;
let kv_hidden = config.num_kv_heads * config.head_dim();
let expected = (hidden * rank + rank * hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * hidden);
assert_eq!(param_count, expected);
assert!(param_count > 0);
}
#[test]
fn test_mha_with_lora_longer_sequence() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
let seq_len = 8;
let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], false);
let output = lora_attn.forward(&x, seq_len);
assert_eq!(output.len(), seq_len * config.hidden_size);
}
#[test]
fn test_parameters_mut() {
let config = TransformerConfig::tiny();
let mut attn = MultiHeadAttention::new(&config);
let params = attn.parameters_mut();
assert_eq!(params.len(), 4);
}
#[test]
fn falsify_a1e_from_params_rejects_wrong_shape_q_weight() {
let config = TransformerConfig::tiny();
let hidden_size = config.hidden_size;
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let mut params = HashMap::new();
params.insert("attn.q_proj.weight".to_string(), Tensor::from_vec(vec![0.1; 50], true));
params.insert(
"attn.k_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
"attn.v_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
"attn.o_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
assert!(
attn.is_none(),
"FALSIFY-A1e: PMAT-331 fix — from_params MUST reject wrong-shape q_proj"
);
}
#[test]
fn falsify_a2e_gqa_init_correct_kv_dimensions() {
let mut config = TransformerConfig::tiny();
config.num_kv_heads = 1;
let attn = MultiHeadAttention::new(&config);
let head_dim = config.head_dim();
let kv_hidden = config.num_kv_heads * head_dim;
assert_eq!(
attn.w_q.len(),
config.hidden_size * config.hidden_size,
"FALSIFY-A2e: Q projection must be hidden*hidden"
);
assert_eq!(
attn.w_k.len(),
config.hidden_size * kv_hidden,
"FALSIFY-A2e: K projection must use num_kv_heads, not num_heads"
);
assert_eq!(
attn.w_v.len(),
config.hidden_size * kv_hidden,
"FALSIFY-A2e: V projection must use num_kv_heads, not num_heads"
);
assert_eq!(
attn.w_o.len(),
config.hidden_size * config.hidden_size,
"FALSIFY-A2e: O projection must be hidden*hidden"
);
assert!(
attn.w_k.len() < attn.w_q.len(),
"FALSIFY-A2e: For GQA, K weight must be smaller than Q weight"
);
}
#[test]
fn falsify_a3e_gqa_forward_correct_output_dims() {
let mut config = TransformerConfig::tiny();
config.num_kv_heads = 1;
let attn = MultiHeadAttention::new(&config);
let seq_len = 3;
let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
let output = attn.forward(&x, seq_len);
assert_eq!(
output.len(),
seq_len * config.hidden_size,
"FALSIFY-A3e: GQA output must be seq_len * hidden_size, not seq_len * kv_hidden"
);
}
#[test]
fn falsify_a4e_init_produces_valid_attention_weights() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
for (name, w) in
[("w_q", &attn.w_q), ("w_k", &attn.w_k), ("w_v", &attn.w_v), ("w_o", &attn.w_o)]
{
let data = w.data();
let slice = data.as_slice().expect("data as slice");
let nan_count = slice.iter().filter(|v| v.is_nan()).count();
assert_eq!(nan_count, 0, "FALSIFY-A4e: {name} init must not contain NaN");
let inf_count = slice.iter().filter(|v| v.is_infinite()).count();
assert_eq!(inf_count, 0, "FALSIFY-A4e: {name} init must not contain Inf");
let min = slice.iter().copied().fold(f32::INFINITY, f32::min);
let max = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
assert!(
(max - min).abs() > 1e-6,
"FALSIFY-A4e: {name} init values are constant ({min}..{max}) — degenerate weight"
);
}
}
#[test]
fn falsify_a5e_forward_produces_finite_output() {
let config = TransformerConfig::tiny();
let attn = MultiHeadAttention::new(&config);
let seq_len = 4;
let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
let output = attn.forward(&x, seq_len);
let data = output.data();
let nan_count = data.iter().filter(|v| v.is_nan()).count();
let inf_count = data.iter().filter(|v| v.is_infinite()).count();
assert_eq!(nan_count, 0, "FALSIFY-A5e: Attention output must not contain NaN");
assert_eq!(inf_count, 0, "FALSIFY-A5e: Attention output must not contain Inf");
}
#[test]
fn falsify_gq_001e_output_shape() {
for (num_heads, num_kv_heads) in [(2, 2), (4, 2), (4, 1), (2, 1)] {
let mut config = TransformerConfig::tiny();
config.num_attention_heads = num_heads;
config.num_kv_heads = num_kv_heads;
let attn = MultiHeadAttention::new(&config);
let seq_len = 3;
let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
let output = attn.forward(&x, seq_len);
assert_eq!(
output.len(),
seq_len * config.hidden_size,
"FALSIFIED GQ-001e: output len mismatch for heads={num_heads},kv={num_kv_heads}"
);
}
}
#[test]
fn falsify_gq_002e_mha_degeneration() {
let config = TransformerConfig::tiny(); assert_eq!(config.num_attention_heads, config.num_kv_heads);
let attn = MultiHeadAttention::new(&config);
let seq_len = 4;
let x = Tensor::from_vec(
(0..seq_len * config.hidden_size).map(|i| (i as f32 * 0.37).sin()).collect(),
true,
);
let output = attn.forward(&x, seq_len);
let data = output.data();
for (i, v) in data.iter().enumerate() {
assert!(v.is_finite(), "FALSIFIED GQ-002e: MHA output[{i}] = {v} (not finite)");
}
}
#[test]
fn falsify_gq_004e_head_divisibility() {
for (nh, nkv) in [(2, 1), (2, 2), (4, 1), (4, 2), (4, 4), (8, 2), (8, 4)] {
let mut config = TransformerConfig::tiny();
config.num_attention_heads = nh;
config.num_kv_heads = nkv;
assert_eq!(nh % nkv, 0, "FALSIFIED GQ-004e: test config has invalid head ratio");
let attn = MultiHeadAttention::new(&config);
let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
let _ = attn.forward(&x, 2);
}
}
#[test]
fn falsify_gq_006e_mqa_boundary() {
let mut config = TransformerConfig::tiny();
config.num_attention_heads = 4;
config.num_kv_heads = 1;
config.hidden_size = 64;
let attn = MultiHeadAttention::new(&config);
let seq_len = 3;
let x = Tensor::from_vec(
(0..seq_len * config.hidden_size).map(|i| (i as f32 * 0.73).cos()).collect(),
true,
);
let output = attn.forward(&x, seq_len);
assert_eq!(
output.len(),
seq_len * config.hidden_size,
"FALSIFIED GQ-006e: MQA output size wrong"
);
let data = output.data();
for (i, v) in data.iter().enumerate() {
assert!(v.is_finite(), "FALSIFIED GQ-006e: MQA output[{i}] = {v} (not finite)");
}
}
mod gq_proptest_falsify {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn falsify_gq_001e_prop_output_shape(
config_idx in 0..4usize,
seq_len in 2..=6usize,
seed in 0..500u32,
) {
let configs: [(usize, usize); 4] = [
(2, 2), (2, 1), (4, 2), (4, 1),
];
let (num_heads, num_kv_heads) = configs[config_idx];
let mut config = TransformerConfig::tiny();
config.num_attention_heads = num_heads;
config.num_kv_heads = num_kv_heads;
let attn = MultiHeadAttention::new(&config);
let data: Vec<f32> = (0..seq_len * config.hidden_size)
.map(|i| ((i as f32 + seed as f32) * 0.37).sin())
.collect();
let x = Tensor::from_vec(data, true);
let output = attn.forward(&x, seq_len);
prop_assert_eq!(
output.len(),
seq_len * config.hidden_size,
"FALSIFIED GQ-001e-prop: output len mismatch"
);
for v in output.data() {
prop_assert!(
v.is_finite(),
"FALSIFIED GQ-001e-prop: non-finite output"
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(30))]
#[test]
fn falsify_gq_006e_prop_mqa_boundary(
seed in 0..500u32,
seq_len in 2..=5usize,
) {
let mut config = TransformerConfig::tiny();
config.num_attention_heads = 4;
config.num_kv_heads = 1;
config.hidden_size = 64;
let attn = MultiHeadAttention::new(&config);
let data: Vec<f32> = (0..seq_len * config.hidden_size)
.map(|i| ((i as f32 + seed as f32) * 0.73).cos())
.collect();
let x = Tensor::from_vec(data, true);
let output = attn.forward(&x, seq_len);
prop_assert_eq!(
output.len(),
seq_len * config.hidden_size,
"FALSIFIED GQ-006e-prop: MQA output len mismatch"
);
for v in output.data() {
prop_assert!(
v.is_finite(),
"FALSIFIED GQ-006e-prop: non-finite MQA output"
);
}
}
}
}
#[test]
fn test_attention_from_params_with_biases() {
let config = TransformerConfig::tiny();
let hidden_size = config.hidden_size;
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let mut params = HashMap::new();
params.insert(
"attn.q_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
"attn.k_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
"attn.v_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
"attn.o_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
"attn.q_proj.bias".to_string(),
Tensor::from_vec(vec![0.01; hidden_size], true),
);
params.insert(
"attn.k_proj.bias".to_string(),
Tensor::from_vec(vec![0.01; kv_hidden_size], true),
);
params.insert(
"attn.v_proj.bias".to_string(),
Tensor::from_vec(vec![0.01; kv_hidden_size], true),
);
let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
assert!(attn.is_some());
let attn = attn.expect("should load with biases");
assert!(attn.has_biases());
assert_eq!(attn.parameters().len(), 7);
}
#[test]
fn test_attention_named_parameters_with_biases() {
let config = TransformerConfig::tiny();
let hidden_size = config.hidden_size;
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let mut params = HashMap::new();
params.insert(
"attn.q_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
"attn.k_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
"attn.v_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
"attn.o_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
"attn.q_proj.bias".to_string(),
Tensor::from_vec(vec![0.01; hidden_size], true),
);
params.insert(
"attn.k_proj.bias".to_string(),
Tensor::from_vec(vec![0.01; kv_hidden_size], true),
);
params.insert(
"attn.v_proj.bias".to_string(),
Tensor::from_vec(vec![0.01; kv_hidden_size], true),
);
let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn").expect("should load");
let named = attn.named_parameters("attn");
assert_eq!(named.len(), 7);
let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
assert!(names.contains(&"attn.q_proj.bias"));
assert!(names.contains(&"attn.k_proj.bias"));
assert!(names.contains(&"attn.v_proj.bias"));
}
#[test]
fn test_attention_forward_with_biases() {
let config = TransformerConfig::tiny();
let hidden_size = config.hidden_size;
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let mut params = HashMap::new();
params.insert(
"attn.q_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
"attn.k_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
"attn.v_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
"attn.o_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params
.insert("attn.q_proj.bias".to_string(), Tensor::from_vec(vec![0.5; hidden_size], true));
params.insert(
"attn.k_proj.bias".to_string(),
Tensor::from_vec(vec![0.5; kv_hidden_size], true),
);
params.insert(
"attn.v_proj.bias".to_string(),
Tensor::from_vec(vec![0.5; kv_hidden_size], true),
);
let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn").expect("should load");
let x = Tensor::from_vec(vec![0.1; 2 * hidden_size], false);
let output = attn.forward(&x, 2);
assert_eq!(output.len(), 2 * hidden_size);
assert!(output.data().iter().all(|v| v.is_finite()));
}
}