1use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::layers::Linear;
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17pub struct MultiHeadAttention {
37 q_proj: Linear,
39 k_proj: Linear,
41 v_proj: Linear,
43 out_proj: Linear,
45 embed_dim: usize,
47 num_heads: usize,
49 head_dim: usize,
51 scale: f32,
53 batch_first: bool,
55}
56
57impl MultiHeadAttention {
58 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
60 Self::with_options(embed_dim, num_heads, 0.0, true)
61 }
62
63 pub fn with_options(
65 embed_dim: usize,
66 num_heads: usize,
67 _dropout: f32,
68 batch_first: bool,
69 ) -> Self {
70 assert!(
71 embed_dim % num_heads == 0,
72 "embed_dim must be divisible by num_heads"
73 );
74
75 let head_dim = embed_dim / num_heads;
76 let scale = (head_dim as f32).sqrt().recip();
77
78 Self {
79 q_proj: Linear::new(embed_dim, embed_dim),
80 k_proj: Linear::new(embed_dim, embed_dim),
81 v_proj: Linear::new(embed_dim, embed_dim),
82 out_proj: Linear::new(embed_dim, embed_dim),
83 embed_dim,
84 num_heads,
85 head_dim,
86 scale,
87 batch_first,
88 }
89 }
90
91 pub fn attention(
93 &self,
94 query: &Variable,
95 key: &Variable,
96 value: &Variable,
97 attn_mask: Option<&Variable>,
98 ) -> Variable {
99 let q_shape = query.shape();
100 let (batch_size, tgt_len, _) = if self.batch_first {
101 (q_shape[0], q_shape[1], q_shape[2])
102 } else {
103 (q_shape[1], q_shape[0], q_shape[2])
104 };
105 let src_len = if self.batch_first {
106 key.shape()[1]
107 } else {
108 key.shape()[0]
109 };
110
111 let q = self.q_proj.forward(query);
113 let k = self.k_proj.forward(key);
114 let v = self.v_proj.forward(value);
115
116 let q_vec = q.data().to_vec();
119 let k_vec = k.data().to_vec();
120 let v_vec = v.data().to_vec();
121
122 let mut attn_scores = vec![0.0f32; batch_size * self.num_heads * tgt_len * src_len];
124
125 for b in 0..batch_size {
126 for h in 0..self.num_heads {
127 for i in 0..tgt_len {
128 for j in 0..src_len {
129 let mut score = 0.0f32;
130 for d in 0..self.head_dim {
131 let q_idx = b * tgt_len * self.embed_dim
132 + i * self.embed_dim
133 + h * self.head_dim
134 + d;
135 let k_idx = b * src_len * self.embed_dim
136 + j * self.embed_dim
137 + h * self.head_dim
138 + d;
139 score += q_vec[q_idx] * k_vec[k_idx];
140 }
141 let attn_idx = b * self.num_heads * tgt_len * src_len
142 + h * tgt_len * src_len
143 + i * src_len
144 + j;
145 attn_scores[attn_idx] = score * self.scale;
146 }
147 }
148 }
149 }
150
151 if let Some(mask) = attn_mask {
153 let mask_vec = mask.data().to_vec();
154 for (i, score) in attn_scores.iter_mut().enumerate() {
155 if mask_vec[i % mask_vec.len()] == 0.0 {
156 *score = f32::NEG_INFINITY;
157 }
158 }
159 }
160
161 let mut attn_weights = vec![0.0f32; batch_size * self.num_heads * tgt_len * src_len];
163 for b in 0..batch_size {
164 for h in 0..self.num_heads {
165 for i in 0..tgt_len {
166 let base = b * self.num_heads * tgt_len * src_len
167 + h * tgt_len * src_len
168 + i * src_len;
169
170 let max_score = (0..src_len)
172 .map(|j| attn_scores[base + j])
173 .fold(f32::NEG_INFINITY, f32::max);
174
175 let mut sum = 0.0f32;
177 for j in 0..src_len {
178 let exp_val = (attn_scores[base + j] - max_score).exp();
179 attn_weights[base + j] = exp_val;
180 sum += exp_val;
181 }
182
183 for j in 0..src_len {
185 attn_weights[base + j] /= sum;
186 }
187 }
188 }
189 }
190
191 let mut output_vec = vec![0.0f32; batch_size * tgt_len * self.embed_dim];
193 for b in 0..batch_size {
194 for h in 0..self.num_heads {
195 for i in 0..tgt_len {
196 for d in 0..self.head_dim {
197 let mut weighted_sum = 0.0f32;
198 for j in 0..src_len {
199 let attn_idx = b * self.num_heads * tgt_len * src_len
200 + h * tgt_len * src_len
201 + i * src_len
202 + j;
203 let v_idx = b * src_len * self.embed_dim
204 + j * self.embed_dim
205 + h * self.head_dim
206 + d;
207 weighted_sum += attn_weights[attn_idx] * v_vec[v_idx];
208 }
209 let out_idx = b * tgt_len * self.embed_dim
210 + i * self.embed_dim
211 + h * self.head_dim
212 + d;
213 output_vec[out_idx] = weighted_sum;
214 }
215 }
216 }
217 }
218
219 let output_shape = if self.batch_first {
220 vec![batch_size, tgt_len, self.embed_dim]
221 } else {
222 vec![tgt_len, batch_size, self.embed_dim]
223 };
224
225 let output = Variable::new(
226 Tensor::from_vec(output_vec, &output_shape).unwrap(),
227 query.requires_grad(),
228 );
229
230 self.out_proj.forward(&output)
232 }
233}
234
235impl Module for MultiHeadAttention {
236 fn forward(&self, input: &Variable) -> Variable {
237 self.attention(input, input, input, None)
239 }
240
241 fn parameters(&self) -> Vec<Parameter> {
242 let mut params = Vec::new();
243 params.extend(self.q_proj.parameters());
244 params.extend(self.k_proj.parameters());
245 params.extend(self.v_proj.parameters());
246 params.extend(self.out_proj.parameters());
247 params
248 }
249
250 fn named_parameters(&self) -> HashMap<String, Parameter> {
251 let mut params = HashMap::new();
252 for (name, param) in self.q_proj.named_parameters() {
253 params.insert(format!("q_proj.{name}"), param);
254 }
255 for (name, param) in self.k_proj.named_parameters() {
256 params.insert(format!("k_proj.{name}"), param);
257 }
258 for (name, param) in self.v_proj.named_parameters() {
259 params.insert(format!("v_proj.{name}"), param);
260 }
261 for (name, param) in self.out_proj.named_parameters() {
262 params.insert(format!("out_proj.{name}"), param);
263 }
264 params
265 }
266
267 fn name(&self) -> &'static str {
268 "MultiHeadAttention"
269 }
270}
271
272#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_multihead_attention_creation() {
282 let mha = MultiHeadAttention::new(512, 8);
283 assert_eq!(mha.embed_dim, 512);
284 assert_eq!(mha.num_heads, 8);
285 assert_eq!(mha.head_dim, 64);
286 }
287
288 #[test]
289 fn test_multihead_attention_forward() {
290 let mha = MultiHeadAttention::new(64, 4);
291 let input = Variable::new(
292 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
293 false,
294 );
295 let output = mha.forward(&input);
296 assert_eq!(output.shape(), vec![2, 10, 64]);
297 }
298
299 #[test]
300 fn test_cross_attention() {
301 let mha = MultiHeadAttention::new(64, 4);
302 let query = Variable::new(
303 Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
304 false,
305 );
306 let key_value = Variable::new(
307 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
308 false,
309 );
310 let output = mha.attention(&query, &key_value, &key_value, None);
311 assert_eq!(output.shape(), vec![2, 5, 64]);
312 }
313
314 #[test]
315 fn test_multihead_attention_parameters() {
316 let mha = MultiHeadAttention::new(64, 4);
317 let params = mha.parameters();
318 assert_eq!(params.len(), 8);
320 }
321}