Skip to main content

god_graph/transformer/sparse_attention/
mod.rs

1//! Sparse Attention module for efficient attention computation
2//!
3//! This module provides various sparse attention patterns:
4//! - Sliding window attention (used in Mistral)
5//! - Block sparse attention
6//! - Star attention
7//! - Head-wise sparse attention
8
9use crate::tensor::DenseTensor;
10use crate::tensor::traits::{TensorOps, TensorBase};
11use crate::tensor::sparse::SparseTensor;
12
13/// Sparse attention pattern types
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SparsePattern {
16    /// Sliding window attention
17    SlidingWindow,
18    /// Block sparse attention
19    BlockSparse,
20    /// Star attention (center node attends to all)
21    Star,
22    /// Head-wise sparse (different heads use different patterns)
23    HeadSparse,
24}
25
26/// Configuration for sliding window attention
27#[derive(Debug, Clone)]
28pub struct SlidingWindowConfig {
29    /// Window size (number of tokens to attend to)
30    pub window_size: usize,
31    /// Left-only (causal) or bidirectional
32    pub causal: bool,
33}
34
35impl SlidingWindowConfig {
36    /// Create a new sliding window config
37    pub fn new(window_size: usize) -> Self {
38        Self {
39            window_size,
40            causal: true,
41        }
42    }
43
44    /// Create with bidirectional attention
45    pub fn bidirectional(window_size: usize) -> Self {
46        Self {
47            window_size,
48            causal: false,
49        }
50    }
51}
52
53/// Configuration for block sparse attention
54#[derive(Debug, Clone)]
55pub struct BlockSparseConfig {
56    /// Block size
57    pub block_size: usize,
58    /// Number of blocks to attend to per query block
59    pub num_blocks: usize,
60}
61
62impl BlockSparseConfig {
63    /// Create a new block sparse config
64    pub fn new(block_size: usize, num_blocks: usize) -> Self {
65        Self {
66            block_size,
67            num_blocks,
68        }
69    }
70}
71
72/// Sparse attention mask
73#[derive(Debug, Clone)]
74pub struct SparseMask {
75    /// Row offsets for CSR format [seq_len + 1]
76    pub row_offsets: Vec<usize>,
77    /// Column indices for CSR format [nnz]
78    pub col_indices: Vec<usize>,
79    /// Sequence length
80    pub seq_len: usize,
81    /// Number of non-zero elements
82    pub nnz: usize,
83}
84
85impl SparseMask {
86    /// Create a sliding window mask
87    ///
88    /// # Arguments
89    /// * `seq_len` - Sequence length
90    /// * `window_size` - Window size
91    /// * `causal` - Whether to use causal masking
92    pub fn sliding_window(seq_len: usize, window_size: usize, causal: bool) -> Self {
93        let mut row_offsets = Vec::with_capacity(seq_len + 1);
94        let mut col_indices = Vec::new();
95
96        row_offsets.push(0);
97
98        for i in 0..seq_len {
99            let start = if causal {
100                (i + 1).saturating_sub(window_size)
101            } else {
102                i.saturating_sub(window_size)
103            };
104            let end = if causal {
105                i + 1
106            } else {
107                (i + window_size).min(seq_len)
108            };
109
110            for j in start..end {
111                col_indices.push(j);
112            }
113
114            row_offsets.push(col_indices.len());
115        }
116
117        let nnz = col_indices.len();
118
119        Self {
120            row_offsets,
121            col_indices,
122            seq_len,
123            nnz,
124        }
125    }
126
127    /// Create a block sparse mask
128    ///
129    /// # Arguments
130    /// * `seq_len` - Sequence length
131    /// * `block_size` - Block size
132    /// * `num_blocks` - Number of blocks to attend to
133    pub fn block_sparse(seq_len: usize, block_size: usize, num_blocks: usize) -> Self {
134        let _num_blocks_total = seq_len.div_ceil(block_size);
135        let mut row_offsets = Vec::with_capacity(seq_len + 1);
136        let mut col_indices = Vec::new();
137
138        row_offsets.push(0);
139
140        for i in 0..seq_len {
141            let block_id = i / block_size;
142
143            // Attend to current block and previous blocks
144            for b in 0..num_blocks.min(block_id + 1) {
145                let src_block = block_id - b;
146                let start = src_block * block_size;
147                let end = (start + block_size).min(seq_len);
148
149                for j in start..end {
150                    col_indices.push(j);
151                }
152            }
153
154            row_offsets.push(col_indices.len());
155        }
156
157        let nnz = col_indices.len();
158
159        Self {
160            row_offsets,
161            col_indices,
162            seq_len,
163            nnz,
164        }
165    }
166
167    /// Create a star attention mask
168    ///
169    /// # Arguments
170    /// * `seq_len` - Sequence length
171    /// * `center_ratio` - Ratio of center tokens (e.g., 0.1 for 10%)
172    pub fn star(seq_len: usize, center_ratio: f64) -> Self {
173        let num_centers = (seq_len as f64 * center_ratio).ceil() as usize;
174        let mut row_offsets = Vec::with_capacity(seq_len + 1);
175        let mut col_indices = Vec::new();
176
177        row_offsets.push(0);
178
179        for i in 0..seq_len {
180            // Center tokens attend to all
181            if i < num_centers {
182                for j in 0..seq_len {
183                    col_indices.push(j);
184                }
185            } else {
186                // Non-center tokens attend to centers and local window
187                // Attend to centers
188                for j in 0..num_centers {
189                    col_indices.push(j);
190                }
191                // Attend to local window
192                let window_start = i.saturating_sub(64);
193                let window_end = (i + 64).min(seq_len);
194                for j in window_start..window_end {
195                    if !col_indices.contains(&j) {
196                        col_indices.push(j);
197                    }
198                }
199            }
200
201            row_offsets.push(col_indices.len());
202        }
203
204        let nnz = col_indices.len();
205
206        Self {
207            row_offsets,
208            col_indices,
209            seq_len,
210            nnz,
211        }
212    }
213
214    /// Convert to sparse tensor
215    pub fn to_sparse_tensor(&self, values: Vec<f64>) -> SparseTensor {
216        let values_tensor = DenseTensor::new(values, vec![self.nnz]);
217        SparseTensor::csr(
218            self.row_offsets.clone(),
219            self.col_indices.clone(),
220            values_tensor,
221            [self.seq_len, self.seq_len],
222        )
223    }
224
225    /// Get sparsity ratio
226    pub fn sparsity(&self) -> f64 {
227        let total = self.seq_len * self.seq_len;
228        1.0 - (self.nnz as f64 / total as f64)
229    }
230
231    /// Apply mask to attention scores
232    ///
233    /// # Arguments
234    /// * `scores` - Attention scores [batch, heads, seq_len, seq_len]
235    pub fn apply(&self, scores: &DenseTensor) -> DenseTensor {
236        let mut masked = scores.clone();
237        let data = masked.data_mut();
238
239        // Set masked positions to -inf
240        for i in 0..self.seq_len {
241            let start = self.row_offsets[i];
242            let end = self.row_offsets[i + 1];
243
244            for j in 0..self.seq_len {
245                // Check if (i, j) is in the mask
246                let is_valid = self.col_indices[start..end].contains(&j);
247
248                if !is_valid {
249                    // Set to -inf (use a large negative number)
250                    let offset = i * self.seq_len + j;
251                    if offset < data.len() {
252                        data[offset] = f64::NEG_INFINITY;
253                    }
254                }
255            }
256        }
257
258        masked
259    }
260}
261
262/// Sparse attention module
263#[derive(Debug, Clone)]
264pub struct SparseAttention {
265    /// Sparse pattern
266    pub pattern: SparsePattern,
267    /// Sparse mask
268    pub mask: Option<SparseMask>,
269    /// Window size (for sliding window)
270    pub window_size: Option<usize>,
271    /// Block size (for block sparse)
272    pub block_size: Option<usize>,
273    /// Number of blocks (for block sparse)
274    pub num_blocks: Option<usize>,
275    /// Scale factor
276    pub scale: f64,
277}
278
279impl SparseAttention {
280    /// Create a new sparse attention module
281    ///
282    /// # Arguments
283    /// * `pattern` - Sparse pattern type
284    /// * `head_dim` - Head dimension for scaling
285    pub fn new(pattern: SparsePattern, head_dim: usize) -> Self {
286        Self {
287            pattern,
288            mask: None,
289            window_size: None,
290            block_size: None,
291            num_blocks: None,
292            scale: 1.0 / (head_dim as f64).sqrt(),
293        }
294    }
295
296    /// Create sliding window attention
297    ///
298    /// # Arguments
299    /// * `head_dim` - Head dimension
300    /// * `window_size` - Window size
301    pub fn sliding_window(head_dim: usize, window_size: usize) -> Self {
302        let mut self_ = Self::new(SparsePattern::SlidingWindow, head_dim);
303        self_.window_size = Some(window_size);
304        self_
305    }
306
307    /// Create block sparse attention
308    ///
309    /// # Arguments
310    /// * `head_dim` - Head dimension
311    /// * `block_size` - Block size
312    /// * `num_blocks` - Number of blocks to attend to
313    pub fn block_sparse(head_dim: usize, block_size: usize, num_blocks: usize) -> Self {
314        let mut self_ = Self::new(SparsePattern::BlockSparse, head_dim);
315        self_.block_size = Some(block_size);
316        self_.num_blocks = Some(num_blocks);
317        self_
318    }
319
320    /// Create star attention
321    ///
322    /// # Arguments
323    /// * `head_dim` - Head dimension
324    /// * `center_ratio` - Ratio of center tokens
325    pub fn star(head_dim: usize, _center_ratio: f64) -> Self {
326        
327        Self::new(SparsePattern::Star, head_dim)
328    }
329
330    /// Build sparse mask for given sequence length
331    ///
332    /// # Arguments
333    /// * `seq_len` - Sequence length
334    pub fn build_mask(&mut self, seq_len: usize) {
335        self.mask = Some(match self.pattern {
336            SparsePattern::SlidingWindow => {
337                let window_size = self.window_size.unwrap_or(seq_len);
338                SparseMask::sliding_window(seq_len, window_size, true)
339            }
340            SparsePattern::BlockSparse => {
341                let block_size = self.block_size.unwrap_or(64);
342                let num_blocks = self.num_blocks.unwrap_or(4);
343                SparseMask::block_sparse(seq_len, block_size, num_blocks)
344            }
345            SparsePattern::Star => {
346                SparseMask::star(seq_len, 0.1)
347            }
348            SparsePattern::HeadSparse => {
349                // Default to sliding window for head-sparse
350                SparseMask::sliding_window(seq_len, 64, true)
351            }
352        });
353    }
354
355    /// Compute sparse attention
356    ///
357    /// # Arguments
358    /// * `query` - Query tensor [batch, heads, seq_len, head_dim]
359    /// * `key` - Key tensor [batch, heads, seq_len, head_dim]
360    /// * `value` - Value tensor [batch, heads, seq_len, head_dim]
361    ///
362    /// # Returns
363    /// Attention output [batch, heads, seq_len, head_dim]
364    pub fn forward(
365        &mut self,
366        query: &DenseTensor,
367        key: &DenseTensor,
368        value: &DenseTensor,
369    ) -> DenseTensor {
370        let seq_len = query.shape()[2];
371
372        // Build mask if not already built
373        if self.mask.is_none() || self.mask.as_ref().unwrap().seq_len != seq_len {
374            self.build_mask(seq_len);
375        }
376
377        // Compute attention scores
378        let key_t = key.transpose(None);
379        let mut scores = query.matmul(&key_t);
380        scores = scores.scale(self.scale);
381
382        // Apply sparse mask
383        if let Some(mask) = &self.mask {
384            scores = mask.apply(&scores);
385        }
386
387        // Apply softmax
388        let attn_weights = scores.softmax(-1);
389
390        // Apply attention to values
391        attn_weights.matmul(value)
392    }
393
394    /// Get sparsity ratio
395    pub fn sparsity(&self) -> f64 {
396        self.mask.as_ref().map(|m| m.sparsity()).unwrap_or(0.0)
397    }
398}
399
400/// Sliding window attention helper
401pub struct SlidingWindowAttention {
402    window_size: usize,
403    scale: f64,
404}
405
406impl SlidingWindowAttention {
407    /// Create a new sliding window attention
408    pub fn new(window_size: usize, head_dim: usize) -> Self {
409        Self {
410            window_size,
411            scale: 1.0 / (head_dim as f64).sqrt(),
412        }
413    }
414
415    /// Compute sliding window attention efficiently
416    ///
417    /// # Arguments
418    /// * `query` - Query [batch, heads, seq_len, head_dim]
419    /// * `key` - Key [batch, heads, seq_len, head_dim]
420    /// * `value` - Value [batch, heads, seq_len, head_dim]
421    pub fn forward(&self, query: &DenseTensor, key: &DenseTensor, value: &DenseTensor) -> DenseTensor {
422        let batch_size = query.shape()[0];
423        let num_heads = query.shape()[1];
424        let seq_len = query.shape()[2];
425        let head_dim = query.shape()[3];
426
427        let mut output_data = Vec::with_capacity(batch_size * num_heads * seq_len * head_dim);
428
429        for b in 0..batch_size {
430            for h in 0..num_heads {
431                for i in 0..seq_len {
432                    // Compute attention for position i
433                    let mut attn_output = vec![0.0; head_dim];
434                    let mut total_weight = 0.0;
435
436                    // Only attend to window
437                    let start = i.saturating_sub(self.window_size);
438                    let end = i + 1;
439
440                    for j in start..end {
441                        // Compute dot product
442                        let q_slice = &query.data()[(b * num_heads * seq_len * head_dim + h * seq_len * head_dim + i * head_dim)..];
443                        let k_slice = &key.data()[(b * num_heads * seq_len * head_dim + h * seq_len * head_dim + j * head_dim)..];
444
445                        let mut score = 0.0;
446                        for d in 0..head_dim {
447                            score += q_slice[d] * k_slice[d];
448                        }
449                        score *= self.scale;
450
451                        // Softmax weight
452                        let weight = score.exp();
453
454                        // Weighted sum of values
455                        let v_slice = &value.data()[(b * num_heads * seq_len * head_dim + h * seq_len * head_dim + j * head_dim)..];
456                        #[allow(clippy::needless_range_loop)]
457                        for d in 0..head_dim {
458                            attn_output[d] += weight * v_slice[d];
459                        }
460                        total_weight += weight;
461                    }
462
463                    // Normalize
464                    if total_weight > 0.0 {
465                        #[allow(clippy::needless_range_loop)]
466                        for d in 0..head_dim {
467                            attn_output[d] /= total_weight;
468                        }
469                    }
470
471                    output_data.extend(attn_output);
472                }
473            }
474        }
475
476        DenseTensor::new(output_data, vec![batch_size, num_heads, seq_len, head_dim])
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    #[test]
485    fn test_sliding_window_mask() {
486        let mask = SparseMask::sliding_window(10, 3, true);
487
488        assert_eq!(mask.seq_len, 10);
489        assert!(mask.nnz < 10 * 10); // Should be sparse
490        assert_eq!(mask.row_offsets.len(), 11);
491    }
492
493    #[test]
494    fn test_block_sparse_mask() {
495        let mask = SparseMask::block_sparse(16, 4, 2);
496
497        assert_eq!(mask.seq_len, 16);
498        assert!(mask.nnz < 16 * 16);
499    }
500
501    #[test]
502    fn test_star_mask() {
503        let mask = SparseMask::star(20, 0.1);
504
505        assert_eq!(mask.seq_len, 20);
506        // Center tokens (2) should attend to all
507        // Non-center tokens should attend to centers + local
508    }
509
510    #[test]
511    fn test_sparsity_calculation() {
512        let mask = SparseMask::sliding_window(100, 10, true);
513        let sparsity = mask.sparsity();
514
515        // Should be approximately 90% sparse
516        assert!(sparsity > 0.8);
517        assert!(sparsity < 1.0);
518    }
519
520    #[test]
521    fn test_sparse_attention_sliding_window() {
522        let mut attn = SparseAttention::sliding_window(64, 10);
523        attn.build_mask(20);
524
525        assert_eq!(attn.pattern, SparsePattern::SlidingWindow);
526        assert!(attn.mask.is_some());
527    }
528
529    #[test]
530    fn test_sliding_window_attention_forward() {
531        let batch_size = 1;
532        let num_heads = 2;
533        let seq_len = 8;
534        let head_dim = 16;
535
536        let query = DenseTensor::ones(vec![batch_size, num_heads, seq_len, head_dim]);
537        let key = DenseTensor::ones(vec![batch_size, num_heads, seq_len, head_dim]);
538        let value = DenseTensor::ones(vec![batch_size, num_heads, seq_len, head_dim]);
539
540        let attn = SlidingWindowAttention::new(4, head_dim);
541        let output = attn.forward(&query, &key, &value);
542
543        assert_eq!(output.shape(), &[batch_size, num_heads, seq_len, head_dim]);
544    }
545}