1use thiserror::Error;
11
12use crate::layers::attention_fused::softmax_inplace;
13
14#[derive(Debug, Error)]
18pub enum SparseAttnError {
19 #[error("query/key/value length mismatch: q={q}, k={k}, v={v}")]
20 LengthMismatch { q: usize, k: usize, v: usize },
21 #[error("head_dim must be > 0")]
22 InvalidHeadDim,
23 #[error("window_size must be odd for symmetric windows")]
24 WindowSizeMustBeOdd,
25 #[error("empty attention: no valid (q,k) pairs")]
26 EmptyAttention,
27}
28
29#[derive(Debug, Clone, PartialEq)]
33pub enum SparsePattern {
34 LocalWindow { window_size: usize },
36 BigBird {
38 window_size: usize,
39 num_global_tokens: usize,
40 num_random_connections: usize,
41 seed: u64,
42 },
43 Strided { window_size: usize, stride: usize },
45 Dense,
47}
48
49pub struct SparseAttentionMask {
53 pub seq_len: usize,
55 attend_to: Vec<Vec<usize>>,
57 pub pattern: SparsePattern,
59}
60
61impl SparseAttentionMask {
62 pub fn build(seq_len: usize, pattern: &SparsePattern) -> Result<Self, SparseAttnError> {
66 let attend_to = match pattern {
67 SparsePattern::Dense => build_dense(seq_len),
68 SparsePattern::LocalWindow { window_size } => {
69 build_local_window(seq_len, *window_size)?
70 }
71 SparsePattern::BigBird {
72 window_size,
73 num_global_tokens,
74 num_random_connections,
75 seed,
76 } => build_bigbird(
77 seq_len,
78 *window_size,
79 *num_global_tokens,
80 *num_random_connections,
81 *seed,
82 )?,
83 SparsePattern::Strided {
84 window_size,
85 stride,
86 } => build_strided(seq_len, *window_size, *stride)?,
87 };
88
89 Ok(Self {
90 seq_len,
91 attend_to,
92 pattern: pattern.clone(),
93 })
94 }
95
96 pub fn keys_for_query(&self, q: usize) -> &[usize] {
98 if q >= self.seq_len {
99 return &[];
100 }
101 &self.attend_to[q]
102 }
103
104 pub fn nnz(&self) -> usize {
106 self.attend_to.iter().map(|v| v.len()).sum()
107 }
108
109 pub fn density(&self) -> f32 {
111 let total = (self.seq_len as f64) * (self.seq_len as f64);
112 if total == 0.0 {
113 return 0.0;
114 }
115 (self.nnz() as f64 / total) as f32
116 }
117
118 pub fn can_attend(&self, q: usize, k: usize) -> bool {
120 if q >= self.seq_len || k >= self.seq_len {
121 return false;
122 }
123 self.attend_to[q].binary_search(&k).is_ok()
124 }
125
126 pub fn to_dense(&self) -> Vec<Vec<bool>> {
130 let n = self.seq_len;
131 let mut mask = vec![vec![false; n]; n];
132 for (q, keys) in self.attend_to.iter().enumerate() {
133 for &k in keys {
134 mask[q][k] = true;
135 }
136 }
137 mask
138 }
139}
140
141fn build_dense(seq_len: usize) -> Vec<Vec<usize>> {
145 (0..seq_len).map(|_| (0..seq_len).collect()).collect()
146}
147
148fn build_local_window(
150 seq_len: usize,
151 window_size: usize,
152) -> Result<Vec<Vec<usize>>, SparseAttnError> {
153 if window_size % 2 == 0 {
154 return Err(SparseAttnError::WindowSizeMustBeOdd);
155 }
156 let half = window_size / 2;
157 let mut attend_to = Vec::with_capacity(seq_len);
158 for q in 0..seq_len {
159 let start = q.saturating_sub(half);
160 let end = (q + half + 1).min(seq_len);
161 attend_to.push((start..end).collect());
162 }
163 Ok(attend_to)
164}
165
166fn build_bigbird(
171 seq_len: usize,
172 window_size: usize,
173 num_global_tokens: usize,
174 num_random_connections: usize,
175 seed: u64,
176) -> Result<Vec<Vec<usize>>, SparseAttnError> {
177 if window_size % 2 == 0 {
178 return Err(SparseAttnError::WindowSizeMustBeOdd);
179 }
180 let half = window_size / 2;
181 let actual_global = num_global_tokens.min(seq_len);
183
184 let mut attend_to: Vec<Vec<usize>> = Vec::with_capacity(seq_len);
185 let mut lcg_state = seed.wrapping_add(0xDEAD_BEEF_CAFE_1234);
186
187 for q in 0..seq_len {
188 let mut keys: std::collections::BTreeSet<usize> = std::collections::BTreeSet::new();
189
190 for g in 0..actual_global {
192 keys.insert(g);
193 }
194 for g in 0..actual_global {
196 if q == g {
197 for k in 0..seq_len {
199 keys.insert(k);
200 }
201 }
202 }
203
204 let start = q.saturating_sub(half);
206 let end = (q + half + 1).min(seq_len);
207 for k in start..end {
208 keys.insert(k);
209 }
210
211 let num_rand = if seq_len > actual_global + window_size {
213 num_random_connections
214 } else {
215 0
216 };
217 for r in 0..num_rand {
218 lcg_state = lcg_state
220 .wrapping_mul(6_364_136_223_846_793_005)
221 .wrapping_add(1_442_695_040_888_963_407)
222 .wrapping_add((q as u64).wrapping_mul(137).wrapping_add(r as u64));
223 let k = (lcg_state >> 33) as usize % seq_len;
224 keys.insert(k);
225 }
226
227 attend_to.push(keys.into_iter().collect());
228 }
229
230 Ok(attend_to)
231}
232
233fn build_strided(
235 seq_len: usize,
236 window_size: usize,
237 stride: usize,
238) -> Result<Vec<Vec<usize>>, SparseAttnError> {
239 if window_size % 2 == 0 {
240 return Err(SparseAttnError::WindowSizeMustBeOdd);
241 }
242 if stride == 0 {
243 return Ok(build_dense(seq_len));
245 }
246 let half = window_size / 2;
247
248 let mut attend_to = Vec::with_capacity(seq_len);
249 for q in 0..seq_len {
250 let is_global = (q % stride) == 0;
251 let mut keys: Vec<usize> = if is_global {
252 (0..seq_len).collect()
254 } else {
255 let start = q.saturating_sub(half);
257 let end = (q + half + 1).min(seq_len);
258 let mut ks: std::collections::BTreeSet<usize> = (start..end).collect();
260 let mut g = 0usize;
261 while g < seq_len {
262 ks.insert(g);
263 g += stride;
264 }
265 ks.into_iter().collect()
266 };
267 keys.sort_unstable();
268 keys.dedup();
269 attend_to.push(keys);
270 }
271 Ok(attend_to)
272}
273
274pub fn sparse_attention_forward(
284 queries: &[f32],
285 keys: &[f32],
286 values: &[f32],
287 seq_len: usize,
288 head_dim: usize,
289 mask: &SparseAttentionMask,
290 scale: f32,
291) -> Result<Vec<f32>, SparseAttnError> {
292 validate_inputs(queries, keys, values, seq_len, head_dim)?;
293
294 if mask.nnz() == 0 {
295 return Err(SparseAttnError::EmptyAttention);
296 }
297
298 let mut output = vec![0.0f32; seq_len * head_dim];
299
300 for q in 0..seq_len {
301 let key_positions = mask.keys_for_query(q);
302 if key_positions.is_empty() {
303 continue;
305 }
306
307 let q_vec = &queries[q * head_dim..(q + 1) * head_dim];
308
309 let mut scores: Vec<f32> = key_positions
311 .iter()
312 .map(|&k| {
313 let k_vec = &keys[k * head_dim..(k + 1) * head_dim];
314 dot_scaled(q_vec, k_vec, scale)
315 })
316 .collect();
317
318 softmax_inplace(&mut scores);
320
321 let out_row = &mut output[q * head_dim..(q + 1) * head_dim];
323 for (weight, &k_pos) in scores.iter().zip(key_positions.iter()) {
324 let v_vec = &values[k_pos * head_dim..(k_pos + 1) * head_dim];
325 for (o, &v) in out_row.iter_mut().zip(v_vec.iter()) {
326 *o += weight * v;
327 }
328 }
329 }
330
331 Ok(output)
332}
333
334pub fn sparse_vs_dense_error(
339 queries: &[f32],
340 keys: &[f32],
341 values: &[f32],
342 seq_len: usize,
343 head_dim: usize,
344 mask: &SparseAttentionMask,
345) -> Result<f32, SparseAttnError> {
346 let scale = 1.0 / (head_dim as f32).sqrt();
347
348 let sparse_out =
349 sparse_attention_forward(queries, keys, values, seq_len, head_dim, mask, scale)?;
350
351 let dense_mask = SparseAttentionMask::build(seq_len, &SparsePattern::Dense)
352 .map_err(|_| SparseAttnError::EmptyAttention)?;
353 let dense_out =
354 sparse_attention_forward(queries, keys, values, seq_len, head_dim, &dense_mask, scale)?;
355
356 let total_elements = seq_len * head_dim;
357 if total_elements == 0 {
358 return Ok(0.0);
359 }
360
361 let mae = sparse_out
362 .iter()
363 .zip(dense_out.iter())
364 .map(|(s, d)| (s - d).abs())
365 .sum::<f32>()
366 / total_elements as f32;
367
368 Ok(mae)
369}
370
371pub fn memory_reduction(_seq_len: usize, mask: &SparseAttentionMask) -> f32 {
376 1.0 - mask.density()
377}
378
379fn validate_inputs(
383 queries: &[f32],
384 keys: &[f32],
385 values: &[f32],
386 seq_len: usize,
387 head_dim: usize,
388) -> Result<(), SparseAttnError> {
389 if head_dim == 0 {
390 return Err(SparseAttnError::InvalidHeadDim);
391 }
392 let expected = seq_len * head_dim;
393 if queries.len() != expected || keys.len() != expected || values.len() != expected {
394 return Err(SparseAttnError::LengthMismatch {
395 q: queries.len(),
396 k: keys.len(),
397 v: values.len(),
398 });
399 }
400 Ok(())
401}
402
403#[inline]
405fn dot_scaled(a: &[f32], b: &[f32], scale: f32) -> f32 {
406 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum::<f32>() * scale
407}
408
409#[cfg(test)]
412mod tests {
413 use super::*;
414
415 fn make_qkv(seq_len: usize, head_dim: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
416 let n = seq_len * head_dim;
417 let q: Vec<f32> = (0..n).map(|i| (i as f32 * 0.03) - 0.5).collect();
418 let k: Vec<f32> = (0..n)
419 .map(|i| ((i * 7 + 3) % 17) as f32 * 0.04 - 0.3)
420 .collect();
421 let v: Vec<f32> = (0..n)
422 .map(|i| ((i * 11 + 5) % 13) as f32 * 0.05 - 0.3)
423 .collect();
424 (q, k, v)
425 }
426
427 #[test]
428 fn dense_mask_full() {
429 let seq_len = 8;
430 let mask = SparseAttentionMask::build(seq_len, &SparsePattern::Dense)
431 .expect("dense build should succeed");
432 assert_eq!(mask.nnz(), seq_len * seq_len);
433 }
434
435 #[test]
436 fn local_window_density_less_than_one() {
437 let seq_len = 16;
438 let mask =
439 SparseAttentionMask::build(seq_len, &SparsePattern::LocalWindow { window_size: 3 })
440 .expect("local window build should succeed");
441 assert!(
442 mask.density() < 1.0,
443 "density should be < 1.0 for local window"
444 );
445 }
446
447 #[test]
448 fn sparse_forward_dense_matches_naive_inline() {
449 let seq_len = 4;
450 let head_dim = 4;
451 let (q, k, v) = make_qkv(seq_len, head_dim);
452 let scale = 1.0 / (head_dim as f32).sqrt();
453 let mask = SparseAttentionMask::build(seq_len, &SparsePattern::Dense).expect("dense mask");
454 let out = sparse_attention_forward(&q, &k, &v, seq_len, head_dim, &mask, scale)
455 .expect("sparse forward failed");
456 assert_eq!(out.len(), seq_len * head_dim);
457 }
458}