use crate::layers::linear::SparseLinear;
use crate::{CooTensor, CsrTensor, TorshResult};
use torsh_core::{Shape, TorshError};
use torsh_tensor::{creation::zeros, Tensor};
#[derive(Debug, Clone)]
pub struct SparseAttention {
query_proj: SparseLinear,
key_proj: SparseLinear,
value_proj: SparseLinear,
output_proj: SparseLinear,
num_heads: usize,
head_dim: usize,
model_dim: usize,
#[allow(dead_code)]
dropout: f32,
scale: f32,
}
impl SparseAttention {
pub fn new(
model_dim: usize,
num_heads: usize,
sparsity: f32,
dropout: f32,
) -> TorshResult<Self> {
if model_dim % num_heads != 0 {
return Err(TorshError::InvalidArgument(
"Model dimension must be divisible by number of heads".to_string(),
));
}
if !(0.0..=1.0).contains(&sparsity) {
return Err(TorshError::InvalidArgument(
"Sparsity must be between 0.0 and 1.0".to_string(),
));
}
if !(0.0..=1.0).contains(&dropout) {
return Err(TorshError::InvalidArgument(
"Dropout must be between 0.0 and 1.0".to_string(),
));
}
let head_dim = model_dim / num_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let query_proj = SparseLinear::new(model_dim, model_dim, sparsity, false)?;
let key_proj = SparseLinear::new(model_dim, model_dim, sparsity, false)?;
let value_proj = SparseLinear::new(model_dim, model_dim, sparsity, false)?;
let output_proj = SparseLinear::new(model_dim, model_dim, sparsity, false)?;
Ok(Self {
query_proj,
key_proj,
value_proj,
output_proj,
num_heads,
head_dim,
model_dim,
dropout,
scale,
})
}
pub fn forward(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attention_mask: Option<&CsrTensor>,
) -> TorshResult<Tensor> {
let batch_size = query.shape().dims()[0];
let seq_len = query.shape().dims()[1];
self.validate_inputs(query, key, value)?;
let query_2d = self.reshape_3d_to_2d(query, batch_size, seq_len)?;
let key_2d = self.reshape_3d_to_2d(key, batch_size, seq_len)?;
let value_2d = self.reshape_3d_to_2d(value, batch_size, seq_len)?;
let q_2d = self.query_proj.forward(&query_2d)?;
let k_2d = self.key_proj.forward(&key_2d)?;
let v_2d = self.value_proj.forward(&value_2d)?;
let q = self.reshape_2d_to_3d(&q_2d, batch_size, seq_len)?;
let k = self.reshape_2d_to_3d(&k_2d, batch_size, seq_len)?;
let v = self.reshape_2d_to_3d(&v_2d, batch_size, seq_len)?;
let q_reshaped = self.reshape_for_attention(&q, batch_size, seq_len)?;
let k_reshaped = self.reshape_for_attention(&k, batch_size, seq_len)?;
let v_reshaped = self.reshape_for_attention(&v, batch_size, seq_len)?;
let attention_output = self.compute_sparse_attention(
&q_reshaped,
&k_reshaped,
&v_reshaped,
batch_size,
seq_len,
attention_mask,
)?;
let output_reshaped =
self.reshape_from_attention(&attention_output, batch_size, seq_len)?;
let output_2d = self.reshape_3d_to_2d(&output_reshaped, batch_size, seq_len)?;
let projected_2d = self.output_proj.forward(&output_2d)?;
self.reshape_2d_to_3d(&projected_2d, batch_size, seq_len)
}
pub fn self_attention(
&self,
input: &Tensor,
attention_mask: Option<&CsrTensor>,
) -> TorshResult<Tensor> {
self.forward(input, input, input, attention_mask)
}
pub fn create_local_attention_mask(
seq_len: usize,
window_size: usize,
) -> TorshResult<CsrTensor> {
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
let mut values = Vec::new();
for i in 0..seq_len {
let start = i.saturating_sub(window_size);
let end = std::cmp::min(i + window_size + 1, seq_len);
for j in start..end {
row_indices.push(i);
col_indices.push(j);
values.push(1.0);
}
}
let shape = Shape::new(vec![seq_len, seq_len]);
let coo = CooTensor::new(row_indices, col_indices, values, shape)?;
CsrTensor::from_coo(&coo)
}
pub fn create_strided_attention_mask(
seq_len: usize,
stride: usize,
local_window: usize,
) -> TorshResult<CsrTensor> {
if stride == 0 {
return Err(TorshError::InvalidArgument(
"Stride must be greater than 0".to_string(),
));
}
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
let mut values = Vec::new();
for i in 0..seq_len {
let local_start = i.saturating_sub(local_window);
let local_end = std::cmp::min(i + local_window + 1, seq_len);
for j in local_start..local_end {
row_indices.push(i);
col_indices.push(j);
values.push(1.0);
}
let mut j = i % stride;
while j < seq_len {
if j < local_start || j >= local_end {
row_indices.push(i);
col_indices.push(j);
values.push(1.0);
}
j += stride;
}
}
let shape = Shape::new(vec![seq_len, seq_len]);
let coo = CooTensor::new(row_indices, col_indices, values, shape)?;
CsrTensor::from_coo(&coo)
}
fn validate_inputs(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> TorshResult<()> {
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.ndim() != 3 || k_shape.ndim() != 3 || v_shape.ndim() != 3 {
return Err(TorshError::InvalidArgument(
"Input tensors must be 3D (batch, seq_len, model_dim)".to_string(),
));
}
if q_shape.dims()[0] != k_shape.dims()[0] || q_shape.dims()[0] != v_shape.dims()[0] {
return Err(TorshError::InvalidArgument(
"Batch sizes must match across Q, K, V".to_string(),
));
}
if k_shape.dims()[1] != v_shape.dims()[1] {
return Err(TorshError::InvalidArgument(
"Key and Value sequence lengths must match".to_string(),
));
}
if q_shape.dims()[2] != self.model_dim {
return Err(TorshError::InvalidArgument(
"Query dimension doesn't match model dimension".to_string(),
));
}
if k_shape.dims()[2] != self.model_dim || v_shape.dims()[2] != self.model_dim {
return Err(TorshError::InvalidArgument(
"Key/Value dimensions don't match model dimension".to_string(),
));
}
Ok(())
}
fn reshape_for_attention(
&self,
tensor: &Tensor,
batch_size: usize,
seq_len: usize,
) -> TorshResult<Tensor> {
let reshaped = zeros::<f32>(&[batch_size, self.num_heads, seq_len, self.head_dim])?;
for b in 0..batch_size {
for s in 0..seq_len {
for h in 0..self.num_heads {
for d in 0..self.head_dim {
let model_idx = h * self.head_dim + d;
let val = tensor.get(&[b, s, model_idx])?;
reshaped.set(&[b, h, s, d], val)?;
}
}
}
}
Ok(reshaped)
}
fn reshape_from_attention(
&self,
tensor: &Tensor,
batch_size: usize,
seq_len: usize,
) -> TorshResult<Tensor> {
let reshaped = zeros::<f32>(&[batch_size, seq_len, self.model_dim])?;
for b in 0..batch_size {
for s in 0..seq_len {
for h in 0..self.num_heads {
for d in 0..self.head_dim {
let model_idx = h * self.head_dim + d;
let val = tensor.get(&[b, h, s, d])?;
reshaped.set(&[b, s, model_idx], val)?;
}
}
}
}
Ok(reshaped)
}
#[allow(clippy::too_many_arguments)]
fn compute_sparse_attention(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
batch_size: usize,
seq_len: usize,
attention_mask: Option<&CsrTensor>,
) -> TorshResult<Tensor> {
let output = zeros::<f32>(&[batch_size, self.num_heads, seq_len, self.head_dim])?;
for b in 0..batch_size {
for h in 0..self.num_heads {
let scores = zeros::<f32>(&[seq_len, seq_len])?;
for i in 0..seq_len {
for j in 0..seq_len {
let should_compute = if let Some(mask) = attention_mask {
let (cols, _) = mask.get_row(i)?;
cols.contains(&j)
} else {
true };
if should_compute {
let mut score = 0.0;
for d in 0..self.head_dim {
score += query.get(&[b, h, i, d])? * key.get(&[b, h, j, d])?;
}
scores.set(&[i, j], score * self.scale)?;
} else {
scores.set(&[i, j], f32::NEG_INFINITY)?;
}
}
}
for i in 0..seq_len {
let mut max_score = f32::NEG_INFINITY;
for j in 0..seq_len {
let score = scores.get(&[i, j])?;
if score > max_score && score != f32::NEG_INFINITY {
max_score = score;
}
}
let mut sum_exp = 0.0;
let mut exp_scores = vec![0.0; seq_len];
#[allow(clippy::needless_range_loop)]
for j in 0..seq_len {
let score = scores.get(&[i, j])?;
if score != f32::NEG_INFINITY {
exp_scores[j] = (score - max_score).exp();
sum_exp += exp_scores[j];
}
}
for d in 0..self.head_dim {
let mut weighted_sum = 0.0;
#[allow(clippy::needless_range_loop)]
for j in 0..seq_len {
if exp_scores[j] > 0.0 {
let attention_weight = exp_scores[j] / sum_exp;
weighted_sum += attention_weight * value.get(&[b, h, j, d])?;
}
}
output.set(&[b, h, i, d], weighted_sum)?;
}
}
}
}
Ok(output)
}
pub fn num_parameters(&self) -> usize {
self.query_proj.num_parameters()
+ self.key_proj.num_parameters()
+ self.value_proj.num_parameters()
+ self.output_proj.num_parameters()
}
pub fn model_dim(&self) -> usize {
self.model_dim
}
pub fn num_heads(&self) -> usize {
self.num_heads
}
pub fn head_dim(&self) -> usize {
self.head_dim
}
pub fn scale(&self) -> f32 {
self.scale
}
fn reshape_3d_to_2d(
&self,
tensor: &Tensor,
batch_size: usize,
seq_len: usize,
) -> TorshResult<Tensor> {
let reshaped = zeros::<f32>(&[batch_size * seq_len, self.model_dim])?;
for b in 0..batch_size {
for s in 0..seq_len {
for d in 0..self.model_dim {
let val = tensor.get(&[b, s, d])?;
reshaped.set(&[b * seq_len + s, d], val)?;
}
}
}
Ok(reshaped)
}
fn reshape_2d_to_3d(
&self,
tensor: &Tensor,
batch_size: usize,
seq_len: usize,
) -> TorshResult<Tensor> {
let reshaped = zeros::<f32>(&[batch_size, seq_len, self.model_dim])?;
for b in 0..batch_size {
for s in 0..seq_len {
for d in 0..self.model_dim {
let val = tensor.get(&[b * seq_len + s, d])?;
reshaped.set(&[b, s, d], val)?;
}
}
}
Ok(reshaped)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SparseTensor;
use torsh_tensor::creation::ones;
#[test]
fn test_sparse_attention_creation() {
let attention = SparseAttention::new(64, 8, 0.5, 0.1).unwrap();
assert_eq!(attention.model_dim(), 64);
assert_eq!(attention.num_heads(), 8);
assert_eq!(attention.head_dim(), 8);
assert!(attention.num_parameters() > 0);
}
#[test]
fn test_invalid_model_dim() {
assert!(SparseAttention::new(65, 8, 0.5, 0.1).is_err());
}
#[test]
fn test_invalid_sparsity() {
assert!(SparseAttention::new(64, 8, 1.5, 0.1).is_err());
assert!(SparseAttention::new(64, 8, -0.1, 0.1).is_err());
}
#[test]
fn test_sparse_attention_forward() {
let attention = SparseAttention::new(32, 4, 0.3, 0.1).unwrap();
let query = ones::<f32>(&[2, 5, 32]).unwrap();
let key = ones::<f32>(&[2, 5, 32]).unwrap();
let value = ones::<f32>(&[2, 5, 32]).unwrap();
let output = attention.forward(&query, &key, &value, None).unwrap();
assert_eq!(output.shape().dims(), &[2, 5, 32]);
}
#[test]
fn test_self_attention() {
let attention = SparseAttention::new(16, 2, 0.4, 0.0).unwrap();
let input = ones::<f32>(&[1, 4, 16]).unwrap();
let output = attention.self_attention(&input, None).unwrap();
assert_eq!(output.shape().dims(), &[1, 4, 16]);
}
#[test]
fn test_local_attention_mask() {
let mask = SparseAttention::create_local_attention_mask(5, 1).unwrap();
assert_eq!(mask.shape().dims(), &[5, 5]);
assert!(mask.nnz() > 0);
assert!(mask.nnz() <= 15); }
#[test]
fn test_strided_attention_mask() {
let mask = SparseAttention::create_strided_attention_mask(8, 2, 1).unwrap();
assert_eq!(mask.shape().dims(), &[8, 8]);
assert!(mask.nnz() > 0);
}
#[test]
fn test_attention_with_local_mask() {
let attention = SparseAttention::new(16, 2, 0.2, 0.0).unwrap();
let input = ones::<f32>(&[1, 4, 16]).unwrap();
let mask = SparseAttention::create_local_attention_mask(4, 1).unwrap();
let output = attention.self_attention(&input, Some(&mask)).unwrap();
assert_eq!(output.shape().dims(), &[1, 4, 16]);
}
#[test]
fn test_dimension_validation() {
let attention = SparseAttention::new(32, 4, 0.3, 0.1).unwrap();
let wrong_query = ones::<f32>(&[2, 5, 16]).unwrap(); let key = ones::<f32>(&[2, 5, 32]).unwrap();
let value = ones::<f32>(&[2, 5, 32]).unwrap();
assert!(attention.forward(&wrong_query, &key, &value, None).is_err());
}
#[test]
fn test_batch_size_validation() {
let attention = SparseAttention::new(16, 2, 0.3, 0.1).unwrap();
let query = ones::<f32>(&[2, 5, 16]).unwrap();
let key = ones::<f32>(&[3, 5, 16]).unwrap(); let value = ones::<f32>(&[2, 5, 16]).unwrap();
assert!(attention.forward(&query, &key, &value, None).is_err());
}
#[test]
fn test_sequence_length_validation() {
let attention = SparseAttention::new(16, 2, 0.3, 0.1).unwrap();
let query = ones::<f32>(&[2, 5, 16]).unwrap();
let key = ones::<f32>(&[2, 4, 16]).unwrap(); let value = ones::<f32>(&[2, 4, 16]).unwrap();
assert!(attention.forward(&query, &key, &value, None).is_err());
}
#[test]
fn test_invalid_stride() {
assert!(SparseAttention::create_strided_attention_mask(8, 0, 1).is_err());
}
}