burn_nn/modules/attention/
cross_attention.rs

1//! Cross-Attention Module for Burn
2//!
3//! Features:
4//! - Asymmetric Input Shapes (Query vs Context)
5//! - Grouped Query Attention (GQA) & Multi-Query Attention (MQA) support
6//! - Quantization-Safe Masking (min_float)
7//! - Sparse-Ready (quiet_softmax)
8//! - KV Caching for Streaming Inference
9
10use crate::cache::TensorCache;
11use crate::modules::{Linear, LinearConfig};
12use crate::{Dropout, DropoutConfig};
13use burn_core as burn;
14
15use burn::{
16    config::Config,
17    module::{Initializer, Module},
18    tensor::{
19        Bool, Tensor,
20        activation::{quiet_softmax, softmax},
21        backend::Backend,
22    },
23};
24
25#[cfg(not(feature = "std"))]
26#[allow(unused_imports)]
27use num_traits::Float as _;
28
29#[derive(Config, Debug)]
30/// Configuration to create a [CrossAttention](CrossAttention) layer using the [init function](CrossAttentionConfig::init).
31pub struct CrossAttentionConfig {
32    /// Dimension of the Query (e.g., Decoder state).
33    pub d_model: usize,
34    /// Dimension of the Context (e.g., Encoder audio embeddings).
35    pub d_context: usize,
36    /// Number of heads for the Query.
37    pub n_heads: usize,
38    /// Number of heads for Key/Value (Set to 1 for MQA, set to n_heads for MHA).
39    pub n_heads_kv: usize,
40    /// Dimension of a single head.
41    pub d_head: usize,
42    /// Dropout rate.
43    #[config(default = 0.1)]
44    pub dropout: f64,
45    /// Masking value. Use -1.0e4 for f16/bf16 safety.
46    #[config(default = -1.0e4)]
47    pub min_float: f64,
48    /// Use quiet_softmax to allow zero-attention (good for sparse/quantized models).
49    #[config(default = false)]
50    pub quiet_softmax: bool,
51}
52
53#[derive(Module, Debug)]
54/// The Cross attention module
55///
56/// # Params
57///
58/// - `query`: [`Linear`] layer with `d_model` input and output features.
59/// - `key`: [`Linear`] layer with `d_model` input and output features.
60/// - `value`: [`Linear`] layer with `d_model` input and output features.
61/// - `output`: [`Linear`] layer with `d_model` input and output features.
62///
63/// Should be created with [CrossAttentionConfig].
64pub struct CrossAttention<B: Backend> {
65    query: Linear<B>,
66    key: Linear<B>,
67    value: Linear<B>,
68    output: Linear<B>,
69    dropout: Dropout,
70
71    n_heads: usize,
72    n_heads_kv: usize,
73    d_head: usize,
74    scale: f64,
75    min_float: f64,
76    quiet_softmax: bool,
77}
78
79/// Cache for the [Cross Attention](CrossAttention) layer.
80///
81/// To be used during inference when context is constant.
82pub struct CrossAttentionCache<B: Backend> {
83    /// Cached key tensor.
84    pub k: TensorCache<B, 4>,
85    /// Cached value tensor.
86    pub v: TensorCache<B, 4>,
87}
88
89impl<B: Backend> CrossAttentionCache<B> {
90    /// Create a new empty cache.
91    pub fn new() -> Self {
92        Self {
93            k: TensorCache::empty(),
94            v: TensorCache::empty(),
95        }
96    }
97}
98
99impl<B: Backend> Default for CrossAttentionCache<B> {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl CrossAttentionConfig {
106    /// Initializes a new cross-attention module.
107    ///
108    /// # Arguments
109    ///
110    /// * `device` - The device on which to initialize the module.
111    ///
112    /// # Returns
113    ///
114    /// A new [CrossAttention] module.
115    pub fn init<B: Backend>(&self, device: &B::Device) -> CrossAttention<B> {
116        // Safety Rail for GQA
117        assert_eq!(
118            self.n_heads % self.n_heads_kv,
119            0,
120            "Query heads must be divisible by KV heads"
121        );
122
123        let init_linear = |in_dim, out_dim| {
124            LinearConfig::new(in_dim, out_dim)
125                .with_initializer(Initializer::KaimingUniform {
126                    gain: 1.0 / (self.d_head as f64).sqrt(),
127                    fan_out_only: false,
128                })
129                .init(device)
130        };
131
132        CrossAttention {
133            // ADVICE: Asymmetric Projections
134            query: init_linear(self.d_model, self.n_heads * self.d_head),
135            key: init_linear(self.d_context, self.n_heads_kv * self.d_head),
136            value: init_linear(self.d_context, self.n_heads_kv * self.d_head),
137            output: init_linear(self.n_heads * self.d_head, self.d_model),
138
139            dropout: DropoutConfig::new(self.dropout).init(),
140            n_heads: self.n_heads,
141            n_heads_kv: self.n_heads_kv,
142            d_head: self.d_head,
143            scale: (self.d_head as f64).sqrt().recip(),
144            min_float: self.min_float,
145            quiet_softmax: self.quiet_softmax,
146        }
147    }
148}
149
150impl<B: Backend> CrossAttention<B> {
151    /// Applies cross-attention to query using context as key and value.
152    ///
153    /// # Arguments
154    ///
155    /// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`.
156    /// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`.
157    /// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask.
158    ///
159    /// # Returns
160    ///
161    /// Output tensor of shape `[batch, seq_len_query, d_model]`.
162    pub fn forward(
163        &self,
164        query: Tensor<B, 3>,
165        context: Tensor<B, 3>,
166        mask: Option<Tensor<B, 2, Bool>>,
167    ) -> Tensor<B, 3> {
168        let [batch, l_q, _] = query.dims();
169        let [_, l_k, _] = context.dims();
170
171        // 1. Projections
172        let q = self.query.forward(query);
173        let k = self.key.forward(context.clone());
174        let v = self.value.forward(context);
175
176        // 2. Reshape Heads
177        // Q: [Batch, Heads, L_q, D_head]
178        let q = q
179            .reshape([batch, l_q, self.n_heads, self.d_head])
180            .swap_dims(1, 2);
181
182        // K, V: [Batch, Heads_KV, L_k, D_head]
183        let k = k
184            .reshape([batch, l_k, self.n_heads_kv, self.d_head])
185            .swap_dims(1, 2);
186        let v = v
187            .reshape([batch, l_k, self.n_heads_kv, self.d_head])
188            .swap_dims(1, 2);
189
190        // 3. GQA Expansion
191        // ADVICE: Handle GQA by repeating KV heads to match Query heads
192        let (k, v) = if self.n_heads != self.n_heads_kv {
193            let n_rep = self.n_heads / self.n_heads_kv;
194            (self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
195        } else {
196            (k, v)
197        };
198
199        // 4. Score Calculation
200        let scores = q.matmul(k.transpose()) * self.scale;
201
202        // 5. Masking
203        // ADVICE: Use min_float for F16/FP8 safety
204        let scores = if let Some(mask) = mask {
205            let mask = mask.reshape([batch, 1, 1, l_k]);
206            scores.mask_fill(mask, self.min_float)
207        } else {
208            scores
209        };
210
211        // 6. Softmax
212        // ADVICE: Optional Quiet Softmax for sparse networks
213        let weights = if self.quiet_softmax {
214            quiet_softmax(scores, 3)
215        } else {
216            softmax(scores, 3)
217        };
218
219        let weights = self.dropout.forward(weights);
220
221        // 7. Aggregate & Output
222        let output = weights.matmul(v);
223        let output = output
224            .swap_dims(1, 2)
225            .reshape([batch, l_q, self.n_heads * self.d_head]);
226
227        self.output.forward(output)
228    }
229
230    /// Applies cross-attention to query using context as key and value.
231    ///
232    /// This method uses a cache to avoid recomputing key and value tensors when the context is the same.
233    ///
234    /// # Arguments
235    ///
236    /// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`.
237    /// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`.
238    /// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask.
239    /// * `cache` - The cache to use.
240    ///
241    /// # Returns
242    ///
243    /// Output tensor of shape `[batch, seq_len_query, d_model]`.
244    pub fn forward_cache(
245        &self,
246        query: Tensor<B, 3>,
247        context: Tensor<B, 3>,
248        mask: Option<Tensor<B, 2, Bool>>,
249        cache: &mut CrossAttentionCache<B>,
250    ) -> Tensor<B, 3> {
251        let [batch, l_q, _] = query.dims();
252
253        // 1. Projections
254        let q = self.query.forward(query);
255
256        let k_compute = |context: Tensor<B, 3>| {
257            let [batch, l_k, _] = context.dims();
258            self.key
259                .forward(context)
260                .reshape([batch, l_k, self.n_heads_kv, self.d_head])
261                .swap_dims(1, 2)
262        };
263        let v_compute = |context: Tensor<B, 3>| {
264            let [batch, l_k, _] = context.dims();
265            self.value
266                .forward(context)
267                .reshape([batch, l_k, self.n_heads_kv, self.d_head])
268                .swap_dims(1, 2)
269        };
270
271        let k = cache.k.forward_full(context.clone(), k_compute);
272        let v = cache.v.forward_full(context, v_compute);
273
274        let [_, _, l_k, _] = k.dims();
275
276        // 2. Reshape Heads
277        // Q: [Batch, Heads, L_q, D_head]
278        let q = q
279            .reshape([batch, l_q, self.n_heads, self.d_head])
280            .swap_dims(1, 2);
281
282        // K, V are already in their correct shape from k_compute and v_compute
283
284        // 3. GQA Expansion
285        // ADVICE: Handle GQA by repeating KV heads to match Query heads
286        let (k, v) = if self.n_heads != self.n_heads_kv {
287            let n_rep = self.n_heads / self.n_heads_kv;
288            (self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
289        } else {
290            (k, v)
291        };
292
293        // 4. Score Calculation
294        let scores = q.matmul(k.transpose()) * self.scale;
295
296        // 5. Masking
297        // ADVICE: Use min_float for F16/FP8 safety
298        let scores = if let Some(mask) = mask {
299            let mask = mask.reshape([batch, 1, 1, l_k]);
300            scores.mask_fill(mask, self.min_float)
301        } else {
302            scores
303        };
304
305        // 6. Softmax
306        // ADVICE: Optional Quiet Softmax for sparse networks
307        let weights = if self.quiet_softmax {
308            quiet_softmax(scores, 3)
309        } else {
310            softmax(scores, 3)
311        };
312
313        let weights = self.dropout.forward(weights);
314
315        // 7. Aggregate & Output
316        let output = weights.matmul(v);
317        let output = output
318            .swap_dims(1, 2)
319            .reshape([batch, l_q, self.n_heads * self.d_head]);
320
321        self.output.forward(output)
322    }
323
324    /// Helper for Grouped Query Attention
325    fn repeat_kv(&self, x: Tensor<B, 4>, n_rep: usize) -> Tensor<B, 4> {
326        let [b, h, l, d] = x.dims();
327        x.reshape([b, h, 1, l, d])
328            .expand([b, h, n_rep, l, d])
329            .reshape([b, h * n_rep, l, d])
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use crate::TestBackend;
337    use burn::tensor::{Distribution, Int, Shape, Tensor, Tolerance};
338
339    #[test]
340    fn test_cross_attention_mha_shapes() {
341        let [
342            batch_size,
343            seq_len_query,
344            seq_len_context,
345            d_model,
346            d_context,
347            n_heads,
348            d_head,
349        ] = [7, 13, 15, 32, 40, 4, 8];
350        let device = Default::default();
351        let config = CrossAttentionConfig {
352            d_model,
353            d_context,
354            n_heads,
355            n_heads_kv: n_heads, // MHA case
356            d_head,
357            dropout: 0.1,
358            min_float: -1.0e4,
359            quiet_softmax: false,
360        };
361        let cross_attn = config.init::<TestBackend>(&device);
362
363        let query = Tensor::random(
364            [batch_size, seq_len_query, d_model],
365            Distribution::Default,
366            &device,
367        );
368        let context = Tensor::random(
369            [batch_size, seq_len_context, d_context],
370            Distribution::Default,
371            &device,
372        );
373
374        let output = cross_attn.forward(query, context, None);
375
376        assert_eq!(
377            output.shape(),
378            Shape::new([batch_size, seq_len_query, d_model]),
379            "Output should have the correct shape",
380        );
381    }
382
383    #[test]
384    fn test_cross_attention_gqa_shapes() {
385        let [
386            batch_size,
387            seq_len_query,
388            seq_len_context,
389            d_model,
390            d_context,
391            n_heads,
392            n_heads_kv,
393            d_head,
394        ] = [7, 13, 15, 32, 40, 4, 2, 8];
395        let device = Default::default();
396        let config = CrossAttentionConfig {
397            d_model,
398            d_context,
399            n_heads,
400            n_heads_kv, // GQA case
401            d_head,
402            dropout: 0.1,
403            min_float: -1.0e4,
404            quiet_softmax: false,
405        };
406        let cross_attn = config.init::<TestBackend>(&device);
407
408        let query = Tensor::random(
409            [batch_size, seq_len_query, d_model],
410            Distribution::Default,
411            &device,
412        );
413        let context = Tensor::random(
414            [batch_size, seq_len_context, d_context],
415            Distribution::Default,
416            &device,
417        );
418
419        let output = cross_attn.forward(query, context, None);
420
421        assert_eq!(
422            output.shape(),
423            Shape::new([batch_size, seq_len_query, d_model]),
424            "Output should have the correct shape",
425        );
426    }
427
428    #[test]
429    fn test_cross_attention_mqa_shapes() {
430        let [
431            batch_size,
432            seq_len_query,
433            seq_len_context,
434            d_model,
435            d_context,
436            n_heads,
437            d_head,
438        ] = [7, 13, 15, 32, 40, 4, 8];
439        let device = Default::default();
440        let config = CrossAttentionConfig {
441            d_model,
442            d_context,
443            n_heads,
444            n_heads_kv: 1, // MQA case
445            d_head,
446            dropout: 0.1,
447            min_float: -1.0e4,
448            quiet_softmax: false,
449        };
450        let cross_attn = config.init::<TestBackend>(&device);
451
452        let query = Tensor::random(
453            [batch_size, seq_len_query, d_model],
454            Distribution::Default,
455            &device,
456        );
457        let context = Tensor::random(
458            [batch_size, seq_len_context, d_context],
459            Distribution::Default,
460            &device,
461        );
462
463        let output = cross_attn.forward(query, context, None);
464
465        assert_eq!(
466            output.shape(),
467            Shape::new([batch_size, seq_len_query, d_model]),
468            "Output should have the correct shape",
469        );
470    }
471
472    #[test]
473    fn test_cross_attention_mask() {
474        let [
475            batch_size,
476            seq_len_query,
477            seq_len_context,
478            d_model,
479            d_context,
480            n_heads,
481            d_head,
482        ] = [3, 6, 8, 12, 16, 4, 3];
483        let num_padded = 2;
484        let device = Default::default();
485        let config = CrossAttentionConfig {
486            d_model,
487            d_context,
488            n_heads,
489            n_heads_kv: n_heads,
490            d_head,
491            dropout: 0.0, // No dropout for deterministic test
492            min_float: -1.0e4,
493            quiet_softmax: false,
494        };
495        let cross_attn = config.init::<TestBackend>(&device);
496
497        // Create a padding mask for the context
498        let mut mask: Tensor<TestBackend, 2, Int> =
499            Tensor::zeros([batch_size, seq_len_context], &device);
500        mask = mask.slice_assign(
501            [0..batch_size, seq_len_context - num_padded..seq_len_context],
502            Tensor::ones([batch_size, num_padded], &device),
503        );
504        let mask_bool = mask.equal_elem(1);
505
506        let query = Tensor::<TestBackend, 3>::random(
507            [batch_size, seq_len_query, d_model],
508            Distribution::Default,
509            &device,
510        );
511
512        let context_1 = Tensor::<TestBackend, 3>::random(
513            [batch_size, seq_len_context, d_context],
514            Distribution::Default,
515            &device,
516        );
517
518        // Change the padded part of the context tensor
519        let context_2 = context_1.clone().slice_assign(
520            [
521                0..batch_size,
522                seq_len_context - num_padded..seq_len_context,
523                0..d_context,
524            ],
525            Tensor::random(
526                [batch_size, num_padded, d_context],
527                Distribution::Default,
528                &device,
529            ),
530        );
531
532        // The outputs should be the same since the changed part is masked.
533        let output_1 = cross_attn.forward(query.clone(), context_1, Some(mask_bool.clone()));
534        let output_2 = cross_attn.forward(query, context_2, Some(mask_bool));
535
536        output_1
537            .into_data()
538            .assert_approx_eq(&output_2.into_data(), Tolerance::<f32>::default());
539    }
540
541    #[test]
542    #[should_panic]
543    fn test_gqa_panic_if_n_heads_not_divisible_by_n_heads_kv() {
544        let device = Default::default();
545        let config = CrossAttentionConfig {
546            d_model: 32,
547            d_context: 32,
548            n_heads: 5,
549            n_heads_kv: 2,
550            d_head: 8,
551            dropout: 0.1,
552            min_float: -1.0e4,
553            quiet_softmax: false,
554        };
555        config.init::<TestBackend>(&device);
556    }
557
558    #[test]
559    fn test_cross_attention_cache() {
560        let [
561            batch_size,
562            seq_len_query,
563            seq_len_context,
564            d_model,
565            d_context,
566            n_heads,
567            d_head,
568        ] = [3, 6, 8, 12, 16, 4, 3];
569        let device = Default::default();
570        let config = CrossAttentionConfig {
571            d_model,
572            d_context,
573            n_heads,
574            n_heads_kv: n_heads,
575            d_head,
576            dropout: 0.0, // No dropout for deterministic test
577            min_float: -1.0e4,
578            quiet_softmax: false,
579        };
580        let cross_attn = config.init::<TestBackend>(&device);
581
582        let query1 = Tensor::<TestBackend, 3>::random(
583            [batch_size, seq_len_query, d_model],
584            Distribution::Default,
585            &device,
586        );
587        let context = Tensor::<TestBackend, 3>::random(
588            [batch_size, seq_len_context, d_context],
589            Distribution::Default,
590            &device,
591        );
592
593        // First forward pass, no cache
594        let output1 = cross_attn.forward(query1.clone(), context.clone(), None);
595
596        // Second forward pass with cache
597        let mut cache = CrossAttentionCache::new();
598        let output2 = cross_attn.forward_cache(query1.clone(), context.clone(), None, &mut cache);
599
600        // The two outputs should be identical
601        output1
602            .into_data()
603            .assert_approx_eq(&output2.into_data(), Tolerance::<f32>::default());
604
605        // Third forward pass with different query, but same context and cache
606        let query2 = Tensor::<TestBackend, 3>::random(
607            [batch_size, seq_len_query, d_model],
608            Distribution::Default,
609            &device,
610        );
611        let output3 = cross_attn.forward_cache(query2.clone(), context.clone(), None, &mut cache);
612
613        // For control, do a forward pass without cache with query2
614        let output4 = cross_attn.forward(query2.clone(), context.clone(), None);
615
616        // output3 and output4 should be identical
617        output3
618            .into_data()
619            .assert_approx_eq(&output4.into_data(), Tolerance::<f32>::default());
620    }
621}