1use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct FlashAttentionConfig {
19 pub block_size: usize,
22 pub dimensions: usize,
24 pub temperature: f32,
26}
27
28impl Default for FlashAttentionConfig {
29 fn default() -> Self {
30 Self {
31 block_size: 64,
32 dimensions: 128,
33 temperature: 1.0,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct BenchmarkResult {
45 pub naive_time_ms: f64,
47 pub flash_time_ms: f64,
49 pub speedup: f64,
51 pub memory_reduction: f64,
53 pub num_queries: usize,
55 pub dimensions: usize,
57}
58
59impl std::fmt::Display for BenchmarkResult {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(
62 f,
63 "Flash Attention Benchmark: {} queries × {}d — {:.2}ms → {:.2}ms ({:.1}× speedup, {:.0}% memory reduction)",
64 self.num_queries,
65 self.dimensions,
66 self.naive_time_ms,
67 self.flash_time_ms,
68 self.speedup,
69 self.memory_reduction * 100.0,
70 )
71 }
72}
73
74#[derive(Debug)]
84pub struct FlashAttention {
85 config: FlashAttentionConfig,
86}
87
88impl FlashAttention {
89 pub fn new(config: FlashAttentionConfig) -> Self {
91 Self { config }
92 }
93
94 pub fn with_dimensions(dimensions: usize) -> Self {
96 let config = FlashAttentionConfig {
97 dimensions,
98 ..Default::default()
99 };
100 Self { config }
101 }
102
103 pub fn config(&self) -> &FlashAttentionConfig {
105 &self.config
106 }
107
108 #[allow(clippy::needless_range_loop)]
122 pub fn attention(
123 &self,
124 queries: &[Vec<f32>],
125 keys: &[Vec<f32>],
126 values: &[Vec<f32>],
127 ) -> Vec<Vec<f32>> {
128 if queries.is_empty() || keys.is_empty() {
129 return Vec::new();
130 }
131
132 let dim = queries.first().map_or(0, |v| v.len());
134 if dim == 0 {
135 return vec![vec![]; queries.len()];
136 }
137 let scale = 1.0 / (self.config.temperature * (dim as f32).sqrt());
138 let block_size = self.config.block_size.min(keys.len());
139
140 let num_queries = queries.len();
141 let mut outputs = vec![vec![0.0f32; dim]; num_queries];
142
143 for (qi, query) in queries.iter().enumerate() {
145 let mut output_accum = vec![0.0f32; dim];
148 let mut max_score = f32::NEG_INFINITY; let mut sum_exp = 0.0f32; for k_block_start in (0..keys.len()).step_by(block_size) {
153 let k_block_end = (k_block_start + block_size).min(keys.len());
154
155 let mut block_max = max_score;
157 let mut block_scores = Vec::with_capacity(k_block_end - k_block_start);
158
159 for ki in k_block_start..k_block_end {
160 let score = dot_product(query, &keys[ki]) * scale;
161 block_scores.push(score);
162 if score > block_max {
163 block_max = score;
164 }
165 }
166
167 let old_max = max_score;
169 if block_max > max_score {
170 max_score = block_max;
171 }
172
173 let rescale_factor = if old_max == f32::NEG_INFINITY {
175 0.0
176 } else {
177 (old_max - max_score).exp()
178 };
179 sum_exp *= rescale_factor;
180 for v in output_accum.iter_mut() {
181 *v *= rescale_factor;
182 }
183
184 for (block_idx, &score) in block_scores.iter().enumerate() {
186 let ki = k_block_start + block_idx;
187 let weight = (score - max_score).exp();
188 sum_exp += weight;
189 for (d, v) in output_accum.iter_mut().enumerate() {
190 *v += weight * values[ki][d];
191 }
192 }
193 }
194
195 if sum_exp > 0.0 {
197 let inv_sum = 1.0 / sum_exp;
198 for v in output_accum.iter_mut() {
199 *v *= inv_sum;
200 }
201 }
202
203 outputs[qi] = output_accum;
204 }
205
206 outputs
207 }
208
209 pub fn naive_attention(
213 &self,
214 queries: &[Vec<f32>],
215 keys: &[Vec<f32>],
216 values: &[Vec<f32>],
217 ) -> Vec<Vec<f32>> {
218 if queries.is_empty() || keys.is_empty() {
219 return Vec::new();
220 }
221
222 let dim = queries.first().map_or(0, |v| v.len());
224 if dim == 0 {
225 return vec![vec![]; queries.len()];
226 }
227 let scale = 1.0 / (self.config.temperature * (dim as f32).sqrt());
228 let num_queries = queries.len();
229 let num_keys = keys.len();
230
231 let mut attention_weights = vec![vec![0.0f32; num_keys]; num_queries];
233
234 for (qi, query) in queries.iter().enumerate() {
236 let mut max_score = f32::NEG_INFINITY;
237 for (ki, key) in keys.iter().enumerate() {
238 let score = dot_product(query, key) * scale;
239 attention_weights[qi][ki] = score;
240 if score > max_score {
241 max_score = score;
242 }
243 }
244 let mut sum_exp = 0.0f32;
246 for w in &mut attention_weights[qi] {
247 *w = (*w - max_score).exp();
248 sum_exp += *w;
249 }
250 if sum_exp > 0.0 {
251 let inv = 1.0 / sum_exp;
252 for w in &mut attention_weights[qi] {
253 *w *= inv;
254 }
255 }
256 }
257
258 let mut outputs = vec![vec![0.0f32; dim]; num_queries];
260 for qi in 0..num_queries {
261 for ki in 0..num_keys {
262 let w = attention_weights[qi][ki];
263 for d in 0..dim {
264 outputs[qi][d] += w * values[ki][d];
265 }
266 }
267 }
268
269 outputs
270 }
271
272 pub fn benchmark(&self, num_vectors: usize) -> BenchmarkResult {
277 let vectors = generate_test_vectors(num_vectors, self.config.dimensions);
278
279 let naive_start = std::time::Instant::now();
280 let naive_result = self.naive_attention(&vectors, &vectors, &vectors);
281 let naive_duration = naive_start.elapsed();
282
283 let flash_start = std::time::Instant::now();
284 let flash_result = self.attention(&vectors, &vectors, &vectors);
285 let flash_duration = flash_start.elapsed();
286
287 let mut max_rel_err = 0.0f32;
289 for (f_row, n_row) in flash_result.iter().zip(naive_result.iter()) {
290 for (f, n) in f_row.iter().zip(n_row.iter()) {
291 let err = (f - n).abs() / f.abs().max(n.abs()).max(1e-6);
292 max_rel_err = max_rel_err.max(err);
293 }
294 }
295 if max_rel_err > 0.05 {
296 tracing::warn!(
297 max_relative_error = max_rel_err,
298 "Flash vs naive attention results diverge"
299 );
300 }
301
302 let naive_ms = naive_duration.as_secs_f64() * 1000.0;
303 let flash_ms = flash_duration.as_secs_f64() * 1000.0;
304 let speedup = if flash_ms > 0.0 {
305 naive_ms / flash_ms
306 } else {
307 f64::INFINITY
308 };
309
310 let naive_mem = num_vectors * num_vectors; let flash_mem = self.config.dimensions + 2; let memory_reduction = 1.0 - (flash_mem as f64 / naive_mem as f64);
314
315 BenchmarkResult {
316 naive_time_ms: naive_ms,
317 flash_time_ms: flash_ms,
318 speedup,
319 memory_reduction: memory_reduction.max(0.0),
320 num_queries: num_vectors,
321 dimensions: self.config.dimensions,
322 }
323 }
324
325 pub fn self_attention(&self, sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
329 self.attention(sequence, sequence, sequence)
330 }
331
332 pub fn cross_attention(&self, queries: &[Vec<f32>], kv_sequence: &[Vec<f32>]) -> Vec<Vec<f32>> {
336 self.attention(queries, kv_sequence, kv_sequence)
337 }
338
339 pub fn memory_estimate(&self, seq_len: usize) -> MemoryEstimate {
341 let dim = self.config.dimensions;
342 let element_size = std::mem::size_of::<f32>();
343
344 let naive_peak = seq_len * seq_len * element_size + seq_len * dim * element_size * 3 + seq_len * dim * element_size; let flash_peak = dim * element_size + self.config.block_size * element_size + seq_len * dim * element_size * 3 + seq_len * dim * element_size; MemoryEstimate {
356 naive_bytes: naive_peak,
357 flash_bytes: flash_peak,
358 reduction_ratio: 1.0 - (flash_peak as f64 / naive_peak as f64),
359 }
360 }
361}
362
363#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct MemoryEstimate {
366 pub naive_bytes: usize,
368 pub flash_bytes: usize,
370 pub reduction_ratio: f64,
372}
373
374fn dot_product(a: &[f32], b: &[f32]) -> f32 {
380 a.iter().zip(b).map(|(x, y)| x * y).sum()
381}
382
383fn generate_test_vectors(count: usize, dim: usize) -> Vec<Vec<f32>> {
385 let mut rng_state = 42u64;
386 let mut vectors = Vec::with_capacity(count);
387
388 for _ in 0..count {
389 let mut v = Vec::with_capacity(dim);
390 for _ in 0..dim {
391 rng_state = rng_state
393 .wrapping_mul(6364136223846793005)
394 .wrapping_add(1442695040888963407);
395 let val = ((rng_state >> 33) as f32 / (1u64 << 31) as f32) - 1.0;
396 v.push(val);
397 }
398 vectors.push(v);
399 }
400
401 vectors
402}
403
404#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_flash_vs_naive_small() {
414 let fa = FlashAttention::with_dimensions(16);
415 let queries = generate_test_vectors(4, 16);
416 let keys = generate_test_vectors(4, 16);
417 let values = generate_test_vectors(4, 16);
418
419 let flash_output = fa.attention(&queries, &keys, &values);
420 let naive_output = fa.naive_attention(&queries, &keys, &values);
421
422 assert_eq!(flash_output.len(), naive_output.len());
423
424 for (flash_row, naive_row) in flash_output.iter().zip(naive_output.iter()) {
426 for (f, n) in flash_row.iter().zip(naive_row.iter()) {
427 let diff = (f - n).abs();
428 let max_val = f.abs().max(n.abs()).max(1e-6);
429 assert!(
430 diff / max_val < 0.01,
431 "Flash and naive outputs differ: flash={:.6}, naive={:.6}",
432 f,
433 n
434 );
435 }
436 }
437 }
438
439 #[test]
440 fn test_flash_attention_empty() {
441 let fa = FlashAttention::with_dimensions(16);
442 let result = fa.attention(&[], &[], &[]);
443 assert!(result.is_empty());
444 }
445
446 #[test]
447 fn test_self_attention() {
448 let fa = FlashAttention::with_dimensions(8);
449 let seq = generate_test_vectors(3, 8);
450 let result = fa.self_attention(&seq);
451 assert_eq!(result.len(), 3);
452 for row in &result {
454 assert_eq!(row.len(), 8);
455 }
456 }
457
458 #[test]
459 fn test_cross_attention() {
460 let fa = FlashAttention::with_dimensions(8);
461 let queries = generate_test_vectors(2, 8);
462 let kv = generate_test_vectors(5, 8);
463 let result = fa.cross_attention(&queries, &kv);
464 assert_eq!(result.len(), 2);
465 for row in &result {
466 assert_eq!(row.len(), 8);
467 }
468 }
469
470 #[test]
471 fn test_memory_estimate() {
472 let fa = FlashAttention::with_dimensions(128);
473 let estimate = fa.memory_estimate(1000);
474
475 assert!(estimate.flash_bytes < estimate.naive_bytes);
476 assert!(
477 estimate.reduction_ratio > 0.5,
478 "Should achieve >50% memory reduction"
479 );
480
481 }
484
485 #[test]
486 fn test_benchmark_result_display() {
487 let result = BenchmarkResult {
488 naive_time_ms: 10.0,
489 flash_time_ms: 3.0,
490 speedup: 3.33,
491 memory_reduction: 0.75,
492 num_queries: 256,
493 dimensions: 128,
494 };
495 let s = format!("{}", result);
496 assert!(s.contains("256"));
497 assert!(s.contains("3.3"));
498 assert!(s.contains("75%"));
499 }
500
501 #[test]
502 fn test_block_size_effect() {
503 let mut config1 = FlashAttentionConfig::default();
505 config1.dimensions = 16;
506 config1.block_size = 2;
507
508 let mut config2 = FlashAttentionConfig::default();
509 config2.dimensions = 16;
510 config2.block_size = 32;
511
512 let fa1 = FlashAttention::new(config1);
513 let fa2 = FlashAttention::new(config2);
514
515 let vectors = generate_test_vectors(8, 16);
516
517 let out1 = fa1.attention(&vectors, &vectors, &vectors);
518 let out2 = fa2.attention(&vectors, &vectors, &vectors);
519
520 for (row1, row2) in out1.iter().zip(out2.iter()) {
522 for (v1, v2) in row1.iter().zip(row2.iter()) {
523 assert!(
524 (v1 - v2).abs() < 1e-4,
525 "Block size shouldn't affect output: {} vs {}",
526 v1,
527 v2
528 );
529 }
530 }
531 }
532
533 #[test]
534 fn test_temperature_scaling() {
535 let mut config_high = FlashAttentionConfig::default();
536 config_high.dimensions = 16;
537 config_high.temperature = 2.0;
538
539 let mut config_low = FlashAttentionConfig::default();
540 config_low.dimensions = 16;
541 config_low.temperature = 0.5;
542
543 let fa_high = FlashAttention::new(config_high);
544 let fa_low = FlashAttention::new(config_low);
545
546 let vectors = generate_test_vectors(4, 16);
547
548 let out_high = fa_high.attention(&vectors, &vectors, &vectors);
549 let out_low = fa_low.attention(&vectors, &vectors, &vectors);
550
551 let mut different = false;
555 for (r_high, r_low) in out_high.iter().zip(out_low.iter()) {
556 for (v_high, v_low) in r_high.iter().zip(r_low.iter()) {
557 if (v_high - v_low).abs() > 1e-4 {
558 different = true;
559 break;
560 }
561 }
562 }
563 assert!(
564 different,
565 "Different temperatures should produce different outputs"
566 );
567 }
568
569 #[test]
570 fn test_large_sequence_correctness() {
571 let fa = FlashAttention::with_dimensions(32);
572 let vectors = generate_test_vectors(50, 32);
573
574 let flash = fa.attention(&vectors, &vectors, &vectors);
575 let naive = fa.naive_attention(&vectors, &vectors, &vectors);
576
577 let mut max_relative_error = 0.0f32;
579 for (f_row, n_row) in flash.iter().zip(naive.iter()) {
580 for (f, n) in f_row.iter().zip(n_row.iter()) {
581 let err = (f - n).abs() / f.abs().max(n.abs()).max(1e-6);
582 max_relative_error = max_relative_error.max(err);
583 }
584 }
585 assert!(
586 max_relative_error < 0.02,
587 "Max relative error: {:.4} — should be < 2%",
588 max_relative_error
589 );
590 }
591}