1use rayon::prelude::*;
10
11#[derive(Debug, Clone)]
15pub struct FlashDecodeConfig {
16 pub num_tiles: usize,
18 pub scale: f32,
20}
21
22impl FlashDecodeConfig {
23 pub fn new(head_dim: usize) -> Self {
25 let scale = if head_dim > 0 {
26 1.0_f32 / (head_dim as f32).sqrt()
27 } else {
28 1.0_f32
29 };
30 Self {
31 num_tiles: 4,
32 scale,
33 }
34 }
35
36 #[must_use]
38 pub fn with_num_tiles(mut self, n: usize) -> Self {
39 self.num_tiles = n;
40 self
41 }
42}
43
44fn flash_decode_tile(
58 query: &[f32],
59 keys_tile: &[f32],
60 values_tile: &[f32],
61 tile_len: usize,
62 head_dim: usize,
63 scale: f32,
64) -> (Vec<f32>, f32, f32) {
65 let mut scores: Vec<f32> = (0..tile_len)
67 .map(|t| {
68 let k_start = t * head_dim;
69 let k_vec = &keys_tile[k_start..k_start + head_dim];
70 query
71 .iter()
72 .zip(k_vec.iter())
73 .map(|(q, k)| q * k)
74 .sum::<f32>()
75 * scale
76 })
77 .collect();
78
79 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
81
82 if !max_score.is_finite() {
83 return (
85 vec![0.0_f32; head_dim],
86 f32::NEG_INFINITY,
87 f32::NEG_INFINITY,
88 );
89 }
90
91 for s in scores.iter_mut() {
93 *s = (*s - max_score).exp();
94 }
95 let sum: f32 = scores.iter().sum();
96 let log_sum_exp = max_score + sum.ln();
97
98 let mut output = vec![0.0_f32; head_dim];
100 for (t, &w) in scores.iter().enumerate() {
101 let v_start = t * head_dim;
102 let v_vec = &values_tile[v_start..v_start + head_dim];
103 for d in 0..head_dim {
104 output[d] += w * v_vec[d];
105 }
106 }
107
108 if sum > 0.0 {
110 for o in output.iter_mut() {
111 *o /= sum;
112 }
113 }
114
115 (output, max_score, log_sum_exp)
116}
117
118fn combine_tile_outputs(
126 tile_outputs: &[Vec<f32>],
127 tile_max_scores: &[f32],
128 tile_lse: &[f32],
129 head_dim: usize,
130) -> Vec<f32> {
131 debug_assert_eq!(tile_outputs.len(), tile_lse.len());
132 debug_assert_eq!(tile_outputs.len(), tile_max_scores.len());
133
134 if tile_outputs.is_empty() {
135 return vec![0.0_f32; head_dim];
136 }
137 if tile_outputs.len() == 1 {
138 return tile_outputs[0].clone();
139 }
140
141 let valid: Vec<usize> = (0..tile_lse.len())
143 .filter(|&i| tile_lse[i].is_finite())
144 .collect();
145
146 if valid.is_empty() {
147 return vec![0.0_f32; head_dim];
148 }
149 if valid.len() == 1 {
150 return tile_outputs[valid[0]].clone();
151 }
152
153 let global_lse_max = valid
155 .iter()
156 .map(|&i| tile_lse[i])
157 .fold(f32::NEG_INFINITY, f32::max);
158
159 let global_sum: f32 = valid
160 .iter()
161 .map(|&i| (tile_lse[i] - global_lse_max).exp())
162 .sum();
163 let global_lse = global_lse_max + global_sum.ln();
164
165 let mut combined = vec![0.0_f32; head_dim];
167 for &i in &valid {
168 let weight = (tile_lse[i] - global_lse).exp();
169 for d in 0..head_dim {
170 combined[d] += weight * tile_outputs[i][d];
171 }
172 }
173
174 combined
175}
176
177pub fn flash_decode_single_head(
186 query: &[f32],
187 keys: &[f32],
188 values: &[f32],
189 seq_len: usize,
190 head_dim: usize,
191 config: &FlashDecodeConfig,
192) -> Result<Vec<f32>, FlashDecodeError> {
193 if seq_len == 0 {
194 return Err(FlashDecodeError::EmptyKv);
195 }
196 if query.len() != head_dim {
197 return Err(FlashDecodeError::DimMismatch {
198 q_dim: query.len(),
199 k_dim: head_dim,
200 });
201 }
202 if keys.len() != seq_len * head_dim {
203 return Err(FlashDecodeError::DimMismatch {
204 q_dim: query.len(),
205 k_dim: keys.len() / seq_len.max(1),
206 });
207 }
208 if values.len() != seq_len * head_dim {
209 return Err(FlashDecodeError::DimMismatch {
210 q_dim: query.len(),
211 k_dim: values.len() / seq_len.max(1),
212 });
213 }
214
215 let num_tiles = config.num_tiles.min(seq_len).max(1);
217
218 let tile_size_base = seq_len / num_tiles;
219 let remainder = seq_len % num_tiles;
220
221 let mut tile_outputs: Vec<Vec<f32>> = Vec::with_capacity(num_tiles);
222 let mut tile_max_scores: Vec<f32> = Vec::with_capacity(num_tiles);
223 let mut tile_lse: Vec<f32> = Vec::with_capacity(num_tiles);
224
225 let mut offset = 0usize;
226 for tile_idx in 0..num_tiles {
227 let tile_len = tile_size_base + if tile_idx < remainder { 1 } else { 0 };
229 if tile_len == 0 {
230 break;
231 }
232
233 let k_start = offset * head_dim;
234 let k_end = k_start + tile_len * head_dim;
235 let v_start = offset * head_dim;
236 let v_end = v_start + tile_len * head_dim;
237
238 let (out, max_s, lse) = flash_decode_tile(
239 query,
240 &keys[k_start..k_end],
241 &values[v_start..v_end],
242 tile_len,
243 head_dim,
244 config.scale,
245 );
246 tile_outputs.push(out);
247 tile_max_scores.push(max_s);
248 tile_lse.push(lse);
249
250 offset += tile_len;
251 }
252
253 Ok(combine_tile_outputs(
254 &tile_outputs,
255 &tile_max_scores,
256 &tile_lse,
257 head_dim,
258 ))
259}
260
261pub fn flash_decode_multi_head(
271 queries: &[f32],
272 keys: &[f32],
273 values: &[f32],
274 num_heads: usize,
275 seq_len: usize,
276 head_dim: usize,
277 config: &FlashDecodeConfig,
278) -> Result<Vec<f32>, FlashDecodeError> {
279 if seq_len == 0 {
280 return Err(FlashDecodeError::EmptyKv);
281 }
282 if queries.len() != num_heads * head_dim {
283 return Err(FlashDecodeError::DimMismatch {
284 q_dim: queries.len(),
285 k_dim: head_dim,
286 });
287 }
288
289 let per_head_keys: Vec<Vec<f32>> = (0..num_heads)
293 .map(|h| {
294 let mut buf = vec![0.0_f32; seq_len * head_dim];
295 for t in 0..seq_len {
296 let src_start = t * num_heads * head_dim + h * head_dim;
297 let dst_start = t * head_dim;
298 buf[dst_start..dst_start + head_dim]
299 .copy_from_slice(&keys[src_start..src_start + head_dim]);
300 }
301 buf
302 })
303 .collect();
304
305 let per_head_values: Vec<Vec<f32>> = (0..num_heads)
306 .map(|h| {
307 let mut buf = vec![0.0_f32; seq_len * head_dim];
308 for t in 0..seq_len {
309 let src_start = t * num_heads * head_dim + h * head_dim;
310 let dst_start = t * head_dim;
311 buf[dst_start..dst_start + head_dim]
312 .copy_from_slice(&values[src_start..src_start + head_dim]);
313 }
314 buf
315 })
316 .collect();
317
318 let results: Vec<Result<Vec<f32>, FlashDecodeError>> = (0..num_heads)
320 .into_par_iter()
321 .map(|h| {
322 let q_start = h * head_dim;
323 let q_vec = &queries[q_start..q_start + head_dim];
324 flash_decode_single_head(
325 q_vec,
326 &per_head_keys[h],
327 &per_head_values[h],
328 seq_len,
329 head_dim,
330 config,
331 )
332 })
333 .collect();
334
335 let mut output = vec![0.0_f32; num_heads * head_dim];
337 for (h, res) in results.into_iter().enumerate() {
338 let head_out = res?;
339 let start = h * head_dim;
340 output[start..start + head_dim].copy_from_slice(&head_out);
341 }
342
343 Ok(output)
344}
345
346pub fn flash_vs_naive_error(
352 query: &[f32],
353 keys: &[f32],
354 values: &[f32],
355 seq_len: usize,
356 head_dim: usize,
357) -> Result<f32, FlashDecodeError> {
358 if seq_len == 0 {
359 return Err(FlashDecodeError::EmptyKv);
360 }
361 if query.len() != head_dim {
362 return Err(FlashDecodeError::DimMismatch {
363 q_dim: query.len(),
364 k_dim: head_dim,
365 });
366 }
367
368 let config = FlashDecodeConfig::new(head_dim);
370 let flash_out = flash_decode_single_head(query, keys, values, seq_len, head_dim, &config)?;
371
372 let scale = config.scale;
374 let mut scores: Vec<f32> = (0..seq_len)
375 .map(|t| {
376 let k_start = t * head_dim;
377 let k_vec = &keys[k_start..k_start + head_dim];
378 query
379 .iter()
380 .zip(k_vec.iter())
381 .map(|(q, k)| q * k)
382 .sum::<f32>()
383 * scale
384 })
385 .collect();
386
387 let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
389 for s in scores.iter_mut() {
390 *s = (*s - max_s).exp();
391 }
392 let sum: f32 = scores.iter().sum();
393 if sum > 0.0 {
394 for s in scores.iter_mut() {
395 *s /= sum;
396 }
397 }
398
399 let mut naive_out = vec![0.0_f32; head_dim];
401 for (t, &w) in scores.iter().enumerate() {
402 let v_start = t * head_dim;
403 for d in 0..head_dim {
404 naive_out[d] += w * values[v_start + d];
405 }
406 }
407
408 let mae = flash_out
410 .iter()
411 .zip(naive_out.iter())
412 .map(|(a, b)| (a - b).abs())
413 .sum::<f32>()
414 / head_dim as f32;
415
416 Ok(mae)
417}
418
419#[derive(Debug, thiserror::Error)]
423pub enum FlashDecodeError {
424 #[error("empty KV sequence")]
425 EmptyKv,
426
427 #[error("dimension mismatch: query has {q_dim}, keys have {k_dim}")]
428 DimMismatch { q_dim: usize, k_dim: usize },
429
430 #[error("num_tiles ({0}) exceeds seq_len ({1})")]
431 TooManyTiles(usize, usize),
432
433 #[error("invalid config: {0}")]
434 InvalidConfig(String),
435}
436
437#[cfg(test)]
440mod tests {
441 use super::*;
442
443 fn make_deterministic_data(seq_len: usize, head_dim: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
444 let query: Vec<f32> = (0..head_dim).map(|i| 0.1 * i as f32).collect();
445 let keys: Vec<f32> = (0..seq_len * head_dim)
446 .map(|i| 0.05 * i as f32 + 0.01)
447 .collect();
448 let values: Vec<f32> = (0..seq_len * head_dim)
449 .map(|i| 0.02 * i as f32 + 0.1)
450 .collect();
451 (query, keys, values)
452 }
453
454 #[test]
455 fn flash_decode_config_default() {
456 let head_dim = 64usize;
457 let cfg = FlashDecodeConfig::new(head_dim);
458 let expected_scale = 1.0_f32 / (head_dim as f32).sqrt();
459 assert!(
460 (cfg.scale - expected_scale).abs() < 1e-6,
461 "scale mismatch: {} vs {}",
462 cfg.scale,
463 expected_scale
464 );
465 assert_eq!(cfg.num_tiles, 4);
466 }
467
468 #[test]
469 fn flash_decode_single_head_matches_naive() {
470 let head_dim = 16;
471 let seq_len = 32;
472 let (q, k, v) = make_deterministic_data(seq_len, head_dim);
473 let mae = flash_vs_naive_error(&q, &k, &v, seq_len, head_dim)
474 .expect("flash_vs_naive_error failed");
475 assert!(
476 mae < 1e-5,
477 "MAE between flash and naive exceeds threshold: {mae}"
478 );
479 }
480
481 #[test]
482 fn flash_decode_empty_kv_error() {
483 let head_dim = 8;
484 let config = FlashDecodeConfig::new(head_dim);
485 let q = vec![0.1f32; head_dim];
486 let result = flash_decode_single_head(&q, &[], &[], 0, head_dim, &config);
487 assert!(
488 matches!(result, Err(FlashDecodeError::EmptyKv)),
489 "expected EmptyKv, got {result:?}"
490 );
491 }
492
493 #[test]
494 fn flash_decode_dim_mismatch_error() {
495 let head_dim = 8;
496 let config = FlashDecodeConfig::new(head_dim);
497 let q = vec![0.1f32; head_dim + 2];
499 let k = vec![0.1f32; head_dim];
500 let v = vec![0.1f32; head_dim];
501 let result = flash_decode_single_head(&q, &k, &v, 1, head_dim, &config);
502 assert!(
503 matches!(result, Err(FlashDecodeError::DimMismatch { .. })),
504 "expected DimMismatch, got {result:?}"
505 );
506 }
507
508 #[test]
509 fn flash_decode_single_token() {
510 let head_dim = 4;
512 let config = FlashDecodeConfig::new(head_dim);
513 let q = vec![1.0f32, 0.0, 0.0, 0.0];
514 let k = vec![0.5f32, 0.5, 0.5, 0.5]; let v = vec![3.0f32, 1.0, 2.0, 4.0]; let out = flash_decode_single_head(&q, &k, &v, 1, head_dim, &config)
518 .expect("flash_decode_single_head failed");
519
520 for (i, (&o, &expected)) in out.iter().zip(v.iter()).enumerate() {
521 assert!(
522 (o - expected).abs() < 1e-5,
523 "output[{i}] = {o}, expected {expected}"
524 );
525 }
526 }
527
528 #[test]
529 fn flash_decode_uniform_keys() {
530 let head_dim = 4;
533 let seq_len = 4;
534 let config = FlashDecodeConfig::new(head_dim);
535 let q = vec![0.1f32; head_dim];
536 let k = vec![0.1f32; seq_len * head_dim]; let v: Vec<f32> = (0..seq_len)
540 .flat_map(|t| vec![(t + 1) as f32; head_dim])
541 .collect();
542
543 let out = flash_decode_single_head(&q, &k, &v, seq_len, head_dim, &config)
544 .expect("flash_decode_single_head failed");
545
546 let expected = 2.5_f32;
549 for (i, &o) in out.iter().enumerate() {
550 assert!(
551 (o - expected).abs() < 1e-4,
552 "output[{i}] = {o}, expected {expected}"
553 );
554 }
555 }
556
557 #[test]
558 fn flash_decode_tile_count_1() {
559 let head_dim = 8;
560 let seq_len = 16;
561 let config = FlashDecodeConfig::new(head_dim).with_num_tiles(1);
562 let (q, k, v) = make_deterministic_data(seq_len, head_dim);
563 let result = flash_decode_single_head(&q, &k, &v, seq_len, head_dim, &config);
564 assert!(result.is_ok(), "num_tiles=1 should be valid: {result:?}");
565 }
566
567 #[test]
568 fn flash_decode_tile_count_many() {
569 let head_dim = 8;
570 let seq_len = 16;
571 let config = FlashDecodeConfig::new(head_dim).with_num_tiles(8);
572 let (q, k, v) = make_deterministic_data(seq_len, head_dim);
573 let result = flash_decode_single_head(&q, &k, &v, seq_len, head_dim, &config);
574 assert!(
575 result.is_ok(),
576 "num_tiles=8 with seq_len=16 failed: {result:?}"
577 );
578 }
579
580 #[test]
581 fn flash_vs_naive_error_small() {
582 let head_dim = 32;
583 let seq_len = 64;
584 let (q, k, v) = make_deterministic_data(seq_len, head_dim);
585 let mae = flash_vs_naive_error(&q, &k, &v, seq_len, head_dim)
586 .expect("flash_vs_naive_error failed");
587 assert!(mae < 1e-4, "MAE too large: {mae}");
588 }
589
590 #[test]
591 fn flash_decode_multi_head_shape() {
592 let num_heads = 4;
593 let head_dim = 8;
594 let seq_len = 16;
595 let config = FlashDecodeConfig::new(head_dim);
596
597 let queries = vec![0.1f32; num_heads * head_dim];
598 let keys = vec![0.05f32; seq_len * num_heads * head_dim];
599 let values = vec![0.2f32; seq_len * num_heads * head_dim];
600
601 let out = flash_decode_multi_head(
602 &queries, &keys, &values, num_heads, seq_len, head_dim, &config,
603 )
604 .expect("multi_head flash decode failed");
605
606 assert_eq!(
607 out.len(),
608 num_heads * head_dim,
609 "output shape mismatch: {} vs {}",
610 out.len(),
611 num_heads * head_dim
612 );
613 }
614
615 #[test]
616 fn flash_decode_multi_head_matches_naive_per_head() {
617 let num_heads = 2;
618 let head_dim = 8;
619 let seq_len = 16;
620 let config = FlashDecodeConfig::new(head_dim);
621
622 let queries: Vec<f32> = (0..num_heads * head_dim).map(|i| 0.1 * i as f32).collect();
624 let keys: Vec<f32> = (0..seq_len * num_heads * head_dim)
625 .map(|i| 0.05 * (i % 17) as f32 + 0.01)
626 .collect();
627 let values: Vec<f32> = (0..seq_len * num_heads * head_dim)
628 .map(|i| 0.02 * (i % 13) as f32 + 0.1)
629 .collect();
630
631 let flash_out = flash_decode_multi_head(
632 &queries, &keys, &values, num_heads, seq_len, head_dim, &config,
633 )
634 .expect("multi_head flash decode failed");
635
636 for h in 0..num_heads {
638 let q_vec = &queries[h * head_dim..(h + 1) * head_dim];
639
640 let mut k_head = vec![0.0f32; seq_len * head_dim];
642 let mut v_head = vec![0.0f32; seq_len * head_dim];
643 for t in 0..seq_len {
644 let src_k = t * num_heads * head_dim + h * head_dim;
645 let src_v = t * num_heads * head_dim + h * head_dim;
646 let dst = t * head_dim;
647 k_head[dst..dst + head_dim].copy_from_slice(&keys[src_k..src_k + head_dim]);
648 v_head[dst..dst + head_dim].copy_from_slice(&values[src_v..src_v + head_dim]);
649 }
650
651 let naive_config = FlashDecodeConfig::new(head_dim).with_num_tiles(1);
652 let naive_out =
653 flash_decode_single_head(q_vec, &k_head, &v_head, seq_len, head_dim, &naive_config)
654 .expect("naive single head failed");
655
656 let head_flash = &flash_out[h * head_dim..(h + 1) * head_dim];
657 let mae: f32 = head_flash
658 .iter()
659 .zip(naive_out.iter())
660 .map(|(a, b)| (a - b).abs())
661 .sum::<f32>()
662 / head_dim as f32;
663 assert!(
664 mae < 1e-4,
665 "head {h}: MAE between multi_head flash and single-head naive = {mae}"
666 );
667 }
668 }
669
670 #[test]
671 fn combine_tiles_single_tile() {
672 let head_dim = 4;
673 let tile_out = vec![1.0f32, 2.0, 3.0, 4.0];
674 let combined = combine_tile_outputs(
675 std::slice::from_ref(&tile_out),
676 &[0.5_f32],
677 &[1.0_f32],
678 head_dim,
679 );
680 for (i, (&c, &t)) in combined.iter().zip(tile_out.iter()).enumerate() {
682 assert!((c - t).abs() < 1e-5, "combined[{i}] = {c}, expected {t}");
683 }
684 }
685
686 #[test]
687 fn flash_decode_long_sequence() {
688 let head_dim = 16;
689 let seq_len = 128;
690 let config = FlashDecodeConfig::new(head_dim).with_num_tiles(8);
691 let (q, k, v) = make_deterministic_data(seq_len, head_dim);
692 let result = flash_decode_single_head(&q, &k, &v, seq_len, head_dim, &config);
693 assert!(
694 result.is_ok(),
695 "long sequence (seq_len=128) failed: {result:?}"
696 );
697 let out = result.expect("already checked");
698 assert_eq!(out.len(), head_dim);
699 for (i, &o) in out.iter().enumerate() {
700 assert!(o.is_finite(), "output[{i}] = {o} is not finite");
701 }
702 }
703}