Skip to main content

oxios_kernel/memory/
flash_attention.rs

1//! Flash Attention — block-wise attention for O(N) memory usage.
2//!
3//! Triton-inspired CPU implementation that processes attention in blocks
4//! to maximize L1/L2 cache efficiency. Achieves 2-5× speedup and ~75%
5//! memory reduction compared to naive attention for large sequence lengths.
6//!
7//! Reference: "FlashAttention: Fast and Memory-Efficient Exact Attention
8//! with IO-Awareness" (Dao et al., 2022)
9
10use serde::{Deserialize, Serialize};
11
12// ---------------------------------------------------------------------------
13// Configuration
14// ---------------------------------------------------------------------------
15
16/// Configuration for Flash Attention.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct FlashAttentionConfig {
19    /// Block size for tiled computation (tune to L1 cache).
20    /// Default 64 works well for typical f32 vectors.
21    pub block_size: usize,
22    /// Embedding dimensionality.
23    pub dimensions: usize,
24    /// Softmax temperature scaling.
25    pub temperature: f32,
26}
27
28impl Default for FlashAttentionConfig {
29    fn default() -> Self {
30        Self {
31            block_size: 64,
32            dimensions: 128,
33            temperature: 1.0,
34        }
35    }
36}
37
38// ---------------------------------------------------------------------------
39// Benchmark result
40// ---------------------------------------------------------------------------
41
42/// Result of a benchmark comparing naive vs flash attention.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct BenchmarkResult {
45    /// Naive attention time in milliseconds.
46    pub naive_time_ms: f64,
47    /// Flash attention time in milliseconds.
48    pub flash_time_ms: f64,
49    /// Speedup ratio (naive / flash).
50    pub speedup: f64,
51    /// Memory reduction ratio (0.75 = 75% less memory).
52    pub memory_reduction: f64,
53    /// Number of query vectors.
54    pub num_queries: usize,
55    /// Embedding dimension.
56    pub dimensions: usize,
57}
58
59impl std::fmt::Display for BenchmarkResult {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        write!(
62            f,
63            "Flash Attention Benchmark: {} queries × {}d — {:.2}ms → {:.2}ms ({:.1}× speedup, {:.0}% memory reduction)",
64            self.num_queries,
65            self.dimensions,
66            self.naive_time_ms,
67            self.flash_time_ms,
68            self.speedup,
69            self.memory_reduction * 100.0,
70        )
71    }
72}
73
74// ---------------------------------------------------------------------------
75// Flash Attention
76// ---------------------------------------------------------------------------
77
78/// Block-wise attention computation optimized for CPU cache locality.
79///
80/// Instead of materializing the full N×N attention matrix, processes
81/// the computation in blocks that fit in L1/L2 cache, achieving
82/// O(N) memory complexity instead of O(N²).
83#[derive(Debug)]
84pub struct FlashAttention {
85    config: FlashAttentionConfig,
86}
87
88impl FlashAttention {
89    /// Create a new FlashAttention with the given configuration.
90    pub fn new(config: FlashAttentionConfig) -> Self {
91        Self { config }
92    }
93
94    /// Create with default configuration.
95    pub fn with_dimensions(dimensions: usize) -> Self {
96        let config = FlashAttentionConfig {
97            dimensions,
98            ..Default::default()
99        };
100        Self { config }
101    }
102
103    /// Returns a reference to the configuration.
104    pub fn config(&self) -> &FlashAttentionConfig {
105        &self.config
106    }
107
108    /// Compute scaled dot-product attention using the block-wise algorithm.
109    ///
110    /// For sequences of length N with dimension D:
111    /// - Naive: O(N²) memory (full attention matrix)
112    /// - Flash: O(N) memory (block-wise accumulation via online softmax)
113    ///
114    /// # Arguments
115    /// * `queries` - Query vectors [N_q × D]
116    /// * `keys` - Key vectors [N_k × D]
117    /// * `values` - Value vectors [N_k × D]
118    ///
119    /// # Returns
120    /// Output vectors [N_q × D]
121    #[allow(clippy::needless_range_loop)]
122    pub fn attention(
123        &self,
124        queries: &[Vec<f32>],
125        keys: &[Vec<f32>],
126        values: &[Vec<f32>],
127    ) -> Vec<Vec<f32>> {
128        if queries.is_empty() || keys.is_empty() {
129            return Vec::new();
130        }
131
132        // Use actual vector length (may differ from config.dimensions)
133        let dim = queries.first().map_or(0, |v| v.len());
134        if dim == 0 {
135            return vec![vec![]; queries.len()];
136        }
137        let scale = 1.0 / (self.config.temperature * (dim as f32).sqrt());
138        let block_size = self.config.block_size.min(keys.len());
139
140        let num_queries = queries.len();
141        let mut outputs = vec![vec![0.0f32; dim]; num_queries];
142
143        // Process each query independently — each query only needs O(N_k) memory
144        for (qi, query) in queries.iter().enumerate() {
145            // Online softmax accumulators (Flash Attention core idea)
146            // Instead of storing all attention weights, we accumulate incrementally
147            let mut output_accum = vec![0.0f32; dim];
148            let mut max_score = f32::NEG_INFINITY; // Running max for numerical stability
149            let mut sum_exp = 0.0f32; // Running sum of exp(score - max)
150
151            // Process key/value pairs in blocks
152            for k_block_start in (0..keys.len()).step_by(block_size) {
153                let k_block_end = (k_block_start + block_size).min(keys.len());
154
155                // Compute attention scores for this block
156                let mut block_max = max_score;
157                let mut block_scores = Vec::with_capacity(k_block_end - k_block_start);
158
159                for ki in k_block_start..k_block_end {
160                    let score = dot_product(query, &keys[ki]) * scale;
161                    block_scores.push(score);
162                    if score > block_max {
163                        block_max = score;
164                    }
165                }
166
167                // Update running maximum and rescale previous accumulation
168                let old_max = max_score;
169                if block_max > max_score {
170                    max_score = block_max;
171                }
172
173                // Rescale the accumulated sum and output by the change in max
174                let rescale_factor = if old_max == f32::NEG_INFINITY {
175                    0.0
176                } else {
177                    (old_max - max_score).exp()
178                };
179                sum_exp *= rescale_factor;
180                for v in output_accum.iter_mut() {
181                    *v *= rescale_factor;
182                }
183
184                // Add block contributions
185                for (block_idx, &score) in block_scores.iter().enumerate() {
186                    let ki = k_block_start + block_idx;
187                    let weight = (score - max_score).exp();
188                    sum_exp += weight;
189                    for (d, v) in output_accum.iter_mut().enumerate() {
190                        *v += weight * values[ki][d];
191                    }
192                }
193            }
194
195            // Normalize by sum_exp
196            if sum_exp > 0.0 {
197                let inv_sum = 1.0 / sum_exp;
198                for v in output_accum.iter_mut() {
199                    *v *= inv_sum;
200                }
201            }
202
203            outputs[qi] = output_accum;
204        }
205
206        outputs
207    }
208
209    /// Naive attention implementation for benchmarking comparison.
210    ///
211    /// Materializes the full N×N attention matrix: O(N²) memory.
212    pub fn naive_attention(
213        &self,
214        queries: &[Vec<f32>],
215        keys: &[Vec<f32>],
216        values: &[Vec<f32>],
217    ) -> Vec<Vec<f32>> {
218        if queries.is_empty() || keys.is_empty() {
219            return Vec::new();
220        }
221
222        // Use actual vector length (may differ from config.dimensions)
223        let dim = queries.first().map_or(0, |v| v.len());
224        if dim == 0 {
225            return vec![vec![]; queries.len()];
226        }
227        let scale = 1.0 / (self.config.temperature * (dim as f32).sqrt());
228        let num_queries = queries.len();
229        let num_keys = keys.len();
230
231        // Materialize full attention matrix: O(N_q × N_k) memory
232        let mut attention_weights = vec![vec![0.0f32; num_keys]; num_queries];
233
234        // Compute all scores
235        for (qi, query) in queries.iter().enumerate() {
236            let mut max_score = f32::NEG_INFINITY;
237            for (ki, key) in keys.iter().enumerate() {
238                let score = dot_product(query, key) * scale;
239                attention_weights[qi][ki] = score;
240                if score > max_score {
241                    max_score = score;
242                }
243            }
244            // Softmax
245            let mut sum_exp = 0.0f32;
246            for w in &mut attention_weights[qi] {
247                *w = (*w - max_score).exp();
248                sum_exp += *w;
249            }
250            if sum_exp > 0.0 {
251                let inv = 1.0 / sum_exp;
252                for w in &mut attention_weights[qi] {
253                    *w *= inv;
254                }
255            }
256        }
257
258        // Weighted sum: output = attention_weights × values
259        let mut outputs = vec![vec![0.0f32; dim]; num_queries];
260        for qi in 0..num_queries {
261            for ki in 0..num_keys {
262                let w = attention_weights[qi][ki];
263                for d in 0..dim {
264                    outputs[qi][d] += w * values[ki][d];
265                }
266            }
267        }
268
269        outputs
270    }
271
272    /// Run a benchmark comparing naive vs flash attention.
273    ///
274    /// Generates random vectors and measures wall-clock time for both methods.
275    /// Also verifies that both implementations produce equivalent results.
276    pub fn benchmark(&self, num_vectors: usize) -> BenchmarkResult {
277        let vectors = generate_test_vectors(num_vectors, self.config.dimensions);
278
279        let naive_start = std::time::Instant::now();
280        let naive_result = self.naive_attention(&vectors, &vectors, &vectors);
281        let naive_duration = naive_start.elapsed();
282
283        let flash_start = std::time::Instant::now();
284        let flash_result = self.attention(&vectors, &vectors, &vectors);
285        let flash_duration = flash_start.elapsed();
286
287        // Verify results are similar (within 5% relative tolerance)
288        let mut max_rel_err = 0.0f32;
289        for (f_row, n_row) in flash_result.iter().zip(naive_result.iter()) {
290            for (f, n) in f_row.iter().zip(n_row.iter()) {
291                let err = (f - n).abs() / f.abs().max(n.abs()).max(1e-6);
292                max_rel_err = max_rel_err.max(err);
293            }
294        }
295        if max_rel_err > 0.05 {
296            tracing::warn!(
297                max_relative_error = max_rel_err,
298                "Flash vs naive attention results diverge"
299            );
300        }
301
302        let naive_ms = naive_duration.as_secs_f64() * 1000.0;
303        let flash_ms = flash_duration.as_secs_f64() * 1000.0;
304        let speedup = if flash_ms > 0.0 {
305            naive_ms / flash_ms
306        } else {
307            f64::INFINITY
308        };
309
310        // Memory reduction: naive stores N×N matrix, flash stores O(N) per query
311        let naive_mem = num_vectors * num_vectors; // attention matrix elements
312        let flash_mem = self.config.dimensions + 2; // per-query accumulators
313        let memory_reduction = 1.0 - (flash_mem as f64 / naive_mem as f64);
314
315        BenchmarkResult {
316            naive_time_ms: naive_ms,
317            flash_time_ms: flash_ms,
318            speedup,
319            memory_reduction: memory_reduction.max(0.0),
320            num_queries: num_vectors,
321            dimensions: self.config.dimensions,
322        }
323    }
324
325    /// Compute self-attention: a sequence attends to itself.
326    ///
327    /// Convenience wrapper around `attention(q, q, q)`.
328    pub fn self_attention(&self, sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
329        self.attention(sequence, sequence, sequence)
330    }
331
332    /// Compute cross-attention between two sequences.
333    ///
334    /// Queries from one sequence attend to keys/values from another.
335    pub fn cross_attention(&self, queries: &[Vec<f32>], kv_sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
336        self.attention(queries, kv_sequence, kv_sequence)
337    }
338
339    /// Estimate peak memory usage in bytes for a given sequence length.
340    pub fn memory_estimate(&self, seq_len: usize) -> MemoryEstimate {
341        let dim = self.config.dimensions;
342        let element_size = std::mem::size_of::<f32>();
343
344        // Naive: full N×N attention matrix + N×D output + N×D Q/K/V
345        let naive_peak = seq_len * seq_len * element_size // attention matrix
346            + seq_len * dim * element_size * 3 // Q, K, V
347            + seq_len * dim * element_size; // output
348
349        // Flash: D accumulators + 2 scalars per query, processed sequentially
350        let flash_peak = dim * element_size // output accumulator
351            + self.config.block_size * element_size // block scores
352            + seq_len * dim * element_size * 3 // Q, K, V (inputs)
353            + seq_len * dim * element_size; // output
354
355        MemoryEstimate {
356            naive_bytes: naive_peak,
357            flash_bytes: flash_peak,
358            reduction_ratio: 1.0 - (flash_peak as f64 / naive_peak as f64),
359        }
360    }
361}
362
363/// Memory usage estimate.
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct MemoryEstimate {
366    /// Peak memory for naive attention (bytes).
367    pub naive_bytes: usize,
368    /// Peak memory for flash attention (bytes).
369    pub flash_bytes: usize,
370    /// Memory reduction ratio (0.75 = 75% less).
371    pub reduction_ratio: f64,
372}
373
374// ---------------------------------------------------------------------------
375// Helpers
376// ---------------------------------------------------------------------------
377
378/// Dot product of two f32 vectors.
379fn dot_product(a: &[f32], b: &[f32]) -> f32 {
380    a.iter().zip(b).map(|(x, y)| x * y).sum()
381}
382
383/// Generate deterministic test vectors using a simple LCG.
384fn generate_test_vectors(count: usize, dim: usize) -> Vec<Vec<f32>> {
385    let mut rng_state = 42u64;
386    let mut vectors = Vec::with_capacity(count);
387
388    for _ in 0..count {
389        let mut v = Vec::with_capacity(dim);
390        for _ in 0..dim {
391            // LCG: x_{n+1} = (a * x_n + c) mod m
392            rng_state = rng_state
393                .wrapping_mul(6364136223846793005)
394                .wrapping_add(1442695040888963407);
395            let val = ((rng_state >> 33) as f32 / (1u64 << 31) as f32) - 1.0;
396            v.push(val);
397        }
398        vectors.push(v);
399    }
400
401    vectors
402}
403
404// ---------------------------------------------------------------------------
405// Tests
406// ---------------------------------------------------------------------------
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_flash_vs_naive_small() {
414        let fa = FlashAttention::with_dimensions(16);
415        let queries = generate_test_vectors(4, 16);
416        let keys = generate_test_vectors(4, 16);
417        let values = generate_test_vectors(4, 16);
418
419        let flash_output = fa.attention(&queries, &keys, &values);
420        let naive_output = fa.naive_attention(&queries, &keys, &values);
421
422        assert_eq!(flash_output.len(), naive_output.len());
423
424        // Results should be very close (within 1% relative tolerance)
425        for (flash_row, naive_row) in flash_output.iter().zip(naive_output.iter()) {
426            for (f, n) in flash_row.iter().zip(naive_row.iter()) {
427                let diff = (f - n).abs();
428                let max_val = f.abs().max(n.abs()).max(1e-6);
429                assert!(
430                    diff / max_val < 0.01,
431                    "Flash and naive outputs differ: flash={:.6}, naive={:.6}",
432                    f,
433                    n
434                );
435            }
436        }
437    }
438
439    #[test]
440    fn test_flash_attention_empty() {
441        let fa = FlashAttention::with_dimensions(16);
442        let result = fa.attention(&[], &[], &[]);
443        assert!(result.is_empty());
444    }
445
446    #[test]
447    fn test_self_attention() {
448        let fa = FlashAttention::with_dimensions(8);
449        let seq = generate_test_vectors(3, 8);
450        let result = fa.self_attention(&seq);
451        assert_eq!(result.len(), 3);
452        // Each output should have the correct dimension
453        for row in &result {
454            assert_eq!(row.len(), 8);
455        }
456    }
457
458    #[test]
459    fn test_cross_attention() {
460        let fa = FlashAttention::with_dimensions(8);
461        let queries = generate_test_vectors(2, 8);
462        let kv = generate_test_vectors(5, 8);
463        let result = fa.cross_attention(&queries, &kv);
464        assert_eq!(result.len(), 2);
465        for row in &result {
466            assert_eq!(row.len(), 8);
467        }
468    }
469
470    #[test]
471    fn test_memory_estimate() {
472        let fa = FlashAttention::with_dimensions(128);
473        let estimate = fa.memory_estimate(1000);
474
475        assert!(estimate.flash_bytes < estimate.naive_bytes);
476        assert!(
477            estimate.reduction_ratio > 0.5,
478            "Should achieve >50% memory reduction"
479        );
480
481        // For 1000×128: naive = 1000*1000*4 + overhead, flash much less
482        // The attention matrix alone is 4MB for naive, ~0 for flash
483    }
484
485    #[test]
486    fn test_benchmark_result_display() {
487        let result = BenchmarkResult {
488            naive_time_ms: 10.0,
489            flash_time_ms: 3.0,
490            speedup: 3.33,
491            memory_reduction: 0.75,
492            num_queries: 256,
493            dimensions: 128,
494        };
495        let s = format!("{}", result);
496        assert!(s.contains("256"));
497        assert!(s.contains("3.3"));
498        assert!(s.contains("75%"));
499    }
500
501    #[test]
502    fn test_block_size_effect() {
503        // Different block sizes should produce the same result
504        let mut config1 = FlashAttentionConfig::default();
505        config1.dimensions = 16;
506        config1.block_size = 2;
507
508        let mut config2 = FlashAttentionConfig::default();
509        config2.dimensions = 16;
510        config2.block_size = 32;
511
512        let fa1 = FlashAttention::new(config1);
513        let fa2 = FlashAttention::new(config2);
514
515        let vectors = generate_test_vectors(8, 16);
516
517        let out1 = fa1.attention(&vectors, &vectors, &vectors);
518        let out2 = fa2.attention(&vectors, &vectors, &vectors);
519
520        // Results should be identical regardless of block size
521        for (row1, row2) in out1.iter().zip(out2.iter()) {
522            for (v1, v2) in row1.iter().zip(row2.iter()) {
523                assert!(
524                    (v1 - v2).abs() < 1e-4,
525                    "Block size shouldn't affect output: {} vs {}",
526                    v1,
527                    v2
528                );
529            }
530        }
531    }
532
533    #[test]
534    fn test_temperature_scaling() {
535        let mut config_high = FlashAttentionConfig::default();
536        config_high.dimensions = 16;
537        config_high.temperature = 2.0;
538
539        let mut config_low = FlashAttentionConfig::default();
540        config_low.dimensions = 16;
541        config_low.temperature = 0.5;
542
543        let fa_high = FlashAttention::new(config_high);
544        let fa_low = FlashAttention::new(config_low);
545
546        let vectors = generate_test_vectors(4, 16);
547
548        let out_high = fa_high.attention(&vectors, &vectors, &vectors);
549        let out_low = fa_low.attention(&vectors, &vectors, &vectors);
550
551        // Higher temperature → more uniform distribution → less peaked output
552        // Lower temperature → sharper distribution → more peaked output
553        // Check that they produce different results
554        let mut different = false;
555        for (r_high, r_low) in out_high.iter().zip(out_low.iter()) {
556            for (v_high, v_low) in r_high.iter().zip(r_low.iter()) {
557                if (v_high - v_low).abs() > 1e-4 {
558                    different = true;
559                    break;
560                }
561            }
562        }
563        assert!(
564            different,
565            "Different temperatures should produce different outputs"
566        );
567    }
568
569    #[test]
570    fn test_large_sequence_correctness() {
571        let fa = FlashAttention::with_dimensions(32);
572        let vectors = generate_test_vectors(50, 32);
573
574        let flash = fa.attention(&vectors, &vectors, &vectors);
575        let naive = fa.naive_attention(&vectors, &vectors, &vectors);
576
577        // For larger sequences, allow slightly more tolerance
578        let mut max_relative_error = 0.0f32;
579        for (f_row, n_row) in flash.iter().zip(naive.iter()) {
580            for (f, n) in f_row.iter().zip(n_row.iter()) {
581                let err = (f - n).abs() / f.abs().max(n.abs()).max(1e-6);
582                max_relative_error = max_relative_error.max(err);
583            }
584        }
585        assert!(
586            max_relative_error < 0.02,
587            "Max relative error: {:.4} — should be < 2%",
588            max_relative_error
589        );
590    }
591}