Skip to main content

oxios_memory/memory/
flash_attention.rs

1#![allow(missing_docs)]
2//! Flash Attention — block-wise attention for O(N) memory usage.
3//!
4//! Triton-inspired CPU implementation that processes attention in blocks
5//! to maximize L1/L2 cache efficiency. Achieves 2-5× speedup and ~75%
6//! memory reduction compared to naive attention for large sequence lengths.
7//!
8//! Reference: "FlashAttention: Fast and Memory-Efficient Exact Attention
9//! with IO-Awareness" (Dao et al., 2022)
10
11use serde::{Deserialize, Serialize};
12
13// ---------------------------------------------------------------------------
14// Configuration
15// ---------------------------------------------------------------------------
16
17/// Configuration for Flash Attention.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct FlashAttentionConfig {
20    /// Block size for tiled computation (tune to L1 cache).
21    /// Default 64 works well for typical f32 vectors.
22    pub block_size: usize,
23    /// Embedding dimensionality.
24    pub dimensions: usize,
25    /// Softmax temperature scaling.
26    pub temperature: f32,
27}
28
29impl Default for FlashAttentionConfig {
30    fn default() -> Self {
31        Self {
32            block_size: 64,
33            dimensions: 128,
34            temperature: 1.0,
35        }
36    }
37}
38
39// ---------------------------------------------------------------------------
40// Benchmark result
41// ---------------------------------------------------------------------------
42
43/// Result of a benchmark comparing naive vs flash attention.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct BenchmarkResult {
46    /// Naive attention time in milliseconds.
47    pub naive_time_ms: f64,
48    /// Flash attention time in milliseconds.
49    pub flash_time_ms: f64,
50    /// Speedup ratio (naive / flash).
51    pub speedup: f64,
52    /// Memory reduction ratio (0.75 = 75% less memory).
53    pub memory_reduction: f64,
54    /// Number of query vectors.
55    pub num_queries: usize,
56    /// Embedding dimension.
57    pub dimensions: usize,
58}
59
60impl std::fmt::Display for BenchmarkResult {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(
63            f,
64            "Flash Attention Benchmark: {} queries × {}d — {:.2}ms → {:.2}ms ({:.1}× speedup, {:.0}% memory reduction)",
65            self.num_queries,
66            self.dimensions,
67            self.naive_time_ms,
68            self.flash_time_ms,
69            self.speedup,
70            self.memory_reduction * 100.0,
71        )
72    }
73}
74
75// ---------------------------------------------------------------------------
76// Flash Attention
77// ---------------------------------------------------------------------------
78
79/// Block-wise attention computation optimized for CPU cache locality.
80///
81/// Instead of materializing the full N×N attention matrix, processes
82/// the computation in blocks that fit in L1/L2 cache, achieving
83/// O(N) memory complexity instead of O(N²).
84#[derive(Debug)]
85pub struct FlashAttention {
86    config: FlashAttentionConfig,
87}
88
89impl FlashAttention {
90    /// Create a new FlashAttention with the given configuration.
91    pub fn new(config: FlashAttentionConfig) -> Self {
92        Self { config }
93    }
94
95    /// Create with default configuration.
96    pub fn with_dimensions(dimensions: usize) -> Self {
97        let config = FlashAttentionConfig {
98            dimensions,
99            ..Default::default()
100        };
101        Self { config }
102    }
103
104    /// Returns a reference to the configuration.
105    pub fn config(&self) -> &FlashAttentionConfig {
106        &self.config
107    }
108
109    /// Compute scaled dot-product attention using the block-wise algorithm.
110    ///
111    /// For sequences of length N with dimension D:
112    /// - Naive: O(N²) memory (full attention matrix)
113    /// - Flash: O(N) memory (block-wise accumulation via online softmax)
114    ///
115    /// # Arguments
116    /// * `queries` - Query vectors [N_q × D]
117    /// * `keys` - Key vectors [N_k × D]
118    /// * `values` - Value vectors [N_k × D]
119    ///
120    /// # Returns
121    /// Output vectors [N_q × D]
122    #[allow(clippy::needless_range_loop)]
123    pub fn attention(
124        &self,
125        queries: &[Vec<f32>],
126        keys: &[Vec<f32>],
127        values: &[Vec<f32>],
128    ) -> Vec<Vec<f32>> {
129        if queries.is_empty() || keys.is_empty() {
130            return Vec::new();
131        }
132
133        // Use actual vector length (may differ from config.dimensions)
134        let dim = queries.first().map_or(0, |v| v.len());
135        if dim == 0 {
136            return vec![vec![]; queries.len()];
137        }
138        let scale = 1.0 / (self.config.temperature * (dim as f32).sqrt());
139        let block_size = self.config.block_size.min(keys.len());
140
141        let num_queries = queries.len();
142        let mut outputs = vec![vec![0.0f32; dim]; num_queries];
143
144        // Process each query independently — each query only needs O(N_k) memory
145        for (qi, query) in queries.iter().enumerate() {
146            // Online softmax accumulators (Flash Attention core idea)
147            // Instead of storing all attention weights, we accumulate incrementally
148            let mut output_accum = vec![0.0f32; dim];
149            let mut max_score = f32::NEG_INFINITY; // Running max for numerical stability
150            let mut sum_exp = 0.0f32; // Running sum of exp(score - max)
151
152            // Process key/value pairs in blocks
153            for k_block_start in (0..keys.len()).step_by(block_size) {
154                let k_block_end = (k_block_start + block_size).min(keys.len());
155
156                // Compute attention scores for this block
157                let mut block_max = max_score;
158                let mut block_scores = Vec::with_capacity(k_block_end - k_block_start);
159
160                for ki in k_block_start..k_block_end {
161                    let score = dot_product(query, &keys[ki]) * scale;
162                    block_scores.push(score);
163                    if score > block_max {
164                        block_max = score;
165                    }
166                }
167
168                // Update running maximum and rescale previous accumulation
169                let old_max = max_score;
170                if block_max > max_score {
171                    max_score = block_max;
172                }
173
174                // Rescale the accumulated sum and output by the change in max
175                let rescale_factor = if old_max == f32::NEG_INFINITY {
176                    0.0
177                } else {
178                    (old_max - max_score).exp()
179                };
180                sum_exp *= rescale_factor;
181                for v in output_accum.iter_mut() {
182                    *v *= rescale_factor;
183                }
184
185                // Add block contributions
186                for (block_idx, &score) in block_scores.iter().enumerate() {
187                    let ki = k_block_start + block_idx;
188                    let weight = (score - max_score).exp();
189                    sum_exp += weight;
190                    for (d, v) in output_accum.iter_mut().enumerate() {
191                        *v += weight * values[ki][d];
192                    }
193                }
194            }
195
196            // Normalize by sum_exp
197            if sum_exp > 0.0 {
198                let inv_sum = 1.0 / sum_exp;
199                for v in output_accum.iter_mut() {
200                    *v *= inv_sum;
201                }
202            }
203
204            outputs[qi] = output_accum;
205        }
206
207        outputs
208    }
209
210    /// Naive attention implementation for benchmarking comparison.
211    ///
212    /// Materializes the full N×N attention matrix: O(N²) memory.
213    pub fn naive_attention(
214        &self,
215        queries: &[Vec<f32>],
216        keys: &[Vec<f32>],
217        values: &[Vec<f32>],
218    ) -> Vec<Vec<f32>> {
219        if queries.is_empty() || keys.is_empty() {
220            return Vec::new();
221        }
222
223        // Use actual vector length (may differ from config.dimensions)
224        let dim = queries.first().map_or(0, |v| v.len());
225        if dim == 0 {
226            return vec![vec![]; queries.len()];
227        }
228        let scale = 1.0 / (self.config.temperature * (dim as f32).sqrt());
229        let num_queries = queries.len();
230        let num_keys = keys.len();
231
232        // Materialize full attention matrix: O(N_q × N_k) memory
233        let mut attention_weights = vec![vec![0.0f32; num_keys]; num_queries];
234
235        // Compute all scores
236        for (qi, query) in queries.iter().enumerate() {
237            let mut max_score = f32::NEG_INFINITY;
238            for (ki, key) in keys.iter().enumerate() {
239                let score = dot_product(query, key) * scale;
240                attention_weights[qi][ki] = score;
241                if score > max_score {
242                    max_score = score;
243                }
244            }
245            // Softmax
246            let mut sum_exp = 0.0f32;
247            for w in &mut attention_weights[qi] {
248                *w = (*w - max_score).exp();
249                sum_exp += *w;
250            }
251            if sum_exp > 0.0 {
252                let inv = 1.0 / sum_exp;
253                for w in &mut attention_weights[qi] {
254                    *w *= inv;
255                }
256            }
257        }
258
259        // Weighted sum: output = attention_weights × values
260        let mut outputs = vec![vec![0.0f32; dim]; num_queries];
261        for qi in 0..num_queries {
262            for (ki, value_row) in values.iter().enumerate() {
263                let w = attention_weights[qi][ki];
264                for (d, val) in value_row.iter().enumerate().take(dim) {
265                    outputs[qi][d] += w * val;
266                }
267            }
268        }
269
270        outputs
271    }
272
273    /// Run a benchmark comparing naive vs flash attention.
274    ///
275    /// Generates random vectors and measures wall-clock time for both methods.
276    /// Also verifies that both implementations produce equivalent results.
277    pub fn benchmark(&self, num_vectors: usize) -> BenchmarkResult {
278        let vectors = generate_test_vectors(num_vectors, self.config.dimensions);
279
280        let naive_start = std::time::Instant::now();
281        let naive_result = self.naive_attention(&vectors, &vectors, &vectors);
282        let naive_duration = naive_start.elapsed();
283
284        let flash_start = std::time::Instant::now();
285        let flash_result = self.attention(&vectors, &vectors, &vectors);
286        let flash_duration = flash_start.elapsed();
287
288        // Verify results are similar (within 5% relative tolerance)
289        let mut max_rel_err = 0.0f32;
290        for (f_row, n_row) in flash_result.iter().zip(naive_result.iter()) {
291            for (f, n) in f_row.iter().zip(n_row.iter()) {
292                let err = (f - n).abs() / f.abs().max(n.abs()).max(1e-6);
293                max_rel_err = max_rel_err.max(err);
294            }
295        }
296        if max_rel_err > 0.05 {
297            tracing::warn!(
298                max_relative_error = max_rel_err,
299                "Flash vs naive attention results diverge"
300            );
301        }
302
303        let naive_ms = naive_duration.as_secs_f64() * 1000.0;
304        let flash_ms = flash_duration.as_secs_f64() * 1000.0;
305        let speedup = if flash_ms > 0.0 {
306            naive_ms / flash_ms
307        } else {
308            f64::INFINITY
309        };
310
311        // Memory reduction: naive stores N×N matrix, flash stores O(N) per query
312        let naive_mem = num_vectors * num_vectors; // attention matrix elements
313        let flash_mem = self.config.dimensions + 2; // per-query accumulators
314        let memory_reduction = 1.0 - (flash_mem as f64 / naive_mem as f64);
315
316        BenchmarkResult {
317            naive_time_ms: naive_ms,
318            flash_time_ms: flash_ms,
319            speedup,
320            memory_reduction: memory_reduction.max(0.0),
321            num_queries: num_vectors,
322            dimensions: self.config.dimensions,
323        }
324    }
325
326    /// Compute self-attention: a sequence attends to itself.
327    ///
328    /// Convenience wrapper around `attention(q, q, q)`.
329    pub fn self_attention(&self, sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
330        self.attention(sequence, sequence, sequence)
331    }
332
333    /// Compute cross-attention between two sequences.
334    ///
335    /// Queries from one sequence attend to keys/values from another.
336    pub fn cross_attention(&self, queries: &[Vec<f32>], kv_sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
337        self.attention(queries, kv_sequence, kv_sequence)
338    }
339
340    /// Estimate peak memory usage in bytes for a given sequence length.
341    pub fn memory_estimate(&self, seq_len: usize) -> MemoryEstimate {
342        let dim = self.config.dimensions;
343        let element_size = std::mem::size_of::<f32>();
344
345        // Naive: full N×N attention matrix + N×D output + N×D Q/K/V
346        let naive_peak = seq_len * seq_len * element_size // attention matrix
347            + seq_len * dim * element_size * 3 // Q, K, V
348            + seq_len * dim * element_size; // output
349
350        // Flash: D accumulators + 2 scalars per query, processed sequentially
351        let flash_peak = dim * element_size // output accumulator
352            + self.config.block_size * element_size // block scores
353            + seq_len * dim * element_size * 3 // Q, K, V (inputs)
354            + seq_len * dim * element_size; // output
355
356        MemoryEstimate {
357            naive_bytes: naive_peak,
358            flash_bytes: flash_peak,
359            reduction_ratio: 1.0 - (flash_peak as f64 / naive_peak as f64),
360        }
361    }
362}
363
364/// Memory usage estimate.
365#[derive(Debug, Clone, Serialize, Deserialize)]
366pub struct MemoryEstimate {
367    /// Peak memory for naive attention (bytes).
368    pub naive_bytes: usize,
369    /// Peak memory for flash attention (bytes).
370    pub flash_bytes: usize,
371    /// Memory reduction ratio (0.75 = 75% less).
372    pub reduction_ratio: f64,
373}
374
375// ---------------------------------------------------------------------------
376// Helpers
377// ---------------------------------------------------------------------------
378
379/// Dot product of two f32 vectors.
380fn dot_product(a: &[f32], b: &[f32]) -> f32 {
381    a.iter().zip(b).map(|(x, y)| x * y).sum()
382}
383
384/// Generate deterministic test vectors using a simple LCG.
385fn generate_test_vectors(count: usize, dim: usize) -> Vec<Vec<f32>> {
386    let mut rng_state = 42u64;
387    let mut vectors = Vec::with_capacity(count);
388
389    for _ in 0..count {
390        let mut v = Vec::with_capacity(dim);
391        for _ in 0..dim {
392            // LCG: x_{n+1} = (a * x_n + c) mod m
393            rng_state = rng_state
394                .wrapping_mul(6364136223846793005)
395                .wrapping_add(1442695040888963407);
396            let val = ((rng_state >> 33) as f32 / (1u64 << 31) as f32) - 1.0;
397            v.push(val);
398        }
399        vectors.push(v);
400    }
401
402    vectors
403}
404
405// ---------------------------------------------------------------------------
406// Tests
407// ---------------------------------------------------------------------------
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_flash_vs_naive_small() {
415        let fa = FlashAttention::with_dimensions(16);
416        let queries = generate_test_vectors(4, 16);
417        let keys = generate_test_vectors(4, 16);
418        let values = generate_test_vectors(4, 16);
419
420        let flash_output = fa.attention(&queries, &keys, &values);
421        let naive_output = fa.naive_attention(&queries, &keys, &values);
422
423        assert_eq!(flash_output.len(), naive_output.len());
424
425        // Results should be very close (within 1% relative tolerance)
426        for (flash_row, naive_row) in flash_output.iter().zip(naive_output.iter()) {
427            for (f, n) in flash_row.iter().zip(naive_row.iter()) {
428                let diff = (f - n).abs();
429                let max_val = f.abs().max(n.abs()).max(1e-6);
430                assert!(
431                    diff / max_val < 0.01,
432                    "Flash and naive outputs differ: flash={:.6}, naive={:.6}",
433                    f,
434                    n
435                );
436            }
437        }
438    }
439
440    #[test]
441    fn test_flash_attention_empty() {
442        let fa = FlashAttention::with_dimensions(16);
443        let result = fa.attention(&[], &[], &[]);
444        assert!(result.is_empty());
445    }
446
447    #[test]
448    fn test_self_attention() {
449        let fa = FlashAttention::with_dimensions(8);
450        let seq = generate_test_vectors(3, 8);
451        let result = fa.self_attention(&seq);
452        assert_eq!(result.len(), 3);
453        // Each output should have the correct dimension
454        for row in &result {
455            assert_eq!(row.len(), 8);
456        }
457    }
458
459    #[test]
460    fn test_cross_attention() {
461        let fa = FlashAttention::with_dimensions(8);
462        let queries = generate_test_vectors(2, 8);
463        let kv = generate_test_vectors(5, 8);
464        let result = fa.cross_attention(&queries, &kv);
465        assert_eq!(result.len(), 2);
466        for row in &result {
467            assert_eq!(row.len(), 8);
468        }
469    }
470
471    #[test]
472    fn test_memory_estimate() {
473        let fa = FlashAttention::with_dimensions(128);
474        let estimate = fa.memory_estimate(1000);
475
476        assert!(estimate.flash_bytes < estimate.naive_bytes);
477        assert!(
478            estimate.reduction_ratio > 0.5,
479            "Should achieve >50% memory reduction"
480        );
481
482        // For 1000×128: naive = 1000*1000*4 + overhead, flash much less
483        // The attention matrix alone is 4MB for naive, ~0 for flash
484    }
485
486    #[test]
487    fn test_benchmark_result_display() {
488        let result = BenchmarkResult {
489            naive_time_ms: 10.0,
490            flash_time_ms: 3.0,
491            speedup: 3.33,
492            memory_reduction: 0.75,
493            num_queries: 256,
494            dimensions: 128,
495        };
496        let s = format!("{}", result);
497        assert!(s.contains("256"));
498        assert!(s.contains("3.3"));
499        assert!(s.contains("75%"));
500    }
501
502    #[test]
503    fn test_block_size_effect() {
504        // Different block sizes should produce the same result
505        let mut config1 = FlashAttentionConfig::default();
506        config1.dimensions = 16;
507        config1.block_size = 2;
508
509        let mut config2 = FlashAttentionConfig::default();
510        config2.dimensions = 16;
511        config2.block_size = 32;
512
513        let fa1 = FlashAttention::new(config1);
514        let fa2 = FlashAttention::new(config2);
515
516        let vectors = generate_test_vectors(8, 16);
517
518        let out1 = fa1.attention(&vectors, &vectors, &vectors);
519        let out2 = fa2.attention(&vectors, &vectors, &vectors);
520
521        // Results should be identical regardless of block size
522        for (row1, row2) in out1.iter().zip(out2.iter()) {
523            for (v1, v2) in row1.iter().zip(row2.iter()) {
524                assert!(
525                    (v1 - v2).abs() < 1e-4,
526                    "Block size shouldn't affect output: {} vs {}",
527                    v1,
528                    v2
529                );
530            }
531        }
532    }
533
534    #[test]
535    fn test_temperature_scaling() {
536        let mut config_high = FlashAttentionConfig::default();
537        config_high.dimensions = 16;
538        config_high.temperature = 2.0;
539
540        let mut config_low = FlashAttentionConfig::default();
541        config_low.dimensions = 16;
542        config_low.temperature = 0.5;
543
544        let fa_high = FlashAttention::new(config_high);
545        let fa_low = FlashAttention::new(config_low);
546
547        let vectors = generate_test_vectors(4, 16);
548
549        let out_high = fa_high.attention(&vectors, &vectors, &vectors);
550        let out_low = fa_low.attention(&vectors, &vectors, &vectors);
551
552        // Higher temperature → more uniform distribution → less peaked output
553        // Lower temperature → sharper distribution → more peaked output
554        // Check that they produce different results
555        let mut different = false;
556        for (r_high, r_low) in out_high.iter().zip(out_low.iter()) {
557            for (v_high, v_low) in r_high.iter().zip(r_low.iter()) {
558                if (v_high - v_low).abs() > 1e-4 {
559                    different = true;
560                    break;
561                }
562            }
563        }
564        assert!(
565            different,
566            "Different temperatures should produce different outputs"
567        );
568    }
569
570    #[test]
571    fn test_large_sequence_correctness() {
572        let fa = FlashAttention::with_dimensions(32);
573        let vectors = generate_test_vectors(50, 32);
574
575        let flash = fa.attention(&vectors, &vectors, &vectors);
576        let naive = fa.naive_attention(&vectors, &vectors, &vectors);
577
578        // For larger sequences, allow slightly more tolerance
579        let mut max_relative_error = 0.0f32;
580        for (f_row, n_row) in flash.iter().zip(naive.iter()) {
581            for (f, n) in f_row.iter().zip(n_row.iter()) {
582                let err = (f - n).abs() / f.abs().max(n.abs()).max(1e-6);
583                max_relative_error = max_relative_error.max(err);
584            }
585        }
586        assert!(
587            max_relative_error < 0.02,
588            "Max relative error: {:.4} — should be < 2%",
589            max_relative_error
590        );
591    }
592}