Skip to main content

axonml_nn/layers/
attention.rs

1//! Attention mechanisms — `MultiHeadAttention` and `CrossAttention`.
2//!
3//! 946 lines. `MultiHeadAttention` (Q/K/V projections, scaled dot-product
4//! with optional causal mask + dropout, output projection, fused CUDA
5//! dispatch via `scaled_dot_product_attention_fused`). `CrossAttention`
6//! (separate Q vs K/V sources for encoder-decoder models). Both support
7//! `with_options` for custom dropout rate, bias, and output projection
8//! configuration.
9//!
10//! # File
11//! `crates/axonml-nn/src/layers/attention.rs`
12//!
13//! # Author
14//! Andrew Jewell Sr. — AutomataNexus LLC
15//! ORCID: 0009-0005-2158-7060
16//!
17//! # Updated
18//! April 14, 2026 11:15 PM EST
19//!
20//! # Disclaimer
21//! Use at own risk. This software is provided "as is", without warranty of any
22//! kind, express or implied. The author and AutomataNexus shall not be held
23//! liable for any damages arising from the use of this software.
24
25use std::collections::HashMap;
26
27use axonml_autograd::Variable;
28#[cfg(feature = "cuda")]
29use axonml_autograd::functions::FusedAttentionBackward;
30#[cfg(feature = "cuda")]
31use axonml_autograd::grad_fn::GradFn;
32use axonml_tensor::Tensor;
33
34use crate::layers::Linear;
35use crate::module::Module;
36use crate::parameter::Parameter;
37
38// =============================================================================
39// MultiHeadAttention
40// =============================================================================
41
42/// Multi-Head Attention mechanism.
43///
44/// Allows the model to jointly attend to information from different
45/// representation subspaces at different positions.
46///
47/// # Arguments
48/// * `embed_dim` - Total dimension of the model
49/// * `num_heads` - Number of parallel attention heads
50/// * `dropout` - Dropout probability (default: 0.0)
51///
52/// # Shape
53/// - Query: (L, N, E) or (N, L, E) if batch_first
54/// - Key: (S, N, E) or (N, S, E) if batch_first
55/// - Value: (S, N, E) or (N, S, E) if batch_first
56/// - Output: (L, N, E) or (N, L, E) if batch_first
57pub struct MultiHeadAttention {
58    /// Query projection.
59    q_proj: Linear,
60    /// Key projection.
61    k_proj: Linear,
62    /// Value projection.
63    v_proj: Linear,
64    /// Output projection.
65    out_proj: Linear,
66    /// Embedding dimension.
67    embed_dim: usize,
68    /// Number of attention heads.
69    num_heads: usize,
70    /// Dimension per head.
71    head_dim: usize,
72    /// Scaling factor.
73    scale: f32,
74    /// Whether input is batch first.
75    batch_first: bool,
76}
77
78impl MultiHeadAttention {
79    /// Creates a new MultiHeadAttention module.
80    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
81        Self::with_options(embed_dim, num_heads, 0.0, true)
82    }
83
84    /// Creates MultiHeadAttention with all options.
85    pub fn with_options(
86        embed_dim: usize,
87        num_heads: usize,
88        _dropout: f32,
89        batch_first: bool,
90    ) -> Self {
91        assert!(
92            embed_dim % num_heads == 0,
93            "embed_dim must be divisible by num_heads"
94        );
95
96        let head_dim = embed_dim / num_heads;
97        let scale = (head_dim as f32).sqrt().recip();
98
99        Self {
100            q_proj: Linear::new(embed_dim, embed_dim),
101            k_proj: Linear::new(embed_dim, embed_dim),
102            v_proj: Linear::new(embed_dim, embed_dim),
103            out_proj: Linear::new(embed_dim, embed_dim),
104            embed_dim,
105            num_heads,
106            head_dim,
107            scale,
108            batch_first,
109        }
110    }
111
112    /// Expand attention mask to match scores shape via broadcast.
113    ///
114    /// Handles common mask shapes:
115    /// - [T, S] → [B, H, T, S] (same mask for all batches/heads)
116    /// - [B, 1, T, S] → [B, H, T, S] (per-batch, shared across heads)
117    /// - [B, H, T, S] → no expansion needed
118    ///
119    /// Works on both CPU and GPU — uses Variable::expand which preserves device.
120    #[allow(dead_code)]
121    fn expand_mask(
122        mask: &Variable,
123        batch_size: usize,
124        num_heads: usize,
125        tgt_len: usize,
126        src_len: usize,
127    ) -> Variable {
128        let mask_shape = mask.shape();
129        let target = [batch_size, num_heads, tgt_len, src_len];
130
131        if mask_shape == target {
132            return mask.clone();
133        }
134
135        // [T, S] → [1, 1, T, S] → expand to [B, H, T, S]
136        if mask_shape.len() == 2 {
137            let reshaped = mask.reshape(&[1, 1, tgt_len, src_len]);
138            return reshaped.expand(&target);
139        }
140
141        // [B, 1, T, S] → expand heads dim
142        if mask_shape.len() == 4 && mask_shape[1] == 1 {
143            return mask.expand(&target);
144        }
145
146        // [1, 1, T, S] → expand both
147        if mask_shape.len() == 4 && mask_shape[0] == 1 && mask_shape[1] == 1 {
148            return mask.expand(&target);
149        }
150
151        // Fallback: just clone
152        mask.clone()
153    }
154
155    /// Computes attention using batched matmul (BLAS-accelerated).
156    pub fn attention(
157        &self,
158        query: &Variable,
159        key: &Variable,
160        value: &Variable,
161        attn_mask: Option<&Variable>,
162    ) -> Variable {
163        let q_shape = query.shape();
164        let (batch_size, tgt_len, _) = if self.batch_first {
165            (q_shape[0], q_shape[1], q_shape[2])
166        } else {
167            (q_shape[1], q_shape[0], q_shape[2])
168        };
169        let src_len = if self.batch_first {
170            key.shape()[1]
171        } else {
172            key.shape()[0]
173        };
174
175        // Project Q, K, V  (all tracked through autograd)
176        let q = self.q_proj.forward(query);
177        let k = self.k_proj.forward(key);
178        let v = self.v_proj.forward(value);
179
180        // Reshape to multi-head: [batch, seq, embed] → [batch, seq, heads, head_dim]
181        // Then transpose to:     [batch, heads, seq, head_dim]
182        let q = q
183            .reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
184            .transpose(1, 2);
185        let k = k
186            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
187            .transpose(1, 2);
188        let v = v
189            .reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
190            .transpose(1, 2);
191
192        // GPU fast path: fused attention kernel avoids materializing the N*N
193        // attention matrix. Works for both inference and training when no mask
194        // is provided. The kernel computes Q@K^T * scale -> softmax -> @V in
195        // one pass per row. In training mode, a FusedAttentionBackward autograd
196        // function is attached that uses the CUDA backward kernel (or CPU fallback).
197        #[cfg(feature = "cuda")]
198        if q.data().device().is_gpu() && attn_mask.is_none() {
199            let is_training = axonml_autograd::no_grad::is_grad_enabled();
200            let q_tensor = q.data();
201            let k_tensor = k.data();
202            let v_tensor = v.data();
203
204            if let Some(attn_out) = q_tensor.fused_attention_cuda(
205                &k_tensor, &v_tensor, self.scale,
206                false, // not causal by default; causal mask would be in attn_mask
207            ) {
208                let attn_output = if is_training
209                    && (q.requires_grad() || k.requires_grad() || v.requires_grad())
210                {
211                    // Build autograd backward function for training
212                    let backward = FusedAttentionBackward::new(
213                        q.grad_fn().cloned(),
214                        k.grad_fn().cloned(),
215                        v.grad_fn().cloned(),
216                        q_tensor,
217                        k_tensor,
218                        v_tensor,
219                        attn_out.clone(),
220                        self.scale,
221                        false,
222                    );
223                    Variable::from_operation(attn_out, GradFn::new(backward), true)
224                } else {
225                    Variable::new(attn_out, false)
226                };
227                let attn_output =
228                    attn_output
229                        .transpose(1, 2)
230                        .reshape(&[batch_size, tgt_len, self.embed_dim]);
231                return self.out_proj.forward(&attn_output);
232            }
233            // Fall through to standard path if fused kernel fails
234        }
235
236        // Scaled dot-product attention: scores = Q @ K^T * scale
237        // K^T: [batch, heads, head_dim, src_len]
238        let k_t = k.transpose(2, 3);
239        // scores: [batch, heads, tgt_len, src_len]
240        let scores = q.matmul(&k_t).mul_scalar(self.scale);
241
242        // Apply attention mask (0 → -1e9 additive mask)
243        // Mask shapes: [tgt_len, src_len] (causal) or [batch, src_len] (padding)
244        // Scores shape: [batch, heads, tgt_len, src_len]
245        let scores = if let Some(mask) = attn_mask {
246            let mask_shape = mask.shape();
247            let mask_data = mask.data();
248            let scores_shape = scores.shape();
249            let total = scores_shape.iter().product::<usize>();
250
251            // GPU fast path: expand mask entirely on GPU via CUDA kernel
252            // Avoids GPU→CPU→GPU round-trip (9 mask expansions per forward pass)
253            #[cfg(feature = "cuda")]
254            if scores.data().device().is_gpu() {
255                // Ensure mask is on GPU (it's small, so upload is cheap if needed)
256                let mask_gpu = if mask_data.device().is_gpu() {
257                    mask_data.clone()
258                } else {
259                    mask_data.to_device(scores.data().device()).unwrap()
260                };
261
262                if let Some(expanded_tensor) = mask_gpu.mask_expand_cuda(
263                    &scores_shape,
264                    batch_size,
265                    self.num_heads,
266                    tgt_len,
267                    src_len,
268                ) {
269                    let additive_mask = Variable::new(expanded_tensor, false);
270                    return self.finish_attention(
271                        scores.add_var(&additive_mask),
272                        &v,
273                        batch_size,
274                        tgt_len,
275                    );
276                }
277                // Fall through to CPU path on unsupported shape
278            }
279
280            // CPU fallback: expand mask with nested loops
281            let mask_vec = mask_data.to_vec();
282            let additive: Vec<f32> = mask_vec
283                .iter()
284                .map(|&v| if v == 0.0 { -1e9 } else { 0.0 })
285                .collect();
286
287            let mut expanded = vec![0.0f32; total];
288
289            if mask_shape.len() == 2 && mask_shape[0] == tgt_len && mask_shape[1] == src_len {
290                // Causal mask [tgt_len, src_len] → broadcast over batch & heads
291                for b in 0..batch_size {
292                    for h in 0..self.num_heads {
293                        for i in 0..tgt_len {
294                            for j in 0..src_len {
295                                let idx = b * self.num_heads * tgt_len * src_len
296                                    + h * tgt_len * src_len
297                                    + i * src_len
298                                    + j;
299                                expanded[idx] = additive[i * src_len + j];
300                            }
301                        }
302                    }
303                }
304            } else if mask_shape.len() == 2
305                && mask_shape[0] == batch_size
306                && mask_shape[1] == src_len
307            {
308                // Padding mask [batch, src_len] → broadcast over heads & tgt positions
309                for b in 0..batch_size {
310                    for h in 0..self.num_heads {
311                        for i in 0..tgt_len {
312                            for j in 0..src_len {
313                                let idx = b * self.num_heads * tgt_len * src_len
314                                    + h * tgt_len * src_len
315                                    + i * src_len
316                                    + j;
317                                expanded[idx] = additive[b * src_len + j];
318                            }
319                        }
320                    }
321                }
322            } else {
323                // General: tile mask across scores using modular indexing
324                for (i, val) in expanded.iter_mut().enumerate() {
325                    *val = additive[i % additive.len()];
326                }
327            }
328
329            let mut additive_tensor =
330                Tensor::from_vec(expanded, &scores_shape).expect("tensor creation failed");
331            let scores_device = scores.data().device();
332            if scores_device.is_gpu() {
333                additive_tensor = additive_tensor
334                    .to_device(scores_device)
335                    .expect("device transfer failed");
336            }
337            let additive_mask = Variable::new(additive_tensor, false);
338            scores.add_var(&additive_mask)
339        } else {
340            scores
341        };
342
343        self.finish_attention(scores, &v, batch_size, tgt_len)
344    }
345
346    /// Softmax → weighted sum → reshape → output projection.
347    /// Shared by both GPU and CPU mask expansion paths.
348    fn finish_attention(
349        &self,
350        scores: Variable,
351        v: &Variable,
352        batch_size: usize,
353        tgt_len: usize,
354    ) -> Variable {
355        let attn_weights = scores.softmax(-1);
356        let attn_output = attn_weights.matmul(v);
357        let attn_output =
358            attn_output
359                .transpose(1, 2)
360                .reshape(&[batch_size, tgt_len, self.embed_dim]);
361        self.out_proj.forward(&attn_output)
362    }
363}
364
365impl Module for MultiHeadAttention {
366    fn forward(&self, input: &Variable) -> Variable {
367        // Self-attention: query = key = value = input
368        self.attention(input, input, input, None)
369    }
370
371    fn parameters(&self) -> Vec<Parameter> {
372        let mut params = Vec::new();
373        params.extend(self.q_proj.parameters());
374        params.extend(self.k_proj.parameters());
375        params.extend(self.v_proj.parameters());
376        params.extend(self.out_proj.parameters());
377        params
378    }
379
380    fn named_parameters(&self) -> HashMap<String, Parameter> {
381        let mut params = HashMap::new();
382        for (name, param) in self.q_proj.named_parameters() {
383            params.insert(format!("q_proj.{name}"), param);
384        }
385        for (name, param) in self.k_proj.named_parameters() {
386            params.insert(format!("k_proj.{name}"), param);
387        }
388        for (name, param) in self.v_proj.named_parameters() {
389            params.insert(format!("v_proj.{name}"), param);
390        }
391        for (name, param) in self.out_proj.named_parameters() {
392            params.insert(format!("out_proj.{name}"), param);
393        }
394        params
395    }
396
397    fn name(&self) -> &'static str {
398        "MultiHeadAttention"
399    }
400}
401
402// =============================================================================
403// CrossAttention
404// =============================================================================
405
406/// Cross-Attention mechanism for encoder-decoder architectures.
407///
408/// Queries come from the decoder, keys and values come from the encoder.
409/// This is the standard cross-attention used in Transformer decoders,
410/// seq2seq models, and vision-language models.
411///
412/// # Shape (batch_first=true)
413/// - Query (decoder): (N, T, E)
414/// - Memory (encoder): (N, S, E)
415/// - Output: (N, T, E)
416///
417/// where N=batch, T=target seq len, S=source seq len, E=embed_dim.
418pub struct CrossAttention {
419    /// Underlying multi-head attention.
420    mha: MultiHeadAttention,
421}
422
423impl CrossAttention {
424    /// Creates a new CrossAttention module.
425    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
426        Self {
427            mha: MultiHeadAttention::new(embed_dim, num_heads),
428        }
429    }
430
431    /// Creates CrossAttention with all options.
432    pub fn with_options(
433        embed_dim: usize,
434        num_heads: usize,
435        dropout: f32,
436        batch_first: bool,
437    ) -> Self {
438        Self {
439            mha: MultiHeadAttention::with_options(embed_dim, num_heads, dropout, batch_first),
440        }
441    }
442
443    /// Computes cross-attention.
444    ///
445    /// # Arguments
446    /// * `query` - Decoder hidden states (N, T, E)
447    /// * `memory` - Encoder output (N, S, E)
448    /// * `attn_mask` - Optional attention mask
449    pub fn cross_attention(
450        &self,
451        query: &Variable,
452        memory: &Variable,
453        attn_mask: Option<&Variable>,
454    ) -> Variable {
455        self.mha.attention(query, memory, memory, attn_mask)
456    }
457
458    /// Returns the embedding dimension.
459    pub fn embed_dim(&self) -> usize {
460        self.mha.embed_dim
461    }
462
463    /// Returns the number of heads.
464    pub fn num_heads(&self) -> usize {
465        self.mha.num_heads
466    }
467}
468
469impl Module for CrossAttention {
470    fn forward(&self, input: &Variable) -> Variable {
471        // When called as Module (single input), acts as self-attention.
472        // Use cross_attention() for encoder-decoder attention.
473        self.mha.forward(input)
474    }
475
476    fn parameters(&self) -> Vec<Parameter> {
477        self.mha.parameters()
478    }
479
480    fn named_parameters(&self) -> HashMap<String, Parameter> {
481        let mut params = HashMap::new();
482        for (name, param) in self.mha.named_parameters() {
483            params.insert(format!("mha.{name}"), param);
484        }
485        params
486    }
487
488    fn name(&self) -> &'static str {
489        "CrossAttention"
490    }
491}
492
493// =============================================================================
494// Fused Scaled Dot-Product Attention
495// =============================================================================
496
497/// Fused scaled dot-product attention on Tensors.
498///
499/// Computes `softmax(Q @ K^T * scale) @ V` without materializing the full
500/// N*N attention matrix. On GPU, this uses a CUDA kernel that processes one
501/// query row per thread. On CPU, falls back to standard matmul + softmax.
502///
503/// # Arguments
504/// * `q` - Query tensor `[B, H, Tq, D]`
505/// * `k` - Key tensor `[B, H, Tk, D]`
506/// * `v` - Value tensor `[B, H, Tk, D]`
507/// * `scale` - Scaling factor (typically `1/sqrt(head_dim)`)
508/// * `is_causal` - Whether to apply causal masking
509///
510/// # Returns
511/// Output tensor `[B, H, Tq, D]`
512///
513/// # Note on Flash Attention
514/// This is a fused kernel (Option B) — it avoids the N*N memory allocation
515/// but each thread still iterates over the full key sequence. True Flash
516/// Attention (Option A) would use shared memory tiling with online softmax
517/// (the Dao et al. algorithm), processing in blocks of ~64-128 and requiring:
518/// - Shared memory for Q/K/V tile loading
519/// - Online softmax with running max/sum correction across tiles
520/// - Two passes: forward for output + logsumexp, backward with recomputation
521/// - ~3x more complex kernel code but O(N) memory and better cache behavior
522///
523/// For sequences up to ~2048, this fused kernel provides most of the benefit.
524/// For longer sequences, use the tiled CPU Flash Attention in `axonml-llm`.
525pub fn scaled_dot_product_attention_fused(
526    q: &Tensor<f32>,
527    k: &Tensor<f32>,
528    v: &Tensor<f32>,
529    scale: f32,
530    is_causal: bool,
531) -> Tensor<f32> {
532    // Try GPU fused kernel
533    #[cfg(feature = "cuda")]
534    if q.device().is_gpu() {
535        if let Some(result) = q.fused_attention_cuda(k, v, scale, is_causal) {
536            return result;
537        }
538    }
539
540    // CPU fallback: standard matmul-based attention
541    let shape = q.shape();
542    let batch_size = shape[0];
543    let num_heads = shape[1];
544    let tgt_len = shape[2];
545    let head_dim = shape[3];
546    let src_len = k.shape()[2];
547
548    let q_data = q.to_vec();
549    let k_data = k.to_vec();
550    let v_data = v.to_vec();
551
552    let mut output = vec![0.0f32; batch_size * num_heads * tgt_len * head_dim];
553
554    for b in 0..batch_size {
555        for h in 0..num_heads {
556            for i in 0..tgt_len {
557                // Compute attention scores for row i
558                let mut scores = vec![0.0f32; src_len];
559                let mut max_score = f32::NEG_INFINITY;
560
561                for j in 0..src_len {
562                    if is_causal && j > i {
563                        scores[j] = f32::NEG_INFINITY;
564                        continue;
565                    }
566                    let mut score = 0.0f32;
567                    for d in 0..head_dim {
568                        let q_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
569                        let k_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
570                        score += q_data[q_idx] * k_data[k_idx];
571                    }
572                    score *= scale;
573                    scores[j] = score;
574                    if score > max_score {
575                        max_score = score;
576                    }
577                }
578
579                // Softmax
580                let mut sum_exp = 0.0f32;
581                for s in &mut scores {
582                    if *s > f32::NEG_INFINITY {
583                        *s = (*s - max_score).exp();
584                        sum_exp += *s;
585                    } else {
586                        *s = 0.0;
587                    }
588                }
589                let inv_sum = if sum_exp > 0.0 { 1.0 / sum_exp } else { 0.0 };
590
591                // Weighted sum of V
592                for d in 0..head_dim {
593                    let mut val = 0.0f32;
594                    for j in 0..src_len {
595                        let v_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
596                        val += scores[j] * v_data[v_idx];
597                    }
598                    let out_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
599                    output[out_idx] = val * inv_sum;
600                }
601            }
602        }
603    }
604
605    Tensor::from_vec(output, &[batch_size, num_heads, tgt_len, head_dim])
606        .expect("tensor creation failed")
607}
608
609// =============================================================================
610// Tests
611// =============================================================================
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[test]
618    fn test_multihead_attention_creation() {
619        let mha = MultiHeadAttention::new(512, 8);
620        assert_eq!(mha.embed_dim, 512);
621        assert_eq!(mha.num_heads, 8);
622        assert_eq!(mha.head_dim, 64);
623    }
624
625    #[test]
626    fn test_multihead_attention_forward() {
627        let mha = MultiHeadAttention::new(64, 4);
628        let input = Variable::new(
629            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
630            false,
631        );
632        let output = mha.forward(&input);
633        assert_eq!(output.shape(), vec![2, 10, 64]);
634    }
635
636    #[test]
637    fn test_cross_attention() {
638        let mha = MultiHeadAttention::new(64, 4);
639        let query = Variable::new(
640            Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
641            false,
642        );
643        let key_value = Variable::new(
644            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
645            false,
646        );
647        let output = mha.attention(&query, &key_value, &key_value, None);
648        assert_eq!(output.shape(), vec![2, 5, 64]);
649    }
650
651    #[test]
652    fn test_multihead_attention_parameters() {
653        let mha = MultiHeadAttention::new(64, 4);
654        let params = mha.parameters();
655        // Q, K, V, Out projections each have weight + bias = 8 total
656        assert_eq!(params.len(), 8);
657    }
658
659    #[test]
660    fn test_cross_attention_creation() {
661        let ca = CrossAttention::new(256, 8);
662        assert_eq!(ca.embed_dim(), 256);
663        assert_eq!(ca.num_heads(), 8);
664    }
665
666    #[test]
667    fn test_cross_attention_forward() {
668        let ca = CrossAttention::new(64, 4);
669        // Decoder query: (batch=2, tgt_len=5, embed=64)
670        let query = Variable::new(
671            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
672            false,
673        );
674        // Encoder memory: (batch=2, src_len=10, embed=64)
675        let memory = Variable::new(
676            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
677            false,
678        );
679        let output = ca.cross_attention(&query, &memory, None);
680        assert_eq!(output.shape(), vec![2, 5, 64]);
681    }
682
683    #[test]
684    fn test_cross_attention_self_attention_fallback() {
685        let ca = CrossAttention::new(64, 4);
686        let input = Variable::new(
687            Tensor::from_vec(vec![1.0; 2 * 8 * 64], &[2, 8, 64]).expect("tensor creation failed"),
688            false,
689        );
690        // Module::forward does self-attention
691        let output = ca.forward(&input);
692        assert_eq!(output.shape(), vec![2, 8, 64]);
693    }
694
695    #[test]
696    fn test_cross_attention_parameters() {
697        let ca = CrossAttention::new(64, 4);
698        let params = ca.parameters();
699        assert_eq!(params.len(), 8); // Q, K, V, Out × (weight + bias)
700        let named = ca.named_parameters();
701        assert!(named.contains_key("mha.q_proj.weight"));
702        assert!(named.contains_key("mha.out_proj.bias"));
703    }
704
705    #[test]
706    fn test_fused_attention_cpu() {
707        // Test fused attention on CPU (fallback path)
708        let batch = 2;
709        let heads = 4;
710        let seq = 8;
711        let dim = 16;
712        let scale = 1.0 / (dim as f32).sqrt();
713
714        let q = Tensor::from_vec(
715            vec![0.1; batch * heads * seq * dim],
716            &[batch, heads, seq, dim],
717        )
718        .unwrap();
719        let k = Tensor::from_vec(
720            vec![0.1; batch * heads * seq * dim],
721            &[batch, heads, seq, dim],
722        )
723        .unwrap();
724        let v = Tensor::from_vec(
725            vec![0.5; batch * heads * seq * dim],
726            &[batch, heads, seq, dim],
727        )
728        .unwrap();
729
730        let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
731        assert_eq!(out.shape(), &[batch, heads, seq, dim]);
732
733        // With uniform V=0.5, output should be close to 0.5
734        let out_vec = out.to_vec();
735        for val in &out_vec {
736            assert!((*val - 0.5).abs() < 0.01, "Expected ~0.5, got {}", val);
737        }
738    }
739
740    #[test]
741    fn test_fused_attention_causal() {
742        let batch = 1;
743        let heads = 1;
744        let seq = 4;
745        let dim = 4;
746        let scale = 1.0 / (dim as f32).sqrt();
747
748        // Q and K are identity-like so attention focuses on matching positions
749        let q = Tensor::from_vec(
750            vec![0.1; batch * heads * seq * dim],
751            &[batch, heads, seq, dim],
752        )
753        .unwrap();
754        let k = Tensor::from_vec(
755            vec![0.1; batch * heads * seq * dim],
756            &[batch, heads, seq, dim],
757        )
758        .unwrap();
759        let v = Tensor::from_vec(
760            vec![
761                1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
762            ],
763            &[batch, heads, seq, dim],
764        )
765        .unwrap();
766
767        let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
768        assert_eq!(out.shape(), &[batch, heads, seq, dim]);
769
770        // First position can only attend to position 0, so output = V[0] = [1,0,0,0]
771        let out_vec = out.to_vec();
772        assert!(
773            (out_vec[0] - 1.0).abs() < 1e-5,
774            "row 0, col 0 should be 1.0"
775        );
776        assert!((out_vec[1]).abs() < 1e-5, "row 0, col 1 should be 0.0");
777    }
778
779    #[test]
780    fn test_multihead_attention_backward_cpu() {
781        // Test that gradients flow through MHA in training mode (CPU path)
782        use axonml_autograd::backward;
783
784        let mha = MultiHeadAttention::new(32, 4);
785        let input = Variable::new(
786            Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
787            true,
788        );
789        let output = mha.forward(&input);
790        assert_eq!(output.shape(), vec![2, 4, 32]);
791
792        // Sum the output and backward
793        let loss = output.sum();
794        let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
795        backward(&loss, &ones);
796
797        // Input should have gradients
798        let grad = input.grad();
799        assert!(grad.is_some(), "Input gradient should exist");
800        let grad_data = grad.unwrap();
801        assert_eq!(grad_data.shape(), &[2, 4, 32]);
802
803        // Gradients should be non-zero
804        let grad_vec = grad_data.to_vec();
805        let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
806        assert!(non_zero, "Gradients should be non-zero");
807    }
808
809    #[test]
810    fn test_fused_attention_backward_cpu() {
811        // Test the FusedAttentionBackward autograd function directly on CPU
812        use axonml_autograd::functions::FusedAttentionBackward;
813        use axonml_autograd::grad_fn::GradientFunction;
814
815        let batch = 1;
816        let heads = 2;
817        let seq = 4;
818        let dim = 8;
819        let scale = 1.0 / (dim as f32).sqrt();
820
821        // Create random-ish tensors
822        let q_data: Vec<f32> = (0..batch * heads * seq * dim)
823            .map(|i| ((i as f32) * 0.01).sin())
824            .collect();
825        let k_data: Vec<f32> = (0..batch * heads * seq * dim)
826            .map(|i| ((i as f32) * 0.02).cos())
827            .collect();
828        let v_data: Vec<f32> = (0..batch * heads * seq * dim)
829            .map(|i| ((i as f32) * 0.03).sin() + 0.5)
830            .collect();
831
832        let q =
833            Tensor::from_vec(q_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
834        let k =
835            Tensor::from_vec(k_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
836        let v =
837            Tensor::from_vec(v_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
838
839        // Compute forward output using the fused CPU path
840        let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
841        assert_eq!(output.shape(), &[batch, heads, seq, dim]);
842
843        // Create backward function
844        let backward_fn = FusedAttentionBackward::new(
845            None,
846            None,
847            None,
848            q.clone(),
849            k.clone(),
850            v.clone(),
851            output.clone(),
852            scale,
853            false,
854        );
855
856        // Use ones as grad_output
857        let grad_output = Tensor::from_vec(
858            vec![1.0f32; batch * heads * seq * dim],
859            &[batch, heads, seq, dim],
860        )
861        .unwrap();
862
863        let grads = backward_fn.apply(&grad_output);
864        assert_eq!(grads.len(), 3);
865
866        let gq = grads[0].as_ref().expect("grad_Q should exist");
867        let gk = grads[1].as_ref().expect("grad_K should exist");
868        let gv = grads[2].as_ref().expect("grad_V should exist");
869
870        assert_eq!(gq.shape(), &[batch, heads, seq, dim]);
871        assert_eq!(gk.shape(), &[batch, heads, seq, dim]);
872        assert_eq!(gv.shape(), &[batch, heads, seq, dim]);
873
874        // Gradients should be finite
875        for val in gq
876            .to_vec()
877            .iter()
878            .chain(gk.to_vec().iter())
879            .chain(gv.to_vec().iter())
880        {
881            assert!(val.is_finite(), "Gradient should be finite, got {}", val);
882        }
883
884        // grad_V should be non-zero (it's P^T @ grad_output)
885        let gv_nonzero = gv.to_vec().iter().any(|&v| v.abs() > 1e-10);
886        assert!(gv_nonzero, "grad_V should be non-zero");
887    }
888
889    #[test]
890    fn test_fused_attention_backward_causal_cpu() {
891        // Test the backward with causal masking
892        use axonml_autograd::functions::FusedAttentionBackward;
893        use axonml_autograd::grad_fn::GradientFunction;
894
895        let batch = 1;
896        let heads = 1;
897        let seq = 4;
898        let dim = 4;
899        let scale = 1.0 / (dim as f32).sqrt();
900
901        let q = Tensor::from_vec(
902            vec![0.1f32; batch * heads * seq * dim],
903            &[batch, heads, seq, dim],
904        )
905        .unwrap();
906        let k = Tensor::from_vec(
907            vec![0.2f32; batch * heads * seq * dim],
908            &[batch, heads, seq, dim],
909        )
910        .unwrap();
911        let v = Tensor::from_vec(
912            vec![0.5f32; batch * heads * seq * dim],
913            &[batch, heads, seq, dim],
914        )
915        .unwrap();
916
917        let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
918
919        let backward_fn = FusedAttentionBackward::new(
920            None,
921            None,
922            None,
923            q.clone(),
924            k.clone(),
925            v.clone(),
926            output.clone(),
927            scale,
928            true,
929        );
930
931        let grad_output = Tensor::from_vec(
932            vec![1.0f32; batch * heads * seq * dim],
933            &[batch, heads, seq, dim],
934        )
935        .unwrap();
936
937        let grads = backward_fn.apply(&grad_output);
938        assert_eq!(grads.len(), 3);
939
940        let gq = grads[0].as_ref().unwrap();
941        let gk = grads[1].as_ref().unwrap();
942        let gv = grads[2].as_ref().unwrap();
943
944        // All grads should be finite
945        for val in gq
946            .to_vec()
947            .iter()
948            .chain(gk.to_vec().iter())
949            .chain(gv.to_vec().iter())
950        {
951            assert!(val.is_finite(), "Gradient should be finite, got {}", val);
952        }
953    }
954}