Skip to main content

axonml_nn/layers/
attention.rs

1//! Attention Mechanisms - Multi-Head Attention
2//!
3//! Implements scaled dot-product and multi-head attention.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::layers::Linear;
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17// =============================================================================
18// MultiHeadAttention
19// =============================================================================
20
21/// Multi-Head Attention mechanism.
22///
23/// Allows the model to jointly attend to information from different
24/// representation subspaces at different positions.
25///
26/// # Arguments
27/// * `embed_dim` - Total dimension of the model
28/// * `num_heads` - Number of parallel attention heads
29/// * `dropout` - Dropout probability (default: 0.0)
30///
31/// # Shape
32/// - Query: (L, N, E) or (N, L, E) if batch_first
33/// - Key: (S, N, E) or (N, S, E) if batch_first
34/// - Value: (S, N, E) or (N, S, E) if batch_first
35/// - Output: (L, N, E) or (N, L, E) if batch_first
36pub struct MultiHeadAttention {
37    /// Query projection.
38    q_proj: Linear,
39    /// Key projection.
40    k_proj: Linear,
41    /// Value projection.
42    v_proj: Linear,
43    /// Output projection.
44    out_proj: Linear,
45    /// Embedding dimension.
46    embed_dim: usize,
47    /// Number of attention heads.
48    num_heads: usize,
49    /// Dimension per head.
50    head_dim: usize,
51    /// Scaling factor.
52    scale: f32,
53    /// Whether input is batch first.
54    batch_first: bool,
55}
56
57impl MultiHeadAttention {
58    /// Creates a new MultiHeadAttention module.
59    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
60        Self::with_options(embed_dim, num_heads, 0.0, true)
61    }
62
63    /// Creates MultiHeadAttention with all options.
64    pub fn with_options(
65        embed_dim: usize,
66        num_heads: usize,
67        _dropout: f32,
68        batch_first: bool,
69    ) -> Self {
70        assert!(
71            embed_dim % num_heads == 0,
72            "embed_dim must be divisible by num_heads"
73        );
74
75        let head_dim = embed_dim / num_heads;
76        let scale = (head_dim as f32).sqrt().recip();
77
78        Self {
79            q_proj: Linear::new(embed_dim, embed_dim),
80            k_proj: Linear::new(embed_dim, embed_dim),
81            v_proj: Linear::new(embed_dim, embed_dim),
82            out_proj: Linear::new(embed_dim, embed_dim),
83            embed_dim,
84            num_heads,
85            head_dim,
86            scale,
87            batch_first,
88        }
89    }
90
91    /// Computes attention.
92    pub fn attention(
93        &self,
94        query: &Variable,
95        key: &Variable,
96        value: &Variable,
97        attn_mask: Option<&Variable>,
98    ) -> Variable {
99        let q_shape = query.shape();
100        let (batch_size, tgt_len, _) = if self.batch_first {
101            (q_shape[0], q_shape[1], q_shape[2])
102        } else {
103            (q_shape[1], q_shape[0], q_shape[2])
104        };
105        let src_len = if self.batch_first {
106            key.shape()[1]
107        } else {
108            key.shape()[0]
109        };
110
111        // Project Q, K, V
112        let q = self.q_proj.forward(query);
113        let k = self.k_proj.forward(key);
114        let v = self.v_proj.forward(value);
115
116        // Reshape for multi-head: (batch, seq, embed) -> (batch, heads, seq, head_dim)
117        // For simplicity, we'll work with the flat representation
118        let q_vec = q.data().to_vec();
119        let k_vec = k.data().to_vec();
120        let v_vec = v.data().to_vec();
121
122        // Compute attention scores: Q @ K^T / sqrt(d_k)
123        let mut attn_scores = vec![0.0f32; batch_size * self.num_heads * tgt_len * src_len];
124
125        for b in 0..batch_size {
126            for h in 0..self.num_heads {
127                for i in 0..tgt_len {
128                    for j in 0..src_len {
129                        let mut score = 0.0f32;
130                        for d in 0..self.head_dim {
131                            let q_idx = b * tgt_len * self.embed_dim
132                                + i * self.embed_dim
133                                + h * self.head_dim
134                                + d;
135                            let k_idx = b * src_len * self.embed_dim
136                                + j * self.embed_dim
137                                + h * self.head_dim
138                                + d;
139                            score += q_vec[q_idx] * k_vec[k_idx];
140                        }
141                        let attn_idx = b * self.num_heads * tgt_len * src_len
142                            + h * tgt_len * src_len
143                            + i * src_len
144                            + j;
145                        attn_scores[attn_idx] = score * self.scale;
146                    }
147                }
148            }
149        }
150
151        // Apply attention mask if provided
152        if let Some(mask) = attn_mask {
153            let mask_vec = mask.data().to_vec();
154            for (i, score) in attn_scores.iter_mut().enumerate() {
155                if mask_vec[i % mask_vec.len()] == 0.0 {
156                    *score = f32::NEG_INFINITY;
157                }
158            }
159        }
160
161        // Softmax over source sequence
162        let mut attn_weights = vec![0.0f32; batch_size * self.num_heads * tgt_len * src_len];
163        for b in 0..batch_size {
164            for h in 0..self.num_heads {
165                for i in 0..tgt_len {
166                    let base = b * self.num_heads * tgt_len * src_len
167                        + h * tgt_len * src_len
168                        + i * src_len;
169
170                    // Find max for numerical stability
171                    let max_score = (0..src_len)
172                        .map(|j| attn_scores[base + j])
173                        .fold(f32::NEG_INFINITY, f32::max);
174
175                    // Compute exp and sum
176                    let mut sum = 0.0f32;
177                    for j in 0..src_len {
178                        let exp_val = (attn_scores[base + j] - max_score).exp();
179                        attn_weights[base + j] = exp_val;
180                        sum += exp_val;
181                    }
182
183                    // Normalize
184                    for j in 0..src_len {
185                        attn_weights[base + j] /= sum;
186                    }
187                }
188            }
189        }
190
191        // Apply attention to values
192        let mut output_vec = vec![0.0f32; batch_size * tgt_len * self.embed_dim];
193        for b in 0..batch_size {
194            for h in 0..self.num_heads {
195                for i in 0..tgt_len {
196                    for d in 0..self.head_dim {
197                        let mut weighted_sum = 0.0f32;
198                        for j in 0..src_len {
199                            let attn_idx = b * self.num_heads * tgt_len * src_len
200                                + h * tgt_len * src_len
201                                + i * src_len
202                                + j;
203                            let v_idx = b * src_len * self.embed_dim
204                                + j * self.embed_dim
205                                + h * self.head_dim
206                                + d;
207                            weighted_sum += attn_weights[attn_idx] * v_vec[v_idx];
208                        }
209                        let out_idx = b * tgt_len * self.embed_dim
210                            + i * self.embed_dim
211                            + h * self.head_dim
212                            + d;
213                        output_vec[out_idx] = weighted_sum;
214                    }
215                }
216            }
217        }
218
219        let output_shape = if self.batch_first {
220            vec![batch_size, tgt_len, self.embed_dim]
221        } else {
222            vec![tgt_len, batch_size, self.embed_dim]
223        };
224
225        let output = Variable::new(
226            Tensor::from_vec(output_vec, &output_shape).unwrap(),
227            query.requires_grad(),
228        );
229
230        // Output projection
231        self.out_proj.forward(&output)
232    }
233}
234
235impl Module for MultiHeadAttention {
236    fn forward(&self, input: &Variable) -> Variable {
237        // Self-attention: query = key = value = input
238        self.attention(input, input, input, None)
239    }
240
241    fn parameters(&self) -> Vec<Parameter> {
242        let mut params = Vec::new();
243        params.extend(self.q_proj.parameters());
244        params.extend(self.k_proj.parameters());
245        params.extend(self.v_proj.parameters());
246        params.extend(self.out_proj.parameters());
247        params
248    }
249
250    fn named_parameters(&self) -> HashMap<String, Parameter> {
251        let mut params = HashMap::new();
252        for (name, param) in self.q_proj.named_parameters() {
253            params.insert(format!("q_proj.{name}"), param);
254        }
255        for (name, param) in self.k_proj.named_parameters() {
256            params.insert(format!("k_proj.{name}"), param);
257        }
258        for (name, param) in self.v_proj.named_parameters() {
259            params.insert(format!("v_proj.{name}"), param);
260        }
261        for (name, param) in self.out_proj.named_parameters() {
262            params.insert(format!("out_proj.{name}"), param);
263        }
264        params
265    }
266
267    fn name(&self) -> &'static str {
268        "MultiHeadAttention"
269    }
270}
271
272// =============================================================================
273// Tests
274// =============================================================================
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_multihead_attention_creation() {
282        let mha = MultiHeadAttention::new(512, 8);
283        assert_eq!(mha.embed_dim, 512);
284        assert_eq!(mha.num_heads, 8);
285        assert_eq!(mha.head_dim, 64);
286    }
287
288    #[test]
289    fn test_multihead_attention_forward() {
290        let mha = MultiHeadAttention::new(64, 4);
291        let input = Variable::new(
292            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
293            false,
294        );
295        let output = mha.forward(&input);
296        assert_eq!(output.shape(), vec![2, 10, 64]);
297    }
298
299    #[test]
300    fn test_cross_attention() {
301        let mha = MultiHeadAttention::new(64, 4);
302        let query = Variable::new(
303            Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
304            false,
305        );
306        let key_value = Variable::new(
307            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
308            false,
309        );
310        let output = mha.attention(&query, &key_value, &key_value, None);
311        assert_eq!(output.shape(), vec![2, 5, 64]);
312    }
313
314    #[test]
315    fn test_multihead_attention_parameters() {
316        let mha = MultiHeadAttention::new(64, 4);
317        let params = mha.parameters();
318        // Q, K, V, Out projections each have weight + bias = 8 total
319        assert_eq!(params.len(), 8);
320    }
321}