Skip to main content

entrenar/transformer/
attention.rs

1//! Multi-head attention module
2//!
3//! This module provides multi-head self-attention with grouped-query attention support.
4
5use crate::autograd::{matmul, matmul_nt, BackwardOp};
6use crate::Tensor;
7use ndarray::Array1;
8use std::cell::RefCell;
9use std::collections::HashMap;
10use std::rc::Rc;
11
12use super::config::TransformerConfig;
13
14/// Add a bias vector to a projected tensor: output[s] += bias for each sequence position.
15/// Input shape: (seq_len × dim) flattened. Bias shape: (dim).
16fn add_bias(x: &Tensor, bias: &Tensor, seq_len: usize) -> Tensor {
17    let xd = x.data();
18    let x_slice = xd.as_slice().expect("contiguous projection");
19    let bd = bias.data();
20    let b_slice = bd.as_slice().expect("contiguous bias");
21    let dim = b_slice.len();
22    let mut out = Vec::with_capacity(x_slice.len());
23    for s in 0..seq_len {
24        let base = s * dim;
25        for d in 0..dim {
26            out.push(x_slice[base + d] + b_slice[d]);
27        }
28    }
29    Tensor::from_vec(out, x.requires_grad())
30}
31
32/// Apply per-head RMSNorm to Q or K (Qwen3 QK-norm, ENT-269).
33///
34/// Input: [seq_len * total_dim] where total_dim = num_heads * head_dim.
35/// Norm weight: [head_dim]. Applied independently to each head's head_dim slice.
36fn apply_qk_norm(
37    x: &Tensor,
38    norm_weight: &Tensor,
39    seq_len: usize,
40    num_heads: usize,
41    head_dim: usize,
42) -> Tensor {
43    let xd = x.data();
44    let x_slice = xd.as_slice().expect("contiguous qk");
45    let wd = norm_weight.data();
46    let w_slice = wd.as_slice().expect("contiguous norm weight");
47    let total_dim = num_heads * head_dim;
48    let eps = 1e-6_f32;
49    let mut out = vec![0.0f32; seq_len * total_dim];
50
51    for s in 0..seq_len {
52        for h in 0..num_heads {
53            let offset = s * total_dim + h * head_dim;
54            // RMSNorm: x * weight / sqrt(mean(x^2) + eps)
55            let mut sum_sq = 0.0f32;
56            for d in 0..head_dim {
57                let v = x_slice[offset + d];
58                sum_sq += v * v;
59            }
60            let rms = (sum_sq / head_dim as f32 + eps).sqrt();
61            let inv_rms = 1.0 / rms;
62            for d in 0..head_dim {
63                out[offset + d] = x_slice[offset + d] * inv_rms * w_slice[d];
64            }
65        }
66    }
67
68    Tensor::from_vec(out, x.requires_grad())
69}
70
71/// Apply Rotary Position Embedding (RoPE) to Q or K tensor (ENT-269).
72///
73/// Uses Llama/Qwen3 half-rotation layout (NOT interleaved pairs):
74///   x1 = x[..., :half_dim], x2 = x[..., half_dim:]
75///   rotate_half(x) = [-x2, x1]
76///   result = x * cos + rotate_half(x) * sin
77///
78/// freq[i] = 1 / (theta ^ (2i / head_dim))
79fn apply_rope(
80    x: &Tensor,
81    seq_len: usize,
82    num_heads: usize,
83    head_dim: usize,
84    rope_theta: f32,
85) -> Tensor {
86    let xd = x.data();
87    let x_slice = xd.as_slice().expect("contiguous qk for rope");
88    let total_dim = num_heads * head_dim;
89    let half_dim = head_dim / 2;
90    let mut out = vec![0.0f32; seq_len * total_dim];
91
92    // Precompute inverse frequencies: 1 / (theta ^ (2i / head_dim))
93    let inv_freq: Vec<f32> =
94        (0..half_dim).map(|i| 1.0 / rope_theta.powf(2.0 * i as f32 / head_dim as f32)).collect();
95
96    for pos in 0..seq_len {
97        for h in 0..num_heads {
98            let offset = pos * total_dim + h * head_dim;
99            for i in 0..half_dim {
100                let freq = pos as f32 * inv_freq[i];
101                let cos_f = freq.cos();
102                let sin_f = freq.sin();
103                // Half-rotation: pair (x[i], x[i + half_dim])
104                let x_first = x_slice[offset + i];
105                let x_second = x_slice[offset + i + half_dim];
106                // rotate_half: [-x_second, x_first]
107                out[offset + i] = x_first * cos_f - x_second * sin_f;
108                out[offset + i + half_dim] = x_second * cos_f + x_first * sin_f;
109            }
110        }
111    }
112
113    let result = Tensor::from_vec(out, x.requires_grad());
114    contract_post_rope!(result.data().as_slice().unwrap_or(&[]));
115    result
116}
117
118// ---------------------------------------------------------------------------
119// AttentionBlockBackward: combined backward for multi-head attention
120//
121// Orchestrates gradient flow: concat → per-head attention → Q/K/V projections.
122// Calls each Q/K/V matmul backward exactly once to avoid gradient inflation.
123// ---------------------------------------------------------------------------
124
125struct AttentionBlockBackward {
126    q: Tensor,
127    k: Tensor,
128    v: Tensor,
129    head_q_tensors: Vec<Tensor>,
130    head_k_tensors: Vec<Tensor>,
131    head_v_tensors: Vec<Tensor>,
132    head_outputs: Vec<Tensor>,
133    head_kv_indices: Vec<usize>,
134    seq_len: usize,
135    head_dim: usize,
136    q_dim: usize,
137    kv_hidden_size: usize,
138    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
139}
140
141impl BackwardOp for AttentionBlockBackward {
142    fn backward(&self) {
143        let Some(grad_out) = self.result_grad.borrow().as_ref().cloned() else { return };
144        let go = grad_out.as_slice().expect("grad contiguous");
145        let h = self.head_dim;
146
147        // Step 1: Split concat grad per head and trigger each attention backward
148        split_and_backward_heads(go, &self.head_outputs, self.seq_len, h, self.q_dim);
149
150        // Step 2-4: Scatter per-head grads into full Q/K/V
151        scatter_head_grads_q(&self.q, &self.head_q_tensors, self.seq_len, h, self.q_dim);
152        scatter_head_grads_kv(
153            &self.k,
154            &self.head_k_tensors,
155            &self.head_kv_indices,
156            self.seq_len,
157            h,
158            self.kv_hidden_size,
159        );
160        scatter_head_grads_kv(
161            &self.v,
162            &self.head_v_tensors,
163            &self.head_kv_indices,
164            self.seq_len,
165            h,
166            self.kv_hidden_size,
167        );
168
169        // Step 5: Propagate backward through Q/K/V matmuls (once each)
170        for proj in [&self.q, &self.k, &self.v] {
171            if let Some(op) = proj.backward_op() {
172                op.backward();
173            }
174        }
175    }
176}
177
178/// Split concat gradient per head and trigger each head's attention backward
179fn split_and_backward_heads(
180    go: &[f32],
181    head_outputs: &[Tensor],
182    seq_len: usize,
183    head_dim: usize,
184    q_dim: usize,
185) {
186    for (head_idx, head_out) in head_outputs.iter().enumerate() {
187        let mut grad_head = vec![0.0_f32; seq_len * head_dim];
188        for s in 0..seq_len {
189            let src_base = s * q_dim + head_idx * head_dim;
190            let dst_base = s * head_dim;
191            grad_head[dst_base..dst_base + head_dim]
192                .copy_from_slice(&go[src_base..src_base + head_dim]);
193        }
194        head_out.accumulate_grad(Array1::from(grad_head));
195        if let Some(op) = head_out.backward_op() {
196            op.backward();
197        }
198    }
199}
200
201/// Scatter per-head Q gradients into the full Q projection tensor
202fn scatter_head_grads_q(
203    q: &Tensor,
204    head_q_tensors: &[Tensor],
205    seq_len: usize,
206    head_dim: usize,
207    q_dim: usize,
208) {
209    if !q.requires_grad() {
210        return;
211    }
212    let mut grad_q = vec![0.0_f32; seq_len * q_dim];
213    for (head_idx, head_q) in head_q_tensors.iter().enumerate() {
214        if let Some(hgrad) = head_q.grad() {
215            let hg = hgrad.as_slice().expect("contiguous");
216            for s in 0..seq_len {
217                let src_base = s * head_dim;
218                let dst_base = s * q_dim + head_idx * head_dim;
219                for d in 0..head_dim {
220                    grad_q[dst_base + d] += hg[src_base + d];
221                }
222            }
223        }
224    }
225    q.accumulate_grad(Array1::from(grad_q));
226}
227
228/// Scatter per-head K or V gradients into the full K/V projection tensor (GQA-correct)
229fn scatter_head_grads_kv(
230    target: &Tensor,
231    head_tensors: &[Tensor],
232    kv_indices: &[usize],
233    seq_len: usize,
234    head_dim: usize,
235    kv_hidden_size: usize,
236) {
237    if !target.requires_grad() {
238        return;
239    }
240    let mut grad = vec![0.0_f32; seq_len * kv_hidden_size];
241    for (head_idx, head_t) in head_tensors.iter().enumerate() {
242        let kv_h = kv_indices[head_idx];
243        if let Some(hgrad) = head_t.grad() {
244            let hg = hgrad.as_slice().expect("contiguous");
245            for s in 0..seq_len {
246                let src_base = s * head_dim;
247                let dst_base = s * kv_hidden_size + kv_h * head_dim;
248                for d in 0..head_dim {
249                    grad[dst_base + d] += hg[src_base + d];
250                }
251            }
252        }
253    }
254    target.accumulate_grad(Array1::from(grad));
255}
256
257/// Multi-head self-attention layer
258pub struct MultiHeadAttention {
259    /// Configuration
260    config: TransformerConfig,
261    /// Query projection weight (hidden_size x hidden_size)
262    pub w_q: Tensor,
263    /// Key projection weight (hidden_size x kv_hidden_size)
264    pub w_k: Tensor,
265    /// Value projection weight (hidden_size x kv_hidden_size)
266    pub w_v: Tensor,
267    /// Output projection weight (hidden_size x hidden_size)
268    pub w_o: Tensor,
269    /// Optional query bias (Qwen2 uses attention biases)
270    pub b_q: Option<Tensor>,
271    /// Optional key bias
272    pub b_k: Option<Tensor>,
273    /// Optional value bias
274    pub b_v: Option<Tensor>,
275    /// Optional Q RMSNorm weight (Qwen3 uses QK-norm, shape=[head_dim])
276    pub q_norm: Option<Tensor>,
277    /// Optional K RMSNorm weight (Qwen3 uses QK-norm, shape=[head_dim])
278    pub k_norm: Option<Tensor>,
279}
280
281impl MultiHeadAttention {
282    /// Create new attention layer with initialized weights.
283    ///
284    /// When `config.use_bias == true` (Qwen2 family), allocates Q/K/V
285    /// projection biases as zero tensors. The forward pass already
286    /// honors `Option<Tensor>` biases (lines 388-395 add via
287    /// `add_bias` when `Some`); without allocating them here, biases
288    /// stay `None` and `populate_trainer_from_init_tensors` silently
289    /// drops the corresponding init tensors during fine-tune from a
290    /// Qwen APR checkpoint — see FALSIFY-APR-PRETRAIN-INIT-POPULATE-
291    /// COVERAGE-001/002 in `transformer::config::tests` for the
292    /// 290-vs-218 named-parameters gap that surfaced this bug.
293    ///
294    /// Zero-init for biases matches HuggingFace LLaMA / Qwen
295    /// convention (PyTorch `nn.Linear(bias=True)` initializes the
296    /// weight with `kaiming_uniform_` but the bias as the all-zeros
297    /// tensor — see `torch.nn.modules.linear.Linear.reset_parameters`).
298    pub fn new(config: &TransformerConfig) -> Self {
299        use super::init::{get_init_seed, rand_normal_seeded};
300        let hidden_size = config.hidden_size;
301        let q_dim = config.q_dim();
302        let kv_hidden_size = config.num_kv_heads * config.head_dim();
303        let seed = get_init_seed();
304
305        let (b_q, b_k, b_v) = if config.use_bias {
306            (
307                Some(Tensor::from_vec(vec![0.0_f32; q_dim], true)),
308                Some(Tensor::from_vec(vec![0.0_f32; kv_hidden_size], true)),
309                Some(Tensor::from_vec(vec![0.0_f32; kv_hidden_size], true)),
310            )
311        } else {
312            (None, None, None)
313        };
314
315        // C-INIT-001: normal(0, 0.02) matching HuggingFace LLaMA
316        Self {
317            config: config.clone(),
318            w_q: Tensor::from_vec(rand_normal_seeded(q_dim * hidden_size, seed, "w_q"), true),
319            w_k: Tensor::from_vec(
320                rand_normal_seeded(kv_hidden_size * hidden_size, seed, "w_k"),
321                true,
322            ),
323            w_v: Tensor::from_vec(
324                rand_normal_seeded(kv_hidden_size * hidden_size, seed, "w_v"),
325                true,
326            ),
327            w_o: Tensor::from_vec(rand_normal_seeded(hidden_size * q_dim, seed, "w_o"), true),
328            b_q,
329            b_k,
330            b_v,
331            q_norm: None,
332            k_norm: None,
333        }
334    }
335
336    /// Create attention layer from parameter map
337    ///
338    /// Expected parameter names (following HuggingFace convention):
339    /// - `{prefix}.q_proj.weight`
340    /// - `{prefix}.k_proj.weight`
341    /// - `{prefix}.v_proj.weight`
342    /// - `{prefix}.o_proj.weight`
343    /// # Contract (PMAT-331)
344    /// Validates Q/K/V/O projection shapes against config dimensions.
345    /// Returns None if any key is missing or shape is wrong.
346    pub fn from_params(
347        config: &TransformerConfig,
348        params: &HashMap<String, Tensor>,
349        prefix: &str,
350    ) -> Option<Self> {
351        let w_q = params.get(&format!("{prefix}.q_proj.weight"))?.clone();
352        let w_k = params.get(&format!("{prefix}.k_proj.weight"))?.clone();
353        let w_v = params.get(&format!("{prefix}.v_proj.weight"))?.clone();
354        let w_o = params.get(&format!("{prefix}.o_proj.weight"))?.clone();
355
356        let hidden = config.hidden_size;
357        let q_dim = config.q_dim();
358        let kv_hidden = config.num_kv_heads * config.head_dim();
359
360        // PMAT-331: Shape validation for attention projections
361        // Q: [q_dim, hidden], K: [kv_hidden, hidden], V: [kv_hidden, hidden], O: [hidden, q_dim]
362        let checks: &[(&str, &Tensor, usize)] = &[
363            ("q_proj", &w_q, q_dim * hidden),
364            ("k_proj", &w_k, kv_hidden * hidden),
365            ("v_proj", &w_v, kv_hidden * hidden),
366            ("o_proj", &w_o, hidden * q_dim),
367        ];
368        for &(name, tensor, expected) in checks {
369            if tensor.len() != expected {
370                eprintln!(
371                    "[PMAT-331] {prefix}.{name}: shape mismatch — got {} elements, expected {expected}",
372                    tensor.len()
373                );
374                return None;
375            }
376        }
377
378        // Optional attention biases (Qwen2 uses Q/K/V biases)
379        let b_q = params.get(&format!("{prefix}.q_proj.bias")).cloned();
380        let b_k = params.get(&format!("{prefix}.k_proj.bias")).cloned();
381        let b_v = params.get(&format!("{prefix}.v_proj.bias")).cloned();
382
383        // Optional Q/K RMSNorm (Qwen3 uses QK-norm, ENT-269)
384        let q_norm = params.get(&format!("{prefix}.q_norm.weight")).cloned();
385        let k_norm = params.get(&format!("{prefix}.k_norm.weight")).cloned();
386
387        Some(Self { config: config.clone(), w_q, w_k, w_v, w_o, b_q, b_k, b_v, q_norm, k_norm })
388    }
389
390    /// Forward pass
391    ///
392    /// # Arguments
393    /// * `x` - Input tensor (seq_len * hidden_size, flattened)
394    /// * `seq_len` - Sequence length
395    ///
396    /// # Returns
397    /// Output tensor (seq_len * hidden_size, flattened)
398    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
399        contract_pre_attention!(x.data());
400        let hidden_size = self.config.hidden_size;
401        let num_heads = self.config.num_attention_heads;
402        let num_kv_heads = self.config.num_kv_heads;
403        let head_dim = self.config.head_dim();
404        let q_dim = self.config.q_dim();
405        let kv_hidden_size = num_kv_heads * head_dim;
406
407        // Project Q, K, V — HF weights are [out_dim, in_dim], use matmul_nt (ENT-269)
408        let mut q = matmul_nt(x, &self.w_q, seq_len, hidden_size, q_dim);
409        let mut k = matmul_nt(x, &self.w_k, seq_len, hidden_size, kv_hidden_size);
410        let mut v = matmul_nt(x, &self.w_v, seq_len, hidden_size, kv_hidden_size);
411
412        // Apply attention biases if present (Qwen2 architecture)
413        if let Some(ref b_q) = self.b_q {
414            q = add_bias(&q, b_q, seq_len);
415        }
416        if let Some(ref b_k) = self.b_k {
417            k = add_bias(&k, b_k, seq_len);
418        }
419        if let Some(ref b_v) = self.b_v {
420            v = add_bias(&v, b_v, seq_len);
421        }
422
423        // Apply Q/K RMSNorm if present (Qwen3 QK-norm, ENT-269)
424        if let Some(ref qn) = self.q_norm {
425            q = apply_qk_norm(&q, qn, seq_len, num_heads, head_dim);
426        }
427        if let Some(ref kn) = self.k_norm {
428            k = apply_qk_norm(&k, kn, seq_len, num_kv_heads, head_dim);
429        }
430
431        // Apply Rotary Position Embedding (RoPE) to Q and K (ENT-269)
432        // Skip for encoder models (BERT/RoBERTa use learned positions, not RoPE)
433        if self.config.rope_theta > 0.0 {
434            q = apply_rope(&q, seq_len, num_heads, head_dim, self.config.rope_theta);
435            k = apply_rope(&k, seq_len, num_kv_heads, head_dim, self.config.rope_theta);
436        }
437
438        let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
439        let heads_per_kv = num_heads / num_kv_heads;
440
441        // KAIZEN-016: Hoist data borrows outside the head loop to avoid
442        // num_heads × seq_len × 3 redundant RefCell borrows per attention call.
443        let q_data = q.data();
444        let q_slice = q_data.as_slice().expect("contiguous Q");
445        let k_data = k.data();
446        let k_slice = k_data.as_slice().expect("contiguous K");
447        let v_data = v.data();
448        let v_slice = v_data.as_slice().expect("contiguous V");
449
450        // Per-head attention with gradient tracking
451        let mut head_q_tensors = Vec::with_capacity(num_heads);
452        let mut head_k_tensors = Vec::with_capacity(num_heads);
453        let mut head_v_tensors = Vec::with_capacity(num_heads);
454        let mut head_outputs = Vec::with_capacity(num_heads);
455        let mut head_kv_indices = Vec::with_capacity(num_heads);
456
457        for h in 0..num_heads {
458            let kv_h = h / heads_per_kv;
459            head_kv_indices.push(kv_h);
460
461            // KAIZEN-016: Use extend_from_slice instead of flat_map+to_vec.
462            // Eliminates num_heads × 3 × seq_len intermediate Vec allocations
463            // (3.5M allocs/forward for Qwen3-4B).
464            let mut q_head = Vec::with_capacity(seq_len * head_dim);
465            for s in 0..seq_len {
466                let start = s * q_dim + h * head_dim;
467                q_head.extend_from_slice(&q_slice[start..start + head_dim]);
468            }
469
470            let mut k_head = Vec::with_capacity(seq_len * head_dim);
471            for s in 0..seq_len {
472                let start = s * kv_hidden_size + kv_h * head_dim;
473                k_head.extend_from_slice(&k_slice[start..start + head_dim]);
474            }
475
476            let mut v_head = Vec::with_capacity(seq_len * head_dim);
477            for s in 0..seq_len {
478                let start = s * kv_hidden_size + kv_h * head_dim;
479                v_head.extend_from_slice(&v_slice[start..start + head_dim]);
480            }
481
482            let q_tensor = Tensor::from_vec(q_head, requires_grad);
483            let k_tensor = Tensor::from_vec(k_head, requires_grad);
484            let v_tensor = Tensor::from_vec(v_head, requires_grad);
485
486            let attn_out = crate::autograd::attention(
487                &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
488            );
489
490            head_q_tensors.push(q_tensor);
491            head_k_tensors.push(k_tensor);
492            head_v_tensors.push(v_tensor);
493            head_outputs.push(attn_out);
494        }
495
496        // Concatenate heads: reorder from per-head (head, seq, dim) to (seq, head*dim)
497        let mut concat_output = vec![0.0; seq_len * q_dim];
498        for (h, head_out) in head_outputs.iter().enumerate() {
499            let hd = head_out.data();
500            let hdata = hd.as_slice().expect("contiguous attention output");
501            for s in 0..seq_len {
502                let src_base = s * head_dim;
503                let dst_base = s * q_dim + h * head_dim;
504                concat_output[dst_base..dst_base + head_dim]
505                    .copy_from_slice(&hdata[src_base..src_base + head_dim]);
506            }
507        }
508
509        let mut concat_tensor = Tensor::from_vec(concat_output, requires_grad);
510
511        if requires_grad {
512            let backward_op = Rc::new(AttentionBlockBackward {
513                q: q.clone(),
514                k: k.clone(),
515                v: v.clone(),
516                head_q_tensors,
517                head_k_tensors,
518                head_v_tensors,
519                head_outputs,
520                head_kv_indices,
521                seq_len,
522                head_dim,
523                q_dim,
524                kv_hidden_size,
525                result_grad: concat_tensor.grad_cell(),
526            });
527            concat_tensor.set_backward_op(backward_op);
528        }
529
530        // Output projection — w_o is [hidden_size, q_dim] in HF (ENT-269)
531        let result = matmul_nt(&concat_tensor, &self.w_o, seq_len, q_dim, hidden_size);
532        contract_post_attention!(result.data().as_slice().unwrap_or(&[]));
533        result
534    }
535
536    /// Forward pass with LoRA adjusts on Q and V projections (KAIZEN-010).
537    ///
538    /// Applies LoRA adapters to Q and V during the forward pass so that
539    /// gradients flow through LoRA A/B matrices on non-CUDA paths.
540    ///
541    /// # Arguments
542    /// * `x` - Input tensor (seq_len * hidden_size)
543    /// * `seq_len` - Sequence length
544    /// * `lora_a_q`, `lora_b_q` - Q projection LoRA matrices (rank×d_in, d_out×rank)
545    /// * `lora_a_v`, `lora_b_v` - V projection LoRA matrices (rank×d_in, d_out×rank)
546    /// * `lora_rank` - LoRA rank
547    /// * `lora_scale` - LoRA scaling factor (alpha/rank)
548    pub fn forward_with_lora(
549        &self,
550        x: &Tensor,
551        seq_len: usize,
552        lora_a_q: &Tensor,
553        // contract_pre_attention applied via forward()
554        lora_b_q: &Tensor,
555        lora_a_v: &Tensor,
556        lora_b_v: &Tensor,
557        lora_rank: usize,
558        lora_scale: f32,
559    ) -> Tensor {
560        contract_pre_lora_forward!();
561        let hidden_size = self.config.hidden_size;
562        let num_heads = self.config.num_attention_heads;
563        let num_kv_heads = self.config.num_kv_heads;
564        let head_dim = self.config.head_dim();
565        let q_dim = self.config.q_dim();
566        let kv_hidden_size = num_kv_heads * head_dim;
567
568        // Q projection with LoRA: Q = x @ W_q + scale * (x @ A_q^T) @ B_q^T
569        //
570        // KAIZEN-011: Use matmul_nt to compute x @ A^T directly on the ORIGINAL
571        // LoRA tensors. Previous impl created transposed copies via Tensor::from_vec
572        // which broke gradient flow — gradients accumulated on ephemeral copies
573        // instead of the actual trainable LoRA parameters.
574        //
575        // LoRA layout: A is (rank, d_in), B is (d_out, rank)
576        // matmul_nt(x, A, seq, d_in, rank) computes x @ A^T = (seq, d_in) @ (d_in, rank) = (seq, rank)
577        // matmul_nt(mid, B, seq, rank, d_out) computes mid @ B^T = (seq, rank) @ (rank, d_out) = (seq, d_out)
578        let q_base = matmul_nt(x, &self.w_q, seq_len, hidden_size, q_dim);
579        let q_mid = crate::autograd::matmul_nt(x, lora_a_q, seq_len, hidden_size, lora_rank);
580        let q_lora = crate::autograd::matmul_nt(&q_mid, lora_b_q, seq_len, lora_rank, q_dim);
581        let q = crate::autograd::add_scaled(&q_base, &q_lora, lora_scale);
582
583        // K projection (no LoRA) — HF weights [out, in] (ENT-269)
584        let k = matmul_nt(x, &self.w_k, seq_len, hidden_size, kv_hidden_size);
585
586        // V projection with LoRA (same pattern as Q)
587        let v_base = matmul_nt(x, &self.w_v, seq_len, hidden_size, kv_hidden_size);
588        let v_mid = crate::autograd::matmul_nt(x, lora_a_v, seq_len, hidden_size, lora_rank);
589        let v_lora =
590            crate::autograd::matmul_nt(&v_mid, lora_b_v, seq_len, lora_rank, kv_hidden_size);
591        let v = crate::autograd::add_scaled(&v_base, &v_lora, lora_scale);
592
593        // Apply Q/K RMSNorm if present (Qwen3 QK-norm, ENT-269)
594        let q = if let Some(ref qn) = self.q_norm {
595            apply_qk_norm(&q, qn, seq_len, num_heads, head_dim)
596        } else {
597            q
598        };
599        let k = if let Some(ref kn) = self.k_norm {
600            apply_qk_norm(&k, kn, seq_len, num_kv_heads, head_dim)
601        } else {
602            k
603        };
604
605        // Apply Rotary Position Embedding (RoPE) to Q and K (ENT-269)
606        // Skip for encoder models (BERT/RoBERTa use learned positions, not RoPE)
607        let (q, k) = if self.config.rope_theta > 0.0 {
608            (
609                apply_rope(&q, seq_len, num_heads, head_dim, self.config.rope_theta),
610                apply_rope(&k, seq_len, num_kv_heads, head_dim, self.config.rope_theta),
611            )
612        } else {
613            (q, k)
614        };
615
616        let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
617        let heads_per_kv = num_heads / num_kv_heads;
618
619        // KAIZEN-016: Hoist data borrows outside head loop (same optimization as forward())
620        let q_data = q.data();
621        let q_slice = q_data.as_slice().expect("contiguous Q");
622        let k_data = k.data();
623        let k_slice = k_data.as_slice().expect("contiguous K");
624        let v_data = v.data();
625        let v_slice = v_data.as_slice().expect("contiguous V");
626
627        // Per-head attention (same as forward())
628        let mut head_q_tensors = Vec::with_capacity(num_heads);
629        let mut head_k_tensors = Vec::with_capacity(num_heads);
630        let mut head_v_tensors = Vec::with_capacity(num_heads);
631        let mut head_outputs = Vec::with_capacity(num_heads);
632        let mut head_kv_indices = Vec::with_capacity(num_heads);
633
634        for h in 0..num_heads {
635            let kv_h = h / heads_per_kv;
636            head_kv_indices.push(kv_h);
637
638            // KAIZEN-016: extend_from_slice replaces flat_map+to_vec
639            let mut q_head = Vec::with_capacity(seq_len * head_dim);
640            for s in 0..seq_len {
641                let start = s * q_dim + h * head_dim;
642                q_head.extend_from_slice(&q_slice[start..start + head_dim]);
643            }
644
645            let mut k_head = Vec::with_capacity(seq_len * head_dim);
646            for s in 0..seq_len {
647                let start = s * kv_hidden_size + kv_h * head_dim;
648                k_head.extend_from_slice(&k_slice[start..start + head_dim]);
649            }
650
651            let mut v_head = Vec::with_capacity(seq_len * head_dim);
652            for s in 0..seq_len {
653                let start = s * kv_hidden_size + kv_h * head_dim;
654                v_head.extend_from_slice(&v_slice[start..start + head_dim]);
655            }
656
657            let q_tensor = Tensor::from_vec(q_head, requires_grad);
658            let k_tensor = Tensor::from_vec(k_head, requires_grad);
659            let v_tensor = Tensor::from_vec(v_head, requires_grad);
660
661            let attn_out = crate::autograd::attention(
662                &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
663            );
664
665            head_q_tensors.push(q_tensor);
666            head_k_tensors.push(k_tensor);
667            head_v_tensors.push(v_tensor);
668            head_outputs.push(attn_out);
669        }
670
671        // Concatenate heads
672        let mut concat_output = vec![0.0; seq_len * q_dim];
673        for (h, head_out) in head_outputs.iter().enumerate() {
674            let hd = head_out.data();
675            let hdata = hd.as_slice().expect("contiguous attention output");
676            for s in 0..seq_len {
677                let src_base = s * head_dim;
678                let dst_base = s * q_dim + h * head_dim;
679                concat_output[dst_base..dst_base + head_dim]
680                    .copy_from_slice(&hdata[src_base..src_base + head_dim]);
681            }
682        }
683
684        let mut concat_tensor = Tensor::from_vec(concat_output, requires_grad);
685
686        if requires_grad {
687            let backward_op = Rc::new(AttentionBlockBackward {
688                q: q.clone(),
689                k: k.clone(),
690                v: v.clone(),
691                head_q_tensors,
692                head_k_tensors,
693                head_v_tensors,
694                head_outputs,
695                head_kv_indices,
696                seq_len,
697                head_dim,
698                q_dim,
699                kv_hidden_size,
700                result_grad: concat_tensor.grad_cell(),
701            });
702            concat_tensor.set_backward_op(backward_op);
703        }
704
705        // Output projection — w_o is [hidden_size, q_dim] in HF (ENT-269)
706        let result = matmul_nt(&concat_tensor, &self.w_o, seq_len, q_dim, hidden_size);
707        contract_post_lora_forward!(result);
708        result
709    }
710
711    /// Get all parameters as a vector
712    pub fn parameters(&self) -> Vec<&Tensor> {
713        let mut params = vec![&self.w_q, &self.w_k, &self.w_v, &self.w_o];
714        if let Some(ref b) = self.b_q {
715            params.push(b);
716        }
717        if let Some(ref b) = self.b_k {
718            params.push(b);
719        }
720        if let Some(ref b) = self.b_v {
721            params.push(b);
722        }
723        params
724    }
725
726    /// Get all parameters as mutable references for optimizer
727    pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
728        let mut params = vec![&mut self.w_q, &mut self.w_k, &mut self.w_v, &mut self.w_o];
729        if let Some(ref mut b) = self.b_q {
730            params.push(b);
731        }
732        if let Some(ref mut b) = self.b_k {
733            params.push(b);
734        }
735        if let Some(ref mut b) = self.b_v {
736            params.push(b);
737        }
738        params
739    }
740
741    /// Whether this attention layer has QKV biases
742    pub fn has_biases(&self) -> bool {
743        self.b_q.is_some()
744    }
745
746    /// Get named parameters for checkpoint serialization
747    pub fn named_parameters(&self, prefix: &str) -> Vec<(String, &Tensor)> {
748        let mut params = vec![
749            (format!("{prefix}.q_proj.weight"), &self.w_q),
750            (format!("{prefix}.k_proj.weight"), &self.w_k),
751            (format!("{prefix}.v_proj.weight"), &self.w_v),
752            (format!("{prefix}.o_proj.weight"), &self.w_o),
753        ];
754        if let Some(ref b) = self.b_q {
755            params.push((format!("{prefix}.q_proj.bias"), b));
756        }
757        if let Some(ref b) = self.b_k {
758            params.push((format!("{prefix}.k_proj.bias"), b));
759        }
760        if let Some(ref b) = self.b_v {
761            params.push((format!("{prefix}.v_proj.bias"), b));
762        }
763        params
764    }
765
766    /// ENT-282: Set a named parameter by suffix (after "self_attn.").
767    ///
768    /// Bias suffixes route to `b_q` / `b_k` / `b_v` only when those
769    /// fields are already `Some` (i.e., `MultiHeadAttention::new`
770    /// allocated them because `config.use_bias == true`). If the
771    /// caller asks to set a bias on an attention that doesn't have
772    /// one, return false — same semantic as setting an unrecognized
773    /// suffix. This keeps `populate_trainer_from_init_tensors`
774    /// honest: a Qwen-init APR's biases populate iff the target
775    /// `Transformer` was built from a `use_bias=true` config.
776    pub fn set_named_parameter(&mut self, suffix: &str, value: Tensor) -> bool {
777        match suffix {
778            "self_attn.q_proj.weight" => {
779                self.w_q = value;
780                true
781            }
782            "self_attn.k_proj.weight" => {
783                self.w_k = value;
784                true
785            }
786            "self_attn.v_proj.weight" => {
787                self.w_v = value;
788                true
789            }
790            "self_attn.o_proj.weight" => {
791                self.w_o = value;
792                true
793            }
794            "self_attn.q_proj.bias" => {
795                if self.b_q.is_some() {
796                    self.b_q = Some(value);
797                    true
798                } else {
799                    false
800                }
801            }
802            "self_attn.k_proj.bias" => {
803                if self.b_k.is_some() {
804                    self.b_k = Some(value);
805                    true
806                } else {
807                    false
808                }
809            }
810            "self_attn.v_proj.bias" => {
811                if self.b_v.is_some() {
812                    self.b_v = Some(value);
813                    true
814                } else {
815                    false
816                }
817            }
818            _ => false,
819        }
820    }
821}
822
823/// LoRA-enabled linear projection
824///
825/// Computes: y = x @ W + scale * (x @ A) @ B
826/// Where W is frozen base weight, A and B are trainable LoRA adapters
827pub struct LoRAProjection {
828    /// Base weight (frozen), shape (d_in × d_out)
829    pub base_weight: Tensor,
830    /// LoRA A matrix (down-projection), shape (d_in × rank)
831    pub lora_a: Tensor,
832    /// LoRA B matrix (up-projection), shape (rank × d_out)
833    pub lora_b: Tensor,
834    /// Input dimension
835    pub d_in: usize,
836    /// Output dimension
837    pub d_out: usize,
838    /// LoRA rank
839    pub rank: usize,
840    /// Scaling factor (alpha / rank)
841    pub scale: f32,
842}
843
844impl LoRAProjection {
845    /// Create a new LoRA projection
846    ///
847    /// # Arguments
848    /// * `base_weight` - Frozen base weight [d_in × d_out]
849    /// * `d_in` - Input dimension
850    /// * `d_out` - Output dimension
851    /// * `rank` - LoRA rank (typically 4, 8, 16, 32, or 64)
852    /// * `alpha` - LoRA scaling parameter
853    pub fn new(base_weight: Tensor, d_in: usize, d_out: usize, rank: usize, alpha: f32) -> Self {
854        assert_eq!(base_weight.len(), d_in * d_out, "Base weight size mismatch");
855
856        // Freeze base weight — only LoRA adapters are trainable
857        let mut base_weight = base_weight;
858        base_weight.set_requires_grad(false);
859
860        // Initialize A with Kaiming uniform (standard LoRA paper)
861        let lora_a = Tensor::from_vec(
862            (0..d_in * rank).map(|i| (i as f32 * 0.123).sin() * 0.01).collect(),
863            true, // requires_grad
864        );
865
866        // Initialize B with zeros (LoRA invariant: ΔW = B @ A = 0 at init)
867        let lora_b = Tensor::zeros(rank * d_out, true);
868
869        Self { base_weight, lora_a, lora_b, d_in, d_out, rank, scale: alpha / rank as f32 }
870    }
871
872    /// Forward pass with LoRA
873    ///
874    /// Computes: y = x @ W + scale * (x @ A) @ B
875    ///
876    /// # Arguments
877    /// * `x` - Input tensor [seq_len × d_in]
878    /// * `seq_len` - Sequence length
879    ///
880    /// # Returns
881    /// Output tensor [seq_len × d_out]
882    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
883        // Base projection: x @ W, (seq × d_in) @ (d_in × d_out) = (seq × d_out)
884        let base_out = matmul(x, &self.base_weight, seq_len, self.d_in, self.d_out);
885
886        // LoRA path: scale * (x @ A) @ B
887        // Step 1: x @ A, (seq × d_in) @ (d_in × rank) = (seq × rank)
888        let lora_intermediate = matmul(x, &self.lora_a, seq_len, self.d_in, self.rank);
889
890        // Step 2: (x @ A) @ B, (seq × rank) @ (rank × d_out) = (seq × d_out)
891        let lora_out = matmul(&lora_intermediate, &self.lora_b, seq_len, self.rank, self.d_out);
892
893        // Combine: base + scale * lora
894        // Use autograd-compatible addition
895        crate::autograd::add_scaled(&base_out, &lora_out, self.scale)
896    }
897
898    /// Get trainable LoRA parameters
899    pub fn lora_params(&self) -> Vec<&Tensor> {
900        vec![&self.lora_a, &self.lora_b]
901    }
902
903    /// Get mutable trainable LoRA parameters
904    pub fn lora_params_mut(&mut self) -> Vec<&mut Tensor> {
905        vec![&mut self.lora_a, &mut self.lora_b]
906    }
907}
908
909/// Multi-head attention with deep LoRA injection
910///
911/// LoRA adapters are applied to Q, K, V, O projections during forward pass
912pub struct MultiHeadAttentionWithLoRA {
913    /// Configuration
914    pub config: TransformerConfig,
915    /// Query projection with LoRA
916    pub q_proj: LoRAProjection,
917    /// Key projection with LoRA
918    pub k_proj: LoRAProjection,
919    /// Value projection with LoRA
920    pub v_proj: LoRAProjection,
921    /// Output projection with LoRA
922    pub o_proj: LoRAProjection,
923}
924
925impl MultiHeadAttentionWithLoRA {
926    /// Create LoRA-enabled attention from existing attention weights
927    ///
928    /// # Arguments
929    /// * `attn` - Base MultiHeadAttention with pretrained weights
930    /// * `rank` - LoRA rank
931    /// * `alpha` - LoRA alpha scaling factor
932    pub fn from_attention(attn: &MultiHeadAttention, rank: usize, alpha: f32) -> Self {
933        let hidden_size = attn.config.hidden_size;
934        let q_dim = attn.config.q_dim();
935        let kv_hidden_size = attn.config.num_kv_heads * attn.config.head_dim();
936
937        Self {
938            config: attn.config.clone(),
939            q_proj: LoRAProjection::new(attn.w_q.clone(), hidden_size, q_dim, rank, alpha),
940            k_proj: LoRAProjection::new(attn.w_k.clone(), hidden_size, kv_hidden_size, rank, alpha),
941            v_proj: LoRAProjection::new(attn.w_v.clone(), hidden_size, kv_hidden_size, rank, alpha),
942            o_proj: LoRAProjection::new(attn.w_o.clone(), q_dim, hidden_size, rank, alpha),
943        }
944    }
945
946    /// Forward pass with deep LoRA injection
947    ///
948    /// LoRA is applied to all Q, K, V, O projections
949    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
950        let num_heads = self.config.num_attention_heads;
951        let num_kv_heads = self.config.num_kv_heads;
952        let head_dim = self.config.head_dim();
953        let q_dim = self.config.q_dim();
954        let kv_hidden_size = num_kv_heads * head_dim;
955
956        // Project Q, K, V with LoRA
957        let q = self.q_proj.forward(x, seq_len);
958        let k = self.k_proj.forward(x, seq_len);
959        let v = self.v_proj.forward(x, seq_len);
960
961        // Multi-head attention with grouped-query attention support
962        let mut attn_outputs = Vec::with_capacity(num_heads * seq_len * head_dim);
963        let heads_per_kv = num_heads / num_kv_heads;
964
965        // KAIZEN-016: Hoist data borrows outside head loop
966        let q_data = q.data();
967        let q_slice = q_data.as_slice().expect("contiguous Q tensor");
968        let k_data = k.data();
969        let k_slice = k_data.as_slice().expect("contiguous K tensor");
970        let v_data = v.data();
971        let v_slice = v_data.as_slice().expect("contiguous V tensor");
972
973        for h in 0..num_heads {
974            let kv_h = h / heads_per_kv;
975
976            // KAIZEN-016: extend_from_slice replaces flat_map+to_vec
977            let mut q_head = Vec::with_capacity(seq_len * head_dim);
978            for s in 0..seq_len {
979                let start = s * q_dim + h * head_dim;
980                q_head.extend_from_slice(&q_slice[start..start + head_dim]);
981            }
982
983            let mut k_head = Vec::with_capacity(seq_len * head_dim);
984            for s in 0..seq_len {
985                let start = s * kv_hidden_size + kv_h * head_dim;
986                k_head.extend_from_slice(&k_slice[start..start + head_dim]);
987            }
988
989            let mut v_head = Vec::with_capacity(seq_len * head_dim);
990            for s in 0..seq_len {
991                let start = s * kv_hidden_size + kv_h * head_dim;
992                v_head.extend_from_slice(&v_slice[start..start + head_dim]);
993            }
994
995            // Scaled dot-product attention
996            let q_tensor = Tensor::from_vec(q_head, false);
997            let k_tensor = Tensor::from_vec(k_head, false);
998            let v_tensor = Tensor::from_vec(v_head, false);
999
1000            let attn_out = crate::autograd::attention(
1001                &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
1002            );
1003
1004            attn_outputs.extend_from_slice(
1005                attn_out.data().as_slice().expect("contiguous attention output"),
1006            );
1007        }
1008
1009        // Concatenate heads and reorder: (seq_len, q_dim)
1010        let mut concat_output = vec![0.0; seq_len * q_dim];
1011        for h in 0..num_heads {
1012            for s in 0..seq_len {
1013                let src_idx = h * seq_len * head_dim + s * head_dim;
1014                let dst_idx = s * q_dim + h * head_dim;
1015                concat_output[dst_idx..dst_idx + head_dim]
1016                    .copy_from_slice(&attn_outputs[src_idx..src_idx + head_dim]);
1017            }
1018        }
1019
1020        let concat_tensor = Tensor::from_vec(concat_output, true);
1021
1022        // Output projection with LoRA: (seq_len, q_dim) -> (seq_len, hidden_size)
1023        self.o_proj.forward(&concat_tensor, seq_len)
1024    }
1025
1026    /// Get all trainable LoRA parameters
1027    pub fn lora_params(&self) -> Vec<&Tensor> {
1028        let mut params = Vec::new();
1029        params.extend(self.q_proj.lora_params());
1030        params.extend(self.k_proj.lora_params());
1031        params.extend(self.v_proj.lora_params());
1032        params.extend(self.o_proj.lora_params());
1033        params
1034    }
1035
1036    /// Get all trainable LoRA parameters as mutable references
1037    pub fn lora_params_mut(&mut self) -> Vec<&mut Tensor> {
1038        let mut params = Vec::new();
1039        params.extend(self.q_proj.lora_params_mut());
1040        params.extend(self.k_proj.lora_params_mut());
1041        params.extend(self.v_proj.lora_params_mut());
1042        params.extend(self.o_proj.lora_params_mut());
1043        params
1044    }
1045
1046    /// Count total LoRA parameters
1047    pub fn lora_param_count(&self) -> usize {
1048        // Each projection has A (d_in × rank) + B (rank × d_out)
1049        let hidden = self.config.hidden_size;
1050        let kv_hidden = self.config.num_kv_heads * self.config.head_dim();
1051        let rank = self.q_proj.rank;
1052
1053        // Q: (hidden × rank) + (rank × hidden)
1054        // K: (hidden × rank) + (rank × kv_hidden)
1055        // V: (hidden × rank) + (rank × kv_hidden)
1056        // O: (hidden × rank) + (rank × hidden)
1057        (hidden * rank + rank * hidden)      // Q
1058            + (hidden * rank + rank * kv_hidden) // K
1059            + (hidden * rank + rank * kv_hidden) // V
1060            + (hidden * rank + rank * hidden) // O
1061    }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066    use super::*;
1067
1068    #[test]
1069    fn test_multi_head_attention_tiny() {
1070        let config = TransformerConfig::tiny();
1071        let attn = MultiHeadAttention::new(&config);
1072        let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
1073        let output = attn.forward(&x, 2);
1074        assert_eq!(output.len(), 2 * config.hidden_size);
1075    }
1076
1077    #[test]
1078    fn test_multi_head_attention_parameters() {
1079        let config = TransformerConfig::tiny();
1080        let attn = MultiHeadAttention::new(&config);
1081        let params = attn.parameters();
1082        assert_eq!(params.len(), 4); // w_q, w_k, w_v, w_o
1083    }
1084
1085    #[test]
1086    fn test_attention_longer_sequence() {
1087        let config = TransformerConfig::tiny();
1088        let attn = MultiHeadAttention::new(&config);
1089        let x = Tensor::from_vec(vec![0.1; 8 * config.hidden_size], true);
1090        let output = attn.forward(&x, 8);
1091        assert_eq!(output.len(), 8 * config.hidden_size);
1092    }
1093
1094    #[test]
1095    fn test_attention_weight_sizes() {
1096        let config = TransformerConfig::tiny();
1097        let attn = MultiHeadAttention::new(&config);
1098        let kv_hidden = config.num_kv_heads * config.head_dim();
1099        assert_eq!(attn.w_q.len(), config.hidden_size * config.hidden_size);
1100        assert_eq!(attn.w_k.len(), config.hidden_size * kv_hidden);
1101        assert_eq!(attn.w_v.len(), config.hidden_size * kv_hidden);
1102        assert_eq!(attn.w_o.len(), config.hidden_size * config.hidden_size);
1103    }
1104
1105    #[test]
1106    fn test_multi_head_attention_from_params_success() {
1107        let config = TransformerConfig::tiny();
1108        let hidden_size = config.hidden_size;
1109        let kv_hidden_size = config.num_kv_heads * config.head_dim();
1110
1111        let mut params = HashMap::new();
1112        params.insert(
1113            "attn.q_proj.weight".to_string(),
1114            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1115        );
1116        params.insert(
1117            "attn.k_proj.weight".to_string(),
1118            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1119        );
1120        params.insert(
1121            "attn.v_proj.weight".to_string(),
1122            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1123        );
1124        params.insert(
1125            "attn.o_proj.weight".to_string(),
1126            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1127        );
1128
1129        let attn = MultiHeadAttention::from_params(&config, &params, "attn");
1130        assert!(attn.is_some());
1131        let attn = attn.expect("operation should succeed");
1132        assert_eq!(attn.w_q.len(), hidden_size * hidden_size);
1133    }
1134
1135    #[test]
1136    fn test_multi_head_attention_from_params_missing_key() {
1137        let config = TransformerConfig::tiny();
1138        let hidden_size = config.hidden_size;
1139
1140        let mut params = HashMap::new();
1141        params.insert(
1142            "attn.q_proj.weight".to_string(),
1143            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1144        );
1145        // Missing k_proj, v_proj, o_proj
1146
1147        let attn = MultiHeadAttention::from_params(&config, &params, "attn");
1148        assert!(attn.is_none());
1149    }
1150
1151    #[test]
1152    fn test_attention_projections_backward() {
1153        // Test that Q, K, V projection matmuls have gradients
1154        // (isolated from the full attention which has intermediate tensor issues)
1155        let config = TransformerConfig::tiny();
1156        let attn = MultiHeadAttention::new(&config);
1157        let hidden_size = config.hidden_size;
1158        let seq_len = 2;
1159
1160        let x = Tensor::from_vec(vec![0.1; seq_len * hidden_size], true);
1161
1162        // Test Q projection
1163        let mut q = crate::autograd::matmul(&x, &attn.w_q, seq_len, hidden_size, hidden_size);
1164        let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1165        crate::autograd::backward(&mut q, Some(grad_out));
1166
1167        assert!(attn.w_q.grad().is_some());
1168        let grad_q = attn.w_q.grad().expect("gradient should be available");
1169        assert!(grad_q.iter().all(|&v| v.is_finite()));
1170    }
1171
1172    #[test]
1173    fn test_output_projection_backward() {
1174        // Test output projection in isolation
1175        let config = TransformerConfig::tiny();
1176        let attn = MultiHeadAttention::new(&config);
1177        let hidden_size = config.hidden_size;
1178        let seq_len = 2;
1179
1180        // Simulate concatenated attention output
1181        let concat_out = Tensor::from_vec(vec![0.1; seq_len * hidden_size], true);
1182
1183        // Output projection
1184        let mut output =
1185            crate::autograd::matmul(&concat_out, &attn.w_o, seq_len, hidden_size, hidden_size);
1186
1187        let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1188        crate::autograd::backward(&mut output, Some(grad_out));
1189
1190        assert!(attn.w_o.grad().is_some());
1191        let grad_o = attn.w_o.grad().expect("gradient should be available");
1192        assert!(grad_o.iter().all(|&v| v.is_finite()));
1193        let sum: f32 = grad_o.iter().map(|v| v.abs()).sum();
1194        assert!(sum > 0.0, "Output projection gradient should not be all zero");
1195    }
1196
1197    /// ALB-038: Full attention forward must propagate gradients to Q/K/V weights
1198    ///
1199    /// NOTE: Currently fails because apply_rope() has no backward op — it severs
1200    /// the autograd chain for Q and K. Needs a proper RoPE backward implementation
1201    /// (ENT-272). Skipped until then.
1202    #[test]
1203    #[ignore = "apply_rope() severs autograd chain — needs backward op (ENT-272)"]
1204    fn test_attention_full_forward_qkv_gradients() {
1205        let config = TransformerConfig::tiny();
1206        let attn = MultiHeadAttention::new(&config);
1207        let hidden_size = config.hidden_size;
1208        let seq_len = 3;
1209
1210        // Non-uniform input: different positions must have different representations
1211        // so softmax produces non-uniform weights with non-zero score gradients
1212        let x_data: Vec<f32> =
1213            (0..seq_len * hidden_size).map(|i| ((i as f32) * 0.17).sin() * 0.5).collect();
1214        let x = Tensor::from_vec(x_data, true);
1215        let mut output = attn.forward(&x, seq_len);
1216
1217        let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1218        crate::autograd::backward(&mut output, Some(grad_out));
1219
1220        // All four projection weights must receive gradients
1221        for (name, param) in
1222            [("w_q", &attn.w_q), ("w_k", &attn.w_k), ("w_v", &attn.w_v), ("w_o", &attn.w_o)]
1223        {
1224            assert!(
1225                param.grad().is_some(),
1226                "ALB-038: {name} must have gradient after full attention forward"
1227            );
1228            let grad = param.grad().expect("gradient available");
1229            assert!(grad.iter().all(|&v| v.is_finite()), "ALB-038: {name} gradient must be finite");
1230            assert!(
1231                grad.iter().any(|&v| v.abs() > 1e-10),
1232                "ALB-038: {name} gradient must be non-zero"
1233            );
1234        }
1235
1236        // Input must also receive gradient (enables gradient flow through model)
1237        assert!(x.grad().is_some(), "ALB-038: input x must have gradient");
1238    }
1239
1240    // ============================================================================
1241    // LoRAProjection tests
1242    // ============================================================================
1243
1244    #[test]
1245    fn test_lora_projection_new() {
1246        let d_in = 32;
1247        let d_out = 16;
1248        let rank = 4;
1249        let alpha = 8.0;
1250
1251        let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1252        let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, alpha);
1253
1254        assert_eq!(lora.d_in, d_in);
1255        assert_eq!(lora.d_out, d_out);
1256        assert_eq!(lora.rank, rank);
1257        assert!((lora.scale - 2.0).abs() < 1e-6); // alpha / rank = 8 / 4 = 2
1258        assert_eq!(lora.lora_a.len(), d_in * rank);
1259        assert_eq!(lora.lora_b.len(), rank * d_out);
1260    }
1261
1262    #[test]
1263    fn test_lora_projection_forward() {
1264        let d_in = 32;
1265        let d_out = 16;
1266        let rank = 4;
1267        let alpha = 8.0;
1268        let seq_len = 2;
1269
1270        let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1271        let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, alpha);
1272
1273        let x = Tensor::from_vec(vec![0.1; seq_len * d_in], false);
1274        let output = lora.forward(&x, seq_len);
1275
1276        assert_eq!(output.len(), seq_len * d_out);
1277        // Check output is finite
1278        assert!(output.data().iter().all(|&v| v.is_finite()));
1279    }
1280
1281    #[test]
1282    fn test_lora_projection_params() {
1283        let d_in = 32;
1284        let d_out = 16;
1285        let rank = 4;
1286
1287        let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1288        let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1289
1290        let params = lora.lora_params();
1291        assert_eq!(params.len(), 2); // lora_a and lora_b
1292    }
1293
1294    #[test]
1295    fn test_lora_projection_params_mut() {
1296        let d_in = 32;
1297        let d_out = 16;
1298        let rank = 4;
1299
1300        let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1301        let mut lora = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1302
1303        let params = lora.lora_params_mut();
1304        assert_eq!(params.len(), 2);
1305    }
1306
1307    #[test]
1308    #[should_panic(expected = "Base weight size mismatch")]
1309    fn test_lora_projection_size_mismatch() {
1310        let d_in = 32;
1311        let d_out = 16;
1312        let rank = 4;
1313
1314        // Wrong base weight size
1315        let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out + 1], false);
1316        let _ = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1317    }
1318
1319    // ============================================================================
1320    // MultiHeadAttentionWithLoRA tests
1321    // ============================================================================
1322
1323    #[test]
1324    fn test_mha_with_lora_creation() {
1325        let config = TransformerConfig::tiny();
1326        let attn = MultiHeadAttention::new(&config);
1327        let rank = 4;
1328        let alpha = 8.0;
1329
1330        let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, rank, alpha);
1331
1332        assert_eq!(lora_attn.q_proj.rank, rank);
1333        assert_eq!(lora_attn.k_proj.rank, rank);
1334        assert_eq!(lora_attn.v_proj.rank, rank);
1335        assert_eq!(lora_attn.o_proj.rank, rank);
1336    }
1337
1338    #[test]
1339    fn test_mha_with_lora_forward() {
1340        let config = TransformerConfig::tiny();
1341        let attn = MultiHeadAttention::new(&config);
1342        let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1343
1344        let seq_len = 2;
1345        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], false);
1346        let output = lora_attn.forward(&x, seq_len);
1347
1348        assert_eq!(output.len(), seq_len * config.hidden_size);
1349        // Check output is finite and non-zero
1350        assert!(output.data().iter().all(|&v| v.is_finite()));
1351    }
1352
1353    #[test]
1354    fn test_mha_with_lora_params() {
1355        let config = TransformerConfig::tiny();
1356        let attn = MultiHeadAttention::new(&config);
1357        let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1358
1359        let params = lora_attn.lora_params();
1360        // 4 projections × 2 params each = 8
1361        assert_eq!(params.len(), 8);
1362    }
1363
1364    #[test]
1365    fn test_mha_with_lora_params_mut() {
1366        let config = TransformerConfig::tiny();
1367        let attn = MultiHeadAttention::new(&config);
1368        let mut lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1369
1370        let params = lora_attn.lora_params_mut();
1371        assert_eq!(params.len(), 8);
1372    }
1373
1374    #[test]
1375    fn test_mha_with_lora_param_count() {
1376        let config = TransformerConfig::tiny();
1377        let attn = MultiHeadAttention::new(&config);
1378        let rank = 4;
1379        let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, rank, 8.0);
1380
1381        let param_count = lora_attn.lora_param_count();
1382
1383        // Calculate expected:
1384        let hidden = config.hidden_size;
1385        let kv_hidden = config.num_kv_heads * config.head_dim();
1386        let expected = (hidden * rank + rank * hidden)      // Q
1387            + (hidden * rank + rank * kv_hidden) // K
1388            + (hidden * rank + rank * kv_hidden) // V
1389            + (hidden * rank + rank * hidden); // O
1390
1391        assert_eq!(param_count, expected);
1392        assert!(param_count > 0);
1393    }
1394
1395    #[test]
1396    fn test_mha_with_lora_longer_sequence() {
1397        let config = TransformerConfig::tiny();
1398        let attn = MultiHeadAttention::new(&config);
1399        let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1400
1401        let seq_len = 8;
1402        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], false);
1403        let output = lora_attn.forward(&x, seq_len);
1404
1405        assert_eq!(output.len(), seq_len * config.hidden_size);
1406    }
1407
1408    #[test]
1409    fn test_parameters_mut() {
1410        let config = TransformerConfig::tiny();
1411        let mut attn = MultiHeadAttention::new(&config);
1412
1413        let params = attn.parameters_mut();
1414        assert_eq!(params.len(), 4);
1415    }
1416
1417    // =========================================================================
1418    // FALSIFY-A: §2.1.3 Attention Projections — Five-Whys Gap Analysis (Refs PMAT-331)
1419    //
1420    // Contract: tensor-layout-v1.yaml §tensors.q_proj/k_proj/v_proj/o_proj
1421    //   q_proj: [num_heads*head_dim, hidden] (= [hidden, hidden] for MHA)
1422    //   k_proj: [num_kv_heads*head_dim, hidden] (smaller for GQA)
1423    //   v_proj: [num_kv_heads*head_dim, hidden] (smaller for GQA)
1424    //   o_proj: [hidden, num_heads*head_dim]
1425    //
1426    // Five-Whys:
1427    //   Why 1: Trained model's attention weights could be wrong shape
1428    //   Why 2: from_params accepts any tensor without shape validation
1429    //   Why 3: No ValidatedWeight in entrenar
1430    //   Why 4: entrenar predates the Poka-Yoke contract
1431    //   Why 5: No cross-crate contract enforcement for training weights
1432    //
1433    // Popper (1959): "These tests attempt to falsify the claim that
1434    // entrenar's attention weight handling prevents degenerate models."
1435    // =========================================================================
1436
1437    /// FALSIFY-A1e: from_params rejects wrong-shape Q weight (PMAT-331 fix)
1438    ///
1439    /// from_params now validates Q projection shape against config dimensions.
1440    /// A tensor of 50 elements is rejected when hidden*hidden is expected.
1441    #[test]
1442    fn falsify_a1e_from_params_rejects_wrong_shape_q_weight() {
1443        let config = TransformerConfig::tiny();
1444        let hidden_size = config.hidden_size;
1445        let kv_hidden_size = config.num_kv_heads * config.head_dim();
1446
1447        let mut params = HashMap::new();
1448        // WRONG-SHAPE q_proj: 50 elements instead of hidden*hidden
1449        params.insert("attn.q_proj.weight".to_string(), Tensor::from_vec(vec![0.1; 50], true));
1450        // Correct k, v, o
1451        params.insert(
1452            "attn.k_proj.weight".to_string(),
1453            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1454        );
1455        params.insert(
1456            "attn.v_proj.weight".to_string(),
1457            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1458        );
1459        params.insert(
1460            "attn.o_proj.weight".to_string(),
1461            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1462        );
1463
1464        let attn = MultiHeadAttention::from_params(&config, &params, "attn");
1465        // FIXED (PMAT-331): now rejected
1466        assert!(
1467            attn.is_none(),
1468            "FALSIFY-A1e: PMAT-331 fix — from_params MUST reject wrong-shape q_proj"
1469        );
1470    }
1471
1472    /// FALSIFY-A2e: GQA init produces correct K/V dimensions
1473    ///
1474    /// For GQA (num_kv_heads < num_heads), K/V must be smaller than Q.
1475    /// If init uses num_heads for K/V, the shapes are wrong.
1476    #[test]
1477    fn falsify_a2e_gqa_init_correct_kv_dimensions() {
1478        let mut config = TransformerConfig::tiny();
1479        config.num_kv_heads = 1; // Force GQA: 1 KV head, but num_heads > 1
1480
1481        let attn = MultiHeadAttention::new(&config);
1482        let head_dim = config.head_dim();
1483        let kv_hidden = config.num_kv_heads * head_dim; // 1 * head_dim
1484
1485        // Q: hidden * hidden
1486        assert_eq!(
1487            attn.w_q.len(),
1488            config.hidden_size * config.hidden_size,
1489            "FALSIFY-A2e: Q projection must be hidden*hidden"
1490        );
1491
1492        // K: hidden * kv_hidden (smaller than Q for GQA)
1493        assert_eq!(
1494            attn.w_k.len(),
1495            config.hidden_size * kv_hidden,
1496            "FALSIFY-A2e: K projection must use num_kv_heads, not num_heads"
1497        );
1498
1499        // V: hidden * kv_hidden (same as K)
1500        assert_eq!(
1501            attn.w_v.len(),
1502            config.hidden_size * kv_hidden,
1503            "FALSIFY-A2e: V projection must use num_kv_heads, not num_heads"
1504        );
1505
1506        // O: hidden * hidden (matches Q output)
1507        assert_eq!(
1508            attn.w_o.len(),
1509            config.hidden_size * config.hidden_size,
1510            "FALSIFY-A2e: O projection must be hidden*hidden"
1511        );
1512
1513        // K/V must be SMALLER than Q for GQA
1514        assert!(
1515            attn.w_k.len() < attn.w_q.len(),
1516            "FALSIFY-A2e: For GQA, K weight must be smaller than Q weight"
1517        );
1518    }
1519
1520    /// FALSIFY-A3e: GQA forward produces correct output dimensions
1521    ///
1522    /// With num_kv_heads < num_heads, the forward pass must still produce
1523    /// [seq_len, hidden_size] output (not [seq_len, kv_hidden]).
1524    #[test]
1525    fn falsify_a3e_gqa_forward_correct_output_dims() {
1526        let mut config = TransformerConfig::tiny();
1527        config.num_kv_heads = 1; // Force GQA
1528
1529        let attn = MultiHeadAttention::new(&config);
1530        let seq_len = 3;
1531        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1532        let output = attn.forward(&x, seq_len);
1533
1534        assert_eq!(
1535            output.len(),
1536            seq_len * config.hidden_size,
1537            "FALSIFY-A3e: GQA output must be seq_len * hidden_size, not seq_len * kv_hidden"
1538        );
1539    }
1540
1541    /// FALSIFY-A4e: Attention init produces non-degenerate values
1542    ///
1543    /// Like FALSIFY-E7a for embeddings: init must produce varied, finite values.
1544    #[test]
1545    fn falsify_a4e_init_produces_valid_attention_weights() {
1546        let config = TransformerConfig::tiny();
1547        let attn = MultiHeadAttention::new(&config);
1548
1549        for (name, w) in
1550            [("w_q", &attn.w_q), ("w_k", &attn.w_k), ("w_v", &attn.w_v), ("w_o", &attn.w_o)]
1551        {
1552            let data = w.data();
1553            let slice = data.as_slice().expect("data as slice");
1554
1555            // No NaN
1556            let nan_count = slice.iter().filter(|v| v.is_nan()).count();
1557            assert_eq!(nan_count, 0, "FALSIFY-A4e: {name} init must not contain NaN");
1558
1559            // No Inf
1560            let inf_count = slice.iter().filter(|v| v.is_infinite()).count();
1561            assert_eq!(inf_count, 0, "FALSIFY-A4e: {name} init must not contain Inf");
1562
1563            // Values vary
1564            let min = slice.iter().copied().fold(f32::INFINITY, f32::min);
1565            let max = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1566            assert!(
1567                (max - min).abs() > 1e-6,
1568                "FALSIFY-A4e: {name} init values are constant ({min}..{max}) — degenerate weight"
1569            );
1570        }
1571    }
1572
1573    /// FALSIFY-A5e: Attention forward produces finite outputs
1574    ///
1575    /// If any attention weight is degenerate, output should still be finite
1576    /// (the init is designed to prevent this).
1577    #[test]
1578    fn falsify_a5e_forward_produces_finite_output() {
1579        let config = TransformerConfig::tiny();
1580        let attn = MultiHeadAttention::new(&config);
1581        let seq_len = 4;
1582        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1583        let output = attn.forward(&x, seq_len);
1584
1585        let data = output.data();
1586        let nan_count = data.iter().filter(|v| v.is_nan()).count();
1587        let inf_count = data.iter().filter(|v| v.is_infinite()).count();
1588        assert_eq!(nan_count, 0, "FALSIFY-A5e: Attention output must not contain NaN");
1589        assert_eq!(inf_count, 0, "FALSIFY-A5e: Attention output must not contain Inf");
1590    }
1591
1592    // =========================================================================
1593    // FALSIFY-GQ: gqa-kernel-v1.yaml contract (entrenar MultiHeadAttention GQA)
1594    //
1595    // Five-Whys (PMAT-354):
1596    //   Why 1: entrenar had FALSIFY-A tests but zero FALSIFY-GQ-* tests
1597    //   Why 2: FALSIFY-A tests verify projections/shapes, not GQA invariants
1598    //   Why 3: no mapping from gqa-kernel-v1.yaml to entrenar test names
1599    //   Why 4: entrenar's GQA support added after FALSIFY-A tests
1600    //   Why 5: GQA was "obviously correct" (just index K/V by h/heads_per_kv)
1601    //
1602    // References:
1603    //   - provable-contracts/contracts/gqa-kernel-v1.yaml
1604    //   - Ainslie et al. (2023) "GQA: Training Generalized MQT Models"
1605    // =========================================================================
1606
1607    /// FALSIFY-GQ-001e: GQA output shape correct for various head configs
1608    #[test]
1609    fn falsify_gq_001e_output_shape() {
1610        for (num_heads, num_kv_heads) in [(2, 2), (4, 2), (4, 1), (2, 1)] {
1611            let mut config = TransformerConfig::tiny();
1612            config.num_attention_heads = num_heads;
1613            config.num_kv_heads = num_kv_heads;
1614
1615            let attn = MultiHeadAttention::new(&config);
1616            let seq_len = 3;
1617            let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1618            let output = attn.forward(&x, seq_len);
1619
1620            assert_eq!(
1621                output.len(),
1622                seq_len * config.hidden_size,
1623                "FALSIFIED GQ-001e: output len mismatch for heads={num_heads},kv={num_kv_heads}"
1624            );
1625        }
1626    }
1627
1628    /// FALSIFY-GQ-002e: MHA degeneration — kv_heads == num_heads produces finite output
1629    #[test]
1630    fn falsify_gq_002e_mha_degeneration() {
1631        let config = TransformerConfig::tiny(); // num_heads == num_kv_heads == 2
1632        assert_eq!(config.num_attention_heads, config.num_kv_heads);
1633
1634        let attn = MultiHeadAttention::new(&config);
1635        let seq_len = 4;
1636        let x = Tensor::from_vec(
1637            (0..seq_len * config.hidden_size).map(|i| (i as f32 * 0.37).sin()).collect(),
1638            true,
1639        );
1640        let output = attn.forward(&x, seq_len);
1641
1642        let data = output.data();
1643        for (i, v) in data.iter().enumerate() {
1644            assert!(v.is_finite(), "FALSIFIED GQ-002e: MHA output[{i}] = {v} (not finite)");
1645        }
1646    }
1647
1648    /// FALSIFY-GQ-004e: Head divisibility — GQA requires num_heads % num_kv_heads == 0
1649    #[test]
1650    fn falsify_gq_004e_head_divisibility() {
1651        // Valid configurations should not panic
1652        for (nh, nkv) in [(2, 1), (2, 2), (4, 1), (4, 2), (4, 4), (8, 2), (8, 4)] {
1653            let mut config = TransformerConfig::tiny();
1654            config.num_attention_heads = nh;
1655            config.num_kv_heads = nkv;
1656            assert_eq!(nh % nkv, 0, "FALSIFIED GQ-004e: test config has invalid head ratio");
1657            // Should not panic during construction or forward
1658            let attn = MultiHeadAttention::new(&config);
1659            let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
1660            let _ = attn.forward(&x, 2);
1661        }
1662    }
1663
1664    /// FALSIFY-GQ-006e: MQA boundary — kv_heads=1 broadcasts single KV to all heads
1665    #[test]
1666    fn falsify_gq_006e_mqa_boundary() {
1667        let mut config = TransformerConfig::tiny();
1668        config.num_attention_heads = 4;
1669        config.num_kv_heads = 1;
1670        // Adjust hidden_size to be divisible by 4 heads
1671        config.hidden_size = 64;
1672
1673        let attn = MultiHeadAttention::new(&config);
1674        let seq_len = 3;
1675        let x = Tensor::from_vec(
1676            (0..seq_len * config.hidden_size).map(|i| (i as f32 * 0.73).cos()).collect(),
1677            true,
1678        );
1679        let output = attn.forward(&x, seq_len);
1680
1681        assert_eq!(
1682            output.len(),
1683            seq_len * config.hidden_size,
1684            "FALSIFIED GQ-006e: MQA output size wrong"
1685        );
1686
1687        // All finite
1688        let data = output.data();
1689        for (i, v) in data.iter().enumerate() {
1690            assert!(v.is_finite(), "FALSIFIED GQ-006e: MQA output[{i}] = {v} (not finite)");
1691        }
1692    }
1693
1694    mod gq_proptest_falsify {
1695        use super::*;
1696        use proptest::prelude::*;
1697
1698        // FALSIFY-GQ-001e-prop: GQA output shape for random configs
1699        proptest! {
1700            #![proptest_config(ProptestConfig::with_cases(50))]
1701
1702            #[test]
1703            fn falsify_gq_001e_prop_output_shape(
1704                config_idx in 0..4usize,
1705                seq_len in 2..=6usize,
1706                seed in 0..500u32,
1707            ) {
1708                let configs: [(usize, usize); 4] = [
1709                    (2, 2), (2, 1), (4, 2), (4, 1),
1710                ];
1711                let (num_heads, num_kv_heads) = configs[config_idx];
1712                let mut config = TransformerConfig::tiny();
1713                config.num_attention_heads = num_heads;
1714                config.num_kv_heads = num_kv_heads;
1715
1716                let attn = MultiHeadAttention::new(&config);
1717                let data: Vec<f32> = (0..seq_len * config.hidden_size)
1718                    .map(|i| ((i as f32 + seed as f32) * 0.37).sin())
1719                    .collect();
1720                let x = Tensor::from_vec(data, true);
1721                let output = attn.forward(&x, seq_len);
1722
1723                prop_assert_eq!(
1724                    output.len(),
1725                    seq_len * config.hidden_size,
1726                    "FALSIFIED GQ-001e-prop: output len mismatch"
1727                );
1728
1729                // All finite
1730                for v in output.data() {
1731                    prop_assert!(
1732                        v.is_finite(),
1733                        "FALSIFIED GQ-001e-prop: non-finite output"
1734                    );
1735                }
1736            }
1737        }
1738
1739        // FALSIFY-GQ-006e-prop: MQA boundary with random inputs
1740        proptest! {
1741            #![proptest_config(ProptestConfig::with_cases(30))]
1742
1743            #[test]
1744            fn falsify_gq_006e_prop_mqa_boundary(
1745                seed in 0..500u32,
1746                seq_len in 2..=5usize,
1747            ) {
1748                let mut config = TransformerConfig::tiny();
1749                config.num_attention_heads = 4;
1750                config.num_kv_heads = 1;
1751                config.hidden_size = 64;
1752
1753                let attn = MultiHeadAttention::new(&config);
1754                let data: Vec<f32> = (0..seq_len * config.hidden_size)
1755                    .map(|i| ((i as f32 + seed as f32) * 0.73).cos())
1756                    .collect();
1757                let x = Tensor::from_vec(data, true);
1758                let output = attn.forward(&x, seq_len);
1759
1760                prop_assert_eq!(
1761                    output.len(),
1762                    seq_len * config.hidden_size,
1763                    "FALSIFIED GQ-006e-prop: MQA output len mismatch"
1764                );
1765
1766                for v in output.data() {
1767                    prop_assert!(
1768                        v.is_finite(),
1769                        "FALSIFIED GQ-006e-prop: non-finite MQA output"
1770                    );
1771                }
1772            }
1773        }
1774    }
1775
1776    #[test]
1777    fn test_attention_from_params_with_biases() {
1778        let config = TransformerConfig::tiny();
1779        let hidden_size = config.hidden_size;
1780        let kv_hidden_size = config.num_kv_heads * config.head_dim();
1781
1782        let mut params = HashMap::new();
1783        params.insert(
1784            "attn.q_proj.weight".to_string(),
1785            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1786        );
1787        params.insert(
1788            "attn.k_proj.weight".to_string(),
1789            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1790        );
1791        params.insert(
1792            "attn.v_proj.weight".to_string(),
1793            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1794        );
1795        params.insert(
1796            "attn.o_proj.weight".to_string(),
1797            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1798        );
1799        params.insert(
1800            "attn.q_proj.bias".to_string(),
1801            Tensor::from_vec(vec![0.01; hidden_size], true),
1802        );
1803        params.insert(
1804            "attn.k_proj.bias".to_string(),
1805            Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1806        );
1807        params.insert(
1808            "attn.v_proj.bias".to_string(),
1809            Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1810        );
1811
1812        let attn = MultiHeadAttention::from_params(&config, &params, "attn");
1813        assert!(attn.is_some());
1814        let attn = attn.expect("should load with biases");
1815        assert!(attn.has_biases());
1816        assert_eq!(attn.parameters().len(), 7);
1817    }
1818
1819    #[test]
1820    fn test_attention_named_parameters_with_biases() {
1821        let config = TransformerConfig::tiny();
1822        let hidden_size = config.hidden_size;
1823        let kv_hidden_size = config.num_kv_heads * config.head_dim();
1824
1825        let mut params = HashMap::new();
1826        params.insert(
1827            "attn.q_proj.weight".to_string(),
1828            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1829        );
1830        params.insert(
1831            "attn.k_proj.weight".to_string(),
1832            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1833        );
1834        params.insert(
1835            "attn.v_proj.weight".to_string(),
1836            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1837        );
1838        params.insert(
1839            "attn.o_proj.weight".to_string(),
1840            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1841        );
1842        params.insert(
1843            "attn.q_proj.bias".to_string(),
1844            Tensor::from_vec(vec![0.01; hidden_size], true),
1845        );
1846        params.insert(
1847            "attn.k_proj.bias".to_string(),
1848            Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1849        );
1850        params.insert(
1851            "attn.v_proj.bias".to_string(),
1852            Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1853        );
1854
1855        let attn = MultiHeadAttention::from_params(&config, &params, "attn").expect("should load");
1856        let named = attn.named_parameters("attn");
1857        assert_eq!(named.len(), 7);
1858        let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
1859        assert!(names.contains(&"attn.q_proj.bias"));
1860        assert!(names.contains(&"attn.k_proj.bias"));
1861        assert!(names.contains(&"attn.v_proj.bias"));
1862    }
1863
1864    #[test]
1865    fn test_attention_forward_with_biases() {
1866        let config = TransformerConfig::tiny();
1867        let hidden_size = config.hidden_size;
1868        let kv_hidden_size = config.num_kv_heads * config.head_dim();
1869
1870        let mut params = HashMap::new();
1871        params.insert(
1872            "attn.q_proj.weight".to_string(),
1873            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1874        );
1875        params.insert(
1876            "attn.k_proj.weight".to_string(),
1877            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1878        );
1879        params.insert(
1880            "attn.v_proj.weight".to_string(),
1881            Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1882        );
1883        params.insert(
1884            "attn.o_proj.weight".to_string(),
1885            Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1886        );
1887        params
1888            .insert("attn.q_proj.bias".to_string(), Tensor::from_vec(vec![0.5; hidden_size], true));
1889        params.insert(
1890            "attn.k_proj.bias".to_string(),
1891            Tensor::from_vec(vec![0.5; kv_hidden_size], true),
1892        );
1893        params.insert(
1894            "attn.v_proj.bias".to_string(),
1895            Tensor::from_vec(vec![0.5; kv_hidden_size], true),
1896        );
1897
1898        let attn = MultiHeadAttention::from_params(&config, &params, "attn").expect("should load");
1899        let x = Tensor::from_vec(vec![0.1; 2 * hidden_size], false);
1900        let output = attn.forward(&x, 2);
1901        assert_eq!(output.len(), 2 * hidden_size);
1902        assert!(output.data().iter().all(|v| v.is_finite()));
1903    }
1904}