1#![allow(missing_docs)]
2use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct FlashAttentionConfig {
20 pub block_size: usize,
23 pub dimensions: usize,
25 pub temperature: f32,
27}
28
29impl Default for FlashAttentionConfig {
30 fn default() -> Self {
31 Self {
32 block_size: 64,
33 dimensions: 128,
34 temperature: 1.0,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct BenchmarkResult {
46 pub naive_time_ms: f64,
48 pub flash_time_ms: f64,
50 pub speedup: f64,
52 pub memory_reduction: f64,
54 pub num_queries: usize,
56 pub dimensions: usize,
58}
59
60impl std::fmt::Display for BenchmarkResult {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(
63 f,
64 "Flash Attention Benchmark: {} queries × {}d — {:.2}ms → {:.2}ms ({:.1}× speedup, {:.0}% memory reduction)",
65 self.num_queries,
66 self.dimensions,
67 self.naive_time_ms,
68 self.flash_time_ms,
69 self.speedup,
70 self.memory_reduction * 100.0,
71 )
72 }
73}
74
75#[derive(Debug)]
85pub struct FlashAttention {
86 config: FlashAttentionConfig,
87}
88
89impl FlashAttention {
90 pub fn new(config: FlashAttentionConfig) -> Self {
92 Self { config }
93 }
94
95 pub fn with_dimensions(dimensions: usize) -> Self {
97 let config = FlashAttentionConfig {
98 dimensions,
99 ..Default::default()
100 };
101 Self { config }
102 }
103
104 pub fn config(&self) -> &FlashAttentionConfig {
106 &self.config
107 }
108
109 #[allow(clippy::needless_range_loop)]
123 pub fn attention(
124 &self,
125 queries: &[Vec<f32>],
126 keys: &[Vec<f32>],
127 values: &[Vec<f32>],
128 ) -> Vec<Vec<f32>> {
129 if queries.is_empty() || keys.is_empty() {
130 return Vec::new();
131 }
132
133 let dim = queries.first().map_or(0, |v| v.len());
135 if dim == 0 {
136 return vec![vec![]; queries.len()];
137 }
138 let scale = 1.0 / (self.config.temperature * (dim as f32).sqrt());
139 let block_size = self.config.block_size.min(keys.len());
140
141 let num_queries = queries.len();
142 let mut outputs = vec![vec![0.0f32; dim]; num_queries];
143
144 for (qi, query) in queries.iter().enumerate() {
146 let mut output_accum = vec![0.0f32; dim];
149 let mut max_score = f32::NEG_INFINITY; let mut sum_exp = 0.0f32; for k_block_start in (0..keys.len()).step_by(block_size) {
154 let k_block_end = (k_block_start + block_size).min(keys.len());
155
156 let mut block_max = max_score;
158 let mut block_scores = Vec::with_capacity(k_block_end - k_block_start);
159
160 for ki in k_block_start..k_block_end {
161 let score = dot_product(query, &keys[ki]) * scale;
162 block_scores.push(score);
163 if score > block_max {
164 block_max = score;
165 }
166 }
167
168 let old_max = max_score;
170 if block_max > max_score {
171 max_score = block_max;
172 }
173
174 let rescale_factor = if old_max == f32::NEG_INFINITY {
176 0.0
177 } else {
178 (old_max - max_score).exp()
179 };
180 sum_exp *= rescale_factor;
181 for v in output_accum.iter_mut() {
182 *v *= rescale_factor;
183 }
184
185 for (block_idx, &score) in block_scores.iter().enumerate() {
187 let ki = k_block_start + block_idx;
188 let weight = (score - max_score).exp();
189 sum_exp += weight;
190 for (d, v) in output_accum.iter_mut().enumerate() {
191 *v += weight * values[ki][d];
192 }
193 }
194 }
195
196 if sum_exp > 0.0 {
198 let inv_sum = 1.0 / sum_exp;
199 for v in output_accum.iter_mut() {
200 *v *= inv_sum;
201 }
202 }
203
204 outputs[qi] = output_accum;
205 }
206
207 outputs
208 }
209
210 pub fn naive_attention(
214 &self,
215 queries: &[Vec<f32>],
216 keys: &[Vec<f32>],
217 values: &[Vec<f32>],
218 ) -> Vec<Vec<f32>> {
219 if queries.is_empty() || keys.is_empty() {
220 return Vec::new();
221 }
222
223 let dim = queries.first().map_or(0, |v| v.len());
225 if dim == 0 {
226 return vec![vec![]; queries.len()];
227 }
228 let scale = 1.0 / (self.config.temperature * (dim as f32).sqrt());
229 let num_queries = queries.len();
230 let num_keys = keys.len();
231
232 let mut attention_weights = vec![vec![0.0f32; num_keys]; num_queries];
234
235 for (qi, query) in queries.iter().enumerate() {
237 let mut max_score = f32::NEG_INFINITY;
238 for (ki, key) in keys.iter().enumerate() {
239 let score = dot_product(query, key) * scale;
240 attention_weights[qi][ki] = score;
241 if score > max_score {
242 max_score = score;
243 }
244 }
245 let mut sum_exp = 0.0f32;
247 for w in &mut attention_weights[qi] {
248 *w = (*w - max_score).exp();
249 sum_exp += *w;
250 }
251 if sum_exp > 0.0 {
252 let inv = 1.0 / sum_exp;
253 for w in &mut attention_weights[qi] {
254 *w *= inv;
255 }
256 }
257 }
258
259 let mut outputs = vec![vec![0.0f32; dim]; num_queries];
261 for qi in 0..num_queries {
262 for (ki, value_row) in values.iter().enumerate() {
263 let w = attention_weights[qi][ki];
264 for (d, val) in value_row.iter().enumerate().take(dim) {
265 outputs[qi][d] += w * val;
266 }
267 }
268 }
269
270 outputs
271 }
272
273 pub fn benchmark(&self, num_vectors: usize) -> BenchmarkResult {
278 let vectors = generate_test_vectors(num_vectors, self.config.dimensions);
279
280 let naive_start = std::time::Instant::now();
281 let naive_result = self.naive_attention(&vectors, &vectors, &vectors);
282 let naive_duration = naive_start.elapsed();
283
284 let flash_start = std::time::Instant::now();
285 let flash_result = self.attention(&vectors, &vectors, &vectors);
286 let flash_duration = flash_start.elapsed();
287
288 let mut max_rel_err = 0.0f32;
290 for (f_row, n_row) in flash_result.iter().zip(naive_result.iter()) {
291 for (f, n) in f_row.iter().zip(n_row.iter()) {
292 let err = (f - n).abs() / f.abs().max(n.abs()).max(1e-6);
293 max_rel_err = max_rel_err.max(err);
294 }
295 }
296 if max_rel_err > 0.05 {
297 tracing::warn!(
298 max_relative_error = max_rel_err,
299 "Flash vs naive attention results diverge"
300 );
301 }
302
303 let naive_ms = naive_duration.as_secs_f64() * 1000.0;
304 let flash_ms = flash_duration.as_secs_f64() * 1000.0;
305 let speedup = if flash_ms > 0.0 {
306 naive_ms / flash_ms
307 } else {
308 f64::INFINITY
309 };
310
311 let naive_mem = num_vectors * num_vectors; let flash_mem = self.config.dimensions + 2; let memory_reduction = 1.0 - (flash_mem as f64 / naive_mem as f64);
315
316 BenchmarkResult {
317 naive_time_ms: naive_ms,
318 flash_time_ms: flash_ms,
319 speedup,
320 memory_reduction: memory_reduction.max(0.0),
321 num_queries: num_vectors,
322 dimensions: self.config.dimensions,
323 }
324 }
325
326 pub fn self_attention(&self, sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
330 self.attention(sequence, sequence, sequence)
331 }
332
333 pub fn cross_attention(&self, queries: &[Vec<f32>], kv_sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
337 self.attention(queries, kv_sequence, kv_sequence)
338 }
339
340 pub fn memory_estimate(&self, seq_len: usize) -> MemoryEstimate {
342 let dim = self.config.dimensions;
343 let element_size = std::mem::size_of::<f32>();
344
345 let naive_peak = seq_len * seq_len * element_size + seq_len * dim * element_size * 3 + seq_len * dim * element_size; let flash_peak = dim * element_size + self.config.block_size * element_size + seq_len * dim * element_size * 3 + seq_len * dim * element_size; MemoryEstimate {
357 naive_bytes: naive_peak,
358 flash_bytes: flash_peak,
359 reduction_ratio: 1.0 - (flash_peak as f64 / naive_peak as f64),
360 }
361 }
362}
363
364#[derive(Debug, Clone, Serialize, Deserialize)]
366pub struct MemoryEstimate {
367 pub naive_bytes: usize,
369 pub flash_bytes: usize,
371 pub reduction_ratio: f64,
373}
374
375fn dot_product(a: &[f32], b: &[f32]) -> f32 {
381 a.iter().zip(b).map(|(x, y)| x * y).sum()
382}
383
384fn generate_test_vectors(count: usize, dim: usize) -> Vec<Vec<f32>> {
386 let mut rng_state = 42u64;
387 let mut vectors = Vec::with_capacity(count);
388
389 for _ in 0..count {
390 let mut v = Vec::with_capacity(dim);
391 for _ in 0..dim {
392 rng_state = rng_state
394 .wrapping_mul(6364136223846793005)
395 .wrapping_add(1442695040888963407);
396 let val = ((rng_state >> 33) as f32 / (1u64 << 31) as f32) - 1.0;
397 v.push(val);
398 }
399 vectors.push(v);
400 }
401
402 vectors
403}
404
405#[cfg(test)]
410mod tests {
411 use super::*;
412
413 #[test]
414 fn test_flash_vs_naive_small() {
415 let fa = FlashAttention::with_dimensions(16);
416 let queries = generate_test_vectors(4, 16);
417 let keys = generate_test_vectors(4, 16);
418 let values = generate_test_vectors(4, 16);
419
420 let flash_output = fa.attention(&queries, &keys, &values);
421 let naive_output = fa.naive_attention(&queries, &keys, &values);
422
423 assert_eq!(flash_output.len(), naive_output.len());
424
425 for (flash_row, naive_row) in flash_output.iter().zip(naive_output.iter()) {
427 for (f, n) in flash_row.iter().zip(naive_row.iter()) {
428 let diff = (f - n).abs();
429 let max_val = f.abs().max(n.abs()).max(1e-6);
430 assert!(
431 diff / max_val < 0.01,
432 "Flash and naive outputs differ: flash={:.6}, naive={:.6}",
433 f,
434 n
435 );
436 }
437 }
438 }
439
440 #[test]
441 fn test_flash_attention_empty() {
442 let fa = FlashAttention::with_dimensions(16);
443 let result = fa.attention(&[], &[], &[]);
444 assert!(result.is_empty());
445 }
446
447 #[test]
448 fn test_self_attention() {
449 let fa = FlashAttention::with_dimensions(8);
450 let seq = generate_test_vectors(3, 8);
451 let result = fa.self_attention(&seq);
452 assert_eq!(result.len(), 3);
453 for row in &result {
455 assert_eq!(row.len(), 8);
456 }
457 }
458
459 #[test]
460 fn test_cross_attention() {
461 let fa = FlashAttention::with_dimensions(8);
462 let queries = generate_test_vectors(2, 8);
463 let kv = generate_test_vectors(5, 8);
464 let result = fa.cross_attention(&queries, &kv);
465 assert_eq!(result.len(), 2);
466 for row in &result {
467 assert_eq!(row.len(), 8);
468 }
469 }
470
471 #[test]
472 fn test_memory_estimate() {
473 let fa = FlashAttention::with_dimensions(128);
474 let estimate = fa.memory_estimate(1000);
475
476 assert!(estimate.flash_bytes < estimate.naive_bytes);
477 assert!(
478 estimate.reduction_ratio > 0.5,
479 "Should achieve >50% memory reduction"
480 );
481
482 }
485
486 #[test]
487 fn test_benchmark_result_display() {
488 let result = BenchmarkResult {
489 naive_time_ms: 10.0,
490 flash_time_ms: 3.0,
491 speedup: 3.33,
492 memory_reduction: 0.75,
493 num_queries: 256,
494 dimensions: 128,
495 };
496 let s = format!("{}", result);
497 assert!(s.contains("256"));
498 assert!(s.contains("3.3"));
499 assert!(s.contains("75%"));
500 }
501
502 #[test]
503 fn test_block_size_effect() {
504 let mut config1 = FlashAttentionConfig::default();
506 config1.dimensions = 16;
507 config1.block_size = 2;
508
509 let mut config2 = FlashAttentionConfig::default();
510 config2.dimensions = 16;
511 config2.block_size = 32;
512
513 let fa1 = FlashAttention::new(config1);
514 let fa2 = FlashAttention::new(config2);
515
516 let vectors = generate_test_vectors(8, 16);
517
518 let out1 = fa1.attention(&vectors, &vectors, &vectors);
519 let out2 = fa2.attention(&vectors, &vectors, &vectors);
520
521 for (row1, row2) in out1.iter().zip(out2.iter()) {
523 for (v1, v2) in row1.iter().zip(row2.iter()) {
524 assert!(
525 (v1 - v2).abs() < 1e-4,
526 "Block size shouldn't affect output: {} vs {}",
527 v1,
528 v2
529 );
530 }
531 }
532 }
533
534 #[test]
535 fn test_temperature_scaling() {
536 let mut config_high = FlashAttentionConfig::default();
537 config_high.dimensions = 16;
538 config_high.temperature = 2.0;
539
540 let mut config_low = FlashAttentionConfig::default();
541 config_low.dimensions = 16;
542 config_low.temperature = 0.5;
543
544 let fa_high = FlashAttention::new(config_high);
545 let fa_low = FlashAttention::new(config_low);
546
547 let vectors = generate_test_vectors(4, 16);
548
549 let out_high = fa_high.attention(&vectors, &vectors, &vectors);
550 let out_low = fa_low.attention(&vectors, &vectors, &vectors);
551
552 let mut different = false;
556 for (r_high, r_low) in out_high.iter().zip(out_low.iter()) {
557 for (v_high, v_low) in r_high.iter().zip(r_low.iter()) {
558 if (v_high - v_low).abs() > 1e-4 {
559 different = true;
560 break;
561 }
562 }
563 }
564 assert!(
565 different,
566 "Different temperatures should produce different outputs"
567 );
568 }
569
570 #[test]
571 fn test_large_sequence_correctness() {
572 let fa = FlashAttention::with_dimensions(32);
573 let vectors = generate_test_vectors(50, 32);
574
575 let flash = fa.attention(&vectors, &vectors, &vectors);
576 let naive = fa.naive_attention(&vectors, &vectors, &vectors);
577
578 let mut max_relative_error = 0.0f32;
580 for (f_row, n_row) in flash.iter().zip(naive.iter()) {
581 for (f, n) in f_row.iter().zip(n_row.iter()) {
582 let err = (f - n).abs() / f.abs().max(n.abs()).max(1e-6);
583 max_relative_error = max_relative_error.max(err);
584 }
585 }
586 assert!(
587 max_relative_error < 0.02,
588 "Max relative error: {:.4} — should be < 2%",
589 max_relative_error
590 );
591 }
592}