Skip to main content

oxigaf_diffusion/
flash_attention.rs

1//! Flash Attention: Memory-efficient attention with block-wise computation.
2//!
3//! This module implements the Flash Attention algorithm, which reduces memory
4//! complexity from O(N^2) to O(N) by using tiled computation with online softmax.
5//!
6//! # Algorithm Overview
7//!
8//! Instead of computing the full N x N attention matrix, Flash Attention:
9//! 1. Splits Q, K, V into blocks of size `block_size`
10//! 2. For each query block, iterates over all key/value blocks
11//! 3. Computes local attention scores for each block pair
12//! 4. Uses online softmax with running max/sum for numerical stability
13//! 5. Accumulates weighted values with proper rescaling
14//!
15//! # References
16//!
17//! - Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention
18//!   with IO-Awareness", NeurIPS 2022
19
20use candle_core::{DType, Result, Tensor, D};
21
22/// Configuration for Flash Attention computation.
23#[derive(Debug, Clone, Copy)]
24pub struct FlashAttentionConfig {
25    /// Block size for tiled computation. Larger blocks use more memory but
26    /// may be faster due to better cache utilization. Default: 64.
27    pub block_size: usize,
28    /// Whether to use causal masking (for autoregressive models). Default: false.
29    pub causal: bool,
30    /// Epsilon for numerical stability in softmax. Default: 1e-6.
31    pub softmax_eps: f64,
32}
33
34impl Default for FlashAttentionConfig {
35    fn default() -> Self {
36        Self {
37            block_size: 64,
38            causal: false,
39            softmax_eps: 1e-6,
40        }
41    }
42}
43
44impl FlashAttentionConfig {
45    /// Create a new Flash Attention config with specified block size.
46    pub fn with_block_size(block_size: usize) -> Self {
47        Self {
48            block_size,
49            ..Default::default()
50        }
51    }
52
53    /// Enable causal masking for autoregressive attention.
54    #[allow(dead_code)]
55    pub fn with_causal(mut self, causal: bool) -> Self {
56        self.causal = causal;
57        self
58    }
59}
60
61/// Flash Attention: Memory-efficient scaled dot-product attention.
62///
63/// This struct provides the flash attention computation without maintaining
64/// learned parameters. It is designed to be used within attention modules
65/// that handle Q, K, V projections separately.
66#[derive(Debug, Clone)]
67pub struct FlashAttention {
68    config: FlashAttentionConfig,
69    scale: f64,
70}
71
72impl FlashAttention {
73    /// Create a new Flash Attention module.
74    ///
75    /// # Arguments
76    ///
77    /// * `dim_head` - Dimension of each attention head (for scaling)
78    /// * `config` - Flash Attention configuration
79    pub fn new(dim_head: usize, config: FlashAttentionConfig) -> Self {
80        let scale = 1.0 / (dim_head as f64).sqrt();
81        Self { config, scale }
82    }
83
84    /// Create with default configuration.
85    pub fn with_dim_head(dim_head: usize) -> Self {
86        Self::new(dim_head, FlashAttentionConfig::default())
87    }
88
89    /// Compute flash attention.
90    ///
91    /// # Arguments
92    ///
93    /// * `q` - Query tensor of shape `(batch, heads, seq_q, dim_head)`
94    /// * `k` - Key tensor of shape `(batch, heads, seq_k, dim_head)`
95    /// * `v` - Value tensor of shape `(batch, heads, seq_k, dim_head)`
96    ///
97    /// # Returns
98    ///
99    /// Output tensor of shape `(batch, heads, seq_q, dim_head)`
100    pub fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
101        let (batch, heads, seq_q, dim_head) = q.dims4()?;
102        let (_, _, seq_k, _) = k.dims4()?;
103
104        // For small sequences, use standard attention (simpler and often faster)
105        let use_standard = seq_q <= self.config.block_size && seq_k <= self.config.block_size;
106        if use_standard {
107            return self.standard_attention(q, k, v);
108        }
109
110        // Compute in f32 for numerical stability
111        let in_dtype = q.dtype();
112        let q = q.to_dtype(DType::F32)?;
113        let k = k.to_dtype(DType::F32)?;
114        let v = v.to_dtype(DType::F32)?;
115
116        // Block-wise computation
117        let block_size = self.config.block_size;
118        let num_q_blocks = seq_q.div_ceil(block_size);
119        let num_k_blocks = seq_k.div_ceil(block_size);
120
121        // Initialize output accumulator and softmax statistics
122        let device = q.device();
123        let neg_inf = f32::NEG_INFINITY;
124
125        // Process each query block
126        let mut output_blocks: Vec<Tensor> = Vec::with_capacity(num_q_blocks);
127
128        for q_block_idx in 0..num_q_blocks {
129            let q_start = q_block_idx * block_size;
130            let q_end = (q_start + block_size).min(seq_q);
131            let q_len = q_end - q_start;
132
133            // Extract query block: (batch, heads, q_len, dim_head)
134            let q_block = q.narrow(2, q_start, q_len)?;
135
136            // Initialize accumulators for this query block
137            // m: running max, shape (batch, heads, q_len)
138            // l: running sum, shape (batch, heads, q_len)
139            // o: running output, shape (batch, heads, q_len, dim_head)
140            let mut m = Tensor::full(neg_inf, (batch, heads, q_len), device)?;
141            let mut l = Tensor::zeros((batch, heads, q_len), DType::F32, device)?;
142            let mut o = Tensor::zeros((batch, heads, q_len, dim_head), DType::F32, device)?;
143
144            // Iterate over key/value blocks
145            for k_block_idx in 0..num_k_blocks {
146                let k_start = k_block_idx * block_size;
147                let k_end = (k_start + block_size).min(seq_k);
148                let k_len = k_end - k_start;
149
150                // Check causal mask: skip if this KV block is entirely in the future
151                if self.config.causal && k_start >= q_end {
152                    continue;
153                }
154
155                // Extract key and value blocks
156                let k_block = k.narrow(2, k_start, k_len)?;
157                let v_block = v.narrow(2, k_start, k_len)?;
158
159                // Compute attention scores for this block: (batch, heads, q_len, k_len)
160                let k_t = k_block.transpose(D::Minus2, D::Minus1)?;
161                let scores = (q_block.matmul(&k_t)? * self.scale)?;
162
163                // Apply causal mask if needed
164                let scores = if self.config.causal {
165                    self.apply_causal_mask(&scores, q_start, k_start)?
166                } else {
167                    scores
168                };
169
170                // Online softmax update
171                let (m_new, l_new, o_new) =
172                    self.online_softmax_update(&m, &l, &o, &scores, &v_block)?;
173
174                m = m_new;
175                l = l_new;
176                o = o_new;
177            }
178
179            // Finalize output for this block: o = o / l
180            let l_expanded = l.unsqueeze(D::Minus1)?;
181            let l_safe = (l_expanded + self.config.softmax_eps)?;
182            let block_output = o.broadcast_div(&l_safe)?;
183
184            output_blocks.push(block_output);
185        }
186
187        // Concatenate all output blocks along sequence dimension
188        let output = Tensor::cat(&output_blocks, 2)?;
189        output.to_dtype(in_dtype)
190    }
191
192    /// Standard attention for small sequences (fallback path).
193    fn standard_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
194        let in_dtype = q.dtype();
195        let q = q.to_dtype(DType::F32)?;
196        let k = k.to_dtype(DType::F32)?;
197        let v = v.to_dtype(DType::F32)?;
198
199        let k_t = k.transpose(D::Minus2, D::Minus1)?.contiguous()?;
200        let attn = (q.matmul(&k_t)? * self.scale)?;
201        let attn = candle_nn::ops::softmax_last_dim(&attn)?;
202        let out = attn.matmul(&v)?;
203        out.to_dtype(in_dtype)
204    }
205
206    /// Online softmax update step.
207    ///
208    /// Given current statistics (m, l, o) and new block scores, compute updated
209    /// statistics using the online softmax algorithm.
210    ///
211    /// # Arguments
212    ///
213    /// * `m` - Current running max, shape (batch, heads, q_len)
214    /// * `l` - Current running sum, shape (batch, heads, q_len)
215    /// * `o` - Current running output, shape (batch, heads, q_len, dim_head)
216    /// * `scores` - New block scores, shape (batch, heads, q_len, k_len)
217    /// * `v_block` - Value block, shape (batch, heads, k_len, dim_head)
218    fn online_softmax_update(
219        &self,
220        m: &Tensor,
221        l: &Tensor,
222        o: &Tensor,
223        scores: &Tensor,
224        v_block: &Tensor,
225    ) -> Result<(Tensor, Tensor, Tensor)> {
226        // Compute block max: max over k dimension
227        // scores: (batch, heads, q_len, k_len) -> m_block: (batch, heads, q_len)
228        let m_block = scores.max(D::Minus1)?;
229
230        // Compute new global max
231        let m_new = m.maximum(&m_block)?;
232
233        // Rescale old statistics
234        // exp(m - m_new) for old values
235        let m_diff_old = m.broadcast_sub(&m_new)?;
236        let rescale_old = m_diff_old.exp()?;
237
238        // exp(m_block - m_new) for new block (computed but kept for debugging/clarity)
239        // The actual rescaling is done implicitly when computing p_block below
240        let m_diff_new = m_block.broadcast_sub(&m_new)?;
241        let _rescale_new = m_diff_new.exp()?;
242
243        // Compute softmax for current block (unnormalized)
244        // p_block = exp(scores - m_new)
245        let m_new_expanded = m_new.unsqueeze(D::Minus1)?;
246        let p_block = scores.broadcast_sub(&m_new_expanded)?.exp()?;
247
248        // Sum over k dimension for new block contribution to l
249        let l_block = p_block.sum(D::Minus1)?;
250
251        // Update l: l_new = l * rescale_old + l_block
252        let l_new = (l.mul(&rescale_old)? + l_block)?;
253
254        // Update o: o_new = o * rescale_old + p_block @ v_block
255        let rescale_old_expanded = rescale_old.unsqueeze(D::Minus1)?;
256        let o_rescaled = o.broadcast_mul(&rescale_old_expanded)?;
257        let pv = p_block.matmul(v_block)?;
258        let o_new = (o_rescaled + pv)?;
259
260        Ok((m_new, l_new, o_new))
261    }
262
263    /// Apply causal mask to attention scores.
264    ///
265    /// Sets scores to -inf where key position > query position.
266    fn apply_causal_mask(&self, scores: &Tensor, q_start: usize, k_start: usize) -> Result<Tensor> {
267        let (batch, heads, q_len, k_len) = scores.dims4()?;
268        let device = scores.device();
269
270        // Create causal mask: mask[i,j] = true if k_start + j > q_start + i
271        let mut mask_data = vec![0.0f32; q_len * k_len];
272        let neg_inf = f32::NEG_INFINITY;
273
274        for i in 0..q_len {
275            let q_pos = q_start + i;
276            for j in 0..k_len {
277                let k_pos = k_start + j;
278                if k_pos > q_pos {
279                    mask_data[i * k_len + j] = neg_inf;
280                }
281            }
282        }
283
284        let mask = Tensor::from_vec(mask_data, (1, 1, q_len, k_len), device)?;
285        let mask = mask.broadcast_as((batch, heads, q_len, k_len))?;
286        scores.add(&mask)
287    }
288}
289
290/// Compute flash attention with default settings.
291///
292/// This is a convenience function for one-off attention computations.
293///
294/// # Arguments
295///
296/// * `q` - Query tensor of shape `(batch, heads, seq_q, dim_head)`
297/// * `k` - Key tensor of shape `(batch, heads, seq_k, dim_head)`
298/// * `v` - Value tensor of shape `(batch, heads, seq_k, dim_head)`
299/// * `dim_head` - Dimension of each attention head
300///
301/// # Returns
302///
303/// Output tensor of shape `(batch, heads, seq_q, dim_head)`
304pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, dim_head: usize) -> Result<Tensor> {
305    let flash = FlashAttention::with_dim_head(dim_head);
306    flash.forward(q, k, v)
307}
308
309/// Compute flash attention with custom configuration.
310pub fn flash_attention_with_config(
311    q: &Tensor,
312    k: &Tensor,
313    v: &Tensor,
314    dim_head: usize,
315    config: FlashAttentionConfig,
316) -> Result<Tensor> {
317    let flash = FlashAttention::new(dim_head, config);
318    flash.forward(q, k, v)
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use approx::assert_relative_eq;
325    use candle_core::Device;
326
327    fn create_test_tensors(
328        batch: usize,
329        heads: usize,
330        seq_q: usize,
331        seq_k: usize,
332        dim_head: usize,
333        device: &Device,
334    ) -> Result<(Tensor, Tensor, Tensor)> {
335        // Create deterministic test data
336        let q_size = batch * heads * seq_q * dim_head;
337        let k_size = batch * heads * seq_k * dim_head;
338
339        let q_data: Vec<f32> = (0..q_size).map(|i| (i as f32 * 0.01).sin()).collect();
340        let k_data: Vec<f32> = (0..k_size).map(|i| (i as f32 * 0.02).cos()).collect();
341        let v_data: Vec<f32> = (0..k_size).map(|i| (i as f32 * 0.03).sin()).collect();
342
343        let q = Tensor::from_vec(q_data, (batch, heads, seq_q, dim_head), device)?;
344        let k = Tensor::from_vec(k_data, (batch, heads, seq_k, dim_head), device)?;
345        let v = Tensor::from_vec(v_data, (batch, heads, seq_k, dim_head), device)?;
346
347        Ok((q, k, v))
348    }
349
350    fn standard_attention(q: &Tensor, k: &Tensor, v: &Tensor, scale: f64) -> Result<Tensor> {
351        let q = q.to_dtype(DType::F32)?;
352        let k = k.to_dtype(DType::F32)?;
353        let v = v.to_dtype(DType::F32)?;
354
355        let k_t = k.transpose(D::Minus2, D::Minus1)?;
356        let attn = (q.matmul(&k_t)? * scale)?;
357        let attn = candle_nn::ops::softmax_last_dim(&attn)?;
358        attn.matmul(&v)
359    }
360
361    #[test]
362    fn test_flash_attention_small_sequence() -> Result<()> {
363        let device = Device::Cpu;
364        let batch = 2;
365        let heads = 4;
366        let seq_len = 32;
367        let dim_head = 64;
368
369        let (q, k, v) = create_test_tensors(batch, heads, seq_len, seq_len, dim_head, &device)?;
370
371        let flash = FlashAttention::with_dim_head(dim_head);
372        let flash_out = flash.forward(&q, &k, &v)?;
373
374        let scale = 1.0 / (dim_head as f64).sqrt();
375        let std_out = standard_attention(&q, &k, &v, scale)?;
376
377        // Compare outputs
378        let flash_vec: Vec<f32> = flash_out.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
379        let std_vec: Vec<f32> = std_out.flatten_all()?.to_vec1()?;
380
381        assert_eq!(flash_vec.len(), std_vec.len());
382        for (f, s) in flash_vec.iter().zip(std_vec.iter()) {
383            assert_relative_eq!(f, s, epsilon = 1e-4);
384        }
385
386        Ok(())
387    }
388
389    #[test]
390    fn test_flash_attention_large_sequence() -> Result<()> {
391        let device = Device::Cpu;
392        let batch = 1;
393        let heads = 2;
394        let seq_len = 128; // Larger than default block size (64)
395        let dim_head = 32;
396
397        let (q, k, v) = create_test_tensors(batch, heads, seq_len, seq_len, dim_head, &device)?;
398
399        let flash = FlashAttention::with_dim_head(dim_head);
400        let flash_out = flash.forward(&q, &k, &v)?;
401
402        let scale = 1.0 / (dim_head as f64).sqrt();
403        let std_out = standard_attention(&q, &k, &v, scale)?;
404
405        // Compare outputs
406        let flash_vec: Vec<f32> = flash_out.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
407        let std_vec: Vec<f32> = std_out.flatten_all()?.to_vec1()?;
408
409        assert_eq!(flash_vec.len(), std_vec.len());
410        for (f, s) in flash_vec.iter().zip(std_vec.iter()) {
411            assert_relative_eq!(f, s, epsilon = 1e-3);
412        }
413
414        Ok(())
415    }
416
417    #[test]
418    fn test_flash_attention_asymmetric_sequences() -> Result<()> {
419        let device = Device::Cpu;
420        let batch = 1;
421        let heads = 2;
422        let seq_q = 100;
423        let seq_k = 150;
424        let dim_head = 32;
425
426        let (q, k, v) = create_test_tensors(batch, heads, seq_q, seq_k, dim_head, &device)?;
427
428        let flash = FlashAttention::with_dim_head(dim_head);
429        let flash_out = flash.forward(&q, &k, &v)?;
430
431        let scale = 1.0 / (dim_head as f64).sqrt();
432        let std_out = standard_attention(&q, &k, &v, scale)?;
433
434        // Compare outputs
435        let flash_vec: Vec<f32> = flash_out.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
436        let std_vec: Vec<f32> = std_out.flatten_all()?.to_vec1()?;
437
438        assert_eq!(flash_vec.len(), std_vec.len());
439        for (f, s) in flash_vec.iter().zip(std_vec.iter()) {
440            assert_relative_eq!(f, s, epsilon = 1e-3);
441        }
442
443        Ok(())
444    }
445
446    #[test]
447    fn test_flash_attention_output_shape() -> Result<()> {
448        let device = Device::Cpu;
449        let batch = 2;
450        let heads = 4;
451        let seq_q = 96;
452        let seq_k = 128;
453        let dim_head = 64;
454
455        let (q, k, v) = create_test_tensors(batch, heads, seq_q, seq_k, dim_head, &device)?;
456
457        let flash = FlashAttention::with_dim_head(dim_head);
458        let out = flash.forward(&q, &k, &v)?;
459
460        assert_eq!(out.dims(), &[batch, heads, seq_q, dim_head]);
461
462        Ok(())
463    }
464
465    #[test]
466    fn test_flash_attention_single_element() -> Result<()> {
467        let device = Device::Cpu;
468        let batch = 1;
469        let heads = 1;
470        let seq_len = 1;
471        let dim_head = 16;
472
473        let (q, k, v) = create_test_tensors(batch, heads, seq_len, seq_len, dim_head, &device)?;
474
475        let flash = FlashAttention::with_dim_head(dim_head);
476        let flash_out = flash.forward(&q, &k, &v)?;
477
478        // For single element, output should equal value (softmax of single element is 1)
479        let flash_vec: Vec<f32> = flash_out.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
480        let v_vec: Vec<f32> = v.flatten_all()?.to_vec1()?;
481
482        for (f, vv) in flash_vec.iter().zip(v_vec.iter()) {
483            assert_relative_eq!(f, vv, epsilon = 1e-5);
484        }
485
486        Ok(())
487    }
488
489    #[test]
490    fn test_flash_attention_config_block_size() -> Result<()> {
491        let device = Device::Cpu;
492        let batch = 1;
493        let heads = 2;
494        let seq_len = 200;
495        let dim_head = 32;
496
497        let (q, k, v) = create_test_tensors(batch, heads, seq_len, seq_len, dim_head, &device)?;
498
499        // Test with different block sizes
500        for block_size in [32, 64, 128] {
501            let config = FlashAttentionConfig::with_block_size(block_size);
502            let flash = FlashAttention::new(dim_head, config);
503            let flash_out = flash.forward(&q, &k, &v)?;
504
505            let scale = 1.0 / (dim_head as f64).sqrt();
506            let std_out = standard_attention(&q, &k, &v, scale)?;
507
508            let flash_vec: Vec<f32> = flash_out.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
509            let std_vec: Vec<f32> = std_out.flatten_all()?.to_vec1()?;
510
511            for (f, s) in flash_vec.iter().zip(std_vec.iter()) {
512                assert_relative_eq!(f, s, epsilon = 1e-3);
513            }
514        }
515
516        Ok(())
517    }
518
519    #[test]
520    fn test_flash_attention_convenience_function() -> Result<()> {
521        let device = Device::Cpu;
522        let batch = 1;
523        let heads = 2;
524        let seq_len = 64;
525        let dim_head = 32;
526
527        let (q, k, v) = create_test_tensors(batch, heads, seq_len, seq_len, dim_head, &device)?;
528
529        let out = flash_attention(&q, &k, &v, dim_head)?;
530        assert_eq!(out.dims(), &[batch, heads, seq_len, dim_head]);
531
532        Ok(())
533    }
534}