1use crate::cache::TensorCache;
11use crate::modules::{Linear, LinearConfig};
12use crate::{Dropout, DropoutConfig};
13use burn_core as burn;
14
15use burn::{
16 config::Config,
17 module::{Initializer, Module},
18 tensor::{
19 Bool, Tensor,
20 activation::{quiet_softmax, softmax},
21 backend::Backend,
22 },
23};
24
25#[cfg(not(feature = "std"))]
26#[allow(unused_imports)]
27use num_traits::Float as _;
28
29#[derive(Config, Debug)]
30pub struct CrossAttentionConfig {
32 pub d_model: usize,
34 pub d_context: usize,
36 pub n_heads: usize,
38 pub n_heads_kv: usize,
40 pub d_head: usize,
42 #[config(default = 0.1)]
44 pub dropout: f64,
45 #[config(default = -1.0e4)]
47 pub min_float: f64,
48 #[config(default = false)]
50 pub quiet_softmax: bool,
51}
52
53#[derive(Module, Debug)]
54pub struct CrossAttention<B: Backend> {
65 query: Linear<B>,
66 key: Linear<B>,
67 value: Linear<B>,
68 output: Linear<B>,
69 dropout: Dropout,
70
71 n_heads: usize,
72 n_heads_kv: usize,
73 d_head: usize,
74 scale: f64,
75 min_float: f64,
76 quiet_softmax: bool,
77}
78
79pub struct CrossAttentionCache<B: Backend> {
83 pub k: TensorCache<B, 4>,
85 pub v: TensorCache<B, 4>,
87}
88
89impl<B: Backend> CrossAttentionCache<B> {
90 pub fn new() -> Self {
92 Self {
93 k: TensorCache::empty(),
94 v: TensorCache::empty(),
95 }
96 }
97}
98
99impl<B: Backend> Default for CrossAttentionCache<B> {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl CrossAttentionConfig {
106 pub fn init<B: Backend>(&self, device: &B::Device) -> CrossAttention<B> {
116 assert_eq!(
118 self.n_heads % self.n_heads_kv,
119 0,
120 "Query heads must be divisible by KV heads"
121 );
122
123 let init_linear = |in_dim, out_dim| {
124 LinearConfig::new(in_dim, out_dim)
125 .with_initializer(Initializer::KaimingUniform {
126 gain: 1.0 / (self.d_head as f64).sqrt(),
127 fan_out_only: false,
128 })
129 .init(device)
130 };
131
132 CrossAttention {
133 query: init_linear(self.d_model, self.n_heads * self.d_head),
135 key: init_linear(self.d_context, self.n_heads_kv * self.d_head),
136 value: init_linear(self.d_context, self.n_heads_kv * self.d_head),
137 output: init_linear(self.n_heads * self.d_head, self.d_model),
138
139 dropout: DropoutConfig::new(self.dropout).init(),
140 n_heads: self.n_heads,
141 n_heads_kv: self.n_heads_kv,
142 d_head: self.d_head,
143 scale: (self.d_head as f64).sqrt().recip(),
144 min_float: self.min_float,
145 quiet_softmax: self.quiet_softmax,
146 }
147 }
148}
149
150impl<B: Backend> CrossAttention<B> {
151 pub fn forward(
163 &self,
164 query: Tensor<B, 3>,
165 context: Tensor<B, 3>,
166 mask: Option<Tensor<B, 2, Bool>>,
167 ) -> Tensor<B, 3> {
168 let [batch, l_q, _] = query.dims();
169 let [_, l_k, _] = context.dims();
170
171 let q = self.query.forward(query);
173 let k = self.key.forward(context.clone());
174 let v = self.value.forward(context);
175
176 let q = q
179 .reshape([batch, l_q, self.n_heads, self.d_head])
180 .swap_dims(1, 2);
181
182 let k = k
184 .reshape([batch, l_k, self.n_heads_kv, self.d_head])
185 .swap_dims(1, 2);
186 let v = v
187 .reshape([batch, l_k, self.n_heads_kv, self.d_head])
188 .swap_dims(1, 2);
189
190 let (k, v) = if self.n_heads != self.n_heads_kv {
193 let n_rep = self.n_heads / self.n_heads_kv;
194 (self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
195 } else {
196 (k, v)
197 };
198
199 let scores = q.matmul(k.transpose()) * self.scale;
201
202 let scores = if let Some(mask) = mask {
205 let mask = mask.reshape([batch, 1, 1, l_k]);
206 scores.mask_fill(mask, self.min_float)
207 } else {
208 scores
209 };
210
211 let weights = if self.quiet_softmax {
214 quiet_softmax(scores, 3)
215 } else {
216 softmax(scores, 3)
217 };
218
219 let weights = self.dropout.forward(weights);
220
221 let output = weights.matmul(v);
223 let output = output
224 .swap_dims(1, 2)
225 .reshape([batch, l_q, self.n_heads * self.d_head]);
226
227 self.output.forward(output)
228 }
229
230 pub fn forward_cache(
245 &self,
246 query: Tensor<B, 3>,
247 context: Tensor<B, 3>,
248 mask: Option<Tensor<B, 2, Bool>>,
249 cache: &mut CrossAttentionCache<B>,
250 ) -> Tensor<B, 3> {
251 let [batch, l_q, _] = query.dims();
252
253 let q = self.query.forward(query);
255
256 let k_compute = |context: Tensor<B, 3>| {
257 let [batch, l_k, _] = context.dims();
258 self.key
259 .forward(context)
260 .reshape([batch, l_k, self.n_heads_kv, self.d_head])
261 .swap_dims(1, 2)
262 };
263 let v_compute = |context: Tensor<B, 3>| {
264 let [batch, l_k, _] = context.dims();
265 self.value
266 .forward(context)
267 .reshape([batch, l_k, self.n_heads_kv, self.d_head])
268 .swap_dims(1, 2)
269 };
270
271 let k = cache.k.forward_full(context.clone(), k_compute);
272 let v = cache.v.forward_full(context, v_compute);
273
274 let [_, _, l_k, _] = k.dims();
275
276 let q = q
279 .reshape([batch, l_q, self.n_heads, self.d_head])
280 .swap_dims(1, 2);
281
282 let (k, v) = if self.n_heads != self.n_heads_kv {
287 let n_rep = self.n_heads / self.n_heads_kv;
288 (self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
289 } else {
290 (k, v)
291 };
292
293 let scores = q.matmul(k.transpose()) * self.scale;
295
296 let scores = if let Some(mask) = mask {
299 let mask = mask.reshape([batch, 1, 1, l_k]);
300 scores.mask_fill(mask, self.min_float)
301 } else {
302 scores
303 };
304
305 let weights = if self.quiet_softmax {
308 quiet_softmax(scores, 3)
309 } else {
310 softmax(scores, 3)
311 };
312
313 let weights = self.dropout.forward(weights);
314
315 let output = weights.matmul(v);
317 let output = output
318 .swap_dims(1, 2)
319 .reshape([batch, l_q, self.n_heads * self.d_head]);
320
321 self.output.forward(output)
322 }
323
324 fn repeat_kv(&self, x: Tensor<B, 4>, n_rep: usize) -> Tensor<B, 4> {
326 let [b, h, l, d] = x.dims();
327 x.reshape([b, h, 1, l, d])
328 .expand([b, h, n_rep, l, d])
329 .reshape([b, h * n_rep, l, d])
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::TestBackend;
337 use burn::tensor::{Distribution, Int, Shape, Tensor, Tolerance};
338
339 #[test]
340 fn test_cross_attention_mha_shapes() {
341 let [
342 batch_size,
343 seq_len_query,
344 seq_len_context,
345 d_model,
346 d_context,
347 n_heads,
348 d_head,
349 ] = [7, 13, 15, 32, 40, 4, 8];
350 let device = Default::default();
351 let config = CrossAttentionConfig {
352 d_model,
353 d_context,
354 n_heads,
355 n_heads_kv: n_heads, d_head,
357 dropout: 0.1,
358 min_float: -1.0e4,
359 quiet_softmax: false,
360 };
361 let cross_attn = config.init::<TestBackend>(&device);
362
363 let query = Tensor::random(
364 [batch_size, seq_len_query, d_model],
365 Distribution::Default,
366 &device,
367 );
368 let context = Tensor::random(
369 [batch_size, seq_len_context, d_context],
370 Distribution::Default,
371 &device,
372 );
373
374 let output = cross_attn.forward(query, context, None);
375
376 assert_eq!(
377 output.shape(),
378 Shape::new([batch_size, seq_len_query, d_model]),
379 "Output should have the correct shape",
380 );
381 }
382
383 #[test]
384 fn test_cross_attention_gqa_shapes() {
385 let [
386 batch_size,
387 seq_len_query,
388 seq_len_context,
389 d_model,
390 d_context,
391 n_heads,
392 n_heads_kv,
393 d_head,
394 ] = [7, 13, 15, 32, 40, 4, 2, 8];
395 let device = Default::default();
396 let config = CrossAttentionConfig {
397 d_model,
398 d_context,
399 n_heads,
400 n_heads_kv, d_head,
402 dropout: 0.1,
403 min_float: -1.0e4,
404 quiet_softmax: false,
405 };
406 let cross_attn = config.init::<TestBackend>(&device);
407
408 let query = Tensor::random(
409 [batch_size, seq_len_query, d_model],
410 Distribution::Default,
411 &device,
412 );
413 let context = Tensor::random(
414 [batch_size, seq_len_context, d_context],
415 Distribution::Default,
416 &device,
417 );
418
419 let output = cross_attn.forward(query, context, None);
420
421 assert_eq!(
422 output.shape(),
423 Shape::new([batch_size, seq_len_query, d_model]),
424 "Output should have the correct shape",
425 );
426 }
427
428 #[test]
429 fn test_cross_attention_mqa_shapes() {
430 let [
431 batch_size,
432 seq_len_query,
433 seq_len_context,
434 d_model,
435 d_context,
436 n_heads,
437 d_head,
438 ] = [7, 13, 15, 32, 40, 4, 8];
439 let device = Default::default();
440 let config = CrossAttentionConfig {
441 d_model,
442 d_context,
443 n_heads,
444 n_heads_kv: 1, d_head,
446 dropout: 0.1,
447 min_float: -1.0e4,
448 quiet_softmax: false,
449 };
450 let cross_attn = config.init::<TestBackend>(&device);
451
452 let query = Tensor::random(
453 [batch_size, seq_len_query, d_model],
454 Distribution::Default,
455 &device,
456 );
457 let context = Tensor::random(
458 [batch_size, seq_len_context, d_context],
459 Distribution::Default,
460 &device,
461 );
462
463 let output = cross_attn.forward(query, context, None);
464
465 assert_eq!(
466 output.shape(),
467 Shape::new([batch_size, seq_len_query, d_model]),
468 "Output should have the correct shape",
469 );
470 }
471
472 #[test]
473 fn test_cross_attention_mask() {
474 let [
475 batch_size,
476 seq_len_query,
477 seq_len_context,
478 d_model,
479 d_context,
480 n_heads,
481 d_head,
482 ] = [3, 6, 8, 12, 16, 4, 3];
483 let num_padded = 2;
484 let device = Default::default();
485 let config = CrossAttentionConfig {
486 d_model,
487 d_context,
488 n_heads,
489 n_heads_kv: n_heads,
490 d_head,
491 dropout: 0.0, min_float: -1.0e4,
493 quiet_softmax: false,
494 };
495 let cross_attn = config.init::<TestBackend>(&device);
496
497 let mut mask: Tensor<TestBackend, 2, Int> =
499 Tensor::zeros([batch_size, seq_len_context], &device);
500 mask = mask.slice_assign(
501 [0..batch_size, seq_len_context - num_padded..seq_len_context],
502 Tensor::ones([batch_size, num_padded], &device),
503 );
504 let mask_bool = mask.equal_elem(1);
505
506 let query = Tensor::<TestBackend, 3>::random(
507 [batch_size, seq_len_query, d_model],
508 Distribution::Default,
509 &device,
510 );
511
512 let context_1 = Tensor::<TestBackend, 3>::random(
513 [batch_size, seq_len_context, d_context],
514 Distribution::Default,
515 &device,
516 );
517
518 let context_2 = context_1.clone().slice_assign(
520 [
521 0..batch_size,
522 seq_len_context - num_padded..seq_len_context,
523 0..d_context,
524 ],
525 Tensor::random(
526 [batch_size, num_padded, d_context],
527 Distribution::Default,
528 &device,
529 ),
530 );
531
532 let output_1 = cross_attn.forward(query.clone(), context_1, Some(mask_bool.clone()));
534 let output_2 = cross_attn.forward(query, context_2, Some(mask_bool));
535
536 output_1
537 .into_data()
538 .assert_approx_eq(&output_2.into_data(), Tolerance::<f32>::default());
539 }
540
541 #[test]
542 #[should_panic]
543 fn test_gqa_panic_if_n_heads_not_divisible_by_n_heads_kv() {
544 let device = Default::default();
545 let config = CrossAttentionConfig {
546 d_model: 32,
547 d_context: 32,
548 n_heads: 5,
549 n_heads_kv: 2,
550 d_head: 8,
551 dropout: 0.1,
552 min_float: -1.0e4,
553 quiet_softmax: false,
554 };
555 config.init::<TestBackend>(&device);
556 }
557
558 #[test]
559 fn test_cross_attention_cache() {
560 let [
561 batch_size,
562 seq_len_query,
563 seq_len_context,
564 d_model,
565 d_context,
566 n_heads,
567 d_head,
568 ] = [3, 6, 8, 12, 16, 4, 3];
569 let device = Default::default();
570 let config = CrossAttentionConfig {
571 d_model,
572 d_context,
573 n_heads,
574 n_heads_kv: n_heads,
575 d_head,
576 dropout: 0.0, min_float: -1.0e4,
578 quiet_softmax: false,
579 };
580 let cross_attn = config.init::<TestBackend>(&device);
581
582 let query1 = Tensor::<TestBackend, 3>::random(
583 [batch_size, seq_len_query, d_model],
584 Distribution::Default,
585 &device,
586 );
587 let context = Tensor::<TestBackend, 3>::random(
588 [batch_size, seq_len_context, d_context],
589 Distribution::Default,
590 &device,
591 );
592
593 let output1 = cross_attn.forward(query1.clone(), context.clone(), None);
595
596 let mut cache = CrossAttentionCache::new();
598 let output2 = cross_attn.forward_cache(query1.clone(), context.clone(), None, &mut cache);
599
600 output1
602 .into_data()
603 .assert_approx_eq(&output2.into_data(), Tolerance::<f32>::default());
604
605 let query2 = Tensor::<TestBackend, 3>::random(
607 [batch_size, seq_len_query, d_model],
608 Distribution::Default,
609 &device,
610 );
611 let output3 = cross_attn.forward_cache(query2.clone(), context.clone(), None, &mut cache);
612
613 let output4 = cross_attn.forward(query2.clone(), context.clone(), None);
615
616 output3
618 .into_data()
619 .assert_approx_eq(&output4.into_data(), Tolerance::<f32>::default());
620 }
621}