Skip to main content

axonml_nn/layers/
attention.rs

1//! Attention Mechanisms - Multi-Head Attention
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/attention.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20#[cfg(feature = "cuda")]
21use axonml_autograd::functions::FusedAttentionBackward;
22#[cfg(feature = "cuda")]
23use axonml_autograd::grad_fn::GradFn;
24use axonml_tensor::Tensor;
25
26use crate::layers::Linear;
27use crate::module::Module;
28use crate::parameter::Parameter;
29
30// =============================================================================
31// MultiHeadAttention
32// =============================================================================
33
34/// Multi-Head Attention mechanism.
35///
36/// Allows the model to jointly attend to information from different
37/// representation subspaces at different positions.
38///
39/// # Arguments
40/// * `embed_dim` - Total dimension of the model
41/// * `num_heads` - Number of parallel attention heads
42/// * `dropout` - Dropout probability (default: 0.0)
43///
44/// # Shape
45/// - Query: (L, N, E) or (N, L, E) if batch_first
46/// - Key: (S, N, E) or (N, S, E) if batch_first
47/// - Value: (S, N, E) or (N, S, E) if batch_first
48/// - Output: (L, N, E) or (N, L, E) if batch_first
49pub struct MultiHeadAttention {
50    /// Query projection.
51    q_proj: Linear,
52    /// Key projection.
53    k_proj: Linear,
54    /// Value projection.
55    v_proj: Linear,
56    /// Output projection.
57    out_proj: Linear,
58    /// Embedding dimension.
59    embed_dim: usize,
60    /// Number of attention heads.
61    num_heads: usize,
62    /// Dimension per head.
63    head_dim: usize,
64    /// Scaling factor.
65    scale: f32,
66    /// Whether input is batch first.
67    batch_first: bool,
68}
69
70impl MultiHeadAttention {
71    /// Creates a new MultiHeadAttention module.
72    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
73        Self::with_options(embed_dim, num_heads, 0.0, true)
74    }
75
76    /// Creates MultiHeadAttention with all options.
77    pub fn with_options(
78        embed_dim: usize,
79        num_heads: usize,
80        _dropout: f32,
81        batch_first: bool,
82    ) -> Self {
83        assert!(
84            embed_dim % num_heads == 0,
85            "embed_dim must be divisible by num_heads"
86        );
87
88        let head_dim = embed_dim / num_heads;
89        let scale = (head_dim as f32).sqrt().recip();
90
91        Self {
92            q_proj: Linear::new(embed_dim, embed_dim),
93            k_proj: Linear::new(embed_dim, embed_dim),
94            v_proj: Linear::new(embed_dim, embed_dim),
95            out_proj: Linear::new(embed_dim, embed_dim),
96            embed_dim,
97            num_heads,
98            head_dim,
99            scale,
100            batch_first,
101        }
102    }
103
104    /// Expand attention mask to match scores shape via broadcast.
105    ///
106    /// Handles common mask shapes:
107    /// - [T, S] → [B, H, T, S] (same mask for all batches/heads)
108    /// - [B, 1, T, S] → [B, H, T, S] (per-batch, shared across heads)
109    /// - [B, H, T, S] → no expansion needed
110    ///
111    /// Works on both CPU and GPU — uses Variable::expand which preserves device.
112    #[allow(dead_code)]
113    fn expand_mask(
114        mask: &Variable,
115        batch_size: usize,
116        num_heads: usize,
117        tgt_len: usize,
118        src_len: usize,
119    ) -> Variable {
120        let mask_shape = mask.shape();
121        let target = [batch_size, num_heads, tgt_len, src_len];
122
123        if mask_shape == target {
124            return mask.clone();
125        }
126
127        // [T, S] → [1, 1, T, S] → expand to [B, H, T, S]
128        if mask_shape.len() == 2 {
129            let reshaped = mask.reshape(&[1, 1, tgt_len, src_len]);
130            return reshaped.expand(&target);
131        }
132
133        // [B, 1, T, S] → expand heads dim
134        if mask_shape.len() == 4 && mask_shape[1] == 1 {
135            return mask.expand(&target);
136        }
137
138        // [1, 1, T, S] → expand both
139        if mask_shape.len() == 4 && mask_shape[0] == 1 && mask_shape[1] == 1 {
140            return mask.expand(&target);
141        }
142
143        // Fallback: just clone
144        mask.clone()
145    }
146
147    /// Computes attention using batched matmul (BLAS-accelerated).
148    pub fn attention(
149        &self,
150        query: &Variable,
151        key: &Variable,
152        value: &Variable,
153        attn_mask: Option<&Variable>,
154    ) -> Variable {
155        let q_shape = query.shape();
156        let (batch_size, tgt_len, _) = if self.batch_first {
157            (q_shape[0], q_shape[1], q_shape[2])
158        } else {
159            (q_shape[1], q_shape[0], q_shape[2])
160        };
161        let src_len = if self.batch_first {
162            key.shape()[1]
163        } else {
164            key.shape()[0]
165        };
166
167        // Project Q, K, V  (all tracked through autograd)
168        let q = self.q_proj.forward(query);
169        let k = self.k_proj.forward(key);
170        let v = self.v_proj.forward(value);
171
172        // Reshape to multi-head: [batch, seq, embed] → [batch, seq, heads, head_dim]
173        // Then transpose to:     [batch, heads, seq, head_dim]
174        let q = q
175            .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
176            .transpose(1, 2);
177        let k = k
178            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
179            .transpose(1, 2);
180        let v = v
181            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
182            .transpose(1, 2);
183
184        // GPU fast path: fused attention kernel avoids materializing the N*N
185        // attention matrix. Works for both inference and training when no mask
186        // is provided. The kernel computes Q@K^T * scale -> softmax -> @V in
187        // one pass per row. In training mode, a FusedAttentionBackward autograd
188        // function is attached that uses the CUDA backward kernel (or CPU fallback).
189        #[cfg(feature = "cuda")]
190        if q.data().device().is_gpu() && attn_mask.is_none() {
191            let is_training = axonml_autograd::no_grad::is_grad_enabled();
192            let q_tensor = q.data();
193            let k_tensor = k.data();
194            let v_tensor = v.data();
195
196            if let Some(attn_out) = q_tensor.fused_attention_cuda(
197                &k_tensor, &v_tensor, self.scale,
198                false, // not causal by default; causal mask would be in attn_mask
199            ) {
200                let attn_output = if is_training
201                    && (q.requires_grad() || k.requires_grad() || v.requires_grad())
202                {
203                    // Build autograd backward function for training
204                    let backward = FusedAttentionBackward::new(
205                        q.grad_fn().cloned(),
206                        k.grad_fn().cloned(),
207                        v.grad_fn().cloned(),
208                        q_tensor,
209                        k_tensor,
210                        v_tensor,
211                        attn_out.clone(),
212                        self.scale,
213                        false,
214                    );
215                    Variable::from_operation(attn_out, GradFn::new(backward), true)
216                } else {
217                    Variable::new(attn_out, false)
218                };
219                let attn_output =
220                    attn_output
221                        .transpose(1, 2)
222                        .reshape(&[batch_size, tgt_len, self.embed_dim]);
223                return self.out_proj.forward(&attn_output);
224            }
225            // Fall through to standard path if fused kernel fails
226        }
227
228        // Scaled dot-product attention: scores = Q @ K^T * scale
229        // K^T: [batch, heads, head_dim, src_len]
230        let k_t = k.transpose(2, 3);
231        // scores: [batch, heads, tgt_len, src_len]
232        let scores = q.matmul(&k_t).mul_scalar(self.scale);
233
234        // Apply attention mask (0 → -1e9 additive mask)
235        // Mask shapes: [tgt_len, src_len] (causal) or [batch, src_len] (padding)
236        // Scores shape: [batch, heads, tgt_len, src_len]
237        let scores = if let Some(mask) = attn_mask {
238            let mask_shape = mask.shape();
239            let mask_data = mask.data();
240            let scores_shape = scores.shape();
241            let total = scores_shape.iter().product::<usize>();
242
243            // GPU fast path: expand mask entirely on GPU via CUDA kernel
244            // Avoids GPU→CPU→GPU round-trip (9 mask expansions per forward pass)
245            #[cfg(feature = "cuda")]
246            if scores.data().device().is_gpu() {
247                // Ensure mask is on GPU (it's small, so upload is cheap if needed)
248                let mask_gpu = if mask_data.device().is_gpu() {
249                    mask_data.clone()
250                } else {
251                    mask_data.to_device(scores.data().device()).unwrap()
252                };
253
254                if let Some(expanded_tensor) = mask_gpu.mask_expand_cuda(
255                    &scores_shape,
256                    batch_size,
257                    self.num_heads,
258                    tgt_len,
259                    src_len,
260                ) {
261                    let additive_mask = Variable::new(expanded_tensor, false);
262                    return self.finish_attention(
263                        scores.add_var(&additive_mask),
264                        &v,
265                        batch_size,
266                        tgt_len,
267                    );
268                }
269                // Fall through to CPU path on unsupported shape
270            }
271
272            // CPU fallback: expand mask with nested loops
273            let mask_vec = mask_data.to_vec();
274            let additive: Vec<f32> = mask_vec
275                .iter()
276                .map(|&v| if v == 0.0 { -1e9 } else { 0.0 })
277                .collect();
278
279            let mut expanded = vec![0.0f32; total];
280
281            if mask_shape.len() == 2 && mask_shape[0] == tgt_len && mask_shape[1] == src_len {
282                // Causal mask [tgt_len, src_len] → broadcast over batch & heads
283                for b in 0..batch_size {
284                    for h in 0..self.num_heads {
285                        for i in 0..tgt_len {
286                            for j in 0..src_len {
287                                let idx = b * self.num_heads * tgt_len * src_len
288                                    + h * tgt_len * src_len
289                                    + i * src_len
290                                    + j;
291                                expanded[idx] = additive[i * src_len + j];
292                            }
293                        }
294                    }
295                }
296            } else if mask_shape.len() == 2
297                && mask_shape[0] == batch_size
298                && mask_shape[1] == src_len
299            {
300                // Padding mask [batch, src_len] → broadcast over heads & tgt positions
301                for b in 0..batch_size {
302                    for h in 0..self.num_heads {
303                        for i in 0..tgt_len {
304                            for j in 0..src_len {
305                                let idx = b * self.num_heads * tgt_len * src_len
306                                    + h * tgt_len * src_len
307                                    + i * src_len
308                                    + j;
309                                expanded[idx] = additive[b * src_len + j];
310                            }
311                        }
312                    }
313                }
314            } else {
315                // General: tile mask across scores using modular indexing
316                for (i, val) in expanded.iter_mut().enumerate() {
317                    *val = additive[i % additive.len()];
318                }
319            }
320
321            let mut additive_tensor =
322                Tensor::from_vec(expanded, &scores_shape).expect("tensor creation failed");
323            let scores_device = scores.data().device();
324            if scores_device.is_gpu() {
325                additive_tensor = additive_tensor
326                    .to_device(scores_device)
327                    .expect("device transfer failed");
328            }
329            let additive_mask = Variable::new(additive_tensor, false);
330            scores.add_var(&additive_mask)
331        } else {
332            scores
333        };
334
335        self.finish_attention(scores, &v, batch_size, tgt_len)
336    }
337
338    /// Softmax → weighted sum → reshape → output projection.
339    /// Shared by both GPU and CPU mask expansion paths.
340    fn finish_attention(
341        &self,
342        scores: Variable,
343        v: &Variable,
344        batch_size: usize,
345        tgt_len: usize,
346    ) -> Variable {
347        let attn_weights = scores.softmax(-1);
348        let attn_output = attn_weights.matmul(v);
349        let attn_output =
350            attn_output
351                .transpose(1, 2)
352                .reshape(&[batch_size, tgt_len, self.embed_dim]);
353        self.out_proj.forward(&attn_output)
354    }
355}
356
357impl Module for MultiHeadAttention {
358    fn forward(&self, input: &Variable) -> Variable {
359        // Self-attention: query = key = value = input
360        self.attention(input, input, input, None)
361    }
362
363    fn parameters(&self) -> Vec<Parameter> {
364        let mut params = Vec::new();
365        params.extend(self.q_proj.parameters());
366        params.extend(self.k_proj.parameters());
367        params.extend(self.v_proj.parameters());
368        params.extend(self.out_proj.parameters());
369        params
370    }
371
372    fn named_parameters(&self) -> HashMap<String, Parameter> {
373        let mut params = HashMap::new();
374        for (name, param) in self.q_proj.named_parameters() {
375            params.insert(format!("q_proj.{name}"), param);
376        }
377        for (name, param) in self.k_proj.named_parameters() {
378            params.insert(format!("k_proj.{name}"), param);
379        }
380        for (name, param) in self.v_proj.named_parameters() {
381            params.insert(format!("v_proj.{name}"), param);
382        }
383        for (name, param) in self.out_proj.named_parameters() {
384            params.insert(format!("out_proj.{name}"), param);
385        }
386        params
387    }
388
389    fn name(&self) -> &'static str {
390        "MultiHeadAttention"
391    }
392}
393
394// =============================================================================
395// CrossAttention
396// =============================================================================
397
398/// Cross-Attention mechanism for encoder-decoder architectures.
399///
400/// Queries come from the decoder, keys and values come from the encoder.
401/// This is the standard cross-attention used in Transformer decoders,
402/// seq2seq models, and vision-language models.
403///
404/// # Shape (batch_first=true)
405/// - Query (decoder): (N, T, E)
406/// - Memory (encoder): (N, S, E)
407/// - Output: (N, T, E)
408///
409/// where N=batch, T=target seq len, S=source seq len, E=embed_dim.
410pub struct CrossAttention {
411    /// Underlying multi-head attention.
412    mha: MultiHeadAttention,
413}
414
415impl CrossAttention {
416    /// Creates a new CrossAttention module.
417    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
418        Self {
419            mha: MultiHeadAttention::new(embed_dim, num_heads),
420        }
421    }
422
423    /// Creates CrossAttention with all options.
424    pub fn with_options(
425        embed_dim: usize,
426        num_heads: usize,
427        dropout: f32,
428        batch_first: bool,
429    ) -> Self {
430        Self {
431            mha: MultiHeadAttention::with_options(embed_dim, num_heads, dropout, batch_first),
432        }
433    }
434
435    /// Computes cross-attention.
436    ///
437    /// # Arguments
438    /// * `query` - Decoder hidden states (N, T, E)
439    /// * `memory` - Encoder output (N, S, E)
440    /// * `attn_mask` - Optional attention mask
441    pub fn cross_attention(
442        &self,
443        query: &Variable,
444        memory: &Variable,
445        attn_mask: Option<&Variable>,
446    ) -> Variable {
447        self.mha.attention(query, memory, memory, attn_mask)
448    }
449
450    /// Returns the embedding dimension.
451    pub fn embed_dim(&self) -> usize {
452        self.mha.embed_dim
453    }
454
455    /// Returns the number of heads.
456    pub fn num_heads(&self) -> usize {
457        self.mha.num_heads
458    }
459}
460
461impl Module for CrossAttention {
462    fn forward(&self, input: &Variable) -> Variable {
463        // When called as Module (single input), acts as self-attention.
464        // Use cross_attention() for encoder-decoder attention.
465        self.mha.forward(input)
466    }
467
468    fn parameters(&self) -> Vec<Parameter> {
469        self.mha.parameters()
470    }
471
472    fn named_parameters(&self) -> HashMap<String, Parameter> {
473        let mut params = HashMap::new();
474        for (name, param) in self.mha.named_parameters() {
475            params.insert(format!("mha.{name}"), param);
476        }
477        params
478    }
479
480    fn name(&self) -> &'static str {
481        "CrossAttention"
482    }
483}
484
485// =============================================================================
486// Fused Scaled Dot-Product Attention
487// =============================================================================
488
489/// Fused scaled dot-product attention on Tensors.
490///
491/// Computes `softmax(Q @ K^T * scale) @ V` without materializing the full
492/// N*N attention matrix. On GPU, this uses a CUDA kernel that processes one
493/// query row per thread. On CPU, falls back to standard matmul + softmax.
494///
495/// # Arguments
496/// * `q` - Query tensor `[B, H, Tq, D]`
497/// * `k` - Key tensor `[B, H, Tk, D]`
498/// * `v` - Value tensor `[B, H, Tk, D]`
499/// * `scale` - Scaling factor (typically `1/sqrt(head_dim)`)
500/// * `is_causal` - Whether to apply causal masking
501///
502/// # Returns
503/// Output tensor `[B, H, Tq, D]`
504///
505/// # Note on Flash Attention
506/// This is a fused kernel (Option B) — it avoids the N*N memory allocation
507/// but each thread still iterates over the full key sequence. True Flash
508/// Attention (Option A) would use shared memory tiling with online softmax
509/// (the Dao et al. algorithm), processing in blocks of ~64-128 and requiring:
510/// - Shared memory for Q/K/V tile loading
511/// - Online softmax with running max/sum correction across tiles
512/// - Two passes: forward for output + logsumexp, backward with recomputation
513/// - ~3x more complex kernel code but O(N) memory and better cache behavior
514///
515/// For sequences up to ~2048, this fused kernel provides most of the benefit.
516/// For longer sequences, use the tiled CPU Flash Attention in `axonml-llm`.
517pub fn scaled_dot_product_attention_fused(
518    q: &Tensor<f32>,
519    k: &Tensor<f32>,
520    v: &Tensor<f32>,
521    scale: f32,
522    is_causal: bool,
523) -> Tensor<f32> {
524    // Try GPU fused kernel
525    #[cfg(feature = "cuda")]
526    if q.device().is_gpu() {
527        if let Some(result) = q.fused_attention_cuda(k, v, scale, is_causal) {
528            return result;
529        }
530    }
531
532    // CPU fallback: standard matmul-based attention
533    let shape = q.shape();
534    let batch_size = shape[0];
535    let num_heads = shape[1];
536    let tgt_len = shape[2];
537    let head_dim = shape[3];
538    let src_len = k.shape()[2];
539
540    let q_data = q.to_vec();
541    let k_data = k.to_vec();
542    let v_data = v.to_vec();
543
544    let mut output = vec![0.0f32; batch_size * num_heads * tgt_len * head_dim];
545
546    for b in 0..batch_size {
547        for h in 0..num_heads {
548            for i in 0..tgt_len {
549                // Compute attention scores for row i
550                let mut scores = vec![0.0f32; src_len];
551                let mut max_score = f32::NEG_INFINITY;
552
553                for j in 0..src_len {
554                    if is_causal && j > i {
555                        scores[j] = f32::NEG_INFINITY;
556                        continue;
557                    }
558                    let mut score = 0.0f32;
559                    for d in 0..head_dim {
560                        let q_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
561                        let k_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
562                        score += q_data[q_idx] * k_data[k_idx];
563                    }
564                    score *= scale;
565                    scores[j] = score;
566                    if score > max_score {
567                        max_score = score;
568                    }
569                }
570
571                // Softmax
572                let mut sum_exp = 0.0f32;
573                for s in &mut scores {
574                    if *s > f32::NEG_INFINITY {
575                        *s = (*s - max_score).exp();
576                        sum_exp += *s;
577                    } else {
578                        *s = 0.0;
579                    }
580                }
581                let inv_sum = if sum_exp > 0.0 { 1.0 / sum_exp } else { 0.0 };
582
583                // Weighted sum of V
584                for d in 0..head_dim {
585                    let mut val = 0.0f32;
586                    for j in 0..src_len {
587                        let v_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
588                        val += scores[j] * v_data[v_idx];
589                    }
590                    let out_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
591                    output[out_idx] = val * inv_sum;
592                }
593            }
594        }
595    }
596
597    Tensor::from_vec(output, &[batch_size, num_heads, tgt_len, head_dim])
598        .expect("tensor creation failed")
599}
600
601// =============================================================================
602// Tests
603// =============================================================================
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    #[test]
610    fn test_multihead_attention_creation() {
611        let mha = MultiHeadAttention::new(512, 8);
612        assert_eq!(mha.embed_dim, 512);
613        assert_eq!(mha.num_heads, 8);
614        assert_eq!(mha.head_dim, 64);
615    }
616
617    #[test]
618    fn test_multihead_attention_forward() {
619        let mha = MultiHeadAttention::new(64, 4);
620        let input = Variable::new(
621            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
622            false,
623        );
624        let output = mha.forward(&input);
625        assert_eq!(output.shape(), vec![2, 10, 64]);
626    }
627
628    #[test]
629    fn test_cross_attention() {
630        let mha = MultiHeadAttention::new(64, 4);
631        let query = Variable::new(
632            Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
633            false,
634        );
635        let key_value = Variable::new(
636            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
637            false,
638        );
639        let output = mha.attention(&query, &key_value, &key_value, None);
640        assert_eq!(output.shape(), vec![2, 5, 64]);
641    }
642
643    #[test]
644    fn test_multihead_attention_parameters() {
645        let mha = MultiHeadAttention::new(64, 4);
646        let params = mha.parameters();
647        // Q, K, V, Out projections each have weight + bias = 8 total
648        assert_eq!(params.len(), 8);
649    }
650
651    #[test]
652    fn test_cross_attention_creation() {
653        let ca = CrossAttention::new(256, 8);
654        assert_eq!(ca.embed_dim(), 256);
655        assert_eq!(ca.num_heads(), 8);
656    }
657
658    #[test]
659    fn test_cross_attention_forward() {
660        let ca = CrossAttention::new(64, 4);
661        // Decoder query: (batch=2, tgt_len=5, embed=64)
662        let query = Variable::new(
663            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
664            false,
665        );
666        // Encoder memory: (batch=2, src_len=10, embed=64)
667        let memory = Variable::new(
668            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
669            false,
670        );
671        let output = ca.cross_attention(&query, &memory, None);
672        assert_eq!(output.shape(), vec![2, 5, 64]);
673    }
674
675    #[test]
676    fn test_cross_attention_self_attention_fallback() {
677        let ca = CrossAttention::new(64, 4);
678        let input = Variable::new(
679            Tensor::from_vec(vec![1.0; 2 * 8 * 64], &[2, 8, 64]).expect("tensor creation failed"),
680            false,
681        );
682        // Module::forward does self-attention
683        let output = ca.forward(&input);
684        assert_eq!(output.shape(), vec![2, 8, 64]);
685    }
686
687    #[test]
688    fn test_cross_attention_parameters() {
689        let ca = CrossAttention::new(64, 4);
690        let params = ca.parameters();
691        assert_eq!(params.len(), 8); // Q, K, V, Out × (weight + bias)
692        let named = ca.named_parameters();
693        assert!(named.contains_key("mha.q_proj.weight"));
694        assert!(named.contains_key("mha.out_proj.bias"));
695    }
696
697    #[test]
698    fn test_fused_attention_cpu() {
699        // Test fused attention on CPU (fallback path)
700        let batch = 2;
701        let heads = 4;
702        let seq = 8;
703        let dim = 16;
704        let scale = 1.0 / (dim as f32).sqrt();
705
706        let q = Tensor::from_vec(
707            vec![0.1; batch * heads * seq * dim],
708            &[batch, heads, seq, dim],
709        )
710        .unwrap();
711        let k = Tensor::from_vec(
712            vec![0.1; batch * heads * seq * dim],
713            &[batch, heads, seq, dim],
714        )
715        .unwrap();
716        let v = Tensor::from_vec(
717            vec![0.5; batch * heads * seq * dim],
718            &[batch, heads, seq, dim],
719        )
720        .unwrap();
721
722        let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
723        assert_eq!(out.shape(), &[batch, heads, seq, dim]);
724
725        // With uniform V=0.5, output should be close to 0.5
726        let out_vec = out.to_vec();
727        for val in &out_vec {
728            assert!((*val - 0.5).abs() < 0.01, "Expected ~0.5, got {}", val);
729        }
730    }
731
732    #[test]
733    fn test_fused_attention_causal() {
734        let batch = 1;
735        let heads = 1;
736        let seq = 4;
737        let dim = 4;
738        let scale = 1.0 / (dim as f32).sqrt();
739
740        // Q and K are identity-like so attention focuses on matching positions
741        let q = Tensor::from_vec(
742            vec![0.1; batch * heads * seq * dim],
743            &[batch, heads, seq, dim],
744        )
745        .unwrap();
746        let k = Tensor::from_vec(
747            vec![0.1; batch * heads * seq * dim],
748            &[batch, heads, seq, dim],
749        )
750        .unwrap();
751        let v = Tensor::from_vec(
752            vec![
753                1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
754            ],
755            &[batch, heads, seq, dim],
756        )
757        .unwrap();
758
759        let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
760        assert_eq!(out.shape(), &[batch, heads, seq, dim]);
761
762        // First position can only attend to position 0, so output = V[0] = [1,0,0,0]
763        let out_vec = out.to_vec();
764        assert!(
765            (out_vec[0] - 1.0).abs() < 1e-5,
766            "row 0, col 0 should be 1.0"
767        );
768        assert!((out_vec[1]).abs() < 1e-5, "row 0, col 1 should be 0.0");
769    }
770
771    #[test]
772    fn test_multihead_attention_backward_cpu() {
773        // Test that gradients flow through MHA in training mode (CPU path)
774        use axonml_autograd::backward;
775
776        let mha = MultiHeadAttention::new(32, 4);
777        let input = Variable::new(
778            Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
779            true,
780        );
781        let output = mha.forward(&input);
782        assert_eq!(output.shape(), vec![2, 4, 32]);
783
784        // Sum the output and backward
785        let loss = output.sum();
786        let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
787        backward(&loss, &ones);
788
789        // Input should have gradients
790        let grad = input.grad();
791        assert!(grad.is_some(), "Input gradient should exist");
792        let grad_data = grad.unwrap();
793        assert_eq!(grad_data.shape(), &[2, 4, 32]);
794
795        // Gradients should be non-zero
796        let grad_vec = grad_data.to_vec();
797        let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
798        assert!(non_zero, "Gradients should be non-zero");
799    }
800
801    #[test]
802    fn test_fused_attention_backward_cpu() {
803        // Test the FusedAttentionBackward autograd function directly on CPU
804        use axonml_autograd::functions::FusedAttentionBackward;
805        use axonml_autograd::grad_fn::GradientFunction;
806
807        let batch = 1;
808        let heads = 2;
809        let seq = 4;
810        let dim = 8;
811        let scale = 1.0 / (dim as f32).sqrt();
812
813        // Create random-ish tensors
814        let q_data: Vec<f32> = (0..batch * heads * seq * dim)
815            .map(|i| ((i as f32) * 0.01).sin())
816            .collect();
817        let k_data: Vec<f32> = (0..batch * heads * seq * dim)
818            .map(|i| ((i as f32) * 0.02).cos())
819            .collect();
820        let v_data: Vec<f32> = (0..batch * heads * seq * dim)
821            .map(|i| ((i as f32) * 0.03).sin() + 0.5)
822            .collect();
823
824        let q =
825            Tensor::from_vec(q_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
826        let k =
827            Tensor::from_vec(k_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
828        let v =
829            Tensor::from_vec(v_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
830
831        // Compute forward output using the fused CPU path
832        let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
833        assert_eq!(output.shape(), &[batch, heads, seq, dim]);
834
835        // Create backward function
836        let backward_fn = FusedAttentionBackward::new(
837            None,
838            None,
839            None,
840            q.clone(),
841            k.clone(),
842            v.clone(),
843            output.clone(),
844            scale,
845            false,
846        );
847
848        // Use ones as grad_output
849        let grad_output = Tensor::from_vec(
850            vec![1.0f32; batch * heads * seq * dim],
851            &[batch, heads, seq, dim],
852        )
853        .unwrap();
854
855        let grads = backward_fn.apply(&grad_output);
856        assert_eq!(grads.len(), 3);
857
858        let gq = grads[0].as_ref().expect("grad_Q should exist");
859        let gk = grads[1].as_ref().expect("grad_K should exist");
860        let gv = grads[2].as_ref().expect("grad_V should exist");
861
862        assert_eq!(gq.shape(), &[batch, heads, seq, dim]);
863        assert_eq!(gk.shape(), &[batch, heads, seq, dim]);
864        assert_eq!(gv.shape(), &[batch, heads, seq, dim]);
865
866        // Gradients should be finite
867        for val in gq
868            .to_vec()
869            .iter()
870            .chain(gk.to_vec().iter())
871            .chain(gv.to_vec().iter())
872        {
873            assert!(val.is_finite(), "Gradient should be finite, got {}", val);
874        }
875
876        // grad_V should be non-zero (it's P^T @ grad_output)
877        let gv_nonzero = gv.to_vec().iter().any(|&v| v.abs() > 1e-10);
878        assert!(gv_nonzero, "grad_V should be non-zero");
879    }
880
881    #[test]
882    fn test_fused_attention_backward_causal_cpu() {
883        // Test the backward with causal masking
884        use axonml_autograd::functions::FusedAttentionBackward;
885        use axonml_autograd::grad_fn::GradientFunction;
886
887        let batch = 1;
888        let heads = 1;
889        let seq = 4;
890        let dim = 4;
891        let scale = 1.0 / (dim as f32).sqrt();
892
893        let q = Tensor::from_vec(
894            vec![0.1f32; batch * heads * seq * dim],
895            &[batch, heads, seq, dim],
896        )
897        .unwrap();
898        let k = Tensor::from_vec(
899            vec![0.2f32; batch * heads * seq * dim],
900            &[batch, heads, seq, dim],
901        )
902        .unwrap();
903        let v = Tensor::from_vec(
904            vec![0.5f32; batch * heads * seq * dim],
905            &[batch, heads, seq, dim],
906        )
907        .unwrap();
908
909        let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
910
911        let backward_fn = FusedAttentionBackward::new(
912            None,
913            None,
914            None,
915            q.clone(),
916            k.clone(),
917            v.clone(),
918            output.clone(),
919            scale,
920            true,
921        );
922
923        let grad_output = Tensor::from_vec(
924            vec![1.0f32; batch * heads * seq * dim],
925            &[batch, heads, seq, dim],
926        )
927        .unwrap();
928
929        let grads = backward_fn.apply(&grad_output);
930        assert_eq!(grads.len(), 3);
931
932        let gq = grads[0].as_ref().unwrap();
933        let gk = grads[1].as_ref().unwrap();
934        let gv = grads[2].as_ref().unwrap();
935
936        // All grads should be finite
937        for val in gq
938            .to_vec()
939            .iter()
940            .chain(gk.to_vec().iter())
941            .chain(gv.to_vec().iter())
942        {
943            assert!(val.is_finite(), "Gradient should be finite, got {}", val);
944        }
945    }
946}