god_graph/transformer/sparse_attention/
mod.rs1use crate::tensor::DenseTensor;
10use crate::tensor::traits::{TensorOps, TensorBase};
11use crate::tensor::sparse::SparseTensor;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SparsePattern {
16 SlidingWindow,
18 BlockSparse,
20 Star,
22 HeadSparse,
24}
25
26#[derive(Debug, Clone)]
28pub struct SlidingWindowConfig {
29 pub window_size: usize,
31 pub causal: bool,
33}
34
35impl SlidingWindowConfig {
36 pub fn new(window_size: usize) -> Self {
38 Self {
39 window_size,
40 causal: true,
41 }
42 }
43
44 pub fn bidirectional(window_size: usize) -> Self {
46 Self {
47 window_size,
48 causal: false,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct BlockSparseConfig {
56 pub block_size: usize,
58 pub num_blocks: usize,
60}
61
62impl BlockSparseConfig {
63 pub fn new(block_size: usize, num_blocks: usize) -> Self {
65 Self {
66 block_size,
67 num_blocks,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct SparseMask {
75 pub row_offsets: Vec<usize>,
77 pub col_indices: Vec<usize>,
79 pub seq_len: usize,
81 pub nnz: usize,
83}
84
85impl SparseMask {
86 pub fn sliding_window(seq_len: usize, window_size: usize, causal: bool) -> Self {
93 let mut row_offsets = Vec::with_capacity(seq_len + 1);
94 let mut col_indices = Vec::new();
95
96 row_offsets.push(0);
97
98 for i in 0..seq_len {
99 let start = if causal {
100 (i + 1).saturating_sub(window_size)
101 } else {
102 i.saturating_sub(window_size)
103 };
104 let end = if causal {
105 i + 1
106 } else {
107 (i + window_size).min(seq_len)
108 };
109
110 for j in start..end {
111 col_indices.push(j);
112 }
113
114 row_offsets.push(col_indices.len());
115 }
116
117 let nnz = col_indices.len();
118
119 Self {
120 row_offsets,
121 col_indices,
122 seq_len,
123 nnz,
124 }
125 }
126
127 pub fn block_sparse(seq_len: usize, block_size: usize, num_blocks: usize) -> Self {
134 let _num_blocks_total = seq_len.div_ceil(block_size);
135 let mut row_offsets = Vec::with_capacity(seq_len + 1);
136 let mut col_indices = Vec::new();
137
138 row_offsets.push(0);
139
140 for i in 0..seq_len {
141 let block_id = i / block_size;
142
143 for b in 0..num_blocks.min(block_id + 1) {
145 let src_block = block_id - b;
146 let start = src_block * block_size;
147 let end = (start + block_size).min(seq_len);
148
149 for j in start..end {
150 col_indices.push(j);
151 }
152 }
153
154 row_offsets.push(col_indices.len());
155 }
156
157 let nnz = col_indices.len();
158
159 Self {
160 row_offsets,
161 col_indices,
162 seq_len,
163 nnz,
164 }
165 }
166
167 pub fn star(seq_len: usize, center_ratio: f64) -> Self {
173 let num_centers = (seq_len as f64 * center_ratio).ceil() as usize;
174 let mut row_offsets = Vec::with_capacity(seq_len + 1);
175 let mut col_indices = Vec::new();
176
177 row_offsets.push(0);
178
179 for i in 0..seq_len {
180 if i < num_centers {
182 for j in 0..seq_len {
183 col_indices.push(j);
184 }
185 } else {
186 for j in 0..num_centers {
189 col_indices.push(j);
190 }
191 let window_start = i.saturating_sub(64);
193 let window_end = (i + 64).min(seq_len);
194 for j in window_start..window_end {
195 if !col_indices.contains(&j) {
196 col_indices.push(j);
197 }
198 }
199 }
200
201 row_offsets.push(col_indices.len());
202 }
203
204 let nnz = col_indices.len();
205
206 Self {
207 row_offsets,
208 col_indices,
209 seq_len,
210 nnz,
211 }
212 }
213
214 pub fn to_sparse_tensor(&self, values: Vec<f64>) -> SparseTensor {
216 let values_tensor = DenseTensor::new(values, vec![self.nnz]);
217 SparseTensor::csr(
218 self.row_offsets.clone(),
219 self.col_indices.clone(),
220 values_tensor,
221 [self.seq_len, self.seq_len],
222 )
223 }
224
225 pub fn sparsity(&self) -> f64 {
227 let total = self.seq_len * self.seq_len;
228 1.0 - (self.nnz as f64 / total as f64)
229 }
230
231 pub fn apply(&self, scores: &DenseTensor) -> DenseTensor {
236 let mut masked = scores.clone();
237 let data = masked.data_mut();
238
239 for i in 0..self.seq_len {
241 let start = self.row_offsets[i];
242 let end = self.row_offsets[i + 1];
243
244 for j in 0..self.seq_len {
245 let is_valid = self.col_indices[start..end].contains(&j);
247
248 if !is_valid {
249 let offset = i * self.seq_len + j;
251 if offset < data.len() {
252 data[offset] = f64::NEG_INFINITY;
253 }
254 }
255 }
256 }
257
258 masked
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct SparseAttention {
265 pub pattern: SparsePattern,
267 pub mask: Option<SparseMask>,
269 pub window_size: Option<usize>,
271 pub block_size: Option<usize>,
273 pub num_blocks: Option<usize>,
275 pub scale: f64,
277}
278
279impl SparseAttention {
280 pub fn new(pattern: SparsePattern, head_dim: usize) -> Self {
286 Self {
287 pattern,
288 mask: None,
289 window_size: None,
290 block_size: None,
291 num_blocks: None,
292 scale: 1.0 / (head_dim as f64).sqrt(),
293 }
294 }
295
296 pub fn sliding_window(head_dim: usize, window_size: usize) -> Self {
302 let mut self_ = Self::new(SparsePattern::SlidingWindow, head_dim);
303 self_.window_size = Some(window_size);
304 self_
305 }
306
307 pub fn block_sparse(head_dim: usize, block_size: usize, num_blocks: usize) -> Self {
314 let mut self_ = Self::new(SparsePattern::BlockSparse, head_dim);
315 self_.block_size = Some(block_size);
316 self_.num_blocks = Some(num_blocks);
317 self_
318 }
319
320 pub fn star(head_dim: usize, _center_ratio: f64) -> Self {
326
327 Self::new(SparsePattern::Star, head_dim)
328 }
329
330 pub fn build_mask(&mut self, seq_len: usize) {
335 self.mask = Some(match self.pattern {
336 SparsePattern::SlidingWindow => {
337 let window_size = self.window_size.unwrap_or(seq_len);
338 SparseMask::sliding_window(seq_len, window_size, true)
339 }
340 SparsePattern::BlockSparse => {
341 let block_size = self.block_size.unwrap_or(64);
342 let num_blocks = self.num_blocks.unwrap_or(4);
343 SparseMask::block_sparse(seq_len, block_size, num_blocks)
344 }
345 SparsePattern::Star => {
346 SparseMask::star(seq_len, 0.1)
347 }
348 SparsePattern::HeadSparse => {
349 SparseMask::sliding_window(seq_len, 64, true)
351 }
352 });
353 }
354
355 pub fn forward(
365 &mut self,
366 query: &DenseTensor,
367 key: &DenseTensor,
368 value: &DenseTensor,
369 ) -> DenseTensor {
370 let seq_len = query.shape()[2];
371
372 if self.mask.is_none() || self.mask.as_ref().unwrap().seq_len != seq_len {
374 self.build_mask(seq_len);
375 }
376
377 let key_t = key.transpose(None);
379 let mut scores = query.matmul(&key_t);
380 scores = scores.scale(self.scale);
381
382 if let Some(mask) = &self.mask {
384 scores = mask.apply(&scores);
385 }
386
387 let attn_weights = scores.softmax(-1);
389
390 attn_weights.matmul(value)
392 }
393
394 pub fn sparsity(&self) -> f64 {
396 self.mask.as_ref().map(|m| m.sparsity()).unwrap_or(0.0)
397 }
398}
399
400pub struct SlidingWindowAttention {
402 window_size: usize,
403 scale: f64,
404}
405
406impl SlidingWindowAttention {
407 pub fn new(window_size: usize, head_dim: usize) -> Self {
409 Self {
410 window_size,
411 scale: 1.0 / (head_dim as f64).sqrt(),
412 }
413 }
414
415 pub fn forward(&self, query: &DenseTensor, key: &DenseTensor, value: &DenseTensor) -> DenseTensor {
422 let batch_size = query.shape()[0];
423 let num_heads = query.shape()[1];
424 let seq_len = query.shape()[2];
425 let head_dim = query.shape()[3];
426
427 let mut output_data = Vec::with_capacity(batch_size * num_heads * seq_len * head_dim);
428
429 for b in 0..batch_size {
430 for h in 0..num_heads {
431 for i in 0..seq_len {
432 let mut attn_output = vec![0.0; head_dim];
434 let mut total_weight = 0.0;
435
436 let start = i.saturating_sub(self.window_size);
438 let end = i + 1;
439
440 for j in start..end {
441 let q_slice = &query.data()[(b * num_heads * seq_len * head_dim + h * seq_len * head_dim + i * head_dim)..];
443 let k_slice = &key.data()[(b * num_heads * seq_len * head_dim + h * seq_len * head_dim + j * head_dim)..];
444
445 let mut score = 0.0;
446 for d in 0..head_dim {
447 score += q_slice[d] * k_slice[d];
448 }
449 score *= self.scale;
450
451 let weight = score.exp();
453
454 let v_slice = &value.data()[(b * num_heads * seq_len * head_dim + h * seq_len * head_dim + j * head_dim)..];
456 #[allow(clippy::needless_range_loop)]
457 for d in 0..head_dim {
458 attn_output[d] += weight * v_slice[d];
459 }
460 total_weight += weight;
461 }
462
463 if total_weight > 0.0 {
465 #[allow(clippy::needless_range_loop)]
466 for d in 0..head_dim {
467 attn_output[d] /= total_weight;
468 }
469 }
470
471 output_data.extend(attn_output);
472 }
473 }
474 }
475
476 DenseTensor::new(output_data, vec![batch_size, num_heads, seq_len, head_dim])
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[test]
485 fn test_sliding_window_mask() {
486 let mask = SparseMask::sliding_window(10, 3, true);
487
488 assert_eq!(mask.seq_len, 10);
489 assert!(mask.nnz < 10 * 10); assert_eq!(mask.row_offsets.len(), 11);
491 }
492
493 #[test]
494 fn test_block_sparse_mask() {
495 let mask = SparseMask::block_sparse(16, 4, 2);
496
497 assert_eq!(mask.seq_len, 16);
498 assert!(mask.nnz < 16 * 16);
499 }
500
501 #[test]
502 fn test_star_mask() {
503 let mask = SparseMask::star(20, 0.1);
504
505 assert_eq!(mask.seq_len, 20);
506 }
509
510 #[test]
511 fn test_sparsity_calculation() {
512 let mask = SparseMask::sliding_window(100, 10, true);
513 let sparsity = mask.sparsity();
514
515 assert!(sparsity > 0.8);
517 assert!(sparsity < 1.0);
518 }
519
520 #[test]
521 fn test_sparse_attention_sliding_window() {
522 let mut attn = SparseAttention::sliding_window(64, 10);
523 attn.build_mask(20);
524
525 assert_eq!(attn.pattern, SparsePattern::SlidingWindow);
526 assert!(attn.mask.is_some());
527 }
528
529 #[test]
530 fn test_sliding_window_attention_forward() {
531 let batch_size = 1;
532 let num_heads = 2;
533 let seq_len = 8;
534 let head_dim = 16;
535
536 let query = DenseTensor::ones(vec![batch_size, num_heads, seq_len, head_dim]);
537 let key = DenseTensor::ones(vec![batch_size, num_heads, seq_len, head_dim]);
538 let value = DenseTensor::ones(vec![batch_size, num_heads, seq_len, head_dim]);
539
540 let attn = SlidingWindowAttention::new(4, head_dim);
541 let output = attn.forward(&query, &key, &value);
542
543 assert_eq!(output.shape(), &[batch_size, num_heads, seq_len, head_dim]);
544 }
545}