use super::mask::AttentionMaskBuilder;
use super::types::{SparseAttentionConfig, SparseAttentionMask};
pub struct SparseAttention {
config: SparseAttentionConfig,
mask_builder: AttentionMaskBuilder,
}
impl SparseAttention {
pub fn new(config: SparseAttentionConfig) -> Self {
let mask_builder = AttentionMaskBuilder::new(config.clone());
Self {
config,
mask_builder,
}
}
pub fn forward(
&self,
q: &[f64],
k: &[f64],
v: &[f64],
seq_len: usize,
global_indices: &[usize],
) -> Vec<f64> {
let n_heads = self.config.n_heads;
let head_dim = self.config.head_dim;
let total = seq_len * n_heads * head_dim;
if total == 0 || q.len() < total || k.len() < total || v.len() < total {
return vec![0.0; total];
}
let mask = self.mask_builder.build(seq_len, global_indices);
let scale = 1.0 / (head_dim as f64).sqrt();
let mut output = vec![0.0f64; total];
for i in 0..seq_len {
let allowed = &mask.attend_to[i];
if allowed.is_empty() {
continue;
}
for h in 0..n_heads {
let q_offset = i * n_heads * head_dim + h * head_dim;
let q_vec: &[f64] = &q[q_offset..q_offset + head_dim];
let k_positions: Vec<Vec<f64>> = allowed
.iter()
.map(|&j| {
let k_off = j * n_heads * head_dim + h * head_dim;
k[k_off..k_off + head_dim].to_vec()
})
.collect();
let v_positions: Vec<Vec<f64>> = allowed
.iter()
.map(|&j| {
let v_off = j * n_heads * head_dim + h * head_dim;
v[v_off..v_off + head_dim].to_vec()
})
.collect();
let weights = Self::softmax_attend(q_vec, &k_positions, scale);
let mut attn_out = vec![0.0f64; head_dim];
for (w, v_vec) in weights.iter().zip(v_positions.iter()) {
for d in 0..head_dim {
attn_out[d] += w * v_vec[d];
}
}
let out_offset = i * n_heads * head_dim + h * head_dim;
output[out_offset..out_offset + head_dim].copy_from_slice(&attn_out);
}
}
output
}
pub fn attend_to_positions(
q_i: &[f64],
k_positions: &[Vec<f64>],
v_positions: &[Vec<f64>],
scale: f64,
) -> Vec<f64> {
if k_positions.is_empty() || q_i.is_empty() {
return vec![0.0; q_i.len()];
}
let n_heads_dim = q_i.len();
let weights = Self::softmax_attend(q_i, k_positions, scale);
let mut out = vec![0.0f64; n_heads_dim];
for (w, v_vec) in weights.iter().zip(v_positions.iter()) {
let d = v_vec.len().min(n_heads_dim);
for i in 0..d {
out[i] += w * v_vec[i];
}
}
out
}
pub fn softmax_attend(q: &[f64], keys: &[Vec<f64>], scale: f64) -> Vec<f64> {
if keys.is_empty() {
return Vec::new();
}
let scores: Vec<f64> = keys
.iter()
.map(|k| {
let dot: f64 = q.iter().zip(k.iter()).map(|(&qi, &ki)| qi * ki).sum();
dot * scale
})
.collect();
let max_score = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum <= 0.0 || sum.is_nan() {
let u = 1.0 / keys.len() as f64;
return vec![u; keys.len()];
}
exps.iter().map(|&e| e / sum).collect()
}
pub fn config(&self) -> &SparseAttentionConfig {
&self.config
}
pub fn build_mask(&self, seq_len: usize, global_indices: &[usize]) -> SparseAttentionMask {
self.mask_builder.build(seq_len, global_indices)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::attention::sparse::{SparseAttentionConfig, SparsePattern};
fn make_sa(
pattern: SparsePattern,
window: usize,
n_heads: usize,
head_dim: usize,
) -> SparseAttention {
let cfg = SparseAttentionConfig {
pattern,
window_size: window,
n_heads,
head_dim,
..Default::default()
};
SparseAttention::new(cfg)
}
fn zeros(seq_len: usize, n_heads: usize, head_dim: usize) -> Vec<f64> {
vec![0.0; seq_len * n_heads * head_dim]
}
#[test]
fn forward_output_shape_correct() {
let sa = make_sa(SparsePattern::LocalWindow, 2, 4, 8);
let seq_len = 10;
let n = seq_len * 4 * 8;
let q: Vec<f64> = (0..n).map(|i| i as f64 * 0.001).collect();
let k = q.clone();
let v = q.clone();
let out = sa.forward(&q, &k, &v, seq_len, &[]);
assert_eq!(out.len(), n, "output length should equal seq * heads * dim");
}
#[test]
fn forward_empty_sequence_returns_empty() {
let sa = make_sa(SparsePattern::LocalWindow, 2, 2, 4);
let out = sa.forward(&[], &[], &[], 0, &[]);
assert!(out.is_empty());
}
#[test]
fn softmax_attend_weights_sum_to_one() {
let q = vec![1.0, 0.0, -1.0];
let keys = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let weights = SparseAttention::softmax_attend(&q, &keys, 1.0);
let sum: f64 = weights.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"attention weights should sum to 1.0, got {sum}"
);
for &w in &weights {
assert!(w >= 0.0, "weights must be non-negative");
}
}
#[test]
fn softmax_attend_no_keys_returns_empty() {
let q = vec![1.0, 2.0];
let weights = SparseAttention::softmax_attend(&q, &[], 1.0);
assert!(weights.is_empty());
}
#[test]
fn softmax_attend_single_key_weight_is_one() {
let q = vec![0.5, -0.5];
let keys = vec![vec![1.0, 0.5]];
let weights = SparseAttention::softmax_attend(&q, &keys, 0.5);
assert_eq!(weights.len(), 1);
assert!((weights[0] - 1.0).abs() < 1e-10);
}
#[test]
fn attend_to_positions_output_length() {
let dim = 6;
let q = vec![0.1f64; dim];
let k = vec![vec![0.2f64; dim], vec![0.3f64; dim]];
let v = k.clone();
let out = SparseAttention::attend_to_positions(&q, &k, &v, 1.0);
assert_eq!(out.len(), dim);
}
#[test]
fn forward_v_zeros_output_zeros() {
let (seq_len, n_heads, head_dim) = (6, 2, 4);
let sa = make_sa(SparsePattern::LocalWindow, 1, n_heads, head_dim);
let n = seq_len * n_heads * head_dim;
let q: Vec<f64> = (0..n).map(|i| i as f64).collect();
let k = q.clone();
let v = zeros(seq_len, n_heads, head_dim);
let out = sa.forward(&q, &k, &v, seq_len, &[]);
for &x in &out {
assert!(x.abs() < 1e-12, "V=0 → output should be 0, got {x}");
}
}
#[test]
fn forward_single_token_equal_to_v() {
let (n_heads, head_dim) = (1, 4);
let seq_len = 1;
let sa = make_sa(SparsePattern::LocalWindow, 1, n_heads, head_dim);
let v: Vec<f64> = (0..head_dim).map(|i| (i + 1) as f64).collect();
let out = sa.forward(&v, &v, &v, seq_len, &[]);
assert_eq!(out.len(), n_heads * head_dim);
for (a, b) in out.iter().zip(v.iter()) {
assert!((a - b).abs() < 1e-9, "single token: out={a}, v={b}");
}
}
#[test]
fn build_mask_local_window() {
let sa = make_sa(SparsePattern::LocalWindow, 2, 2, 4);
let mask = sa.build_mask(8, &[]);
assert_eq!(mask.attend_to[0], vec![0, 1, 2]);
}
#[test]
fn sparse_attention_config_default_works() {
let cfg = SparseAttentionConfig::default();
let _sa = SparseAttention::new(cfg);
}
}