use std::fmt;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SparsePattern {
LocalWindow,
GlobalLocal,
Random,
BlockSparse,
Sliding,
}
impl fmt::Display for SparsePattern {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SparsePattern::LocalWindow => write!(f, "LocalWindow"),
SparsePattern::GlobalLocal => write!(f, "GlobalLocal"),
SparsePattern::Random => write!(f, "Random"),
SparsePattern::BlockSparse => write!(f, "BlockSparse"),
SparsePattern::Sliding => write!(f, "Sliding"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct SparseAttentionConfig {
pub pattern: SparsePattern,
pub window_size: usize,
pub n_global_tokens: usize,
pub n_random: usize,
pub block_size: usize,
pub n_heads: usize,
pub head_dim: usize,
}
impl Default for SparseAttentionConfig {
fn default() -> Self {
Self {
pattern: SparsePattern::LocalWindow,
window_size: 64,
n_global_tokens: 0,
n_random: 3,
block_size: 64,
n_heads: 8,
head_dim: 64,
}
}
}
impl fmt::Display for SparseAttentionConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SparseAttentionConfig(pattern={}, window={}, n_global={}, n_random={}, block={}, heads={}, head_dim={})",
self.pattern,
self.window_size,
self.n_global_tokens,
self.n_random,
self.block_size,
self.n_heads,
self.head_dim,
)
}
}
#[derive(Debug, Clone)]
pub struct SparseAttentionMask {
pub seq_len: usize,
pub attend_to: Vec<Vec<usize>>,
}
impl SparseAttentionMask {
pub fn n_pairs(&self) -> usize {
self.attend_to.iter().map(|v| v.len()).sum()
}
pub fn density(&self) -> f64 {
if self.seq_len == 0 {
return 0.0;
}
self.n_pairs() as f64 / (self.seq_len * self.seq_len) as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sparse_attention_config_default() {
let cfg = SparseAttentionConfig::default();
assert_eq!(cfg.pattern, SparsePattern::LocalWindow);
assert_eq!(cfg.window_size, 64);
assert_eq!(cfg.n_global_tokens, 0);
assert_eq!(cfg.n_random, 3);
assert_eq!(cfg.block_size, 64);
assert_eq!(cfg.n_heads, 8);
assert_eq!(cfg.head_dim, 64);
}
#[test]
fn sparse_pattern_display() {
assert_eq!(SparsePattern::LocalWindow.to_string(), "LocalWindow");
assert_eq!(SparsePattern::BlockSparse.to_string(), "BlockSparse");
}
#[test]
fn mask_density_empty() {
let mask = SparseAttentionMask {
seq_len: 0,
attend_to: Vec::new(),
};
assert!((mask.density() - 0.0).abs() < 1e-12);
}
#[test]
fn mask_density_full() {
let seq_len = 4;
let attend_to = vec![vec![0, 1, 2, 3]; seq_len];
let mask = SparseAttentionMask { seq_len, attend_to };
assert!((mask.density() - 1.0).abs() < 1e-10);
}
}