Skip to main content

oxibonsai_model/layers/
sparse_attention.rs

1//! Sparse attention patterns for efficient long-sequence processing.
2//!
3//! Implements three attention mask patterns:
4//! - Local window attention (Longformer-style)
5//! - Global + local (BigBird-style)
6//! - Strided sparse attention (every k-th token attends globally)
7//!
8//! These reduce complexity from O(n²) to O(n√n) or O(n*w).
9
10use thiserror::Error;
11
12use crate::layers::attention_fused::softmax_inplace;
13
14// ─── Error type ──────────────────────────────────────────────────────────────
15
16/// Errors that can occur during sparse attention operations.
17#[derive(Debug, Error)]
18pub enum SparseAttnError {
19    #[error("query/key/value length mismatch: q={q}, k={k}, v={v}")]
20    LengthMismatch { q: usize, k: usize, v: usize },
21    #[error("head_dim must be > 0")]
22    InvalidHeadDim,
23    #[error("window_size must be odd for symmetric windows")]
24    WindowSizeMustBeOdd,
25    #[error("empty attention: no valid (q,k) pairs")]
26    EmptyAttention,
27}
28
29// ─── Sparse Pattern ───────────────────────────────────────────────────────────
30
31/// Sparse attention pattern type.
32#[derive(Debug, Clone, PartialEq)]
33pub enum SparsePattern {
34    /// Each token attends to `window_size` neighbors (must be odd).
35    LocalWindow { window_size: usize },
36    /// BigBird: global tokens + local window + random sparse connections.
37    BigBird {
38        window_size: usize,
39        num_global_tokens: usize,
40        num_random_connections: usize,
41        seed: u64,
42    },
43    /// Strided: every `stride`-th token attends globally; others use local window.
44    Strided { window_size: usize, stride: usize },
45    /// Full dense attention (baseline).
46    Dense,
47}
48
49// ─── Sparse Attention Mask ────────────────────────────────────────────────────
50
51/// A sparse attention mask: which positions each query can attend to.
52pub struct SparseAttentionMask {
53    /// Length of the sequence.
54    pub seq_len: usize,
55    /// For each query position: sorted list of key positions it can attend to.
56    attend_to: Vec<Vec<usize>>,
57    /// The pattern that was used to build this mask.
58    pub pattern: SparsePattern,
59}
60
61impl SparseAttentionMask {
62    /// Build a sparse mask for `seq_len` tokens using `pattern`.
63    ///
64    /// Returns an error if the pattern parameters are invalid.
65    pub fn build(seq_len: usize, pattern: &SparsePattern) -> Result<Self, SparseAttnError> {
66        let attend_to = match pattern {
67            SparsePattern::Dense => build_dense(seq_len),
68            SparsePattern::LocalWindow { window_size } => {
69                build_local_window(seq_len, *window_size)?
70            }
71            SparsePattern::BigBird {
72                window_size,
73                num_global_tokens,
74                num_random_connections,
75                seed,
76            } => build_bigbird(
77                seq_len,
78                *window_size,
79                *num_global_tokens,
80                *num_random_connections,
81                *seed,
82            )?,
83            SparsePattern::Strided {
84                window_size,
85                stride,
86            } => build_strided(seq_len, *window_size, *stride)?,
87        };
88
89        Ok(Self {
90            seq_len,
91            attend_to,
92            pattern: pattern.clone(),
93        })
94    }
95
96    /// Get the key positions that query `q` can attend to.
97    pub fn keys_for_query(&self, q: usize) -> &[usize] {
98        if q >= self.seq_len {
99            return &[];
100        }
101        &self.attend_to[q]
102    }
103
104    /// Total number of attended (q, k) pairs.
105    pub fn nnz(&self) -> usize {
106        self.attend_to.iter().map(|v| v.len()).sum()
107    }
108
109    /// Density: nnz / seq_len² (1.0 = dense).
110    pub fn density(&self) -> f32 {
111        let total = (self.seq_len as f64) * (self.seq_len as f64);
112        if total == 0.0 {
113            return 0.0;
114        }
115        (self.nnz() as f64 / total) as f32
116    }
117
118    /// Whether query `q` can attend to key `k`.
119    pub fn can_attend(&self, q: usize, k: usize) -> bool {
120        if q >= self.seq_len || k >= self.seq_len {
121            return false;
122        }
123        self.attend_to[q].binary_search(&k).is_ok()
124    }
125
126    /// Convert to a dense boolean mask matrix [seq_len × seq_len].
127    ///
128    /// `result[q * seq_len + k] == true` means query `q` attends to key `k`.
129    pub fn to_dense(&self) -> Vec<Vec<bool>> {
130        let n = self.seq_len;
131        let mut mask = vec![vec![false; n]; n];
132        for (q, keys) in self.attend_to.iter().enumerate() {
133            for &k in keys {
134                mask[q][k] = true;
135            }
136        }
137        mask
138    }
139}
140
141// ─── Pattern builders ─────────────────────────────────────────────────────────
142
143/// Full dense: every query attends to every key.
144fn build_dense(seq_len: usize) -> Vec<Vec<usize>> {
145    (0..seq_len).map(|_| (0..seq_len).collect()).collect()
146}
147
148/// Local sliding window of `window_size` (must be odd) centered at each token.
149fn build_local_window(
150    seq_len: usize,
151    window_size: usize,
152) -> Result<Vec<Vec<usize>>, SparseAttnError> {
153    if window_size % 2 == 0 {
154        return Err(SparseAttnError::WindowSizeMustBeOdd);
155    }
156    let half = window_size / 2;
157    let mut attend_to = Vec::with_capacity(seq_len);
158    for q in 0..seq_len {
159        let start = q.saturating_sub(half);
160        let end = (q + half + 1).min(seq_len);
161        attend_to.push((start..end).collect());
162    }
163    Ok(attend_to)
164}
165
166/// BigBird pattern: global tokens + local window + random sparse connections.
167///
168/// Uses a linear congruential generator (LCG) instead of rand to avoid
169/// external dependencies while producing deterministic pseudo-random connections.
170fn build_bigbird(
171    seq_len: usize,
172    window_size: usize,
173    num_global_tokens: usize,
174    num_random_connections: usize,
175    seed: u64,
176) -> Result<Vec<Vec<usize>>, SparseAttnError> {
177    if window_size % 2 == 0 {
178        return Err(SparseAttnError::WindowSizeMustBeOdd);
179    }
180    let half = window_size / 2;
181    // Clamp global tokens to seq_len
182    let actual_global = num_global_tokens.min(seq_len);
183
184    let mut attend_to: Vec<Vec<usize>> = Vec::with_capacity(seq_len);
185    let mut lcg_state = seed.wrapping_add(0xDEAD_BEEF_CAFE_1234);
186
187    for q in 0..seq_len {
188        let mut keys: std::collections::BTreeSet<usize> = std::collections::BTreeSet::new();
189
190        // 1. Global tokens: attend to all positions
191        for g in 0..actual_global {
192            keys.insert(g);
193        }
194        // 2. All queries attend to all global-token positions (global tokens attend back)
195        for g in 0..actual_global {
196            if q == g {
197                // global token q itself attends to every position
198                for k in 0..seq_len {
199                    keys.insert(k);
200                }
201            }
202        }
203
204        // 3. Local window
205        let start = q.saturating_sub(half);
206        let end = (q + half + 1).min(seq_len);
207        for k in start..end {
208            keys.insert(k);
209        }
210
211        // 4. Random sparse connections (LCG)
212        let num_rand = if seq_len > actual_global + window_size {
213            num_random_connections
214        } else {
215            0
216        };
217        for r in 0..num_rand {
218            // LCG: a=6364136223846793005, c=1442695040888963407 (Knuth)
219            lcg_state = lcg_state
220                .wrapping_mul(6_364_136_223_846_793_005)
221                .wrapping_add(1_442_695_040_888_963_407)
222                .wrapping_add((q as u64).wrapping_mul(137).wrapping_add(r as u64));
223            let k = (lcg_state >> 33) as usize % seq_len;
224            keys.insert(k);
225        }
226
227        attend_to.push(keys.into_iter().collect());
228    }
229
230    Ok(attend_to)
231}
232
233/// Strided pattern: stride positions attend globally; others use local window.
234fn build_strided(
235    seq_len: usize,
236    window_size: usize,
237    stride: usize,
238) -> Result<Vec<Vec<usize>>, SparseAttnError> {
239    if window_size % 2 == 0 {
240        return Err(SparseAttnError::WindowSizeMustBeOdd);
241    }
242    if stride == 0 {
243        // stride=0 degenerates to all-global; treat as dense
244        return Ok(build_dense(seq_len));
245    }
246    let half = window_size / 2;
247
248    let mut attend_to = Vec::with_capacity(seq_len);
249    for q in 0..seq_len {
250        let is_global = (q % stride) == 0;
251        let mut keys: Vec<usize> = if is_global {
252            // stride positions attend to every key
253            (0..seq_len).collect()
254        } else {
255            // local window
256            let start = q.saturating_sub(half);
257            let end = (q + half + 1).min(seq_len);
258            // plus all stride positions (global tokens)
259            let mut ks: std::collections::BTreeSet<usize> = (start..end).collect();
260            let mut g = 0usize;
261            while g < seq_len {
262                ks.insert(g);
263                g += stride;
264            }
265            ks.into_iter().collect()
266        };
267        keys.sort_unstable();
268        keys.dedup();
269        attend_to.push(keys);
270    }
271    Ok(attend_to)
272}
273
274// ─── Sparse attention forward ─────────────────────────────────────────────────
275
276/// Apply sparse attention: compute attention output using a sparse mask.
277///
278/// - `queries`: shape [seq_len, head_dim] (row-major)
279/// - `keys`:    shape [seq_len, head_dim] (row-major)
280/// - `values`:  shape [seq_len, head_dim] (row-major)
281///
282/// Returns: shape [seq_len, head_dim]
283pub fn sparse_attention_forward(
284    queries: &[f32],
285    keys: &[f32],
286    values: &[f32],
287    seq_len: usize,
288    head_dim: usize,
289    mask: &SparseAttentionMask,
290    scale: f32,
291) -> Result<Vec<f32>, SparseAttnError> {
292    validate_inputs(queries, keys, values, seq_len, head_dim)?;
293
294    if mask.nnz() == 0 {
295        return Err(SparseAttnError::EmptyAttention);
296    }
297
298    let mut output = vec![0.0f32; seq_len * head_dim];
299
300    for q in 0..seq_len {
301        let key_positions = mask.keys_for_query(q);
302        if key_positions.is_empty() {
303            // No attention positions: output stays zero for this query
304            continue;
305        }
306
307        let q_vec = &queries[q * head_dim..(q + 1) * head_dim];
308
309        // Compute raw scores for each attended key
310        let mut scores: Vec<f32> = key_positions
311            .iter()
312            .map(|&k| {
313                let k_vec = &keys[k * head_dim..(k + 1) * head_dim];
314                dot_scaled(q_vec, k_vec, scale)
315            })
316            .collect();
317
318        // In-place softmax over the sparse scores
319        softmax_inplace(&mut scores);
320
321        // Weighted sum over values
322        let out_row = &mut output[q * head_dim..(q + 1) * head_dim];
323        for (weight, &k_pos) in scores.iter().zip(key_positions.iter()) {
324            let v_vec = &values[k_pos * head_dim..(k_pos + 1) * head_dim];
325            for (o, &v) in out_row.iter_mut().zip(v_vec.iter()) {
326                *o += weight * v;
327            }
328        }
329    }
330
331    Ok(output)
332}
333
334/// Compare sparse vs dense attention output (MAE).
335///
336/// Computes both with the given mask and with a fully dense mask, then
337/// returns the mean absolute error between the two outputs.
338pub fn sparse_vs_dense_error(
339    queries: &[f32],
340    keys: &[f32],
341    values: &[f32],
342    seq_len: usize,
343    head_dim: usize,
344    mask: &SparseAttentionMask,
345) -> Result<f32, SparseAttnError> {
346    let scale = 1.0 / (head_dim as f32).sqrt();
347
348    let sparse_out =
349        sparse_attention_forward(queries, keys, values, seq_len, head_dim, mask, scale)?;
350
351    let dense_mask = SparseAttentionMask::build(seq_len, &SparsePattern::Dense)
352        .map_err(|_| SparseAttnError::EmptyAttention)?;
353    let dense_out =
354        sparse_attention_forward(queries, keys, values, seq_len, head_dim, &dense_mask, scale)?;
355
356    let total_elements = seq_len * head_dim;
357    if total_elements == 0 {
358        return Ok(0.0);
359    }
360
361    let mae = sparse_out
362        .iter()
363        .zip(dense_out.iter())
364        .map(|(s, d)| (s - d).abs())
365        .sum::<f32>()
366        / total_elements as f32;
367
368    Ok(mae)
369}
370
371/// Memory savings vs dense attention.
372///
373/// Returns the fraction of memory saved: `1.0 - density`.
374/// A value of 0.0 means no savings (dense), 1.0 means all connections removed.
375pub fn memory_reduction(_seq_len: usize, mask: &SparseAttentionMask) -> f32 {
376    1.0 - mask.density()
377}
378
379// ─── Private helpers ──────────────────────────────────────────────────────────
380
381/// Validate input buffer sizes.
382fn validate_inputs(
383    queries: &[f32],
384    keys: &[f32],
385    values: &[f32],
386    seq_len: usize,
387    head_dim: usize,
388) -> Result<(), SparseAttnError> {
389    if head_dim == 0 {
390        return Err(SparseAttnError::InvalidHeadDim);
391    }
392    let expected = seq_len * head_dim;
393    if queries.len() != expected || keys.len() != expected || values.len() != expected {
394        return Err(SparseAttnError::LengthMismatch {
395            q: queries.len(),
396            k: keys.len(),
397            v: values.len(),
398        });
399    }
400    Ok(())
401}
402
403/// Scaled dot product of two vectors.
404#[inline]
405fn dot_scaled(a: &[f32], b: &[f32], scale: f32) -> f32 {
406    a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum::<f32>() * scale
407}
408
409// ─── Unit tests ───────────────────────────────────────────────────────────────
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    fn make_qkv(seq_len: usize, head_dim: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
416        let n = seq_len * head_dim;
417        let q: Vec<f32> = (0..n).map(|i| (i as f32 * 0.03) - 0.5).collect();
418        let k: Vec<f32> = (0..n)
419            .map(|i| ((i * 7 + 3) % 17) as f32 * 0.04 - 0.3)
420            .collect();
421        let v: Vec<f32> = (0..n)
422            .map(|i| ((i * 11 + 5) % 13) as f32 * 0.05 - 0.3)
423            .collect();
424        (q, k, v)
425    }
426
427    #[test]
428    fn dense_mask_full() {
429        let seq_len = 8;
430        let mask = SparseAttentionMask::build(seq_len, &SparsePattern::Dense)
431            .expect("dense build should succeed");
432        assert_eq!(mask.nnz(), seq_len * seq_len);
433    }
434
435    #[test]
436    fn local_window_density_less_than_one() {
437        let seq_len = 16;
438        let mask =
439            SparseAttentionMask::build(seq_len, &SparsePattern::LocalWindow { window_size: 3 })
440                .expect("local window build should succeed");
441        assert!(
442            mask.density() < 1.0,
443            "density should be < 1.0 for local window"
444        );
445    }
446
447    #[test]
448    fn sparse_forward_dense_matches_naive_inline() {
449        let seq_len = 4;
450        let head_dim = 4;
451        let (q, k, v) = make_qkv(seq_len, head_dim);
452        let scale = 1.0 / (head_dim as f32).sqrt();
453        let mask = SparseAttentionMask::build(seq_len, &SparsePattern::Dense).expect("dense mask");
454        let out = sparse_attention_forward(&q, &k, &v, seq_len, head_dim, &mask, scale)
455            .expect("sparse forward failed");
456        assert_eq!(out.len(), seq_len * head_dim);
457    }
458}