1use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::linear::Linear;
6use crate::dropout::Dropout;
7
8pub fn scaled_dot_product_attention(
12 query: &Tensor,
13 key: &Tensor,
14 value: &Tensor,
15 mask: Option<&Tensor>,
16 dropout_p: f32,
17 training: bool,
18) -> (Tensor, Tensor) {
19 let d_k = query.dims()[query.ndim() - 1] as f32;
20 let scale = 1.0 / d_k.sqrt();
21
22 let key_t = key.transpose(key.ndim() - 2, key.ndim() - 1).unwrap();
24 let scores = query.matmul(&key_t).unwrap();
25 let scaled_scores = scores.mul_scalar(scale);
26
27 let masked_scores = if let Some(m) = mask {
29 apply_attention_mask(&scaled_scores, m)
30 } else {
31 scaled_scores
32 };
33
34 let attn_weights = masked_scores.softmax(-1);
36
37 let attn_weights = if training && dropout_p > 0.0 {
39 let dropout = Dropout::new(dropout_p);
40 dropout.forward(&attn_weights)
41 } else {
42 attn_weights
43 };
44
45 let output = attn_weights.matmul(value).unwrap();
47
48 (output, attn_weights)
49}
50
51fn apply_attention_mask(scores: &Tensor, mask: &Tensor) -> Tensor {
52 let mask_data = mask.data_f32();
54 let scores_data = scores.data_f32();
55
56 let result: Vec<f32> = scores_data.iter()
57 .zip(mask_data.iter().cycle())
58 .map(|(&s, &m)| {
59 if m > 0.5 { s } else { f32::NEG_INFINITY }
60 })
61 .collect();
62
63 Tensor::from_slice(&result, scores.dims()).unwrap()
64}
65
66pub struct MultiHeadAttention {
68 embed_dim: usize,
69 num_heads: usize,
70 head_dim: usize,
71
72 q_proj: Linear,
73 k_proj: Linear,
74 v_proj: Linear,
75 out_proj: Linear,
76
77 dropout_p: f32,
78 training: bool,
79}
80
81impl MultiHeadAttention {
82 pub fn new(embed_dim: usize, num_heads: usize, dropout: f32) -> Self {
83 assert!(embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads");
84 let head_dim = embed_dim / num_heads;
85
86 MultiHeadAttention {
87 embed_dim,
88 num_heads,
89 head_dim,
90 q_proj: Linear::new(embed_dim, embed_dim),
91 k_proj: Linear::new(embed_dim, embed_dim),
92 v_proj: Linear::new(embed_dim, embed_dim),
93 out_proj: Linear::new(embed_dim, embed_dim),
94 dropout_p: dropout,
95 training: true,
96 }
97 }
98
99 pub fn forward_with_cache(
101 &self,
102 query: &Tensor,
103 key: &Tensor,
104 value: &Tensor,
105 mask: Option<&Tensor>,
106 past_key: Option<&Tensor>,
107 past_value: Option<&Tensor>,
108 ) -> (Tensor, Tensor, Tensor, Tensor) {
109 let batch_size = query.dims()[0];
110 let seq_len = query.dims()[1];
111
112 let q = self.q_proj.forward(query);
114 let mut k = self.k_proj.forward(key);
115 let mut v = self.v_proj.forward(value);
116
117 if let (Some(pk), Some(pv)) = (past_key, past_value) {
119 k = concat_tensors(pk, &k, 1);
120 v = concat_tensors(pv, &v, 1);
121 }
122
123 let kv_seq_len = k.dims()[1];
124
125 let q = self.reshape_for_attention(&q, batch_size, seq_len);
127 let k = self.reshape_for_attention(&k, batch_size, kv_seq_len);
128 let v = self.reshape_for_attention(&v, batch_size, kv_seq_len);
129
130 let (attn_output, attn_weights) = scaled_dot_product_attention(
132 &q, &k, &v, mask, self.dropout_p, self.training
133 );
134
135 let attn_output = self.reshape_from_attention(&attn_output, batch_size, seq_len);
137
138 let output = self.out_proj.forward(&attn_output);
140
141 let k_cache = self.reshape_for_attention(&self.k_proj.forward(key), batch_size, key.dims()[1]);
143 let v_cache = self.reshape_for_attention(&self.v_proj.forward(value), batch_size, value.dims()[1]);
144
145 (output, attn_weights, k_cache, v_cache)
146 }
147
148 fn reshape_for_attention(&self, x: &Tensor, batch_size: usize, seq_len: usize) -> Tensor {
149 let reshaped = x.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim]).unwrap();
151 reshaped.transpose(1, 2).unwrap()
152 }
153
154 fn reshape_from_attention(&self, x: &Tensor, batch_size: usize, seq_len: usize) -> Tensor {
155 let transposed = x.transpose(1, 2).unwrap();
157 transposed.reshape(&[batch_size, seq_len, self.embed_dim]).unwrap()
158 }
159}
160
161impl Module for MultiHeadAttention {
162 fn forward(&self, input: &Tensor) -> Tensor {
163 let (output, _, _, _) = self.forward_with_cache(input, input, input, None, None, None);
165 output
166 }
167
168 fn parameters(&self) -> Vec<Tensor> {
169 let mut params = self.q_proj.parameters();
170 params.extend(self.k_proj.parameters());
171 params.extend(self.v_proj.parameters());
172 params.extend(self.out_proj.parameters());
173 params
174 }
175
176 fn train(&mut self) { self.training = true; }
177 fn eval(&mut self) { self.training = false; }
178 fn is_training(&self) -> bool { self.training }
179}
180
181pub struct CrossAttention {
183 mha: MultiHeadAttention,
184}
185
186impl CrossAttention {
187 pub fn new(embed_dim: usize, num_heads: usize, dropout: f32) -> Self {
188 CrossAttention {
189 mha: MultiHeadAttention::new(embed_dim, num_heads, dropout),
190 }
191 }
192
193 pub fn forward_cross(&self, query: &Tensor, key: &Tensor, value: &Tensor, mask: Option<&Tensor>) -> Tensor {
194 let (output, _, _, _) = self.mha.forward_with_cache(query, key, value, mask, None, None);
195 output
196 }
197}
198
199impl Module for CrossAttention {
200 fn forward(&self, input: &Tensor) -> Tensor {
201 self.mha.forward(input)
202 }
203
204 fn parameters(&self) -> Vec<Tensor> {
205 self.mha.parameters()
206 }
207
208 fn train(&mut self) { self.mha.train(); }
209 fn eval(&mut self) { self.mha.eval(); }
210 fn is_training(&self) -> bool { self.mha.is_training() }
211}
212
213fn concat_tensors(a: &Tensor, b: &Tensor, dim: usize) -> Tensor {
215 let a_dims = a.dims();
216 let b_dims = b.dims();
217
218 let mut new_dims = a_dims.to_vec();
219 new_dims[dim] = a_dims[dim] + b_dims[dim];
220
221 let a_data = a.data_f32();
222 let b_data = b.data_f32();
223
224 if dim == 1 {
226 let batch = a_dims[0];
227 let a_seq = a_dims[1];
228 let b_seq = b_dims[1];
229 let rest: usize = a_dims[2..].iter().product();
230
231 let mut result = Vec::with_capacity(batch * (a_seq + b_seq) * rest);
232
233 for b_idx in 0..batch {
234 let a_start = b_idx * a_seq * rest;
236 result.extend_from_slice(&a_data[a_start..a_start + a_seq * rest]);
237 let b_start = b_idx * b_seq * rest;
239 result.extend_from_slice(&b_data[b_start..b_start + b_seq * rest]);
240 }
241
242 Tensor::from_slice(&result, &new_dims).unwrap()
243 } else {
244 b.clone()
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_scaled_dot_product_attention() {
255 let q = Tensor::randn(&[2, 4, 8]); let k = Tensor::randn(&[2, 4, 8]);
257 let v = Tensor::randn(&[2, 4, 8]);
258
259 let (output, weights) = scaled_dot_product_attention(&q, &k, &v, None, 0.0, false);
260
261 assert_eq!(output.dims(), &[2, 4, 8]);
262 assert_eq!(weights.dims(), &[2, 4, 4]);
263 }
264
265 #[test]
266 fn test_multi_head_attention() {
267 let mha = MultiHeadAttention::new(64, 8, 0.1);
268 let input = Tensor::randn(&[2, 10, 64]); let output = mha.forward(&input);
271
272 assert_eq!(output.dims(), &[2, 10, 64]);
273 }
274}