use trustformers_core::{errors::Result, tensor::Tensor};
#[derive(Debug, Clone)]
pub struct CrossAttentionOutput {
pub output: Tensor,
pub attention_weights: Option<Tensor>,
pub attention_stats: Option<AttentionStats>,
}
#[derive(Debug, Clone)]
pub struct AttentionStats {
pub entropy: f32,
pub max_weight: f32,
pub min_weight: f32,
pub sparsity: f32,
pub head_stats: Vec<HeadStats>,
}
#[derive(Debug, Clone)]
pub struct HeadStats {
pub head_idx: usize,
pub entropy: f32,
pub sparsity: f32,
pub top_positions: Vec<usize>,
}
pub fn create_attention_mask(
query_len: usize,
key_len: usize,
mask_type: MaskType,
) -> Result<Tensor> {
match mask_type {
MaskType::None => {
Ok(Tensor::zeros(&[query_len, key_len])?)
},
MaskType::Causal => {
let mut mask = vec![vec![0.0f32; key_len]; query_len];
for (i, row) in mask.iter_mut().enumerate() {
for j in (i + 1)..key_len.min(query_len) {
row[j] = f32::NEG_INFINITY;
}
}
let flattened: Vec<f32> = mask.into_iter().flatten().collect();
Ok(Tensor::from_vec(flattened, &[query_len, key_len])?)
},
MaskType::Local(window_size) => {
let mut mask = vec![vec![f32::NEG_INFINITY; key_len]; query_len];
for (i, row) in mask.iter_mut().enumerate() {
let start = i.saturating_sub(window_size / 2);
let end = (i + window_size / 2 + 1).min(key_len);
for j in start..end {
row[j] = 0.0;
}
}
let flattened: Vec<f32> = mask.into_iter().flatten().collect();
Ok(Tensor::from_vec(flattened, &[query_len, key_len])?)
},
MaskType::Custom(custom_mask) => Ok(custom_mask),
}
}
#[derive(Debug, Clone)]
pub enum MaskType {
None,
Causal,
Local(usize),
Custom(Tensor),
}
pub fn create_sparse_mask(
query_len: usize,
key_len: usize,
sparsity_ratio: f32,
pattern: SparsePattern,
) -> Result<Tensor> {
match pattern {
SparsePattern::Random => create_random_sparse_mask(query_len, key_len, sparsity_ratio),
SparsePattern::Block(block_size) => {
create_block_sparse_mask(query_len, key_len, block_size)
},
SparsePattern::Strided(stride) => create_strided_sparse_mask(query_len, key_len, stride),
SparsePattern::TopK(k) => create_topk_sparse_mask(query_len, key_len, k),
}
}
#[derive(Debug, Clone)]
pub enum SparsePattern {
Random,
Block(usize),
Strided(usize),
TopK(usize),
}
fn create_random_sparse_mask(
query_len: usize,
key_len: usize,
sparsity_ratio: f32,
) -> Result<Tensor> {
let mut mask = vec![vec![f32::NEG_INFINITY; key_len]; query_len];
let keep_ratio = 1.0 - sparsity_ratio;
for (i, row) in mask.iter_mut().enumerate() {
for (j, val) in row.iter_mut().enumerate() {
if (i + j) % 10 < (keep_ratio * 10.0) as usize {
*val = 0.0;
}
}
}
let flattened: Vec<f32> = mask.into_iter().flatten().collect();
Tensor::from_vec(flattened, &[query_len, key_len])
}
fn create_block_sparse_mask(query_len: usize, key_len: usize, block_size: usize) -> Result<Tensor> {
let mut mask = vec![vec![f32::NEG_INFINITY; key_len]; query_len];
for (i, row) in mask.iter_mut().enumerate() {
for (j, val) in row.iter_mut().enumerate() {
let qi = i / block_size;
let kj = j / block_size;
if qi == kj || qi.abs_diff(kj) <= 1 {
*val = 0.0;
}
}
}
let flattened: Vec<f32> = mask.into_iter().flatten().collect();
Tensor::from_vec(flattened, &[query_len, key_len])
}
fn create_strided_sparse_mask(query_len: usize, key_len: usize, stride: usize) -> Result<Tensor> {
let mut mask = vec![vec![f32::NEG_INFINITY; key_len]; query_len];
for (i, row) in mask.iter_mut().enumerate() {
for (j, val) in row.iter_mut().enumerate() {
if j % stride == i % stride {
*val = 0.0;
}
}
}
let flattened: Vec<f32> = mask.into_iter().flatten().collect();
Tensor::from_vec(flattened, &[query_len, key_len])
}
fn create_topk_sparse_mask(query_len: usize, key_len: usize, k: usize) -> Result<Tensor> {
let mut mask = vec![vec![f32::NEG_INFINITY; key_len]; query_len];
for (i, row) in mask.iter_mut().enumerate() {
let start = i.saturating_sub(k / 2);
let end = (i + k / 2 + 1).min(key_len);
for j in start..end {
row[j] = 0.0;
}
}
let flattened: Vec<f32> = mask.into_iter().flatten().collect();
Tensor::from_vec(flattened, &[query_len, key_len])
}
pub fn create_hierarchical_mask(
query_len: usize,
key_len: usize,
num_levels: usize,
pooling_factor: usize,
) -> Result<Vec<Tensor>> {
let mut masks = Vec::new();
for level in 0..num_levels {
let level_pooling = pooling_factor.pow(level as u32);
let level_query_len = query_len.div_ceil(level_pooling);
let level_key_len = key_len.div_ceil(level_pooling);
let mask = create_attention_mask(level_query_len, level_key_len, MaskType::None)?;
masks.push(mask);
}
Ok(masks)
}
pub fn compute_attention_stats(
attention_weights: &Tensor,
num_heads: usize,
) -> Result<AttentionStats> {
let shape = attention_weights.shape();
let _batch_size = shape[0];
let _seq_len = shape[2];
let entropy = compute_entropy(attention_weights)?;
let (min_weight, max_weight) = compute_min_max(attention_weights)?;
let sparsity = compute_sparsity(attention_weights, 1e-6)?;
let mut head_stats = Vec::new();
for head in 0..num_heads {
let head_weights = attention_weights.select(1, head as i64)?;
let head_entropy = compute_entropy(&head_weights)?;
let head_sparsity = compute_sparsity(&head_weights, 1e-6)?;
let top_positions = compute_top_positions(&head_weights, 5)?;
head_stats.push(HeadStats {
head_idx: head,
entropy: head_entropy,
sparsity: head_sparsity,
top_positions,
});
}
Ok(AttentionStats {
entropy,
max_weight,
min_weight,
sparsity,
head_stats,
})
}
fn compute_entropy(_tensor: &Tensor) -> Result<f32> {
Ok(0.5) }
fn compute_min_max(_tensor: &Tensor) -> Result<(f32, f32)> {
Ok((0.0, 1.0)) }
fn compute_sparsity(_tensor: &Tensor, _threshold: f32) -> Result<f32> {
Ok(0.1) }
fn compute_top_positions(_tensor: &Tensor, _k: usize) -> Result<Vec<usize>> {
Ok(vec![0, 1, 2, 3, 4]) }
pub fn apply_attention_dropout(
attention_weights: Tensor,
dropout_rate: f32,
training: bool,
) -> Result<Tensor> {
if training && dropout_rate > 0.0 {
attention_weights.dropout(dropout_rate)
} else {
Ok(attention_weights)
}
}
pub fn scaled_dot_product_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
mask: Option<&Tensor>,
scale: f32,
dropout_rate: f32,
training: bool,
) -> Result<CrossAttentionOutput> {
let key_shape = key.shape();
let dim0 = key_shape.len().saturating_sub(2);
let dim1 = key_shape.len().saturating_sub(1);
let scores = query.matmul(&key.transpose(dim0, dim1)?)?;
let scores = scores.mul_scalar(scale)?;
let scores = if let Some(mask) = mask { scores.add(mask)? } else { scores };
let attention_weights = scores.softmax(-1)?;
let attention_weights = apply_attention_dropout(attention_weights, dropout_rate, training)?;
let output = attention_weights.matmul(value)?;
Ok(CrossAttentionOutput {
output,
attention_weights: Some(attention_weights),
attention_stats: None,
})
}
pub fn reshape_for_multihead(tensor: Tensor, num_heads: usize, head_dim: usize) -> Result<Tensor> {
let shape = tensor.shape();
let batch_size = shape[0];
let seq_len = shape[1];
tensor.reshape(&[batch_size, seq_len, num_heads, head_dim])?.transpose(1, 2)
}
pub fn reshape_from_multihead(tensor: Tensor, hidden_size: usize) -> Result<Tensor> {
let shape = tensor.shape();
let batch_size = shape[0];
let seq_len = shape[2];
tensor.transpose(1, 2)?.reshape(&[batch_size, seq_len, hidden_size])
}
pub fn pool_tensor(tensor: Tensor, pooling_factor: usize, method: PoolingMethod) -> Result<Tensor> {
match method {
PoolingMethod::Average => average_pool_1d(tensor, pooling_factor),
PoolingMethod::Max => max_pool_1d(tensor, pooling_factor),
PoolingMethod::Learnable => {
average_pool_1d(tensor, pooling_factor)
},
}
}
#[derive(Debug, Clone)]
pub enum PoolingMethod {
Average,
Max,
Learnable,
}
fn average_pool_1d(tensor: Tensor, _pooling_factor: usize) -> Result<Tensor> {
Ok(tensor)
}
fn max_pool_1d(tensor: Tensor, _pooling_factor: usize) -> Result<Tensor> {
Ok(tensor)
}
pub fn interpolate_tensor(
tensor: Tensor,
target_length: usize,
method: InterpolationMethod,
) -> Result<Tensor> {
match method {
InterpolationMethod::Linear => linear_interpolate(tensor, target_length),
InterpolationMethod::Nearest => nearest_interpolate(tensor, target_length),
}
}
#[derive(Debug, Clone)]
pub enum InterpolationMethod {
Linear,
Nearest,
}
fn linear_interpolate(tensor: Tensor, _target_length: usize) -> Result<Tensor> {
Ok(tensor)
}
fn nearest_interpolate(tensor: Tensor, _target_length: usize) -> Result<Tensor> {
Ok(tensor)
}