use crate::tensor::DenseTensor;
use crate::tensor::traits::{TensorOps, TensorBase};
use crate::tensor::sparse::SparseTensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SparsePattern {
SlidingWindow,
BlockSparse,
Star,
HeadSparse,
}
#[derive(Debug, Clone)]
pub struct SlidingWindowConfig {
pub window_size: usize,
pub causal: bool,
}
impl SlidingWindowConfig {
pub fn new(window_size: usize) -> Self {
Self {
window_size,
causal: true,
}
}
pub fn bidirectional(window_size: usize) -> Self {
Self {
window_size,
causal: false,
}
}
}
#[derive(Debug, Clone)]
pub struct BlockSparseConfig {
pub block_size: usize,
pub num_blocks: usize,
}
impl BlockSparseConfig {
pub fn new(block_size: usize, num_blocks: usize) -> Self {
Self {
block_size,
num_blocks,
}
}
}
#[derive(Debug, Clone)]
pub struct SparseMask {
pub row_offsets: Vec<usize>,
pub col_indices: Vec<usize>,
pub seq_len: usize,
pub nnz: usize,
}
impl SparseMask {
pub fn sliding_window(seq_len: usize, window_size: usize, causal: bool) -> Self {
let mut row_offsets = Vec::with_capacity(seq_len + 1);
let mut col_indices = Vec::new();
row_offsets.push(0);
for i in 0..seq_len {
let start = if causal {
(i + 1).saturating_sub(window_size)
} else {
i.saturating_sub(window_size)
};
let end = if causal {
i + 1
} else {
(i + window_size).min(seq_len)
};
for j in start..end {
col_indices.push(j);
}
row_offsets.push(col_indices.len());
}
let nnz = col_indices.len();
Self {
row_offsets,
col_indices,
seq_len,
nnz,
}
}
pub fn block_sparse(seq_len: usize, block_size: usize, num_blocks: usize) -> Self {
let _num_blocks_total = seq_len.div_ceil(block_size);
let mut row_offsets = Vec::with_capacity(seq_len + 1);
let mut col_indices = Vec::new();
row_offsets.push(0);
for i in 0..seq_len {
let block_id = i / block_size;
for b in 0..num_blocks.min(block_id + 1) {
let src_block = block_id - b;
let start = src_block * block_size;
let end = (start + block_size).min(seq_len);
for j in start..end {
col_indices.push(j);
}
}
row_offsets.push(col_indices.len());
}
let nnz = col_indices.len();
Self {
row_offsets,
col_indices,
seq_len,
nnz,
}
}
pub fn star(seq_len: usize, center_ratio: f64) -> Self {
let num_centers = (seq_len as f64 * center_ratio).ceil() as usize;
let mut row_offsets = Vec::with_capacity(seq_len + 1);
let mut col_indices = Vec::new();
row_offsets.push(0);
for i in 0..seq_len {
if i < num_centers {
for j in 0..seq_len {
col_indices.push(j);
}
} else {
for j in 0..num_centers {
col_indices.push(j);
}
let window_start = i.saturating_sub(64);
let window_end = (i + 64).min(seq_len);
for j in window_start..window_end {
if !col_indices.contains(&j) {
col_indices.push(j);
}
}
}
row_offsets.push(col_indices.len());
}
let nnz = col_indices.len();
Self {
row_offsets,
col_indices,
seq_len,
nnz,
}
}
pub fn to_sparse_tensor(&self, values: Vec<f64>) -> SparseTensor {
let values_tensor = DenseTensor::new(values, vec![self.nnz]);
SparseTensor::csr(
self.row_offsets.clone(),
self.col_indices.clone(),
values_tensor,
[self.seq_len, self.seq_len],
)
}
pub fn sparsity(&self) -> f64 {
let total = self.seq_len * self.seq_len;
1.0 - (self.nnz as f64 / total as f64)
}
pub fn apply(&self, scores: &DenseTensor) -> DenseTensor {
let mut masked = scores.clone();
let data = masked.data_mut();
for i in 0..self.seq_len {
let start = self.row_offsets[i];
let end = self.row_offsets[i + 1];
for j in 0..self.seq_len {
let is_valid = self.col_indices[start..end].contains(&j);
if !is_valid {
let offset = i * self.seq_len + j;
if offset < data.len() {
data[offset] = f64::NEG_INFINITY;
}
}
}
}
masked
}
}
#[derive(Debug, Clone)]
pub struct SparseAttention {
pub pattern: SparsePattern,
pub mask: Option<SparseMask>,
pub window_size: Option<usize>,
pub block_size: Option<usize>,
pub num_blocks: Option<usize>,
pub scale: f64,
}
impl SparseAttention {
pub fn new(pattern: SparsePattern, head_dim: usize) -> Self {
Self {
pattern,
mask: None,
window_size: None,
block_size: None,
num_blocks: None,
scale: 1.0 / (head_dim as f64).sqrt(),
}
}
pub fn sliding_window(head_dim: usize, window_size: usize) -> Self {
let mut self_ = Self::new(SparsePattern::SlidingWindow, head_dim);
self_.window_size = Some(window_size);
self_
}
pub fn block_sparse(head_dim: usize, block_size: usize, num_blocks: usize) -> Self {
let mut self_ = Self::new(SparsePattern::BlockSparse, head_dim);
self_.block_size = Some(block_size);
self_.num_blocks = Some(num_blocks);
self_
}
pub fn star(head_dim: usize, _center_ratio: f64) -> Self {
Self::new(SparsePattern::Star, head_dim)
}
pub fn build_mask(&mut self, seq_len: usize) {
self.mask = Some(match self.pattern {
SparsePattern::SlidingWindow => {
let window_size = self.window_size.unwrap_or(seq_len);
SparseMask::sliding_window(seq_len, window_size, true)
}
SparsePattern::BlockSparse => {
let block_size = self.block_size.unwrap_or(64);
let num_blocks = self.num_blocks.unwrap_or(4);
SparseMask::block_sparse(seq_len, block_size, num_blocks)
}
SparsePattern::Star => {
SparseMask::star(seq_len, 0.1)
}
SparsePattern::HeadSparse => {
SparseMask::sliding_window(seq_len, 64, true)
}
});
}
pub fn forward(
&mut self,
query: &DenseTensor,
key: &DenseTensor,
value: &DenseTensor,
) -> DenseTensor {
let seq_len = query.shape()[2];
if self.mask.is_none() || self.mask.as_ref().unwrap().seq_len != seq_len {
self.build_mask(seq_len);
}
let key_t = key.transpose(None);
let mut scores = query.matmul(&key_t);
scores = scores.scale(self.scale);
if let Some(mask) = &self.mask {
scores = mask.apply(&scores);
}
let attn_weights = scores.softmax(-1);
attn_weights.matmul(value)
}
pub fn sparsity(&self) -> f64 {
self.mask.as_ref().map(|m| m.sparsity()).unwrap_or(0.0)
}
}
pub struct SlidingWindowAttention {
window_size: usize,
scale: f64,
}
impl SlidingWindowAttention {
pub fn new(window_size: usize, head_dim: usize) -> Self {
Self {
window_size,
scale: 1.0 / (head_dim as f64).sqrt(),
}
}
pub fn forward(&self, query: &DenseTensor, key: &DenseTensor, value: &DenseTensor) -> DenseTensor {
let batch_size = query.shape()[0];
let num_heads = query.shape()[1];
let seq_len = query.shape()[2];
let head_dim = query.shape()[3];
let mut output_data = Vec::with_capacity(batch_size * num_heads * seq_len * head_dim);
for b in 0..batch_size {
for h in 0..num_heads {
for i in 0..seq_len {
let mut attn_output = vec![0.0; head_dim];
let mut total_weight = 0.0;
let start = i.saturating_sub(self.window_size);
let end = i + 1;
for j in start..end {
let q_slice = &query.data()[(b * num_heads * seq_len * head_dim + h * seq_len * head_dim + i * head_dim)..];
let k_slice = &key.data()[(b * num_heads * seq_len * head_dim + h * seq_len * head_dim + j * head_dim)..];
let mut score = 0.0;
for d in 0..head_dim {
score += q_slice[d] * k_slice[d];
}
score *= self.scale;
let weight = score.exp();
let v_slice = &value.data()[(b * num_heads * seq_len * head_dim + h * seq_len * head_dim + j * head_dim)..];
#[allow(clippy::needless_range_loop)]
for d in 0..head_dim {
attn_output[d] += weight * v_slice[d];
}
total_weight += weight;
}
if total_weight > 0.0 {
#[allow(clippy::needless_range_loop)]
for d in 0..head_dim {
attn_output[d] /= total_weight;
}
}
output_data.extend(attn_output);
}
}
}
DenseTensor::new(output_data, vec![batch_size, num_heads, seq_len, head_dim])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sliding_window_mask() {
let mask = SparseMask::sliding_window(10, 3, true);
assert_eq!(mask.seq_len, 10);
assert!(mask.nnz < 10 * 10); assert_eq!(mask.row_offsets.len(), 11);
}
#[test]
fn test_block_sparse_mask() {
let mask = SparseMask::block_sparse(16, 4, 2);
assert_eq!(mask.seq_len, 16);
assert!(mask.nnz < 16 * 16);
}
#[test]
fn test_star_mask() {
let mask = SparseMask::star(20, 0.1);
assert_eq!(mask.seq_len, 20);
}
#[test]
fn test_sparsity_calculation() {
let mask = SparseMask::sliding_window(100, 10, true);
let sparsity = mask.sparsity();
assert!(sparsity > 0.8);
assert!(sparsity < 1.0);
}
#[test]
fn test_sparse_attention_sliding_window() {
let mut attn = SparseAttention::sliding_window(64, 10);
attn.build_mask(20);
assert_eq!(attn.pattern, SparsePattern::SlidingWindow);
assert!(attn.mask.is_some());
}
#[test]
fn test_sliding_window_attention_forward() {
let batch_size = 1;
let num_heads = 2;
let seq_len = 8;
let head_dim = 16;
let query = DenseTensor::ones(vec![batch_size, num_heads, seq_len, head_dim]);
let key = DenseTensor::ones(vec![batch_size, num_heads, seq_len, head_dim]);
let value = DenseTensor::ones(vec![batch_size, num_heads, seq_len, head_dim]);
let attn = SlidingWindowAttention::new(4, head_dim);
let output = attn.forward(&query, &key, &value);
assert_eq!(output.shape(), &[batch_size, num_heads, seq_len, head_dim]);
}
}