1use crate::autograd::{matmul, matmul_nt, BackwardOp};
6use crate::Tensor;
7use ndarray::Array1;
8use std::cell::RefCell;
9use std::collections::HashMap;
10use std::rc::Rc;
11
12use super::config::TransformerConfig;
13
14fn add_bias(x: &Tensor, bias: &Tensor, seq_len: usize) -> Tensor {
17 let xd = x.data();
18 let x_slice = xd.as_slice().expect("contiguous projection");
19 let bd = bias.data();
20 let b_slice = bd.as_slice().expect("contiguous bias");
21 let dim = b_slice.len();
22 let mut out = Vec::with_capacity(x_slice.len());
23 for s in 0..seq_len {
24 let base = s * dim;
25 for d in 0..dim {
26 out.push(x_slice[base + d] + b_slice[d]);
27 }
28 }
29 Tensor::from_vec(out, x.requires_grad())
30}
31
32fn apply_qk_norm(
37 x: &Tensor,
38 norm_weight: &Tensor,
39 seq_len: usize,
40 num_heads: usize,
41 head_dim: usize,
42) -> Tensor {
43 let xd = x.data();
44 let x_slice = xd.as_slice().expect("contiguous qk");
45 let wd = norm_weight.data();
46 let w_slice = wd.as_slice().expect("contiguous norm weight");
47 let total_dim = num_heads * head_dim;
48 let eps = 1e-6_f32;
49 let mut out = vec![0.0f32; seq_len * total_dim];
50
51 for s in 0..seq_len {
52 for h in 0..num_heads {
53 let offset = s * total_dim + h * head_dim;
54 let mut sum_sq = 0.0f32;
56 for d in 0..head_dim {
57 let v = x_slice[offset + d];
58 sum_sq += v * v;
59 }
60 let rms = (sum_sq / head_dim as f32 + eps).sqrt();
61 let inv_rms = 1.0 / rms;
62 for d in 0..head_dim {
63 out[offset + d] = x_slice[offset + d] * inv_rms * w_slice[d];
64 }
65 }
66 }
67
68 Tensor::from_vec(out, x.requires_grad())
69}
70
71fn apply_rope(
80 x: &Tensor,
81 seq_len: usize,
82 num_heads: usize,
83 head_dim: usize,
84 rope_theta: f32,
85) -> Tensor {
86 let xd = x.data();
87 let x_slice = xd.as_slice().expect("contiguous qk for rope");
88 let total_dim = num_heads * head_dim;
89 let half_dim = head_dim / 2;
90 let mut out = vec![0.0f32; seq_len * total_dim];
91
92 let inv_freq: Vec<f32> =
94 (0..half_dim).map(|i| 1.0 / rope_theta.powf(2.0 * i as f32 / head_dim as f32)).collect();
95
96 for pos in 0..seq_len {
97 for h in 0..num_heads {
98 let offset = pos * total_dim + h * head_dim;
99 for i in 0..half_dim {
100 let freq = pos as f32 * inv_freq[i];
101 let cos_f = freq.cos();
102 let sin_f = freq.sin();
103 let x_first = x_slice[offset + i];
105 let x_second = x_slice[offset + i + half_dim];
106 out[offset + i] = x_first * cos_f - x_second * sin_f;
108 out[offset + i + half_dim] = x_second * cos_f + x_first * sin_f;
109 }
110 }
111 }
112
113 let result = Tensor::from_vec(out, x.requires_grad());
114 contract_post_rope!(result.data().as_slice().unwrap_or(&[]));
115 result
116}
117
118struct AttentionBlockBackward {
126 q: Tensor,
127 k: Tensor,
128 v: Tensor,
129 head_q_tensors: Vec<Tensor>,
130 head_k_tensors: Vec<Tensor>,
131 head_v_tensors: Vec<Tensor>,
132 head_outputs: Vec<Tensor>,
133 head_kv_indices: Vec<usize>,
134 seq_len: usize,
135 head_dim: usize,
136 q_dim: usize,
137 kv_hidden_size: usize,
138 result_grad: Rc<RefCell<Option<Array1<f32>>>>,
139}
140
141impl BackwardOp for AttentionBlockBackward {
142 fn backward(&self) {
143 let Some(grad_out) = self.result_grad.borrow().as_ref().cloned() else { return };
144 let go = grad_out.as_slice().expect("grad contiguous");
145 let h = self.head_dim;
146
147 split_and_backward_heads(go, &self.head_outputs, self.seq_len, h, self.q_dim);
149
150 scatter_head_grads_q(&self.q, &self.head_q_tensors, self.seq_len, h, self.q_dim);
152 scatter_head_grads_kv(
153 &self.k,
154 &self.head_k_tensors,
155 &self.head_kv_indices,
156 self.seq_len,
157 h,
158 self.kv_hidden_size,
159 );
160 scatter_head_grads_kv(
161 &self.v,
162 &self.head_v_tensors,
163 &self.head_kv_indices,
164 self.seq_len,
165 h,
166 self.kv_hidden_size,
167 );
168
169 for proj in [&self.q, &self.k, &self.v] {
171 if let Some(op) = proj.backward_op() {
172 op.backward();
173 }
174 }
175 }
176}
177
178fn split_and_backward_heads(
180 go: &[f32],
181 head_outputs: &[Tensor],
182 seq_len: usize,
183 head_dim: usize,
184 q_dim: usize,
185) {
186 for (head_idx, head_out) in head_outputs.iter().enumerate() {
187 let mut grad_head = vec![0.0_f32; seq_len * head_dim];
188 for s in 0..seq_len {
189 let src_base = s * q_dim + head_idx * head_dim;
190 let dst_base = s * head_dim;
191 grad_head[dst_base..dst_base + head_dim]
192 .copy_from_slice(&go[src_base..src_base + head_dim]);
193 }
194 head_out.accumulate_grad(Array1::from(grad_head));
195 if let Some(op) = head_out.backward_op() {
196 op.backward();
197 }
198 }
199}
200
201fn scatter_head_grads_q(
203 q: &Tensor,
204 head_q_tensors: &[Tensor],
205 seq_len: usize,
206 head_dim: usize,
207 q_dim: usize,
208) {
209 if !q.requires_grad() {
210 return;
211 }
212 let mut grad_q = vec![0.0_f32; seq_len * q_dim];
213 for (head_idx, head_q) in head_q_tensors.iter().enumerate() {
214 if let Some(hgrad) = head_q.grad() {
215 let hg = hgrad.as_slice().expect("contiguous");
216 for s in 0..seq_len {
217 let src_base = s * head_dim;
218 let dst_base = s * q_dim + head_idx * head_dim;
219 for d in 0..head_dim {
220 grad_q[dst_base + d] += hg[src_base + d];
221 }
222 }
223 }
224 }
225 q.accumulate_grad(Array1::from(grad_q));
226}
227
228fn scatter_head_grads_kv(
230 target: &Tensor,
231 head_tensors: &[Tensor],
232 kv_indices: &[usize],
233 seq_len: usize,
234 head_dim: usize,
235 kv_hidden_size: usize,
236) {
237 if !target.requires_grad() {
238 return;
239 }
240 let mut grad = vec![0.0_f32; seq_len * kv_hidden_size];
241 for (head_idx, head_t) in head_tensors.iter().enumerate() {
242 let kv_h = kv_indices[head_idx];
243 if let Some(hgrad) = head_t.grad() {
244 let hg = hgrad.as_slice().expect("contiguous");
245 for s in 0..seq_len {
246 let src_base = s * head_dim;
247 let dst_base = s * kv_hidden_size + kv_h * head_dim;
248 for d in 0..head_dim {
249 grad[dst_base + d] += hg[src_base + d];
250 }
251 }
252 }
253 }
254 target.accumulate_grad(Array1::from(grad));
255}
256
257pub struct MultiHeadAttention {
259 config: TransformerConfig,
261 pub w_q: Tensor,
263 pub w_k: Tensor,
265 pub w_v: Tensor,
267 pub w_o: Tensor,
269 pub b_q: Option<Tensor>,
271 pub b_k: Option<Tensor>,
273 pub b_v: Option<Tensor>,
275 pub q_norm: Option<Tensor>,
277 pub k_norm: Option<Tensor>,
279}
280
281impl MultiHeadAttention {
282 pub fn new(config: &TransformerConfig) -> Self {
299 use super::init::{get_init_seed, rand_normal_seeded};
300 let hidden_size = config.hidden_size;
301 let q_dim = config.q_dim();
302 let kv_hidden_size = config.num_kv_heads * config.head_dim();
303 let seed = get_init_seed();
304
305 let (b_q, b_k, b_v) = if config.use_bias {
306 (
307 Some(Tensor::from_vec(vec![0.0_f32; q_dim], true)),
308 Some(Tensor::from_vec(vec![0.0_f32; kv_hidden_size], true)),
309 Some(Tensor::from_vec(vec![0.0_f32; kv_hidden_size], true)),
310 )
311 } else {
312 (None, None, None)
313 };
314
315 Self {
317 config: config.clone(),
318 w_q: Tensor::from_vec(rand_normal_seeded(q_dim * hidden_size, seed, "w_q"), true),
319 w_k: Tensor::from_vec(
320 rand_normal_seeded(kv_hidden_size * hidden_size, seed, "w_k"),
321 true,
322 ),
323 w_v: Tensor::from_vec(
324 rand_normal_seeded(kv_hidden_size * hidden_size, seed, "w_v"),
325 true,
326 ),
327 w_o: Tensor::from_vec(rand_normal_seeded(hidden_size * q_dim, seed, "w_o"), true),
328 b_q,
329 b_k,
330 b_v,
331 q_norm: None,
332 k_norm: None,
333 }
334 }
335
336 pub fn from_params(
347 config: &TransformerConfig,
348 params: &HashMap<String, Tensor>,
349 prefix: &str,
350 ) -> Option<Self> {
351 let w_q = params.get(&format!("{prefix}.q_proj.weight"))?.clone();
352 let w_k = params.get(&format!("{prefix}.k_proj.weight"))?.clone();
353 let w_v = params.get(&format!("{prefix}.v_proj.weight"))?.clone();
354 let w_o = params.get(&format!("{prefix}.o_proj.weight"))?.clone();
355
356 let hidden = config.hidden_size;
357 let q_dim = config.q_dim();
358 let kv_hidden = config.num_kv_heads * config.head_dim();
359
360 let checks: &[(&str, &Tensor, usize)] = &[
363 ("q_proj", &w_q, q_dim * hidden),
364 ("k_proj", &w_k, kv_hidden * hidden),
365 ("v_proj", &w_v, kv_hidden * hidden),
366 ("o_proj", &w_o, hidden * q_dim),
367 ];
368 for &(name, tensor, expected) in checks {
369 if tensor.len() != expected {
370 eprintln!(
371 "[PMAT-331] {prefix}.{name}: shape mismatch — got {} elements, expected {expected}",
372 tensor.len()
373 );
374 return None;
375 }
376 }
377
378 let b_q = params.get(&format!("{prefix}.q_proj.bias")).cloned();
380 let b_k = params.get(&format!("{prefix}.k_proj.bias")).cloned();
381 let b_v = params.get(&format!("{prefix}.v_proj.bias")).cloned();
382
383 let q_norm = params.get(&format!("{prefix}.q_norm.weight")).cloned();
385 let k_norm = params.get(&format!("{prefix}.k_norm.weight")).cloned();
386
387 Some(Self { config: config.clone(), w_q, w_k, w_v, w_o, b_q, b_k, b_v, q_norm, k_norm })
388 }
389
390 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
399 contract_pre_attention!(x.data());
400 let hidden_size = self.config.hidden_size;
401 let num_heads = self.config.num_attention_heads;
402 let num_kv_heads = self.config.num_kv_heads;
403 let head_dim = self.config.head_dim();
404 let q_dim = self.config.q_dim();
405 let kv_hidden_size = num_kv_heads * head_dim;
406
407 let mut q = matmul_nt(x, &self.w_q, seq_len, hidden_size, q_dim);
409 let mut k = matmul_nt(x, &self.w_k, seq_len, hidden_size, kv_hidden_size);
410 let mut v = matmul_nt(x, &self.w_v, seq_len, hidden_size, kv_hidden_size);
411
412 if let Some(ref b_q) = self.b_q {
414 q = add_bias(&q, b_q, seq_len);
415 }
416 if let Some(ref b_k) = self.b_k {
417 k = add_bias(&k, b_k, seq_len);
418 }
419 if let Some(ref b_v) = self.b_v {
420 v = add_bias(&v, b_v, seq_len);
421 }
422
423 if let Some(ref qn) = self.q_norm {
425 q = apply_qk_norm(&q, qn, seq_len, num_heads, head_dim);
426 }
427 if let Some(ref kn) = self.k_norm {
428 k = apply_qk_norm(&k, kn, seq_len, num_kv_heads, head_dim);
429 }
430
431 if self.config.rope_theta > 0.0 {
434 q = apply_rope(&q, seq_len, num_heads, head_dim, self.config.rope_theta);
435 k = apply_rope(&k, seq_len, num_kv_heads, head_dim, self.config.rope_theta);
436 }
437
438 let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
439 let heads_per_kv = num_heads / num_kv_heads;
440
441 let q_data = q.data();
444 let q_slice = q_data.as_slice().expect("contiguous Q");
445 let k_data = k.data();
446 let k_slice = k_data.as_slice().expect("contiguous K");
447 let v_data = v.data();
448 let v_slice = v_data.as_slice().expect("contiguous V");
449
450 let mut head_q_tensors = Vec::with_capacity(num_heads);
452 let mut head_k_tensors = Vec::with_capacity(num_heads);
453 let mut head_v_tensors = Vec::with_capacity(num_heads);
454 let mut head_outputs = Vec::with_capacity(num_heads);
455 let mut head_kv_indices = Vec::with_capacity(num_heads);
456
457 for h in 0..num_heads {
458 let kv_h = h / heads_per_kv;
459 head_kv_indices.push(kv_h);
460
461 let mut q_head = Vec::with_capacity(seq_len * head_dim);
465 for s in 0..seq_len {
466 let start = s * q_dim + h * head_dim;
467 q_head.extend_from_slice(&q_slice[start..start + head_dim]);
468 }
469
470 let mut k_head = Vec::with_capacity(seq_len * head_dim);
471 for s in 0..seq_len {
472 let start = s * kv_hidden_size + kv_h * head_dim;
473 k_head.extend_from_slice(&k_slice[start..start + head_dim]);
474 }
475
476 let mut v_head = Vec::with_capacity(seq_len * head_dim);
477 for s in 0..seq_len {
478 let start = s * kv_hidden_size + kv_h * head_dim;
479 v_head.extend_from_slice(&v_slice[start..start + head_dim]);
480 }
481
482 let q_tensor = Tensor::from_vec(q_head, requires_grad);
483 let k_tensor = Tensor::from_vec(k_head, requires_grad);
484 let v_tensor = Tensor::from_vec(v_head, requires_grad);
485
486 let attn_out = crate::autograd::attention(
487 &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
488 );
489
490 head_q_tensors.push(q_tensor);
491 head_k_tensors.push(k_tensor);
492 head_v_tensors.push(v_tensor);
493 head_outputs.push(attn_out);
494 }
495
496 let mut concat_output = vec![0.0; seq_len * q_dim];
498 for (h, head_out) in head_outputs.iter().enumerate() {
499 let hd = head_out.data();
500 let hdata = hd.as_slice().expect("contiguous attention output");
501 for s in 0..seq_len {
502 let src_base = s * head_dim;
503 let dst_base = s * q_dim + h * head_dim;
504 concat_output[dst_base..dst_base + head_dim]
505 .copy_from_slice(&hdata[src_base..src_base + head_dim]);
506 }
507 }
508
509 let mut concat_tensor = Tensor::from_vec(concat_output, requires_grad);
510
511 if requires_grad {
512 let backward_op = Rc::new(AttentionBlockBackward {
513 q: q.clone(),
514 k: k.clone(),
515 v: v.clone(),
516 head_q_tensors,
517 head_k_tensors,
518 head_v_tensors,
519 head_outputs,
520 head_kv_indices,
521 seq_len,
522 head_dim,
523 q_dim,
524 kv_hidden_size,
525 result_grad: concat_tensor.grad_cell(),
526 });
527 concat_tensor.set_backward_op(backward_op);
528 }
529
530 let result = matmul_nt(&concat_tensor, &self.w_o, seq_len, q_dim, hidden_size);
532 contract_post_attention!(result.data().as_slice().unwrap_or(&[]));
533 result
534 }
535
536 pub fn forward_with_lora(
549 &self,
550 x: &Tensor,
551 seq_len: usize,
552 lora_a_q: &Tensor,
553 lora_b_q: &Tensor,
555 lora_a_v: &Tensor,
556 lora_b_v: &Tensor,
557 lora_rank: usize,
558 lora_scale: f32,
559 ) -> Tensor {
560 contract_pre_lora_forward!();
561 let hidden_size = self.config.hidden_size;
562 let num_heads = self.config.num_attention_heads;
563 let num_kv_heads = self.config.num_kv_heads;
564 let head_dim = self.config.head_dim();
565 let q_dim = self.config.q_dim();
566 let kv_hidden_size = num_kv_heads * head_dim;
567
568 let q_base = matmul_nt(x, &self.w_q, seq_len, hidden_size, q_dim);
579 let q_mid = crate::autograd::matmul_nt(x, lora_a_q, seq_len, hidden_size, lora_rank);
580 let q_lora = crate::autograd::matmul_nt(&q_mid, lora_b_q, seq_len, lora_rank, q_dim);
581 let q = crate::autograd::add_scaled(&q_base, &q_lora, lora_scale);
582
583 let k = matmul_nt(x, &self.w_k, seq_len, hidden_size, kv_hidden_size);
585
586 let v_base = matmul_nt(x, &self.w_v, seq_len, hidden_size, kv_hidden_size);
588 let v_mid = crate::autograd::matmul_nt(x, lora_a_v, seq_len, hidden_size, lora_rank);
589 let v_lora =
590 crate::autograd::matmul_nt(&v_mid, lora_b_v, seq_len, lora_rank, kv_hidden_size);
591 let v = crate::autograd::add_scaled(&v_base, &v_lora, lora_scale);
592
593 let q = if let Some(ref qn) = self.q_norm {
595 apply_qk_norm(&q, qn, seq_len, num_heads, head_dim)
596 } else {
597 q
598 };
599 let k = if let Some(ref kn) = self.k_norm {
600 apply_qk_norm(&k, kn, seq_len, num_kv_heads, head_dim)
601 } else {
602 k
603 };
604
605 let (q, k) = if self.config.rope_theta > 0.0 {
608 (
609 apply_rope(&q, seq_len, num_heads, head_dim, self.config.rope_theta),
610 apply_rope(&k, seq_len, num_kv_heads, head_dim, self.config.rope_theta),
611 )
612 } else {
613 (q, k)
614 };
615
616 let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
617 let heads_per_kv = num_heads / num_kv_heads;
618
619 let q_data = q.data();
621 let q_slice = q_data.as_slice().expect("contiguous Q");
622 let k_data = k.data();
623 let k_slice = k_data.as_slice().expect("contiguous K");
624 let v_data = v.data();
625 let v_slice = v_data.as_slice().expect("contiguous V");
626
627 let mut head_q_tensors = Vec::with_capacity(num_heads);
629 let mut head_k_tensors = Vec::with_capacity(num_heads);
630 let mut head_v_tensors = Vec::with_capacity(num_heads);
631 let mut head_outputs = Vec::with_capacity(num_heads);
632 let mut head_kv_indices = Vec::with_capacity(num_heads);
633
634 for h in 0..num_heads {
635 let kv_h = h / heads_per_kv;
636 head_kv_indices.push(kv_h);
637
638 let mut q_head = Vec::with_capacity(seq_len * head_dim);
640 for s in 0..seq_len {
641 let start = s * q_dim + h * head_dim;
642 q_head.extend_from_slice(&q_slice[start..start + head_dim]);
643 }
644
645 let mut k_head = Vec::with_capacity(seq_len * head_dim);
646 for s in 0..seq_len {
647 let start = s * kv_hidden_size + kv_h * head_dim;
648 k_head.extend_from_slice(&k_slice[start..start + head_dim]);
649 }
650
651 let mut v_head = Vec::with_capacity(seq_len * head_dim);
652 for s in 0..seq_len {
653 let start = s * kv_hidden_size + kv_h * head_dim;
654 v_head.extend_from_slice(&v_slice[start..start + head_dim]);
655 }
656
657 let q_tensor = Tensor::from_vec(q_head, requires_grad);
658 let k_tensor = Tensor::from_vec(k_head, requires_grad);
659 let v_tensor = Tensor::from_vec(v_head, requires_grad);
660
661 let attn_out = crate::autograd::attention(
662 &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
663 );
664
665 head_q_tensors.push(q_tensor);
666 head_k_tensors.push(k_tensor);
667 head_v_tensors.push(v_tensor);
668 head_outputs.push(attn_out);
669 }
670
671 let mut concat_output = vec![0.0; seq_len * q_dim];
673 for (h, head_out) in head_outputs.iter().enumerate() {
674 let hd = head_out.data();
675 let hdata = hd.as_slice().expect("contiguous attention output");
676 for s in 0..seq_len {
677 let src_base = s * head_dim;
678 let dst_base = s * q_dim + h * head_dim;
679 concat_output[dst_base..dst_base + head_dim]
680 .copy_from_slice(&hdata[src_base..src_base + head_dim]);
681 }
682 }
683
684 let mut concat_tensor = Tensor::from_vec(concat_output, requires_grad);
685
686 if requires_grad {
687 let backward_op = Rc::new(AttentionBlockBackward {
688 q: q.clone(),
689 k: k.clone(),
690 v: v.clone(),
691 head_q_tensors,
692 head_k_tensors,
693 head_v_tensors,
694 head_outputs,
695 head_kv_indices,
696 seq_len,
697 head_dim,
698 q_dim,
699 kv_hidden_size,
700 result_grad: concat_tensor.grad_cell(),
701 });
702 concat_tensor.set_backward_op(backward_op);
703 }
704
705 let result = matmul_nt(&concat_tensor, &self.w_o, seq_len, q_dim, hidden_size);
707 contract_post_lora_forward!(result);
708 result
709 }
710
711 pub fn parameters(&self) -> Vec<&Tensor> {
713 let mut params = vec![&self.w_q, &self.w_k, &self.w_v, &self.w_o];
714 if let Some(ref b) = self.b_q {
715 params.push(b);
716 }
717 if let Some(ref b) = self.b_k {
718 params.push(b);
719 }
720 if let Some(ref b) = self.b_v {
721 params.push(b);
722 }
723 params
724 }
725
726 pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
728 let mut params = vec![&mut self.w_q, &mut self.w_k, &mut self.w_v, &mut self.w_o];
729 if let Some(ref mut b) = self.b_q {
730 params.push(b);
731 }
732 if let Some(ref mut b) = self.b_k {
733 params.push(b);
734 }
735 if let Some(ref mut b) = self.b_v {
736 params.push(b);
737 }
738 params
739 }
740
741 pub fn has_biases(&self) -> bool {
743 self.b_q.is_some()
744 }
745
746 pub fn named_parameters(&self, prefix: &str) -> Vec<(String, &Tensor)> {
748 let mut params = vec![
749 (format!("{prefix}.q_proj.weight"), &self.w_q),
750 (format!("{prefix}.k_proj.weight"), &self.w_k),
751 (format!("{prefix}.v_proj.weight"), &self.w_v),
752 (format!("{prefix}.o_proj.weight"), &self.w_o),
753 ];
754 if let Some(ref b) = self.b_q {
755 params.push((format!("{prefix}.q_proj.bias"), b));
756 }
757 if let Some(ref b) = self.b_k {
758 params.push((format!("{prefix}.k_proj.bias"), b));
759 }
760 if let Some(ref b) = self.b_v {
761 params.push((format!("{prefix}.v_proj.bias"), b));
762 }
763 params
764 }
765
766 pub fn set_named_parameter(&mut self, suffix: &str, value: Tensor) -> bool {
777 match suffix {
778 "self_attn.q_proj.weight" => {
779 self.w_q = value;
780 true
781 }
782 "self_attn.k_proj.weight" => {
783 self.w_k = value;
784 true
785 }
786 "self_attn.v_proj.weight" => {
787 self.w_v = value;
788 true
789 }
790 "self_attn.o_proj.weight" => {
791 self.w_o = value;
792 true
793 }
794 "self_attn.q_proj.bias" => {
795 if self.b_q.is_some() {
796 self.b_q = Some(value);
797 true
798 } else {
799 false
800 }
801 }
802 "self_attn.k_proj.bias" => {
803 if self.b_k.is_some() {
804 self.b_k = Some(value);
805 true
806 } else {
807 false
808 }
809 }
810 "self_attn.v_proj.bias" => {
811 if self.b_v.is_some() {
812 self.b_v = Some(value);
813 true
814 } else {
815 false
816 }
817 }
818 _ => false,
819 }
820 }
821}
822
823pub struct LoRAProjection {
828 pub base_weight: Tensor,
830 pub lora_a: Tensor,
832 pub lora_b: Tensor,
834 pub d_in: usize,
836 pub d_out: usize,
838 pub rank: usize,
840 pub scale: f32,
842}
843
844impl LoRAProjection {
845 pub fn new(base_weight: Tensor, d_in: usize, d_out: usize, rank: usize, alpha: f32) -> Self {
854 assert_eq!(base_weight.len(), d_in * d_out, "Base weight size mismatch");
855
856 let mut base_weight = base_weight;
858 base_weight.set_requires_grad(false);
859
860 let lora_a = Tensor::from_vec(
862 (0..d_in * rank).map(|i| (i as f32 * 0.123).sin() * 0.01).collect(),
863 true, );
865
866 let lora_b = Tensor::zeros(rank * d_out, true);
868
869 Self { base_weight, lora_a, lora_b, d_in, d_out, rank, scale: alpha / rank as f32 }
870 }
871
872 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
883 let base_out = matmul(x, &self.base_weight, seq_len, self.d_in, self.d_out);
885
886 let lora_intermediate = matmul(x, &self.lora_a, seq_len, self.d_in, self.rank);
889
890 let lora_out = matmul(&lora_intermediate, &self.lora_b, seq_len, self.rank, self.d_out);
892
893 crate::autograd::add_scaled(&base_out, &lora_out, self.scale)
896 }
897
898 pub fn lora_params(&self) -> Vec<&Tensor> {
900 vec![&self.lora_a, &self.lora_b]
901 }
902
903 pub fn lora_params_mut(&mut self) -> Vec<&mut Tensor> {
905 vec![&mut self.lora_a, &mut self.lora_b]
906 }
907}
908
909pub struct MultiHeadAttentionWithLoRA {
913 pub config: TransformerConfig,
915 pub q_proj: LoRAProjection,
917 pub k_proj: LoRAProjection,
919 pub v_proj: LoRAProjection,
921 pub o_proj: LoRAProjection,
923}
924
925impl MultiHeadAttentionWithLoRA {
926 pub fn from_attention(attn: &MultiHeadAttention, rank: usize, alpha: f32) -> Self {
933 let hidden_size = attn.config.hidden_size;
934 let q_dim = attn.config.q_dim();
935 let kv_hidden_size = attn.config.num_kv_heads * attn.config.head_dim();
936
937 Self {
938 config: attn.config.clone(),
939 q_proj: LoRAProjection::new(attn.w_q.clone(), hidden_size, q_dim, rank, alpha),
940 k_proj: LoRAProjection::new(attn.w_k.clone(), hidden_size, kv_hidden_size, rank, alpha),
941 v_proj: LoRAProjection::new(attn.w_v.clone(), hidden_size, kv_hidden_size, rank, alpha),
942 o_proj: LoRAProjection::new(attn.w_o.clone(), q_dim, hidden_size, rank, alpha),
943 }
944 }
945
946 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
950 let num_heads = self.config.num_attention_heads;
951 let num_kv_heads = self.config.num_kv_heads;
952 let head_dim = self.config.head_dim();
953 let q_dim = self.config.q_dim();
954 let kv_hidden_size = num_kv_heads * head_dim;
955
956 let q = self.q_proj.forward(x, seq_len);
958 let k = self.k_proj.forward(x, seq_len);
959 let v = self.v_proj.forward(x, seq_len);
960
961 let mut attn_outputs = Vec::with_capacity(num_heads * seq_len * head_dim);
963 let heads_per_kv = num_heads / num_kv_heads;
964
965 let q_data = q.data();
967 let q_slice = q_data.as_slice().expect("contiguous Q tensor");
968 let k_data = k.data();
969 let k_slice = k_data.as_slice().expect("contiguous K tensor");
970 let v_data = v.data();
971 let v_slice = v_data.as_slice().expect("contiguous V tensor");
972
973 for h in 0..num_heads {
974 let kv_h = h / heads_per_kv;
975
976 let mut q_head = Vec::with_capacity(seq_len * head_dim);
978 for s in 0..seq_len {
979 let start = s * q_dim + h * head_dim;
980 q_head.extend_from_slice(&q_slice[start..start + head_dim]);
981 }
982
983 let mut k_head = Vec::with_capacity(seq_len * head_dim);
984 for s in 0..seq_len {
985 let start = s * kv_hidden_size + kv_h * head_dim;
986 k_head.extend_from_slice(&k_slice[start..start + head_dim]);
987 }
988
989 let mut v_head = Vec::with_capacity(seq_len * head_dim);
990 for s in 0..seq_len {
991 let start = s * kv_hidden_size + kv_h * head_dim;
992 v_head.extend_from_slice(&v_slice[start..start + head_dim]);
993 }
994
995 let q_tensor = Tensor::from_vec(q_head, false);
997 let k_tensor = Tensor::from_vec(k_head, false);
998 let v_tensor = Tensor::from_vec(v_head, false);
999
1000 let attn_out = crate::autograd::attention(
1001 &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
1002 );
1003
1004 attn_outputs.extend_from_slice(
1005 attn_out.data().as_slice().expect("contiguous attention output"),
1006 );
1007 }
1008
1009 let mut concat_output = vec![0.0; seq_len * q_dim];
1011 for h in 0..num_heads {
1012 for s in 0..seq_len {
1013 let src_idx = h * seq_len * head_dim + s * head_dim;
1014 let dst_idx = s * q_dim + h * head_dim;
1015 concat_output[dst_idx..dst_idx + head_dim]
1016 .copy_from_slice(&attn_outputs[src_idx..src_idx + head_dim]);
1017 }
1018 }
1019
1020 let concat_tensor = Tensor::from_vec(concat_output, true);
1021
1022 self.o_proj.forward(&concat_tensor, seq_len)
1024 }
1025
1026 pub fn lora_params(&self) -> Vec<&Tensor> {
1028 let mut params = Vec::new();
1029 params.extend(self.q_proj.lora_params());
1030 params.extend(self.k_proj.lora_params());
1031 params.extend(self.v_proj.lora_params());
1032 params.extend(self.o_proj.lora_params());
1033 params
1034 }
1035
1036 pub fn lora_params_mut(&mut self) -> Vec<&mut Tensor> {
1038 let mut params = Vec::new();
1039 params.extend(self.q_proj.lora_params_mut());
1040 params.extend(self.k_proj.lora_params_mut());
1041 params.extend(self.v_proj.lora_params_mut());
1042 params.extend(self.o_proj.lora_params_mut());
1043 params
1044 }
1045
1046 pub fn lora_param_count(&self) -> usize {
1048 let hidden = self.config.hidden_size;
1050 let kv_hidden = self.config.num_kv_heads * self.config.head_dim();
1051 let rank = self.q_proj.rank;
1052
1053 (hidden * rank + rank * hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * hidden) }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use super::*;
1067
1068 #[test]
1069 fn test_multi_head_attention_tiny() {
1070 let config = TransformerConfig::tiny();
1071 let attn = MultiHeadAttention::new(&config);
1072 let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
1073 let output = attn.forward(&x, 2);
1074 assert_eq!(output.len(), 2 * config.hidden_size);
1075 }
1076
1077 #[test]
1078 fn test_multi_head_attention_parameters() {
1079 let config = TransformerConfig::tiny();
1080 let attn = MultiHeadAttention::new(&config);
1081 let params = attn.parameters();
1082 assert_eq!(params.len(), 4); }
1084
1085 #[test]
1086 fn test_attention_longer_sequence() {
1087 let config = TransformerConfig::tiny();
1088 let attn = MultiHeadAttention::new(&config);
1089 let x = Tensor::from_vec(vec![0.1; 8 * config.hidden_size], true);
1090 let output = attn.forward(&x, 8);
1091 assert_eq!(output.len(), 8 * config.hidden_size);
1092 }
1093
1094 #[test]
1095 fn test_attention_weight_sizes() {
1096 let config = TransformerConfig::tiny();
1097 let attn = MultiHeadAttention::new(&config);
1098 let kv_hidden = config.num_kv_heads * config.head_dim();
1099 assert_eq!(attn.w_q.len(), config.hidden_size * config.hidden_size);
1100 assert_eq!(attn.w_k.len(), config.hidden_size * kv_hidden);
1101 assert_eq!(attn.w_v.len(), config.hidden_size * kv_hidden);
1102 assert_eq!(attn.w_o.len(), config.hidden_size * config.hidden_size);
1103 }
1104
1105 #[test]
1106 fn test_multi_head_attention_from_params_success() {
1107 let config = TransformerConfig::tiny();
1108 let hidden_size = config.hidden_size;
1109 let kv_hidden_size = config.num_kv_heads * config.head_dim();
1110
1111 let mut params = HashMap::new();
1112 params.insert(
1113 "attn.q_proj.weight".to_string(),
1114 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1115 );
1116 params.insert(
1117 "attn.k_proj.weight".to_string(),
1118 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1119 );
1120 params.insert(
1121 "attn.v_proj.weight".to_string(),
1122 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1123 );
1124 params.insert(
1125 "attn.o_proj.weight".to_string(),
1126 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1127 );
1128
1129 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
1130 assert!(attn.is_some());
1131 let attn = attn.expect("operation should succeed");
1132 assert_eq!(attn.w_q.len(), hidden_size * hidden_size);
1133 }
1134
1135 #[test]
1136 fn test_multi_head_attention_from_params_missing_key() {
1137 let config = TransformerConfig::tiny();
1138 let hidden_size = config.hidden_size;
1139
1140 let mut params = HashMap::new();
1141 params.insert(
1142 "attn.q_proj.weight".to_string(),
1143 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1144 );
1145 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
1148 assert!(attn.is_none());
1149 }
1150
1151 #[test]
1152 fn test_attention_projections_backward() {
1153 let config = TransformerConfig::tiny();
1156 let attn = MultiHeadAttention::new(&config);
1157 let hidden_size = config.hidden_size;
1158 let seq_len = 2;
1159
1160 let x = Tensor::from_vec(vec![0.1; seq_len * hidden_size], true);
1161
1162 let mut q = crate::autograd::matmul(&x, &attn.w_q, seq_len, hidden_size, hidden_size);
1164 let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1165 crate::autograd::backward(&mut q, Some(grad_out));
1166
1167 assert!(attn.w_q.grad().is_some());
1168 let grad_q = attn.w_q.grad().expect("gradient should be available");
1169 assert!(grad_q.iter().all(|&v| v.is_finite()));
1170 }
1171
1172 #[test]
1173 fn test_output_projection_backward() {
1174 let config = TransformerConfig::tiny();
1176 let attn = MultiHeadAttention::new(&config);
1177 let hidden_size = config.hidden_size;
1178 let seq_len = 2;
1179
1180 let concat_out = Tensor::from_vec(vec![0.1; seq_len * hidden_size], true);
1182
1183 let mut output =
1185 crate::autograd::matmul(&concat_out, &attn.w_o, seq_len, hidden_size, hidden_size);
1186
1187 let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1188 crate::autograd::backward(&mut output, Some(grad_out));
1189
1190 assert!(attn.w_o.grad().is_some());
1191 let grad_o = attn.w_o.grad().expect("gradient should be available");
1192 assert!(grad_o.iter().all(|&v| v.is_finite()));
1193 let sum: f32 = grad_o.iter().map(|v| v.abs()).sum();
1194 assert!(sum > 0.0, "Output projection gradient should not be all zero");
1195 }
1196
1197 #[test]
1203 #[ignore = "apply_rope() severs autograd chain — needs backward op (ENT-272)"]
1204 fn test_attention_full_forward_qkv_gradients() {
1205 let config = TransformerConfig::tiny();
1206 let attn = MultiHeadAttention::new(&config);
1207 let hidden_size = config.hidden_size;
1208 let seq_len = 3;
1209
1210 let x_data: Vec<f32> =
1213 (0..seq_len * hidden_size).map(|i| ((i as f32) * 0.17).sin() * 0.5).collect();
1214 let x = Tensor::from_vec(x_data, true);
1215 let mut output = attn.forward(&x, seq_len);
1216
1217 let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1218 crate::autograd::backward(&mut output, Some(grad_out));
1219
1220 for (name, param) in
1222 [("w_q", &attn.w_q), ("w_k", &attn.w_k), ("w_v", &attn.w_v), ("w_o", &attn.w_o)]
1223 {
1224 assert!(
1225 param.grad().is_some(),
1226 "ALB-038: {name} must have gradient after full attention forward"
1227 );
1228 let grad = param.grad().expect("gradient available");
1229 assert!(grad.iter().all(|&v| v.is_finite()), "ALB-038: {name} gradient must be finite");
1230 assert!(
1231 grad.iter().any(|&v| v.abs() > 1e-10),
1232 "ALB-038: {name} gradient must be non-zero"
1233 );
1234 }
1235
1236 assert!(x.grad().is_some(), "ALB-038: input x must have gradient");
1238 }
1239
1240 #[test]
1245 fn test_lora_projection_new() {
1246 let d_in = 32;
1247 let d_out = 16;
1248 let rank = 4;
1249 let alpha = 8.0;
1250
1251 let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1252 let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, alpha);
1253
1254 assert_eq!(lora.d_in, d_in);
1255 assert_eq!(lora.d_out, d_out);
1256 assert_eq!(lora.rank, rank);
1257 assert!((lora.scale - 2.0).abs() < 1e-6); assert_eq!(lora.lora_a.len(), d_in * rank);
1259 assert_eq!(lora.lora_b.len(), rank * d_out);
1260 }
1261
1262 #[test]
1263 fn test_lora_projection_forward() {
1264 let d_in = 32;
1265 let d_out = 16;
1266 let rank = 4;
1267 let alpha = 8.0;
1268 let seq_len = 2;
1269
1270 let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1271 let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, alpha);
1272
1273 let x = Tensor::from_vec(vec![0.1; seq_len * d_in], false);
1274 let output = lora.forward(&x, seq_len);
1275
1276 assert_eq!(output.len(), seq_len * d_out);
1277 assert!(output.data().iter().all(|&v| v.is_finite()));
1279 }
1280
1281 #[test]
1282 fn test_lora_projection_params() {
1283 let d_in = 32;
1284 let d_out = 16;
1285 let rank = 4;
1286
1287 let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1288 let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1289
1290 let params = lora.lora_params();
1291 assert_eq!(params.len(), 2); }
1293
1294 #[test]
1295 fn test_lora_projection_params_mut() {
1296 let d_in = 32;
1297 let d_out = 16;
1298 let rank = 4;
1299
1300 let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1301 let mut lora = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1302
1303 let params = lora.lora_params_mut();
1304 assert_eq!(params.len(), 2);
1305 }
1306
1307 #[test]
1308 #[should_panic(expected = "Base weight size mismatch")]
1309 fn test_lora_projection_size_mismatch() {
1310 let d_in = 32;
1311 let d_out = 16;
1312 let rank = 4;
1313
1314 let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out + 1], false);
1316 let _ = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1317 }
1318
1319 #[test]
1324 fn test_mha_with_lora_creation() {
1325 let config = TransformerConfig::tiny();
1326 let attn = MultiHeadAttention::new(&config);
1327 let rank = 4;
1328 let alpha = 8.0;
1329
1330 let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, rank, alpha);
1331
1332 assert_eq!(lora_attn.q_proj.rank, rank);
1333 assert_eq!(lora_attn.k_proj.rank, rank);
1334 assert_eq!(lora_attn.v_proj.rank, rank);
1335 assert_eq!(lora_attn.o_proj.rank, rank);
1336 }
1337
1338 #[test]
1339 fn test_mha_with_lora_forward() {
1340 let config = TransformerConfig::tiny();
1341 let attn = MultiHeadAttention::new(&config);
1342 let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1343
1344 let seq_len = 2;
1345 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], false);
1346 let output = lora_attn.forward(&x, seq_len);
1347
1348 assert_eq!(output.len(), seq_len * config.hidden_size);
1349 assert!(output.data().iter().all(|&v| v.is_finite()));
1351 }
1352
1353 #[test]
1354 fn test_mha_with_lora_params() {
1355 let config = TransformerConfig::tiny();
1356 let attn = MultiHeadAttention::new(&config);
1357 let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1358
1359 let params = lora_attn.lora_params();
1360 assert_eq!(params.len(), 8);
1362 }
1363
1364 #[test]
1365 fn test_mha_with_lora_params_mut() {
1366 let config = TransformerConfig::tiny();
1367 let attn = MultiHeadAttention::new(&config);
1368 let mut lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1369
1370 let params = lora_attn.lora_params_mut();
1371 assert_eq!(params.len(), 8);
1372 }
1373
1374 #[test]
1375 fn test_mha_with_lora_param_count() {
1376 let config = TransformerConfig::tiny();
1377 let attn = MultiHeadAttention::new(&config);
1378 let rank = 4;
1379 let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, rank, 8.0);
1380
1381 let param_count = lora_attn.lora_param_count();
1382
1383 let hidden = config.hidden_size;
1385 let kv_hidden = config.num_kv_heads * config.head_dim();
1386 let expected = (hidden * rank + rank * hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * hidden); assert_eq!(param_count, expected);
1392 assert!(param_count > 0);
1393 }
1394
1395 #[test]
1396 fn test_mha_with_lora_longer_sequence() {
1397 let config = TransformerConfig::tiny();
1398 let attn = MultiHeadAttention::new(&config);
1399 let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1400
1401 let seq_len = 8;
1402 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], false);
1403 let output = lora_attn.forward(&x, seq_len);
1404
1405 assert_eq!(output.len(), seq_len * config.hidden_size);
1406 }
1407
1408 #[test]
1409 fn test_parameters_mut() {
1410 let config = TransformerConfig::tiny();
1411 let mut attn = MultiHeadAttention::new(&config);
1412
1413 let params = attn.parameters_mut();
1414 assert_eq!(params.len(), 4);
1415 }
1416
1417 #[test]
1442 fn falsify_a1e_from_params_rejects_wrong_shape_q_weight() {
1443 let config = TransformerConfig::tiny();
1444 let hidden_size = config.hidden_size;
1445 let kv_hidden_size = config.num_kv_heads * config.head_dim();
1446
1447 let mut params = HashMap::new();
1448 params.insert("attn.q_proj.weight".to_string(), Tensor::from_vec(vec![0.1; 50], true));
1450 params.insert(
1452 "attn.k_proj.weight".to_string(),
1453 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1454 );
1455 params.insert(
1456 "attn.v_proj.weight".to_string(),
1457 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1458 );
1459 params.insert(
1460 "attn.o_proj.weight".to_string(),
1461 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1462 );
1463
1464 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
1465 assert!(
1467 attn.is_none(),
1468 "FALSIFY-A1e: PMAT-331 fix — from_params MUST reject wrong-shape q_proj"
1469 );
1470 }
1471
1472 #[test]
1477 fn falsify_a2e_gqa_init_correct_kv_dimensions() {
1478 let mut config = TransformerConfig::tiny();
1479 config.num_kv_heads = 1; let attn = MultiHeadAttention::new(&config);
1482 let head_dim = config.head_dim();
1483 let kv_hidden = config.num_kv_heads * head_dim; assert_eq!(
1487 attn.w_q.len(),
1488 config.hidden_size * config.hidden_size,
1489 "FALSIFY-A2e: Q projection must be hidden*hidden"
1490 );
1491
1492 assert_eq!(
1494 attn.w_k.len(),
1495 config.hidden_size * kv_hidden,
1496 "FALSIFY-A2e: K projection must use num_kv_heads, not num_heads"
1497 );
1498
1499 assert_eq!(
1501 attn.w_v.len(),
1502 config.hidden_size * kv_hidden,
1503 "FALSIFY-A2e: V projection must use num_kv_heads, not num_heads"
1504 );
1505
1506 assert_eq!(
1508 attn.w_o.len(),
1509 config.hidden_size * config.hidden_size,
1510 "FALSIFY-A2e: O projection must be hidden*hidden"
1511 );
1512
1513 assert!(
1515 attn.w_k.len() < attn.w_q.len(),
1516 "FALSIFY-A2e: For GQA, K weight must be smaller than Q weight"
1517 );
1518 }
1519
1520 #[test]
1525 fn falsify_a3e_gqa_forward_correct_output_dims() {
1526 let mut config = TransformerConfig::tiny();
1527 config.num_kv_heads = 1; let attn = MultiHeadAttention::new(&config);
1530 let seq_len = 3;
1531 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1532 let output = attn.forward(&x, seq_len);
1533
1534 assert_eq!(
1535 output.len(),
1536 seq_len * config.hidden_size,
1537 "FALSIFY-A3e: GQA output must be seq_len * hidden_size, not seq_len * kv_hidden"
1538 );
1539 }
1540
1541 #[test]
1545 fn falsify_a4e_init_produces_valid_attention_weights() {
1546 let config = TransformerConfig::tiny();
1547 let attn = MultiHeadAttention::new(&config);
1548
1549 for (name, w) in
1550 [("w_q", &attn.w_q), ("w_k", &attn.w_k), ("w_v", &attn.w_v), ("w_o", &attn.w_o)]
1551 {
1552 let data = w.data();
1553 let slice = data.as_slice().expect("data as slice");
1554
1555 let nan_count = slice.iter().filter(|v| v.is_nan()).count();
1557 assert_eq!(nan_count, 0, "FALSIFY-A4e: {name} init must not contain NaN");
1558
1559 let inf_count = slice.iter().filter(|v| v.is_infinite()).count();
1561 assert_eq!(inf_count, 0, "FALSIFY-A4e: {name} init must not contain Inf");
1562
1563 let min = slice.iter().copied().fold(f32::INFINITY, f32::min);
1565 let max = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1566 assert!(
1567 (max - min).abs() > 1e-6,
1568 "FALSIFY-A4e: {name} init values are constant ({min}..{max}) — degenerate weight"
1569 );
1570 }
1571 }
1572
1573 #[test]
1578 fn falsify_a5e_forward_produces_finite_output() {
1579 let config = TransformerConfig::tiny();
1580 let attn = MultiHeadAttention::new(&config);
1581 let seq_len = 4;
1582 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1583 let output = attn.forward(&x, seq_len);
1584
1585 let data = output.data();
1586 let nan_count = data.iter().filter(|v| v.is_nan()).count();
1587 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
1588 assert_eq!(nan_count, 0, "FALSIFY-A5e: Attention output must not contain NaN");
1589 assert_eq!(inf_count, 0, "FALSIFY-A5e: Attention output must not contain Inf");
1590 }
1591
1592 #[test]
1609 fn falsify_gq_001e_output_shape() {
1610 for (num_heads, num_kv_heads) in [(2, 2), (4, 2), (4, 1), (2, 1)] {
1611 let mut config = TransformerConfig::tiny();
1612 config.num_attention_heads = num_heads;
1613 config.num_kv_heads = num_kv_heads;
1614
1615 let attn = MultiHeadAttention::new(&config);
1616 let seq_len = 3;
1617 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1618 let output = attn.forward(&x, seq_len);
1619
1620 assert_eq!(
1621 output.len(),
1622 seq_len * config.hidden_size,
1623 "FALSIFIED GQ-001e: output len mismatch for heads={num_heads},kv={num_kv_heads}"
1624 );
1625 }
1626 }
1627
1628 #[test]
1630 fn falsify_gq_002e_mha_degeneration() {
1631 let config = TransformerConfig::tiny(); assert_eq!(config.num_attention_heads, config.num_kv_heads);
1633
1634 let attn = MultiHeadAttention::new(&config);
1635 let seq_len = 4;
1636 let x = Tensor::from_vec(
1637 (0..seq_len * config.hidden_size).map(|i| (i as f32 * 0.37).sin()).collect(),
1638 true,
1639 );
1640 let output = attn.forward(&x, seq_len);
1641
1642 let data = output.data();
1643 for (i, v) in data.iter().enumerate() {
1644 assert!(v.is_finite(), "FALSIFIED GQ-002e: MHA output[{i}] = {v} (not finite)");
1645 }
1646 }
1647
1648 #[test]
1650 fn falsify_gq_004e_head_divisibility() {
1651 for (nh, nkv) in [(2, 1), (2, 2), (4, 1), (4, 2), (4, 4), (8, 2), (8, 4)] {
1653 let mut config = TransformerConfig::tiny();
1654 config.num_attention_heads = nh;
1655 config.num_kv_heads = nkv;
1656 assert_eq!(nh % nkv, 0, "FALSIFIED GQ-004e: test config has invalid head ratio");
1657 let attn = MultiHeadAttention::new(&config);
1659 let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
1660 let _ = attn.forward(&x, 2);
1661 }
1662 }
1663
1664 #[test]
1666 fn falsify_gq_006e_mqa_boundary() {
1667 let mut config = TransformerConfig::tiny();
1668 config.num_attention_heads = 4;
1669 config.num_kv_heads = 1;
1670 config.hidden_size = 64;
1672
1673 let attn = MultiHeadAttention::new(&config);
1674 let seq_len = 3;
1675 let x = Tensor::from_vec(
1676 (0..seq_len * config.hidden_size).map(|i| (i as f32 * 0.73).cos()).collect(),
1677 true,
1678 );
1679 let output = attn.forward(&x, seq_len);
1680
1681 assert_eq!(
1682 output.len(),
1683 seq_len * config.hidden_size,
1684 "FALSIFIED GQ-006e: MQA output size wrong"
1685 );
1686
1687 let data = output.data();
1689 for (i, v) in data.iter().enumerate() {
1690 assert!(v.is_finite(), "FALSIFIED GQ-006e: MQA output[{i}] = {v} (not finite)");
1691 }
1692 }
1693
1694 mod gq_proptest_falsify {
1695 use super::*;
1696 use proptest::prelude::*;
1697
1698 proptest! {
1700 #![proptest_config(ProptestConfig::with_cases(50))]
1701
1702 #[test]
1703 fn falsify_gq_001e_prop_output_shape(
1704 config_idx in 0..4usize,
1705 seq_len in 2..=6usize,
1706 seed in 0..500u32,
1707 ) {
1708 let configs: [(usize, usize); 4] = [
1709 (2, 2), (2, 1), (4, 2), (4, 1),
1710 ];
1711 let (num_heads, num_kv_heads) = configs[config_idx];
1712 let mut config = TransformerConfig::tiny();
1713 config.num_attention_heads = num_heads;
1714 config.num_kv_heads = num_kv_heads;
1715
1716 let attn = MultiHeadAttention::new(&config);
1717 let data: Vec<f32> = (0..seq_len * config.hidden_size)
1718 .map(|i| ((i as f32 + seed as f32) * 0.37).sin())
1719 .collect();
1720 let x = Tensor::from_vec(data, true);
1721 let output = attn.forward(&x, seq_len);
1722
1723 prop_assert_eq!(
1724 output.len(),
1725 seq_len * config.hidden_size,
1726 "FALSIFIED GQ-001e-prop: output len mismatch"
1727 );
1728
1729 for v in output.data() {
1731 prop_assert!(
1732 v.is_finite(),
1733 "FALSIFIED GQ-001e-prop: non-finite output"
1734 );
1735 }
1736 }
1737 }
1738
1739 proptest! {
1741 #![proptest_config(ProptestConfig::with_cases(30))]
1742
1743 #[test]
1744 fn falsify_gq_006e_prop_mqa_boundary(
1745 seed in 0..500u32,
1746 seq_len in 2..=5usize,
1747 ) {
1748 let mut config = TransformerConfig::tiny();
1749 config.num_attention_heads = 4;
1750 config.num_kv_heads = 1;
1751 config.hidden_size = 64;
1752
1753 let attn = MultiHeadAttention::new(&config);
1754 let data: Vec<f32> = (0..seq_len * config.hidden_size)
1755 .map(|i| ((i as f32 + seed as f32) * 0.73).cos())
1756 .collect();
1757 let x = Tensor::from_vec(data, true);
1758 let output = attn.forward(&x, seq_len);
1759
1760 prop_assert_eq!(
1761 output.len(),
1762 seq_len * config.hidden_size,
1763 "FALSIFIED GQ-006e-prop: MQA output len mismatch"
1764 );
1765
1766 for v in output.data() {
1767 prop_assert!(
1768 v.is_finite(),
1769 "FALSIFIED GQ-006e-prop: non-finite MQA output"
1770 );
1771 }
1772 }
1773 }
1774 }
1775
1776 #[test]
1777 fn test_attention_from_params_with_biases() {
1778 let config = TransformerConfig::tiny();
1779 let hidden_size = config.hidden_size;
1780 let kv_hidden_size = config.num_kv_heads * config.head_dim();
1781
1782 let mut params = HashMap::new();
1783 params.insert(
1784 "attn.q_proj.weight".to_string(),
1785 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1786 );
1787 params.insert(
1788 "attn.k_proj.weight".to_string(),
1789 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1790 );
1791 params.insert(
1792 "attn.v_proj.weight".to_string(),
1793 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1794 );
1795 params.insert(
1796 "attn.o_proj.weight".to_string(),
1797 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1798 );
1799 params.insert(
1800 "attn.q_proj.bias".to_string(),
1801 Tensor::from_vec(vec![0.01; hidden_size], true),
1802 );
1803 params.insert(
1804 "attn.k_proj.bias".to_string(),
1805 Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1806 );
1807 params.insert(
1808 "attn.v_proj.bias".to_string(),
1809 Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1810 );
1811
1812 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
1813 assert!(attn.is_some());
1814 let attn = attn.expect("should load with biases");
1815 assert!(attn.has_biases());
1816 assert_eq!(attn.parameters().len(), 7);
1817 }
1818
1819 #[test]
1820 fn test_attention_named_parameters_with_biases() {
1821 let config = TransformerConfig::tiny();
1822 let hidden_size = config.hidden_size;
1823 let kv_hidden_size = config.num_kv_heads * config.head_dim();
1824
1825 let mut params = HashMap::new();
1826 params.insert(
1827 "attn.q_proj.weight".to_string(),
1828 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1829 );
1830 params.insert(
1831 "attn.k_proj.weight".to_string(),
1832 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1833 );
1834 params.insert(
1835 "attn.v_proj.weight".to_string(),
1836 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1837 );
1838 params.insert(
1839 "attn.o_proj.weight".to_string(),
1840 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1841 );
1842 params.insert(
1843 "attn.q_proj.bias".to_string(),
1844 Tensor::from_vec(vec![0.01; hidden_size], true),
1845 );
1846 params.insert(
1847 "attn.k_proj.bias".to_string(),
1848 Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1849 );
1850 params.insert(
1851 "attn.v_proj.bias".to_string(),
1852 Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1853 );
1854
1855 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn").expect("should load");
1856 let named = attn.named_parameters("attn");
1857 assert_eq!(named.len(), 7);
1858 let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
1859 assert!(names.contains(&"attn.q_proj.bias"));
1860 assert!(names.contains(&"attn.k_proj.bias"));
1861 assert!(names.contains(&"attn.v_proj.bias"));
1862 }
1863
1864 #[test]
1865 fn test_attention_forward_with_biases() {
1866 let config = TransformerConfig::tiny();
1867 let hidden_size = config.hidden_size;
1868 let kv_hidden_size = config.num_kv_heads * config.head_dim();
1869
1870 let mut params = HashMap::new();
1871 params.insert(
1872 "attn.q_proj.weight".to_string(),
1873 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1874 );
1875 params.insert(
1876 "attn.k_proj.weight".to_string(),
1877 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1878 );
1879 params.insert(
1880 "attn.v_proj.weight".to_string(),
1881 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1882 );
1883 params.insert(
1884 "attn.o_proj.weight".to_string(),
1885 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1886 );
1887 params
1888 .insert("attn.q_proj.bias".to_string(), Tensor::from_vec(vec![0.5; hidden_size], true));
1889 params.insert(
1890 "attn.k_proj.bias".to_string(),
1891 Tensor::from_vec(vec![0.5; kv_hidden_size], true),
1892 );
1893 params.insert(
1894 "attn.v_proj.bias".to_string(),
1895 Tensor::from_vec(vec![0.5; kv_hidden_size], true),
1896 );
1897
1898 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn").expect("should load");
1899 let x = Tensor::from_vec(vec![0.1; 2 * hidden_size], false);
1900 let output = attn.forward(&x, 2);
1901 assert_eq!(output.len(), 2 * hidden_size);
1902 assert!(output.data().iter().all(|v| v.is_finite()));
1903 }
1904}