Skip to main content

axonml_nn/layers/
attention.rs

1//! Attention Mechanisms - Multi-Head Attention
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/attention.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20#[cfg(feature = "cuda")]
21use axonml_autograd::functions::FusedAttentionBackward;
22#[cfg(feature = "cuda")]
23use axonml_autograd::grad_fn::GradFn;
24use axonml_tensor::Tensor;
25
26use crate::layers::Linear;
27use crate::module::Module;
28use crate::parameter::Parameter;
29
30// =============================================================================
31// MultiHeadAttention
32// =============================================================================
33
34/// Multi-Head Attention mechanism.
35///
36/// Allows the model to jointly attend to information from different
37/// representation subspaces at different positions.
38///
39/// # Arguments
40/// * `embed_dim` - Total dimension of the model
41/// * `num_heads` - Number of parallel attention heads
42/// * `dropout` - Dropout probability (default: 0.0)
43///
44/// # Shape
45/// - Query: (L, N, E) or (N, L, E) if batch_first
46/// - Key: (S, N, E) or (N, S, E) if batch_first
47/// - Value: (S, N, E) or (N, S, E) if batch_first
48/// - Output: (L, N, E) or (N, L, E) if batch_first
49pub struct MultiHeadAttention {
50    /// Query projection.
51    q_proj: Linear,
52    /// Key projection.
53    k_proj: Linear,
54    /// Value projection.
55    v_proj: Linear,
56    /// Output projection.
57    out_proj: Linear,
58    /// Embedding dimension.
59    embed_dim: usize,
60    /// Number of attention heads.
61    num_heads: usize,
62    /// Dimension per head.
63    head_dim: usize,
64    /// Scaling factor.
65    scale: f32,
66    /// Whether input is batch first.
67    batch_first: bool,
68}
69
70impl MultiHeadAttention {
71    /// Creates a new MultiHeadAttention module.
72    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
73        Self::with_options(embed_dim, num_heads, 0.0, true)
74    }
75
76    /// Creates MultiHeadAttention with all options.
77    pub fn with_options(
78        embed_dim: usize,
79        num_heads: usize,
80        _dropout: f32,
81        batch_first: bool,
82    ) -> Self {
83        assert!(
84            embed_dim % num_heads == 0,
85            "embed_dim must be divisible by num_heads"
86        );
87
88        let head_dim = embed_dim / num_heads;
89        let scale = (head_dim as f32).sqrt().recip();
90
91        Self {
92            q_proj: Linear::new(embed_dim, embed_dim),
93            k_proj: Linear::new(embed_dim, embed_dim),
94            v_proj: Linear::new(embed_dim, embed_dim),
95            out_proj: Linear::new(embed_dim, embed_dim),
96            embed_dim,
97            num_heads,
98            head_dim,
99            scale,
100            batch_first,
101        }
102    }
103
104    /// Try to expand attention mask on GPU via CUDA kernels.
105    /// Returns None to fall through to CPU expansion.
106    #[allow(unused_variables)]
107    #[allow(dead_code)]
108    fn try_gpu_mask_expand(
109        mask_data: &Tensor<f32>,
110        mask_shape: &[usize],
111        scores_shape: &[usize],
112        device: axonml_core::Device,
113        total: usize,
114        batch_size: usize,
115        num_heads: usize,
116        tgt_len: usize,
117        src_len: usize,
118    ) -> Option<Variable> {
119        // TODO: CUDA kernel for mask broadcast expansion
120        None
121    }
122
123    /// Computes attention using batched matmul (BLAS-accelerated).
124    pub fn attention(
125        &self,
126        query: &Variable,
127        key: &Variable,
128        value: &Variable,
129        attn_mask: Option<&Variable>,
130    ) -> Variable {
131        let q_shape = query.shape();
132        let (batch_size, tgt_len, _) = if self.batch_first {
133            (q_shape[0], q_shape[1], q_shape[2])
134        } else {
135            (q_shape[1], q_shape[0], q_shape[2])
136        };
137        let src_len = if self.batch_first {
138            key.shape()[1]
139        } else {
140            key.shape()[0]
141        };
142
143        // Project Q, K, V  (all tracked through autograd)
144        let q = self.q_proj.forward(query);
145        let k = self.k_proj.forward(key);
146        let v = self.v_proj.forward(value);
147
148        // Reshape to multi-head: [batch, seq, embed] → [batch, seq, heads, head_dim]
149        // Then transpose to:     [batch, heads, seq, head_dim]
150        let q = q
151            .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
152            .transpose(1, 2);
153        let k = k
154            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
155            .transpose(1, 2);
156        let v = v
157            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
158            .transpose(1, 2);
159
160        // GPU fast path: fused attention kernel avoids materializing the N*N
161        // attention matrix. Works for both inference and training when no mask
162        // is provided. The kernel computes Q@K^T * scale -> softmax -> @V in
163        // one pass per row. In training mode, a FusedAttentionBackward autograd
164        // function is attached that uses the CUDA backward kernel (or CPU fallback).
165        #[cfg(feature = "cuda")]
166        if q.data().device().is_gpu() && attn_mask.is_none() {
167            let is_training = axonml_autograd::no_grad::is_grad_enabled();
168            let q_tensor = q.data();
169            let k_tensor = k.data();
170            let v_tensor = v.data();
171
172            if let Some(attn_out) = q_tensor.fused_attention_cuda(
173                &k_tensor, &v_tensor, self.scale,
174                false, // not causal by default; causal mask would be in attn_mask
175            ) {
176                let attn_output = if is_training
177                    && (q.requires_grad() || k.requires_grad() || v.requires_grad())
178                {
179                    // Build autograd backward function for training
180                    let backward = FusedAttentionBackward::new(
181                        q.grad_fn().cloned(),
182                        k.grad_fn().cloned(),
183                        v.grad_fn().cloned(),
184                        q_tensor,
185                        k_tensor,
186                        v_tensor,
187                        attn_out.clone(),
188                        self.scale,
189                        false,
190                    );
191                    Variable::from_operation(attn_out, GradFn::new(backward), true)
192                } else {
193                    Variable::new(attn_out, false)
194                };
195                let attn_output =
196                    attn_output
197                        .transpose(1, 2)
198                        .reshape(&[batch_size, tgt_len, self.embed_dim]);
199                return self.out_proj.forward(&attn_output);
200            }
201            // Fall through to standard path if fused kernel fails
202        }
203
204        // Scaled dot-product attention: scores = Q @ K^T * scale
205        // K^T: [batch, heads, head_dim, src_len]
206        let k_t = k.transpose(2, 3);
207        // scores: [batch, heads, tgt_len, src_len]
208        let scores = q.matmul(&k_t).mul_scalar(self.scale);
209
210        // Apply attention mask (0 → -1e9 additive mask)
211        // Mask shapes: [tgt_len, src_len] (causal) or [batch, src_len] (padding)
212        // Scores shape: [batch, heads, tgt_len, src_len]
213        let scores = if let Some(mask) = attn_mask {
214            let mask_shape = mask.shape();
215            let mask_data = mask.data();
216            let scores_shape = scores.shape();
217            let total = scores_shape.iter().product::<usize>();
218
219            // GPU fast path: expand mask entirely on GPU via CUDA kernel
220            // Avoids GPU→CPU→GPU round-trip (9 mask expansions per forward pass)
221            #[cfg(feature = "cuda")]
222            if scores.data().device().is_gpu() {
223                // Ensure mask is on GPU (it's small, so upload is cheap if needed)
224                let mask_gpu = if mask_data.device().is_gpu() {
225                    mask_data.clone()
226                } else {
227                    mask_data.to_device(scores.data().device()).unwrap()
228                };
229
230                if let Some(expanded_tensor) = mask_gpu.mask_expand_cuda(
231                    &scores_shape,
232                    batch_size,
233                    self.num_heads,
234                    tgt_len,
235                    src_len,
236                ) {
237                    let additive_mask = Variable::new(expanded_tensor, false);
238                    return self.finish_attention(
239                        scores.add_var(&additive_mask),
240                        &v,
241                        batch_size,
242                        tgt_len,
243                    );
244                }
245                // Fall through to CPU path on unsupported shape
246            }
247
248            // CPU fallback: expand mask with nested loops
249            let mask_vec = mask_data.to_vec();
250            let additive: Vec<f32> = mask_vec
251                .iter()
252                .map(|&v| if v == 0.0 { -1e9 } else { 0.0 })
253                .collect();
254
255            let mut expanded = vec![0.0f32; total];
256
257            if mask_shape.len() == 2 && mask_shape[0] == tgt_len && mask_shape[1] == src_len {
258                // Causal mask [tgt_len, src_len] → broadcast over batch & heads
259                for b in 0..batch_size {
260                    for h in 0..self.num_heads {
261                        for i in 0..tgt_len {
262                            for j in 0..src_len {
263                                let idx = b * self.num_heads * tgt_len * src_len
264                                    + h * tgt_len * src_len
265                                    + i * src_len
266                                    + j;
267                                expanded[idx] = additive[i * src_len + j];
268                            }
269                        }
270                    }
271                }
272            } else if mask_shape.len() == 2
273                && mask_shape[0] == batch_size
274                && mask_shape[1] == src_len
275            {
276                // Padding mask [batch, src_len] → broadcast over heads & tgt positions
277                for b in 0..batch_size {
278                    for h in 0..self.num_heads {
279                        for i in 0..tgt_len {
280                            for j in 0..src_len {
281                                let idx = b * self.num_heads * tgt_len * src_len
282                                    + h * tgt_len * src_len
283                                    + i * src_len
284                                    + j;
285                                expanded[idx] = additive[b * src_len + j];
286                            }
287                        }
288                    }
289                }
290            } else {
291                // General: tile mask across scores using modular indexing
292                for (i, val) in expanded.iter_mut().enumerate() {
293                    *val = additive[i % additive.len()];
294                }
295            }
296
297            let mut additive_tensor = Tensor::from_vec(expanded, &scores_shape).unwrap();
298            let scores_device = scores.data().device();
299            if scores_device.is_gpu() {
300                additive_tensor = additive_tensor.to_device(scores_device).unwrap();
301            }
302            let additive_mask = Variable::new(additive_tensor, false);
303            scores.add_var(&additive_mask)
304        } else {
305            scores
306        };
307
308        self.finish_attention(scores, &v, batch_size, tgt_len)
309    }
310
311    /// Softmax → weighted sum → reshape → output projection.
312    /// Shared by both GPU and CPU mask expansion paths.
313    fn finish_attention(
314        &self,
315        scores: Variable,
316        v: &Variable,
317        batch_size: usize,
318        tgt_len: usize,
319    ) -> Variable {
320        let attn_weights = scores.softmax(-1);
321        let attn_output = attn_weights.matmul(v);
322        let attn_output =
323            attn_output
324                .transpose(1, 2)
325                .reshape(&[batch_size, tgt_len, self.embed_dim]);
326        self.out_proj.forward(&attn_output)
327    }
328}
329
330impl Module for MultiHeadAttention {
331    fn forward(&self, input: &Variable) -> Variable {
332        // Self-attention: query = key = value = input
333        self.attention(input, input, input, None)
334    }
335
336    fn parameters(&self) -> Vec<Parameter> {
337        let mut params = Vec::new();
338        params.extend(self.q_proj.parameters());
339        params.extend(self.k_proj.parameters());
340        params.extend(self.v_proj.parameters());
341        params.extend(self.out_proj.parameters());
342        params
343    }
344
345    fn named_parameters(&self) -> HashMap<String, Parameter> {
346        let mut params = HashMap::new();
347        for (name, param) in self.q_proj.named_parameters() {
348            params.insert(format!("q_proj.{name}"), param);
349        }
350        for (name, param) in self.k_proj.named_parameters() {
351            params.insert(format!("k_proj.{name}"), param);
352        }
353        for (name, param) in self.v_proj.named_parameters() {
354            params.insert(format!("v_proj.{name}"), param);
355        }
356        for (name, param) in self.out_proj.named_parameters() {
357            params.insert(format!("out_proj.{name}"), param);
358        }
359        params
360    }
361
362    fn name(&self) -> &'static str {
363        "MultiHeadAttention"
364    }
365}
366
367// =============================================================================
368// CrossAttention
369// =============================================================================
370
371/// Cross-Attention mechanism for encoder-decoder architectures.
372///
373/// Queries come from the decoder, keys and values come from the encoder.
374/// This is the standard cross-attention used in Transformer decoders,
375/// seq2seq models, and vision-language models.
376///
377/// # Shape (batch_first=true)
378/// - Query (decoder): (N, T, E)
379/// - Memory (encoder): (N, S, E)
380/// - Output: (N, T, E)
381///
382/// where N=batch, T=target seq len, S=source seq len, E=embed_dim.
383pub struct CrossAttention {
384    /// Underlying multi-head attention.
385    mha: MultiHeadAttention,
386}
387
388impl CrossAttention {
389    /// Creates a new CrossAttention module.
390    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
391        Self {
392            mha: MultiHeadAttention::new(embed_dim, num_heads),
393        }
394    }
395
396    /// Creates CrossAttention with all options.
397    pub fn with_options(
398        embed_dim: usize,
399        num_heads: usize,
400        dropout: f32,
401        batch_first: bool,
402    ) -> Self {
403        Self {
404            mha: MultiHeadAttention::with_options(embed_dim, num_heads, dropout, batch_first),
405        }
406    }
407
408    /// Computes cross-attention.
409    ///
410    /// # Arguments
411    /// * `query` - Decoder hidden states (N, T, E)
412    /// * `memory` - Encoder output (N, S, E)
413    /// * `attn_mask` - Optional attention mask
414    pub fn cross_attention(
415        &self,
416        query: &Variable,
417        memory: &Variable,
418        attn_mask: Option<&Variable>,
419    ) -> Variable {
420        self.mha.attention(query, memory, memory, attn_mask)
421    }
422
423    /// Returns the embedding dimension.
424    pub fn embed_dim(&self) -> usize {
425        self.mha.embed_dim
426    }
427
428    /// Returns the number of heads.
429    pub fn num_heads(&self) -> usize {
430        self.mha.num_heads
431    }
432}
433
434impl Module for CrossAttention {
435    fn forward(&self, input: &Variable) -> Variable {
436        // When called as Module (single input), acts as self-attention.
437        // Use cross_attention() for encoder-decoder attention.
438        self.mha.forward(input)
439    }
440
441    fn parameters(&self) -> Vec<Parameter> {
442        self.mha.parameters()
443    }
444
445    fn named_parameters(&self) -> HashMap<String, Parameter> {
446        let mut params = HashMap::new();
447        for (name, param) in self.mha.named_parameters() {
448            params.insert(format!("mha.{name}"), param);
449        }
450        params
451    }
452
453    fn name(&self) -> &'static str {
454        "CrossAttention"
455    }
456}
457
458// =============================================================================
459// Fused Scaled Dot-Product Attention
460// =============================================================================
461
462/// Fused scaled dot-product attention on Tensors.
463///
464/// Computes `softmax(Q @ K^T * scale) @ V` without materializing the full
465/// N*N attention matrix. On GPU, this uses a CUDA kernel that processes one
466/// query row per thread. On CPU, falls back to standard matmul + softmax.
467///
468/// # Arguments
469/// * `q` - Query tensor `[B, H, Tq, D]`
470/// * `k` - Key tensor `[B, H, Tk, D]`
471/// * `v` - Value tensor `[B, H, Tk, D]`
472/// * `scale` - Scaling factor (typically `1/sqrt(head_dim)`)
473/// * `is_causal` - Whether to apply causal masking
474///
475/// # Returns
476/// Output tensor `[B, H, Tq, D]`
477///
478/// # Note on Flash Attention
479/// This is a fused kernel (Option B) — it avoids the N*N memory allocation
480/// but each thread still iterates over the full key sequence. True Flash
481/// Attention (Option A) would use shared memory tiling with online softmax
482/// (the Dao et al. algorithm), processing in blocks of ~64-128 and requiring:
483/// - Shared memory for Q/K/V tile loading
484/// - Online softmax with running max/sum correction across tiles
485/// - Two passes: forward for output + logsumexp, backward with recomputation
486/// - ~3x more complex kernel code but O(N) memory and better cache behavior
487///
488/// For sequences up to ~2048, this fused kernel provides most of the benefit.
489/// For longer sequences, use the tiled CPU Flash Attention in `axonml-llm`.
490pub fn scaled_dot_product_attention_fused(
491    q: &Tensor<f32>,
492    k: &Tensor<f32>,
493    v: &Tensor<f32>,
494    scale: f32,
495    is_causal: bool,
496) -> Tensor<f32> {
497    // Try GPU fused kernel
498    #[cfg(feature = "cuda")]
499    if q.device().is_gpu() {
500        if let Some(result) = q.fused_attention_cuda(k, v, scale, is_causal) {
501            return result;
502        }
503    }
504
505    // CPU fallback: standard matmul-based attention
506    let shape = q.shape();
507    let batch_size = shape[0];
508    let num_heads = shape[1];
509    let tgt_len = shape[2];
510    let head_dim = shape[3];
511    let src_len = k.shape()[2];
512
513    let q_data = q.to_vec();
514    let k_data = k.to_vec();
515    let v_data = v.to_vec();
516
517    let mut output = vec![0.0f32; batch_size * num_heads * tgt_len * head_dim];
518
519    for b in 0..batch_size {
520        for h in 0..num_heads {
521            for i in 0..tgt_len {
522                // Compute attention scores for row i
523                let mut scores = vec![0.0f32; src_len];
524                let mut max_score = f32::NEG_INFINITY;
525
526                for j in 0..src_len {
527                    if is_causal && j > i {
528                        scores[j] = f32::NEG_INFINITY;
529                        continue;
530                    }
531                    let mut score = 0.0f32;
532                    for d in 0..head_dim {
533                        let q_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
534                        let k_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
535                        score += q_data[q_idx] * k_data[k_idx];
536                    }
537                    score *= scale;
538                    scores[j] = score;
539                    if score > max_score {
540                        max_score = score;
541                    }
542                }
543
544                // Softmax
545                let mut sum_exp = 0.0f32;
546                for s in &mut scores {
547                    if *s > f32::NEG_INFINITY {
548                        *s = (*s - max_score).exp();
549                        sum_exp += *s;
550                    } else {
551                        *s = 0.0;
552                    }
553                }
554                let inv_sum = if sum_exp > 0.0 { 1.0 / sum_exp } else { 0.0 };
555
556                // Weighted sum of V
557                for d in 0..head_dim {
558                    let mut val = 0.0f32;
559                    for j in 0..src_len {
560                        let v_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
561                        val += scores[j] * v_data[v_idx];
562                    }
563                    let out_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
564                    output[out_idx] = val * inv_sum;
565                }
566            }
567        }
568    }
569
570    Tensor::from_vec(output, &[batch_size, num_heads, tgt_len, head_dim]).unwrap()
571}
572
573// =============================================================================
574// Tests
575// =============================================================================
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    #[test]
582    fn test_multihead_attention_creation() {
583        let mha = MultiHeadAttention::new(512, 8);
584        assert_eq!(mha.embed_dim, 512);
585        assert_eq!(mha.num_heads, 8);
586        assert_eq!(mha.head_dim, 64);
587    }
588
589    #[test]
590    fn test_multihead_attention_forward() {
591        let mha = MultiHeadAttention::new(64, 4);
592        let input = Variable::new(
593            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
594            false,
595        );
596        let output = mha.forward(&input);
597        assert_eq!(output.shape(), vec![2, 10, 64]);
598    }
599
600    #[test]
601    fn test_cross_attention() {
602        let mha = MultiHeadAttention::new(64, 4);
603        let query = Variable::new(
604            Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
605            false,
606        );
607        let key_value = Variable::new(
608            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
609            false,
610        );
611        let output = mha.attention(&query, &key_value, &key_value, None);
612        assert_eq!(output.shape(), vec![2, 5, 64]);
613    }
614
615    #[test]
616    fn test_multihead_attention_parameters() {
617        let mha = MultiHeadAttention::new(64, 4);
618        let params = mha.parameters();
619        // Q, K, V, Out projections each have weight + bias = 8 total
620        assert_eq!(params.len(), 8);
621    }
622
623    #[test]
624    fn test_cross_attention_creation() {
625        let ca = CrossAttention::new(256, 8);
626        assert_eq!(ca.embed_dim(), 256);
627        assert_eq!(ca.num_heads(), 8);
628    }
629
630    #[test]
631    fn test_cross_attention_forward() {
632        let ca = CrossAttention::new(64, 4);
633        // Decoder query: (batch=2, tgt_len=5, embed=64)
634        let query = Variable::new(
635            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
636            false,
637        );
638        // Encoder memory: (batch=2, src_len=10, embed=64)
639        let memory = Variable::new(
640            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
641            false,
642        );
643        let output = ca.cross_attention(&query, &memory, None);
644        assert_eq!(output.shape(), vec![2, 5, 64]);
645    }
646
647    #[test]
648    fn test_cross_attention_self_attention_fallback() {
649        let ca = CrossAttention::new(64, 4);
650        let input = Variable::new(
651            Tensor::from_vec(vec![1.0; 2 * 8 * 64], &[2, 8, 64]).unwrap(),
652            false,
653        );
654        // Module::forward does self-attention
655        let output = ca.forward(&input);
656        assert_eq!(output.shape(), vec![2, 8, 64]);
657    }
658
659    #[test]
660    fn test_cross_attention_parameters() {
661        let ca = CrossAttention::new(64, 4);
662        let params = ca.parameters();
663        assert_eq!(params.len(), 8); // Q, K, V, Out × (weight + bias)
664        let named = ca.named_parameters();
665        assert!(named.contains_key("mha.q_proj.weight"));
666        assert!(named.contains_key("mha.out_proj.bias"));
667    }
668
669    #[test]
670    fn test_fused_attention_cpu() {
671        // Test fused attention on CPU (fallback path)
672        let batch = 2;
673        let heads = 4;
674        let seq = 8;
675        let dim = 16;
676        let scale = 1.0 / (dim as f32).sqrt();
677
678        let q = Tensor::from_vec(
679            vec![0.1; batch * heads * seq * dim],
680            &[batch, heads, seq, dim],
681        )
682        .unwrap();
683        let k = Tensor::from_vec(
684            vec![0.1; batch * heads * seq * dim],
685            &[batch, heads, seq, dim],
686        )
687        .unwrap();
688        let v = Tensor::from_vec(
689            vec![0.5; batch * heads * seq * dim],
690            &[batch, heads, seq, dim],
691        )
692        .unwrap();
693
694        let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
695        assert_eq!(out.shape(), &[batch, heads, seq, dim]);
696
697        // With uniform V=0.5, output should be close to 0.5
698        let out_vec = out.to_vec();
699        for val in &out_vec {
700            assert!((*val - 0.5).abs() < 0.01, "Expected ~0.5, got {}", val);
701        }
702    }
703
704    #[test]
705    fn test_fused_attention_causal() {
706        let batch = 1;
707        let heads = 1;
708        let seq = 4;
709        let dim = 4;
710        let scale = 1.0 / (dim as f32).sqrt();
711
712        // Q and K are identity-like so attention focuses on matching positions
713        let q = Tensor::from_vec(
714            vec![0.1; batch * heads * seq * dim],
715            &[batch, heads, seq, dim],
716        )
717        .unwrap();
718        let k = Tensor::from_vec(
719            vec![0.1; batch * heads * seq * dim],
720            &[batch, heads, seq, dim],
721        )
722        .unwrap();
723        let v = Tensor::from_vec(
724            vec![
725                1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
726            ],
727            &[batch, heads, seq, dim],
728        )
729        .unwrap();
730
731        let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
732        assert_eq!(out.shape(), &[batch, heads, seq, dim]);
733
734        // First position can only attend to position 0, so output = V[0] = [1,0,0,0]
735        let out_vec = out.to_vec();
736        assert!(
737            (out_vec[0] - 1.0).abs() < 1e-5,
738            "row 0, col 0 should be 1.0"
739        );
740        assert!((out_vec[1]).abs() < 1e-5, "row 0, col 1 should be 0.0");
741    }
742
743    #[test]
744    fn test_multihead_attention_backward_cpu() {
745        // Test that gradients flow through MHA in training mode (CPU path)
746        use axonml_autograd::backward;
747
748        let mha = MultiHeadAttention::new(32, 4);
749        let input = Variable::new(
750            Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).unwrap(),
751            true,
752        );
753        let output = mha.forward(&input);
754        assert_eq!(output.shape(), vec![2, 4, 32]);
755
756        // Sum the output and backward
757        let loss = output.sum();
758        let ones = Tensor::from_vec(vec![1.0f32], &[1]).unwrap();
759        backward(&loss, &ones);
760
761        // Input should have gradients
762        let grad = input.grad();
763        assert!(grad.is_some(), "Input gradient should exist");
764        let grad_data = grad.unwrap();
765        assert_eq!(grad_data.shape(), &[2, 4, 32]);
766
767        // Gradients should be non-zero
768        let grad_vec = grad_data.to_vec();
769        let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
770        assert!(non_zero, "Gradients should be non-zero");
771    }
772
773    #[test]
774    fn test_fused_attention_backward_cpu() {
775        // Test the FusedAttentionBackward autograd function directly on CPU
776        use axonml_autograd::functions::FusedAttentionBackward;
777        use axonml_autograd::grad_fn::GradientFunction;
778
779        let batch = 1;
780        let heads = 2;
781        let seq = 4;
782        let dim = 8;
783        let scale = 1.0 / (dim as f32).sqrt();
784
785        // Create random-ish tensors
786        let q_data: Vec<f32> = (0..batch * heads * seq * dim)
787            .map(|i| ((i as f32) * 0.01).sin())
788            .collect();
789        let k_data: Vec<f32> = (0..batch * heads * seq * dim)
790            .map(|i| ((i as f32) * 0.02).cos())
791            .collect();
792        let v_data: Vec<f32> = (0..batch * heads * seq * dim)
793            .map(|i| ((i as f32) * 0.03).sin() + 0.5)
794            .collect();
795
796        let q = Tensor::from_vec(q_data, &[batch, heads, seq, dim]).unwrap();
797        let k = Tensor::from_vec(k_data, &[batch, heads, seq, dim]).unwrap();
798        let v = Tensor::from_vec(v_data, &[batch, heads, seq, dim]).unwrap();
799
800        // Compute forward output using the fused CPU path
801        let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
802        assert_eq!(output.shape(), &[batch, heads, seq, dim]);
803
804        // Create backward function
805        let backward_fn = FusedAttentionBackward::new(
806            None,
807            None,
808            None,
809            q.clone(),
810            k.clone(),
811            v.clone(),
812            output.clone(),
813            scale,
814            false,
815        );
816
817        // Use ones as grad_output
818        let grad_output = Tensor::from_vec(
819            vec![1.0f32; batch * heads * seq * dim],
820            &[batch, heads, seq, dim],
821        )
822        .unwrap();
823
824        let grads = backward_fn.apply(&grad_output);
825        assert_eq!(grads.len(), 3);
826
827        let gq = grads[0].as_ref().expect("grad_Q should exist");
828        let gk = grads[1].as_ref().expect("grad_K should exist");
829        let gv = grads[2].as_ref().expect("grad_V should exist");
830
831        assert_eq!(gq.shape(), &[batch, heads, seq, dim]);
832        assert_eq!(gk.shape(), &[batch, heads, seq, dim]);
833        assert_eq!(gv.shape(), &[batch, heads, seq, dim]);
834
835        // Gradients should be finite
836        for val in gq
837            .to_vec()
838            .iter()
839            .chain(gk.to_vec().iter())
840            .chain(gv.to_vec().iter())
841        {
842            assert!(val.is_finite(), "Gradient should be finite, got {}", val);
843        }
844
845        // grad_V should be non-zero (it's P^T @ grad_output)
846        let gv_nonzero = gv.to_vec().iter().any(|&v| v.abs() > 1e-10);
847        assert!(gv_nonzero, "grad_V should be non-zero");
848    }
849
850    #[test]
851    fn test_fused_attention_backward_causal_cpu() {
852        // Test the backward with causal masking
853        use axonml_autograd::functions::FusedAttentionBackward;
854        use axonml_autograd::grad_fn::GradientFunction;
855
856        let batch = 1;
857        let heads = 1;
858        let seq = 4;
859        let dim = 4;
860        let scale = 1.0 / (dim as f32).sqrt();
861
862        let q = Tensor::from_vec(
863            vec![0.1f32; batch * heads * seq * dim],
864            &[batch, heads, seq, dim],
865        )
866        .unwrap();
867        let k = Tensor::from_vec(
868            vec![0.2f32; batch * heads * seq * dim],
869            &[batch, heads, seq, dim],
870        )
871        .unwrap();
872        let v = Tensor::from_vec(
873            vec![0.5f32; batch * heads * seq * dim],
874            &[batch, heads, seq, dim],
875        )
876        .unwrap();
877
878        let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
879
880        let backward_fn = FusedAttentionBackward::new(
881            None,
882            None,
883            None,
884            q.clone(),
885            k.clone(),
886            v.clone(),
887            output.clone(),
888            scale,
889            true,
890        );
891
892        let grad_output = Tensor::from_vec(
893            vec![1.0f32; batch * heads * seq * dim],
894            &[batch, heads, seq, dim],
895        )
896        .unwrap();
897
898        let grads = backward_fn.apply(&grad_output);
899        assert_eq!(grads.len(), 3);
900
901        let gq = grads[0].as_ref().unwrap();
902        let gk = grads[1].as_ref().unwrap();
903        let gv = grads[2].as_ref().unwrap();
904
905        // All grads should be finite
906        for val in gq
907            .to_vec()
908            .iter()
909            .chain(gk.to_vec().iter())
910            .chain(gv.to_vec().iter())
911        {
912            assert!(val.is_finite(), "Gradient should be finite, got {}", val);
913        }
914    }
915}