use serde::{Deserialize, Serialize};
use tensorlogic_ir::{EinsumGraph, EinsumNode};
use crate::{
config::AttentionConfig,
error::{Result, TrustformerError},
};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum SparsePatternType {
Strided { stride: usize },
Local { window_size: usize },
GlobalLocal {
window_size: usize,
global_positions: Vec<usize>,
},
BlockSparse { block_size: usize },
Random { num_random: usize },
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct SparseAttentionGraphConfig {
pub base_attention: AttentionConfig,
pub pattern: SparsePatternType,
pub exact_sparse: bool,
}
impl SparseAttentionGraphConfig {
pub fn strided(base_attention: AttentionConfig, stride: usize) -> Result<Self> {
if stride == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "stride must be positive".to_string(),
});
}
Ok(Self {
base_attention,
pattern: SparsePatternType::Strided { stride },
exact_sparse: true,
})
}
pub fn local(base_attention: AttentionConfig, window_size: usize) -> Result<Self> {
if window_size == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "window_size must be positive".to_string(),
});
}
Ok(Self {
base_attention,
pattern: SparsePatternType::Local { window_size },
exact_sparse: true,
})
}
pub fn global_local(
base_attention: AttentionConfig,
window_size: usize,
global_positions: Vec<usize>,
) -> Result<Self> {
if window_size == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "window_size must be positive".to_string(),
});
}
Ok(Self {
base_attention,
pattern: SparsePatternType::GlobalLocal {
window_size,
global_positions,
},
exact_sparse: true,
})
}
pub fn block_sparse(base_attention: AttentionConfig, block_size: usize) -> Result<Self> {
if block_size == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "block_size must be positive".to_string(),
});
}
Ok(Self {
base_attention,
pattern: SparsePatternType::BlockSparse { block_size },
exact_sparse: true,
})
}
pub fn with_exact_sparse(mut self, exact_sparse: bool) -> Self {
self.exact_sparse = exact_sparse;
self
}
pub fn validate(&self) -> Result<()> {
self.base_attention.validate()?;
match &self.pattern {
SparsePatternType::Strided { stride } => {
if *stride == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "stride must be positive".to_string(),
});
}
}
SparsePatternType::Local { window_size } => {
if *window_size == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "window_size must be positive".to_string(),
});
}
}
SparsePatternType::GlobalLocal {
window_size,
global_positions: _,
} => {
if *window_size == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "window_size must be positive".to_string(),
});
}
}
SparsePatternType::BlockSparse { block_size } => {
if *block_size == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "block_size must be positive".to_string(),
});
}
}
SparsePatternType::Random { num_random } => {
if *num_random == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "num_random must be positive".to_string(),
});
}
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct SparseAttentionGraph {
pub config: SparseAttentionGraphConfig,
}
impl SparseAttentionGraph {
pub fn new(config: SparseAttentionGraphConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn build_sparse_attention_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let scores_tensor = graph.add_tensor("sparse_attn_scores");
let scores_node = EinsumNode::new("bqd,bkd->bqk", vec![0, 1], vec![scores_tensor]);
graph.add_node(scores_node)?;
let scale_factor = (self.config.base_attention.d_k as f64).sqrt();
let scale_tensor = graph.add_tensor("sparse_scale");
let scaled_tensor = graph.add_tensor("sparse_scaled_scores");
let scale_node = EinsumNode::elem_binary(
format!("div_scalar_{}", scale_factor),
scores_tensor,
scale_tensor,
scaled_tensor,
);
graph.add_node(scale_node)?;
let masked_tensor = graph.add_tensor("sparse_masked_scores");
let mask_node = EinsumNode::elem_binary("mul", scaled_tensor, 3, masked_tensor);
graph.add_node(mask_node)?;
let softmax_tensor = graph.add_tensor("sparse_attention_weights");
let softmax_node =
EinsumNode::elem_unary("sparse_softmax_k", masked_tensor, softmax_tensor);
graph.add_node(softmax_node)?;
let output_tensor = graph.add_tensor("sparse_attn_output");
let output_node =
EinsumNode::new("bqk,bkv->bqv", vec![softmax_tensor, 2], vec![output_tensor]);
graph.add_node(output_node)?;
Ok(vec![output_tensor])
}
pub fn sparsity_factor(&self, seq_len: usize) -> f64 {
match &self.config.pattern {
SparsePatternType::Strided { stride } => 1.0 / (*stride as f64),
SparsePatternType::Local { window_size } => {
(*window_size as f64).min(seq_len as f64) / (seq_len as f64)
}
SparsePatternType::GlobalLocal {
window_size,
global_positions,
} => {
let local_fraction = (*window_size as f64) / (seq_len as f64);
let global_fraction = (global_positions.len() as f64) / (seq_len as f64);
(local_fraction + global_fraction).min(1.0)
}
SparsePatternType::BlockSparse { block_size } => {
(*block_size as f64) / (seq_len as f64)
}
SparsePatternType::Random { num_random } => (*num_random as f64) / (seq_len as f64),
}
}
pub fn pattern_description(&self) -> String {
match &self.config.pattern {
SparsePatternType::Strided { stride } => {
format!("Strided(stride={})", stride)
}
SparsePatternType::Local { window_size } => {
format!("Local(window={})", window_size)
}
SparsePatternType::GlobalLocal {
window_size,
global_positions,
} => {
format!(
"GlobalLocal(window={}, global_tokens={})",
window_size,
global_positions.len()
)
}
SparsePatternType::BlockSparse { block_size } => {
format!("BlockSparse(block={})", block_size)
}
SparsePatternType::Random { num_random } => {
format!("Random(k={})", num_random)
}
}
}
}
#[derive(Clone, Debug)]
pub struct LocalAttention {
pub config: AttentionConfig,
pub window_size: usize,
}
impl LocalAttention {
pub fn new(config: AttentionConfig, window_size: usize) -> Result<Self> {
config.validate()?;
if window_size == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "window_size must be positive".to_string(),
});
}
Ok(Self {
config,
window_size,
})
}
pub fn build_local_attention_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let scores_tensor = graph.add_tensor("local_attn_scores");
let scores_node = EinsumNode::new("bqd,bkd->bqk", vec![0, 1], vec![scores_tensor]);
graph.add_node(scores_node)?;
let scale_factor = (self.config.d_k as f64).sqrt();
let scale_tensor = graph.add_tensor("local_scale");
let scaled_tensor = graph.add_tensor("local_scaled_scores");
let scale_node = EinsumNode::elem_binary(
format!("div_scalar_{}", scale_factor),
scores_tensor,
scale_tensor,
scaled_tensor,
);
graph.add_node(scale_node)?;
let window_mask_tensor = graph.add_tensor("local_window_mask");
let masked_tensor = graph.add_tensor("local_masked_scores");
let mask_node =
EinsumNode::elem_binary("mul", scaled_tensor, window_mask_tensor, masked_tensor);
graph.add_node(mask_node)?;
let softmax_tensor = graph.add_tensor("local_attention_weights");
let softmax_node = EinsumNode::elem_unary("softmax_k", masked_tensor, softmax_tensor);
graph.add_node(softmax_node)?;
let output_tensor = graph.add_tensor("local_attn_output");
let output_node =
EinsumNode::new("bqk,bkv->bqv", vec![softmax_tensor, 2], vec![output_tensor]);
graph.add_node(output_node)?;
Ok(vec![output_tensor])
}
pub fn attention_span(&self) -> usize {
2 * self.window_size + 1
}
pub fn memory_savings(&self, seq_len: usize) -> f64 {
let full_memory = seq_len * seq_len;
let sparse_memory = seq_len * self.attention_span().min(seq_len);
1.0 - (sparse_memory as f64 / full_memory as f64)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strided_sparse_config() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = SparseAttentionGraphConfig::strided(base, 4).expect("unwrap");
assert!(matches!(
config.pattern,
SparsePatternType::Strided { stride: 4 }
));
assert!(config.validate().is_ok());
}
#[test]
fn test_local_sparse_config() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
assert!(matches!(
config.pattern,
SparsePatternType::Local { window_size: 128 }
));
assert!(config.validate().is_ok());
}
#[test]
fn test_global_local_sparse_config() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let global_positions = vec![0, 1, 2]; let config =
SparseAttentionGraphConfig::global_local(base, 64, global_positions).expect("unwrap");
assert!(config.validate().is_ok());
}
#[test]
fn test_block_sparse_config() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = SparseAttentionGraphConfig::block_sparse(base, 64).expect("unwrap");
assert!(matches!(
config.pattern,
SparsePatternType::BlockSparse { block_size: 64 }
));
assert!(config.validate().is_ok());
}
#[test]
fn test_sparse_attention_creation() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = SparseAttentionGraphConfig::strided(base, 2).expect("unwrap");
let attn = SparseAttentionGraph::new(config).expect("unwrap");
assert_eq!(attn.sparsity_factor(1024), 0.5);
}
#[test]
fn test_sparse_attention_graph_building() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
let attn = SparseAttentionGraph::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("Q");
graph.add_tensor("K");
graph.add_tensor("V");
graph.add_tensor("sparse_mask");
let outputs = attn
.build_sparse_attention_graph(&mut graph)
.expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_local_attention_creation() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let local = LocalAttention::new(config, 64).expect("unwrap");
assert_eq!(local.window_size, 64);
assert_eq!(local.attention_span(), 129);
}
#[test]
fn test_local_attention_graph_building() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let local = LocalAttention::new(config, 64).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("Q");
graph.add_tensor("K");
graph.add_tensor("V");
let outputs = local
.build_local_attention_graph(&mut graph)
.expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_sparsity_factors() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let strided = SparseAttentionGraphConfig::strided(base.clone(), 4).expect("unwrap");
let attn = SparseAttentionGraph::new(strided).expect("unwrap");
assert!((attn.sparsity_factor(1024) - 0.25).abs() < 1e-10);
let local = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
let attn = SparseAttentionGraph::new(local).expect("unwrap");
assert!((attn.sparsity_factor(1024) - 0.125).abs() < 1e-10);
}
#[test]
fn test_memory_savings() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let local = LocalAttention::new(config, 64).expect("unwrap");
let savings = local.memory_savings(1024);
assert!(savings > 0.87 && savings < 0.88);
}
#[test]
fn test_pattern_descriptions() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let strided = SparseAttentionGraphConfig::strided(base.clone(), 4).expect("unwrap");
let attn = SparseAttentionGraph::new(strided).expect("unwrap");
assert_eq!(attn.pattern_description(), "Strided(stride=4)");
let local = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
let attn = SparseAttentionGraph::new(local).expect("unwrap");
assert_eq!(attn.pattern_description(), "Local(window=128)");
}
#[test]
fn test_invalid_configs() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let result = SparseAttentionGraphConfig::strided(base.clone(), 0);
assert!(result.is_err());
let result = SparseAttentionGraphConfig::local(base.clone(), 0);
assert!(result.is_err());
let result = SparseAttentionGraphConfig::block_sparse(base, 0);
assert!(result.is_err());
}
}