1use ferrotorch_core::grad_fns::activation::softmax;
36use ferrotorch_core::grad_fns::arithmetic::{add, mul};
37use ferrotorch_core::grad_fns::linalg::{bmm_differentiable, mm_differentiable};
38use ferrotorch_core::grad_fns::shape::{expand, transpose_2d};
39use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
40
41use crate::init::{xavier_uniform, zeros};
42use crate::module::Module;
43use crate::parameter::Parameter;
44
45#[derive(Debug)]
64pub struct MultiheadAttention<T: Float> {
65 pub embed_dim: usize,
66 pub num_heads: usize,
67 pub num_kv_heads: usize,
71 pub head_dim: usize,
72
73 pub q_proj: Parameter<T>,
75 pub k_proj: Parameter<T>,
77 pub v_proj: Parameter<T>,
79 pub out_proj: Parameter<T>,
81
82 pub q_bias: Option<Parameter<T>>,
84 pub k_bias: Option<Parameter<T>>,
85 pub v_bias: Option<Parameter<T>>,
86 pub out_bias: Option<Parameter<T>>,
87
88 pub training: bool,
89}
90
91impl<T: Float> MultiheadAttention<T> {
92 pub fn new(embed_dim: usize, num_heads: usize, bias: bool) -> FerrotorchResult<Self> {
105 Self::with_gqa(embed_dim, num_heads, num_heads, bias)
106 }
107
108 pub fn with_gqa(
127 embed_dim: usize,
128 num_heads: usize,
129 num_kv_heads: usize,
130 bias: bool,
131 ) -> FerrotorchResult<Self> {
132 if embed_dim == 0 || num_heads == 0 || num_kv_heads == 0 {
133 return Err(FerrotorchError::InvalidArgument {
134 message: "embed_dim, num_heads, num_kv_heads must be positive".into(),
135 });
136 }
137 if embed_dim % num_heads != 0 {
138 return Err(FerrotorchError::InvalidArgument {
139 message: format!(
140 "embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
141 ),
142 });
143 }
144 if num_heads % num_kv_heads != 0 {
145 return Err(FerrotorchError::InvalidArgument {
146 message: format!(
147 "num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
148 ),
149 });
150 }
151
152 let head_dim = embed_dim / num_heads;
153 let kv_dim = num_kv_heads * head_dim;
154
155 let mut q_proj = Parameter::zeros(&[embed_dim, embed_dim])?;
156 let mut k_proj = Parameter::zeros(&[kv_dim, embed_dim])?;
157 let mut v_proj = Parameter::zeros(&[kv_dim, embed_dim])?;
158 let mut out_proj = Parameter::zeros(&[embed_dim, embed_dim])?;
159
160 xavier_uniform(&mut q_proj)?;
161 xavier_uniform(&mut k_proj)?;
162 xavier_uniform(&mut v_proj)?;
163 xavier_uniform(&mut out_proj)?;
164
165 let (q_bias, k_bias, v_bias, out_bias) = if bias {
166 let mut qb = Parameter::zeros(&[embed_dim])?;
167 let mut kb = Parameter::zeros(&[kv_dim])?;
168 let mut vb = Parameter::zeros(&[kv_dim])?;
169 let mut ob = Parameter::zeros(&[embed_dim])?;
170 zeros(&mut qb)?;
171 zeros(&mut kb)?;
172 zeros(&mut vb)?;
173 zeros(&mut ob)?;
174 (Some(qb), Some(kb), Some(vb), Some(ob))
175 } else {
176 (None, None, None, None)
177 };
178
179 Ok(Self {
180 embed_dim,
181 num_heads,
182 num_kv_heads,
183 head_dim,
184 q_proj,
185 k_proj,
186 v_proj,
187 out_proj,
188 q_bias,
189 k_bias,
190 v_bias,
191 out_bias,
192 training: true,
193 })
194 }
195
196 pub fn forward_qkv(
211 &self,
212 query: &Tensor<T>,
213 key: &Tensor<T>,
214 value: &Tensor<T>,
215 causal_mask: bool,
216 ) -> FerrotorchResult<Tensor<T>> {
217 if query.ndim() != 3 || key.ndim() != 3 || value.ndim() != 3 {
219 return Err(FerrotorchError::InvalidArgument {
220 message: format!(
221 "MultiheadAttention expects 3-D inputs [batch, seq, embed_dim], \
222 got query {:?}, key {:?}, value {:?}",
223 query.shape(),
224 key.shape(),
225 value.shape()
226 ),
227 });
228 }
229
230 let batch = query.shape()[0];
231 let seq_q = query.shape()[1];
232 let seq_k = key.shape()[1];
233
234 if query.shape()[2] != self.embed_dim
235 || key.shape()[2] != self.embed_dim
236 || value.shape()[2] != self.embed_dim
237 {
238 return Err(FerrotorchError::ShapeMismatch {
239 message: format!(
240 "embed_dim mismatch: expected {}, got query={}, key={}, value={}",
241 self.embed_dim,
242 query.shape()[2],
243 key.shape()[2],
244 value.shape()[2]
245 ),
246 });
247 }
248
249 if key.shape()[0] != batch || value.shape()[0] != batch {
250 return Err(FerrotorchError::ShapeMismatch {
251 message: format!(
252 "batch size mismatch: query batch={}, key batch={}, value batch={}",
253 batch,
254 key.shape()[0],
255 value.shape()[0]
256 ),
257 });
258 }
259
260 if key.shape()[1] != value.shape()[1] {
261 return Err(FerrotorchError::ShapeMismatch {
262 message: format!(
263 "key and value seq_len must match: key={}, value={}",
264 key.shape()[1],
265 value.shape()[1]
266 ),
267 });
268 }
269
270 if causal_mask && seq_q != seq_k {
271 return Err(FerrotorchError::InvalidArgument {
272 message: format!(
273 "causal mask requires seq_q == seq_k, got seq_q={seq_q}, seq_k={seq_k}"
274 ),
275 });
276 }
277
278 if seq_q == 1 && seq_k == 1 && !causal_mask && self.num_kv_heads == self.num_heads {
289 use ferrotorch_core::grad_fns::linalg::linear_fused;
290
291 let v_2d = value.reshape_t(&[batch as isize, self.embed_dim as isize])?;
293
294 let v_proj = linear_fused(
296 &v_2d,
297 self.v_proj.tensor(),
298 self.v_bias.as_ref().map(|b| b.tensor()),
299 )?;
300
301 let output = linear_fused(
303 &v_proj,
304 self.out_proj.tensor(),
305 self.out_bias.as_ref().map(|b| b.tensor()),
306 )?;
307
308 return output.reshape_t(&[batch as isize, 1, self.embed_dim as isize]);
310 }
311
312 let nh = self.num_heads;
319 let nkv = self.num_kv_heads;
320 let hd = self.head_dim;
321 let group_size = nh / nkv;
322
323 let wq_t = transpose_2d(self.q_proj.tensor())?;
325 let wk_t = transpose_2d(self.k_proj.tensor())?;
326 let wv_t = transpose_2d(self.v_proj.tensor())?;
327 let wo_t = transpose_2d(self.out_proj.tensor())?;
328
329 let flat_q = query.reshape_t(&[-1, self.embed_dim as isize])?;
330 let flat_k = key.reshape_t(&[-1, self.embed_dim as isize])?;
331 let flat_v = value.reshape_t(&[-1, self.embed_dim as isize])?;
332
333 let mut q_proj = mm_differentiable(&flat_q, &wq_t)?;
334 let mut k_proj = mm_differentiable(&flat_k, &wk_t)?;
335 let mut v_proj = mm_differentiable(&flat_v, &wv_t)?;
336
337 if let Some(ref qb) = self.q_bias {
338 let b = expand_bias_to_2d(qb.tensor(), batch * seq_q)?;
339 q_proj = add(&q_proj, &b)?;
340 }
341 if let Some(ref kb) = self.k_bias {
342 let b = expand_bias_to_2d(kb.tensor(), batch * seq_k)?;
343 k_proj = add(&k_proj, &b)?;
344 }
345 if let Some(ref vb) = self.v_bias {
346 let b = expand_bias_to_2d(vb.tensor(), batch * seq_k)?;
347 v_proj = add(&v_proj, &b)?;
348 }
349
350 let q = q_proj
354 .reshape_t(&[batch as isize, seq_q as isize, nh as isize, hd as isize])?
355 .permute(&[0, 2, 1, 3])?
356 .contiguous()?
357 .reshape_t(&[(batch * nh) as isize, seq_q as isize, hd as isize])?;
358
359 let mut k = k_proj
360 .reshape_t(&[batch as isize, seq_k as isize, nkv as isize, hd as isize])?
361 .permute(&[0, 2, 1, 3])?
362 .contiguous()?;
363 let mut v = v_proj
364 .reshape_t(&[batch as isize, seq_k as isize, nkv as isize, hd as isize])?
365 .permute(&[0, 2, 1, 3])?
366 .contiguous()?;
367
368 if group_size > 1 {
370 k = k.reshape_t(&[batch as isize, nkv as isize, 1, seq_k as isize, hd as isize])?;
373 k = expand(&k, &[batch, nkv, group_size, seq_k, hd])?;
374 k = k.reshape_t(&[batch as isize, nh as isize, seq_k as isize, hd as isize])?;
375
376 v = v.reshape_t(&[batch as isize, nkv as isize, 1, seq_k as isize, hd as isize])?;
377 v = expand(&v, &[batch, nkv, group_size, seq_k, hd])?;
378 v = v.reshape_t(&[batch as isize, nh as isize, seq_k as isize, hd as isize])?;
379 }
380
381 let k = k.reshape_t(&[(batch * nh) as isize, seq_k as isize, hd as isize])?;
382 let v = v.reshape_t(&[(batch * nh) as isize, seq_k as isize, hd as isize])?;
383
384 let k_t = k.permute(&[0, 2, 1])?.contiguous()?;
387 let scores = bmm_differentiable(&q, &k_t)?;
388
389 let scale_val = T::from(1.0 / (hd as f64).sqrt()).unwrap();
390 let scale_tensor = Tensor::from_storage(
391 TensorStorage::on_device(vec![scale_val], scores.device())?,
392 vec![1],
393 false,
394 )?;
395 let scaled = mul(&scores, &scale_tensor)?;
396
397 let masked = if causal_mask {
399 let neg_inf = T::from(-1e9).unwrap();
400 let zero = <T as num_traits::Zero>::zero();
401 let mut mask_data = vec![zero; seq_q * seq_k];
402 for i in 0..seq_q {
403 for j in (i + 1)..seq_k {
404 mask_data[i * seq_k + j] = neg_inf;
405 }
406 }
407 let mask =
408 Tensor::from_storage(TensorStorage::cpu(mask_data), vec![1, seq_q, seq_k], false)?;
409 let mask = if scaled.is_cuda() {
410 mask.to(scaled.device())?
411 } else {
412 mask
413 };
414 add(&scaled, &mask)?
415 } else {
416 scaled
417 };
418
419 let weights = softmax(&masked)?;
421 let context = bmm_differentiable(&weights, &v)?;
422
423 let context = context
425 .reshape_t(&[batch as isize, nh as isize, seq_q as isize, hd as isize])?
426 .permute(&[0, 2, 1, 3])?
427 .contiguous()?
428 .reshape_t(&[(batch * seq_q) as isize, self.embed_dim as isize])?;
429
430 let mut output = mm_differentiable(&context, &wo_t)?;
432 if let Some(ref ob) = self.out_bias {
433 let b = expand_bias_to_2d(ob.tensor(), batch * seq_q)?;
434 output = add(&output, &b)?;
435 }
436
437 output.reshape_t(&[batch as isize, seq_q as isize, self.embed_dim as isize])
438 }
439
440 #[inline]
442 pub fn embed_dim(&self) -> usize {
443 self.embed_dim
444 }
445
446 #[inline]
448 pub fn num_heads(&self) -> usize {
449 self.num_heads
450 }
451
452 #[inline]
455 pub fn num_kv_heads(&self) -> usize {
456 self.num_kv_heads
457 }
458
459 #[inline]
461 pub fn head_dim(&self) -> usize {
462 self.head_dim
463 }
464
465 #[inline]
468 pub fn is_gqa(&self) -> bool {
469 self.num_kv_heads != self.num_heads
470 }
471
472 pub fn forward_2d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
482 use ferrotorch_core::grad_fns::linalg::linear_fused;
483
484 if self.is_gqa() {
485 return Err(FerrotorchError::InvalidArgument {
486 message:
487 "forward_2d is MHA-only; use forward_qkv for GQA (num_kv_heads != num_heads)"
488 .into(),
489 });
490 }
491
492 let v_proj = linear_fused(
493 input,
494 self.v_proj.tensor(),
495 self.v_bias.as_ref().map(|b| b.tensor()),
496 )?;
497 linear_fused(
498 &v_proj,
499 self.out_proj.tensor(),
500 self.out_bias.as_ref().map(|b| b.tensor()),
501 )
502 }
503}
504
505impl<T: Float> Module<T> for MultiheadAttention<T> {
506 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
508 self.forward_qkv(input, input, input, false)
509 }
510
511 fn parameters(&self) -> Vec<&Parameter<T>> {
512 let mut params = vec![&self.q_proj, &self.k_proj, &self.v_proj, &self.out_proj];
513 if let Some(ref b) = self.q_bias {
514 params.push(b);
515 }
516 if let Some(ref b) = self.k_bias {
517 params.push(b);
518 }
519 if let Some(ref b) = self.v_bias {
520 params.push(b);
521 }
522 if let Some(ref b) = self.out_bias {
523 params.push(b);
524 }
525 params
526 }
527
528 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
529 let mut params: Vec<&mut Parameter<T>> = vec![
530 &mut self.q_proj,
531 &mut self.k_proj,
532 &mut self.v_proj,
533 &mut self.out_proj,
534 ];
535 if let Some(ref mut b) = self.q_bias {
536 params.push(b);
537 }
538 if let Some(ref mut b) = self.k_bias {
539 params.push(b);
540 }
541 if let Some(ref mut b) = self.v_bias {
542 params.push(b);
543 }
544 if let Some(ref mut b) = self.out_bias {
545 params.push(b);
546 }
547 params
548 }
549
550 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
551 let mut params = vec![
552 ("q_proj.weight".to_string(), &self.q_proj),
553 ("k_proj.weight".to_string(), &self.k_proj),
554 ("v_proj.weight".to_string(), &self.v_proj),
555 ("out_proj.weight".to_string(), &self.out_proj),
556 ];
557 if let Some(ref b) = self.q_bias {
558 params.push(("q_proj.bias".to_string(), b));
559 }
560 if let Some(ref b) = self.k_bias {
561 params.push(("k_proj.bias".to_string(), b));
562 }
563 if let Some(ref b) = self.v_bias {
564 params.push(("v_proj.bias".to_string(), b));
565 }
566 if let Some(ref b) = self.out_bias {
567 params.push(("out_proj.bias".to_string(), b));
568 }
569 params
570 }
571
572 fn train(&mut self) {
573 self.training = true;
574 }
575
576 fn eval(&mut self) {
577 self.training = false;
578 }
579
580 fn is_training(&self) -> bool {
581 self.training
582 }
583}
584
585fn expand_bias_to_2d<T: Float>(bias: &Tensor<T>, rows: usize) -> FerrotorchResult<Tensor<T>> {
594 let dim = bias.shape()[0];
595 let bias_2d = bias.reshape_t(&[1, dim as isize])?;
597 expand(&bias_2d, &[rows, dim])
598}
599
600pub fn reshape_to_heads<T: Float>(
607 tensor: &Tensor<T>,
608 num_heads: usize,
609 seq_len: usize,
610 head_dim: usize,
611) -> FerrotorchResult<Tensor<T>> {
612 let data = tensor.data()?;
613 let mut result = vec![<T as num_traits::Zero>::zero(); num_heads * seq_len * head_dim];
616
617 for s in 0..seq_len {
618 for h in 0..num_heads {
619 for d in 0..head_dim {
620 let src_idx = s * (num_heads * head_dim) + h * head_dim + d;
621 let dst_idx = h * (seq_len * head_dim) + s * head_dim + d;
622 result[dst_idx] = data[src_idx];
623 }
624 }
625 }
626
627 Tensor::from_storage(
628 TensorStorage::cpu(result),
629 vec![num_heads, seq_len, head_dim],
630 tensor.requires_grad(),
631 )
632}
633
634pub fn transpose_heads_to_2d<T: Float>(
638 tensor: &Tensor<T>,
639 num_heads: usize,
640 seq_len: usize,
641 head_dim: usize,
642) -> FerrotorchResult<Tensor<T>> {
643 let embed_dim = num_heads * head_dim;
644 let data = tensor.data_vec()?;
645 let mut result = vec![<T as num_traits::Zero>::zero(); seq_len * embed_dim];
646
647 for h in 0..num_heads {
648 for s in 0..seq_len {
649 for d in 0..head_dim {
650 let src_idx = h * (seq_len * head_dim) + s * head_dim + d;
651 let dst_idx = s * embed_dim + h * head_dim + d;
652 result[dst_idx] = data[src_idx];
653 }
654 }
655 }
656
657 let device = tensor.device();
658 Tensor::from_storage(
659 TensorStorage::on_device(result, device)?,
660 vec![seq_len, embed_dim],
661 false,
662 )
663}
664
665pub fn repeat_kv<T: Float>(kv: &Tensor<T>, group_size: usize) -> FerrotorchResult<Tensor<T>> {
683 if group_size == 1 {
684 return Ok(kv.clone());
685 }
686 let shape = kv.shape();
687 if shape.len() != 3 {
688 return Err(FerrotorchError::ShapeMismatch {
689 message: format!(
690 "repeat_kv expects 3-D [num_kv_heads, seq, head_dim], got {:?}",
691 shape
692 ),
693 });
694 }
695 let num_kv_heads = shape[0];
696 let seq = shape[1];
697 let head_dim = shape[2];
698 let num_q_heads = num_kv_heads * group_size;
699 let data = kv.data_vec()?;
700 let head_stride = seq * head_dim;
701 let mut out = vec![<T as num_traits::Zero>::zero(); num_q_heads * head_stride];
702 for h in 0..num_q_heads {
703 let kv_h = h / group_size;
704 let src_start = kv_h * head_stride;
705 let dst_start = h * head_stride;
706 out[dst_start..dst_start + head_stride]
707 .copy_from_slice(&data[src_start..src_start + head_stride]);
708 }
709 let device = kv.device();
710 Tensor::from_storage(
711 TensorStorage::on_device(out, device)?,
712 vec![num_q_heads, seq, head_dim],
713 kv.requires_grad(),
714 )
715}
716
717#[cfg(test)]
722mod tests {
723 use super::*;
724
725 #[test]
726 fn test_new_valid() {
727 let mha = MultiheadAttention::<f32>::new(64, 8, true);
728 assert!(mha.is_ok());
729 let mha = mha.unwrap();
730 assert_eq!(mha.embed_dim(), 64);
731 assert_eq!(mha.num_heads(), 8);
732 assert_eq!(mha.head_dim(), 8);
733 }
734
735 #[test]
736 fn test_new_invalid_divisibility() {
737 let result = MultiheadAttention::<f32>::new(65, 8, true);
738 assert!(result.is_err());
739 }
740
741 #[test]
742 fn test_new_zero_dims() {
743 assert!(MultiheadAttention::<f32>::new(0, 4, false).is_err());
744 assert!(MultiheadAttention::<f32>::new(64, 0, false).is_err());
745 }
746
747 #[test]
748 fn test_parameter_count_with_bias() {
749 let mha = MultiheadAttention::<f32>::new(16, 4, true).unwrap();
750 let params = mha.parameters();
751 let total: usize = params.iter().map(|p| p.numel()).sum();
755 let embed_dim = 16usize;
756 let expected = 4 * embed_dim * embed_dim + 4 * embed_dim;
757 assert_eq!(total, expected);
758 assert_eq!(params.len(), 8); }
760
761 #[test]
762 fn test_parameter_count_without_bias() {
763 let mha = MultiheadAttention::<f32>::new(16, 4, false).unwrap();
764 let params = mha.parameters();
765 let total: usize = params.iter().map(|p| p.numel()).sum();
766 let embed_dim = 16usize;
767 let expected = 4 * embed_dim * embed_dim;
768 assert_eq!(total, expected);
769 assert_eq!(params.len(), 4); }
771
772 #[test]
773 fn test_named_parameters() {
774 let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
775 let named = mha.named_parameters();
776 let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
777 assert!(names.contains(&"q_proj.weight"));
778 assert!(names.contains(&"k_proj.weight"));
779 assert!(names.contains(&"v_proj.weight"));
780 assert!(names.contains(&"out_proj.weight"));
781 assert!(names.contains(&"q_proj.bias"));
782 assert!(names.contains(&"k_proj.bias"));
783 assert!(names.contains(&"v_proj.bias"));
784 assert!(names.contains(&"out_proj.bias"));
785 }
786
787 #[test]
788 fn test_output_shape() {
789 let mha = MultiheadAttention::<f32>::new(16, 4, true).unwrap();
790 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
792 let output = mha.forward(&input).unwrap();
793 assert_eq!(output.shape(), &[2, 5, 16]);
794 }
795
796 #[test]
797 fn test_output_shape_no_bias() {
798 let mha = MultiheadAttention::<f32>::new(8, 2, false).unwrap();
799 let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
800 let output = mha.forward(&input).unwrap();
801 assert_eq!(output.shape(), &[1, 3, 8]);
802 }
803
804 #[test]
805 fn test_self_attention_basic_forward() {
806 let mha = MultiheadAttention::<f64>::new(4, 2, true).unwrap();
808 let input = ferrotorch_core::ones::<f64>(&[1, 2, 4]).unwrap();
809 let output = mha.forward(&input).unwrap();
810
811 assert_eq!(output.shape(), &[1, 2, 4]);
812 let data = output.data().unwrap();
813 for &v in data {
815 assert!(v.is_finite(), "output contains non-finite value: {v}");
816 }
817 }
818
819 #[test]
820 fn test_cross_attention_shape() {
821 let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
822 let query = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
824 let kv = ferrotorch_core::zeros::<f32>(&[1, 5, 8]).unwrap();
825 let output = mha.forward_qkv(&query, &kv, &kv, false).unwrap();
826 assert_eq!(output.shape(), &[1, 3, 8]);
827 }
828
829 #[test]
830 fn test_causal_mask_different_seq_lens_error() {
831 let mha = MultiheadAttention::<f32>::new(8, 2, false).unwrap();
832 let query = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
833 let kv = ferrotorch_core::zeros::<f32>(&[1, 5, 8]).unwrap();
834 let result = mha.forward_qkv(&query, &kv, &kv, true);
836 assert!(result.is_err());
837 }
838
839 #[test]
840 fn test_train_eval_toggle() {
841 let mut mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
842 assert!(mha.is_training());
843 mha.eval();
844 assert!(!mha.is_training());
845 mha.train();
846 assert!(mha.is_training());
847 }
848
849 #[test]
850 fn test_wrong_embed_dim_input() {
851 let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
852 let input = ferrotorch_core::zeros::<f32>(&[1, 3, 4]).unwrap();
854 let result = mha.forward(&input);
855 assert!(result.is_err());
856 }
857
858 #[test]
859 fn test_2d_input_rejected() {
860 let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
861 let input = ferrotorch_core::zeros::<f32>(&[3, 8]).unwrap();
862 let result = mha.forward(&input);
863 assert!(result.is_err());
864 }
865
866 #[test]
867 fn test_is_send_sync() {
868 fn assert_send_sync<T: Send + Sync>() {}
869 assert_send_sync::<MultiheadAttention<f32>>();
870 assert_send_sync::<MultiheadAttention<f64>>();
871 }
872
873 #[test]
876 fn test_with_gqa_valid_construction() {
877 let mha = MultiheadAttention::<f32>::with_gqa(4096, 32, 8, false).unwrap();
879 assert_eq!(mha.embed_dim(), 4096);
880 assert_eq!(mha.num_heads(), 32);
881 assert_eq!(mha.num_kv_heads(), 8);
882 assert_eq!(mha.head_dim(), 128);
883 assert!(mha.is_gqa());
884 }
885
886 #[test]
887 fn test_with_gqa_kv_proj_shapes() {
888 let mha = MultiheadAttention::<f32>::with_gqa(64, 8, 2, true).unwrap();
890 let kv_dim = 2 * (64 / 8); assert_eq!(mha.q_proj.shape(), &[64, 64]);
892 assert_eq!(mha.k_proj.shape(), &[kv_dim, 64]);
893 assert_eq!(mha.v_proj.shape(), &[kv_dim, 64]);
894 assert_eq!(mha.out_proj.shape(), &[64, 64]);
895 assert_eq!(mha.q_bias.as_ref().unwrap().shape(), &[64]);
897 assert_eq!(mha.k_bias.as_ref().unwrap().shape(), &[kv_dim]);
898 assert_eq!(mha.v_bias.as_ref().unwrap().shape(), &[kv_dim]);
899 assert_eq!(mha.out_bias.as_ref().unwrap().shape(), &[64]);
900 }
901
902 #[test]
903 fn test_with_gqa_rejects_non_divisible_kv_heads() {
904 let result = MultiheadAttention::<f32>::with_gqa(64, 8, 3, false);
906 assert!(result.is_err());
907 }
908
909 #[test]
910 fn test_with_gqa_rejects_zero_kv_heads() {
911 let result = MultiheadAttention::<f32>::with_gqa(64, 8, 0, false);
912 assert!(result.is_err());
913 }
914
915 #[test]
916 fn test_with_gqa_equivalent_to_new_when_kv_equals_q() {
917 let gqa = MultiheadAttention::<f32>::with_gqa(32, 4, 4, true).unwrap();
919 let mha = MultiheadAttention::<f32>::new(32, 4, true).unwrap();
920 assert_eq!(gqa.num_kv_heads(), mha.num_kv_heads());
921 assert_eq!(gqa.k_proj.shape(), mha.k_proj.shape());
922 assert_eq!(gqa.v_proj.shape(), mha.v_proj.shape());
923 assert!(!gqa.is_gqa());
924 }
925
926 #[test]
927 fn test_repeat_kv_noop_on_group_size_1() {
928 let kv = ferrotorch_core::from_slice::<f32>(
930 &(0..24).map(|i| i as f32).collect::<Vec<_>>(),
931 &[2, 3, 4], )
933 .unwrap();
934 let out = repeat_kv(&kv, 1).unwrap();
935 assert_eq!(out.shape(), kv.shape());
936 assert_eq!(out.data_vec().unwrap(), kv.data_vec().unwrap());
937 }
938
939 #[test]
940 fn test_repeat_kv_copies_correct_heads() {
941 let data: Vec<f32> = vec![
945 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, ];
950 let kv = ferrotorch_core::from_slice::<f32>(&data, &[2, 2, 3]).unwrap();
951 let out = repeat_kv(&kv, 3).unwrap();
952 assert_eq!(out.shape(), &[6, 2, 3]);
953 let out_data = out.data_vec().unwrap();
954 let head_stride = 2 * 3; for h in 0..3 {
957 let start = h * head_stride;
958 assert_eq!(&out_data[start..start + head_stride], &data[0..head_stride]);
959 }
960 for h in 3..6 {
962 let start = h * head_stride;
963 assert_eq!(
964 &out_data[start..start + head_stride],
965 &data[head_stride..2 * head_stride]
966 );
967 }
968 }
969
970 #[test]
971 fn test_repeat_kv_rejects_wrong_rank() {
972 let kv = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap(); assert!(repeat_kv(&kv, 2).is_err());
974 }
975
976 #[test]
977 fn test_gqa_forward_output_shape_preserved() {
978 let mha = MultiheadAttention::<f32>::with_gqa(16, 4, 2, true).unwrap();
980 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
981 let out = mha.forward(&input).unwrap();
982 assert_eq!(out.shape(), &[2, 5, 16]);
983 }
984
985 #[test]
986 fn test_gqa_forward_produces_finite_values() {
987 let mha = MultiheadAttention::<f64>::with_gqa(8, 4, 2, true).unwrap();
988 let input = ferrotorch_core::ones::<f64>(&[1, 3, 8]).unwrap();
989 let out = mha.forward(&input).unwrap();
990 let data = out.data().unwrap();
991 for &v in data {
992 assert!(v.is_finite(), "GQA output non-finite: {v}");
993 }
994 }
995
996 #[test]
997 fn test_gqa_forward_decoder_style_single_token() {
998 let mha = MultiheadAttention::<f32>::with_gqa(32, 8, 2, false).unwrap();
1002 let input = ferrotorch_core::ones::<f32>(&[1, 1, 32]).unwrap();
1003 let out = mha.forward(&input).unwrap();
1004 assert_eq!(out.shape(), &[1, 1, 32]);
1005 for &v in out.data().unwrap() {
1006 assert!(v.is_finite());
1007 }
1008 }
1009
1010 #[test]
1011 fn test_gqa_forward_with_causal_mask() {
1012 let mha = MultiheadAttention::<f32>::with_gqa(16, 4, 2, false).unwrap();
1014 let x = ferrotorch_core::ones::<f32>(&[1, 4, 16]).unwrap();
1015 let out = mha.forward_qkv(&x, &x, &x, true).unwrap();
1016 assert_eq!(out.shape(), &[1, 4, 16]);
1017 for &v in out.data().unwrap() {
1018 assert!(v.is_finite());
1019 }
1020 }
1021}