1use std::cell::RefCell;
21use std::collections::HashMap;
22use std::sync::atomic::{AtomicUsize, Ordering};
23
24use burn::tensor::backend::Backend;
25use burn::tensor::{Tensor, TensorData};
26
27use crate::ops::stable_accumulator::AccumulatorConfig;
28
29const DEFAULT_MASK_CACHE_CAPACITY: usize = 8192;
32
33static MASK_CACHE_CAPACITY: AtomicUsize = AtomicUsize::new(DEFAULT_MASK_CACHE_CAPACITY);
35
36#[derive(Hash, Eq, PartialEq, Clone, Debug)]
42struct MaskCacheKey {
43 query_len: usize,
45 key_len: usize,
47 relative_offset: isize,
50}
51
52impl MaskCacheKey {
53 fn from_positions(
55 query_len: usize,
56 key_len: usize,
57 q_start: usize,
58 kv_start: usize,
59 position_offset: usize,
60 ) -> Self {
61 let relative_offset = (q_start + position_offset) as isize - kv_start as isize;
66 Self {
67 query_len,
68 key_len,
69 relative_offset,
70 }
71 }
72}
73
74thread_local! {
76 static MASK_CACHE: RefCell<MaskCache> = RefCell::new(
77 MaskCache::new(MASK_CACHE_CAPACITY.load(Ordering::Relaxed))
78 );
79}
80
81struct MaskCache {
83 cache: HashMap<MaskCacheKey, Vec<f32>>,
85 access_order: Vec<MaskCacheKey>,
87 capacity: usize,
89 hits: usize,
91 misses: usize,
93}
94
95impl MaskCache {
96 fn new(capacity: usize) -> Self {
97 Self {
98 cache: HashMap::with_capacity(capacity.min(1024)), access_order: Vec::with_capacity(capacity.min(1024)),
100 capacity,
101 hits: 0,
102 misses: 0,
103 }
104 }
105
106 fn get_or_create(
107 &mut self,
108 key: MaskCacheKey,
109 create_fn: impl FnOnce() -> Vec<f32>,
110 ) -> &Vec<f32> {
111 if self.cache.contains_key(&key) {
112 self.hits += 1;
113 if self.access_order.last() != Some(&key) {
115 if let Some(pos) = self.access_order.iter().position(|k| k == &key) {
116 self.access_order.remove(pos);
117 self.access_order.push(key.clone());
118 }
119 }
120 } else {
121 self.misses += 1;
122 while self.cache.len() >= self.capacity && !self.access_order.is_empty() {
124 let oldest = self.access_order.remove(0);
125 self.cache.remove(&oldest);
126 }
127 self.cache.insert(key.clone(), create_fn());
128 self.access_order.push(key.clone());
129 }
130 self.cache.get(&key).unwrap()
131 }
132
133 #[allow(dead_code)]
135 fn stats(&self) -> (usize, usize, f64) {
136 let total = self.hits + self.misses;
137 let hit_rate = if total > 0 {
138 self.hits as f64 / total as f64
139 } else {
140 0.0
141 };
142 (self.hits, self.misses, hit_rate)
143 }
144
145 fn clear(&mut self) {
147 self.cache.clear();
148 self.access_order.clear();
149 self.hits = 0;
150 self.misses = 0;
151 }
152
153 #[allow(dead_code)]
155 fn resize(&mut self, new_capacity: usize) {
156 self.capacity = new_capacity;
157 while self.cache.len() > self.capacity && !self.access_order.is_empty() {
159 let oldest = self.access_order.remove(0);
160 self.cache.remove(&oldest);
161 }
162 }
163
164 #[allow(dead_code)]
166 fn len(&self) -> usize {
167 self.cache.len()
168 }
169}
170
171#[derive(Clone, Debug)]
173pub struct MaskCacheConfig {
174 pub capacity: usize,
176 pub track_stats: bool,
178}
179
180impl Default for MaskCacheConfig {
181 fn default() -> Self {
182 Self {
183 capacity: DEFAULT_MASK_CACHE_CAPACITY,
184 track_stats: false,
185 }
186 }
187}
188
189impl MaskCacheConfig {
190 pub fn short_context() -> Self {
192 Self {
193 capacity: 1024,
194 track_stats: false,
195 }
196 }
197
198 pub fn medium_context() -> Self {
200 Self {
201 capacity: 8192,
202 track_stats: false,
203 }
204 }
205
206 pub fn long_context() -> Self {
208 Self {
209 capacity: 32768,
210 track_stats: false,
211 }
212 }
213
214 pub fn ultra_long_context() -> Self {
217 Self {
218 capacity: 131072, track_stats: false,
220 }
221 }
222}
223
224#[derive(Clone, Debug)]
226pub struct DeterministicConfig {
227 pub enabled: bool,
229 pub fixed_tile_order: bool,
231 pub seed: Option<u64>,
233 pub no_gpu_nondeterminism: bool,
235 pub verify_determinism: bool,
237}
238
239impl Default for DeterministicConfig {
240 fn default() -> Self {
241 Self {
242 enabled: false,
243 fixed_tile_order: false,
244 seed: None,
245 no_gpu_nondeterminism: false,
246 verify_determinism: false,
247 }
248 }
249}
250
251impl DeterministicConfig {
252 pub fn strict() -> Self {
254 Self {
255 enabled: true,
256 fixed_tile_order: true,
257 seed: Some(42),
258 no_gpu_nondeterminism: true,
259 verify_determinism: cfg!(debug_assertions),
260 }
261 }
262
263 pub fn relaxed() -> Self {
265 Self {
266 enabled: false,
267 fixed_tile_order: false,
268 seed: None,
269 no_gpu_nondeterminism: false,
270 verify_determinism: false,
271 }
272 }
273
274 pub fn ultra_long_context() -> Self {
276 Self::strict()
277 }
278
279 pub fn is_deterministic(&self) -> bool {
281 self.enabled || self.fixed_tile_order || self.seed.is_some()
282 }
283}
284
285pub struct StrictOrderIterator<I> {
287 inner: I,
288 index: usize,
289}
290
291impl<I: Iterator> StrictOrderIterator<I> {
292 pub fn new(iter: I) -> Self {
293 Self { inner: iter, index: 0 }
294 }
295
296 pub fn current_index(&self) -> usize {
298 self.index
299 }
300}
301
302impl<I: Iterator> Iterator for StrictOrderIterator<I> {
303 type Item = (usize, I::Item);
304
305 fn next(&mut self) -> Option<Self::Item> {
306 let item = self.inner.next()?;
307 let index = self.index;
308 self.index += 1;
309
310 std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
311
312 Some((index, item))
313 }
314}
315
316pub trait StrictOrderExt: Iterator + Sized {
318 fn strict_order(self) -> StrictOrderIterator<Self> {
319 StrictOrderIterator::new(self)
320 }
321}
322
323impl<I: Iterator> StrictOrderExt for I {}
324
325#[derive(Debug, Clone)]
327pub struct FlashAttentionConfig {
328 pub block_q: usize,
330 pub block_kv: usize,
332 pub accumulator: AccumulatorConfig,
334 pub determinism: DeterministicConfig,
336 pub use_log_space: bool,
338 pub max_seq_len: usize,
340}
341
342impl Default for FlashAttentionConfig {
343 fn default() -> Self {
344 Self {
345 block_q: 64,
346 block_kv: 16,
347 accumulator: AccumulatorConfig::max_precision(),
348 determinism: DeterministicConfig::strict(),
349 use_log_space: true,
350 max_seq_len: 2_000_000,
351 }
352 }
353}
354
355impl FlashAttentionConfig {
356 pub fn ultra_long_context() -> Self {
358 Self {
359 block_q: 64,
360 block_kv: 16,
361 accumulator: AccumulatorConfig::max_precision(),
362 determinism: DeterministicConfig::ultra_long_context(),
363 use_log_space: true,
364 max_seq_len: 2_000_000,
365 }
366 }
367
368 pub fn short_context() -> Self {
370 Self {
371 block_q: 128,
372 block_kv: 64,
373 accumulator: AccumulatorConfig::short_context(),
374 determinism: DeterministicConfig::relaxed(),
375 use_log_space: false,
376 max_seq_len: 100_000,
377 }
378 }
379}
380
381pub type HierarchicalFlashConfig = FlashAttentionConfig;
383
384pub trait FusedPagedAttention<B: Backend> {
386 fn forward_fused<'a, I>(
388 &self,
389 q: Tensor<B, 4>,
390 kv_blocks: I,
391 config: &FlashAttentionConfig,
392 causal: bool,
393 position_offset: usize,
394 ) -> Tensor<B, 4>
395 where
396 I: Iterator<Item = (Tensor<B, 3>, Tensor<B, 3>)> + 'a;
397}
398
399pub struct AttentionWorkspace<B: Backend> {
404 pub m_buffer: Option<Tensor<B, 4>>,
406 pub l_buffer: Option<Tensor<B, 4>>,
408 pub o_buffer: Option<Tensor<B, 4>>,
410 dims: Option<(usize, usize, usize, usize)>,
412}
413
414impl<B: Backend> Default for AttentionWorkspace<B> {
415 fn default() -> Self {
416 Self::new()
417 }
418}
419
420impl<B: Backend> AttentionWorkspace<B> {
421 pub fn new() -> Self {
423 Self {
424 m_buffer: None,
425 l_buffer: None,
426 o_buffer: None,
427 dims: None,
428 }
429 }
430
431 pub fn allocate(
433 &mut self,
434 device: &B::Device,
435 batch_size: usize,
436 num_heads: usize,
437 q_block_len: usize,
438 head_dim: usize,
439 ) {
440 let needs_realloc = self.dims.map_or(true, |(b, h, q, d)| {
441 b != batch_size || h != num_heads || q < q_block_len || d != head_dim
442 });
443
444 if needs_realloc {
445 self.m_buffer = Some(Tensor::zeros(
446 [batch_size, num_heads, q_block_len, 1],
447 device,
448 ));
449 self.l_buffer = Some(Tensor::zeros(
450 [batch_size, num_heads, q_block_len, 1],
451 device,
452 ));
453 self.o_buffer = Some(Tensor::zeros(
454 [batch_size, num_heads, q_block_len, head_dim],
455 device,
456 ));
457 self.dims = Some((batch_size, num_heads, q_block_len, head_dim));
458 }
459 }
460
461 pub fn reset(&mut self, device: &B::Device) {
463 if let Some((batch_size, num_heads, q_block_len, _)) = self.dims {
464 self.m_buffer = Some(Tensor::full(
465 [batch_size, num_heads, q_block_len, 1],
466 f32::NEG_INFINITY,
467 device,
468 ));
469 if let Some(ref mut l) = self.l_buffer {
470 *l = l.clone().zeros_like();
471 }
472 if let Some(ref mut o) = self.o_buffer {
473 *o = o.clone().zeros_like();
474 }
475 }
476 }
477
478 pub fn take_output(&mut self) -> Option<Tensor<B, 4>> {
480 self.o_buffer.take()
481 }
482}
483
484#[derive(Debug, Clone)]
486pub struct HierarchicalFlashAttention {
487 config: FlashAttentionConfig,
488}
489
490impl HierarchicalFlashAttention {
491 pub fn new(config: FlashAttentionConfig) -> Self {
493 Self { config }
494 }
495
496 pub fn default_config() -> Self {
498 Self::new(FlashAttentionConfig::default())
499 }
500
501 pub fn ultra_long_context() -> Self {
503 Self::new(FlashAttentionConfig::ultra_long_context())
504 }
505
506 pub fn config(&self) -> &FlashAttentionConfig {
508 &self.config
509 }
510
511 pub fn forward_with_workspace<B: Backend>(
516 &self,
517 q: Tensor<B, 4>,
518 k: Tensor<B, 4>,
519 v: Tensor<B, 4>,
520 causal: bool,
521 position_offset: usize,
522 workspace: &mut AttentionWorkspace<B>,
523 ) -> Tensor<B, 4> {
524 let device = q.device();
525 let [batch_size, num_heads, query_len, head_dim] = q.dims();
526 let key_len = k.dims()[2];
527
528 if query_len == 0 || key_len == 0 {
529 return Tensor::zeros([batch_size, num_heads, query_len, head_dim], &device);
530 }
531
532 let block_q = self.config.block_q.max(1);
533 let block_kv = self.config.block_kv.max(1);
534 let inv_scale = 1.0 / (head_dim as f32).sqrt();
535
536 workspace.allocate(&device, batch_size, num_heads, block_q, head_dim);
538
539 let q_blocks = q.split(block_q, 2);
540 let k_blocks = k.split(block_kv, 2);
541 let v_blocks = v.split(block_kv, 2);
542 let k_blocks_t: Vec<_> = k_blocks.into_iter().map(|block| block.transpose()).collect();
543 let mut outputs = Vec::with_capacity(q_blocks.len());
544
545 for (q_block_index, q_block) in q_blocks.into_iter().enumerate() {
546 let q_block_len = q_block.dims()[2];
547 let q_start = q_block_index * block_q;
548 let q_block_scaled = q_block * inv_scale;
549
550 let mut m_i = Tensor::<B, 4>::full(
552 [batch_size, num_heads, q_block_len, 1],
553 f32::NEG_INFINITY,
554 &device,
555 );
556 let mut l_i = Tensor::<B, 4>::zeros([batch_size, num_heads, q_block_len, 1], &device);
557 let mut o_i = Tensor::<B, 4>::zeros(
558 [batch_size, num_heads, q_block_len, head_dim],
559 &device,
560 );
561
562 for (kv_index, (k_block_t, v_block)) in
563 k_blocks_t.iter().zip(v_blocks.iter()).enumerate()
564 {
565 let kv_block_len = k_block_t.dims()[3];
566 let kv_start = kv_index * block_kv;
567
568 let mut scores = q_block_scaled.clone().matmul(k_block_t.clone());
570
571 if causal {
573 let mask = self.build_causal_mask_cached::<B>(
574 &device,
575 q_block_len,
576 kv_block_len,
577 q_start,
578 kv_start,
579 position_offset,
580 );
581 scores = scores + mask;
582 }
583
584 let m_ij = scores.clone().max_dim(3);
586 let m_new = m_i.clone().max_pair(m_ij);
587
588 let m_scale = (m_i - m_new.clone()).exp();
589 let p_ij = (scores - m_new.clone()).exp();
590 let p_sum = p_ij.clone().sum_dim(3);
591
592 l_i = m_scale.clone() * l_i + p_sum;
594 o_i = m_scale * o_i + p_ij.matmul(v_block.clone());
595 m_i = m_new;
596 }
597
598 outputs.push(o_i / l_i);
599 }
600
601 Tensor::cat(outputs, 2)
602 }
603
604 fn build_causal_mask_cached<B: Backend>(
613 &self,
614 device: &B::Device,
615 query_len: usize,
616 key_len: usize,
617 q_start: usize,
618 kv_start: usize,
619 position_offset: usize,
620 ) -> Tensor<B, 4> {
621 let key = MaskCacheKey::from_positions(
623 query_len,
624 key_len,
625 q_start,
626 kv_start,
627 position_offset,
628 );
629 let relative_offset = key.relative_offset;
630
631 let mask_value = -1.0e4_f32;
632
633 let data = MASK_CACHE.with(|cache| {
634 let mut cache = cache.borrow_mut();
635 cache
636 .get_or_create(key, || {
637 let mut data = Vec::with_capacity(query_len * key_len);
638 for i in 0..query_len {
642 let threshold = relative_offset + i as isize;
643 for j in 0..key_len {
644 let allowed = (j as isize) <= threshold;
645 data.push(if allowed { 0.0 } else { mask_value });
646 }
647 }
648 data
649 })
650 .clone()
651 });
652
653 Tensor::<B, 2>::from_data(TensorData::new(data, [query_len, key_len]), device)
654 .reshape([1, 1, query_len, key_len])
655 }
656
657 pub fn forward<B: Backend>(
659 &self,
660 q: Tensor<B, 4>,
661 k: Tensor<B, 4>,
662 v: Tensor<B, 4>,
663 causal: bool,
664 position_offset: usize,
665 ) -> Tensor<B, 4> {
666 let device = q.device();
667 let [batch_size, num_heads, query_len, head_dim] = q.dims();
668 let key_len = k.dims()[2];
669
670 if query_len == 0 || key_len == 0 {
671 return Tensor::zeros([batch_size, num_heads, query_len, head_dim], &device);
672 }
673
674 let block_q = self.config.block_q.max(1);
675 let block_kv = self.config.block_kv.max(1);
676 let inv_scale = 1.0 / (head_dim as f32).sqrt();
677
678 let k_blocks = k.split(block_kv, 2);
679 let v_blocks = v.split(block_kv, 2);
680 let k_blocks_t: Vec<_> = k_blocks.into_iter().map(|block| block.transpose()).collect();
681 let q_blocks = q.split(block_q, 2);
682
683 let mut outputs = Vec::with_capacity(q_blocks.len());
684 let fixed_tile_order = self.config.determinism.fixed_tile_order;
685 let kv_block_count = k_blocks_t.len();
686
687 let process_q_block = |q_start: usize, q_block: Tensor<B, 4>, outputs: &mut Vec<Tensor<B, 4>>| {
688 let q_block_len = q_block.dims()[2];
689 let q_block_scaled = q_block * inv_scale;
690
691 let mut m_i = Tensor::<B, 4>::full(
692 [batch_size, num_heads, q_block_len, 1],
693 f32::NEG_INFINITY,
694 &device,
695 );
696 let mut l_i = Tensor::<B, 4>::zeros([batch_size, num_heads, q_block_len, 1], &device);
697 let mut o_i = Tensor::<B, 4>::zeros(
698 [batch_size, num_heads, q_block_len, head_dim],
699 &device,
700 );
701
702 if fixed_tile_order {
703 let mut kv_index = 0usize;
704 while kv_index < kv_block_count {
705 let k_block_t = &k_blocks_t[kv_index];
706 let v_block = &v_blocks[kv_index];
707 let kv_block_len = k_block_t.dims()[3];
708 let kv_start = kv_index * block_kv;
709
710 let mut scores = q_block_scaled.clone().matmul(k_block_t.clone());
711
712 if causal {
713 let mask = self.build_causal_mask_cached::<B>(
714 &device,
715 q_block_len,
716 kv_block_len,
717 q_start,
718 kv_start,
719 position_offset,
720 );
721 scores = scores + mask;
722 }
723
724 let m_ij = scores.clone().max_dim(3);
725 let m_new = m_i.clone().max_pair(m_ij);
726
727 let m_scale = (m_i - m_new.clone()).exp();
728 let p_ij = (scores - m_new.clone()).exp();
729 let p_sum = p_ij.clone().sum_dim(3);
730
731 l_i = m_scale.clone() * l_i + p_sum;
732 o_i = m_scale * o_i + p_ij.matmul(v_block.clone());
733 m_i = m_new;
734
735 kv_index += 1;
736 }
737 } else {
738 for kv_index in 0..kv_block_count {
739 let k_block_t = &k_blocks_t[kv_index];
740 let v_block = &v_blocks[kv_index];
741 let kv_block_len = k_block_t.dims()[3];
742 let kv_start = kv_index * block_kv;
743
744 let mut scores = q_block_scaled.clone().matmul(k_block_t.clone());
745
746 if causal {
747 let mask = self.build_causal_mask_cached::<B>(
748 &device,
749 q_block_len,
750 kv_block_len,
751 q_start,
752 kv_start,
753 position_offset,
754 );
755 scores = scores + mask;
756 }
757
758 let m_ij = scores.clone().max_dim(3);
759 let m_new = m_i.clone().max_pair(m_ij);
760
761 let m_scale = (m_i - m_new.clone()).exp();
762 let p_ij = (scores - m_new.clone()).exp();
763 let p_sum = p_ij.clone().sum_dim(3);
764
765 l_i = m_scale.clone() * l_i + p_sum;
766 o_i = m_scale * o_i + p_ij.matmul(v_block.clone());
767 m_i = m_new;
768 }
769 }
770
771 outputs.push(o_i / l_i);
772 };
773
774 for (q_block_index, q_block) in q_blocks.into_iter().enumerate() {
775 let q_start = q_block_index * block_q;
776 process_q_block(q_start, q_block, &mut outputs);
777 }
778
779 let output = Tensor::cat(outputs, 2);
780 B::sync(&output.device());
781 output
782 }
783
784 pub fn forward_fused_iter<'a, B, I>(
786 &self,
787 q: Tensor<B, 4>,
788 kv_blocks: I,
789 causal: bool,
790 position_offset: usize,
791 total_kv_len: usize,
792 ) -> Tensor<B, 4>
793 where
794 B: Backend,
795 I: Iterator<Item = (Tensor<B, 3>, Tensor<B, 3>)> + 'a,
796 {
797 let device = q.device();
798 let [batch_size, num_heads, query_len, head_dim] = q.dims();
799
800 if query_len == 0 || total_kv_len == 0 {
801 return Tensor::zeros([batch_size, num_heads, query_len, head_dim], &device);
802 }
803
804 let block_q = self.config.block_q.max(1);
805 let inv_scale = 1.0 / (head_dim as f32).sqrt();
806
807 let (kv_lower, _) = kv_blocks.size_hint();
808 let mut kv_blocks_vec: Vec<(Tensor<B, 4>, Tensor<B, 4>)> = Vec::with_capacity(kv_lower);
809
810 if self.config.determinism.fixed_tile_order {
811 kv_blocks_vec.extend(
812 kv_blocks
813 .strict_order()
814 .map(|(_, (k, v))| (k.unsqueeze_dim(0).transpose(), v.unsqueeze_dim(0))),
815 );
816 } else {
817 kv_blocks_vec.extend(
818 kv_blocks.map(|(k, v)| (k.unsqueeze_dim(0).transpose(), v.unsqueeze_dim(0))),
819 );
820 }
821
822 let mut kv_starts = Vec::with_capacity(kv_blocks_vec.len());
823 let mut kv_start = 0usize;
824 for (k_block_t, _) in &kv_blocks_vec {
825 kv_starts.push(kv_start);
826 kv_start += k_block_t.dims()[3];
827 }
828
829 let q_blocks = q.split(block_q, 2);
830 let mut outputs = Vec::with_capacity(q_blocks.len());
831
832 for (q_block_index, q_block) in q_blocks.into_iter().enumerate() {
833 let q_start = q_block_index * block_q;
834 let q_block_scaled = q_block * inv_scale;
835
836 let output = if self.config.use_log_space {
837 self.process_q_block_log_space(
838 q_block_scaled,
839 &kv_blocks_vec,
840 &kv_starts,
841 causal,
842 q_start,
843 position_offset,
844 )
845 } else {
846 self.process_q_block_standard(
847 q_block_scaled,
848 &kv_blocks_vec,
849 &kv_starts,
850 causal,
851 q_start,
852 position_offset,
853 )
854 };
855
856 outputs.push(output);
857 }
858
859 Tensor::cat(outputs, 2)
860 }
861
862 fn process_q_block_standard<B: Backend>(
863 &self,
864 q_block: Tensor<B, 4>,
865 kv_blocks: &[(Tensor<B, 4>, Tensor<B, 4>)],
866 kv_starts: &[usize],
867 causal: bool,
868 q_start: usize,
869 position_offset: usize,
870 ) -> Tensor<B, 4> {
871 let device = q_block.device();
872 let [batch_size, num_heads, q_block_len, head_dim] = q_block.dims();
873
874 let mut m_i = Tensor::<B, 4>::full(
875 [batch_size, num_heads, q_block_len, 1],
876 f32::NEG_INFINITY,
877 &device,
878 );
879 let mut l_i = Tensor::<B, 4>::zeros([batch_size, num_heads, q_block_len, 1], &device);
880 let mut o_i = Tensor::<B, 4>::zeros(
881 [batch_size, num_heads, q_block_len, head_dim],
882 &device,
883 );
884
885 for (kv_index, (k_block_t, v_block)) in kv_blocks.iter().enumerate() {
886 let kv_block_len = k_block_t.dims()[3];
887 let kv_start = kv_starts[kv_index];
888
889 let mut scores = q_block.clone().matmul(k_block_t.clone());
890
891 if causal {
892 let mask = self.build_causal_mask_cached::<B>(
893 &device,
894 q_block_len,
895 kv_block_len,
896 q_start,
897 kv_start,
898 position_offset,
899 );
900 scores = scores + mask;
901 }
902
903 let m_ij = scores.clone().max_dim(3);
904 let m_new = m_i.clone().max_pair(m_ij);
905
906 let m_scale = (m_i - m_new.clone()).exp();
907 let p_ij = (scores - m_new.clone()).exp();
908 let p_sum = p_ij.clone().sum_dim(3);
909
910 l_i = m_scale.clone() * l_i + p_sum;
911 o_i = m_scale * o_i + p_ij.matmul(v_block.clone());
912 m_i = m_new;
913 }
914
915 o_i / l_i
916 }
917
918 fn process_q_block_log_space<B: Backend>(
919 &self,
920 q_block: Tensor<B, 4>,
921 kv_blocks: &[(Tensor<B, 4>, Tensor<B, 4>)],
922 kv_starts: &[usize],
923 causal: bool,
924 q_start: usize,
925 position_offset: usize,
926 ) -> Tensor<B, 4> {
927 let device = q_block.device();
928 let [batch_size, num_heads, q_block_len, head_dim] = q_block.dims();
929
930 let mut m_i = Tensor::<B, 4>::full(
931 [batch_size, num_heads, q_block_len, 1],
932 f32::NEG_INFINITY,
933 &device,
934 );
935 let mut log_l_i = Tensor::<B, 4>::full(
936 [batch_size, num_heads, q_block_len, 1],
937 f32::NEG_INFINITY,
938 &device,
939 );
940 let mut o_i = Tensor::<B, 4>::zeros(
941 [batch_size, num_heads, q_block_len, head_dim],
942 &device,
943 );
944
945 for (kv_index, (k_block_t, v_block)) in kv_blocks.iter().enumerate() {
946 let kv_block_len = k_block_t.dims()[3];
947 let kv_start = kv_starts[kv_index];
948
949 let mut scores = q_block.clone().matmul(k_block_t.clone());
950
951 if causal {
952 let mask = self.build_causal_mask_cached::<B>(
953 &device,
954 q_block_len,
955 kv_block_len,
956 q_start,
957 kv_start,
958 position_offset,
959 );
960 scores = scores + mask;
961 }
962
963 let m_ij = scores.clone().max_dim(3);
964 let m_new = m_i.clone().max_pair(m_ij.clone());
965
966 let scores_shifted = scores - m_ij.clone();
967 let p_ij = scores_shifted.exp();
968 let sum_p = p_ij.clone().sum_dim(3);
969 let log_sum_p = sum_p.log();
970
971 let m_diff = m_i - m_new.clone();
972 let log_prev = m_diff.clone() + log_l_i;
973 let log_curr = (m_ij - m_new.clone()) + log_sum_p;
974
975 let log_l_new = Self::tensor_log_add_exp(log_prev, log_curr);
976
977 let m_scale = m_diff.exp();
978 o_i = m_scale * o_i + p_ij.matmul(v_block.clone());
979
980 m_i = m_new;
981 log_l_i = log_l_new;
982 }
983
984 let l_i = log_l_i.exp();
985 o_i / l_i
986 }
987
988 fn tensor_log_add_exp<B: Backend>(a: Tensor<B, 4>, b: Tensor<B, 4>) -> Tensor<B, 4> {
989 let max = a.clone().max_pair(b.clone());
990 let diff_a = a - max.clone();
991 let diff_b = b - max.clone();
992 max + (diff_a.exp() + diff_b.exp()).log()
993 }
994
995 #[allow(dead_code)]
997 fn build_causal_mask_uncached<B: Backend>(
998 &self,
999 device: &B::Device,
1000 query_len: usize,
1001 key_len: usize,
1002 q_start: usize,
1003 kv_start: usize,
1004 position_offset: usize,
1005 ) -> Tensor<B, 4> {
1006 let mut data = Vec::with_capacity(query_len * key_len);
1007 let mask_value = -1.0e4_f32;
1008
1009 for i in 0..query_len {
1010 let absolute_pos = position_offset + q_start + i;
1011 for j in 0..key_len {
1012 let absolute_key = kv_start + j;
1013 let allowed = absolute_key <= absolute_pos;
1014 data.push(if allowed { 0.0 } else { mask_value });
1015 }
1016 }
1017
1018 Tensor::<B, 2>::from_data(TensorData::new(data, [query_len, key_len]), device)
1019 .reshape([1, 1, query_len, key_len])
1020 }
1021
1022 pub fn clear_mask_cache() {
1026 MASK_CACHE.with(|cache| cache.borrow_mut().clear());
1027 }
1028}
1029
1030impl<B: Backend> FusedPagedAttention<B> for HierarchicalFlashAttention {
1031 fn forward_fused<'a, I>(
1032 &self,
1033 q: Tensor<B, 4>,
1034 kv_blocks: I,
1035 config: &FlashAttentionConfig,
1036 causal: bool,
1037 position_offset: usize,
1038 ) -> Tensor<B, 4>
1039 where
1040 I: Iterator<Item = (Tensor<B, 3>, Tensor<B, 3>)> + 'a,
1041 {
1042 let kv_blocks: Vec<_> = kv_blocks.collect();
1043 let total_kv_len: usize = kv_blocks.iter().map(|(k, _)| k.dims()[1]).sum();
1044
1045 let attention = Self::new(config.clone());
1046
1047 attention.forward_fused_iter(q, kv_blocks.into_iter(), causal, position_offset, total_kv_len)
1048 }
1049}
1050
1051#[cfg(all(test, feature = "cpu"))]
1052mod tests {
1053 use super::*;
1054 use burn::tensor::activation::softmax;
1055 use burn_ndarray::NdArray;
1056
1057 type TestBackend = NdArray<f32>;
1058
1059 #[test]
1060 fn test_hierarchical_flash_basic() {
1061 let device = <TestBackend as Backend>::Device::default();
1062 let attention = HierarchicalFlashAttention::default_config();
1063
1064 let batch_size = 1;
1065 let num_heads = 2;
1066 let seq_len = 16;
1067 let head_dim = 8;
1068
1069 let q = Tensor::<TestBackend, 4>::random(
1070 [batch_size, num_heads, seq_len, head_dim],
1071 burn::tensor::Distribution::Normal(0.0, 1.0),
1072 &device,
1073 );
1074 let k = Tensor::<TestBackend, 4>::random(
1075 [batch_size, num_heads, seq_len, head_dim],
1076 burn::tensor::Distribution::Normal(0.0, 1.0),
1077 &device,
1078 );
1079 let v = Tensor::<TestBackend, 4>::random(
1080 [batch_size, num_heads, seq_len, head_dim],
1081 burn::tensor::Distribution::Normal(0.0, 1.0),
1082 &device,
1083 );
1084
1085 let output = attention.forward(q, k, v, false, 0);
1086 assert_eq!(output.dims(), [batch_size, num_heads, seq_len, head_dim]);
1087 }
1088
1089 #[test]
1090 fn test_hierarchical_flash_matches_standard() {
1091 let device = <TestBackend as Backend>::Device::default();
1092 let attention = HierarchicalFlashAttention::new(FlashAttentionConfig {
1093 block_q: 4,
1094 block_kv: 4,
1095 use_log_space: false,
1096 ..Default::default()
1097 });
1098
1099 let batch_size = 1;
1100 let num_heads = 2;
1101 let seq_len = 8;
1102 let head_dim = 4;
1103
1104 let q = Tensor::<TestBackend, 4>::random(
1105 [batch_size, num_heads, seq_len, head_dim],
1106 burn::tensor::Distribution::Normal(0.0, 0.5),
1107 &device,
1108 );
1109 let k = Tensor::<TestBackend, 4>::random(
1110 [batch_size, num_heads, seq_len, head_dim],
1111 burn::tensor::Distribution::Normal(0.0, 0.5),
1112 &device,
1113 );
1114 let v = Tensor::<TestBackend, 4>::random(
1115 [batch_size, num_heads, seq_len, head_dim],
1116 burn::tensor::Distribution::Normal(0.0, 0.5),
1117 &device,
1118 );
1119
1120 let output_hier = attention.forward(q.clone(), k.clone(), v.clone(), false, 0);
1121
1122 let scale = (head_dim as f32).sqrt();
1123 let scores = q.matmul(k.transpose()) / scale;
1124 let attn = softmax(scores, 3);
1125 let output_std = attn.matmul(v);
1126
1127 let hier_data = output_hier
1128 .into_data()
1129 .into_vec::<f32>()
1130 .expect("output data");
1131 let std_data = output_std
1132 .into_data()
1133 .into_vec::<f32>()
1134 .expect("output data");
1135
1136 for (i, (h, s)) in hier_data.iter().zip(std_data.iter()).enumerate() {
1137 let diff = (h - s).abs();
1138 assert!(
1139 diff < 1e-3,
1140 "Mismatch at {}: hier={}, std={}, diff={}",
1141 i,
1142 h,
1143 s,
1144 diff
1145 );
1146 }
1147 }
1148
1149 #[test]
1150 fn test_fused_iter_matches_standard() {
1151 let device = <TestBackend as Backend>::Device::default();
1152 let attention = HierarchicalFlashAttention::new(FlashAttentionConfig {
1153 block_q: 4,
1154 block_kv: 4,
1155 use_log_space: false,
1156 ..Default::default()
1157 });
1158
1159 let num_heads = 2;
1160 let seq_len = 16;
1161 let head_dim = 4;
1162 let block_size = 4;
1163
1164 let q = Tensor::<TestBackend, 4>::random(
1165 [1, num_heads, seq_len, head_dim],
1166 burn::tensor::Distribution::Normal(0.0, 0.5),
1167 &device,
1168 );
1169
1170 let num_blocks = seq_len / block_size;
1171 let kv_blocks: Vec<_> = (0..num_blocks)
1172 .map(|_| {
1173 let k = Tensor::<TestBackend, 3>::random(
1174 [num_heads, block_size, head_dim],
1175 burn::tensor::Distribution::Normal(0.0, 0.5),
1176 &device,
1177 );
1178 let v = Tensor::<TestBackend, 3>::random(
1179 [num_heads, block_size, head_dim],
1180 burn::tensor::Distribution::Normal(0.0, 0.5),
1181 &device,
1182 );
1183 (k, v)
1184 })
1185 .collect();
1186
1187 let output_fused = attention.forward_fused_iter(
1188 q.clone(),
1189 kv_blocks.clone().into_iter(),
1190 false,
1191 0,
1192 seq_len,
1193 );
1194
1195 let k_cat: Vec<_> = kv_blocks.iter().map(|(k, _)| k.clone()).collect();
1196 let v_cat: Vec<_> = kv_blocks.iter().map(|(_, v)| v.clone()).collect();
1197
1198 let k_full = Tensor::cat(k_cat, 1).reshape([1, num_heads, seq_len, head_dim]);
1199 let v_full = Tensor::cat(v_cat, 1).reshape([1, num_heads, seq_len, head_dim]);
1200
1201 let output_std = attention.forward(q, k_full, v_full, false, 0);
1202
1203 let fused_data = output_fused
1204 .into_data()
1205 .into_vec::<f32>()
1206 .expect("output data");
1207 let std_data = output_std
1208 .into_data()
1209 .into_vec::<f32>()
1210 .expect("output data");
1211
1212 for (i, (f, s)) in fused_data.iter().zip(std_data.iter()).enumerate() {
1213 let diff = (f - s).abs();
1214 assert!(
1215 diff < 1e-3,
1216 "Mismatch at {}: fused={}, std={}, diff={}",
1217 i,
1218 f,
1219 s,
1220 diff
1221 );
1222 }
1223 }
1224
1225 #[test]
1226 fn test_causal_mask() {
1227 let device = <TestBackend as Backend>::Device::default();
1228 let attention = HierarchicalFlashAttention::default_config();
1229
1230 let mask = attention.build_causal_mask_cached::<TestBackend>(&device, 4, 4, 0, 0, 0);
1231
1232 let data = mask.into_data().into_vec::<f32>().expect("mask data");
1233
1234 assert!(data[0].abs() < 1e-5);
1235 assert!(data[1] < -1000.0);
1236 assert!(data[4].abs() < 1e-5);
1237 assert!(data[5].abs() < 1e-5);
1238 assert!(data[6] < -1000.0);
1239 }
1240
1241 #[test]
1242 fn test_forward_with_workspace() {
1243 let device = <TestBackend as Backend>::Device::default();
1244 let attention = HierarchicalFlashAttention::new(FlashAttentionConfig {
1245 block_q: 4,
1246 block_kv: 4,
1247 use_log_space: false,
1248 ..Default::default()
1249 });
1250
1251 let batch_size = 1;
1252 let num_heads = 2;
1253 let seq_len = 8;
1254 let head_dim = 4;
1255
1256 let q = Tensor::<TestBackend, 4>::random(
1257 [batch_size, num_heads, seq_len, head_dim],
1258 burn::tensor::Distribution::Normal(0.0, 0.5),
1259 &device,
1260 );
1261 let k = Tensor::<TestBackend, 4>::random(
1262 [batch_size, num_heads, seq_len, head_dim],
1263 burn::tensor::Distribution::Normal(0.0, 0.5),
1264 &device,
1265 );
1266 let v = Tensor::<TestBackend, 4>::random(
1267 [batch_size, num_heads, seq_len, head_dim],
1268 burn::tensor::Distribution::Normal(0.0, 0.5),
1269 &device,
1270 );
1271
1272 let mut workspace = AttentionWorkspace::new();
1274 let output_workspace = attention.forward_with_workspace(
1275 q.clone(),
1276 k.clone(),
1277 v.clone(),
1278 false,
1279 0,
1280 &mut workspace,
1281 );
1282
1283 let output_std = attention.forward(q, k, v, false, 0);
1285
1286 let ws_data = output_workspace
1287 .into_data()
1288 .into_vec::<f32>()
1289 .expect("output data");
1290 let std_data = output_std
1291 .into_data()
1292 .into_vec::<f32>()
1293 .expect("output data");
1294
1295 for (i, (w, s)) in ws_data.iter().zip(std_data.iter()).enumerate() {
1296 let diff = (w - s).abs();
1297 assert!(
1298 diff < 1e-3,
1299 "Mismatch at {}: workspace={}, std={}, diff={}",
1300 i,
1301 w,
1302 s,
1303 diff
1304 );
1305 }
1306 }
1307
1308 #[test]
1309 fn test_mask_cache_hit() {
1310 let device = <TestBackend as Backend>::Device::default();
1311 let attention = HierarchicalFlashAttention::default_config();
1312
1313 HierarchicalFlashAttention::clear_mask_cache();
1315
1316 let mask1 = attention.build_causal_mask_cached::<TestBackend>(&device, 4, 4, 0, 0, 0);
1318 let mask2 = attention.build_causal_mask_cached::<TestBackend>(&device, 4, 4, 0, 0, 0);
1320
1321 let data1 = mask1.into_data().into_vec::<f32>().expect("mask data");
1322 let data2 = mask2.into_data().into_vec::<f32>().expect("mask data");
1323
1324 assert_eq!(data1, data2);
1325 }
1326}