1use crate::autograd::{matmul, matmul_nt, BackwardOp};
6use crate::Tensor;
7use ndarray::Array1;
8use std::cell::RefCell;
9use std::collections::HashMap;
10use std::rc::Rc;
11
12use super::config::TransformerConfig;
13
14fn add_bias(x: &Tensor, bias: &Tensor, seq_len: usize) -> Tensor {
17 let xd = x.data();
18 let x_slice = xd.as_slice().expect("contiguous projection");
19 let bd = bias.data();
20 let b_slice = bd.as_slice().expect("contiguous bias");
21 let dim = b_slice.len();
22 let mut out = Vec::with_capacity(x_slice.len());
23 for s in 0..seq_len {
24 let base = s * dim;
25 for d in 0..dim {
26 out.push(x_slice[base + d] + b_slice[d]);
27 }
28 }
29 Tensor::from_vec(out, x.requires_grad())
30}
31
32fn apply_qk_norm(
37 x: &Tensor,
38 norm_weight: &Tensor,
39 seq_len: usize,
40 num_heads: usize,
41 head_dim: usize,
42) -> Tensor {
43 let xd = x.data();
44 let x_slice = xd.as_slice().expect("contiguous qk");
45 let wd = norm_weight.data();
46 let w_slice = wd.as_slice().expect("contiguous norm weight");
47 let total_dim = num_heads * head_dim;
48 let eps = 1e-6_f32;
49 let mut out = vec![0.0f32; seq_len * total_dim];
50
51 for s in 0..seq_len {
52 for h in 0..num_heads {
53 let offset = s * total_dim + h * head_dim;
54 let mut sum_sq = 0.0f32;
56 for d in 0..head_dim {
57 let v = x_slice[offset + d];
58 sum_sq += v * v;
59 }
60 let rms = (sum_sq / head_dim as f32 + eps).sqrt();
61 let inv_rms = 1.0 / rms;
62 for d in 0..head_dim {
63 out[offset + d] = x_slice[offset + d] * inv_rms * w_slice[d];
64 }
65 }
66 }
67
68 Tensor::from_vec(out, x.requires_grad())
69}
70
71fn apply_rope(
80 x: &Tensor,
81 seq_len: usize,
82 num_heads: usize,
83 head_dim: usize,
84 rope_theta: f32,
85) -> Tensor {
86 let xd = x.data();
87 let x_slice = xd.as_slice().expect("contiguous qk for rope");
88 let total_dim = num_heads * head_dim;
89 let half_dim = head_dim / 2;
90 let mut out = vec![0.0f32; seq_len * total_dim];
91
92 let inv_freq: Vec<f32> =
94 (0..half_dim).map(|i| 1.0 / rope_theta.powf(2.0 * i as f32 / head_dim as f32)).collect();
95
96 for pos in 0..seq_len {
97 for h in 0..num_heads {
98 let offset = pos * total_dim + h * head_dim;
99 for i in 0..half_dim {
100 let freq = pos as f32 * inv_freq[i];
101 let cos_f = freq.cos();
102 let sin_f = freq.sin();
103 let x_first = x_slice[offset + i];
105 let x_second = x_slice[offset + i + half_dim];
106 out[offset + i] = x_first * cos_f - x_second * sin_f;
108 out[offset + i + half_dim] = x_second * cos_f + x_first * sin_f;
109 }
110 }
111 }
112
113 let result = Tensor::from_vec(out, x.requires_grad());
114 contract_post_rope!(result.data().as_slice().unwrap_or(&[]));
115 result
116}
117
118struct AttentionBlockBackward {
126 q: Tensor,
127 k: Tensor,
128 v: Tensor,
129 head_q_tensors: Vec<Tensor>,
130 head_k_tensors: Vec<Tensor>,
131 head_v_tensors: Vec<Tensor>,
132 head_outputs: Vec<Tensor>,
133 head_kv_indices: Vec<usize>,
134 seq_len: usize,
135 head_dim: usize,
136 q_dim: usize,
137 kv_hidden_size: usize,
138 result_grad: Rc<RefCell<Option<Array1<f32>>>>,
139}
140
141impl BackwardOp for AttentionBlockBackward {
142 fn backward(&self) {
143 let Some(grad_out) = self.result_grad.borrow().as_ref().cloned() else { return };
144 let go = grad_out.as_slice().expect("grad contiguous");
145 let h = self.head_dim;
146
147 split_and_backward_heads(go, &self.head_outputs, self.seq_len, h, self.q_dim);
149
150 scatter_head_grads_q(&self.q, &self.head_q_tensors, self.seq_len, h, self.q_dim);
152 scatter_head_grads_kv(
153 &self.k,
154 &self.head_k_tensors,
155 &self.head_kv_indices,
156 self.seq_len,
157 h,
158 self.kv_hidden_size,
159 );
160 scatter_head_grads_kv(
161 &self.v,
162 &self.head_v_tensors,
163 &self.head_kv_indices,
164 self.seq_len,
165 h,
166 self.kv_hidden_size,
167 );
168
169 for proj in [&self.q, &self.k, &self.v] {
171 if let Some(op) = proj.backward_op() {
172 op.backward();
173 }
174 }
175 }
176}
177
178fn split_and_backward_heads(
180 go: &[f32],
181 head_outputs: &[Tensor],
182 seq_len: usize,
183 head_dim: usize,
184 q_dim: usize,
185) {
186 for (head_idx, head_out) in head_outputs.iter().enumerate() {
187 let mut grad_head = vec![0.0_f32; seq_len * head_dim];
188 for s in 0..seq_len {
189 let src_base = s * q_dim + head_idx * head_dim;
190 let dst_base = s * head_dim;
191 grad_head[dst_base..dst_base + head_dim]
192 .copy_from_slice(&go[src_base..src_base + head_dim]);
193 }
194 head_out.accumulate_grad(Array1::from(grad_head));
195 if let Some(op) = head_out.backward_op() {
196 op.backward();
197 }
198 }
199}
200
201fn scatter_head_grads_q(
203 q: &Tensor,
204 head_q_tensors: &[Tensor],
205 seq_len: usize,
206 head_dim: usize,
207 q_dim: usize,
208) {
209 if !q.requires_grad() {
210 return;
211 }
212 let mut grad_q = vec![0.0_f32; seq_len * q_dim];
213 for (head_idx, head_q) in head_q_tensors.iter().enumerate() {
214 if let Some(hgrad) = head_q.grad() {
215 let hg = hgrad.as_slice().expect("contiguous");
216 for s in 0..seq_len {
217 let src_base = s * head_dim;
218 let dst_base = s * q_dim + head_idx * head_dim;
219 for d in 0..head_dim {
220 grad_q[dst_base + d] += hg[src_base + d];
221 }
222 }
223 }
224 }
225 q.accumulate_grad(Array1::from(grad_q));
226}
227
228fn scatter_head_grads_kv(
230 target: &Tensor,
231 head_tensors: &[Tensor],
232 kv_indices: &[usize],
233 seq_len: usize,
234 head_dim: usize,
235 kv_hidden_size: usize,
236) {
237 if !target.requires_grad() {
238 return;
239 }
240 let mut grad = vec![0.0_f32; seq_len * kv_hidden_size];
241 for (head_idx, head_t) in head_tensors.iter().enumerate() {
242 let kv_h = kv_indices[head_idx];
243 if let Some(hgrad) = head_t.grad() {
244 let hg = hgrad.as_slice().expect("contiguous");
245 for s in 0..seq_len {
246 let src_base = s * head_dim;
247 let dst_base = s * kv_hidden_size + kv_h * head_dim;
248 for d in 0..head_dim {
249 grad[dst_base + d] += hg[src_base + d];
250 }
251 }
252 }
253 }
254 target.accumulate_grad(Array1::from(grad));
255}
256
257pub struct MultiHeadAttention {
259 config: TransformerConfig,
261 pub w_q: Tensor,
263 pub w_k: Tensor,
265 pub w_v: Tensor,
267 pub w_o: Tensor,
269 pub b_q: Option<Tensor>,
271 pub b_k: Option<Tensor>,
273 pub b_v: Option<Tensor>,
275 pub q_norm: Option<Tensor>,
277 pub k_norm: Option<Tensor>,
279}
280
281impl MultiHeadAttention {
282 pub fn new(config: &TransformerConfig) -> Self {
284 use super::init::{get_init_seed, rand_normal_seeded};
285 let hidden_size = config.hidden_size;
286 let q_dim = config.q_dim();
287 let kv_hidden_size = config.num_kv_heads * config.head_dim();
288 let seed = get_init_seed();
289
290 Self {
292 config: config.clone(),
293 w_q: Tensor::from_vec(rand_normal_seeded(q_dim * hidden_size, seed, "w_q"), true),
294 w_k: Tensor::from_vec(
295 rand_normal_seeded(kv_hidden_size * hidden_size, seed, "w_k"),
296 true,
297 ),
298 w_v: Tensor::from_vec(
299 rand_normal_seeded(kv_hidden_size * hidden_size, seed, "w_v"),
300 true,
301 ),
302 w_o: Tensor::from_vec(rand_normal_seeded(hidden_size * q_dim, seed, "w_o"), true),
303 b_q: None,
304 b_k: None,
305 b_v: None,
306 q_norm: None,
307 k_norm: None,
308 }
309 }
310
311 pub fn from_params(
322 config: &TransformerConfig,
323 params: &HashMap<String, Tensor>,
324 prefix: &str,
325 ) -> Option<Self> {
326 let w_q = params.get(&format!("{prefix}.q_proj.weight"))?.clone();
327 let w_k = params.get(&format!("{prefix}.k_proj.weight"))?.clone();
328 let w_v = params.get(&format!("{prefix}.v_proj.weight"))?.clone();
329 let w_o = params.get(&format!("{prefix}.o_proj.weight"))?.clone();
330
331 let hidden = config.hidden_size;
332 let q_dim = config.q_dim();
333 let kv_hidden = config.num_kv_heads * config.head_dim();
334
335 let checks: &[(&str, &Tensor, usize)] = &[
338 ("q_proj", &w_q, q_dim * hidden),
339 ("k_proj", &w_k, kv_hidden * hidden),
340 ("v_proj", &w_v, kv_hidden * hidden),
341 ("o_proj", &w_o, hidden * q_dim),
342 ];
343 for &(name, tensor, expected) in checks {
344 if tensor.len() != expected {
345 eprintln!(
346 "[PMAT-331] {prefix}.{name}: shape mismatch — got {} elements, expected {expected}",
347 tensor.len()
348 );
349 return None;
350 }
351 }
352
353 let b_q = params.get(&format!("{prefix}.q_proj.bias")).cloned();
355 let b_k = params.get(&format!("{prefix}.k_proj.bias")).cloned();
356 let b_v = params.get(&format!("{prefix}.v_proj.bias")).cloned();
357
358 let q_norm = params.get(&format!("{prefix}.q_norm.weight")).cloned();
360 let k_norm = params.get(&format!("{prefix}.k_norm.weight")).cloned();
361
362 Some(Self { config: config.clone(), w_q, w_k, w_v, w_o, b_q, b_k, b_v, q_norm, k_norm })
363 }
364
365 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
374 contract_pre_attention!(x.data());
375 let hidden_size = self.config.hidden_size;
376 let num_heads = self.config.num_attention_heads;
377 let num_kv_heads = self.config.num_kv_heads;
378 let head_dim = self.config.head_dim();
379 let q_dim = self.config.q_dim();
380 let kv_hidden_size = num_kv_heads * head_dim;
381
382 let mut q = matmul_nt(x, &self.w_q, seq_len, hidden_size, q_dim);
384 let mut k = matmul_nt(x, &self.w_k, seq_len, hidden_size, kv_hidden_size);
385 let mut v = matmul_nt(x, &self.w_v, seq_len, hidden_size, kv_hidden_size);
386
387 if let Some(ref b_q) = self.b_q {
389 q = add_bias(&q, b_q, seq_len);
390 }
391 if let Some(ref b_k) = self.b_k {
392 k = add_bias(&k, b_k, seq_len);
393 }
394 if let Some(ref b_v) = self.b_v {
395 v = add_bias(&v, b_v, seq_len);
396 }
397
398 if let Some(ref qn) = self.q_norm {
400 q = apply_qk_norm(&q, qn, seq_len, num_heads, head_dim);
401 }
402 if let Some(ref kn) = self.k_norm {
403 k = apply_qk_norm(&k, kn, seq_len, num_kv_heads, head_dim);
404 }
405
406 if self.config.rope_theta > 0.0 {
409 q = apply_rope(&q, seq_len, num_heads, head_dim, self.config.rope_theta);
410 k = apply_rope(&k, seq_len, num_kv_heads, head_dim, self.config.rope_theta);
411 }
412
413 let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
414 let heads_per_kv = num_heads / num_kv_heads;
415
416 let q_data = q.data();
419 let q_slice = q_data.as_slice().expect("contiguous Q");
420 let k_data = k.data();
421 let k_slice = k_data.as_slice().expect("contiguous K");
422 let v_data = v.data();
423 let v_slice = v_data.as_slice().expect("contiguous V");
424
425 let mut head_q_tensors = Vec::with_capacity(num_heads);
427 let mut head_k_tensors = Vec::with_capacity(num_heads);
428 let mut head_v_tensors = Vec::with_capacity(num_heads);
429 let mut head_outputs = Vec::with_capacity(num_heads);
430 let mut head_kv_indices = Vec::with_capacity(num_heads);
431
432 for h in 0..num_heads {
433 let kv_h = h / heads_per_kv;
434 head_kv_indices.push(kv_h);
435
436 let mut q_head = Vec::with_capacity(seq_len * head_dim);
440 for s in 0..seq_len {
441 let start = s * q_dim + h * head_dim;
442 q_head.extend_from_slice(&q_slice[start..start + head_dim]);
443 }
444
445 let mut k_head = Vec::with_capacity(seq_len * head_dim);
446 for s in 0..seq_len {
447 let start = s * kv_hidden_size + kv_h * head_dim;
448 k_head.extend_from_slice(&k_slice[start..start + head_dim]);
449 }
450
451 let mut v_head = Vec::with_capacity(seq_len * head_dim);
452 for s in 0..seq_len {
453 let start = s * kv_hidden_size + kv_h * head_dim;
454 v_head.extend_from_slice(&v_slice[start..start + head_dim]);
455 }
456
457 let q_tensor = Tensor::from_vec(q_head, requires_grad);
458 let k_tensor = Tensor::from_vec(k_head, requires_grad);
459 let v_tensor = Tensor::from_vec(v_head, requires_grad);
460
461 let attn_out = crate::autograd::attention(
462 &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
463 );
464
465 head_q_tensors.push(q_tensor);
466 head_k_tensors.push(k_tensor);
467 head_v_tensors.push(v_tensor);
468 head_outputs.push(attn_out);
469 }
470
471 let mut concat_output = vec![0.0; seq_len * q_dim];
473 for (h, head_out) in head_outputs.iter().enumerate() {
474 let hd = head_out.data();
475 let hdata = hd.as_slice().expect("contiguous attention output");
476 for s in 0..seq_len {
477 let src_base = s * head_dim;
478 let dst_base = s * q_dim + h * head_dim;
479 concat_output[dst_base..dst_base + head_dim]
480 .copy_from_slice(&hdata[src_base..src_base + head_dim]);
481 }
482 }
483
484 let mut concat_tensor = Tensor::from_vec(concat_output, requires_grad);
485
486 if requires_grad {
487 let backward_op = Rc::new(AttentionBlockBackward {
488 q: q.clone(),
489 k: k.clone(),
490 v: v.clone(),
491 head_q_tensors,
492 head_k_tensors,
493 head_v_tensors,
494 head_outputs,
495 head_kv_indices,
496 seq_len,
497 head_dim,
498 q_dim,
499 kv_hidden_size,
500 result_grad: concat_tensor.grad_cell(),
501 });
502 concat_tensor.set_backward_op(backward_op);
503 }
504
505 let result = matmul_nt(&concat_tensor, &self.w_o, seq_len, q_dim, hidden_size);
507 contract_post_attention!(result.data().as_slice().unwrap_or(&[]));
508 result
509 }
510
511 pub fn forward_with_lora(
524 &self,
525 x: &Tensor,
526 seq_len: usize,
527 lora_a_q: &Tensor,
528 lora_b_q: &Tensor,
530 lora_a_v: &Tensor,
531 lora_b_v: &Tensor,
532 lora_rank: usize,
533 lora_scale: f32,
534 ) -> Tensor {
535 contract_pre_lora_forward!();
536 let hidden_size = self.config.hidden_size;
537 let num_heads = self.config.num_attention_heads;
538 let num_kv_heads = self.config.num_kv_heads;
539 let head_dim = self.config.head_dim();
540 let q_dim = self.config.q_dim();
541 let kv_hidden_size = num_kv_heads * head_dim;
542
543 let q_base = matmul_nt(x, &self.w_q, seq_len, hidden_size, q_dim);
554 let q_mid = crate::autograd::matmul_nt(x, lora_a_q, seq_len, hidden_size, lora_rank);
555 let q_lora = crate::autograd::matmul_nt(&q_mid, lora_b_q, seq_len, lora_rank, q_dim);
556 let q = crate::autograd::add_scaled(&q_base, &q_lora, lora_scale);
557
558 let k = matmul_nt(x, &self.w_k, seq_len, hidden_size, kv_hidden_size);
560
561 let v_base = matmul_nt(x, &self.w_v, seq_len, hidden_size, kv_hidden_size);
563 let v_mid = crate::autograd::matmul_nt(x, lora_a_v, seq_len, hidden_size, lora_rank);
564 let v_lora =
565 crate::autograd::matmul_nt(&v_mid, lora_b_v, seq_len, lora_rank, kv_hidden_size);
566 let v = crate::autograd::add_scaled(&v_base, &v_lora, lora_scale);
567
568 let q = if let Some(ref qn) = self.q_norm {
570 apply_qk_norm(&q, qn, seq_len, num_heads, head_dim)
571 } else {
572 q
573 };
574 let k = if let Some(ref kn) = self.k_norm {
575 apply_qk_norm(&k, kn, seq_len, num_kv_heads, head_dim)
576 } else {
577 k
578 };
579
580 let (q, k) = if self.config.rope_theta > 0.0 {
583 (
584 apply_rope(&q, seq_len, num_heads, head_dim, self.config.rope_theta),
585 apply_rope(&k, seq_len, num_kv_heads, head_dim, self.config.rope_theta),
586 )
587 } else {
588 (q, k)
589 };
590
591 let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
592 let heads_per_kv = num_heads / num_kv_heads;
593
594 let q_data = q.data();
596 let q_slice = q_data.as_slice().expect("contiguous Q");
597 let k_data = k.data();
598 let k_slice = k_data.as_slice().expect("contiguous K");
599 let v_data = v.data();
600 let v_slice = v_data.as_slice().expect("contiguous V");
601
602 let mut head_q_tensors = Vec::with_capacity(num_heads);
604 let mut head_k_tensors = Vec::with_capacity(num_heads);
605 let mut head_v_tensors = Vec::with_capacity(num_heads);
606 let mut head_outputs = Vec::with_capacity(num_heads);
607 let mut head_kv_indices = Vec::with_capacity(num_heads);
608
609 for h in 0..num_heads {
610 let kv_h = h / heads_per_kv;
611 head_kv_indices.push(kv_h);
612
613 let mut q_head = Vec::with_capacity(seq_len * head_dim);
615 for s in 0..seq_len {
616 let start = s * q_dim + h * head_dim;
617 q_head.extend_from_slice(&q_slice[start..start + head_dim]);
618 }
619
620 let mut k_head = Vec::with_capacity(seq_len * head_dim);
621 for s in 0..seq_len {
622 let start = s * kv_hidden_size + kv_h * head_dim;
623 k_head.extend_from_slice(&k_slice[start..start + head_dim]);
624 }
625
626 let mut v_head = Vec::with_capacity(seq_len * head_dim);
627 for s in 0..seq_len {
628 let start = s * kv_hidden_size + kv_h * head_dim;
629 v_head.extend_from_slice(&v_slice[start..start + head_dim]);
630 }
631
632 let q_tensor = Tensor::from_vec(q_head, requires_grad);
633 let k_tensor = Tensor::from_vec(k_head, requires_grad);
634 let v_tensor = Tensor::from_vec(v_head, requires_grad);
635
636 let attn_out = crate::autograd::attention(
637 &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
638 );
639
640 head_q_tensors.push(q_tensor);
641 head_k_tensors.push(k_tensor);
642 head_v_tensors.push(v_tensor);
643 head_outputs.push(attn_out);
644 }
645
646 let mut concat_output = vec![0.0; seq_len * q_dim];
648 for (h, head_out) in head_outputs.iter().enumerate() {
649 let hd = head_out.data();
650 let hdata = hd.as_slice().expect("contiguous attention output");
651 for s in 0..seq_len {
652 let src_base = s * head_dim;
653 let dst_base = s * q_dim + h * head_dim;
654 concat_output[dst_base..dst_base + head_dim]
655 .copy_from_slice(&hdata[src_base..src_base + head_dim]);
656 }
657 }
658
659 let mut concat_tensor = Tensor::from_vec(concat_output, requires_grad);
660
661 if requires_grad {
662 let backward_op = Rc::new(AttentionBlockBackward {
663 q: q.clone(),
664 k: k.clone(),
665 v: v.clone(),
666 head_q_tensors,
667 head_k_tensors,
668 head_v_tensors,
669 head_outputs,
670 head_kv_indices,
671 seq_len,
672 head_dim,
673 q_dim,
674 kv_hidden_size,
675 result_grad: concat_tensor.grad_cell(),
676 });
677 concat_tensor.set_backward_op(backward_op);
678 }
679
680 let result = matmul_nt(&concat_tensor, &self.w_o, seq_len, q_dim, hidden_size);
682 contract_post_lora_forward!(result);
683 result
684 }
685
686 pub fn parameters(&self) -> Vec<&Tensor> {
688 let mut params = vec![&self.w_q, &self.w_k, &self.w_v, &self.w_o];
689 if let Some(ref b) = self.b_q {
690 params.push(b);
691 }
692 if let Some(ref b) = self.b_k {
693 params.push(b);
694 }
695 if let Some(ref b) = self.b_v {
696 params.push(b);
697 }
698 params
699 }
700
701 pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
703 let mut params = vec![&mut self.w_q, &mut self.w_k, &mut self.w_v, &mut self.w_o];
704 if let Some(ref mut b) = self.b_q {
705 params.push(b);
706 }
707 if let Some(ref mut b) = self.b_k {
708 params.push(b);
709 }
710 if let Some(ref mut b) = self.b_v {
711 params.push(b);
712 }
713 params
714 }
715
716 pub fn has_biases(&self) -> bool {
718 self.b_q.is_some()
719 }
720
721 pub fn named_parameters(&self, prefix: &str) -> Vec<(String, &Tensor)> {
723 let mut params = vec![
724 (format!("{prefix}.q_proj.weight"), &self.w_q),
725 (format!("{prefix}.k_proj.weight"), &self.w_k),
726 (format!("{prefix}.v_proj.weight"), &self.w_v),
727 (format!("{prefix}.o_proj.weight"), &self.w_o),
728 ];
729 if let Some(ref b) = self.b_q {
730 params.push((format!("{prefix}.q_proj.bias"), b));
731 }
732 if let Some(ref b) = self.b_k {
733 params.push((format!("{prefix}.k_proj.bias"), b));
734 }
735 if let Some(ref b) = self.b_v {
736 params.push((format!("{prefix}.v_proj.bias"), b));
737 }
738 params
739 }
740
741 pub fn set_named_parameter(&mut self, suffix: &str, value: Tensor) -> bool {
743 match suffix {
744 "self_attn.q_proj.weight" => {
745 self.w_q = value;
746 true
747 }
748 "self_attn.k_proj.weight" => {
749 self.w_k = value;
750 true
751 }
752 "self_attn.v_proj.weight" => {
753 self.w_v = value;
754 true
755 }
756 "self_attn.o_proj.weight" => {
757 self.w_o = value;
758 true
759 }
760 _ => false,
761 }
762 }
763}
764
765pub struct LoRAProjection {
770 pub base_weight: Tensor,
772 pub lora_a: Tensor,
774 pub lora_b: Tensor,
776 pub d_in: usize,
778 pub d_out: usize,
780 pub rank: usize,
782 pub scale: f32,
784}
785
786impl LoRAProjection {
787 pub fn new(base_weight: Tensor, d_in: usize, d_out: usize, rank: usize, alpha: f32) -> Self {
796 assert_eq!(base_weight.len(), d_in * d_out, "Base weight size mismatch");
797
798 let mut base_weight = base_weight;
800 base_weight.set_requires_grad(false);
801
802 let lora_a = Tensor::from_vec(
804 (0..d_in * rank).map(|i| (i as f32 * 0.123).sin() * 0.01).collect(),
805 true, );
807
808 let lora_b = Tensor::zeros(rank * d_out, true);
810
811 Self { base_weight, lora_a, lora_b, d_in, d_out, rank, scale: alpha / rank as f32 }
812 }
813
814 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
825 let base_out = matmul(x, &self.base_weight, seq_len, self.d_in, self.d_out);
827
828 let lora_intermediate = matmul(x, &self.lora_a, seq_len, self.d_in, self.rank);
831
832 let lora_out = matmul(&lora_intermediate, &self.lora_b, seq_len, self.rank, self.d_out);
834
835 crate::autograd::add_scaled(&base_out, &lora_out, self.scale)
838 }
839
840 pub fn lora_params(&self) -> Vec<&Tensor> {
842 vec![&self.lora_a, &self.lora_b]
843 }
844
845 pub fn lora_params_mut(&mut self) -> Vec<&mut Tensor> {
847 vec![&mut self.lora_a, &mut self.lora_b]
848 }
849}
850
851pub struct MultiHeadAttentionWithLoRA {
855 pub config: TransformerConfig,
857 pub q_proj: LoRAProjection,
859 pub k_proj: LoRAProjection,
861 pub v_proj: LoRAProjection,
863 pub o_proj: LoRAProjection,
865}
866
867impl MultiHeadAttentionWithLoRA {
868 pub fn from_attention(attn: &MultiHeadAttention, rank: usize, alpha: f32) -> Self {
875 let hidden_size = attn.config.hidden_size;
876 let q_dim = attn.config.q_dim();
877 let kv_hidden_size = attn.config.num_kv_heads * attn.config.head_dim();
878
879 Self {
880 config: attn.config.clone(),
881 q_proj: LoRAProjection::new(attn.w_q.clone(), hidden_size, q_dim, rank, alpha),
882 k_proj: LoRAProjection::new(attn.w_k.clone(), hidden_size, kv_hidden_size, rank, alpha),
883 v_proj: LoRAProjection::new(attn.w_v.clone(), hidden_size, kv_hidden_size, rank, alpha),
884 o_proj: LoRAProjection::new(attn.w_o.clone(), q_dim, hidden_size, rank, alpha),
885 }
886 }
887
888 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
892 let num_heads = self.config.num_attention_heads;
893 let num_kv_heads = self.config.num_kv_heads;
894 let head_dim = self.config.head_dim();
895 let q_dim = self.config.q_dim();
896 let kv_hidden_size = num_kv_heads * head_dim;
897
898 let q = self.q_proj.forward(x, seq_len);
900 let k = self.k_proj.forward(x, seq_len);
901 let v = self.v_proj.forward(x, seq_len);
902
903 let mut attn_outputs = Vec::with_capacity(num_heads * seq_len * head_dim);
905 let heads_per_kv = num_heads / num_kv_heads;
906
907 let q_data = q.data();
909 let q_slice = q_data.as_slice().expect("contiguous Q tensor");
910 let k_data = k.data();
911 let k_slice = k_data.as_slice().expect("contiguous K tensor");
912 let v_data = v.data();
913 let v_slice = v_data.as_slice().expect("contiguous V tensor");
914
915 for h in 0..num_heads {
916 let kv_h = h / heads_per_kv;
917
918 let mut q_head = Vec::with_capacity(seq_len * head_dim);
920 for s in 0..seq_len {
921 let start = s * q_dim + h * head_dim;
922 q_head.extend_from_slice(&q_slice[start..start + head_dim]);
923 }
924
925 let mut k_head = Vec::with_capacity(seq_len * head_dim);
926 for s in 0..seq_len {
927 let start = s * kv_hidden_size + kv_h * head_dim;
928 k_head.extend_from_slice(&k_slice[start..start + head_dim]);
929 }
930
931 let mut v_head = Vec::with_capacity(seq_len * head_dim);
932 for s in 0..seq_len {
933 let start = s * kv_hidden_size + kv_h * head_dim;
934 v_head.extend_from_slice(&v_slice[start..start + head_dim]);
935 }
936
937 let q_tensor = Tensor::from_vec(q_head, false);
939 let k_tensor = Tensor::from_vec(k_head, false);
940 let v_tensor = Tensor::from_vec(v_head, false);
941
942 let attn_out = crate::autograd::attention(
943 &q_tensor, &k_tensor, &v_tensor, seq_len, head_dim, seq_len, head_dim,
944 );
945
946 attn_outputs.extend_from_slice(
947 attn_out.data().as_slice().expect("contiguous attention output"),
948 );
949 }
950
951 let mut concat_output = vec![0.0; seq_len * q_dim];
953 for h in 0..num_heads {
954 for s in 0..seq_len {
955 let src_idx = h * seq_len * head_dim + s * head_dim;
956 let dst_idx = s * q_dim + h * head_dim;
957 concat_output[dst_idx..dst_idx + head_dim]
958 .copy_from_slice(&attn_outputs[src_idx..src_idx + head_dim]);
959 }
960 }
961
962 let concat_tensor = Tensor::from_vec(concat_output, true);
963
964 self.o_proj.forward(&concat_tensor, seq_len)
966 }
967
968 pub fn lora_params(&self) -> Vec<&Tensor> {
970 let mut params = Vec::new();
971 params.extend(self.q_proj.lora_params());
972 params.extend(self.k_proj.lora_params());
973 params.extend(self.v_proj.lora_params());
974 params.extend(self.o_proj.lora_params());
975 params
976 }
977
978 pub fn lora_params_mut(&mut self) -> Vec<&mut Tensor> {
980 let mut params = Vec::new();
981 params.extend(self.q_proj.lora_params_mut());
982 params.extend(self.k_proj.lora_params_mut());
983 params.extend(self.v_proj.lora_params_mut());
984 params.extend(self.o_proj.lora_params_mut());
985 params
986 }
987
988 pub fn lora_param_count(&self) -> usize {
990 let hidden = self.config.hidden_size;
992 let kv_hidden = self.config.num_kv_heads * self.config.head_dim();
993 let rank = self.q_proj.rank;
994
995 (hidden * rank + rank * hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * hidden) }
1004}
1005
1006#[cfg(test)]
1007mod tests {
1008 use super::*;
1009
1010 #[test]
1011 fn test_multi_head_attention_tiny() {
1012 let config = TransformerConfig::tiny();
1013 let attn = MultiHeadAttention::new(&config);
1014 let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
1015 let output = attn.forward(&x, 2);
1016 assert_eq!(output.len(), 2 * config.hidden_size);
1017 }
1018
1019 #[test]
1020 fn test_multi_head_attention_parameters() {
1021 let config = TransformerConfig::tiny();
1022 let attn = MultiHeadAttention::new(&config);
1023 let params = attn.parameters();
1024 assert_eq!(params.len(), 4); }
1026
1027 #[test]
1028 fn test_attention_longer_sequence() {
1029 let config = TransformerConfig::tiny();
1030 let attn = MultiHeadAttention::new(&config);
1031 let x = Tensor::from_vec(vec![0.1; 8 * config.hidden_size], true);
1032 let output = attn.forward(&x, 8);
1033 assert_eq!(output.len(), 8 * config.hidden_size);
1034 }
1035
1036 #[test]
1037 fn test_attention_weight_sizes() {
1038 let config = TransformerConfig::tiny();
1039 let attn = MultiHeadAttention::new(&config);
1040 let kv_hidden = config.num_kv_heads * config.head_dim();
1041 assert_eq!(attn.w_q.len(), config.hidden_size * config.hidden_size);
1042 assert_eq!(attn.w_k.len(), config.hidden_size * kv_hidden);
1043 assert_eq!(attn.w_v.len(), config.hidden_size * kv_hidden);
1044 assert_eq!(attn.w_o.len(), config.hidden_size * config.hidden_size);
1045 }
1046
1047 #[test]
1048 fn test_multi_head_attention_from_params_success() {
1049 let config = TransformerConfig::tiny();
1050 let hidden_size = config.hidden_size;
1051 let kv_hidden_size = config.num_kv_heads * config.head_dim();
1052
1053 let mut params = HashMap::new();
1054 params.insert(
1055 "attn.q_proj.weight".to_string(),
1056 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1057 );
1058 params.insert(
1059 "attn.k_proj.weight".to_string(),
1060 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1061 );
1062 params.insert(
1063 "attn.v_proj.weight".to_string(),
1064 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1065 );
1066 params.insert(
1067 "attn.o_proj.weight".to_string(),
1068 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1069 );
1070
1071 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
1072 assert!(attn.is_some());
1073 let attn = attn.expect("operation should succeed");
1074 assert_eq!(attn.w_q.len(), hidden_size * hidden_size);
1075 }
1076
1077 #[test]
1078 fn test_multi_head_attention_from_params_missing_key() {
1079 let config = TransformerConfig::tiny();
1080 let hidden_size = config.hidden_size;
1081
1082 let mut params = HashMap::new();
1083 params.insert(
1084 "attn.q_proj.weight".to_string(),
1085 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1086 );
1087 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
1090 assert!(attn.is_none());
1091 }
1092
1093 #[test]
1094 fn test_attention_projections_backward() {
1095 let config = TransformerConfig::tiny();
1098 let attn = MultiHeadAttention::new(&config);
1099 let hidden_size = config.hidden_size;
1100 let seq_len = 2;
1101
1102 let x = Tensor::from_vec(vec![0.1; seq_len * hidden_size], true);
1103
1104 let mut q = crate::autograd::matmul(&x, &attn.w_q, seq_len, hidden_size, hidden_size);
1106 let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1107 crate::autograd::backward(&mut q, Some(grad_out));
1108
1109 assert!(attn.w_q.grad().is_some());
1110 let grad_q = attn.w_q.grad().expect("gradient should be available");
1111 assert!(grad_q.iter().all(|&v| v.is_finite()));
1112 }
1113
1114 #[test]
1115 fn test_output_projection_backward() {
1116 let config = TransformerConfig::tiny();
1118 let attn = MultiHeadAttention::new(&config);
1119 let hidden_size = config.hidden_size;
1120 let seq_len = 2;
1121
1122 let concat_out = Tensor::from_vec(vec![0.1; seq_len * hidden_size], true);
1124
1125 let mut output =
1127 crate::autograd::matmul(&concat_out, &attn.w_o, seq_len, hidden_size, hidden_size);
1128
1129 let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1130 crate::autograd::backward(&mut output, Some(grad_out));
1131
1132 assert!(attn.w_o.grad().is_some());
1133 let grad_o = attn.w_o.grad().expect("gradient should be available");
1134 assert!(grad_o.iter().all(|&v| v.is_finite()));
1135 let sum: f32 = grad_o.iter().map(|v| v.abs()).sum();
1136 assert!(sum > 0.0, "Output projection gradient should not be all zero");
1137 }
1138
1139 #[test]
1145 #[ignore = "apply_rope() severs autograd chain — needs backward op (ENT-272)"]
1146 fn test_attention_full_forward_qkv_gradients() {
1147 let config = TransformerConfig::tiny();
1148 let attn = MultiHeadAttention::new(&config);
1149 let hidden_size = config.hidden_size;
1150 let seq_len = 3;
1151
1152 let x_data: Vec<f32> =
1155 (0..seq_len * hidden_size).map(|i| ((i as f32) * 0.17).sin() * 0.5).collect();
1156 let x = Tensor::from_vec(x_data, true);
1157 let mut output = attn.forward(&x, seq_len);
1158
1159 let grad_out = ndarray::Array1::ones(seq_len * hidden_size);
1160 crate::autograd::backward(&mut output, Some(grad_out));
1161
1162 for (name, param) in
1164 [("w_q", &attn.w_q), ("w_k", &attn.w_k), ("w_v", &attn.w_v), ("w_o", &attn.w_o)]
1165 {
1166 assert!(
1167 param.grad().is_some(),
1168 "ALB-038: {name} must have gradient after full attention forward"
1169 );
1170 let grad = param.grad().expect("gradient available");
1171 assert!(grad.iter().all(|&v| v.is_finite()), "ALB-038: {name} gradient must be finite");
1172 assert!(
1173 grad.iter().any(|&v| v.abs() > 1e-10),
1174 "ALB-038: {name} gradient must be non-zero"
1175 );
1176 }
1177
1178 assert!(x.grad().is_some(), "ALB-038: input x must have gradient");
1180 }
1181
1182 #[test]
1187 fn test_lora_projection_new() {
1188 let d_in = 32;
1189 let d_out = 16;
1190 let rank = 4;
1191 let alpha = 8.0;
1192
1193 let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1194 let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, alpha);
1195
1196 assert_eq!(lora.d_in, d_in);
1197 assert_eq!(lora.d_out, d_out);
1198 assert_eq!(lora.rank, rank);
1199 assert!((lora.scale - 2.0).abs() < 1e-6); assert_eq!(lora.lora_a.len(), d_in * rank);
1201 assert_eq!(lora.lora_b.len(), rank * d_out);
1202 }
1203
1204 #[test]
1205 fn test_lora_projection_forward() {
1206 let d_in = 32;
1207 let d_out = 16;
1208 let rank = 4;
1209 let alpha = 8.0;
1210 let seq_len = 2;
1211
1212 let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1213 let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, alpha);
1214
1215 let x = Tensor::from_vec(vec![0.1; seq_len * d_in], false);
1216 let output = lora.forward(&x, seq_len);
1217
1218 assert_eq!(output.len(), seq_len * d_out);
1219 assert!(output.data().iter().all(|&v| v.is_finite()));
1221 }
1222
1223 #[test]
1224 fn test_lora_projection_params() {
1225 let d_in = 32;
1226 let d_out = 16;
1227 let rank = 4;
1228
1229 let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1230 let lora = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1231
1232 let params = lora.lora_params();
1233 assert_eq!(params.len(), 2); }
1235
1236 #[test]
1237 fn test_lora_projection_params_mut() {
1238 let d_in = 32;
1239 let d_out = 16;
1240 let rank = 4;
1241
1242 let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out], false);
1243 let mut lora = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1244
1245 let params = lora.lora_params_mut();
1246 assert_eq!(params.len(), 2);
1247 }
1248
1249 #[test]
1250 #[should_panic(expected = "Base weight size mismatch")]
1251 fn test_lora_projection_size_mismatch() {
1252 let d_in = 32;
1253 let d_out = 16;
1254 let rank = 4;
1255
1256 let base_weight = Tensor::from_vec(vec![0.1; d_in * d_out + 1], false);
1258 let _ = LoRAProjection::new(base_weight, d_in, d_out, rank, 8.0);
1259 }
1260
1261 #[test]
1266 fn test_mha_with_lora_creation() {
1267 let config = TransformerConfig::tiny();
1268 let attn = MultiHeadAttention::new(&config);
1269 let rank = 4;
1270 let alpha = 8.0;
1271
1272 let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, rank, alpha);
1273
1274 assert_eq!(lora_attn.q_proj.rank, rank);
1275 assert_eq!(lora_attn.k_proj.rank, rank);
1276 assert_eq!(lora_attn.v_proj.rank, rank);
1277 assert_eq!(lora_attn.o_proj.rank, rank);
1278 }
1279
1280 #[test]
1281 fn test_mha_with_lora_forward() {
1282 let config = TransformerConfig::tiny();
1283 let attn = MultiHeadAttention::new(&config);
1284 let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1285
1286 let seq_len = 2;
1287 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], false);
1288 let output = lora_attn.forward(&x, seq_len);
1289
1290 assert_eq!(output.len(), seq_len * config.hidden_size);
1291 assert!(output.data().iter().all(|&v| v.is_finite()));
1293 }
1294
1295 #[test]
1296 fn test_mha_with_lora_params() {
1297 let config = TransformerConfig::tiny();
1298 let attn = MultiHeadAttention::new(&config);
1299 let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1300
1301 let params = lora_attn.lora_params();
1302 assert_eq!(params.len(), 8);
1304 }
1305
1306 #[test]
1307 fn test_mha_with_lora_params_mut() {
1308 let config = TransformerConfig::tiny();
1309 let attn = MultiHeadAttention::new(&config);
1310 let mut lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1311
1312 let params = lora_attn.lora_params_mut();
1313 assert_eq!(params.len(), 8);
1314 }
1315
1316 #[test]
1317 fn test_mha_with_lora_param_count() {
1318 let config = TransformerConfig::tiny();
1319 let attn = MultiHeadAttention::new(&config);
1320 let rank = 4;
1321 let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, rank, 8.0);
1322
1323 let param_count = lora_attn.lora_param_count();
1324
1325 let hidden = config.hidden_size;
1327 let kv_hidden = config.num_kv_heads * config.head_dim();
1328 let expected = (hidden * rank + rank * hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * kv_hidden) + (hidden * rank + rank * hidden); assert_eq!(param_count, expected);
1334 assert!(param_count > 0);
1335 }
1336
1337 #[test]
1338 fn test_mha_with_lora_longer_sequence() {
1339 let config = TransformerConfig::tiny();
1340 let attn = MultiHeadAttention::new(&config);
1341 let lora_attn = MultiHeadAttentionWithLoRA::from_attention(&attn, 4, 8.0);
1342
1343 let seq_len = 8;
1344 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], false);
1345 let output = lora_attn.forward(&x, seq_len);
1346
1347 assert_eq!(output.len(), seq_len * config.hidden_size);
1348 }
1349
1350 #[test]
1351 fn test_parameters_mut() {
1352 let config = TransformerConfig::tiny();
1353 let mut attn = MultiHeadAttention::new(&config);
1354
1355 let params = attn.parameters_mut();
1356 assert_eq!(params.len(), 4);
1357 }
1358
1359 #[test]
1384 fn falsify_a1e_from_params_rejects_wrong_shape_q_weight() {
1385 let config = TransformerConfig::tiny();
1386 let hidden_size = config.hidden_size;
1387 let kv_hidden_size = config.num_kv_heads * config.head_dim();
1388
1389 let mut params = HashMap::new();
1390 params.insert("attn.q_proj.weight".to_string(), Tensor::from_vec(vec![0.1; 50], true));
1392 params.insert(
1394 "attn.k_proj.weight".to_string(),
1395 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1396 );
1397 params.insert(
1398 "attn.v_proj.weight".to_string(),
1399 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1400 );
1401 params.insert(
1402 "attn.o_proj.weight".to_string(),
1403 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1404 );
1405
1406 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
1407 assert!(
1409 attn.is_none(),
1410 "FALSIFY-A1e: PMAT-331 fix — from_params MUST reject wrong-shape q_proj"
1411 );
1412 }
1413
1414 #[test]
1419 fn falsify_a2e_gqa_init_correct_kv_dimensions() {
1420 let mut config = TransformerConfig::tiny();
1421 config.num_kv_heads = 1; let attn = MultiHeadAttention::new(&config);
1424 let head_dim = config.head_dim();
1425 let kv_hidden = config.num_kv_heads * head_dim; assert_eq!(
1429 attn.w_q.len(),
1430 config.hidden_size * config.hidden_size,
1431 "FALSIFY-A2e: Q projection must be hidden*hidden"
1432 );
1433
1434 assert_eq!(
1436 attn.w_k.len(),
1437 config.hidden_size * kv_hidden,
1438 "FALSIFY-A2e: K projection must use num_kv_heads, not num_heads"
1439 );
1440
1441 assert_eq!(
1443 attn.w_v.len(),
1444 config.hidden_size * kv_hidden,
1445 "FALSIFY-A2e: V projection must use num_kv_heads, not num_heads"
1446 );
1447
1448 assert_eq!(
1450 attn.w_o.len(),
1451 config.hidden_size * config.hidden_size,
1452 "FALSIFY-A2e: O projection must be hidden*hidden"
1453 );
1454
1455 assert!(
1457 attn.w_k.len() < attn.w_q.len(),
1458 "FALSIFY-A2e: For GQA, K weight must be smaller than Q weight"
1459 );
1460 }
1461
1462 #[test]
1467 fn falsify_a3e_gqa_forward_correct_output_dims() {
1468 let mut config = TransformerConfig::tiny();
1469 config.num_kv_heads = 1; let attn = MultiHeadAttention::new(&config);
1472 let seq_len = 3;
1473 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1474 let output = attn.forward(&x, seq_len);
1475
1476 assert_eq!(
1477 output.len(),
1478 seq_len * config.hidden_size,
1479 "FALSIFY-A3e: GQA output must be seq_len * hidden_size, not seq_len * kv_hidden"
1480 );
1481 }
1482
1483 #[test]
1487 fn falsify_a4e_init_produces_valid_attention_weights() {
1488 let config = TransformerConfig::tiny();
1489 let attn = MultiHeadAttention::new(&config);
1490
1491 for (name, w) in
1492 [("w_q", &attn.w_q), ("w_k", &attn.w_k), ("w_v", &attn.w_v), ("w_o", &attn.w_o)]
1493 {
1494 let data = w.data();
1495 let slice = data.as_slice().expect("data as slice");
1496
1497 let nan_count = slice.iter().filter(|v| v.is_nan()).count();
1499 assert_eq!(nan_count, 0, "FALSIFY-A4e: {name} init must not contain NaN");
1500
1501 let inf_count = slice.iter().filter(|v| v.is_infinite()).count();
1503 assert_eq!(inf_count, 0, "FALSIFY-A4e: {name} init must not contain Inf");
1504
1505 let min = slice.iter().copied().fold(f32::INFINITY, f32::min);
1507 let max = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1508 assert!(
1509 (max - min).abs() > 1e-6,
1510 "FALSIFY-A4e: {name} init values are constant ({min}..{max}) — degenerate weight"
1511 );
1512 }
1513 }
1514
1515 #[test]
1520 fn falsify_a5e_forward_produces_finite_output() {
1521 let config = TransformerConfig::tiny();
1522 let attn = MultiHeadAttention::new(&config);
1523 let seq_len = 4;
1524 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1525 let output = attn.forward(&x, seq_len);
1526
1527 let data = output.data();
1528 let nan_count = data.iter().filter(|v| v.is_nan()).count();
1529 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
1530 assert_eq!(nan_count, 0, "FALSIFY-A5e: Attention output must not contain NaN");
1531 assert_eq!(inf_count, 0, "FALSIFY-A5e: Attention output must not contain Inf");
1532 }
1533
1534 #[test]
1551 fn falsify_gq_001e_output_shape() {
1552 for (num_heads, num_kv_heads) in [(2, 2), (4, 2), (4, 1), (2, 1)] {
1553 let mut config = TransformerConfig::tiny();
1554 config.num_attention_heads = num_heads;
1555 config.num_kv_heads = num_kv_heads;
1556
1557 let attn = MultiHeadAttention::new(&config);
1558 let seq_len = 3;
1559 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
1560 let output = attn.forward(&x, seq_len);
1561
1562 assert_eq!(
1563 output.len(),
1564 seq_len * config.hidden_size,
1565 "FALSIFIED GQ-001e: output len mismatch for heads={num_heads},kv={num_kv_heads}"
1566 );
1567 }
1568 }
1569
1570 #[test]
1572 fn falsify_gq_002e_mha_degeneration() {
1573 let config = TransformerConfig::tiny(); assert_eq!(config.num_attention_heads, config.num_kv_heads);
1575
1576 let attn = MultiHeadAttention::new(&config);
1577 let seq_len = 4;
1578 let x = Tensor::from_vec(
1579 (0..seq_len * config.hidden_size).map(|i| (i as f32 * 0.37).sin()).collect(),
1580 true,
1581 );
1582 let output = attn.forward(&x, seq_len);
1583
1584 let data = output.data();
1585 for (i, v) in data.iter().enumerate() {
1586 assert!(v.is_finite(), "FALSIFIED GQ-002e: MHA output[{i}] = {v} (not finite)");
1587 }
1588 }
1589
1590 #[test]
1592 fn falsify_gq_004e_head_divisibility() {
1593 for (nh, nkv) in [(2, 1), (2, 2), (4, 1), (4, 2), (4, 4), (8, 2), (8, 4)] {
1595 let mut config = TransformerConfig::tiny();
1596 config.num_attention_heads = nh;
1597 config.num_kv_heads = nkv;
1598 assert_eq!(nh % nkv, 0, "FALSIFIED GQ-004e: test config has invalid head ratio");
1599 let attn = MultiHeadAttention::new(&config);
1601 let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
1602 let _ = attn.forward(&x, 2);
1603 }
1604 }
1605
1606 #[test]
1608 fn falsify_gq_006e_mqa_boundary() {
1609 let mut config = TransformerConfig::tiny();
1610 config.num_attention_heads = 4;
1611 config.num_kv_heads = 1;
1612 config.hidden_size = 64;
1614
1615 let attn = MultiHeadAttention::new(&config);
1616 let seq_len = 3;
1617 let x = Tensor::from_vec(
1618 (0..seq_len * config.hidden_size).map(|i| (i as f32 * 0.73).cos()).collect(),
1619 true,
1620 );
1621 let output = attn.forward(&x, seq_len);
1622
1623 assert_eq!(
1624 output.len(),
1625 seq_len * config.hidden_size,
1626 "FALSIFIED GQ-006e: MQA output size wrong"
1627 );
1628
1629 let data = output.data();
1631 for (i, v) in data.iter().enumerate() {
1632 assert!(v.is_finite(), "FALSIFIED GQ-006e: MQA output[{i}] = {v} (not finite)");
1633 }
1634 }
1635
1636 mod gq_proptest_falsify {
1637 use super::*;
1638 use proptest::prelude::*;
1639
1640 proptest! {
1642 #![proptest_config(ProptestConfig::with_cases(50))]
1643
1644 #[test]
1645 fn falsify_gq_001e_prop_output_shape(
1646 config_idx in 0..4usize,
1647 seq_len in 2..=6usize,
1648 seed in 0..500u32,
1649 ) {
1650 let configs: [(usize, usize); 4] = [
1651 (2, 2), (2, 1), (4, 2), (4, 1),
1652 ];
1653 let (num_heads, num_kv_heads) = configs[config_idx];
1654 let mut config = TransformerConfig::tiny();
1655 config.num_attention_heads = num_heads;
1656 config.num_kv_heads = num_kv_heads;
1657
1658 let attn = MultiHeadAttention::new(&config);
1659 let data: Vec<f32> = (0..seq_len * config.hidden_size)
1660 .map(|i| ((i as f32 + seed as f32) * 0.37).sin())
1661 .collect();
1662 let x = Tensor::from_vec(data, true);
1663 let output = attn.forward(&x, seq_len);
1664
1665 prop_assert_eq!(
1666 output.len(),
1667 seq_len * config.hidden_size,
1668 "FALSIFIED GQ-001e-prop: output len mismatch"
1669 );
1670
1671 for v in output.data() {
1673 prop_assert!(
1674 v.is_finite(),
1675 "FALSIFIED GQ-001e-prop: non-finite output"
1676 );
1677 }
1678 }
1679 }
1680
1681 proptest! {
1683 #![proptest_config(ProptestConfig::with_cases(30))]
1684
1685 #[test]
1686 fn falsify_gq_006e_prop_mqa_boundary(
1687 seed in 0..500u32,
1688 seq_len in 2..=5usize,
1689 ) {
1690 let mut config = TransformerConfig::tiny();
1691 config.num_attention_heads = 4;
1692 config.num_kv_heads = 1;
1693 config.hidden_size = 64;
1694
1695 let attn = MultiHeadAttention::new(&config);
1696 let data: Vec<f32> = (0..seq_len * config.hidden_size)
1697 .map(|i| ((i as f32 + seed as f32) * 0.73).cos())
1698 .collect();
1699 let x = Tensor::from_vec(data, true);
1700 let output = attn.forward(&x, seq_len);
1701
1702 prop_assert_eq!(
1703 output.len(),
1704 seq_len * config.hidden_size,
1705 "FALSIFIED GQ-006e-prop: MQA output len mismatch"
1706 );
1707
1708 for v in output.data() {
1709 prop_assert!(
1710 v.is_finite(),
1711 "FALSIFIED GQ-006e-prop: non-finite MQA output"
1712 );
1713 }
1714 }
1715 }
1716 }
1717
1718 #[test]
1719 fn test_attention_from_params_with_biases() {
1720 let config = TransformerConfig::tiny();
1721 let hidden_size = config.hidden_size;
1722 let kv_hidden_size = config.num_kv_heads * config.head_dim();
1723
1724 let mut params = HashMap::new();
1725 params.insert(
1726 "attn.q_proj.weight".to_string(),
1727 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1728 );
1729 params.insert(
1730 "attn.k_proj.weight".to_string(),
1731 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1732 );
1733 params.insert(
1734 "attn.v_proj.weight".to_string(),
1735 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1736 );
1737 params.insert(
1738 "attn.o_proj.weight".to_string(),
1739 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1740 );
1741 params.insert(
1742 "attn.q_proj.bias".to_string(),
1743 Tensor::from_vec(vec![0.01; hidden_size], true),
1744 );
1745 params.insert(
1746 "attn.k_proj.bias".to_string(),
1747 Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1748 );
1749 params.insert(
1750 "attn.v_proj.bias".to_string(),
1751 Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1752 );
1753
1754 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn");
1755 assert!(attn.is_some());
1756 let attn = attn.expect("should load with biases");
1757 assert!(attn.has_biases());
1758 assert_eq!(attn.parameters().len(), 7);
1759 }
1760
1761 #[test]
1762 fn test_attention_named_parameters_with_biases() {
1763 let config = TransformerConfig::tiny();
1764 let hidden_size = config.hidden_size;
1765 let kv_hidden_size = config.num_kv_heads * config.head_dim();
1766
1767 let mut params = HashMap::new();
1768 params.insert(
1769 "attn.q_proj.weight".to_string(),
1770 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1771 );
1772 params.insert(
1773 "attn.k_proj.weight".to_string(),
1774 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1775 );
1776 params.insert(
1777 "attn.v_proj.weight".to_string(),
1778 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1779 );
1780 params.insert(
1781 "attn.o_proj.weight".to_string(),
1782 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1783 );
1784 params.insert(
1785 "attn.q_proj.bias".to_string(),
1786 Tensor::from_vec(vec![0.01; hidden_size], true),
1787 );
1788 params.insert(
1789 "attn.k_proj.bias".to_string(),
1790 Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1791 );
1792 params.insert(
1793 "attn.v_proj.bias".to_string(),
1794 Tensor::from_vec(vec![0.01; kv_hidden_size], true),
1795 );
1796
1797 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn").expect("should load");
1798 let named = attn.named_parameters("attn");
1799 assert_eq!(named.len(), 7);
1800 let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
1801 assert!(names.contains(&"attn.q_proj.bias"));
1802 assert!(names.contains(&"attn.k_proj.bias"));
1803 assert!(names.contains(&"attn.v_proj.bias"));
1804 }
1805
1806 #[test]
1807 fn test_attention_forward_with_biases() {
1808 let config = TransformerConfig::tiny();
1809 let hidden_size = config.hidden_size;
1810 let kv_hidden_size = config.num_kv_heads * config.head_dim();
1811
1812 let mut params = HashMap::new();
1813 params.insert(
1814 "attn.q_proj.weight".to_string(),
1815 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1816 );
1817 params.insert(
1818 "attn.k_proj.weight".to_string(),
1819 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1820 );
1821 params.insert(
1822 "attn.v_proj.weight".to_string(),
1823 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
1824 );
1825 params.insert(
1826 "attn.o_proj.weight".to_string(),
1827 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
1828 );
1829 params
1830 .insert("attn.q_proj.bias".to_string(), Tensor::from_vec(vec![0.5; hidden_size], true));
1831 params.insert(
1832 "attn.k_proj.bias".to_string(),
1833 Tensor::from_vec(vec![0.5; kv_hidden_size], true),
1834 );
1835 params.insert(
1836 "attn.v_proj.bias".to_string(),
1837 Tensor::from_vec(vec![0.5; kv_hidden_size], true),
1838 );
1839
1840 let attn = MultiHeadAttention::from_params(&config, ¶ms, "attn").expect("should load");
1841 let x = Tensor::from_vec(vec![0.1; 2 * hidden_size], false);
1842 let output = attn.forward(&x, 2);
1843 assert_eq!(output.len(), 2 * hidden_size);
1844 assert!(output.data().iter().all(|v| v.is_finite()));
1845 }
1846}