1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20#[cfg(feature = "cuda")]
21use axonml_autograd::functions::FusedAttentionBackward;
22#[cfg(feature = "cuda")]
23use axonml_autograd::grad_fn::GradFn;
24use axonml_tensor::Tensor;
25
26use crate::layers::Linear;
27use crate::module::Module;
28use crate::parameter::Parameter;
29
30pub struct MultiHeadAttention {
50 q_proj: Linear,
52 k_proj: Linear,
54 v_proj: Linear,
56 out_proj: Linear,
58 embed_dim: usize,
60 num_heads: usize,
62 head_dim: usize,
64 scale: f32,
66 batch_first: bool,
68}
69
70impl MultiHeadAttention {
71 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
73 Self::with_options(embed_dim, num_heads, 0.0, true)
74 }
75
76 pub fn with_options(
78 embed_dim: usize,
79 num_heads: usize,
80 _dropout: f32,
81 batch_first: bool,
82 ) -> Self {
83 assert!(
84 embed_dim % num_heads == 0,
85 "embed_dim must be divisible by num_heads"
86 );
87
88 let head_dim = embed_dim / num_heads;
89 let scale = (head_dim as f32).sqrt().recip();
90
91 Self {
92 q_proj: Linear::new(embed_dim, embed_dim),
93 k_proj: Linear::new(embed_dim, embed_dim),
94 v_proj: Linear::new(embed_dim, embed_dim),
95 out_proj: Linear::new(embed_dim, embed_dim),
96 embed_dim,
97 num_heads,
98 head_dim,
99 scale,
100 batch_first,
101 }
102 }
103
104 #[allow(unused_variables)]
107 #[allow(dead_code)]
108 fn try_gpu_mask_expand(
109 mask_data: &Tensor<f32>,
110 mask_shape: &[usize],
111 scores_shape: &[usize],
112 device: axonml_core::Device,
113 total: usize,
114 batch_size: usize,
115 num_heads: usize,
116 tgt_len: usize,
117 src_len: usize,
118 ) -> Option<Variable> {
119 None
121 }
122
123 pub fn attention(
125 &self,
126 query: &Variable,
127 key: &Variable,
128 value: &Variable,
129 attn_mask: Option<&Variable>,
130 ) -> Variable {
131 let q_shape = query.shape();
132 let (batch_size, tgt_len, _) = if self.batch_first {
133 (q_shape[0], q_shape[1], q_shape[2])
134 } else {
135 (q_shape[1], q_shape[0], q_shape[2])
136 };
137 let src_len = if self.batch_first {
138 key.shape()[1]
139 } else {
140 key.shape()[0]
141 };
142
143 let q = self.q_proj.forward(query);
145 let k = self.k_proj.forward(key);
146 let v = self.v_proj.forward(value);
147
148 let q = q
151 .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
152 .transpose(1, 2);
153 let k = k
154 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
155 .transpose(1, 2);
156 let v = v
157 .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
158 .transpose(1, 2);
159
160 #[cfg(feature = "cuda")]
166 if q.data().device().is_gpu() && attn_mask.is_none() {
167 let is_training = axonml_autograd::no_grad::is_grad_enabled();
168 let q_tensor = q.data();
169 let k_tensor = k.data();
170 let v_tensor = v.data();
171
172 if let Some(attn_out) = q_tensor.fused_attention_cuda(
173 &k_tensor, &v_tensor, self.scale,
174 false, ) {
176 let attn_output = if is_training
177 && (q.requires_grad() || k.requires_grad() || v.requires_grad())
178 {
179 let backward = FusedAttentionBackward::new(
181 q.grad_fn().cloned(),
182 k.grad_fn().cloned(),
183 v.grad_fn().cloned(),
184 q_tensor,
185 k_tensor,
186 v_tensor,
187 attn_out.clone(),
188 self.scale,
189 false,
190 );
191 Variable::from_operation(attn_out, GradFn::new(backward), true)
192 } else {
193 Variable::new(attn_out, false)
194 };
195 let attn_output =
196 attn_output
197 .transpose(1, 2)
198 .reshape(&[batch_size, tgt_len, self.embed_dim]);
199 return self.out_proj.forward(&attn_output);
200 }
201 }
203
204 let k_t = k.transpose(2, 3);
207 let scores = q.matmul(&k_t).mul_scalar(self.scale);
209
210 let scores = if let Some(mask) = attn_mask {
214 let mask_shape = mask.shape();
215 let mask_data = mask.data();
216 let scores_shape = scores.shape();
217 let total = scores_shape.iter().product::<usize>();
218
219 #[cfg(feature = "cuda")]
222 if scores.data().device().is_gpu() {
223 let mask_gpu = if mask_data.device().is_gpu() {
225 mask_data.clone()
226 } else {
227 mask_data.to_device(scores.data().device()).unwrap()
228 };
229
230 if let Some(expanded_tensor) = mask_gpu.mask_expand_cuda(
231 &scores_shape,
232 batch_size,
233 self.num_heads,
234 tgt_len,
235 src_len,
236 ) {
237 let additive_mask = Variable::new(expanded_tensor, false);
238 return self.finish_attention(
239 scores.add_var(&additive_mask),
240 &v,
241 batch_size,
242 tgt_len,
243 );
244 }
245 }
247
248 let mask_vec = mask_data.to_vec();
250 let additive: Vec<f32> = mask_vec
251 .iter()
252 .map(|&v| if v == 0.0 { -1e9 } else { 0.0 })
253 .collect();
254
255 let mut expanded = vec![0.0f32; total];
256
257 if mask_shape.len() == 2 && mask_shape[0] == tgt_len && mask_shape[1] == src_len {
258 for b in 0..batch_size {
260 for h in 0..self.num_heads {
261 for i in 0..tgt_len {
262 for j in 0..src_len {
263 let idx = b * self.num_heads * tgt_len * src_len
264 + h * tgt_len * src_len
265 + i * src_len
266 + j;
267 expanded[idx] = additive[i * src_len + j];
268 }
269 }
270 }
271 }
272 } else if mask_shape.len() == 2
273 && mask_shape[0] == batch_size
274 && mask_shape[1] == src_len
275 {
276 for b in 0..batch_size {
278 for h in 0..self.num_heads {
279 for i in 0..tgt_len {
280 for j in 0..src_len {
281 let idx = b * self.num_heads * tgt_len * src_len
282 + h * tgt_len * src_len
283 + i * src_len
284 + j;
285 expanded[idx] = additive[b * src_len + j];
286 }
287 }
288 }
289 }
290 } else {
291 for (i, val) in expanded.iter_mut().enumerate() {
293 *val = additive[i % additive.len()];
294 }
295 }
296
297 let mut additive_tensor = Tensor::from_vec(expanded, &scores_shape).unwrap();
298 let scores_device = scores.data().device();
299 if scores_device.is_gpu() {
300 additive_tensor = additive_tensor.to_device(scores_device).unwrap();
301 }
302 let additive_mask = Variable::new(additive_tensor, false);
303 scores.add_var(&additive_mask)
304 } else {
305 scores
306 };
307
308 self.finish_attention(scores, &v, batch_size, tgt_len)
309 }
310
311 fn finish_attention(
314 &self,
315 scores: Variable,
316 v: &Variable,
317 batch_size: usize,
318 tgt_len: usize,
319 ) -> Variable {
320 let attn_weights = scores.softmax(-1);
321 let attn_output = attn_weights.matmul(v);
322 let attn_output =
323 attn_output
324 .transpose(1, 2)
325 .reshape(&[batch_size, tgt_len, self.embed_dim]);
326 self.out_proj.forward(&attn_output)
327 }
328}
329
330impl Module for MultiHeadAttention {
331 fn forward(&self, input: &Variable) -> Variable {
332 self.attention(input, input, input, None)
334 }
335
336 fn parameters(&self) -> Vec<Parameter> {
337 let mut params = Vec::new();
338 params.extend(self.q_proj.parameters());
339 params.extend(self.k_proj.parameters());
340 params.extend(self.v_proj.parameters());
341 params.extend(self.out_proj.parameters());
342 params
343 }
344
345 fn named_parameters(&self) -> HashMap<String, Parameter> {
346 let mut params = HashMap::new();
347 for (name, param) in self.q_proj.named_parameters() {
348 params.insert(format!("q_proj.{name}"), param);
349 }
350 for (name, param) in self.k_proj.named_parameters() {
351 params.insert(format!("k_proj.{name}"), param);
352 }
353 for (name, param) in self.v_proj.named_parameters() {
354 params.insert(format!("v_proj.{name}"), param);
355 }
356 for (name, param) in self.out_proj.named_parameters() {
357 params.insert(format!("out_proj.{name}"), param);
358 }
359 params
360 }
361
362 fn name(&self) -> &'static str {
363 "MultiHeadAttention"
364 }
365}
366
367pub struct CrossAttention {
384 mha: MultiHeadAttention,
386}
387
388impl CrossAttention {
389 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
391 Self {
392 mha: MultiHeadAttention::new(embed_dim, num_heads),
393 }
394 }
395
396 pub fn with_options(
398 embed_dim: usize,
399 num_heads: usize,
400 dropout: f32,
401 batch_first: bool,
402 ) -> Self {
403 Self {
404 mha: MultiHeadAttention::with_options(embed_dim, num_heads, dropout, batch_first),
405 }
406 }
407
408 pub fn cross_attention(
415 &self,
416 query: &Variable,
417 memory: &Variable,
418 attn_mask: Option<&Variable>,
419 ) -> Variable {
420 self.mha.attention(query, memory, memory, attn_mask)
421 }
422
423 pub fn embed_dim(&self) -> usize {
425 self.mha.embed_dim
426 }
427
428 pub fn num_heads(&self) -> usize {
430 self.mha.num_heads
431 }
432}
433
434impl Module for CrossAttention {
435 fn forward(&self, input: &Variable) -> Variable {
436 self.mha.forward(input)
439 }
440
441 fn parameters(&self) -> Vec<Parameter> {
442 self.mha.parameters()
443 }
444
445 fn named_parameters(&self) -> HashMap<String, Parameter> {
446 let mut params = HashMap::new();
447 for (name, param) in self.mha.named_parameters() {
448 params.insert(format!("mha.{name}"), param);
449 }
450 params
451 }
452
453 fn name(&self) -> &'static str {
454 "CrossAttention"
455 }
456}
457
458pub fn scaled_dot_product_attention_fused(
491 q: &Tensor<f32>,
492 k: &Tensor<f32>,
493 v: &Tensor<f32>,
494 scale: f32,
495 is_causal: bool,
496) -> Tensor<f32> {
497 #[cfg(feature = "cuda")]
499 if q.device().is_gpu() {
500 if let Some(result) = q.fused_attention_cuda(k, v, scale, is_causal) {
501 return result;
502 }
503 }
504
505 let shape = q.shape();
507 let batch_size = shape[0];
508 let num_heads = shape[1];
509 let tgt_len = shape[2];
510 let head_dim = shape[3];
511 let src_len = k.shape()[2];
512
513 let q_data = q.to_vec();
514 let k_data = k.to_vec();
515 let v_data = v.to_vec();
516
517 let mut output = vec![0.0f32; batch_size * num_heads * tgt_len * head_dim];
518
519 for b in 0..batch_size {
520 for h in 0..num_heads {
521 for i in 0..tgt_len {
522 let mut scores = vec![0.0f32; src_len];
524 let mut max_score = f32::NEG_INFINITY;
525
526 for j in 0..src_len {
527 if is_causal && j > i {
528 scores[j] = f32::NEG_INFINITY;
529 continue;
530 }
531 let mut score = 0.0f32;
532 for d in 0..head_dim {
533 let q_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
534 let k_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
535 score += q_data[q_idx] * k_data[k_idx];
536 }
537 score *= scale;
538 scores[j] = score;
539 if score > max_score {
540 max_score = score;
541 }
542 }
543
544 let mut sum_exp = 0.0f32;
546 for s in &mut scores {
547 if *s > f32::NEG_INFINITY {
548 *s = (*s - max_score).exp();
549 sum_exp += *s;
550 } else {
551 *s = 0.0;
552 }
553 }
554 let inv_sum = if sum_exp > 0.0 { 1.0 / sum_exp } else { 0.0 };
555
556 for d in 0..head_dim {
558 let mut val = 0.0f32;
559 for j in 0..src_len {
560 let v_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
561 val += scores[j] * v_data[v_idx];
562 }
563 let out_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
564 output[out_idx] = val * inv_sum;
565 }
566 }
567 }
568 }
569
570 Tensor::from_vec(output, &[batch_size, num_heads, tgt_len, head_dim]).unwrap()
571}
572
573#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_multihead_attention_creation() {
583 let mha = MultiHeadAttention::new(512, 8);
584 assert_eq!(mha.embed_dim, 512);
585 assert_eq!(mha.num_heads, 8);
586 assert_eq!(mha.head_dim, 64);
587 }
588
589 #[test]
590 fn test_multihead_attention_forward() {
591 let mha = MultiHeadAttention::new(64, 4);
592 let input = Variable::new(
593 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
594 false,
595 );
596 let output = mha.forward(&input);
597 assert_eq!(output.shape(), vec![2, 10, 64]);
598 }
599
600 #[test]
601 fn test_cross_attention() {
602 let mha = MultiHeadAttention::new(64, 4);
603 let query = Variable::new(
604 Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
605 false,
606 );
607 let key_value = Variable::new(
608 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
609 false,
610 );
611 let output = mha.attention(&query, &key_value, &key_value, None);
612 assert_eq!(output.shape(), vec![2, 5, 64]);
613 }
614
615 #[test]
616 fn test_multihead_attention_parameters() {
617 let mha = MultiHeadAttention::new(64, 4);
618 let params = mha.parameters();
619 assert_eq!(params.len(), 8);
621 }
622
623 #[test]
624 fn test_cross_attention_creation() {
625 let ca = CrossAttention::new(256, 8);
626 assert_eq!(ca.embed_dim(), 256);
627 assert_eq!(ca.num_heads(), 8);
628 }
629
630 #[test]
631 fn test_cross_attention_forward() {
632 let ca = CrossAttention::new(64, 4);
633 let query = Variable::new(
635 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
636 false,
637 );
638 let memory = Variable::new(
640 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
641 false,
642 );
643 let output = ca.cross_attention(&query, &memory, None);
644 assert_eq!(output.shape(), vec![2, 5, 64]);
645 }
646
647 #[test]
648 fn test_cross_attention_self_attention_fallback() {
649 let ca = CrossAttention::new(64, 4);
650 let input = Variable::new(
651 Tensor::from_vec(vec![1.0; 2 * 8 * 64], &[2, 8, 64]).unwrap(),
652 false,
653 );
654 let output = ca.forward(&input);
656 assert_eq!(output.shape(), vec![2, 8, 64]);
657 }
658
659 #[test]
660 fn test_cross_attention_parameters() {
661 let ca = CrossAttention::new(64, 4);
662 let params = ca.parameters();
663 assert_eq!(params.len(), 8); let named = ca.named_parameters();
665 assert!(named.contains_key("mha.q_proj.weight"));
666 assert!(named.contains_key("mha.out_proj.bias"));
667 }
668
669 #[test]
670 fn test_fused_attention_cpu() {
671 let batch = 2;
673 let heads = 4;
674 let seq = 8;
675 let dim = 16;
676 let scale = 1.0 / (dim as f32).sqrt();
677
678 let q = Tensor::from_vec(
679 vec![0.1; batch * heads * seq * dim],
680 &[batch, heads, seq, dim],
681 )
682 .unwrap();
683 let k = Tensor::from_vec(
684 vec![0.1; batch * heads * seq * dim],
685 &[batch, heads, seq, dim],
686 )
687 .unwrap();
688 let v = Tensor::from_vec(
689 vec![0.5; batch * heads * seq * dim],
690 &[batch, heads, seq, dim],
691 )
692 .unwrap();
693
694 let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
695 assert_eq!(out.shape(), &[batch, heads, seq, dim]);
696
697 let out_vec = out.to_vec();
699 for val in &out_vec {
700 assert!((*val - 0.5).abs() < 0.01, "Expected ~0.5, got {}", val);
701 }
702 }
703
704 #[test]
705 fn test_fused_attention_causal() {
706 let batch = 1;
707 let heads = 1;
708 let seq = 4;
709 let dim = 4;
710 let scale = 1.0 / (dim as f32).sqrt();
711
712 let q = Tensor::from_vec(
714 vec![0.1; batch * heads * seq * dim],
715 &[batch, heads, seq, dim],
716 )
717 .unwrap();
718 let k = Tensor::from_vec(
719 vec![0.1; batch * heads * seq * dim],
720 &[batch, heads, seq, dim],
721 )
722 .unwrap();
723 let v = Tensor::from_vec(
724 vec![
725 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
726 ],
727 &[batch, heads, seq, dim],
728 )
729 .unwrap();
730
731 let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
732 assert_eq!(out.shape(), &[batch, heads, seq, dim]);
733
734 let out_vec = out.to_vec();
736 assert!(
737 (out_vec[0] - 1.0).abs() < 1e-5,
738 "row 0, col 0 should be 1.0"
739 );
740 assert!((out_vec[1]).abs() < 1e-5, "row 0, col 1 should be 0.0");
741 }
742
743 #[test]
744 fn test_multihead_attention_backward_cpu() {
745 use axonml_autograd::backward;
747
748 let mha = MultiHeadAttention::new(32, 4);
749 let input = Variable::new(
750 Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).unwrap(),
751 true,
752 );
753 let output = mha.forward(&input);
754 assert_eq!(output.shape(), vec![2, 4, 32]);
755
756 let loss = output.sum();
758 let ones = Tensor::from_vec(vec![1.0f32], &[1]).unwrap();
759 backward(&loss, &ones);
760
761 let grad = input.grad();
763 assert!(grad.is_some(), "Input gradient should exist");
764 let grad_data = grad.unwrap();
765 assert_eq!(grad_data.shape(), &[2, 4, 32]);
766
767 let grad_vec = grad_data.to_vec();
769 let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
770 assert!(non_zero, "Gradients should be non-zero");
771 }
772
773 #[test]
774 fn test_fused_attention_backward_cpu() {
775 use axonml_autograd::functions::FusedAttentionBackward;
777 use axonml_autograd::grad_fn::GradientFunction;
778
779 let batch = 1;
780 let heads = 2;
781 let seq = 4;
782 let dim = 8;
783 let scale = 1.0 / (dim as f32).sqrt();
784
785 let q_data: Vec<f32> = (0..batch * heads * seq * dim)
787 .map(|i| ((i as f32) * 0.01).sin())
788 .collect();
789 let k_data: Vec<f32> = (0..batch * heads * seq * dim)
790 .map(|i| ((i as f32) * 0.02).cos())
791 .collect();
792 let v_data: Vec<f32> = (0..batch * heads * seq * dim)
793 .map(|i| ((i as f32) * 0.03).sin() + 0.5)
794 .collect();
795
796 let q = Tensor::from_vec(q_data, &[batch, heads, seq, dim]).unwrap();
797 let k = Tensor::from_vec(k_data, &[batch, heads, seq, dim]).unwrap();
798 let v = Tensor::from_vec(v_data, &[batch, heads, seq, dim]).unwrap();
799
800 let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
802 assert_eq!(output.shape(), &[batch, heads, seq, dim]);
803
804 let backward_fn = FusedAttentionBackward::new(
806 None,
807 None,
808 None,
809 q.clone(),
810 k.clone(),
811 v.clone(),
812 output.clone(),
813 scale,
814 false,
815 );
816
817 let grad_output = Tensor::from_vec(
819 vec![1.0f32; batch * heads * seq * dim],
820 &[batch, heads, seq, dim],
821 )
822 .unwrap();
823
824 let grads = backward_fn.apply(&grad_output);
825 assert_eq!(grads.len(), 3);
826
827 let gq = grads[0].as_ref().expect("grad_Q should exist");
828 let gk = grads[1].as_ref().expect("grad_K should exist");
829 let gv = grads[2].as_ref().expect("grad_V should exist");
830
831 assert_eq!(gq.shape(), &[batch, heads, seq, dim]);
832 assert_eq!(gk.shape(), &[batch, heads, seq, dim]);
833 assert_eq!(gv.shape(), &[batch, heads, seq, dim]);
834
835 for val in gq
837 .to_vec()
838 .iter()
839 .chain(gk.to_vec().iter())
840 .chain(gv.to_vec().iter())
841 {
842 assert!(val.is_finite(), "Gradient should be finite, got {}", val);
843 }
844
845 let gv_nonzero = gv.to_vec().iter().any(|&v| v.abs() > 1e-10);
847 assert!(gv_nonzero, "grad_V should be non-zero");
848 }
849
850 #[test]
851 fn test_fused_attention_backward_causal_cpu() {
852 use axonml_autograd::functions::FusedAttentionBackward;
854 use axonml_autograd::grad_fn::GradientFunction;
855
856 let batch = 1;
857 let heads = 1;
858 let seq = 4;
859 let dim = 4;
860 let scale = 1.0 / (dim as f32).sqrt();
861
862 let q = Tensor::from_vec(
863 vec![0.1f32; batch * heads * seq * dim],
864 &[batch, heads, seq, dim],
865 )
866 .unwrap();
867 let k = Tensor::from_vec(
868 vec![0.2f32; batch * heads * seq * dim],
869 &[batch, heads, seq, dim],
870 )
871 .unwrap();
872 let v = Tensor::from_vec(
873 vec![0.5f32; batch * heads * seq * dim],
874 &[batch, heads, seq, dim],
875 )
876 .unwrap();
877
878 let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
879
880 let backward_fn = FusedAttentionBackward::new(
881 None,
882 None,
883 None,
884 q.clone(),
885 k.clone(),
886 v.clone(),
887 output.clone(),
888 scale,
889 true,
890 );
891
892 let grad_output = Tensor::from_vec(
893 vec![1.0f32; batch * heads * seq * dim],
894 &[batch, heads, seq, dim],
895 )
896 .unwrap();
897
898 let grads = backward_fn.apply(&grad_output);
899 assert_eq!(grads.len(), 3);
900
901 let gq = grads[0].as_ref().unwrap();
902 let gk = grads[1].as_ref().unwrap();
903 let gv = grads[2].as_ref().unwrap();
904
905 for val in gq
907 .to_vec()
908 .iter()
909 .chain(gk.to_vec().iter())
910 .chain(gv.to_vec().iter())
911 {
912 assert!(val.is_finite(), "Gradient should be finite, got {}", val);
913 }
914 }
915}