use crate::layers::linear::SparseLinear;
use crate::{CsrTensor, TorshResult};
use torsh_core::TorshError;
use torsh_tensor::{creation::zeros, Tensor};
pub struct SparseAttention {
query_proj: SparseLinear,
key_proj: SparseLinear,
value_proj: SparseLinear,
output_proj: SparseLinear,
num_heads: usize,
head_dim: usize,
model_dim: usize,
#[allow(dead_code)]
dropout: f32,
scale: f32,
}
impl SparseAttention {
pub fn new(
model_dim: usize,
num_heads: usize,
sparsity: f32,
dropout: f32,
) -> TorshResult<Self> {
if model_dim % num_heads != 0 {
return Err(TorshError::InvalidArgument(
"Model dimension must be divisible by number of heads".to_string(),
));
}
let head_dim = model_dim / num_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let query_proj = SparseLinear::new(model_dim, model_dim, sparsity, false)?;
let key_proj = SparseLinear::new(model_dim, model_dim, sparsity, false)?;
let value_proj = SparseLinear::new(model_dim, model_dim, sparsity, false)?;
let output_proj = SparseLinear::new(model_dim, model_dim, sparsity, false)?;
Ok(Self {
query_proj,
key_proj,
value_proj,
output_proj,
num_heads,
head_dim,
model_dim,
dropout,
scale,
})
}
pub fn forward(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attention_mask: Option<&CsrTensor>,
) -> TorshResult<Tensor> {
let batch_size = query.shape().dims()[0];
let seq_len = query.shape().dims()[1];
self.validate_inputs(query, key, value)?;
let query_2d = self.reshape_3d_to_2d(query, batch_size, seq_len)?;
let key_2d = self.reshape_3d_to_2d(key, batch_size, seq_len)?;
let value_2d = self.reshape_3d_to_2d(value, batch_size, seq_len)?;
let q_2d = self.query_proj.forward(&query_2d)?;
let k_2d = self.key_proj.forward(&key_2d)?;
let v_2d = self.value_proj.forward(&value_2d)?;
let q = self.reshape_2d_to_3d(&q_2d, batch_size, seq_len)?;
let k = self.reshape_2d_to_3d(&k_2d, batch_size, seq_len)?;
let v = self.reshape_2d_to_3d(&v_2d, batch_size, seq_len)?;
let q_reshaped = self.reshape_for_attention(&q, batch_size, seq_len)?;
let k_reshaped = self.reshape_for_attention(&k, batch_size, seq_len)?;
let v_reshaped = self.reshape_for_attention(&v, batch_size, seq_len)?;
let attention_output = self.compute_sparse_attention(
&q_reshaped,
&k_reshaped,
&v_reshaped,
batch_size,
seq_len,
attention_mask,
)?;
let output_reshaped =
self.reshape_from_attention(&attention_output, batch_size, seq_len)?;
let output_2d = self.reshape_3d_to_2d(&output_reshaped, batch_size, seq_len)?;
let projected_2d = self.output_proj.forward(&output_2d)?;
self.reshape_2d_to_3d(&projected_2d, batch_size, seq_len)
}
fn validate_inputs(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> TorshResult<()> {
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.ndim() != 3 || k_shape.ndim() != 3 || v_shape.ndim() != 3 {
return Err(TorshError::InvalidArgument(
"Input tensors must be 3D (batch, seq_len, model_dim)".to_string(),
));
}
if q_shape.dims()[2] != self.model_dim {
return Err(TorshError::InvalidArgument(
"Query dimension doesn't match model dimension".to_string(),
));
}
if k_shape.dims()[2] != self.model_dim || v_shape.dims()[2] != self.model_dim {
return Err(TorshError::InvalidArgument(
"Key/Value dimensions don't match model dimension".to_string(),
));
}
Ok(())
}
fn reshape_for_attention(
&self,
tensor: &Tensor,
batch_size: usize,
seq_len: usize,
) -> TorshResult<Tensor> {
let reshaped = zeros::<f32>(&[batch_size, self.num_heads, seq_len, self.head_dim])?;
for b in 0..batch_size {
for s in 0..seq_len {
for h in 0..self.num_heads {
for d in 0..self.head_dim {
let model_idx = h * self.head_dim + d;
let val = tensor.get(&[b, s, model_idx])?;
reshaped.set(&[b, h, s, d], val)?;
}
}
}
}
Ok(reshaped)
}
fn reshape_from_attention(
&self,
tensor: &Tensor,
batch_size: usize,
seq_len: usize,
) -> TorshResult<Tensor> {
let reshaped = zeros::<f32>(&[batch_size, seq_len, self.model_dim])?;
for b in 0..batch_size {
for s in 0..seq_len {
for h in 0..self.num_heads {
for d in 0..self.head_dim {
let model_idx = h * self.head_dim + d;
let val = tensor.get(&[b, h, s, d])?;
reshaped.set(&[b, s, model_idx], val)?;
}
}
}
}
Ok(reshaped)
}
#[allow(clippy::too_many_arguments)]
fn compute_sparse_attention(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
batch_size: usize,
seq_len: usize,
attention_mask: Option<&CsrTensor>,
) -> TorshResult<Tensor> {
let output = zeros::<f32>(&[batch_size, self.num_heads, seq_len, self.head_dim])?;
for b in 0..batch_size {
for h in 0..self.num_heads {
let scores = zeros::<f32>(&[seq_len, seq_len])?;
for i in 0..seq_len {
for j in 0..seq_len {
let should_compute = if let Some(mask) = attention_mask {
let (cols, _) = mask.get_row(i)?;
cols.contains(&j)
} else {
true };
if should_compute {
let mut score = 0.0;
for d in 0..self.head_dim {
score += query.get(&[b, h, i, d])? * key.get(&[b, h, j, d])?;
}
scores.set(&[i, j], score * self.scale)?;
} else {
scores.set(&[i, j], f32::NEG_INFINITY)?;
}
}
}
for i in 0..seq_len {
let mut max_score = f32::NEG_INFINITY;
for j in 0..seq_len {
let score = scores.get(&[i, j])?;
if score > max_score && score != f32::NEG_INFINITY {
max_score = score;
}
}
let mut sum_exp = 0.0;
let mut exp_scores = vec![0.0; seq_len];
#[allow(clippy::needless_range_loop)]
for j in 0..seq_len {
let score = scores.get(&[i, j])?;
if score != f32::NEG_INFINITY {
exp_scores[j] = (score - max_score).exp();
sum_exp += exp_scores[j];
}
}
for d in 0..self.head_dim {
let mut weighted_sum = 0.0;
#[allow(clippy::needless_range_loop)]
for j in 0..seq_len {
if exp_scores[j] > 0.0 {
let attention_weight = exp_scores[j] / sum_exp;
weighted_sum += attention_weight * value.get(&[b, h, j, d])?;
}
}
output.set(&[b, h, i, d], weighted_sum)?;
}
}
}
}
Ok(output)
}
pub fn num_parameters(&self) -> usize {
self.query_proj.num_parameters()
+ self.key_proj.num_parameters()
+ self.value_proj.num_parameters()
+ self.output_proj.num_parameters()
}
pub fn model_dim(&self) -> usize {
self.model_dim
}
pub fn num_heads(&self) -> usize {
self.num_heads
}
fn reshape_3d_to_2d(
&self,
tensor: &Tensor,
batch_size: usize,
seq_len: usize,
) -> TorshResult<Tensor> {
let reshaped = zeros::<f32>(&[batch_size * seq_len, self.model_dim])?;
for b in 0..batch_size {
for s in 0..seq_len {
for d in 0..self.model_dim {
let val = tensor.get(&[b, s, d])?;
reshaped.set(&[b * seq_len + s, d], val)?;
}
}
}
Ok(reshaped)
}
fn reshape_2d_to_3d(
&self,
tensor: &Tensor,
batch_size: usize,
seq_len: usize,
) -> TorshResult<Tensor> {
let reshaped = zeros::<f32>(&[batch_size, seq_len, self.model_dim])?;
for b in 0..batch_size {
for s in 0..seq_len {
for d in 0..self.model_dim {
let val = tensor.get(&[b * seq_len + s, d])?;
reshaped.set(&[b, s, d], val)?;
}
}
}
Ok(reshaped)
}
}