gllm_kernels/ops/
flash_attention.rs

1//! Hierarchical FlashAttention with direct paged KV access.
2//!
3//! ## Performance Optimizations
4//!
5//! This module includes several optimizations for ultra-long contexts:
6//! - **Pre-allocated workspace buffers** to avoid repeated allocations
7//! - **Causal mask caching** with relative offset keys for high hit rates
8//! - **Reduced cloning** through careful use of references and slicing
9//! - **Batched operations** to minimize intermediate tensor creation
10//!
11//! ## Mask Cache Strategy
12//!
13//! Causal masks follow a pattern where the mask value at (i, j) depends only on:
14//! - `relative_offset = (q_start + position_offset) - kv_start`
15//! - Position within the block (i, j)
16//!
17//! By using `relative_offset` as the cache key instead of absolute positions,
18//! we achieve much higher cache hit rates for long sequences.
19
20use std::cell::RefCell;
21use std::collections::HashMap;
22use std::sync::atomic::{AtomicUsize, Ordering};
23
24use burn::tensor::backend::Backend;
25use burn::tensor::{Tensor, TensorData};
26
27use crate::ops::stable_accumulator::AccumulatorConfig;
28
29/// Default mask cache capacity (covers most practical scenarios).
30/// For seq_len 4096 with block_q=64, block_kv=16: ~4096 unique relative offsets.
31const DEFAULT_MASK_CACHE_CAPACITY: usize = 8192;
32
33/// Global configuration for mask cache capacity.
34static MASK_CACHE_CAPACITY: AtomicUsize = AtomicUsize::new(DEFAULT_MASK_CACHE_CAPACITY);
35
36/// Cache key for causal masks using relative offset strategy.
37///
38/// The key insight is that causal masks depend on the *relative* position
39/// between query and key, not their absolute positions. This dramatically
40/// reduces the number of unique masks needed.
41#[derive(Hash, Eq, PartialEq, Clone, Debug)]
42struct MaskCacheKey {
43    /// Query block length.
44    query_len: usize,
45    /// Key block length.
46    key_len: usize,
47    /// Relative offset: (q_start + position_offset) - kv_start.
48    /// This determines the causal boundary within the block.
49    relative_offset: isize,
50}
51
52impl MaskCacheKey {
53    /// Create a cache key from absolute positions.
54    fn from_positions(
55        query_len: usize,
56        key_len: usize,
57        q_start: usize,
58        kv_start: usize,
59        position_offset: usize,
60    ) -> Self {
61        // The causal condition is: kv_pos <= q_pos + position_offset
62        // Rewritten: kv_start + j <= q_start + position_offset + i
63        // Which is: j <= (q_start + position_offset - kv_start) + i
64        // So the relative_offset determines where the causal boundary starts
65        let relative_offset = (q_start + position_offset) as isize - kv_start as isize;
66        Self {
67            query_len,
68            key_len,
69            relative_offset,
70        }
71    }
72}
73
74// Thread-local cache for causal masks to avoid repeated allocations.
75thread_local! {
76    static MASK_CACHE: RefCell<MaskCache> = RefCell::new(
77        MaskCache::new(MASK_CACHE_CAPACITY.load(Ordering::Relaxed))
78    );
79}
80
81/// LRU-style mask cache with bounded capacity and cache statistics.
82struct MaskCache {
83    /// Cached mask data (query_len * key_len f32 values).
84    cache: HashMap<MaskCacheKey, Vec<f32>>,
85    /// Access order for LRU eviction (using VecDeque would be more efficient).
86    access_order: Vec<MaskCacheKey>,
87    /// Maximum number of cached masks.
88    capacity: usize,
89    /// Cache hit count (for diagnostics).
90    hits: usize,
91    /// Cache miss count (for diagnostics).
92    misses: usize,
93}
94
95impl MaskCache {
96    fn new(capacity: usize) -> Self {
97        Self {
98            cache: HashMap::with_capacity(capacity.min(1024)), // Don't over-allocate initially
99            access_order: Vec::with_capacity(capacity.min(1024)),
100            capacity,
101            hits: 0,
102            misses: 0,
103        }
104    }
105
106    fn get_or_create(
107        &mut self,
108        key: MaskCacheKey,
109        create_fn: impl FnOnce() -> Vec<f32>,
110    ) -> &Vec<f32> {
111        if self.cache.contains_key(&key) {
112            self.hits += 1;
113            // Move to end (most recently used) - only if not already at end
114            if self.access_order.last() != Some(&key) {
115                if let Some(pos) = self.access_order.iter().position(|k| k == &key) {
116                    self.access_order.remove(pos);
117                    self.access_order.push(key.clone());
118                }
119            }
120        } else {
121            self.misses += 1;
122            // Evict oldest entries if at capacity
123            while self.cache.len() >= self.capacity && !self.access_order.is_empty() {
124                let oldest = self.access_order.remove(0);
125                self.cache.remove(&oldest);
126            }
127            self.cache.insert(key.clone(), create_fn());
128            self.access_order.push(key.clone());
129        }
130        self.cache.get(&key).unwrap()
131    }
132
133    /// Get cache statistics (hits, misses, hit_rate).
134    #[allow(dead_code)]
135    fn stats(&self) -> (usize, usize, f64) {
136        let total = self.hits + self.misses;
137        let hit_rate = if total > 0 {
138            self.hits as f64 / total as f64
139        } else {
140            0.0
141        };
142        (self.hits, self.misses, hit_rate)
143    }
144
145    /// Clear the cache and reset statistics.
146    fn clear(&mut self) {
147        self.cache.clear();
148        self.access_order.clear();
149        self.hits = 0;
150        self.misses = 0;
151    }
152
153    /// Resize the cache capacity.
154    #[allow(dead_code)]
155    fn resize(&mut self, new_capacity: usize) {
156        self.capacity = new_capacity;
157        // Evict excess entries if shrinking
158        while self.cache.len() > self.capacity && !self.access_order.is_empty() {
159            let oldest = self.access_order.remove(0);
160            self.cache.remove(&oldest);
161        }
162    }
163
164    /// Current number of cached entries.
165    #[allow(dead_code)]
166    fn len(&self) -> usize {
167        self.cache.len()
168    }
169}
170
171/// Configuration for mask caching behavior.
172#[derive(Clone, Debug)]
173pub struct MaskCacheConfig {
174    /// Maximum number of masks to cache per thread.
175    pub capacity: usize,
176    /// Enable cache statistics tracking.
177    pub track_stats: bool,
178}
179
180impl Default for MaskCacheConfig {
181    fn default() -> Self {
182        Self {
183            capacity: DEFAULT_MASK_CACHE_CAPACITY,
184            track_stats: false,
185        }
186    }
187}
188
189impl MaskCacheConfig {
190    /// Configuration optimized for short sequences (< 2K).
191    pub fn short_context() -> Self {
192        Self {
193            capacity: 1024,
194            track_stats: false,
195        }
196    }
197
198    /// Configuration optimized for medium sequences (2K - 8K).
199    pub fn medium_context() -> Self {
200        Self {
201            capacity: 8192,
202            track_stats: false,
203        }
204    }
205
206    /// Configuration optimized for long sequences (8K - 32K).
207    pub fn long_context() -> Self {
208        Self {
209            capacity: 32768,
210            track_stats: false,
211        }
212    }
213
214    /// Configuration optimized for ultra-long sequences (32K+).
215    /// Uses more memory but ensures high cache hit rates.
216    pub fn ultra_long_context() -> Self {
217        Self {
218            capacity: 131072, // 128K entries, ~512MB for 64x16 masks
219            track_stats: false,
220        }
221    }
222}
223
224/// Configuration for deterministic computation.
225#[derive(Clone, Debug)]
226pub struct DeterministicConfig {
227    /// Enable deterministic mode.
228    pub enabled: bool,
229    /// Force fixed tile processing order for reproducibility.
230    pub fixed_tile_order: bool,
231    /// Fixed random seed for reproducibility.
232    pub seed: Option<u64>,
233    /// Disable GPU non-deterministic operations (use deterministic kernels).
234    pub no_gpu_nondeterminism: bool,
235    /// Enable verification of determinism (compare results of multiple runs).
236    pub verify_determinism: bool,
237}
238
239impl Default for DeterministicConfig {
240    fn default() -> Self {
241        Self {
242            enabled: false,
243            fixed_tile_order: false,
244            seed: None,
245            no_gpu_nondeterminism: false,
246            verify_determinism: false,
247        }
248    }
249}
250
251impl DeterministicConfig {
252    /// Create a configuration for maximum reproducibility.
253    pub fn strict() -> Self {
254        Self {
255            enabled: true,
256            fixed_tile_order: true,
257            seed: Some(42),
258            no_gpu_nondeterminism: true,
259            verify_determinism: cfg!(debug_assertions),
260        }
261    }
262
263    /// Create a configuration that allows some non-determinism for speed.
264    pub fn relaxed() -> Self {
265        Self {
266            enabled: false,
267            fixed_tile_order: false,
268            seed: None,
269            no_gpu_nondeterminism: false,
270            verify_determinism: false,
271        }
272    }
273
274    /// Create a configuration for 2M context (strict by default).
275    pub fn ultra_long_context() -> Self {
276        Self::strict()
277    }
278
279    /// Check if any deterministic guarantees are enabled.
280    pub fn is_deterministic(&self) -> bool {
281        self.enabled || self.fixed_tile_order || self.seed.is_some()
282    }
283}
284
285/// Strict ordering iterator for deterministic processing.
286pub struct StrictOrderIterator<I> {
287    inner: I,
288    index: usize,
289}
290
291impl<I: Iterator> StrictOrderIterator<I> {
292    pub fn new(iter: I) -> Self {
293        Self { inner: iter, index: 0 }
294    }
295
296    /// Get the current index (for verification).
297    pub fn current_index(&self) -> usize {
298        self.index
299    }
300}
301
302impl<I: Iterator> Iterator for StrictOrderIterator<I> {
303    type Item = (usize, I::Item);
304
305    fn next(&mut self) -> Option<Self::Item> {
306        let item = self.inner.next()?;
307        let index = self.index;
308        self.index += 1;
309
310        std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
311
312        Some((index, item))
313    }
314}
315
316/// Extension trait for creating strict order iterators.
317pub trait StrictOrderExt: Iterator + Sized {
318    fn strict_order(self) -> StrictOrderIterator<Self> {
319        StrictOrderIterator::new(self)
320    }
321}
322
323impl<I: Iterator> StrictOrderExt for I {}
324
325/// Configuration for hierarchical FlashAttention.
326#[derive(Debug, Clone)]
327pub struct FlashAttentionConfig {
328    /// Query block size for tiling.
329    pub block_q: usize,
330    /// KV block size for tiling (should match PagedKVCache block size).
331    pub block_kv: usize,
332    /// Accumulator configuration for numerical stability.
333    pub accumulator: AccumulatorConfig,
334    /// Determinism configuration.
335    pub determinism: DeterministicConfig,
336    /// Use log-space accumulation (more stable but slightly slower).
337    pub use_log_space: bool,
338    /// Maximum sequence length to expect (for pre-allocation).
339    pub max_seq_len: usize,
340}
341
342impl Default for FlashAttentionConfig {
343    fn default() -> Self {
344        Self {
345            block_q: 64,
346            block_kv: 16,
347            accumulator: AccumulatorConfig::max_precision(),
348            determinism: DeterministicConfig::strict(),
349            use_log_space: true,
350            max_seq_len: 2_000_000,
351        }
352    }
353}
354
355impl FlashAttentionConfig {
356    /// Configuration optimized for 2M context.
357    pub fn ultra_long_context() -> Self {
358        Self {
359            block_q: 64,
360            block_kv: 16,
361            accumulator: AccumulatorConfig::max_precision(),
362            determinism: DeterministicConfig::ultra_long_context(),
363            use_log_space: true,
364            max_seq_len: 2_000_000,
365        }
366    }
367
368    /// Configuration for shorter contexts (< 100K).
369    pub fn short_context() -> Self {
370        Self {
371            block_q: 128,
372            block_kv: 64,
373            accumulator: AccumulatorConfig::short_context(),
374            determinism: DeterministicConfig::relaxed(),
375            use_log_space: false,
376            max_seq_len: 100_000,
377        }
378    }
379}
380
381/// Backward-compatible alias.
382pub type HierarchicalFlashConfig = FlashAttentionConfig;
383
384/// Trait for fused paged attention computation.
385pub trait FusedPagedAttention<B: Backend> {
386    /// Compute attention with direct access to paged KV blocks.
387    fn forward_fused<'a, I>(
388        &self,
389        q: Tensor<B, 4>,
390        kv_blocks: I,
391        config: &FlashAttentionConfig,
392        causal: bool,
393        position_offset: usize,
394    ) -> Tensor<B, 4>
395    where
396        I: Iterator<Item = (Tensor<B, 3>, Tensor<B, 3>)> + 'a;
397}
398
399/// Pre-allocated workspace for attention computation.
400///
401/// This structure holds temporary buffers that are reused across iterations
402/// to avoid repeated memory allocations in the hot path.
403pub struct AttentionWorkspace<B: Backend> {
404    /// Running maximum values [batch, heads, q_block_len, 1].
405    pub m_buffer: Option<Tensor<B, 4>>,
406    /// Running sum values [batch, heads, q_block_len, 1].
407    pub l_buffer: Option<Tensor<B, 4>>,
408    /// Output accumulator [batch, heads, q_block_len, head_dim].
409    pub o_buffer: Option<Tensor<B, 4>>,
410    /// Cached dimensions for validation.
411    dims: Option<(usize, usize, usize, usize)>,
412}
413
414impl<B: Backend> Default for AttentionWorkspace<B> {
415    fn default() -> Self {
416        Self::new()
417    }
418}
419
420impl<B: Backend> AttentionWorkspace<B> {
421    /// Create an empty workspace.
422    pub fn new() -> Self {
423        Self {
424            m_buffer: None,
425            l_buffer: None,
426            o_buffer: None,
427            dims: None,
428        }
429    }
430
431    /// Pre-allocate buffers for given dimensions.
432    pub fn allocate(
433        &mut self,
434        device: &B::Device,
435        batch_size: usize,
436        num_heads: usize,
437        q_block_len: usize,
438        head_dim: usize,
439    ) {
440        let needs_realloc = self.dims.map_or(true, |(b, h, q, d)| {
441            b != batch_size || h != num_heads || q < q_block_len || d != head_dim
442        });
443
444        if needs_realloc {
445            self.m_buffer = Some(Tensor::zeros(
446                [batch_size, num_heads, q_block_len, 1],
447                device,
448            ));
449            self.l_buffer = Some(Tensor::zeros(
450                [batch_size, num_heads, q_block_len, 1],
451                device,
452            ));
453            self.o_buffer = Some(Tensor::zeros(
454                [batch_size, num_heads, q_block_len, head_dim],
455                device,
456            ));
457            self.dims = Some((batch_size, num_heads, q_block_len, head_dim));
458        }
459    }
460
461    /// Reset buffers to initial values for a new Q block.
462    pub fn reset(&mut self, device: &B::Device) {
463        if let Some((batch_size, num_heads, q_block_len, _)) = self.dims {
464            self.m_buffer = Some(Tensor::full(
465                [batch_size, num_heads, q_block_len, 1],
466                f32::NEG_INFINITY,
467                device,
468            ));
469            if let Some(ref mut l) = self.l_buffer {
470                *l = l.clone().zeros_like();
471            }
472            if let Some(ref mut o) = self.o_buffer {
473                *o = o.clone().zeros_like();
474            }
475        }
476    }
477
478    /// Take ownership of output buffer.
479    pub fn take_output(&mut self) -> Option<Tensor<B, 4>> {
480        self.o_buffer.take()
481    }
482}
483
484/// Hierarchical FlashAttention implementation.
485#[derive(Debug, Clone)]
486pub struct HierarchicalFlashAttention {
487    config: FlashAttentionConfig,
488}
489
490impl HierarchicalFlashAttention {
491    /// Create a new HierarchicalFlashAttention with the given configuration.
492    pub fn new(config: FlashAttentionConfig) -> Self {
493        Self { config }
494    }
495
496    /// Create with default configuration.
497    pub fn default_config() -> Self {
498        Self::new(FlashAttentionConfig::default())
499    }
500
501    /// Create optimized for 2M context.
502    pub fn ultra_long_context() -> Self {
503        Self::new(FlashAttentionConfig::ultra_long_context())
504    }
505
506    /// Get the configuration.
507    pub fn config(&self) -> &FlashAttentionConfig {
508        &self.config
509    }
510
511    /// Optimized forward pass with pre-allocated workspace.
512    ///
513    /// This method reuses workspace buffers across Q blocks to minimize allocations.
514    /// For best performance, create a workspace once and reuse it across calls.
515    pub fn forward_with_workspace<B: Backend>(
516        &self,
517        q: Tensor<B, 4>,
518        k: Tensor<B, 4>,
519        v: Tensor<B, 4>,
520        causal: bool,
521        position_offset: usize,
522        workspace: &mut AttentionWorkspace<B>,
523    ) -> Tensor<B, 4> {
524        let device = q.device();
525        let [batch_size, num_heads, query_len, head_dim] = q.dims();
526        let key_len = k.dims()[2];
527
528        if query_len == 0 || key_len == 0 {
529            return Tensor::zeros([batch_size, num_heads, query_len, head_dim], &device);
530        }
531
532        let block_q = self.config.block_q.max(1);
533        let block_kv = self.config.block_kv.max(1);
534        let inv_scale = 1.0 / (head_dim as f32).sqrt();
535
536        // Pre-allocate workspace for maximum block size
537        workspace.allocate(&device, batch_size, num_heads, block_q, head_dim);
538
539        let q_blocks = q.split(block_q, 2);
540        let k_blocks = k.split(block_kv, 2);
541        let v_blocks = v.split(block_kv, 2);
542        let k_blocks_t: Vec<_> = k_blocks.into_iter().map(|block| block.transpose()).collect();
543        let mut outputs = Vec::with_capacity(q_blocks.len());
544
545        for (q_block_index, q_block) in q_blocks.into_iter().enumerate() {
546            let q_block_len = q_block.dims()[2];
547            let q_start = q_block_index * block_q;
548            let q_block_scaled = q_block * inv_scale;
549
550            // Initialize accumulators
551            let mut m_i = Tensor::<B, 4>::full(
552                [batch_size, num_heads, q_block_len, 1],
553                f32::NEG_INFINITY,
554                &device,
555            );
556            let mut l_i = Tensor::<B, 4>::zeros([batch_size, num_heads, q_block_len, 1], &device);
557            let mut o_i = Tensor::<B, 4>::zeros(
558                [batch_size, num_heads, q_block_len, head_dim],
559                &device,
560            );
561
562            for (kv_index, (k_block_t, v_block)) in
563                k_blocks_t.iter().zip(v_blocks.iter()).enumerate()
564            {
565                let kv_block_len = k_block_t.dims()[3];
566                let kv_start = kv_index * block_kv;
567
568                // Compute attention scores
569                let mut scores = q_block_scaled.clone().matmul(k_block_t.clone());
570
571                // Apply causal mask with caching
572                if causal {
573                    let mask = self.build_causal_mask_cached::<B>(
574                        &device,
575                        q_block_len,
576                        kv_block_len,
577                        q_start,
578                        kv_start,
579                        position_offset,
580                    );
581                    scores = scores + mask;
582                }
583
584                // Online softmax update (fused operations)
585                let m_ij = scores.clone().max_dim(3);
586                let m_new = m_i.clone().max_pair(m_ij);
587
588                let m_scale = (m_i - m_new.clone()).exp();
589                let p_ij = (scores - m_new.clone()).exp();
590                let p_sum = p_ij.clone().sum_dim(3);
591
592                // Update accumulators
593                l_i = m_scale.clone() * l_i + p_sum;
594                o_i = m_scale * o_i + p_ij.matmul(v_block.clone());
595                m_i = m_new;
596            }
597
598            outputs.push(o_i / l_i);
599        }
600
601        Tensor::cat(outputs, 2)
602    }
603
604    /// Build causal mask with thread-local caching using relative offset strategy.
605    ///
606    /// The mask pattern depends only on `relative_offset = (q_start + position_offset) - kv_start`,
607    /// not on absolute positions. This dramatically increases cache hit rates for long sequences.
608    ///
609    /// For example, with block_q=64, block_kv=16, seq_len=8192:
610    /// - Old strategy: 65536 unique keys (16 q_blocks × 512 kv_blocks × 8 position offsets)
611    /// - New strategy: ~8192 unique keys (based on relative offset range)
612    fn build_causal_mask_cached<B: Backend>(
613        &self,
614        device: &B::Device,
615        query_len: usize,
616        key_len: usize,
617        q_start: usize,
618        kv_start: usize,
619        position_offset: usize,
620    ) -> Tensor<B, 4> {
621        // Use relative offset as key for better cache hit rate
622        let key = MaskCacheKey::from_positions(
623            query_len,
624            key_len,
625            q_start,
626            kv_start,
627            position_offset,
628        );
629        let relative_offset = key.relative_offset;
630
631        let mask_value = -1.0e4_f32;
632
633        let data = MASK_CACHE.with(|cache| {
634            let mut cache = cache.borrow_mut();
635            cache
636                .get_or_create(key, || {
637                    let mut data = Vec::with_capacity(query_len * key_len);
638                    // The causal condition: kv_pos <= q_pos
639                    // With relative_offset = q_start + position_offset - kv_start:
640                    // j <= relative_offset + i
641                    for i in 0..query_len {
642                        let threshold = relative_offset + i as isize;
643                        for j in 0..key_len {
644                            let allowed = (j as isize) <= threshold;
645                            data.push(if allowed { 0.0 } else { mask_value });
646                        }
647                    }
648                    data
649                })
650                .clone()
651        });
652
653        Tensor::<B, 2>::from_data(TensorData::new(data, [query_len, key_len]), device)
654            .reshape([1, 1, query_len, key_len])
655    }
656
657    /// Standard FlashAttention forward pass (non-fused, for reference/testing).
658    pub fn forward<B: Backend>(
659        &self,
660        q: Tensor<B, 4>,
661        k: Tensor<B, 4>,
662        v: Tensor<B, 4>,
663        causal: bool,
664        position_offset: usize,
665    ) -> Tensor<B, 4> {
666        let device = q.device();
667        let [batch_size, num_heads, query_len, head_dim] = q.dims();
668        let key_len = k.dims()[2];
669
670        if query_len == 0 || key_len == 0 {
671            return Tensor::zeros([batch_size, num_heads, query_len, head_dim], &device);
672        }
673
674        let block_q = self.config.block_q.max(1);
675        let block_kv = self.config.block_kv.max(1);
676        let inv_scale = 1.0 / (head_dim as f32).sqrt();
677
678        let k_blocks = k.split(block_kv, 2);
679        let v_blocks = v.split(block_kv, 2);
680        let k_blocks_t: Vec<_> = k_blocks.into_iter().map(|block| block.transpose()).collect();
681        let q_blocks = q.split(block_q, 2);
682
683        let mut outputs = Vec::with_capacity(q_blocks.len());
684        let fixed_tile_order = self.config.determinism.fixed_tile_order;
685        let kv_block_count = k_blocks_t.len();
686
687        let process_q_block = |q_start: usize, q_block: Tensor<B, 4>, outputs: &mut Vec<Tensor<B, 4>>| {
688            let q_block_len = q_block.dims()[2];
689            let q_block_scaled = q_block * inv_scale;
690
691            let mut m_i = Tensor::<B, 4>::full(
692                [batch_size, num_heads, q_block_len, 1],
693                f32::NEG_INFINITY,
694                &device,
695            );
696            let mut l_i = Tensor::<B, 4>::zeros([batch_size, num_heads, q_block_len, 1], &device);
697            let mut o_i = Tensor::<B, 4>::zeros(
698                [batch_size, num_heads, q_block_len, head_dim],
699                &device,
700            );
701
702            if fixed_tile_order {
703                let mut kv_index = 0usize;
704                while kv_index < kv_block_count {
705                    let k_block_t = &k_blocks_t[kv_index];
706                    let v_block = &v_blocks[kv_index];
707                    let kv_block_len = k_block_t.dims()[3];
708                    let kv_start = kv_index * block_kv;
709
710                    let mut scores = q_block_scaled.clone().matmul(k_block_t.clone());
711
712                    if causal {
713                        let mask = self.build_causal_mask_cached::<B>(
714                            &device,
715                            q_block_len,
716                            kv_block_len,
717                            q_start,
718                            kv_start,
719                            position_offset,
720                        );
721                        scores = scores + mask;
722                    }
723
724                    let m_ij = scores.clone().max_dim(3);
725                    let m_new = m_i.clone().max_pair(m_ij);
726
727                    let m_scale = (m_i - m_new.clone()).exp();
728                    let p_ij = (scores - m_new.clone()).exp();
729                    let p_sum = p_ij.clone().sum_dim(3);
730
731                    l_i = m_scale.clone() * l_i + p_sum;
732                    o_i = m_scale * o_i + p_ij.matmul(v_block.clone());
733                    m_i = m_new;
734
735                    kv_index += 1;
736                }
737            } else {
738                for kv_index in 0..kv_block_count {
739                    let k_block_t = &k_blocks_t[kv_index];
740                    let v_block = &v_blocks[kv_index];
741                    let kv_block_len = k_block_t.dims()[3];
742                    let kv_start = kv_index * block_kv;
743
744                    let mut scores = q_block_scaled.clone().matmul(k_block_t.clone());
745
746                    if causal {
747                        let mask = self.build_causal_mask_cached::<B>(
748                            &device,
749                            q_block_len,
750                            kv_block_len,
751                            q_start,
752                            kv_start,
753                            position_offset,
754                        );
755                        scores = scores + mask;
756                    }
757
758                    let m_ij = scores.clone().max_dim(3);
759                    let m_new = m_i.clone().max_pair(m_ij);
760
761                    let m_scale = (m_i - m_new.clone()).exp();
762                    let p_ij = (scores - m_new.clone()).exp();
763                    let p_sum = p_ij.clone().sum_dim(3);
764
765                    l_i = m_scale.clone() * l_i + p_sum;
766                    o_i = m_scale * o_i + p_ij.matmul(v_block.clone());
767                    m_i = m_new;
768                }
769            }
770
771            outputs.push(o_i / l_i);
772        };
773
774        for (q_block_index, q_block) in q_blocks.into_iter().enumerate() {
775            let q_start = q_block_index * block_q;
776            process_q_block(q_start, q_block, &mut outputs);
777        }
778
779        let output = Tensor::cat(outputs, 2);
780        B::sync(&output.device());
781        output
782    }
783
784    /// Fused forward pass that directly iterates over KV blocks.
785    pub fn forward_fused_iter<'a, B, I>(
786        &self,
787        q: Tensor<B, 4>,
788        kv_blocks: I,
789        causal: bool,
790        position_offset: usize,
791        total_kv_len: usize,
792    ) -> Tensor<B, 4>
793    where
794        B: Backend,
795        I: Iterator<Item = (Tensor<B, 3>, Tensor<B, 3>)> + 'a,
796    {
797        let device = q.device();
798        let [batch_size, num_heads, query_len, head_dim] = q.dims();
799
800        if query_len == 0 || total_kv_len == 0 {
801            return Tensor::zeros([batch_size, num_heads, query_len, head_dim], &device);
802        }
803
804        let block_q = self.config.block_q.max(1);
805        let inv_scale = 1.0 / (head_dim as f32).sqrt();
806
807        let (kv_lower, _) = kv_blocks.size_hint();
808        let mut kv_blocks_vec: Vec<(Tensor<B, 4>, Tensor<B, 4>)> = Vec::with_capacity(kv_lower);
809
810        if self.config.determinism.fixed_tile_order {
811            kv_blocks_vec.extend(
812                kv_blocks
813                    .strict_order()
814                    .map(|(_, (k, v))| (k.unsqueeze_dim(0).transpose(), v.unsqueeze_dim(0))),
815            );
816        } else {
817            kv_blocks_vec.extend(
818                kv_blocks.map(|(k, v)| (k.unsqueeze_dim(0).transpose(), v.unsqueeze_dim(0))),
819            );
820        }
821
822        let mut kv_starts = Vec::with_capacity(kv_blocks_vec.len());
823        let mut kv_start = 0usize;
824        for (k_block_t, _) in &kv_blocks_vec {
825            kv_starts.push(kv_start);
826            kv_start += k_block_t.dims()[3];
827        }
828
829        let q_blocks = q.split(block_q, 2);
830        let mut outputs = Vec::with_capacity(q_blocks.len());
831
832        for (q_block_index, q_block) in q_blocks.into_iter().enumerate() {
833            let q_start = q_block_index * block_q;
834            let q_block_scaled = q_block * inv_scale;
835
836            let output = if self.config.use_log_space {
837                self.process_q_block_log_space(
838                    q_block_scaled,
839                    &kv_blocks_vec,
840                    &kv_starts,
841                    causal,
842                    q_start,
843                    position_offset,
844                )
845            } else {
846                self.process_q_block_standard(
847                    q_block_scaled,
848                    &kv_blocks_vec,
849                    &kv_starts,
850                    causal,
851                    q_start,
852                    position_offset,
853                )
854            };
855
856            outputs.push(output);
857        }
858
859        Tensor::cat(outputs, 2)
860    }
861
862    fn process_q_block_standard<B: Backend>(
863        &self,
864        q_block: Tensor<B, 4>,
865        kv_blocks: &[(Tensor<B, 4>, Tensor<B, 4>)],
866        kv_starts: &[usize],
867        causal: bool,
868        q_start: usize,
869        position_offset: usize,
870    ) -> Tensor<B, 4> {
871        let device = q_block.device();
872        let [batch_size, num_heads, q_block_len, head_dim] = q_block.dims();
873
874        let mut m_i = Tensor::<B, 4>::full(
875            [batch_size, num_heads, q_block_len, 1],
876            f32::NEG_INFINITY,
877            &device,
878        );
879        let mut l_i = Tensor::<B, 4>::zeros([batch_size, num_heads, q_block_len, 1], &device);
880        let mut o_i = Tensor::<B, 4>::zeros(
881            [batch_size, num_heads, q_block_len, head_dim],
882            &device,
883        );
884
885        for (kv_index, (k_block_t, v_block)) in kv_blocks.iter().enumerate() {
886            let kv_block_len = k_block_t.dims()[3];
887            let kv_start = kv_starts[kv_index];
888
889            let mut scores = q_block.clone().matmul(k_block_t.clone());
890
891            if causal {
892                let mask = self.build_causal_mask_cached::<B>(
893                    &device,
894                    q_block_len,
895                    kv_block_len,
896                    q_start,
897                    kv_start,
898                    position_offset,
899                );
900                scores = scores + mask;
901            }
902
903            let m_ij = scores.clone().max_dim(3);
904            let m_new = m_i.clone().max_pair(m_ij);
905
906            let m_scale = (m_i - m_new.clone()).exp();
907            let p_ij = (scores - m_new.clone()).exp();
908            let p_sum = p_ij.clone().sum_dim(3);
909
910            l_i = m_scale.clone() * l_i + p_sum;
911            o_i = m_scale * o_i + p_ij.matmul(v_block.clone());
912            m_i = m_new;
913        }
914
915        o_i / l_i
916    }
917
918    fn process_q_block_log_space<B: Backend>(
919        &self,
920        q_block: Tensor<B, 4>,
921        kv_blocks: &[(Tensor<B, 4>, Tensor<B, 4>)],
922        kv_starts: &[usize],
923        causal: bool,
924        q_start: usize,
925        position_offset: usize,
926    ) -> Tensor<B, 4> {
927        let device = q_block.device();
928        let [batch_size, num_heads, q_block_len, head_dim] = q_block.dims();
929
930        let mut m_i = Tensor::<B, 4>::full(
931            [batch_size, num_heads, q_block_len, 1],
932            f32::NEG_INFINITY,
933            &device,
934        );
935        let mut log_l_i = Tensor::<B, 4>::full(
936            [batch_size, num_heads, q_block_len, 1],
937            f32::NEG_INFINITY,
938            &device,
939        );
940        let mut o_i = Tensor::<B, 4>::zeros(
941            [batch_size, num_heads, q_block_len, head_dim],
942            &device,
943        );
944
945        for (kv_index, (k_block_t, v_block)) in kv_blocks.iter().enumerate() {
946            let kv_block_len = k_block_t.dims()[3];
947            let kv_start = kv_starts[kv_index];
948
949            let mut scores = q_block.clone().matmul(k_block_t.clone());
950
951            if causal {
952                let mask = self.build_causal_mask_cached::<B>(
953                    &device,
954                    q_block_len,
955                    kv_block_len,
956                    q_start,
957                    kv_start,
958                    position_offset,
959                );
960                scores = scores + mask;
961            }
962
963            let m_ij = scores.clone().max_dim(3);
964            let m_new = m_i.clone().max_pair(m_ij.clone());
965
966            let scores_shifted = scores - m_ij.clone();
967            let p_ij = scores_shifted.exp();
968            let sum_p = p_ij.clone().sum_dim(3);
969            let log_sum_p = sum_p.log();
970
971            let m_diff = m_i - m_new.clone();
972            let log_prev = m_diff.clone() + log_l_i;
973            let log_curr = (m_ij - m_new.clone()) + log_sum_p;
974
975            let log_l_new = Self::tensor_log_add_exp(log_prev, log_curr);
976
977            let m_scale = m_diff.exp();
978            o_i = m_scale * o_i + p_ij.matmul(v_block.clone());
979
980            m_i = m_new;
981            log_l_i = log_l_new;
982        }
983
984        let l_i = log_l_i.exp();
985        o_i / l_i
986    }
987
988    fn tensor_log_add_exp<B: Backend>(a: Tensor<B, 4>, b: Tensor<B, 4>) -> Tensor<B, 4> {
989        let max = a.clone().max_pair(b.clone());
990        let diff_a = a - max.clone();
991        let diff_b = b - max.clone();
992        max + (diff_a.exp() + diff_b.exp()).log()
993    }
994
995    /// Build causal mask without caching (fallback for testing).
996    #[allow(dead_code)]
997    fn build_causal_mask_uncached<B: Backend>(
998        &self,
999        device: &B::Device,
1000        query_len: usize,
1001        key_len: usize,
1002        q_start: usize,
1003        kv_start: usize,
1004        position_offset: usize,
1005    ) -> Tensor<B, 4> {
1006        let mut data = Vec::with_capacity(query_len * key_len);
1007        let mask_value = -1.0e4_f32;
1008
1009        for i in 0..query_len {
1010            let absolute_pos = position_offset + q_start + i;
1011            for j in 0..key_len {
1012                let absolute_key = kv_start + j;
1013                let allowed = absolute_key <= absolute_pos;
1014                data.push(if allowed { 0.0 } else { mask_value });
1015            }
1016        }
1017
1018        Tensor::<B, 2>::from_data(TensorData::new(data, [query_len, key_len]), device)
1019            .reshape([1, 1, query_len, key_len])
1020    }
1021
1022    /// Clear the thread-local mask cache.
1023    ///
1024    /// Useful for memory pressure situations or testing.
1025    pub fn clear_mask_cache() {
1026        MASK_CACHE.with(|cache| cache.borrow_mut().clear());
1027    }
1028}
1029
1030impl<B: Backend> FusedPagedAttention<B> for HierarchicalFlashAttention {
1031    fn forward_fused<'a, I>(
1032        &self,
1033        q: Tensor<B, 4>,
1034        kv_blocks: I,
1035        config: &FlashAttentionConfig,
1036        causal: bool,
1037        position_offset: usize,
1038    ) -> Tensor<B, 4>
1039    where
1040        I: Iterator<Item = (Tensor<B, 3>, Tensor<B, 3>)> + 'a,
1041    {
1042        let kv_blocks: Vec<_> = kv_blocks.collect();
1043        let total_kv_len: usize = kv_blocks.iter().map(|(k, _)| k.dims()[1]).sum();
1044
1045        let attention = Self::new(config.clone());
1046
1047        attention.forward_fused_iter(q, kv_blocks.into_iter(), causal, position_offset, total_kv_len)
1048    }
1049}
1050
1051#[cfg(all(test, feature = "cpu"))]
1052mod tests {
1053    use super::*;
1054    use burn::tensor::activation::softmax;
1055    use burn_ndarray::NdArray;
1056
1057    type TestBackend = NdArray<f32>;
1058
1059    #[test]
1060    fn test_hierarchical_flash_basic() {
1061        let device = <TestBackend as Backend>::Device::default();
1062        let attention = HierarchicalFlashAttention::default_config();
1063
1064        let batch_size = 1;
1065        let num_heads = 2;
1066        let seq_len = 16;
1067        let head_dim = 8;
1068
1069        let q = Tensor::<TestBackend, 4>::random(
1070            [batch_size, num_heads, seq_len, head_dim],
1071            burn::tensor::Distribution::Normal(0.0, 1.0),
1072            &device,
1073        );
1074        let k = Tensor::<TestBackend, 4>::random(
1075            [batch_size, num_heads, seq_len, head_dim],
1076            burn::tensor::Distribution::Normal(0.0, 1.0),
1077            &device,
1078        );
1079        let v = Tensor::<TestBackend, 4>::random(
1080            [batch_size, num_heads, seq_len, head_dim],
1081            burn::tensor::Distribution::Normal(0.0, 1.0),
1082            &device,
1083        );
1084
1085        let output = attention.forward(q, k, v, false, 0);
1086        assert_eq!(output.dims(), [batch_size, num_heads, seq_len, head_dim]);
1087    }
1088
1089    #[test]
1090    fn test_hierarchical_flash_matches_standard() {
1091        let device = <TestBackend as Backend>::Device::default();
1092        let attention = HierarchicalFlashAttention::new(FlashAttentionConfig {
1093            block_q: 4,
1094            block_kv: 4,
1095            use_log_space: false,
1096            ..Default::default()
1097        });
1098
1099        let batch_size = 1;
1100        let num_heads = 2;
1101        let seq_len = 8;
1102        let head_dim = 4;
1103
1104        let q = Tensor::<TestBackend, 4>::random(
1105            [batch_size, num_heads, seq_len, head_dim],
1106            burn::tensor::Distribution::Normal(0.0, 0.5),
1107            &device,
1108        );
1109        let k = Tensor::<TestBackend, 4>::random(
1110            [batch_size, num_heads, seq_len, head_dim],
1111            burn::tensor::Distribution::Normal(0.0, 0.5),
1112            &device,
1113        );
1114        let v = Tensor::<TestBackend, 4>::random(
1115            [batch_size, num_heads, seq_len, head_dim],
1116            burn::tensor::Distribution::Normal(0.0, 0.5),
1117            &device,
1118        );
1119
1120        let output_hier = attention.forward(q.clone(), k.clone(), v.clone(), false, 0);
1121
1122        let scale = (head_dim as f32).sqrt();
1123        let scores = q.matmul(k.transpose()) / scale;
1124        let attn = softmax(scores, 3);
1125        let output_std = attn.matmul(v);
1126
1127        let hier_data = output_hier
1128            .into_data()
1129            .into_vec::<f32>()
1130            .expect("output data");
1131        let std_data = output_std
1132            .into_data()
1133            .into_vec::<f32>()
1134            .expect("output data");
1135
1136        for (i, (h, s)) in hier_data.iter().zip(std_data.iter()).enumerate() {
1137            let diff = (h - s).abs();
1138            assert!(
1139                diff < 1e-3,
1140                "Mismatch at {}: hier={}, std={}, diff={}",
1141                i,
1142                h,
1143                s,
1144                diff
1145            );
1146        }
1147    }
1148
1149    #[test]
1150    fn test_fused_iter_matches_standard() {
1151        let device = <TestBackend as Backend>::Device::default();
1152        let attention = HierarchicalFlashAttention::new(FlashAttentionConfig {
1153            block_q: 4,
1154            block_kv: 4,
1155            use_log_space: false,
1156            ..Default::default()
1157        });
1158
1159        let num_heads = 2;
1160        let seq_len = 16;
1161        let head_dim = 4;
1162        let block_size = 4;
1163
1164        let q = Tensor::<TestBackend, 4>::random(
1165            [1, num_heads, seq_len, head_dim],
1166            burn::tensor::Distribution::Normal(0.0, 0.5),
1167            &device,
1168        );
1169
1170        let num_blocks = seq_len / block_size;
1171        let kv_blocks: Vec<_> = (0..num_blocks)
1172            .map(|_| {
1173                let k = Tensor::<TestBackend, 3>::random(
1174                    [num_heads, block_size, head_dim],
1175                    burn::tensor::Distribution::Normal(0.0, 0.5),
1176                    &device,
1177                );
1178                let v = Tensor::<TestBackend, 3>::random(
1179                    [num_heads, block_size, head_dim],
1180                    burn::tensor::Distribution::Normal(0.0, 0.5),
1181                    &device,
1182                );
1183                (k, v)
1184            })
1185            .collect();
1186
1187        let output_fused = attention.forward_fused_iter(
1188            q.clone(),
1189            kv_blocks.clone().into_iter(),
1190            false,
1191            0,
1192            seq_len,
1193        );
1194
1195        let k_cat: Vec<_> = kv_blocks.iter().map(|(k, _)| k.clone()).collect();
1196        let v_cat: Vec<_> = kv_blocks.iter().map(|(_, v)| v.clone()).collect();
1197
1198        let k_full = Tensor::cat(k_cat, 1).reshape([1, num_heads, seq_len, head_dim]);
1199        let v_full = Tensor::cat(v_cat, 1).reshape([1, num_heads, seq_len, head_dim]);
1200
1201        let output_std = attention.forward(q, k_full, v_full, false, 0);
1202
1203        let fused_data = output_fused
1204            .into_data()
1205            .into_vec::<f32>()
1206            .expect("output data");
1207        let std_data = output_std
1208            .into_data()
1209            .into_vec::<f32>()
1210            .expect("output data");
1211
1212        for (i, (f, s)) in fused_data.iter().zip(std_data.iter()).enumerate() {
1213            let diff = (f - s).abs();
1214            assert!(
1215                diff < 1e-3,
1216                "Mismatch at {}: fused={}, std={}, diff={}",
1217                i,
1218                f,
1219                s,
1220                diff
1221            );
1222        }
1223    }
1224
1225    #[test]
1226    fn test_causal_mask() {
1227        let device = <TestBackend as Backend>::Device::default();
1228        let attention = HierarchicalFlashAttention::default_config();
1229
1230        let mask = attention.build_causal_mask_cached::<TestBackend>(&device, 4, 4, 0, 0, 0);
1231
1232        let data = mask.into_data().into_vec::<f32>().expect("mask data");
1233
1234        assert!(data[0].abs() < 1e-5);
1235        assert!(data[1] < -1000.0);
1236        assert!(data[4].abs() < 1e-5);
1237        assert!(data[5].abs() < 1e-5);
1238        assert!(data[6] < -1000.0);
1239    }
1240
1241    #[test]
1242    fn test_forward_with_workspace() {
1243        let device = <TestBackend as Backend>::Device::default();
1244        let attention = HierarchicalFlashAttention::new(FlashAttentionConfig {
1245            block_q: 4,
1246            block_kv: 4,
1247            use_log_space: false,
1248            ..Default::default()
1249        });
1250
1251        let batch_size = 1;
1252        let num_heads = 2;
1253        let seq_len = 8;
1254        let head_dim = 4;
1255
1256        let q = Tensor::<TestBackend, 4>::random(
1257            [batch_size, num_heads, seq_len, head_dim],
1258            burn::tensor::Distribution::Normal(0.0, 0.5),
1259            &device,
1260        );
1261        let k = Tensor::<TestBackend, 4>::random(
1262            [batch_size, num_heads, seq_len, head_dim],
1263            burn::tensor::Distribution::Normal(0.0, 0.5),
1264            &device,
1265        );
1266        let v = Tensor::<TestBackend, 4>::random(
1267            [batch_size, num_heads, seq_len, head_dim],
1268            burn::tensor::Distribution::Normal(0.0, 0.5),
1269            &device,
1270        );
1271
1272        // Test with workspace
1273        let mut workspace = AttentionWorkspace::new();
1274        let output_workspace = attention.forward_with_workspace(
1275            q.clone(),
1276            k.clone(),
1277            v.clone(),
1278            false,
1279            0,
1280            &mut workspace,
1281        );
1282
1283        // Test without workspace (standard)
1284        let output_std = attention.forward(q, k, v, false, 0);
1285
1286        let ws_data = output_workspace
1287            .into_data()
1288            .into_vec::<f32>()
1289            .expect("output data");
1290        let std_data = output_std
1291            .into_data()
1292            .into_vec::<f32>()
1293            .expect("output data");
1294
1295        for (i, (w, s)) in ws_data.iter().zip(std_data.iter()).enumerate() {
1296            let diff = (w - s).abs();
1297            assert!(
1298                diff < 1e-3,
1299                "Mismatch at {}: workspace={}, std={}, diff={}",
1300                i,
1301                w,
1302                s,
1303                diff
1304            );
1305        }
1306    }
1307
1308    #[test]
1309    fn test_mask_cache_hit() {
1310        let device = <TestBackend as Backend>::Device::default();
1311        let attention = HierarchicalFlashAttention::default_config();
1312
1313        // Clear cache first
1314        HierarchicalFlashAttention::clear_mask_cache();
1315
1316        // First call - cache miss
1317        let mask1 = attention.build_causal_mask_cached::<TestBackend>(&device, 4, 4, 0, 0, 0);
1318        // Second call - should hit cache
1319        let mask2 = attention.build_causal_mask_cached::<TestBackend>(&device, 4, 4, 0, 0, 0);
1320
1321        let data1 = mask1.into_data().into_vec::<f32>().expect("mask data");
1322        let data2 = mask2.into_data().into_vec::<f32>().expect("mask data");
1323
1324        assert_eq!(data1, data2);
1325    }
1326}