1use burn::tensor::backend::Backend;
4use burn::tensor::{Tensor, TensorData};
5
6#[derive(Debug, Clone)]
8pub struct SparseAttentionConfig {
9 pub selected_kv_count: usize,
11 pub block_size: usize,
13 pub sparsity_pattern: SparsityPattern,
15}
16
17#[derive(Debug, Clone)]
19pub enum SparsityPattern {
20 SlidingWindowGlobal { window: usize, global_tokens: usize },
22 Dynamic,
24 BlockSparse { block_size: usize },
26}
27
28#[derive(Debug, Clone)]
30pub struct SparseSelection {
31 batch: usize,
32 num_heads: usize,
33 query_len: usize,
34 selected_kv_count: usize,
35 indices: Vec<usize>,
36}
37
38impl SparseSelection {
39 pub fn new(
41 batch: usize,
42 num_heads: usize,
43 query_len: usize,
44 selected_kv_count: usize,
45 indices: Vec<usize>,
46 ) -> Self {
47 Self {
48 batch,
49 num_heads,
50 query_len,
51 selected_kv_count,
52 indices,
53 }
54 }
55
56 pub fn selected_kv_count(&self) -> usize {
58 self.selected_kv_count
59 }
60
61 pub fn indices_for(&self, batch: usize, head: usize, query: usize) -> &[usize] {
63 let stride = self.selected_kv_count;
64 let idx = ((batch * self.num_heads + head) * self.query_len + query) * stride;
65 &self.indices[idx..idx + stride]
66 }
67
68 pub fn flat_indices(&self) -> &[usize] {
70 &self.indices
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct SparseAttention {
77 config: SparseAttentionConfig,
78}
79
80impl SparseAttention {
81 pub fn new(config: SparseAttentionConfig) -> Self {
83 Self { config }
84 }
85
86 pub fn config(&self) -> &SparseAttentionConfig {
88 &self.config
89 }
90
91 pub fn select_indices<B: Backend>(
96 &self,
97 scores: Tensor<B, 4>,
98 ) -> Result<SparseSelection, &'static str> {
99 let dims = scores.dims();
100 let data = scores
101 .into_data()
102 .into_vec::<f32>()
103 .map_err(|_| "sparse attention expects f32 scores")?;
104 self.select_indices_from_data(&data, dims)
105 }
106
107 pub fn sparsify_scores<B: Backend>(
113 &self,
114 scores: Tensor<B, 4>,
115 ) -> Result<(Tensor<B, 4>, SparseSelection), &'static str> {
116 let device = scores.device();
117 let dims = scores.dims();
118 let mut data = scores
119 .into_data()
120 .into_vec::<f32>()
121 .map_err(|_| "sparse attention expects f32 scores")?;
122 let selection = self.select_indices_from_data(&data, dims)?;
123
124 let [batch, num_heads, query_len, kv_len] = dims;
125 let stride_query = kv_len;
126 let stride_head = query_len * stride_query;
127 let stride_batch = num_heads * stride_head;
128
129 for b in 0..batch {
130 for h in 0..num_heads {
131 for q in 0..query_len {
132 let offset = b * stride_batch + h * stride_head + q * stride_query;
133 let selected = selection.indices_for(b, h, q);
134 let mut keep = vec![false; kv_len];
135 for &idx in selected {
136 if idx < kv_len {
137 keep[idx] = true;
138 }
139 }
140 for idx in 0..kv_len {
141 if !keep[idx] {
142 data[offset + idx] = MASK_VALUE;
143 }
144 }
145 }
146 }
147 }
148
149 let masked = Tensor::<B, 4>::from_data(TensorData::new(data, dims), &device);
150 Ok((masked, selection))
151 }
152
153 fn select_indices_from_data(
154 &self,
155 data: &[f32],
156 dims: [usize; 4],
157 ) -> Result<SparseSelection, &'static str> {
158 self.validate_config()?;
159 let [batch, num_heads, query_len, kv_len] = dims;
160 if kv_len == 0 {
161 return Err("kv_len must be > 0");
162 }
163
164 let target = self.config.selected_kv_count.min(kv_len);
165 if target == 0 {
166 return Err("selected_kv_count must be > 0");
167 }
168
169 let mut indices = Vec::with_capacity(batch * num_heads * query_len * target);
170 let stride_query = kv_len;
171 let stride_head = query_len * stride_query;
172 let stride_batch = num_heads * stride_head;
173
174 for b in 0..batch {
175 for h in 0..num_heads {
176 for q in 0..query_len {
177 let offset = b * stride_batch + h * stride_head + q * stride_query;
178 let scores = &data[offset..offset + kv_len];
179 let selected = self.select_for_query(scores, q, kv_len, target);
180 indices.extend(selected);
181 }
182 }
183 }
184
185 Ok(SparseSelection::new(
186 batch,
187 num_heads,
188 query_len,
189 target,
190 indices,
191 ))
192 }
193
194 fn select_for_query(
195 &self,
196 scores: &[f32],
197 query_idx: usize,
198 kv_len: usize,
199 target: usize,
200 ) -> Vec<usize> {
201 let mut forced = Vec::new();
202 let mut forced_mask = vec![false; kv_len];
203
204 match self.config.sparsity_pattern {
205 SparsityPattern::SlidingWindowGlobal {
206 window,
207 global_tokens,
208 } => {
209 let start = query_idx.saturating_sub(window);
210 let end = (query_idx + window + 1).min(kv_len);
211 for idx in start..end {
212 push_unique(idx, &mut forced, &mut forced_mask);
213 }
214 let global = global_tokens.min(kv_len);
215 for idx in 0..global {
216 push_unique(idx, &mut forced, &mut forced_mask);
217 }
218 }
219 SparsityPattern::Dynamic | SparsityPattern::BlockSparse { .. } => {}
220 }
221
222 if forced.len() >= target {
223 return top_k_indices(scores, &forced, target);
224 }
225
226 let remaining = target - forced.len();
227 let block_size = self.block_size();
228 let block_count = (remaining + block_size - 1) / block_size;
229 let blocks = select_blocks(scores, kv_len, block_size, block_count, Some(&forced_mask));
230
231 let mut candidates = Vec::new();
232 for block in blocks {
233 let start = block * block_size;
234 let end = (start + block_size).min(kv_len);
235 for idx in start..end {
236 if !forced_mask[idx] {
237 candidates.push(idx);
238 }
239 }
240 }
241
242 if candidates.len() < remaining {
243 for idx in 0..kv_len {
244 if !forced_mask[idx] {
245 candidates.push(idx);
246 }
247 }
248 }
249 candidates.sort_unstable();
250 candidates.dedup();
251
252 let mut selected = forced;
253 if remaining > 0 {
254 let mut extra = top_k_indices(scores, &candidates, remaining);
255 selected.append(&mut extra);
256 }
257 selected.sort_unstable();
258 selected.truncate(target);
259 selected
260 }
261
262 fn validate_config(&self) -> Result<(), &'static str> {
263 if self.config.selected_kv_count == 0 {
264 return Err("selected_kv_count must be > 0");
265 }
266 if self.config.block_size == 0 {
267 return Err("block_size must be > 0");
268 }
269 if let SparsityPattern::BlockSparse { block_size } = self.config.sparsity_pattern {
270 if block_size == 0 {
271 return Err("block sparse block_size must be > 0");
272 }
273 }
274 Ok(())
275 }
276
277 fn block_size(&self) -> usize {
278 match self.config.sparsity_pattern {
279 SparsityPattern::BlockSparse { block_size } => block_size.max(1),
280 _ => self.config.block_size.max(1),
281 }
282 }
283}
284
285const MASK_VALUE: f32 = -1.0e4_f32;
286
287fn push_unique(idx: usize, list: &mut Vec<usize>, mask: &mut [bool]) {
288 if !mask[idx] {
289 mask[idx] = true;
290 list.push(idx);
291 }
292}
293
294fn select_blocks(
295 scores: &[f32],
296 kv_len: usize,
297 block_size: usize,
298 block_count: usize,
299 skip_mask: Option<&[bool]>,
300) -> Vec<usize> {
301 if block_count == 0 || kv_len == 0 {
302 return Vec::new();
303 }
304 let num_blocks = (kv_len + block_size - 1) / block_size;
305 let mut block_scores = Vec::with_capacity(num_blocks);
306
307 for block in 0..num_blocks {
308 let start = block * block_size;
309 let end = (start + block_size).min(kv_len);
310 let mut max_score = f32::NEG_INFINITY;
311 for idx in start..end {
312 if skip_mask.map_or(false, |mask| mask[idx]) {
313 continue;
314 }
315 let score = scores[idx];
316 let score = if score.is_nan() { f32::NEG_INFINITY } else { score };
317 if score > max_score {
318 max_score = score;
319 }
320 }
321 if max_score > f32::NEG_INFINITY {
322 block_scores.push((block, max_score));
323 }
324 }
325
326 block_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
327 block_scores.truncate(block_count.min(block_scores.len()));
328 block_scores.into_iter().map(|(block, _)| block).collect()
329}
330
331fn top_k_indices(scores: &[f32], candidates: &[usize], k: usize) -> Vec<usize> {
332 if k == 0 || candidates.is_empty() {
333 return Vec::new();
334 }
335 let mut scored: Vec<(usize, f32)> = candidates
336 .iter()
337 .map(|&idx| {
338 let score = scores[idx];
339 let score = if score.is_nan() { f32::NEG_INFINITY } else { score };
340 (idx, score)
341 })
342 .collect();
343 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
344 scored.truncate(k.min(scored.len()));
345 let mut indices: Vec<usize> = scored.into_iter().map(|(idx, _)| idx).collect();
346 indices.sort_unstable();
347 indices
348}
349
350#[cfg(all(test, feature = "cpu"))]
351mod tests {
352 use super::*;
353 use burn::tensor::{Distribution, Tensor, TensorData};
354 use burn_ndarray::NdArray;
355
356 #[test]
357 fn test_sliding_window_global_forced_tokens() {
358 let config = SparseAttentionConfig {
359 selected_kv_count: 4,
360 block_size: 2,
361 sparsity_pattern: SparsityPattern::SlidingWindowGlobal {
362 window: 1,
363 global_tokens: 1,
364 },
365 };
366 let selector = SparseAttention::new(config);
367 let device = <NdArray<f32> as Backend>::Device::default();
368 let scores =
369 Tensor::<NdArray<f32>, 4>::random([1, 1, 3, 6], Distribution::Uniform(0.0, 1.0), &device);
370
371 let selection = selector.select_indices(scores).expect("selection");
372 let indices = selection.indices_for(0, 0, 2);
373 assert_eq!(indices, &[0, 1, 2, 3]);
374 }
375
376 #[test]
377 fn test_block_sparse_selection() {
378 let config = SparseAttentionConfig {
379 selected_kv_count: 3,
380 block_size: 4,
381 sparsity_pattern: SparsityPattern::BlockSparse { block_size: 4 },
382 };
383 let selector = SparseAttention::new(config);
384 let device = <NdArray<f32> as Backend>::Device::default();
385 let data = vec![0.1, 0.2, 0.3, 0.4, 5.0, 4.0, 3.0, 2.0];
386 let scores = Tensor::<NdArray<f32>, 4>::from_data(TensorData::new(data, [1, 1, 1, 8]), &device);
387
388 let selection = selector.select_indices(scores).expect("selection");
389 let indices = selection.indices_for(0, 0, 0);
390 assert_eq!(indices.len(), 3);
391 assert!(indices.iter().all(|&idx| idx >= 4));
392 }
393
394 #[test]
395 fn test_sparsify_scores_masks_unselected() {
396 let config = SparseAttentionConfig {
397 selected_kv_count: 1,
398 block_size: 2,
399 sparsity_pattern: SparsityPattern::Dynamic,
400 };
401 let selector = SparseAttention::new(config);
402 let device = <NdArray<f32> as Backend>::Device::default();
403 let data = vec![0.1, 0.2, 5.0, 0.3, 0.4];
404 let scores = Tensor::<NdArray<f32>, 4>::from_data(TensorData::new(data, [1, 1, 1, 5]), &device);
405
406 let (masked, selection) = selector.sparsify_scores(scores).expect("sparsify");
407 let masked_data = masked.into_data().into_vec::<f32>().expect("masked data");
408 let indices = selection.indices_for(0, 0, 0);
409 assert_eq!(indices, &[2]);
410 for (idx, value) in masked_data.iter().enumerate() {
411 if idx == 2 {
412 assert!((value - 5.0).abs() < 1e-4);
413 } else {
414 assert!((*value - MASK_VALUE).abs() < 1e-4);
415 }
416 }
417 }
418}