Skip to main content

axonml_nn/layers/
attention.rs

1//! Attention Mechanisms - Multi-Head Attention
2//!
3//! Implements scaled dot-product and multi-head attention.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::layers::Linear;
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17// =============================================================================
18// MultiHeadAttention
19// =============================================================================
20
21/// Multi-Head Attention mechanism.
22///
23/// Allows the model to jointly attend to information from different
24/// representation subspaces at different positions.
25///
26/// # Arguments
27/// * `embed_dim` - Total dimension of the model
28/// * `num_heads` - Number of parallel attention heads
29/// * `dropout` - Dropout probability (default: 0.0)
30///
31/// # Shape
32/// - Query: (L, N, E) or (N, L, E) if batch_first
33/// - Key: (S, N, E) or (N, S, E) if batch_first
34/// - Value: (S, N, E) or (N, S, E) if batch_first
35/// - Output: (L, N, E) or (N, L, E) if batch_first
36pub struct MultiHeadAttention {
37    /// Query projection.
38    q_proj: Linear,
39    /// Key projection.
40    k_proj: Linear,
41    /// Value projection.
42    v_proj: Linear,
43    /// Output projection.
44    out_proj: Linear,
45    /// Embedding dimension.
46    embed_dim: usize,
47    /// Number of attention heads.
48    num_heads: usize,
49    /// Dimension per head.
50    head_dim: usize,
51    /// Scaling factor.
52    scale: f32,
53    /// Whether input is batch first.
54    batch_first: bool,
55}
56
57impl MultiHeadAttention {
58    /// Creates a new MultiHeadAttention module.
59    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
60        Self::with_options(embed_dim, num_heads, 0.0, true)
61    }
62
63    /// Creates MultiHeadAttention with all options.
64    pub fn with_options(
65        embed_dim: usize,
66        num_heads: usize,
67        _dropout: f32,
68        batch_first: bool,
69    ) -> Self {
70        assert!(
71            embed_dim % num_heads == 0,
72            "embed_dim must be divisible by num_heads"
73        );
74
75        let head_dim = embed_dim / num_heads;
76        let scale = (head_dim as f32).sqrt().recip();
77
78        Self {
79            q_proj: Linear::new(embed_dim, embed_dim),
80            k_proj: Linear::new(embed_dim, embed_dim),
81            v_proj: Linear::new(embed_dim, embed_dim),
82            out_proj: Linear::new(embed_dim, embed_dim),
83            embed_dim,
84            num_heads,
85            head_dim,
86            scale,
87            batch_first,
88        }
89    }
90
91    /// Computes attention.
92    pub fn attention(
93        &self,
94        query: &Variable,
95        key: &Variable,
96        value: &Variable,
97        attn_mask: Option<&Variable>,
98    ) -> Variable {
99        let q_shape = query.shape();
100        let (batch_size, tgt_len, _) = if self.batch_first {
101            (q_shape[0], q_shape[1], q_shape[2])
102        } else {
103            (q_shape[1], q_shape[0], q_shape[2])
104        };
105        let src_len = if self.batch_first {
106            key.shape()[1]
107        } else {
108            key.shape()[0]
109        };
110
111        // Project Q, K, V
112        let q = self.q_proj.forward(query);
113        let k = self.k_proj.forward(key);
114        let v = self.v_proj.forward(value);
115
116        // Reshape for multi-head: (batch, seq, embed) -> (batch, heads, seq, head_dim)
117        // For simplicity, we'll work with the flat representation
118        let q_vec = q.data().to_vec();
119        let k_vec = k.data().to_vec();
120        let v_vec = v.data().to_vec();
121
122        // Compute attention scores: Q @ K^T / sqrt(d_k)
123        let mut attn_scores = vec![0.0f32; batch_size * self.num_heads * tgt_len * src_len];
124
125        for b in 0..batch_size {
126            for h in 0..self.num_heads {
127                for i in 0..tgt_len {
128                    for j in 0..src_len {
129                        let mut score = 0.0f32;
130                        for d in 0..self.head_dim {
131                            let q_idx = b * tgt_len * self.embed_dim
132                                + i * self.embed_dim
133                                + h * self.head_dim
134                                + d;
135                            let k_idx = b * src_len * self.embed_dim
136                                + j * self.embed_dim
137                                + h * self.head_dim
138                                + d;
139                            score += q_vec[q_idx] * k_vec[k_idx];
140                        }
141                        let attn_idx = b * self.num_heads * tgt_len * src_len
142                            + h * tgt_len * src_len
143                            + i * src_len
144                            + j;
145                        attn_scores[attn_idx] = score * self.scale;
146                    }
147                }
148            }
149        }
150
151        // Apply attention mask if provided
152        if let Some(mask) = attn_mask {
153            let mask_vec = mask.data().to_vec();
154            for (i, score) in attn_scores.iter_mut().enumerate() {
155                if mask_vec[i % mask_vec.len()] == 0.0 {
156                    *score = f32::NEG_INFINITY;
157                }
158            }
159        }
160
161        // Softmax over source sequence
162        let mut attn_weights = vec![0.0f32; batch_size * self.num_heads * tgt_len * src_len];
163        for b in 0..batch_size {
164            for h in 0..self.num_heads {
165                for i in 0..tgt_len {
166                    let base = b * self.num_heads * tgt_len * src_len
167                        + h * tgt_len * src_len
168                        + i * src_len;
169
170                    // Find max for numerical stability
171                    let max_score = (0..src_len)
172                        .map(|j| attn_scores[base + j])
173                        .fold(f32::NEG_INFINITY, f32::max);
174
175                    // Compute exp and sum
176                    let mut sum = 0.0f32;
177                    for j in 0..src_len {
178                        let exp_val = (attn_scores[base + j] - max_score).exp();
179                        attn_weights[base + j] = exp_val;
180                        sum += exp_val;
181                    }
182
183                    // Normalize
184                    for j in 0..src_len {
185                        attn_weights[base + j] /= sum;
186                    }
187                }
188            }
189        }
190
191        // Apply attention to values
192        let mut output_vec = vec![0.0f32; batch_size * tgt_len * self.embed_dim];
193        for b in 0..batch_size {
194            for h in 0..self.num_heads {
195                for i in 0..tgt_len {
196                    for d in 0..self.head_dim {
197                        let mut weighted_sum = 0.0f32;
198                        for j in 0..src_len {
199                            let attn_idx = b * self.num_heads * tgt_len * src_len
200                                + h * tgt_len * src_len
201                                + i * src_len
202                                + j;
203                            let v_idx = b * src_len * self.embed_dim
204                                + j * self.embed_dim
205                                + h * self.head_dim
206                                + d;
207                            weighted_sum += attn_weights[attn_idx] * v_vec[v_idx];
208                        }
209                        let out_idx = b * tgt_len * self.embed_dim
210                            + i * self.embed_dim
211                            + h * self.head_dim
212                            + d;
213                        output_vec[out_idx] = weighted_sum;
214                    }
215                }
216            }
217        }
218
219        let output_shape = if self.batch_first {
220            vec![batch_size, tgt_len, self.embed_dim]
221        } else {
222            vec![tgt_len, batch_size, self.embed_dim]
223        };
224
225        let output = Variable::new(
226            Tensor::from_vec(output_vec, &output_shape).unwrap(),
227            query.requires_grad(),
228        );
229
230        // Output projection
231        self.out_proj.forward(&output)
232    }
233}
234
235impl Module for MultiHeadAttention {
236    fn forward(&self, input: &Variable) -> Variable {
237        // Self-attention: query = key = value = input
238        self.attention(input, input, input, None)
239    }
240
241    fn parameters(&self) -> Vec<Parameter> {
242        let mut params = Vec::new();
243        params.extend(self.q_proj.parameters());
244        params.extend(self.k_proj.parameters());
245        params.extend(self.v_proj.parameters());
246        params.extend(self.out_proj.parameters());
247        params
248    }
249
250    fn named_parameters(&self) -> HashMap<String, Parameter> {
251        let mut params = HashMap::new();
252        for (name, param) in self.q_proj.named_parameters() {
253            params.insert(format!("q_proj.{name}"), param);
254        }
255        for (name, param) in self.k_proj.named_parameters() {
256            params.insert(format!("k_proj.{name}"), param);
257        }
258        for (name, param) in self.v_proj.named_parameters() {
259            params.insert(format!("v_proj.{name}"), param);
260        }
261        for (name, param) in self.out_proj.named_parameters() {
262            params.insert(format!("out_proj.{name}"), param);
263        }
264        params
265    }
266
267    fn name(&self) -> &'static str {
268        "MultiHeadAttention"
269    }
270}
271
272// =============================================================================
273// CrossAttention
274// =============================================================================
275
276/// Cross-Attention mechanism for encoder-decoder architectures.
277///
278/// Queries come from the decoder, keys and values come from the encoder.
279/// This is the standard cross-attention used in Transformer decoders,
280/// seq2seq models, and vision-language models.
281///
282/// # Shape (batch_first=true)
283/// - Query (decoder): (N, T, E)
284/// - Memory (encoder): (N, S, E)
285/// - Output: (N, T, E)
286///
287/// where N=batch, T=target seq len, S=source seq len, E=embed_dim.
288pub struct CrossAttention {
289    /// Underlying multi-head attention.
290    mha: MultiHeadAttention,
291}
292
293impl CrossAttention {
294    /// Creates a new CrossAttention module.
295    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
296        Self {
297            mha: MultiHeadAttention::new(embed_dim, num_heads),
298        }
299    }
300
301    /// Creates CrossAttention with all options.
302    pub fn with_options(embed_dim: usize, num_heads: usize, dropout: f32, batch_first: bool) -> Self {
303        Self {
304            mha: MultiHeadAttention::with_options(embed_dim, num_heads, dropout, batch_first),
305        }
306    }
307
308    /// Computes cross-attention.
309    ///
310    /// # Arguments
311    /// * `query` - Decoder hidden states (N, T, E)
312    /// * `memory` - Encoder output (N, S, E)
313    /// * `attn_mask` - Optional attention mask
314    pub fn cross_attention(
315        &self,
316        query: &Variable,
317        memory: &Variable,
318        attn_mask: Option<&Variable>,
319    ) -> Variable {
320        self.mha.attention(query, memory, memory, attn_mask)
321    }
322
323    /// Returns the embedding dimension.
324    pub fn embed_dim(&self) -> usize {
325        self.mha.embed_dim
326    }
327
328    /// Returns the number of heads.
329    pub fn num_heads(&self) -> usize {
330        self.mha.num_heads
331    }
332}
333
334impl Module for CrossAttention {
335    fn forward(&self, input: &Variable) -> Variable {
336        // When called as Module (single input), acts as self-attention.
337        // Use cross_attention() for encoder-decoder attention.
338        self.mha.forward(input)
339    }
340
341    fn parameters(&self) -> Vec<Parameter> {
342        self.mha.parameters()
343    }
344
345    fn named_parameters(&self) -> HashMap<String, Parameter> {
346        let mut params = HashMap::new();
347        for (name, param) in self.mha.named_parameters() {
348            params.insert(format!("mha.{name}"), param);
349        }
350        params
351    }
352
353    fn name(&self) -> &'static str {
354        "CrossAttention"
355    }
356}
357
358// =============================================================================
359// Tests
360// =============================================================================
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_multihead_attention_creation() {
368        let mha = MultiHeadAttention::new(512, 8);
369        assert_eq!(mha.embed_dim, 512);
370        assert_eq!(mha.num_heads, 8);
371        assert_eq!(mha.head_dim, 64);
372    }
373
374    #[test]
375    fn test_multihead_attention_forward() {
376        let mha = MultiHeadAttention::new(64, 4);
377        let input = Variable::new(
378            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
379            false,
380        );
381        let output = mha.forward(&input);
382        assert_eq!(output.shape(), vec![2, 10, 64]);
383    }
384
385    #[test]
386    fn test_cross_attention() {
387        let mha = MultiHeadAttention::new(64, 4);
388        let query = Variable::new(
389            Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
390            false,
391        );
392        let key_value = Variable::new(
393            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
394            false,
395        );
396        let output = mha.attention(&query, &key_value, &key_value, None);
397        assert_eq!(output.shape(), vec![2, 5, 64]);
398    }
399
400    #[test]
401    fn test_multihead_attention_parameters() {
402        let mha = MultiHeadAttention::new(64, 4);
403        let params = mha.parameters();
404        // Q, K, V, Out projections each have weight + bias = 8 total
405        assert_eq!(params.len(), 8);
406    }
407
408    #[test]
409    fn test_cross_attention_creation() {
410        let ca = CrossAttention::new(256, 8);
411        assert_eq!(ca.embed_dim(), 256);
412        assert_eq!(ca.num_heads(), 8);
413    }
414
415    #[test]
416    fn test_cross_attention_forward() {
417        let ca = CrossAttention::new(64, 4);
418        // Decoder query: (batch=2, tgt_len=5, embed=64)
419        let query = Variable::new(
420            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
421            false,
422        );
423        // Encoder memory: (batch=2, src_len=10, embed=64)
424        let memory = Variable::new(
425            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
426            false,
427        );
428        let output = ca.cross_attention(&query, &memory, None);
429        assert_eq!(output.shape(), vec![2, 5, 64]);
430    }
431
432    #[test]
433    fn test_cross_attention_self_attention_fallback() {
434        let ca = CrossAttention::new(64, 4);
435        let input = Variable::new(
436            Tensor::from_vec(vec![1.0; 2 * 8 * 64], &[2, 8, 64]).unwrap(),
437            false,
438        );
439        // Module::forward does self-attention
440        let output = ca.forward(&input);
441        assert_eq!(output.shape(), vec![2, 8, 64]);
442    }
443
444    #[test]
445    fn test_cross_attention_parameters() {
446        let ca = CrossAttention::new(64, 4);
447        let params = ca.parameters();
448        assert_eq!(params.len(), 8); // Q, K, V, Out × (weight + bias)
449        let named = ca.named_parameters();
450        assert!(named.contains_key("mha.q_proj.weight"));
451        assert!(named.contains_key("mha.out_proj.bias"));
452    }
453}