gllm_kernels/ops/
sparse_attention.rs

1//! Sparse attention utilities with Lightning Indexer-style selection.
2
3use burn::tensor::backend::Backend;
4use burn::tensor::{Tensor, TensorData};
5
6/// Configuration for sparse attention selection.
7#[derive(Debug, Clone)]
8pub struct SparseAttentionConfig {
9    /// Number of KV tokens selected per query (typical: 2048).
10    pub selected_kv_count: usize,
11    /// Lightning Indexer block size.
12    pub block_size: usize,
13    /// Sparsity pattern to apply.
14    pub sparsity_pattern: SparsityPattern,
15}
16
17/// Sparse attention patterns.
18#[derive(Debug, Clone)]
19pub enum SparsityPattern {
20    /// Sliding window plus global tokens.
21    SlidingWindowGlobal { window: usize, global_tokens: usize },
22    /// Dynamic selection with Lightning Indexer.
23    Dynamic,
24    /// Block-sparse pattern.
25    BlockSparse { block_size: usize },
26}
27
28/// Selected indices for sparse attention.
29#[derive(Debug, Clone)]
30pub struct SparseSelection {
31    batch: usize,
32    num_heads: usize,
33    query_len: usize,
34    selected_kv_count: usize,
35    indices: Vec<usize>,
36}
37
38impl SparseSelection {
39    /// Create a new sparse selection result.
40    pub fn new(
41        batch: usize,
42        num_heads: usize,
43        query_len: usize,
44        selected_kv_count: usize,
45        indices: Vec<usize>,
46    ) -> Self {
47        Self {
48            batch,
49            num_heads,
50            query_len,
51            selected_kv_count,
52            indices,
53        }
54    }
55
56    /// Number of KV tokens selected per query.
57    pub fn selected_kv_count(&self) -> usize {
58        self.selected_kv_count
59    }
60
61    /// Slice of indices for a (batch, head, query) triplet.
62    pub fn indices_for(&self, batch: usize, head: usize, query: usize) -> &[usize] {
63        let stride = self.selected_kv_count;
64        let idx = ((batch * self.num_heads + head) * self.query_len + query) * stride;
65        &self.indices[idx..idx + stride]
66    }
67
68    /// Flat view of all indices.
69    pub fn flat_indices(&self) -> &[usize] {
70        &self.indices
71    }
72}
73
74/// Sparse attention selector with Lightning Indexer.
75#[derive(Debug, Clone)]
76pub struct SparseAttention {
77    config: SparseAttentionConfig,
78}
79
80impl SparseAttention {
81    /// Create a new sparse attention selector.
82    pub fn new(config: SparseAttentionConfig) -> Self {
83        Self { config }
84    }
85
86    /// Access the sparse attention configuration.
87    pub fn config(&self) -> &SparseAttentionConfig {
88        &self.config
89    }
90
91    /// Select sparse indices for attention scores.
92    ///
93    /// # Shapes
94    /// * `scores`: [batch, num_heads, query_len, kv_len]
95    pub fn select_indices<B: Backend>(
96        &self,
97        scores: Tensor<B, 4>,
98    ) -> Result<SparseSelection, &'static str> {
99        let dims = scores.dims();
100        let data = scores
101            .into_data()
102            .into_vec::<f32>()
103            .map_err(|_| "sparse attention expects f32 scores")?;
104        self.select_indices_from_data(&data, dims)
105    }
106
107    /// Apply sparse selection by masking unselected scores.
108    ///
109    /// # Shapes
110    /// * `scores`: [batch, num_heads, query_len, kv_len]
111    /// * returns: masked scores tensor + selection indices
112    pub fn sparsify_scores<B: Backend>(
113        &self,
114        scores: Tensor<B, 4>,
115    ) -> Result<(Tensor<B, 4>, SparseSelection), &'static str> {
116        let device = scores.device();
117        let dims = scores.dims();
118        let mut data = scores
119            .into_data()
120            .into_vec::<f32>()
121            .map_err(|_| "sparse attention expects f32 scores")?;
122        let selection = self.select_indices_from_data(&data, dims)?;
123
124        let [batch, num_heads, query_len, kv_len] = dims;
125        let stride_query = kv_len;
126        let stride_head = query_len * stride_query;
127        let stride_batch = num_heads * stride_head;
128
129        for b in 0..batch {
130            for h in 0..num_heads {
131                for q in 0..query_len {
132                    let offset = b * stride_batch + h * stride_head + q * stride_query;
133                    let selected = selection.indices_for(b, h, q);
134                    let mut keep = vec![false; kv_len];
135                    for &idx in selected {
136                        if idx < kv_len {
137                            keep[idx] = true;
138                        }
139                    }
140                    for idx in 0..kv_len {
141                        if !keep[idx] {
142                            data[offset + idx] = MASK_VALUE;
143                        }
144                    }
145                }
146            }
147        }
148
149        let masked = Tensor::<B, 4>::from_data(TensorData::new(data, dims), &device);
150        Ok((masked, selection))
151    }
152
153    fn select_indices_from_data(
154        &self,
155        data: &[f32],
156        dims: [usize; 4],
157    ) -> Result<SparseSelection, &'static str> {
158        self.validate_config()?;
159        let [batch, num_heads, query_len, kv_len] = dims;
160        if kv_len == 0 {
161            return Err("kv_len must be > 0");
162        }
163
164        let target = self.config.selected_kv_count.min(kv_len);
165        if target == 0 {
166            return Err("selected_kv_count must be > 0");
167        }
168
169        let mut indices = Vec::with_capacity(batch * num_heads * query_len * target);
170        let stride_query = kv_len;
171        let stride_head = query_len * stride_query;
172        let stride_batch = num_heads * stride_head;
173
174        for b in 0..batch {
175            for h in 0..num_heads {
176                for q in 0..query_len {
177                    let offset = b * stride_batch + h * stride_head + q * stride_query;
178                    let scores = &data[offset..offset + kv_len];
179                    let selected = self.select_for_query(scores, q, kv_len, target);
180                    indices.extend(selected);
181                }
182            }
183        }
184
185        Ok(SparseSelection::new(
186            batch,
187            num_heads,
188            query_len,
189            target,
190            indices,
191        ))
192    }
193
194    fn select_for_query(
195        &self,
196        scores: &[f32],
197        query_idx: usize,
198        kv_len: usize,
199        target: usize,
200    ) -> Vec<usize> {
201        let mut forced = Vec::new();
202        let mut forced_mask = vec![false; kv_len];
203
204        match self.config.sparsity_pattern {
205            SparsityPattern::SlidingWindowGlobal {
206                window,
207                global_tokens,
208            } => {
209                let start = query_idx.saturating_sub(window);
210                let end = (query_idx + window + 1).min(kv_len);
211                for idx in start..end {
212                    push_unique(idx, &mut forced, &mut forced_mask);
213                }
214                let global = global_tokens.min(kv_len);
215                for idx in 0..global {
216                    push_unique(idx, &mut forced, &mut forced_mask);
217                }
218            }
219            SparsityPattern::Dynamic | SparsityPattern::BlockSparse { .. } => {}
220        }
221
222        if forced.len() >= target {
223            return top_k_indices(scores, &forced, target);
224        }
225
226        let remaining = target - forced.len();
227        let block_size = self.block_size();
228        let block_count = (remaining + block_size - 1) / block_size;
229        let blocks = select_blocks(scores, kv_len, block_size, block_count, Some(&forced_mask));
230
231        let mut candidates = Vec::new();
232        for block in blocks {
233            let start = block * block_size;
234            let end = (start + block_size).min(kv_len);
235            for idx in start..end {
236                if !forced_mask[idx] {
237                    candidates.push(idx);
238                }
239            }
240        }
241
242        if candidates.len() < remaining {
243            for idx in 0..kv_len {
244                if !forced_mask[idx] {
245                    candidates.push(idx);
246                }
247            }
248        }
249        candidates.sort_unstable();
250        candidates.dedup();
251
252        let mut selected = forced;
253        if remaining > 0 {
254            let mut extra = top_k_indices(scores, &candidates, remaining);
255            selected.append(&mut extra);
256        }
257        selected.sort_unstable();
258        selected.truncate(target);
259        selected
260    }
261
262    fn validate_config(&self) -> Result<(), &'static str> {
263        if self.config.selected_kv_count == 0 {
264            return Err("selected_kv_count must be > 0");
265        }
266        if self.config.block_size == 0 {
267            return Err("block_size must be > 0");
268        }
269        if let SparsityPattern::BlockSparse { block_size } = self.config.sparsity_pattern {
270            if block_size == 0 {
271                return Err("block sparse block_size must be > 0");
272            }
273        }
274        Ok(())
275    }
276
277    fn block_size(&self) -> usize {
278        match self.config.sparsity_pattern {
279            SparsityPattern::BlockSparse { block_size } => block_size.max(1),
280            _ => self.config.block_size.max(1),
281        }
282    }
283}
284
285const MASK_VALUE: f32 = -1.0e4_f32;
286
287fn push_unique(idx: usize, list: &mut Vec<usize>, mask: &mut [bool]) {
288    if !mask[idx] {
289        mask[idx] = true;
290        list.push(idx);
291    }
292}
293
294fn select_blocks(
295    scores: &[f32],
296    kv_len: usize,
297    block_size: usize,
298    block_count: usize,
299    skip_mask: Option<&[bool]>,
300) -> Vec<usize> {
301    if block_count == 0 || kv_len == 0 {
302        return Vec::new();
303    }
304    let num_blocks = (kv_len + block_size - 1) / block_size;
305    let mut block_scores = Vec::with_capacity(num_blocks);
306
307    for block in 0..num_blocks {
308        let start = block * block_size;
309        let end = (start + block_size).min(kv_len);
310        let mut max_score = f32::NEG_INFINITY;
311        for idx in start..end {
312            if skip_mask.map_or(false, |mask| mask[idx]) {
313                continue;
314            }
315            let score = scores[idx];
316            let score = if score.is_nan() { f32::NEG_INFINITY } else { score };
317            if score > max_score {
318                max_score = score;
319            }
320        }
321        if max_score > f32::NEG_INFINITY {
322            block_scores.push((block, max_score));
323        }
324    }
325
326    block_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
327    block_scores.truncate(block_count.min(block_scores.len()));
328    block_scores.into_iter().map(|(block, _)| block).collect()
329}
330
331fn top_k_indices(scores: &[f32], candidates: &[usize], k: usize) -> Vec<usize> {
332    if k == 0 || candidates.is_empty() {
333        return Vec::new();
334    }
335    let mut scored: Vec<(usize, f32)> = candidates
336        .iter()
337        .map(|&idx| {
338            let score = scores[idx];
339            let score = if score.is_nan() { f32::NEG_INFINITY } else { score };
340            (idx, score)
341        })
342        .collect();
343    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
344    scored.truncate(k.min(scored.len()));
345    let mut indices: Vec<usize> = scored.into_iter().map(|(idx, _)| idx).collect();
346    indices.sort_unstable();
347    indices
348}
349
350#[cfg(all(test, feature = "cpu"))]
351mod tests {
352    use super::*;
353    use burn::tensor::{Distribution, Tensor, TensorData};
354    use burn_ndarray::NdArray;
355
356    #[test]
357    fn test_sliding_window_global_forced_tokens() {
358        let config = SparseAttentionConfig {
359            selected_kv_count: 4,
360            block_size: 2,
361            sparsity_pattern: SparsityPattern::SlidingWindowGlobal {
362                window: 1,
363                global_tokens: 1,
364            },
365        };
366        let selector = SparseAttention::new(config);
367        let device = <NdArray<f32> as Backend>::Device::default();
368        let scores =
369            Tensor::<NdArray<f32>, 4>::random([1, 1, 3, 6], Distribution::Uniform(0.0, 1.0), &device);
370
371        let selection = selector.select_indices(scores).expect("selection");
372        let indices = selection.indices_for(0, 0, 2);
373        assert_eq!(indices, &[0, 1, 2, 3]);
374    }
375
376    #[test]
377    fn test_block_sparse_selection() {
378        let config = SparseAttentionConfig {
379            selected_kv_count: 3,
380            block_size: 4,
381            sparsity_pattern: SparsityPattern::BlockSparse { block_size: 4 },
382        };
383        let selector = SparseAttention::new(config);
384        let device = <NdArray<f32> as Backend>::Device::default();
385        let data = vec![0.1, 0.2, 0.3, 0.4, 5.0, 4.0, 3.0, 2.0];
386        let scores = Tensor::<NdArray<f32>, 4>::from_data(TensorData::new(data, [1, 1, 1, 8]), &device);
387
388        let selection = selector.select_indices(scores).expect("selection");
389        let indices = selection.indices_for(0, 0, 0);
390        assert_eq!(indices.len(), 3);
391        assert!(indices.iter().all(|&idx| idx >= 4));
392    }
393
394    #[test]
395    fn test_sparsify_scores_masks_unselected() {
396        let config = SparseAttentionConfig {
397            selected_kv_count: 1,
398            block_size: 2,
399            sparsity_pattern: SparsityPattern::Dynamic,
400        };
401        let selector = SparseAttention::new(config);
402        let device = <NdArray<f32> as Backend>::Device::default();
403        let data = vec![0.1, 0.2, 5.0, 0.3, 0.4];
404        let scores = Tensor::<NdArray<f32>, 4>::from_data(TensorData::new(data, [1, 1, 1, 5]), &device);
405
406        let (masked, selection) = selector.sparsify_scores(scores).expect("sparsify");
407        let masked_data = masked.into_data().into_vec::<f32>().expect("masked data");
408        let indices = selection.indices_for(0, 0, 0);
409        assert_eq!(indices, &[2]);
410        for (idx, value) in masked_data.iter().enumerate() {
411            if idx == 2 {
412                assert!((value - 5.0).abs() < 1e-4);
413            } else {
414                assert!((*value - MASK_VALUE).abs() < 1e-4);
415            }
416        }
417    }
418}