Skip to main content

axonml_nn/layers/
attention.rs

1//! Attention Mechanisms - Multi-Head Attention
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/attention.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20#[cfg(feature = "cuda")]
21use axonml_autograd::functions::FusedAttentionBackward;
22#[cfg(feature = "cuda")]
23use axonml_autograd::grad_fn::GradFn;
24use axonml_tensor::Tensor;
25
26use crate::layers::Linear;
27use crate::module::Module;
28use crate::parameter::Parameter;
29
30// =============================================================================
31// MultiHeadAttention
32// =============================================================================
33
34/// Multi-Head Attention mechanism.
35///
36/// Allows the model to jointly attend to information from different
37/// representation subspaces at different positions.
38///
39/// # Arguments
40/// * `embed_dim` - Total dimension of the model
41/// * `num_heads` - Number of parallel attention heads
42/// * `dropout` - Dropout probability (default: 0.0)
43///
44/// # Shape
45/// - Query: (L, N, E) or (N, L, E) if batch_first
46/// - Key: (S, N, E) or (N, S, E) if batch_first
47/// - Value: (S, N, E) or (N, S, E) if batch_first
48/// - Output: (L, N, E) or (N, L, E) if batch_first
49pub struct MultiHeadAttention {
50    /// Query projection.
51    q_proj: Linear,
52    /// Key projection.
53    k_proj: Linear,
54    /// Value projection.
55    v_proj: Linear,
56    /// Output projection.
57    out_proj: Linear,
58    /// Embedding dimension.
59    embed_dim: usize,
60    /// Number of attention heads.
61    num_heads: usize,
62    /// Dimension per head.
63    head_dim: usize,
64    /// Scaling factor.
65    scale: f32,
66    /// Whether input is batch first.
67    batch_first: bool,
68}
69
70impl MultiHeadAttention {
71    /// Creates a new MultiHeadAttention module.
72    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
73        Self::with_options(embed_dim, num_heads, 0.0, true)
74    }
75
76    /// Creates MultiHeadAttention with all options.
77    pub fn with_options(
78        embed_dim: usize,
79        num_heads: usize,
80        _dropout: f32,
81        batch_first: bool,
82    ) -> Self {
83        assert!(
84            embed_dim % num_heads == 0,
85            "embed_dim must be divisible by num_heads"
86        );
87
88        let head_dim = embed_dim / num_heads;
89        let scale = (head_dim as f32).sqrt().recip();
90
91        Self {
92            q_proj: Linear::new(embed_dim, embed_dim),
93            k_proj: Linear::new(embed_dim, embed_dim),
94            v_proj: Linear::new(embed_dim, embed_dim),
95            out_proj: Linear::new(embed_dim, embed_dim),
96            embed_dim,
97            num_heads,
98            head_dim,
99            scale,
100            batch_first,
101        }
102    }
103
104    /// Expand attention mask to match scores shape via broadcast.
105    ///
106    /// Handles common mask shapes:
107    /// - [T, S] → [B, H, T, S] (same mask for all batches/heads)
108    /// - [B, 1, T, S] → [B, H, T, S] (per-batch, shared across heads)
109    /// - [B, H, T, S] → no expansion needed
110    ///
111    /// Works on both CPU and GPU — uses Variable::expand which preserves device.
112    fn expand_mask(
113        mask: &Variable,
114        batch_size: usize,
115        num_heads: usize,
116        tgt_len: usize,
117        src_len: usize,
118    ) -> Variable {
119        let mask_shape = mask.shape();
120        let target = [batch_size, num_heads, tgt_len, src_len];
121
122        if mask_shape == target {
123            return mask.clone();
124        }
125
126        // [T, S] → [1, 1, T, S] → expand to [B, H, T, S]
127        if mask_shape.len() == 2 {
128            let reshaped = mask.reshape(&[1, 1, tgt_len, src_len]);
129            return reshaped.expand(&target);
130        }
131
132        // [B, 1, T, S] → expand heads dim
133        if mask_shape.len() == 4 && mask_shape[1] == 1 {
134            return mask.expand(&target);
135        }
136
137        // [1, 1, T, S] → expand both
138        if mask_shape.len() == 4 && mask_shape[0] == 1 && mask_shape[1] == 1 {
139            return mask.expand(&target);
140        }
141
142        // Fallback: just clone
143        mask.clone()
144    }
145
146    /// Computes attention using batched matmul (BLAS-accelerated).
147    pub fn attention(
148        &self,
149        query: &Variable,
150        key: &Variable,
151        value: &Variable,
152        attn_mask: Option<&Variable>,
153    ) -> Variable {
154        let q_shape = query.shape();
155        let (batch_size, tgt_len, _) = if self.batch_first {
156            (q_shape[0], q_shape[1], q_shape[2])
157        } else {
158            (q_shape[1], q_shape[0], q_shape[2])
159        };
160        let src_len = if self.batch_first {
161            key.shape()[1]
162        } else {
163            key.shape()[0]
164        };
165
166        // Project Q, K, V  (all tracked through autograd)
167        let q = self.q_proj.forward(query);
168        let k = self.k_proj.forward(key);
169        let v = self.v_proj.forward(value);
170
171        // Reshape to multi-head: [batch, seq, embed] → [batch, seq, heads, head_dim]
172        // Then transpose to:     [batch, heads, seq, head_dim]
173        let q = q
174            .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
175            .transpose(1, 2);
176        let k = k
177            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
178            .transpose(1, 2);
179        let v = v
180            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
181            .transpose(1, 2);
182
183        // GPU fast path: fused attention kernel avoids materializing the N*N
184        // attention matrix. Works for both inference and training when no mask
185        // is provided. The kernel computes Q@K^T * scale -> softmax -> @V in
186        // one pass per row. In training mode, a FusedAttentionBackward autograd
187        // function is attached that uses the CUDA backward kernel (or CPU fallback).
188        #[cfg(feature = "cuda")]
189        if q.data().device().is_gpu() && attn_mask.is_none() {
190            let is_training = axonml_autograd::no_grad::is_grad_enabled();
191            let q_tensor = q.data();
192            let k_tensor = k.data();
193            let v_tensor = v.data();
194
195            if let Some(attn_out) = q_tensor.fused_attention_cuda(
196                &k_tensor, &v_tensor, self.scale,
197                false, // not causal by default; causal mask would be in attn_mask
198            ) {
199                let attn_output = if is_training
200                    && (q.requires_grad() || k.requires_grad() || v.requires_grad())
201                {
202                    // Build autograd backward function for training
203                    let backward = FusedAttentionBackward::new(
204                        q.grad_fn().cloned(),
205                        k.grad_fn().cloned(),
206                        v.grad_fn().cloned(),
207                        q_tensor,
208                        k_tensor,
209                        v_tensor,
210                        attn_out.clone(),
211                        self.scale,
212                        false,
213                    );
214                    Variable::from_operation(attn_out, GradFn::new(backward), true)
215                } else {
216                    Variable::new(attn_out, false)
217                };
218                let attn_output =
219                    attn_output
220                        .transpose(1, 2)
221                        .reshape(&[batch_size, tgt_len, self.embed_dim]);
222                return self.out_proj.forward(&attn_output);
223            }
224            // Fall through to standard path if fused kernel fails
225        }
226
227        // Scaled dot-product attention: scores = Q @ K^T * scale
228        // K^T: [batch, heads, head_dim, src_len]
229        let k_t = k.transpose(2, 3);
230        // scores: [batch, heads, tgt_len, src_len]
231        let scores = q.matmul(&k_t).mul_scalar(self.scale);
232
233        // Apply attention mask (0 → -1e9 additive mask)
234        // Mask shapes: [tgt_len, src_len] (causal) or [batch, src_len] (padding)
235        // Scores shape: [batch, heads, tgt_len, src_len]
236        let scores = if let Some(mask) = attn_mask {
237            let mask_shape = mask.shape();
238            let mask_data = mask.data();
239            let scores_shape = scores.shape();
240            let total = scores_shape.iter().product::<usize>();
241
242            // GPU fast path: expand mask entirely on GPU via CUDA kernel
243            // Avoids GPU→CPU→GPU round-trip (9 mask expansions per forward pass)
244            #[cfg(feature = "cuda")]
245            if scores.data().device().is_gpu() {
246                // Ensure mask is on GPU (it's small, so upload is cheap if needed)
247                let mask_gpu = if mask_data.device().is_gpu() {
248                    mask_data.clone()
249                } else {
250                    mask_data.to_device(scores.data().device()).unwrap()
251                };
252
253                if let Some(expanded_tensor) = mask_gpu.mask_expand_cuda(
254                    &scores_shape,
255                    batch_size,
256                    self.num_heads,
257                    tgt_len,
258                    src_len,
259                ) {
260                    let additive_mask = Variable::new(expanded_tensor, false);
261                    return self.finish_attention(
262                        scores.add_var(&additive_mask),
263                        &v,
264                        batch_size,
265                        tgt_len,
266                    );
267                }
268                // Fall through to CPU path on unsupported shape
269            }
270
271            // CPU fallback: expand mask with nested loops
272            let mask_vec = mask_data.to_vec();
273            let additive: Vec<f32> = mask_vec
274                .iter()
275                .map(|&v| if v == 0.0 { -1e9 } else { 0.0 })
276                .collect();
277
278            let mut expanded = vec![0.0f32; total];
279
280            if mask_shape.len() == 2 && mask_shape[0] == tgt_len && mask_shape[1] == src_len {
281                // Causal mask [tgt_len, src_len] → broadcast over batch & heads
282                for b in 0..batch_size {
283                    for h in 0..self.num_heads {
284                        for i in 0..tgt_len {
285                            for j in 0..src_len {
286                                let idx = b * self.num_heads * tgt_len * src_len
287                                    + h * tgt_len * src_len
288                                    + i * src_len
289                                    + j;
290                                expanded[idx] = additive[i * src_len + j];
291                            }
292                        }
293                    }
294                }
295            } else if mask_shape.len() == 2
296                && mask_shape[0] == batch_size
297                && mask_shape[1] == src_len
298            {
299                // Padding mask [batch, src_len] → broadcast over heads & tgt positions
300                for b in 0..batch_size {
301                    for h in 0..self.num_heads {
302                        for i in 0..tgt_len {
303                            for j in 0..src_len {
304                                let idx = b * self.num_heads * tgt_len * src_len
305                                    + h * tgt_len * src_len
306                                    + i * src_len
307                                    + j;
308                                expanded[idx] = additive[b * src_len + j];
309                            }
310                        }
311                    }
312                }
313            } else {
314                // General: tile mask across scores using modular indexing
315                for (i, val) in expanded.iter_mut().enumerate() {
316                    *val = additive[i % additive.len()];
317                }
318            }
319
320            let mut additive_tensor = Tensor::from_vec(expanded, &scores_shape).expect("tensor creation failed");
321            let scores_device = scores.data().device();
322            if scores_device.is_gpu() {
323                additive_tensor = additive_tensor.to_device(scores_device).expect("device transfer failed");
324            }
325            let additive_mask = Variable::new(additive_tensor, false);
326            scores.add_var(&additive_mask)
327        } else {
328            scores
329        };
330
331        self.finish_attention(scores, &v, batch_size, tgt_len)
332    }
333
334    /// Softmax → weighted sum → reshape → output projection.
335    /// Shared by both GPU and CPU mask expansion paths.
336    fn finish_attention(
337        &self,
338        scores: Variable,
339        v: &Variable,
340        batch_size: usize,
341        tgt_len: usize,
342    ) -> Variable {
343        let attn_weights = scores.softmax(-1);
344        let attn_output = attn_weights.matmul(v);
345        let attn_output =
346            attn_output
347                .transpose(1, 2)
348                .reshape(&[batch_size, tgt_len, self.embed_dim]);
349        self.out_proj.forward(&attn_output)
350    }
351}
352
353impl Module for MultiHeadAttention {
354    fn forward(&self, input: &Variable) -> Variable {
355        // Self-attention: query = key = value = input
356        self.attention(input, input, input, None)
357    }
358
359    fn parameters(&self) -> Vec<Parameter> {
360        let mut params = Vec::new();
361        params.extend(self.q_proj.parameters());
362        params.extend(self.k_proj.parameters());
363        params.extend(self.v_proj.parameters());
364        params.extend(self.out_proj.parameters());
365        params
366    }
367
368    fn named_parameters(&self) -> HashMap<String, Parameter> {
369        let mut params = HashMap::new();
370        for (name, param) in self.q_proj.named_parameters() {
371            params.insert(format!("q_proj.{name}"), param);
372        }
373        for (name, param) in self.k_proj.named_parameters() {
374            params.insert(format!("k_proj.{name}"), param);
375        }
376        for (name, param) in self.v_proj.named_parameters() {
377            params.insert(format!("v_proj.{name}"), param);
378        }
379        for (name, param) in self.out_proj.named_parameters() {
380            params.insert(format!("out_proj.{name}"), param);
381        }
382        params
383    }
384
385    fn name(&self) -> &'static str {
386        "MultiHeadAttention"
387    }
388}
389
390// =============================================================================
391// CrossAttention
392// =============================================================================
393
394/// Cross-Attention mechanism for encoder-decoder architectures.
395///
396/// Queries come from the decoder, keys and values come from the encoder.
397/// This is the standard cross-attention used in Transformer decoders,
398/// seq2seq models, and vision-language models.
399///
400/// # Shape (batch_first=true)
401/// - Query (decoder): (N, T, E)
402/// - Memory (encoder): (N, S, E)
403/// - Output: (N, T, E)
404///
405/// where N=batch, T=target seq len, S=source seq len, E=embed_dim.
406pub struct CrossAttention {
407    /// Underlying multi-head attention.
408    mha: MultiHeadAttention,
409}
410
411impl CrossAttention {
412    /// Creates a new CrossAttention module.
413    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
414        Self {
415            mha: MultiHeadAttention::new(embed_dim, num_heads),
416        }
417    }
418
419    /// Creates CrossAttention with all options.
420    pub fn with_options(
421        embed_dim: usize,
422        num_heads: usize,
423        dropout: f32,
424        batch_first: bool,
425    ) -> Self {
426        Self {
427            mha: MultiHeadAttention::with_options(embed_dim, num_heads, dropout, batch_first),
428        }
429    }
430
431    /// Computes cross-attention.
432    ///
433    /// # Arguments
434    /// * `query` - Decoder hidden states (N, T, E)
435    /// * `memory` - Encoder output (N, S, E)
436    /// * `attn_mask` - Optional attention mask
437    pub fn cross_attention(
438        &self,
439        query: &Variable,
440        memory: &Variable,
441        attn_mask: Option<&Variable>,
442    ) -> Variable {
443        self.mha.attention(query, memory, memory, attn_mask)
444    }
445
446    /// Returns the embedding dimension.
447    pub fn embed_dim(&self) -> usize {
448        self.mha.embed_dim
449    }
450
451    /// Returns the number of heads.
452    pub fn num_heads(&self) -> usize {
453        self.mha.num_heads
454    }
455}
456
457impl Module for CrossAttention {
458    fn forward(&self, input: &Variable) -> Variable {
459        // When called as Module (single input), acts as self-attention.
460        // Use cross_attention() for encoder-decoder attention.
461        self.mha.forward(input)
462    }
463
464    fn parameters(&self) -> Vec<Parameter> {
465        self.mha.parameters()
466    }
467
468    fn named_parameters(&self) -> HashMap<String, Parameter> {
469        let mut params = HashMap::new();
470        for (name, param) in self.mha.named_parameters() {
471            params.insert(format!("mha.{name}"), param);
472        }
473        params
474    }
475
476    fn name(&self) -> &'static str {
477        "CrossAttention"
478    }
479}
480
481// =============================================================================
482// Fused Scaled Dot-Product Attention
483// =============================================================================
484
485/// Fused scaled dot-product attention on Tensors.
486///
487/// Computes `softmax(Q @ K^T * scale) @ V` without materializing the full
488/// N*N attention matrix. On GPU, this uses a CUDA kernel that processes one
489/// query row per thread. On CPU, falls back to standard matmul + softmax.
490///
491/// # Arguments
492/// * `q` - Query tensor `[B, H, Tq, D]`
493/// * `k` - Key tensor `[B, H, Tk, D]`
494/// * `v` - Value tensor `[B, H, Tk, D]`
495/// * `scale` - Scaling factor (typically `1/sqrt(head_dim)`)
496/// * `is_causal` - Whether to apply causal masking
497///
498/// # Returns
499/// Output tensor `[B, H, Tq, D]`
500///
501/// # Note on Flash Attention
502/// This is a fused kernel (Option B) — it avoids the N*N memory allocation
503/// but each thread still iterates over the full key sequence. True Flash
504/// Attention (Option A) would use shared memory tiling with online softmax
505/// (the Dao et al. algorithm), processing in blocks of ~64-128 and requiring:
506/// - Shared memory for Q/K/V tile loading
507/// - Online softmax with running max/sum correction across tiles
508/// - Two passes: forward for output + logsumexp, backward with recomputation
509/// - ~3x more complex kernel code but O(N) memory and better cache behavior
510///
511/// For sequences up to ~2048, this fused kernel provides most of the benefit.
512/// For longer sequences, use the tiled CPU Flash Attention in `axonml-llm`.
513pub fn scaled_dot_product_attention_fused(
514    q: &Tensor<f32>,
515    k: &Tensor<f32>,
516    v: &Tensor<f32>,
517    scale: f32,
518    is_causal: bool,
519) -> Tensor<f32> {
520    // Try GPU fused kernel
521    #[cfg(feature = "cuda")]
522    if q.device().is_gpu() {
523        if let Some(result) = q.fused_attention_cuda(k, v, scale, is_causal) {
524            return result;
525        }
526    }
527
528    // CPU fallback: standard matmul-based attention
529    let shape = q.shape();
530    let batch_size = shape[0];
531    let num_heads = shape[1];
532    let tgt_len = shape[2];
533    let head_dim = shape[3];
534    let src_len = k.shape()[2];
535
536    let q_data = q.to_vec();
537    let k_data = k.to_vec();
538    let v_data = v.to_vec();
539
540    let mut output = vec![0.0f32; batch_size * num_heads * tgt_len * head_dim];
541
542    for b in 0..batch_size {
543        for h in 0..num_heads {
544            for i in 0..tgt_len {
545                // Compute attention scores for row i
546                let mut scores = vec![0.0f32; src_len];
547                let mut max_score = f32::NEG_INFINITY;
548
549                for j in 0..src_len {
550                    if is_causal && j > i {
551                        scores[j] = f32::NEG_INFINITY;
552                        continue;
553                    }
554                    let mut score = 0.0f32;
555                    for d in 0..head_dim {
556                        let q_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
557                        let k_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
558                        score += q_data[q_idx] * k_data[k_idx];
559                    }
560                    score *= scale;
561                    scores[j] = score;
562                    if score > max_score {
563                        max_score = score;
564                    }
565                }
566
567                // Softmax
568                let mut sum_exp = 0.0f32;
569                for s in &mut scores {
570                    if *s > f32::NEG_INFINITY {
571                        *s = (*s - max_score).exp();
572                        sum_exp += *s;
573                    } else {
574                        *s = 0.0;
575                    }
576                }
577                let inv_sum = if sum_exp > 0.0 { 1.0 / sum_exp } else { 0.0 };
578
579                // Weighted sum of V
580                for d in 0..head_dim {
581                    let mut val = 0.0f32;
582                    for j in 0..src_len {
583                        let v_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
584                        val += scores[j] * v_data[v_idx];
585                    }
586                    let out_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
587                    output[out_idx] = val * inv_sum;
588                }
589            }
590        }
591    }
592
593    Tensor::from_vec(output, &[batch_size, num_heads, tgt_len, head_dim]).expect("tensor creation failed")
594}
595
596// =============================================================================
597// Tests
598// =============================================================================
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    #[test]
605    fn test_multihead_attention_creation() {
606        let mha = MultiHeadAttention::new(512, 8);
607        assert_eq!(mha.embed_dim, 512);
608        assert_eq!(mha.num_heads, 8);
609        assert_eq!(mha.head_dim, 64);
610    }
611
612    #[test]
613    fn test_multihead_attention_forward() {
614        let mha = MultiHeadAttention::new(64, 4);
615        let input = Variable::new(
616            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
617            false,
618        );
619        let output = mha.forward(&input);
620        assert_eq!(output.shape(), vec![2, 10, 64]);
621    }
622
623    #[test]
624    fn test_cross_attention() {
625        let mha = MultiHeadAttention::new(64, 4);
626        let query = Variable::new(
627            Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
628            false,
629        );
630        let key_value = Variable::new(
631            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
632            false,
633        );
634        let output = mha.attention(&query, &key_value, &key_value, None);
635        assert_eq!(output.shape(), vec![2, 5, 64]);
636    }
637
638    #[test]
639    fn test_multihead_attention_parameters() {
640        let mha = MultiHeadAttention::new(64, 4);
641        let params = mha.parameters();
642        // Q, K, V, Out projections each have weight + bias = 8 total
643        assert_eq!(params.len(), 8);
644    }
645
646    #[test]
647    fn test_cross_attention_creation() {
648        let ca = CrossAttention::new(256, 8);
649        assert_eq!(ca.embed_dim(), 256);
650        assert_eq!(ca.num_heads(), 8);
651    }
652
653    #[test]
654    fn test_cross_attention_forward() {
655        let ca = CrossAttention::new(64, 4);
656        // Decoder query: (batch=2, tgt_len=5, embed=64)
657        let query = Variable::new(
658            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
659            false,
660        );
661        // Encoder memory: (batch=2, src_len=10, embed=64)
662        let memory = Variable::new(
663            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
664            false,
665        );
666        let output = ca.cross_attention(&query, &memory, None);
667        assert_eq!(output.shape(), vec![2, 5, 64]);
668    }
669
670    #[test]
671    fn test_cross_attention_self_attention_fallback() {
672        let ca = CrossAttention::new(64, 4);
673        let input = Variable::new(
674            Tensor::from_vec(vec![1.0; 2 * 8 * 64], &[2, 8, 64]).expect("tensor creation failed"),
675            false,
676        );
677        // Module::forward does self-attention
678        let output = ca.forward(&input);
679        assert_eq!(output.shape(), vec![2, 8, 64]);
680    }
681
682    #[test]
683    fn test_cross_attention_parameters() {
684        let ca = CrossAttention::new(64, 4);
685        let params = ca.parameters();
686        assert_eq!(params.len(), 8); // Q, K, V, Out × (weight + bias)
687        let named = ca.named_parameters();
688        assert!(named.contains_key("mha.q_proj.weight"));
689        assert!(named.contains_key("mha.out_proj.bias"));
690    }
691
692    #[test]
693    fn test_fused_attention_cpu() {
694        // Test fused attention on CPU (fallback path)
695        let batch = 2;
696        let heads = 4;
697        let seq = 8;
698        let dim = 16;
699        let scale = 1.0 / (dim as f32).sqrt();
700
701        let q = Tensor::from_vec(
702            vec![0.1; batch * heads * seq * dim],
703            &[batch, heads, seq, dim],
704        )
705        .unwrap();
706        let k = Tensor::from_vec(
707            vec![0.1; batch * heads * seq * dim],
708            &[batch, heads, seq, dim],
709        )
710        .unwrap();
711        let v = Tensor::from_vec(
712            vec![0.5; batch * heads * seq * dim],
713            &[batch, heads, seq, dim],
714        )
715        .unwrap();
716
717        let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
718        assert_eq!(out.shape(), &[batch, heads, seq, dim]);
719
720        // With uniform V=0.5, output should be close to 0.5
721        let out_vec = out.to_vec();
722        for val in &out_vec {
723            assert!((*val - 0.5).abs() < 0.01, "Expected ~0.5, got {}", val);
724        }
725    }
726
727    #[test]
728    fn test_fused_attention_causal() {
729        let batch = 1;
730        let heads = 1;
731        let seq = 4;
732        let dim = 4;
733        let scale = 1.0 / (dim as f32).sqrt();
734
735        // Q and K are identity-like so attention focuses on matching positions
736        let q = Tensor::from_vec(
737            vec![0.1; batch * heads * seq * dim],
738            &[batch, heads, seq, dim],
739        )
740        .unwrap();
741        let k = Tensor::from_vec(
742            vec![0.1; batch * heads * seq * dim],
743            &[batch, heads, seq, dim],
744        )
745        .unwrap();
746        let v = Tensor::from_vec(
747            vec![
748                1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
749            ],
750            &[batch, heads, seq, dim],
751        )
752        .unwrap();
753
754        let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
755        assert_eq!(out.shape(), &[batch, heads, seq, dim]);
756
757        // First position can only attend to position 0, so output = V[0] = [1,0,0,0]
758        let out_vec = out.to_vec();
759        assert!(
760            (out_vec[0] - 1.0).abs() < 1e-5,
761            "row 0, col 0 should be 1.0"
762        );
763        assert!((out_vec[1]).abs() < 1e-5, "row 0, col 1 should be 0.0");
764    }
765
766    #[test]
767    fn test_multihead_attention_backward_cpu() {
768        // Test that gradients flow through MHA in training mode (CPU path)
769        use axonml_autograd::backward;
770
771        let mha = MultiHeadAttention::new(32, 4);
772        let input = Variable::new(
773            Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
774            true,
775        );
776        let output = mha.forward(&input);
777        assert_eq!(output.shape(), vec![2, 4, 32]);
778
779        // Sum the output and backward
780        let loss = output.sum();
781        let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
782        backward(&loss, &ones);
783
784        // Input should have gradients
785        let grad = input.grad();
786        assert!(grad.is_some(), "Input gradient should exist");
787        let grad_data = grad.unwrap();
788        assert_eq!(grad_data.shape(), &[2, 4, 32]);
789
790        // Gradients should be non-zero
791        let grad_vec = grad_data.to_vec();
792        let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
793        assert!(non_zero, "Gradients should be non-zero");
794    }
795
796    #[test]
797    fn test_fused_attention_backward_cpu() {
798        // Test the FusedAttentionBackward autograd function directly on CPU
799        use axonml_autograd::functions::FusedAttentionBackward;
800        use axonml_autograd::grad_fn::GradientFunction;
801
802        let batch = 1;
803        let heads = 2;
804        let seq = 4;
805        let dim = 8;
806        let scale = 1.0 / (dim as f32).sqrt();
807
808        // Create random-ish tensors
809        let q_data: Vec<f32> = (0..batch * heads * seq * dim)
810            .map(|i| ((i as f32) * 0.01).sin())
811            .collect();
812        let k_data: Vec<f32> = (0..batch * heads * seq * dim)
813            .map(|i| ((i as f32) * 0.02).cos())
814            .collect();
815        let v_data: Vec<f32> = (0..batch * heads * seq * dim)
816            .map(|i| ((i as f32) * 0.03).sin() + 0.5)
817            .collect();
818
819        let q = Tensor::from_vec(q_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
820        let k = Tensor::from_vec(k_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
821        let v = Tensor::from_vec(v_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
822
823        // Compute forward output using the fused CPU path
824        let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
825        assert_eq!(output.shape(), &[batch, heads, seq, dim]);
826
827        // Create backward function
828        let backward_fn = FusedAttentionBackward::new(
829            None,
830            None,
831            None,
832            q.clone(),
833            k.clone(),
834            v.clone(),
835            output.clone(),
836            scale,
837            false,
838        );
839
840        // Use ones as grad_output
841        let grad_output = Tensor::from_vec(
842            vec![1.0f32; batch * heads * seq * dim],
843            &[batch, heads, seq, dim],
844        )
845        .unwrap();
846
847        let grads = backward_fn.apply(&grad_output);
848        assert_eq!(grads.len(), 3);
849
850        let gq = grads[0].as_ref().expect("grad_Q should exist");
851        let gk = grads[1].as_ref().expect("grad_K should exist");
852        let gv = grads[2].as_ref().expect("grad_V should exist");
853
854        assert_eq!(gq.shape(), &[batch, heads, seq, dim]);
855        assert_eq!(gk.shape(), &[batch, heads, seq, dim]);
856        assert_eq!(gv.shape(), &[batch, heads, seq, dim]);
857
858        // Gradients should be finite
859        for val in gq
860            .to_vec()
861            .iter()
862            .chain(gk.to_vec().iter())
863            .chain(gv.to_vec().iter())
864        {
865            assert!(val.is_finite(), "Gradient should be finite, got {}", val);
866        }
867
868        // grad_V should be non-zero (it's P^T @ grad_output)
869        let gv_nonzero = gv.to_vec().iter().any(|&v| v.abs() > 1e-10);
870        assert!(gv_nonzero, "grad_V should be non-zero");
871    }
872
873    #[test]
874    fn test_fused_attention_backward_causal_cpu() {
875        // Test the backward with causal masking
876        use axonml_autograd::functions::FusedAttentionBackward;
877        use axonml_autograd::grad_fn::GradientFunction;
878
879        let batch = 1;
880        let heads = 1;
881        let seq = 4;
882        let dim = 4;
883        let scale = 1.0 / (dim as f32).sqrt();
884
885        let q = Tensor::from_vec(
886            vec![0.1f32; batch * heads * seq * dim],
887            &[batch, heads, seq, dim],
888        )
889        .unwrap();
890        let k = Tensor::from_vec(
891            vec![0.2f32; batch * heads * seq * dim],
892            &[batch, heads, seq, dim],
893        )
894        .unwrap();
895        let v = Tensor::from_vec(
896            vec![0.5f32; batch * heads * seq * dim],
897            &[batch, heads, seq, dim],
898        )
899        .unwrap();
900
901        let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
902
903        let backward_fn = FusedAttentionBackward::new(
904            None,
905            None,
906            None,
907            q.clone(),
908            k.clone(),
909            v.clone(),
910            output.clone(),
911            scale,
912            true,
913        );
914
915        let grad_output = Tensor::from_vec(
916            vec![1.0f32; batch * heads * seq * dim],
917            &[batch, heads, seq, dim],
918        )
919        .unwrap();
920
921        let grads = backward_fn.apply(&grad_output);
922        assert_eq!(grads.len(), 3);
923
924        let gq = grads[0].as_ref().unwrap();
925        let gk = grads[1].as_ref().unwrap();
926        let gv = grads[2].as_ref().unwrap();
927
928        // All grads should be finite
929        for val in gq
930            .to_vec()
931            .iter()
932            .chain(gk.to_vec().iter())
933            .chain(gv.to_vec().iter())
934        {
935            assert!(val.is_finite(), "Gradient should be finite, got {}", val);
936        }
937    }
938}