burn-nn 0.20.1

Neural network building blocks for the Burn deep learning framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
//! Cross-Attention Module for Burn
//!
//! Features:
//! - Asymmetric Input Shapes (Query vs Context)
//! - Grouped Query Attention (GQA) & Multi-Query Attention (MQA) support
//! - Quantization-Safe Masking (min_float)
//! - Sparse-Ready (quiet_softmax)
//! - KV Caching for Streaming Inference

use crate::cache::TensorCache;
use crate::modules::{Linear, LinearConfig};
use crate::{Dropout, DropoutConfig};
use burn_core as burn;

use burn::{
    config::Config,
    module::{Initializer, Module},
    tensor::{
        Bool, Tensor,
        activation::{quiet_softmax, softmax},
        backend::Backend,
    },
};

#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;

#[derive(Config, Debug)]
/// Configuration to create a [CrossAttention](CrossAttention) layer using the [init function](CrossAttentionConfig::init).
pub struct CrossAttentionConfig {
    /// Dimension of the Query (e.g., Decoder state).
    pub d_model: usize,
    /// Dimension of the Context (e.g., Encoder audio embeddings).
    pub d_context: usize,
    /// Number of heads for the Query.
    pub n_heads: usize,
    /// Number of heads for Key/Value (Set to 1 for MQA, set to n_heads for MHA).
    pub n_heads_kv: usize,
    /// Dimension of a single head.
    pub d_head: usize,
    /// Dropout rate.
    #[config(default = 0.1)]
    pub dropout: f64,
    /// Masking value. Use -1.0e4 for f16/bf16 safety.
    #[config(default = -1.0e4)]
    pub min_float: f64,
    /// Use quiet_softmax to allow zero-attention (good for sparse/quantized models).
    #[config(default = false)]
    pub quiet_softmax: bool,
}

#[derive(Module, Debug)]
/// The Cross attention module
///
/// # Params
///
/// - `query`: [`Linear`] layer with `d_model` input and output features.
/// - `key`: [`Linear`] layer with `d_model` input and output features.
/// - `value`: [`Linear`] layer with `d_model` input and output features.
/// - `output`: [`Linear`] layer with `d_model` input and output features.
///
/// Should be created with [CrossAttentionConfig].
pub struct CrossAttention<B: Backend> {
    query: Linear<B>,
    key: Linear<B>,
    value: Linear<B>,
    output: Linear<B>,
    dropout: Dropout,

    n_heads: usize,
    n_heads_kv: usize,
    d_head: usize,
    scale: f64,
    min_float: f64,
    quiet_softmax: bool,
}

/// Cache for the [Cross Attention](CrossAttention) layer.
///
/// To be used during inference when context is constant.
pub struct CrossAttentionCache<B: Backend> {
    /// Cached key tensor.
    pub k: TensorCache<B, 4>,
    /// Cached value tensor.
    pub v: TensorCache<B, 4>,
}

impl<B: Backend> CrossAttentionCache<B> {
    /// Create a new empty cache.
    pub fn new() -> Self {
        Self {
            k: TensorCache::empty(),
            v: TensorCache::empty(),
        }
    }
}

impl<B: Backend> Default for CrossAttentionCache<B> {
    fn default() -> Self {
        Self::new()
    }
}

impl CrossAttentionConfig {
    /// Initializes a new cross-attention module.
    ///
    /// # Arguments
    ///
    /// * `device` - The device on which to initialize the module.
    ///
    /// # Returns
    ///
    /// A new [CrossAttention] module.
    pub fn init<B: Backend>(&self, device: &B::Device) -> CrossAttention<B> {
        // Safety Rail for GQA
        assert_eq!(
            self.n_heads % self.n_heads_kv,
            0,
            "Query heads must be divisible by KV heads"
        );

        let init_linear = |in_dim, out_dim| {
            LinearConfig::new(in_dim, out_dim)
                .with_initializer(Initializer::KaimingUniform {
                    gain: 1.0 / (self.d_head as f64).sqrt(),
                    fan_out_only: false,
                })
                .init(device)
        };

        CrossAttention {
            // ADVICE: Asymmetric Projections
            query: init_linear(self.d_model, self.n_heads * self.d_head),
            key: init_linear(self.d_context, self.n_heads_kv * self.d_head),
            value: init_linear(self.d_context, self.n_heads_kv * self.d_head),
            output: init_linear(self.n_heads * self.d_head, self.d_model),

            dropout: DropoutConfig::new(self.dropout).init(),
            n_heads: self.n_heads,
            n_heads_kv: self.n_heads_kv,
            d_head: self.d_head,
            scale: (self.d_head as f64).sqrt().recip(),
            min_float: self.min_float,
            quiet_softmax: self.quiet_softmax,
        }
    }
}

impl<B: Backend> CrossAttention<B> {
    /// Applies cross-attention to query using context as key and value.
    ///
    /// # Arguments
    ///
    /// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`.
    /// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`.
    /// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask.
    ///
    /// # Returns
    ///
    /// Output tensor of shape `[batch, seq_len_query, d_model]`.
    pub fn forward(
        &self,
        query: Tensor<B, 3>,
        context: Tensor<B, 3>,
        mask: Option<Tensor<B, 2, Bool>>,
    ) -> Tensor<B, 3> {
        let [batch, l_q, _] = query.dims();
        let [_, l_k, _] = context.dims();

        // 1. Projections
        let q = self.query.forward(query);
        let k = self.key.forward(context.clone());
        let v = self.value.forward(context);

        // 2. Reshape Heads
        // Q: [Batch, Heads, L_q, D_head]
        let q = q
            .reshape([batch, l_q, self.n_heads, self.d_head])
            .swap_dims(1, 2);

        // K, V: [Batch, Heads_KV, L_k, D_head]
        let k = k
            .reshape([batch, l_k, self.n_heads_kv, self.d_head])
            .swap_dims(1, 2);
        let v = v
            .reshape([batch, l_k, self.n_heads_kv, self.d_head])
            .swap_dims(1, 2);

        // 3. GQA Expansion
        // ADVICE: Handle GQA by repeating KV heads to match Query heads
        let (k, v) = if self.n_heads != self.n_heads_kv {
            let n_rep = self.n_heads / self.n_heads_kv;
            (self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
        } else {
            (k, v)
        };

        // 4. Score Calculation
        let scores = q.matmul(k.transpose()) * self.scale;

        // 5. Masking
        // ADVICE: Use min_float for F16/FP8 safety
        let scores = if let Some(mask) = mask {
            let mask = mask.reshape([batch, 1, 1, l_k]);
            scores.mask_fill(mask, self.min_float)
        } else {
            scores
        };

        // 6. Softmax
        // ADVICE: Optional Quiet Softmax for sparse networks
        let weights = if self.quiet_softmax {
            quiet_softmax(scores, 3)
        } else {
            softmax(scores, 3)
        };

        let weights = self.dropout.forward(weights);

        // 7. Aggregate & Output
        let output = weights.matmul(v);
        let output = output
            .swap_dims(1, 2)
            .reshape([batch, l_q, self.n_heads * self.d_head]);

        self.output.forward(output)
    }

    /// Applies cross-attention to query using context as key and value.
    ///
    /// This method uses a cache to avoid recomputing key and value tensors when the context is the same.
    ///
    /// # Arguments
    ///
    /// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`.
    /// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`.
    /// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask.
    /// * `cache` - The cache to use.
    ///
    /// # Returns
    ///
    /// Output tensor of shape `[batch, seq_len_query, d_model]`.
    pub fn forward_cache(
        &self,
        query: Tensor<B, 3>,
        context: Tensor<B, 3>,
        mask: Option<Tensor<B, 2, Bool>>,
        cache: &mut CrossAttentionCache<B>,
    ) -> Tensor<B, 3> {
        let [batch, l_q, _] = query.dims();

        // 1. Projections
        let q = self.query.forward(query);

        let k_compute = |context: Tensor<B, 3>| {
            let [batch, l_k, _] = context.dims();
            self.key
                .forward(context)
                .reshape([batch, l_k, self.n_heads_kv, self.d_head])
                .swap_dims(1, 2)
        };
        let v_compute = |context: Tensor<B, 3>| {
            let [batch, l_k, _] = context.dims();
            self.value
                .forward(context)
                .reshape([batch, l_k, self.n_heads_kv, self.d_head])
                .swap_dims(1, 2)
        };

        let k = cache.k.forward_full(context.clone(), k_compute);
        let v = cache.v.forward_full(context, v_compute);

        let [_, _, l_k, _] = k.dims();

        // 2. Reshape Heads
        // Q: [Batch, Heads, L_q, D_head]
        let q = q
            .reshape([batch, l_q, self.n_heads, self.d_head])
            .swap_dims(1, 2);

        // K, V are already in their correct shape from k_compute and v_compute

        // 3. GQA Expansion
        // ADVICE: Handle GQA by repeating KV heads to match Query heads
        let (k, v) = if self.n_heads != self.n_heads_kv {
            let n_rep = self.n_heads / self.n_heads_kv;
            (self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
        } else {
            (k, v)
        };

        // 4. Score Calculation
        let scores = q.matmul(k.transpose()) * self.scale;

        // 5. Masking
        // ADVICE: Use min_float for F16/FP8 safety
        let scores = if let Some(mask) = mask {
            let mask = mask.reshape([batch, 1, 1, l_k]);
            scores.mask_fill(mask, self.min_float)
        } else {
            scores
        };

        // 6. Softmax
        // ADVICE: Optional Quiet Softmax for sparse networks
        let weights = if self.quiet_softmax {
            quiet_softmax(scores, 3)
        } else {
            softmax(scores, 3)
        };

        let weights = self.dropout.forward(weights);

        // 7. Aggregate & Output
        let output = weights.matmul(v);
        let output = output
            .swap_dims(1, 2)
            .reshape([batch, l_q, self.n_heads * self.d_head]);

        self.output.forward(output)
    }

    /// Helper for Grouped Query Attention
    fn repeat_kv(&self, x: Tensor<B, 4>, n_rep: usize) -> Tensor<B, 4> {
        let [b, h, l, d] = x.dims();
        x.reshape([b, h, 1, l, d])
            .expand([b, h, n_rep, l, d])
            .reshape([b, h * n_rep, l, d])
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::TestBackend;
    use burn::tensor::{Distribution, Int, Shape, Tensor, Tolerance};

    #[test]
    fn test_cross_attention_mha_shapes() {
        let [
            batch_size,
            seq_len_query,
            seq_len_context,
            d_model,
            d_context,
            n_heads,
            d_head,
        ] = [7, 13, 15, 32, 40, 4, 8];
        let device = Default::default();
        let config = CrossAttentionConfig {
            d_model,
            d_context,
            n_heads,
            n_heads_kv: n_heads, // MHA case
            d_head,
            dropout: 0.1,
            min_float: -1.0e4,
            quiet_softmax: false,
        };
        let cross_attn = config.init::<TestBackend>(&device);

        let query = Tensor::random(
            [batch_size, seq_len_query, d_model],
            Distribution::Default,
            &device,
        );
        let context = Tensor::random(
            [batch_size, seq_len_context, d_context],
            Distribution::Default,
            &device,
        );

        let output = cross_attn.forward(query, context, None);

        assert_eq!(
            output.shape(),
            Shape::new([batch_size, seq_len_query, d_model]),
            "Output should have the correct shape",
        );
    }

    #[test]
    fn test_cross_attention_gqa_shapes() {
        let [
            batch_size,
            seq_len_query,
            seq_len_context,
            d_model,
            d_context,
            n_heads,
            n_heads_kv,
            d_head,
        ] = [7, 13, 15, 32, 40, 4, 2, 8];
        let device = Default::default();
        let config = CrossAttentionConfig {
            d_model,
            d_context,
            n_heads,
            n_heads_kv, // GQA case
            d_head,
            dropout: 0.1,
            min_float: -1.0e4,
            quiet_softmax: false,
        };
        let cross_attn = config.init::<TestBackend>(&device);

        let query = Tensor::random(
            [batch_size, seq_len_query, d_model],
            Distribution::Default,
            &device,
        );
        let context = Tensor::random(
            [batch_size, seq_len_context, d_context],
            Distribution::Default,
            &device,
        );

        let output = cross_attn.forward(query, context, None);

        assert_eq!(
            output.shape(),
            Shape::new([batch_size, seq_len_query, d_model]),
            "Output should have the correct shape",
        );
    }

    #[test]
    fn test_cross_attention_mqa_shapes() {
        let [
            batch_size,
            seq_len_query,
            seq_len_context,
            d_model,
            d_context,
            n_heads,
            d_head,
        ] = [7, 13, 15, 32, 40, 4, 8];
        let device = Default::default();
        let config = CrossAttentionConfig {
            d_model,
            d_context,
            n_heads,
            n_heads_kv: 1, // MQA case
            d_head,
            dropout: 0.1,
            min_float: -1.0e4,
            quiet_softmax: false,
        };
        let cross_attn = config.init::<TestBackend>(&device);

        let query = Tensor::random(
            [batch_size, seq_len_query, d_model],
            Distribution::Default,
            &device,
        );
        let context = Tensor::random(
            [batch_size, seq_len_context, d_context],
            Distribution::Default,
            &device,
        );

        let output = cross_attn.forward(query, context, None);

        assert_eq!(
            output.shape(),
            Shape::new([batch_size, seq_len_query, d_model]),
            "Output should have the correct shape",
        );
    }

    #[test]
    fn test_cross_attention_mask() {
        let [
            batch_size,
            seq_len_query,
            seq_len_context,
            d_model,
            d_context,
            n_heads,
            d_head,
        ] = [3, 6, 8, 12, 16, 4, 3];
        let num_padded = 2;
        let device = Default::default();
        let config = CrossAttentionConfig {
            d_model,
            d_context,
            n_heads,
            n_heads_kv: n_heads,
            d_head,
            dropout: 0.0, // No dropout for deterministic test
            min_float: -1.0e4,
            quiet_softmax: false,
        };
        let cross_attn = config.init::<TestBackend>(&device);

        // Create a padding mask for the context
        let mut mask: Tensor<TestBackend, 2, Int> =
            Tensor::zeros([batch_size, seq_len_context], &device);
        mask = mask.slice_assign(
            [0..batch_size, seq_len_context - num_padded..seq_len_context],
            Tensor::ones([batch_size, num_padded], &device),
        );
        let mask_bool = mask.equal_elem(1);

        let query = Tensor::<TestBackend, 3>::random(
            [batch_size, seq_len_query, d_model],
            Distribution::Default,
            &device,
        );

        let context_1 = Tensor::<TestBackend, 3>::random(
            [batch_size, seq_len_context, d_context],
            Distribution::Default,
            &device,
        );

        // Change the padded part of the context tensor
        let context_2 = context_1.clone().slice_assign(
            [
                0..batch_size,
                seq_len_context - num_padded..seq_len_context,
                0..d_context,
            ],
            Tensor::random(
                [batch_size, num_padded, d_context],
                Distribution::Default,
                &device,
            ),
        );

        // The outputs should be the same since the changed part is masked.
        let output_1 = cross_attn.forward(query.clone(), context_1, Some(mask_bool.clone()));
        let output_2 = cross_attn.forward(query, context_2, Some(mask_bool));

        output_1
            .into_data()
            .assert_approx_eq(&output_2.into_data(), Tolerance::<f32>::default());
    }

    #[test]
    #[should_panic]
    fn test_gqa_panic_if_n_heads_not_divisible_by_n_heads_kv() {
        let device = Default::default();
        let config = CrossAttentionConfig {
            d_model: 32,
            d_context: 32,
            n_heads: 5,
            n_heads_kv: 2,
            d_head: 8,
            dropout: 0.1,
            min_float: -1.0e4,
            quiet_softmax: false,
        };
        config.init::<TestBackend>(&device);
    }

    #[test]
    fn test_cross_attention_cache() {
        let [
            batch_size,
            seq_len_query,
            seq_len_context,
            d_model,
            d_context,
            n_heads,
            d_head,
        ] = [3, 6, 8, 12, 16, 4, 3];
        let device = Default::default();
        let config = CrossAttentionConfig {
            d_model,
            d_context,
            n_heads,
            n_heads_kv: n_heads,
            d_head,
            dropout: 0.0, // No dropout for deterministic test
            min_float: -1.0e4,
            quiet_softmax: false,
        };
        let cross_attn = config.init::<TestBackend>(&device);

        let query1 = Tensor::<TestBackend, 3>::random(
            [batch_size, seq_len_query, d_model],
            Distribution::Default,
            &device,
        );
        let context = Tensor::<TestBackend, 3>::random(
            [batch_size, seq_len_context, d_context],
            Distribution::Default,
            &device,
        );

        // First forward pass, no cache
        let output1 = cross_attn.forward(query1.clone(), context.clone(), None);

        // Second forward pass with cache
        let mut cache = CrossAttentionCache::new();
        let output2 = cross_attn.forward_cache(query1.clone(), context.clone(), None, &mut cache);

        // The two outputs should be identical
        output1
            .into_data()
            .assert_approx_eq(&output2.into_data(), Tolerance::<f32>::default());

        // Third forward pass with different query, but same context and cache
        let query2 = Tensor::<TestBackend, 3>::random(
            [batch_size, seq_len_query, d_model],
            Distribution::Default,
            &device,
        );
        let output3 = cross_attn.forward_cache(query2.clone(), context.clone(), None, &mut cache);

        // For control, do a forward pass without cache with query2
        let output4 = cross_attn.forward(query2.clone(), context.clone(), None);

        // output3 and output4 should be identical
        output3
            .into_data()
            .assert_approx_eq(&output4.into_data(), Tolerance::<f32>::default());
    }
}