Skip to main content

entrenar/transformer/
attention.rs

1//! Multi-head attention module
2//!
3//! This module provides multi-head self-attention with grouped-query attention support.
4
5use crate::autograd::{matmul, matmul_nt, BackwardOp};
6use crate::Tensor;
7use ndarray::Array1;
8use std::cell::RefCell;
9use std::collections::HashMap;
10use std::rc::Rc;
11
12use super::config::TransformerConfig;
13
14/// Add a bias vector to a projected tensor: output[s] += bias for each sequence position.
15/// Input shape: (seq_len × dim) flattened. Bias shape: (dim).
16fn add_bias(x: &Tensor, bias: &Tensor, seq_len: usize) -> Tensor {
17    let xd = x.data();
18    let x_slice = xd.as_slice().expect("contiguous projection");
19    let bd = bias.data();
20    let b_slice = bd.as_slice().expect("contiguous bias");
21    let dim = b_slice.len();
22    let mut out = Vec::with_capacity(x_slice.len());
23    for s in 0..seq_len {
24        let base = s * dim;
25        for d in 0..dim {
26            out.push(x_slice[base + d] + b_slice[d]);
27        }
28    }
29    Tensor::from_vec(out, x.requires_grad())
30}
31
32/// Apply per-head RMSNorm to Q or K (Qwen3 QK-norm, ENT-269).
33///
34/// Input: [seq_len * total_dim] where total_dim = num_heads * head_dim.
35/// Norm weight: [head_dim]. Applied independently to each head's head_dim slice.
36fn apply_qk_norm(
37    x: &Tensor,
38    norm_weight: &Tensor,
39    seq_len: usize,
40    num_heads: usize,
41    head_dim: usize,
42) -> Tensor {
43    let xd = x.data();
44    let x_slice = xd.as_slice().expect("contiguous qk");
45    let wd = norm_weight.data();
46    let w_slice = wd.as_slice().expect("contiguous norm weight");
47    let total_dim = num_heads * head_dim;
48    let eps = 1e-6_f32;
49    let mut out = vec![0.0f32; seq_len * total_dim];
50
51    for s in 0..seq_len {
52        for h in 0..num_heads {
53            let offset = s * total_dim + h * head_dim;
54            // RMSNorm: x * weight / sqrt(mean(x^2) + eps)
55            let mut sum_sq = 0.0f32;
56            for d in 0..head_dim {
57                let v = x_slice[offset + d];
58                sum_sq += v * v;
59            }
60            let rms = (sum_sq / head_dim as f32 + eps).sqrt();
61            let inv_rms = 1.0 / rms;
62            for d in 0..head_dim {
63                out[offset + d] = x_slice[offset + d] * inv_rms * w_slice[d];
64            }
65        }
66    }
67
68    Tensor::from_vec(out, x.requires_grad())
69}
70
71/// Apply Rotary Position Embedding (RoPE) to Q or K tensor (ENT-269).
72///
73/// Uses Llama/Qwen3 half-rotation layout (NOT interleaved pairs):
74///   x1 = x[..., :half_dim], x2 = x[..., half_dim:]
75///   rotate_half(x) = [-x2, x1]
76///   result = x * cos + rotate_half(x) * sin
77///
78/// freq[i] = 1 / (theta ^ (2i / head_dim))
79fn apply_rope(
80    x: &Tensor,
81    seq_len: usize,
82    num_heads: usize,
83    head_dim: usize,
84    rope_theta: f32,
85) -> Tensor {
86    let xd = x.data();
87    let x_slice = xd.as_slice().expect("contiguous qk for rope");
88    let total_dim = num_heads * head_dim;
89    let half_dim = head_dim / 2;
90    let mut out = vec![0.0f32; seq_len * total_dim];
91
92    // Precompute inverse frequencies: 1 / (theta ^ (2i / head_dim))
93    let inv_freq: Vec<f32> =
94        (0..half_dim).map(|i| 1.0 / rope_theta.powf(2.0 * i as f32 / head_dim as f32)).collect();
95
96    for pos in 0..seq_len {
97        for h in 0..num_heads {
98            let offset = pos * total_dim + h * head_dim;
99            for i in 0..half_dim {
100                let freq = pos as f32 * inv_freq[i];
101                let cos_f = freq.cos();
102                let sin_f = freq.sin();
103                // Half-rotation: pair (x[i], x[i + half_dim])
104                let x_first = x_slice[offset + i];
105                let x_second = x_slice[offset + i + half_dim];
106                // rotate_half: [-x_second, x_first]
107                out[offset + i] = x_first * cos_f - x_second * sin_f;
108                out[offset + i + half_dim] = x_second * cos_f + x_first * sin_f;
109            }
110        }
111    }
112
113    let result = Tensor::from_vec(out, x.requires_grad());
114    contract_post_rope!(result.data().as_slice().unwrap_or(&[]));
115    result
116}
117
118// ---------------------------------------------------------------------------
119// AttentionBlockBackward: combined backward for multi-head attention
120//
121// Orchestrates gradient flow: concat → per-head attention → Q/K/V projections.
122// Calls each Q/K/V matmul backward exactly once to avoid gradient inflation.
123// ---------------------------------------------------------------------------
124
125struct AttentionBlockBackward {
126    q: Tensor,
127    k: Tensor,
128    v: Tensor,
129    head_q_tensors: Vec<Tensor>,
130    head_k_tensors: Vec<Tensor>,
131    head_v_tensors: Vec<Tensor>,
132    head_outputs: Vec<Tensor>,
133    head_kv_indices: Vec<usize>,
134    seq_len: usize,
135    head_dim: usize,
136    q_dim: usize,
137    kv_hidden_size: usize,
138    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
139}
140
141impl BackwardOp for AttentionBlockBackward {
142    fn backward(&self) {
143        let Some(grad_out) = self.result_grad.borrow().as_ref().cloned() else { return };
144        let go = grad_out.as_slice().expect("grad contiguous");
145        let h = self.head_dim;
146
147        // Step 1: Split concat grad per head and trigger each attention backward
148        split_and_backward_heads(go, &self.head_outputs, self.seq_len, h, self.q_dim);
149
150        // Step 2-4: Scatter per-head grads into full Q/K/V
151        scatter_head_grads_q(&self.q, &self.head_q_tensors, self.seq_len, h, self.q_dim);
152        scatter_head_grads_kv(
153            &self.k,
154            &self.head_k_tensors,
155            &self.head_kv_indices,
156            self.seq_len,
157            h,
158            self.kv_hidden_size,
159        );
160        scatter_head_grads_kv(
161            &self.v,
162            &self.head_v_tensors,
163            &self.head_kv_indices,
164            self.seq_len,
165            h,
166            self.kv_hidden_size,
167        );
168
169        // Step 5: Propagate backward through Q/K/V matmuls (once each)
170        for proj in [&self.q, &self.k, &self.v] {
171            if let Some(op) = proj.backward_op() {
172                op.backward();
173            }
174        }
175    }
176}
177
178/// Split concat gradient per head and trigger each head's attention backward
179fn split_and_backward_heads(
180    go: &[f32],
181    head_outputs: &[Tensor],
182    seq_len: usize,
183    head_dim: usize,
184    q_dim: usize,
185) {
186    for (head_idx, head_out) in head_outputs.iter().enumerate() {
187        let mut grad_head = vec![0.0_f32; seq_len * head_dim];
188        for s in 0..seq_len {
189            let src_base = s * q_dim + head_idx * head_dim;
190            let dst_base = s * head_dim;
191            grad_head[dst_base..dst_base + head_dim]
192                .copy_from_slice(&go[src_base..src_base + head_dim]);
193        }
194        head_out.accumulate_grad(Array1::from(grad_head));
195        if let Some(op) = head_out.backward_op() {
196            op.backward();
197        }
198    }
199}
200
201/// Scatter per-head Q gradients into the full Q projection tensor
202fn scatter_head_grads_q(
203    q: &Tensor,
204    head_q_tensors: &[Tensor],
205    seq_len: usize,
206    head_dim: usize,
207    q_dim: usize,
208) {
209    if !q.requires_grad() {
210        return;
211    }
212    let mut grad_q = vec![0.0_f32; seq_len * q_dim];
213    for (head_idx, head_q) in head_q_tensors.iter().enumerate() {
214        if let Some(hgrad) = head_q.grad() {
215            let hg = hgrad.as_slice().expect("contiguous");
216            for s in 0..seq_len {
217                let src_base = s * head_dim;
218                let dst_base = s * q_dim + head_idx * head_dim;
219                for d in 0..head_dim {
220                    grad_q[dst_base + d] += hg[src_base + d];
221                }
222            }
223        }
224    }
225    q.accumulate_grad(Array1::from(grad_q));
226}
227
228/// Scatter per-head K or V gradients into the full K/V projection tensor (GQA-correct)
229fn scatter_head_grads_kv(
230    target: &Tensor,
231    head_tensors: &[Tensor],
232    kv_indices: &[usize],
233    seq_len: usize,
234    head_dim: usize,
235    kv_hidden_size: usize,
236) {
237    if !target.requires_grad() {
238        return;
239    }
240    let mut grad = vec![0.0_f32; seq_len * kv_hidden_size];
241    for (head_idx, head_t) in head_tensors.iter().enumerate() {
242        let kv_h = kv_indices[head_idx];
243        if let Some(hgrad) = head_t.grad() {
244            let hg = hgrad.as_slice().expect("contiguous");
245            for s in 0..seq_len {
246                let src_base = s * head_dim;
247                let dst_base = s * kv_hidden_size + kv_h * head_dim;
248                for d in 0..head_dim {
249                    grad[dst_base + d] += hg[src_base + d];
250                }
251            }
252        }
253    }
254    target.accumulate_grad(Array1::from(grad));
255}
256
257/// Multi-head self-attention layer
258pub struct MultiHeadAttention {
259    /// Configuration
260    config: TransformerConfig,
261    /// Query projection weight (hidden_size x hidden_size)
262    pub w_q: Tensor,
263    /// Key projection weight (hidden_size x kv_hidden_size)
264    pub w_k: Tensor,
265    /// Value projection weight (hidden_size x kv_hidden_size)
266    pub w_v: Tensor,
267    /// Output projection weight (hidden_size x hidden_size)
268    pub w_o: Tensor,
269    /// Optional query bias (Qwen2 uses attention biases)
270    pub b_q: Option<Tensor>,
271    /// Optional key bias
272    pub b_k: Option<Tensor>,
273    /// Optional value bias
274    pub b_v: Option<Tensor>,
275    /// Optional Q RMSNorm weight (Qwen3 uses QK-norm, shape=[head_dim])
276    pub q_norm: Option<Tensor>,
277    /// Optional K RMSNorm weight (Qwen3 uses QK-norm, shape=[head_dim])
278    pub k_norm: Option<Tensor>,
279}
280
281impl MultiHeadAttention {
282    /// Create new attention layer with initialized weights
283    pub fn new(config: &TransformerConfig) -> Self {
284        use super::init::{get_init_seed, rand_normal_seeded};
285        let hidden_size = config.hidden_size;
286        let q_dim = config.q_dim();
287        let kv_hidden_size = config.num_kv_heads * config.head_dim();
288        let seed = get_init_seed();
289
290        // C-INIT-001: normal(0, 0.02) matching HuggingFace LLaMA
291        Self {
292            config: config.clone(),
293            w_q: Tensor::from_vec(rand_normal_seeded(q_dim * hidden_size, seed, "w_q"), true),
294            w_k: Tensor::from_vec(
295                rand_normal_seeded(kv_hidden_size * hidden_size, seed, "w_k"),
296                true,
297            ),
298            w_v: Tensor::from_vec(
299                rand_normal_seeded(kv_hidden_size * hidden_size, seed, "w_v"),
300                true,
301            ),
302            w_o: Tensor::from_vec(rand_normal_seeded(hidden_size * q_dim, seed, "w_o"), true),
303            b_q: None,
304            b_k: None,
305            b_v: None,
306            q_norm: None,
307            k_norm: None,
308        }
309    }
310
311    /// Create attention layer from parameter map
312    ///
313    /// Expected parameter names (following HuggingFace convention):
314    /// - `{prefix}.q_proj.weight`
315    /// - `{prefix}.k_proj.weight`
316    /// - `{prefix}.v_proj.weight`
317    /// - `{prefix}.o_proj.weight`
318    /// # Contract (PMAT-331)
319    /// Validates Q/K/V/O projection shapes against config dimensions.
320    /// Returns None if any key is missing or shape is wrong.
321    pub fn from_params(
322        config: &TransformerConfig,
323        params: &HashMap<String, Tensor>,
324        prefix: &str,
325    ) -> Option<Self> {
326        let w_q = params.get(&format!("{prefix}.q_proj.weight"))?.clone();
327        let w_k = params.get(&format!("{prefix}.k_proj.weight"))?.clone();
328        let w_v = params.get(&format!("{prefix}.v_proj.weight"))?.clone();
329        let w_o = params.get(&format!("{prefix}.o_proj.weight"))?.clone();
330
331        let hidden = config.hidden_size;
332        let q_dim = config.q_dim();
333        let kv_hidden = config.num_kv_heads * config.head_dim();
334
335        // PMAT-331: Shape validation for attention projections
336        // Q: [q_dim, hidden], K: [kv_hidden, hidden], V: [kv_hidden, hidden], O: [hidden, q_dim]
337        let checks: &[(&str, &Tensor, usize)] = &[
338            ("q_proj", &w_q, q_dim * hidden),
339            ("k_proj", &w_k, kv_hidden * hidden),
340            ("v_proj", &w_v, kv_hidden * hidden),
341            ("o_proj", &w_o, hidden * q_dim),
342        ];
343        for &(name, tensor, expected) in checks {
344            if tensor.len() != expected {
345                eprintln!(
346                    "[PMAT-331] {prefix}.{name}: shape mismatch — got {} elements, expected {expected}",
347                    tensor.len()
348                );
349                return None;
350            }
351        }
352
353        // Optional attention biases (Qwen2 uses Q/K/V biases)
354        let b_q = params.get(&format!("{prefix}.q_proj.bias")).cloned();
355        let b_k = params.get(&format!("{prefix}.k_proj.bias")).cloned();
356        let b_v = params.get(&format!("{prefix}.v_proj.bias")).cloned();
357
358        // Optional Q/K RMSNorm (Qwen3 uses QK-norm, ENT-269)
359        let q_norm = params.get(&format!("{prefix}.q_norm.weight")).cloned();
360        let k_norm = params.get(&format!("{prefix}.k_norm.weight")).cloned();
361
362        Some(Self { config: config.clone(), w_q, w_k, w_v, w_o, b_q, b_k, b_v, q_norm, k_norm })
363    }
364
365    /// Forward pass
366    ///
367    /// # Arguments
368    /// * `x` - Input tensor (seq_len * hidden_size, flattened)
369    /// * `seq_len` - Sequence length
370    ///
371    /// # Returns
372    /// Output tensor (seq_len * hidden_size, flattened)
373    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
374        contract_pre_attention!(x.data());
375        let hidden_size = self.config.hidden_size;
376        let num_heads = self.config.num_attention_heads;
377        let num_kv_heads = self.config.num_kv_heads;
378        let head_dim = self.config.head_dim();
379        let q_dim = self.config.q_dim();
380        let kv_hidden_size = num_kv_heads * head_dim;
381
382        // Project Q, K, V — HF weights are [out_dim, in_dim], use matmul_nt (ENT-269)
383        let mut q = matmul_nt(x, &self.w_q, seq_len, hidden_size, q_dim);
384        let mut k = matmul_nt(x, &self.w_k, seq_len, hidden_size, kv_hidden_size);
385        let mut v = matmul_nt(x, &self.w_v, seq_len, hidden_size, kv_hidden_size);
386
387        // Apply attention biases if present (Qwen2 architecture)
388        if let Some(ref b_q) = self.b_q {
389            q = add_bias(&q, b_q, seq_len);
390        }
391        if let Some(ref b_k) = self.b_k {
392            k = add_bias(&k, b_k, seq_len);
393        }
394        if let Some(ref b_v) = self.b_v {
395            v = add_bias(&v, b_v, seq_len);
396        }
397
398        // Apply Q/K RMSNorm if present (Qwen3 QK-norm, ENT-269)
399        if let Some(ref qn) = self.q_norm {
400            q = apply_qk_norm(&q, qn, seq_len, num_heads, head_dim);
401        }
402        if let Some(ref kn) = self.k_norm {
403            k = apply_qk_norm(&k, kn, seq_len, num_kv_heads, head_dim);
404        }
405
406        // Apply Rotary Position Embedding (RoPE) to Q and K (ENT-269)
407        // Skip for encoder models (BERT/RoBERTa use learned positions, not RoPE)
408        if self.config.rope_theta > 0.0 {
409            q = apply_rope(&q, seq_len, num_heads, head_dim, self.config.rope_theta);
410            k = apply_rope(&k, seq_len, num_kv_heads, head_dim, self.config.rope_theta);
411        }
412
413        let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
414        let heads_per_kv = num_heads / num_kv_heads;
415
416        // KAIZEN-016: Hoist data borrows outside the head loop to avoid
417        // num_heads × seq_len × 3 redundant RefCell borrows per attention call.
418        let q_data = q.data();
419        let q_slice = q_data.as_slice().expect("contiguous Q");
420        let k_data = k.data();
421        let k_slice = k_data.as_slice().expect("contiguous K");
422        let v_data = v.data();
423        let v_slice = v_data.as_slice().expect("contiguous V");
424
425        // Per-head attention with gradient tracking
426        let mut head_q_tensors = Vec::with_capacity(num_heads);
427        let mut head_k_tensors = Vec::with_capacity(num_heads);
428        let mut head_v_tensors = Vec::with_capacity(num_heads);
429        let mut head_outputs = Vec::with_capacity(num_heads);
430        let mut head_kv_indices = Vec::with_capacity(num_heads);
431
432        for h in 0..num_heads {
433            let kv_h = h / heads_per_kv;
434            head_kv_indices.push(kv_h);
435
436            // KAIZEN-016: Use extend_from_slice instead of flat_map+to_vec.
437            // Eliminates num_heads × 3 × seq_len intermediate Vec allocations
438            // (3.5M allocs/forward for Qwen3-4B).
439            let mut q_head = Vec::with_capacity(seq_len * head_dim);
440            for s in 0..seq_len {
441                let start = s * q_dim + h * head_dim;
442                q_head.extend_from_slice(&q_slice[start..start + head_dim]);
443            }
444
445            let mut k_head = Vec::with_capacity(seq_len * head_dim);
446            for s in 0..seq_len {
447                let start = s * kv_hidden_size + kv_h * head_dim;
448                k_head.extend_from_slice(&k_slice[start..start + head_dim]);
449            }
450
451            let mut v_head = Vec::with_capacity(seq_len * head_dim);
452            for s in 0..seq_len {
453                let start = s * kv_hidden_size + kv_h * head_dim;
454                v_head.extend_from_slice(&v_slice[start..start + head_dim]);
455            }
456
457            let q_tensor = Tensor::from_vec(q_head, requires_grad);
458            let k_tensor = Tensor::from_vec(k_head, requires_grad);
459            let v_tensor = Tensor::from_vec(v_head, requires_grad);
460
461            let attn_out = crate::autograd::attention(
462                &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
463            );
464
465            head_q_tensors.push(q_tensor);
466            head_k_tensors.push(k_tensor);
467            head_v_tensors.push(v_tensor);
468            head_outputs.push(attn_out);
469        }
470
471        // Concatenate heads: reorder from per-head (head, seq, dim) to (seq, head*dim)
472        let mut concat_output = vec![0.0; seq_len * q_dim];
473        for (h, head_out) in head_outputs.iter().enumerate() {
474            let hd = head_out.data();
475            let hdata = hd.as_slice().expect("contiguous attention output");
476            for s in 0..seq_len {
477                let src_base = s * head_dim;
478                let dst_base = s * q_dim + h * head_dim;
479                concat_output[dst_base..dst_base + head_dim]
480                    .copy_from_slice(&hdata[src_base..src_base + head_dim]);
481            }
482        }
483
484        let mut concat_tensor = Tensor::from_vec(concat_output, requires_grad);
485
486        if requires_grad {
487            let backward_op = Rc::new(AttentionBlockBackward {
488                q: q.clone(),
489                k: k.clone(),
490                v: v.clone(),
491                head_q_tensors,
492                head_k_tensors,
493                head_v_tensors,
494                head_outputs,
495                head_kv_indices,
496                seq_len,
497                head_dim,
498                q_dim,
499                kv_hidden_size,
500                result_grad: concat_tensor.grad_cell(),
501            });
502            concat_tensor.set_backward_op(backward_op);
503        }
504
505        // Output projection — w_o is [hidden_size, q_dim] in HF (ENT-269)
506        let result = matmul_nt(&concat_tensor, &self.w_o, seq_len, q_dim, hidden_size);
507        contract_post_attention!(result.data().as_slice().unwrap_or(&[]));
508        result
509    }
510
511    /// Forward pass with LoRA adjusts on Q and V projections (KAIZEN-010).
512    ///
513    /// Applies LoRA adapters to Q and V during the forward pass so that
514    /// gradients flow through LoRA A/B matrices on non-CUDA paths.
515    ///
516    /// # Arguments
517    /// * `x` - Input tensor (seq_len * hidden_size)
518    /// * `seq_len` - Sequence length
519    /// * `lora_a_q`, `lora_b_q` - Q projection LoRA matrices (rank×d_in, d_out×rank)
520    /// * `lora_a_v`, `lora_b_v` - V projection LoRA matrices (rank×d_in, d_out×rank)
521    /// * `lora_rank` - LoRA rank
522    /// * `lora_scale` - LoRA scaling factor (alpha/rank)
523    pub fn forward_with_lora(
524        &self,
525        x: &Tensor,
526        seq_len: usize,
527        lora_a_q: &Tensor,
528        // contract_pre_attention applied via forward()
529        lora_b_q: &Tensor,
530        lora_a_v: &Tensor,
531        lora_b_v: &Tensor,
532        lora_rank: usize,
533        lora_scale: f32,
534    ) -> Tensor {
535        contract_pre_lora_forward!();
536        let hidden_size = self.config.hidden_size;
537        let num_heads = self.config.num_attention_heads;
538        let num_kv_heads = self.config.num_kv_heads;
539        let head_dim = self.config.head_dim();
540        let q_dim = self.config.q_dim();
541        let kv_hidden_size = num_kv_heads * head_dim;
542
543        // Q projection with LoRA: Q = x @ W_q + scale * (x @ A_q^T) @ B_q^T
544        //
545        // KAIZEN-011: Use matmul_nt to compute x @ A^T directly on the ORIGINAL
546        // LoRA tensors. Previous impl created transposed copies via Tensor::from_vec
547        // which broke gradient flow — gradients accumulated on ephemeral copies
548        // instead of the actual trainable LoRA parameters.
549        //
550        // LoRA layout: A is (rank, d_in), B is (d_out, rank)
551        // matmul_nt(x, A, seq, d_in, rank) computes x @ A^T = (seq, d_in) @ (d_in, rank) = (seq, rank)
552        // matmul_nt(mid, B, seq, rank, d_out) computes mid @ B^T = (seq, rank) @ (rank, d_out) = (seq, d_out)
553        let q_base = matmul_nt(x, &self.w_q, seq_len, hidden_size, q_dim);
554        let q_mid = crate::autograd::matmul_nt(x, lora_a_q, seq_len, hidden_size, lora_rank);
555        let q_lora = crate::autograd::matmul_nt(&q_mid, lora_b_q, seq_len, lora_rank, q_dim);
556        let q = crate::autograd::add_scaled(&q_base, &q_lora, lora_scale);
557
558        // K projection (no LoRA) — HF weights [out, in] (ENT-269)
559        let k = matmul_nt(x, &self.w_k, seq_len, hidden_size, kv_hidden_size);
560
561        // V projection with LoRA (same pattern as Q)
562        let v_base = matmul_nt(x, &self.w_v, seq_len, hidden_size, kv_hidden_size);
563        let v_mid = crate::autograd::matmul_nt(x, lora_a_v, seq_len, hidden_size, lora_rank);
564        let v_lora =
565            crate::autograd::matmul_nt(&v_mid, lora_b_v, seq_len, lora_rank, kv_hidden_size);
566        let v = crate::autograd::add_scaled(&v_base, &v_lora, lora_scale);
567
568        // Apply Q/K RMSNorm if present (Qwen3 QK-norm, ENT-269)
569        let q = if let Some(ref qn) = self.q_norm {
570            apply_qk_norm(&q, qn, seq_len, num_heads, head_dim)
571        } else {
572            q
573        };
574        let k = if let Some(ref kn) = self.k_norm {
575            apply_qk_norm(&k, kn, seq_len, num_kv_heads, head_dim)
576        } else {
577            k
578        };
579
580        // Apply Rotary Position Embedding (RoPE) to Q and K (ENT-269)
581        // Skip for encoder models (BERT/RoBERTa use learned positions, not RoPE)
582        let (q, k) = if self.config.rope_theta > 0.0 {
583            (
584                apply_rope(&q, seq_len, num_heads, head_dim, self.config.rope_theta),
585                apply_rope(&k, seq_len, num_kv_heads, head_dim, self.config.rope_theta),
586            )
587        } else {
588            (q, k)
589        };
590
591        let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
592        let heads_per_kv = num_heads / num_kv_heads;
593
594        // KAIZEN-016: Hoist data borrows outside head loop (same optimization as forward())
595        let q_data = q.data();
596        let q_slice = q_data.as_slice().expect("contiguous Q");
597        let k_data = k.data();
598        let k_slice = k_data.as_slice().expect("contiguous K");
599        let v_data = v.data();
600        let v_slice = v_data.as_slice().expect("contiguous V");
601
602        // Per-head attention (same as forward())
603        let mut head_q_tensors = Vec::with_capacity(num_heads);
604        let mut head_k_tensors = Vec::with_capacity(num_heads);
605        let mut head_v_tensors = Vec::with_capacity(num_heads);
606        let mut head_outputs = Vec::with_capacity(num_heads);
607        let mut head_kv_indices = Vec::with_capacity(num_heads);
608
609        for h in 0..num_heads {
610            let kv_h = h / heads_per_kv;
611            head_kv_indices.push(kv_h);
612
613            // KAIZEN-016: extend_from_slice replaces flat_map+to_vec
614            let mut q_head = Vec::with_capacity(seq_len * head_dim);
615            for s in 0..seq_len {
616                let start = s * q_dim + h * head_dim;
617                q_head.extend_from_slice(&q_slice[start..start + head_dim]);
618            }
619
620            let mut k_head = Vec::with_capacity(seq_len * head_dim);
621            for s in 0..seq_len {
622                let start = s * kv_hidden_size + kv_h * head_dim;
623                k_head.extend_from_slice(&k_slice[start..start + head_dim]);
624            }
625
626            let mut v_head = Vec::with_capacity(seq_len * head_dim);
627            for s in 0..seq_len {
628                let start = s * kv_hidden_size + kv_h * head_dim;
629                v_head.extend_from_slice(&v_slice[start..start + head_dim]);
630            }
631
632            let q_tensor = Tensor::from_vec(q_head, requires_grad);
633            let k_tensor = Tensor::from_vec(k_head, requires_grad);
634            let v_tensor = Tensor::from_vec(v_head, requires_grad);
635
636            let attn_out = crate::autograd::attention(
637                &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
638            );
639
640            head_q_tensors.push(q_tensor);
641            head_k_tensors.push(k_tensor);
642            head_v_tensors.push(v_tensor);
643            head_outputs.push(attn_out);
644        }
645
646        // Concatenate heads
647        let mut concat_output = vec![0.0; seq_len * q_dim];
648        for (h, head_out) in head_outputs.iter().enumerate() {
649            let hd = head_out.data();
650            let hdata = hd.as_slice().expect("contiguous attention output");
651            for s in 0..seq_len {
652                let src_base = s * head_dim;
653                let dst_base = s * q_dim + h * head_dim;
654                concat_output[dst_base..dst_base + head_dim]
655                    .copy_from_slice(&hdata[src_base..src_base + head_dim]);
656            }
657        }
658
659        let mut concat_tensor = Tensor::from_vec(concat_output, requires_grad);
660
661        if requires_grad {
662            let backward_op = Rc::new(AttentionBlockBackward {
663                q: q.clone(),
664                k: k.clone(),
665                v: v.clone(),
666                head_q_tensors,
667                head_k_tensors,
668                head_v_tensors,
669                head_outputs,
670                head_kv_indices,
671                seq_len,
672                head_dim,
673                q_dim,
674                kv_hidden_size,
675                result_grad: concat_tensor.grad_cell(),
676            });
677            concat_tensor.set_backward_op(backward_op);
678        }
679
680        // Output projection — w_o is [hidden_size, q_dim] in HF (ENT-269)
681        let result = matmul_nt(&concat_tensor, &self.w_o, seq_len, q_dim, hidden_size);
682        contract_post_lora_forward!(result);
683        result
684    }
685
686    /// Get all parameters as a vector
687    pub fn parameters(&self) -> Vec<&Tensor> {
688        let mut params = vec![&self.w_q, &self.w_k, &self.w_v, &self.w_o];
689        if let Some(ref b) = self.b_q {
690            params.push(b);
691        }
692        if let Some(ref b) = self.b_k {
693            params.push(b);
694        }
695        if let Some(ref b) = self.b_v {
696            params.push(b);
697        }
698        params
699    }
700
701    /// Get all parameters as mutable references for optimizer
702    pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
703        let mut params = vec![&mut self.w_q, &mut self.w_k, &mut self.w_v, &mut self.w_o];
704        if let Some(ref mut b) = self.b_q {
705            params.push(b);
706        }
707        if let Some(ref mut b) = self.b_k {
708            params.push(b);
709        }
710        if let Some(ref mut b) = self.b_v {
711            params.push(b);
712        }
713        params
714    }
715
716    /// Whether this attention layer has QKV biases
717    pub fn has_biases(&self) -> bool {
718        self.b_q.is_some()
719    }
720
721    /// Get named parameters for checkpoint serialization
722    pub fn named_parameters(&self, prefix: &str) -> Vec<(String, &Tensor)> {
723        let mut params = vec![
724            (format!("{prefix}.q_proj.weight"), &self.w_q),
725            (format!("{prefix}.k_proj.weight"), &self.w_k),
726            (format!("{prefix}.v_proj.weight"), &self.w_v),
727            (format!("{prefix}.o_proj.weight"), &self.w_o),
728        ];
729        if let Some(ref b) = self.b_q {
730            params.push((format!("{prefix}.q_proj.bias"), b));
731        }
732        if let Some(ref b) = self.b_k {
733            params.push((format!("{prefix}.k_proj.bias"), b));
734        }
735        if let Some(ref b) = self.b_v {
736            params.push((format!("{prefix}.v_proj.bias"), b));
737        }
738        params
739    }
740
741    /// ENT-282: Set a named parameter by suffix (after "self_attn.").
742    pub fn set_named_parameter(&mut self, suffix: &str, value: Tensor) -> bool {
743        match suffix {
744            "self_attn.q_proj.weight" => {
745                self.w_q = value;
746                true
747            }
748            "self_attn.k_proj.weight" => {
749                self.w_k = value;
750                true
751            }
752            "self_attn.v_proj.weight" => {
753                self.w_v = value;
754                true
755            }
756            "self_attn.o_proj.weight" => {
757                self.w_o = value;
758                true
759            }
760            _ => false,
761        }
762    }
763}
764
765/// LoRA-enabled linear projection
766///
767/// Computes: y = x @ W + scale * (x @ A) @ B
768/// Where W is frozen base weight, A and B are trainable LoRA adapters
769pub struct LoRAProjection {
770    /// Base weight (frozen), shape (d_in × d_out)
771    pub base_weight: Tensor,
772    /// LoRA A matrix (down-projection), shape (d_in × rank)
773    pub lora_a: Tensor,
774    /// LoRA B matrix (up-projection), shape (rank × d_out)
775    pub lora_b: Tensor,
776    /// Input dimension
777    pub d_in: usize,
778    /// Output dimension
779    pub d_out: usize,
780    /// LoRA rank
781    pub rank: usize,
782    /// Scaling factor (alpha / rank)
783    pub scale: f32,
784}
785
786impl LoRAProjection {
787    /// Create a new LoRA projection
788    ///
789    /// # Arguments
790    /// * `base_weight` - Frozen base weight [d_in × d_out]
791    /// * `d_in` - Input dimension
792    /// * `d_out` - Output dimension
793    /// * `rank` - LoRA rank (typically 4, 8, 16, 32, or 64)
794    /// * `alpha` - LoRA scaling parameter
795    pub fn new(base_weight: Tensor, d_in: usize, d_out: usize, rank: usize, alpha: f32) -> Self {
796        assert_eq!(base_weight.len(), d_in * d_out, "Base weight size mismatch");
797
798        // Freeze base weight — only LoRA adapters are trainable
799        let mut base_weight = base_weight;
800        base_weight.set_requires_grad(false);
801
802        // Initialize A with Kaiming uniform (standard LoRA paper)
803        let lora_a = Tensor::from_vec(
804            (0..d_in * rank).map(|i| (i as f32 * 0.123).sin() * 0.01).collect(),
805            true, // requires_grad
806        );
807
808        // Initialize B with zeros (LoRA invariant: ΔW = B @ A = 0 at init)
809        let lora_b = Tensor::zeros(rank * d_out, true);
810
811        Self { base_weight, lora_a, lora_b, d_in, d_out, rank, scale: alpha / rank as f32 }
812    }
813
814    /// Forward pass with LoRA
815    ///
816    /// Computes: y = x @ W + scale * (x @ A) @ B
817    ///
818    /// # Arguments
819    /// * `x` - Input tensor [seq_len × d_in]
820    /// * `seq_len` - Sequence length
821    ///
822    /// # Returns
823    /// Output tensor [seq_len × d_out]
824    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
825        // Base projection: x @ W, (seq × d_in) @ (d_in × d_out) = (seq × d_out)
826        let base_out = matmul(x, &self.base_weight, seq_len, self.d_in, self.d_out);
827
828        // LoRA path: scale * (x @ A) @ B
829        // Step 1: x @ A, (seq × d_in) @ (d_in × rank) = (seq × rank)
830        let lora_intermediate = matmul(x, &self.lora_a, seq_len, self.d_in, self.rank);
831
832        // Step 2: (x @ A) @ B, (seq × rank) @ (rank × d_out) = (seq × d_out)
833        let lora_out = matmul(&lora_intermediate, &self.lora_b, seq_len, self.rank, self.d_out);
834
835        // Combine: base + scale * lora
836        // Use autograd-compatible addition
837        crate::autograd::add_scaled(&base_out, &lora_out, self.scale)
838    }
839
840    /// Get trainable LoRA parameters
841    pub fn lora_params(&self) -> Vec<&Tensor> {
842        vec![&self.lora_a, &self.lora_b]
843    }
844
845    /// Get mutable trainable LoRA parameters
846    pub fn lora_params_mut(&mut self) -> Vec<&mut Tensor> {
847        vec![&mut self.lora_a, &mut self.lora_b]
848    }
849}
850
851/// Multi-head attention with deep LoRA injection
852///
853/// LoRA adapters are applied to Q, K, V, O projections during forward pass
854pub struct MultiHeadAttentionWithLoRA {
855    /// Configuration
856    pub config: TransformerConfig,
857    /// Query projection with LoRA
858    pub q_proj: LoRAProjection,
859    /// Key projection with LoRA
860    pub k_proj: LoRAProjection,
861    /// Value projection with LoRA
862    pub v_proj: LoRAProjection,
863    /// Output projection with LoRA
864    pub o_proj: LoRAProjection,
865}
866
867impl MultiHeadAttentionWithLoRA {
868    /// Create LoRA-enabled attention from existing attention weights
869    ///
870    /// # Arguments
871    /// * `attn` - Base MultiHeadAttention with pretrained weights
872    /// * `rank` - LoRA rank
873    /// * `alpha` - LoRA alpha scaling factor
874    pub fn from_attention(attn: &MultiHeadAttention, rank: usize, alpha: f32) -> Self {
875        let hidden_size = attn.config.hidden_size;
876        let q_dim = attn.config.q_dim();
877        let kv_hidden_size = attn.config.num_kv_heads * attn.config.head_dim();
878
879        Self {
880            config: attn.config.clone(),
881            q_proj: LoRAProjection::new(attn.w_q.clone(), hidden_size, q_dim, rank, alpha),
882            k_proj: LoRAProjection::new(attn.w_k.clone(), hidden_size, kv_hidden_size, rank, alpha),
883            v_proj: LoRAProjection::new(attn.w_v.clone(), hidden_size, kv_hidden_size, rank, alpha),
884            o_proj: LoRAProjection::new(attn.w_o.clone(), q_dim, hidden_size, rank, alpha),
885        }
886    }
887
888    /// Forward pass with deep LoRA injection
889    ///
890    /// LoRA is applied to all Q, K, V, O projections
891    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
892        let num_heads = self.config.num_attention_heads;
893        let num_kv_heads = self.config.num_kv_heads;
894        let head_dim = self.config.head_dim();
895        let q_dim = self.config.q_dim();
896        let kv_hidden_size = num_kv_heads * head_dim;
897
898        // Project Q, K, V with LoRA
899        let q = self.q_proj.forward(x, seq_len);
900        let k = self.k_proj.forward(x, seq_len);
901        let v = self.v_proj.forward(x, seq_len);
902
903        // Multi-head attention with grouped-query attention support
904        let mut attn_outputs = Vec::with_capacity(num_heads * seq_len * head_dim);
905        let heads_per_kv = num_heads / num_kv_heads;
906
907        // KAIZEN-016: Hoist data borrows outside head loop
908        let q_data = q.data();
909        let q_slice = q_data.as_slice().expect("contiguous Q tensor");
910        let k_data = k.data();
911        let k_slice = k_data.as_slice().expect("contiguous K tensor");
912        let v_data = v.data();
913        let v_slice = v_data.as_slice().expect("contiguous V tensor");
914
915        for h in 0..num_heads {
916            let kv_h = h / heads_per_kv;
917
918            // KAIZEN-016: extend_from_slice replaces flat_map+to_vec
919            let mut q_head = Vec::with_capacity(seq_len * head_dim);
920            for s in 0..seq_len {
921                let start = s * q_dim + h * head_dim;
922                q_head.extend_from_slice(&q_slice[start..start + head_dim]);
923            }
924
925            let mut k_head = Vec::with_capacity(seq_len * head_dim);
926            for s in 0..seq_len {
927                let start = s * kv_hidden_size + kv_h * head_dim;
928                k_head.extend_from_slice(&k_slice[start..start + head_dim]);
929            }
930
931            let mut v_head = Vec::with_capacity(seq_len * head_dim);
932            for s in 0..seq_len {
933                let start = s * kv_hidden_size + kv_h * head_dim;
934                v_head.extend_from_slice(&v_slice[start..start + head_dim]);
935            }
936
937            // Scaled dot-product attention
938            let q_tensor = Tensor::from_vec(q_head, false);
939            let k_tensor = Tensor::from_vec(k_head, false);
940            let v_tensor = Tensor::from_vec(v_head, false);
941
942            let attn_out = crate::autograd::attention(
943                &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
944            );
945
946            attn_outputs.extend_from_slice(
947                attn_out.data().as_slice().expect("contiguous attention output"),
948            );
949        }
950
951        // Concatenate heads and reorder: (seq_len, q_dim)
952        let mut concat_output = vec![0.0; seq_len * q_dim];
953        for h in 0..num_heads {
954            for s in 0..seq_len {
955                let src_idx = h * seq_len * head_dim + s * head_dim;
956                let dst_idx = s * q_dim + h * head_dim;
957                concat_output[dst_idx..dst_idx + head_dim]
958                    .copy_from_slice(&attn_outputs[src_idx..src_idx + head_dim]);
959            }
960        }
961
962        let concat_tensor = Tensor::from_vec(concat_output, true);
963
964        // Output projection with LoRA: (seq_len, q_dim) -> (seq_len, hidden_size)
965        self.o_proj.forward(&concat_tensor, seq_len)
966    }
967
968    /// Get all trainable LoRA parameters
969    pub fn lora_params(&self) -> Vec<&Tensor> {
970        let mut params = Vec::new();
971        params.extend(self.q_proj.lora_params());
972        params.extend(self.k_proj.lora_params());
973        params.extend(self.v_proj.lora_params());
974        params.extend(self.o_proj.lora_params());
975        params
976    }
977
978    /// Get all trainable LoRA parameters as mutable references
979    pub fn lora_params_mut(&mut self) -> Vec<&mut Tensor> {
980        let mut params = Vec::new();
981        params.extend(self.q_proj.lora_params_mut());
982        params.extend(self.k_proj.lora_params_mut());
983        params.extend(self.v_proj.lora_params_mut());
984        params.extend(self.o_proj.lora_params_mut());
985        params
986    }
987
988    /// Count total LoRA parameters
989    pub fn lora_param_count(&self) -> usize {
990        // Each projection has A (d_in × rank) + B (rank × d_out)
991        let hidden = self.config.hidden_size;
992        let kv_hidden = self.config.num_kv_heads * self.config.head_dim();
993        let rank = self.q_proj.rank;
994
995        // Q: (hidden × rank) + (rank × hidden)
996        // K: (hidden × rank) + (rank × kv_hidden)
997        // V: (hidden × rank) + (rank × kv_hidden)
998        // O: (hidden × rank) + (rank × hidden)
999        (hidden * rank + rank * hidden)      // Q
1000            + (hidden * rank + rank * kv_hidden) // K
1001            + (hidden * rank + rank * kv_hidden) // V
1002            + (hidden * rank + rank * hidden) // O
1003    }
1004}
1005
1006#[cfg(test)]
1007mod tests {
1008    use super::*;
1009
1010    #[test]
1011    fn test_multi_head_attention_tiny() {
1012        let config = TransformerConfig::tiny();
1013        let attn = MultiHeadAttention::new(&config);
1014        let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
1015        let output = attn.forward(&x, 2);
1016        assert_eq!(output.len(), 2 * config.hidden_size);
1017    }
1018
1019    #[test]
1020    fn test_multi_head_attention_parameters() {
1021        let config = TransformerConfig::tiny();
1022        let attn = MultiHeadAttention::new(&config);
1023        let params = attn.parameters();
1024        assert_eq!(params.len(), 4); // w_q, w_k, w_v, w_o
1025    }
1026
1027    #[test]
1028    fn test_attention_longer_sequence() {
1029        let config = TransformerConfig::tiny();
1030        let attn = MultiHeadAttention::new(&config);
1031        let x = Tensor::from_vec(vec![0.1; 8 * config.hidden_size], true);
1032        let output = attn.forward(&x, 8);
1033        assert_eq!(output.len(), 8 * config.hidden_size);
1034    }
1035
1036    #[test]
1037    fn test_attention_weight_sizes() {
1038        let config = TransformerConfig::tiny();
1039        let attn = MultiHeadAttention::new(&config);
1040        let kv_hidden = config.num_kv_heads * config.head_dim();
1041        assert_eq!(attn.w_q.len(), config.hidden_size * config.hidden_size);
1042        assert_eq!(attn.w_k.len(), config.hidden_size * kv_hidden);
1043        assert_eq!(attn.w_v.len(), config.hidden_size * kv_hidden);
1044        assert_eq!(attn.w_o.len(), config.hidden_size * config.hidden_size);
1045    }
1046
1047    #[test]
1048    fn test_multi_head_attention_from_params_success() {
1049        let config = TransformerConfig::tiny();
1050        let hidden_size = config.hidden_size;
1051        let kv_hidden_size = config.num_kv_heads * config.head_dim();
1052
1053        let mut params = HashMap::new();
1054        params.insert(
1055            "attn.q_proj.weight".to_string(),
1056            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1057        );
1058        params.insert(
1059            "attn.k_proj.weight".to_string(),
1060            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1061        );
1062        params.insert(
1063            "attn.v_proj.weight".to_string(),
1064            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1065        );
1066        params.insert(
1067            "attn.o_proj.weight".to_string(),
1068            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1069        );
1070
1071        let attn = MultiHeadAttention::from_params(&config, &params, "attn");
1072        assert!(attn.is_some());
1073        let attn = attn.expect("operation should succeed");
1074        assert_eq!(attn.w_q.len(), hidden_size * hidden_size);
1075    }
1076
1077    #[test]
1078    fn test_multi_head_attention_from_params_missing_key() {
1079        let config = TransformerConfig::tiny();
1080        let hidden_size = config.hidden_size;
1081
1082        let mut params = HashMap::new();
1083        params.insert(
1084            "attn.q_proj.weight".to_string(),
1085            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1086        );
1087        // Missing k_proj, v_proj, o_proj
1088
1089        let attn = MultiHeadAttention::from_params(&config, &params, "attn");
1090        assert!(attn.is_none());
1091    }
1092
1093    #[test]
1094    fn test_attention_projections_backward() {
1095        // Test that Q, K, V projection matmuls have gradients
1096        // (isolated from the full attention which has intermediate tensor issues)
1097        let config = TransformerConfig::tiny();
1098        let attn = MultiHeadAttention::new(&config);
1099        let hidden_size = config.hidden_size;
1100        let seq_len = 2;
1101
1102        let x = Tensor::from_vec(vec![0.1; seq_len * hidden_size], true);
1103
1104        // Test Q projection
1105        let mut q = crate::autograd::matmul(&x, &attn.w_q, seq_len, hidden_size, hidden_size);
1106        let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1107        crate::autograd::backward(&mut q, Some(grad_out));
1108
1109        assert!(attn.w_q.grad().is_some());
1110        let grad_q = attn.w_q.grad().expect("gradient should be available");
1111        assert!(grad_q.iter().all(|&v| v.is_finite()));
1112    }
1113
1114    #[test]
1115    fn test_output_projection_backward() {
1116        // Test output projection in isolation
1117        let config = TransformerConfig::tiny();
1118        let attn = MultiHeadAttention::new(&config);
1119        let hidden_size = config.hidden_size;
1120        let seq_len = 2;
1121
1122        // Simulate concatenated attention output
1123        let concat_out = Tensor::from_vec(vec![0.1; seq_len * hidden_size], true);
1124
1125        // Output projection
1126        let mut output =
1127            crate::autograd::matmul(&concat_out, &attn.w_o, seq_len, hidden_size, hidden_size);
1128
1129        let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1130        crate::autograd::backward(&mut output, Some(grad_out));
1131
1132        assert!(attn.w_o.grad().is_some());
1133        let grad_o = attn.w_o.grad().expect("gradient should be available");
1134        assert!(grad_o.iter().all(|&v| v.is_finite()));
1135        let sum: f32 = grad_o.iter().map(|v| v.abs()).sum();
1136        assert!(sum > 0.0, "Output projection gradient should not be all zero");
1137    }
1138
1139    /// ALB-038: Full attention forward must propagate gradients to Q/K/V weights
1140    ///
1141    /// NOTE: Currently fails because apply_rope() has no backward op — it severs
1142    /// the autograd chain for Q and K. Needs a proper RoPE backward implementation
1143    /// (ENT-272). Skipped until then.
1144    #[test]
1145    #[ignore = "apply_rope() severs autograd chain — needs backward op (ENT-272)"]
1146    fn test_attention_full_forward_qkv_gradients() {
1147        let config = TransformerConfig::tiny();
1148        let attn = MultiHeadAttention::new(&config);
1149        let hidden_size = config.hidden_size;
1150        let seq_len = 3;
1151
1152        // Non-uniform input: different positions must have different representations
1153        // so softmax produces non-uniform weights with non-zero score gradients
1154        let x_data: Vec<f32> =
1155            (0..seq_len * hidden_size).map(|i| ((i as f32) * 0.17).sin() * 0.5).collect();
1156        let x = Tensor::from_vec(x_data, true);
1157        let mut output = attn.forward(&x, seq_len);
1158
1159        let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1160        crate::autograd::backward(&mut output, Some(grad_out));
1161
1162        // All four projection weights must receive gradients
1163        for (name, param) in
1164            [("w_q", &attn.w_q), ("w_k", &attn.w_k), ("w_v", &attn.w_v), ("w_o", &attn.w_o)]
1165        {
1166            assert!(
1167                param.grad().is_some(),
1168                "ALB-038: {name} must have gradient after full attention forward"
1169            );
1170            let grad = param.grad().expect("gradient available");
1171            assert!(grad.iter().all(|&v| v.is_finite()), "ALB-038: {name} gradient must be finite");
1172            assert!(
1173                grad.iter().any(|&v| v.abs() > 1e-10),
1174                "ALB-038: {name} gradient must be non-zero"
1175            );
1176        }
1177
1178        // Input must also receive gradient (enables gradient flow through model)
1179        assert!(x.grad().is_some(), "ALB-038: input x must have gradient");
1180    }
1181
1182    // ============================================================================
1183    // LoRAProjection tests
1184    // ============================================================================
1185
1186    #[test]
1187    fn test_lora_projection_new() {
1188        let d_in = 32;
1189        let d_out = 16;
1190        let rank = 4;
1191        let alpha = 8.0;
1192
1193        let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1194        let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, alpha);
1195
1196        assert_eq!(lora.d_in, d_in);
1197        assert_eq!(lora.d_out, d_out);
1198        assert_eq!(lora.rank, rank);
1199        assert!((lora.scale - 2.0).abs() < 1e-6); // alpha / rank = 8 / 4 = 2
1200        assert_eq!(lora.lora_a.len(), d_in * rank);
1201        assert_eq!(lora.lora_b.len(), rank * d_out);
1202    }
1203
1204    #[test]
1205    fn test_lora_projection_forward() {
1206        let d_in = 32;
1207        let d_out = 16;
1208        let rank = 4;
1209        let alpha = 8.0;
1210        let seq_len = 2;
1211
1212        let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1213        let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, alpha);
1214
1215        let x = Tensor::from_vec(vec![0.1; seq_len * d_in], false);
1216        let output = lora.forward(&x, seq_len);
1217
1218        assert_eq!(output.len(), seq_len * d_out);
1219        // Check output is finite
1220        assert!(output.data().iter().all(|&v| v.is_finite()));
1221    }
1222
1223    #[test]
1224    fn test_lora_projection_params() {
1225        let d_in = 32;
1226        let d_out = 16;
1227        let rank = 4;
1228
1229        let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1230        let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1231
1232        let params = lora.lora_params();
1233        assert_eq!(params.len(), 2); // lora_a and lora_b
1234    }
1235
1236    #[test]
1237    fn test_lora_projection_params_mut() {
1238        let d_in = 32;
1239        let d_out = 16;
1240        let rank = 4;
1241
1242        let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1243        let mut lora = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1244
1245        let params = lora.lora_params_mut();
1246        assert_eq!(params.len(), 2);
1247    }
1248
1249    #[test]
1250    #[should_panic(expected = "Base weight size mismatch")]
1251    fn test_lora_projection_size_mismatch() {
1252        let d_in = 32;
1253        let d_out = 16;
1254        let rank = 4;
1255
1256        // Wrong base weight size
1257        let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out + 1], false);
1258        let _ = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1259    }
1260
1261    // ============================================================================
1262    // MultiHeadAttentionWithLoRA tests
1263    // ============================================================================
1264
1265    #[test]
1266    fn test_mha_with_lora_creation() {
1267        let config = TransformerConfig::tiny();
1268        let attn = MultiHeadAttention::new(&config);
1269        let rank = 4;
1270        let alpha = 8.0;
1271
1272        let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, rank, alpha);
1273
1274        assert_eq!(lora_attn.q_proj.rank, rank);
1275        assert_eq!(lora_attn.k_proj.rank, rank);
1276        assert_eq!(lora_attn.v_proj.rank, rank);
1277        assert_eq!(lora_attn.o_proj.rank, rank);
1278    }
1279
1280    #[test]
1281    fn test_mha_with_lora_forward() {
1282        let config = TransformerConfig::tiny();
1283        let attn = MultiHeadAttention::new(&config);
1284        let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1285
1286        let seq_len = 2;
1287        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], false);
1288        let output = lora_attn.forward(&x, seq_len);
1289
1290        assert_eq!(output.len(), seq_len * config.hidden_size);
1291        // Check output is finite and non-zero
1292        assert!(output.data().iter().all(|&v| v.is_finite()));
1293    }
1294
1295    #[test]
1296    fn test_mha_with_lora_params() {
1297        let config = TransformerConfig::tiny();
1298        let attn = MultiHeadAttention::new(&config);
1299        let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1300
1301        let params = lora_attn.lora_params();
1302        // 4 projections × 2 params each = 8
1303        assert_eq!(params.len(), 8);
1304    }
1305
1306    #[test]
1307    fn test_mha_with_lora_params_mut() {
1308        let config = TransformerConfig::tiny();
1309        let attn = MultiHeadAttention::new(&config);
1310        let mut lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1311
1312        let params = lora_attn.lora_params_mut();
1313        assert_eq!(params.len(), 8);
1314    }
1315
1316    #[test]
1317    fn test_mha_with_lora_param_count() {
1318        let config = TransformerConfig::tiny();
1319        let attn = MultiHeadAttention::new(&config);
1320        let rank = 4;
1321        let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, rank, 8.0);
1322
1323        let param_count = lora_attn.lora_param_count();
1324
1325        // Calculate expected:
1326        let hidden = config.hidden_size;
1327        let kv_hidden = config.num_kv_heads * config.head_dim();
1328        let expected = (hidden * rank + rank * hidden)      // Q
1329            + (hidden * rank + rank * kv_hidden) // K
1330            + (hidden * rank + rank * kv_hidden) // V
1331            + (hidden * rank + rank * hidden); // O
1332
1333        assert_eq!(param_count, expected);
1334        assert!(param_count > 0);
1335    }
1336
1337    #[test]
1338    fn test_mha_with_lora_longer_sequence() {
1339        let config = TransformerConfig::tiny();
1340        let attn = MultiHeadAttention::new(&config);
1341        let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1342
1343        let seq_len = 8;
1344        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], false);
1345        let output = lora_attn.forward(&x, seq_len);
1346
1347        assert_eq!(output.len(), seq_len * config.hidden_size);
1348    }
1349
1350    #[test]
1351    fn test_parameters_mut() {
1352        let config = TransformerConfig::tiny();
1353        let mut attn = MultiHeadAttention::new(&config);
1354
1355        let params = attn.parameters_mut();
1356        assert_eq!(params.len(), 4);
1357    }
1358
1359    // =========================================================================
1360    // FALSIFY-A: §2.1.3 Attention Projections — Five-Whys Gap Analysis (Refs PMAT-331)
1361    //
1362    // Contract: tensor-layout-v1.yaml §tensors.q_proj/k_proj/v_proj/o_proj
1363    //   q_proj: [num_heads*head_dim, hidden] (= [hidden, hidden] for MHA)
1364    //   k_proj: [num_kv_heads*head_dim, hidden] (smaller for GQA)
1365    //   v_proj: [num_kv_heads*head_dim, hidden] (smaller for GQA)
1366    //   o_proj: [hidden, num_heads*head_dim]
1367    //
1368    // Five-Whys:
1369    //   Why 1: Trained model's attention weights could be wrong shape
1370    //   Why 2: from_params accepts any tensor without shape validation
1371    //   Why 3: No ValidatedWeight in entrenar
1372    //   Why 4: entrenar predates the Poka-Yoke contract
1373    //   Why 5: No cross-crate contract enforcement for training weights
1374    //
1375    // Popper (1959): "These tests attempt to falsify the claim that
1376    // entrenar's attention weight handling prevents degenerate models."
1377    // =========================================================================
1378
1379    /// FALSIFY-A1e: from_params rejects wrong-shape Q weight (PMAT-331 fix)
1380    ///
1381    /// from_params now validates Q projection shape against config dimensions.
1382    /// A tensor of 50 elements is rejected when hidden*hidden is expected.
1383    #[test]
1384    fn falsify_a1e_from_params_rejects_wrong_shape_q_weight() {
1385        let config = TransformerConfig::tiny();
1386        let hidden_size = config.hidden_size;
1387        let kv_hidden_size = config.num_kv_heads * config.head_dim();
1388
1389        let mut params = HashMap::new();
1390        // WRONG-SHAPE q_proj: 50 elements instead of hidden*hidden
1391        params.insert("attn.q_proj.weight".to_string(), Tensor::from_vec(vec![0.1; 50], true));
1392        // Correct k, v, o
1393        params.insert(
1394            "attn.k_proj.weight".to_string(),
1395            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1396        );
1397        params.insert(
1398            "attn.v_proj.weight".to_string(),
1399            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1400        );
1401        params.insert(
1402            "attn.o_proj.weight".to_string(),
1403            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1404        );
1405
1406        let attn = MultiHeadAttention::from_params(&config, &params, "attn");
1407        // FIXED (PMAT-331): now rejected
1408        assert!(
1409            attn.is_none(),
1410            "FALSIFY-A1e: PMAT-331 fix — from_params MUST reject wrong-shape q_proj"
1411        );
1412    }
1413
1414    /// FALSIFY-A2e: GQA init produces correct K/V dimensions
1415    ///
1416    /// For GQA (num_kv_heads < num_heads), K/V must be smaller than Q.
1417    /// If init uses num_heads for K/V, the shapes are wrong.
1418    #[test]
1419    fn falsify_a2e_gqa_init_correct_kv_dimensions() {
1420        let mut config = TransformerConfig::tiny();
1421        config.num_kv_heads = 1; // Force GQA: 1 KV head, but num_heads > 1
1422
1423        let attn = MultiHeadAttention::new(&config);
1424        let head_dim = config.head_dim();
1425        let kv_hidden = config.num_kv_heads * head_dim; // 1 * head_dim
1426
1427        // Q: hidden * hidden
1428        assert_eq!(
1429            attn.w_q.len(),
1430            config.hidden_size * config.hidden_size,
1431            "FALSIFY-A2e: Q projection must be hidden*hidden"
1432        );
1433
1434        // K: hidden * kv_hidden (smaller than Q for GQA)
1435        assert_eq!(
1436            attn.w_k.len(),
1437            config.hidden_size * kv_hidden,
1438            "FALSIFY-A2e: K projection must use num_kv_heads, not num_heads"
1439        );
1440
1441        // V: hidden * kv_hidden (same as K)
1442        assert_eq!(
1443            attn.w_v.len(),
1444            config.hidden_size * kv_hidden,
1445            "FALSIFY-A2e: V projection must use num_kv_heads, not num_heads"
1446        );
1447
1448        // O: hidden * hidden (matches Q output)
1449        assert_eq!(
1450            attn.w_o.len(),
1451            config.hidden_size * config.hidden_size,
1452            "FALSIFY-A2e: O projection must be hidden*hidden"
1453        );
1454
1455        // K/V must be SMALLER than Q for GQA
1456        assert!(
1457            attn.w_k.len() < attn.w_q.len(),
1458            "FALSIFY-A2e: For GQA, K weight must be smaller than Q weight"
1459        );
1460    }
1461
1462    /// FALSIFY-A3e: GQA forward produces correct output dimensions
1463    ///
1464    /// With num_kv_heads < num_heads, the forward pass must still produce
1465    /// [seq_len, hidden_size] output (not [seq_len, kv_hidden]).
1466    #[test]
1467    fn falsify_a3e_gqa_forward_correct_output_dims() {
1468        let mut config = TransformerConfig::tiny();
1469        config.num_kv_heads = 1; // Force GQA
1470
1471        let attn = MultiHeadAttention::new(&config);
1472        let seq_len = 3;
1473        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1474        let output = attn.forward(&x, seq_len);
1475
1476        assert_eq!(
1477            output.len(),
1478            seq_len * config.hidden_size,
1479            "FALSIFY-A3e: GQA output must be seq_len * hidden_size, not seq_len * kv_hidden"
1480        );
1481    }
1482
1483    /// FALSIFY-A4e: Attention init produces non-degenerate values
1484    ///
1485    /// Like FALSIFY-E7a for embeddings: init must produce varied, finite values.
1486    #[test]
1487    fn falsify_a4e_init_produces_valid_attention_weights() {
1488        let config = TransformerConfig::tiny();
1489        let attn = MultiHeadAttention::new(&config);
1490
1491        for (name, w) in
1492            [("w_q", &attn.w_q), ("w_k", &attn.w_k), ("w_v", &attn.w_v), ("w_o", &attn.w_o)]
1493        {
1494            let data = w.data();
1495            let slice = data.as_slice().expect("data as slice");
1496
1497            // No NaN
1498            let nan_count = slice.iter().filter(|v| v.is_nan()).count();
1499            assert_eq!(nan_count, 0, "FALSIFY-A4e: {name} init must not contain NaN");
1500
1501            // No Inf
1502            let inf_count = slice.iter().filter(|v| v.is_infinite()).count();
1503            assert_eq!(inf_count, 0, "FALSIFY-A4e: {name} init must not contain Inf");
1504
1505            // Values vary
1506            let min = slice.iter().copied().fold(f32::INFINITY, f32::min);
1507            let max = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1508            assert!(
1509                (max - min).abs() > 1e-6,
1510                "FALSIFY-A4e: {name} init values are constant ({min}..{max}) — degenerate weight"
1511            );
1512        }
1513    }
1514
1515    /// FALSIFY-A5e: Attention forward produces finite outputs
1516    ///
1517    /// If any attention weight is degenerate, output should still be finite
1518    /// (the init is designed to prevent this).
1519    #[test]
1520    fn falsify_a5e_forward_produces_finite_output() {
1521        let config = TransformerConfig::tiny();
1522        let attn = MultiHeadAttention::new(&config);
1523        let seq_len = 4;
1524        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1525        let output = attn.forward(&x, seq_len);
1526
1527        let data = output.data();
1528        let nan_count = data.iter().filter(|v| v.is_nan()).count();
1529        let inf_count = data.iter().filter(|v| v.is_infinite()).count();
1530        assert_eq!(nan_count, 0, "FALSIFY-A5e: Attention output must not contain NaN");
1531        assert_eq!(inf_count, 0, "FALSIFY-A5e: Attention output must not contain Inf");
1532    }
1533
1534    // =========================================================================
1535    // FALSIFY-GQ: gqa-kernel-v1.yaml contract (entrenar MultiHeadAttention GQA)
1536    //
1537    // Five-Whys (PMAT-354):
1538    //   Why 1: entrenar had FALSIFY-A tests but zero FALSIFY-GQ-* tests
1539    //   Why 2: FALSIFY-A tests verify projections/shapes, not GQA invariants
1540    //   Why 3: no mapping from gqa-kernel-v1.yaml to entrenar test names
1541    //   Why 4: entrenar's GQA support added after FALSIFY-A tests
1542    //   Why 5: GQA was "obviously correct" (just index K/V by h/heads_per_kv)
1543    //
1544    // References:
1545    //   - provable-contracts/contracts/gqa-kernel-v1.yaml
1546    //   - Ainslie et al. (2023) "GQA: Training Generalized MQT Models"
1547    // =========================================================================
1548
1549    /// FALSIFY-GQ-001e: GQA output shape correct for various head configs
1550    #[test]
1551    fn falsify_gq_001e_output_shape() {
1552        for (num_heads, num_kv_heads) in [(2, 2), (4, 2), (4, 1), (2, 1)] {
1553            let mut config = TransformerConfig::tiny();
1554            config.num_attention_heads = num_heads;
1555            config.num_kv_heads = num_kv_heads;
1556
1557            let attn = MultiHeadAttention::new(&config);
1558            let seq_len = 3;
1559            let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1560            let output = attn.forward(&x, seq_len);
1561
1562            assert_eq!(
1563                output.len(),
1564                seq_len * config.hidden_size,
1565                "FALSIFIED GQ-001e: output len mismatch for heads={num_heads},kv={num_kv_heads}"
1566            );
1567        }
1568    }
1569
1570    /// FALSIFY-GQ-002e: MHA degeneration — kv_heads == num_heads produces finite output
1571    #[test]
1572    fn falsify_gq_002e_mha_degeneration() {
1573        let config = TransformerConfig::tiny(); // num_heads == num_kv_heads == 2
1574        assert_eq!(config.num_attention_heads, config.num_kv_heads);
1575
1576        let attn = MultiHeadAttention::new(&config);
1577        let seq_len = 4;
1578        let x = Tensor::from_vec(
1579            (0..seq_len * config.hidden_size).map(|i| (i as f32 * 0.37).sin()).collect(),
1580            true,
1581        );
1582        let output = attn.forward(&x, seq_len);
1583
1584        let data = output.data();
1585        for (i, v) in data.iter().enumerate() {
1586            assert!(v.is_finite(), "FALSIFIED GQ-002e: MHA output[{i}] = {v} (not finite)");
1587        }
1588    }
1589
1590    /// FALSIFY-GQ-004e: Head divisibility — GQA requires num_heads % num_kv_heads == 0
1591    #[test]
1592    fn falsify_gq_004e_head_divisibility() {
1593        // Valid configurations should not panic
1594        for (nh, nkv) in [(2, 1), (2, 2), (4, 1), (4, 2), (4, 4), (8, 2), (8, 4)] {
1595            let mut config = TransformerConfig::tiny();
1596            config.num_attention_heads = nh;
1597            config.num_kv_heads = nkv;
1598            assert_eq!(nh % nkv, 0, "FALSIFIED GQ-004e: test config has invalid head ratio");
1599            // Should not panic during construction or forward
1600            let attn = MultiHeadAttention::new(&config);
1601            let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
1602            let _ = attn.forward(&x, 2);
1603        }
1604    }
1605
1606    /// FALSIFY-GQ-006e: MQA boundary — kv_heads=1 broadcasts single KV to all heads
1607    #[test]
1608    fn falsify_gq_006e_mqa_boundary() {
1609        let mut config = TransformerConfig::tiny();
1610        config.num_attention_heads = 4;
1611        config.num_kv_heads = 1;
1612        // Adjust hidden_size to be divisible by 4 heads
1613        config.hidden_size = 64;
1614
1615        let attn = MultiHeadAttention::new(&config);
1616        let seq_len = 3;
1617        let x = Tensor::from_vec(
1618            (0..seq_len * config.hidden_size).map(|i| (i as f32 * 0.73).cos()).collect(),
1619            true,
1620        );
1621        let output = attn.forward(&x, seq_len);
1622
1623        assert_eq!(
1624            output.len(),
1625            seq_len * config.hidden_size,
1626            "FALSIFIED GQ-006e: MQA output size wrong"
1627        );
1628
1629        // All finite
1630        let data = output.data();
1631        for (i, v) in data.iter().enumerate() {
1632            assert!(v.is_finite(), "FALSIFIED GQ-006e: MQA output[{i}] = {v} (not finite)");
1633        }
1634    }
1635
1636    mod gq_proptest_falsify {
1637        use super::*;
1638        use proptest::prelude::*;
1639
1640        // FALSIFY-GQ-001e-prop: GQA output shape for random configs
1641        proptest! {
1642            #![proptest_config(ProptestConfig::with_cases(50))]
1643
1644            #[test]
1645            fn falsify_gq_001e_prop_output_shape(
1646                config_idx in 0..4usize,
1647                seq_len in 2..=6usize,
1648                seed in 0..500u32,
1649            ) {
1650                let configs: [(usize, usize); 4] = [
1651                    (2, 2), (2, 1), (4, 2), (4, 1),
1652                ];
1653                let (num_heads, num_kv_heads) = configs[config_idx];
1654                let mut config = TransformerConfig::tiny();
1655                config.num_attention_heads = num_heads;
1656                config.num_kv_heads = num_kv_heads;
1657
1658                let attn = MultiHeadAttention::new(&config);
1659                let data: Vec<f32> = (0..seq_len * config.hidden_size)
1660                    .map(|i| ((i as f32 + seed as f32) * 0.37).sin())
1661                    .collect();
1662                let x = Tensor::from_vec(data, true);
1663                let output = attn.forward(&x, seq_len);
1664
1665                prop_assert_eq!(
1666                    output.len(),
1667                    seq_len * config.hidden_size,
1668                    "FALSIFIED GQ-001e-prop: output len mismatch"
1669                );
1670
1671                // All finite
1672                for v in output.data() {
1673                    prop_assert!(
1674                        v.is_finite(),
1675                        "FALSIFIED GQ-001e-prop: non-finite output"
1676                    );
1677                }
1678            }
1679        }
1680
1681        // FALSIFY-GQ-006e-prop: MQA boundary with random inputs
1682        proptest! {
1683            #![proptest_config(ProptestConfig::with_cases(30))]
1684
1685            #[test]
1686            fn falsify_gq_006e_prop_mqa_boundary(
1687                seed in 0..500u32,
1688                seq_len in 2..=5usize,
1689            ) {
1690                let mut config = TransformerConfig::tiny();
1691                config.num_attention_heads = 4;
1692                config.num_kv_heads = 1;
1693                config.hidden_size = 64;
1694
1695                let attn = MultiHeadAttention::new(&config);
1696                let data: Vec<f32> = (0..seq_len * config.hidden_size)
1697                    .map(|i| ((i as f32 + seed as f32) * 0.73).cos())
1698                    .collect();
1699                let x = Tensor::from_vec(data, true);
1700                let output = attn.forward(&x, seq_len);
1701
1702                prop_assert_eq!(
1703                    output.len(),
1704                    seq_len * config.hidden_size,
1705                    "FALSIFIED GQ-006e-prop: MQA output len mismatch"
1706                );
1707
1708                for v in output.data() {
1709                    prop_assert!(
1710                        v.is_finite(),
1711                        "FALSIFIED GQ-006e-prop: non-finite MQA output"
1712                    );
1713                }
1714            }
1715        }
1716    }
1717
1718    #[test]
1719    fn test_attention_from_params_with_biases() {
1720        let config = TransformerConfig::tiny();
1721        let hidden_size = config.hidden_size;
1722        let kv_hidden_size = config.num_kv_heads * config.head_dim();
1723
1724        let mut params = HashMap::new();
1725        params.insert(
1726            "attn.q_proj.weight".to_string(),
1727            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1728        );
1729        params.insert(
1730            "attn.k_proj.weight".to_string(),
1731            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1732        );
1733        params.insert(
1734            "attn.v_proj.weight".to_string(),
1735            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1736        );
1737        params.insert(
1738            "attn.o_proj.weight".to_string(),
1739            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1740        );
1741        params.insert(
1742            "attn.q_proj.bias".to_string(),
1743            Tensor::from_vec(vec![0.01; hidden_size], true),
1744        );
1745        params.insert(
1746            "attn.k_proj.bias".to_string(),
1747            Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1748        );
1749        params.insert(
1750            "attn.v_proj.bias".to_string(),
1751            Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1752        );
1753
1754        let attn = MultiHeadAttention::from_params(&config, &params, "attn");
1755        assert!(attn.is_some());
1756        let attn = attn.expect("should load with biases");
1757        assert!(attn.has_biases());
1758        assert_eq!(attn.parameters().len(), 7);
1759    }
1760
1761    #[test]
1762    fn test_attention_named_parameters_with_biases() {
1763        let config = TransformerConfig::tiny();
1764        let hidden_size = config.hidden_size;
1765        let kv_hidden_size = config.num_kv_heads * config.head_dim();
1766
1767        let mut params = HashMap::new();
1768        params.insert(
1769            "attn.q_proj.weight".to_string(),
1770            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1771        );
1772        params.insert(
1773            "attn.k_proj.weight".to_string(),
1774            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1775        );
1776        params.insert(
1777            "attn.v_proj.weight".to_string(),
1778            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1779        );
1780        params.insert(
1781            "attn.o_proj.weight".to_string(),
1782            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1783        );
1784        params.insert(
1785            "attn.q_proj.bias".to_string(),
1786            Tensor::from_vec(vec![0.01; hidden_size], true),
1787        );
1788        params.insert(
1789            "attn.k_proj.bias".to_string(),
1790            Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1791        );
1792        params.insert(
1793            "attn.v_proj.bias".to_string(),
1794            Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1795        );
1796
1797        let attn = MultiHeadAttention::from_params(&config, &params, "attn").expect("should load");
1798        let named = attn.named_parameters("attn");
1799        assert_eq!(named.len(), 7);
1800        let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
1801        assert!(names.contains(&"attn.q_proj.bias"));
1802        assert!(names.contains(&"attn.k_proj.bias"));
1803        assert!(names.contains(&"attn.v_proj.bias"));
1804    }
1805
1806    #[test]
1807    fn test_attention_forward_with_biases() {
1808        let config = TransformerConfig::tiny();
1809        let hidden_size = config.hidden_size;
1810        let kv_hidden_size = config.num_kv_heads * config.head_dim();
1811
1812        let mut params = HashMap::new();
1813        params.insert(
1814            "attn.q_proj.weight".to_string(),
1815            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1816        );
1817        params.insert(
1818            "attn.k_proj.weight".to_string(),
1819            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1820        );
1821        params.insert(
1822            "attn.v_proj.weight".to_string(),
1823            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1824        );
1825        params.insert(
1826            "attn.o_proj.weight".to_string(),
1827            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1828        );
1829        params
1830            .insert("attn.q_proj.bias".to_string(), Tensor::from_vec(vec![0.5; hidden_size], true));
1831        params.insert(
1832            "attn.k_proj.bias".to_string(),
1833            Tensor::from_vec(vec![0.5; kv_hidden_size], true),
1834        );
1835        params.insert(
1836            "attn.v_proj.bias".to_string(),
1837            Tensor::from_vec(vec![0.5; kv_hidden_size], true),
1838        );
1839
1840        let attn = MultiHeadAttention::from_params(&config, &params, "attn").expect("should load");
1841        let x = Tensor::from_vec(vec![0.1; 2 * hidden_size], false);
1842        let output = attn.forward(&x, 2);
1843        assert_eq!(output.len(), 2 * hidden_size);
1844        assert!(output.data().iter().all(|v| v.is_finite()));
1845    }
1846}