#![allow(unused_variables)]
use super::common::{
AttentionConfig, AttentionOptimizationHints, AttentionProjections, AttentionUtils,
};
use crate::device::Device;
use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
use crate::traits::Layer;
#[derive(Debug, Clone)]
pub struct MultiHeadAttention {
config: AttentionConfig,
projections: AttentionProjections,
optimization_hints: AttentionOptimizationHints,
}
impl MultiHeadAttention {
pub fn new_with_device(
hidden_size: usize,
num_heads: usize,
dropout_prob: f32,
bias: bool,
device: Device,
) -> Result<Self> {
let config = AttentionConfig::new(hidden_size, num_heads, dropout_prob, bias)?;
let projections = AttentionProjections::new_with_device(&config, device);
let optimization_hints = AttentionOptimizationHints::default();
Ok(Self {
config,
projections,
optimization_hints,
})
}
pub fn new(
hidden_size: usize,
num_heads: usize,
dropout_prob: f32,
bias: bool,
) -> Result<Self> {
Self::new_with_device(hidden_size, num_heads, dropout_prob, bias, Device::CPU)
}
pub fn from_config(config: AttentionConfig) -> Result<Self> {
let projections = AttentionProjections::new(&config);
let optimization_hints = AttentionOptimizationHints::default();
Ok(Self {
config,
projections,
optimization_hints,
})
}
pub fn new_with_hints(
hidden_size: usize,
num_heads: usize,
dropout_prob: f32,
bias: bool,
optimization_hints: AttentionOptimizationHints,
) -> Result<Self> {
let config = AttentionConfig::new(hidden_size, num_heads, dropout_prob, bias)?;
let projections = AttentionProjections::new(&config);
Ok(Self {
config,
projections,
optimization_hints,
})
}
pub fn config(&self) -> &AttentionConfig {
&self.config
}
pub fn optimization_hints(&self) -> &AttentionOptimizationHints {
&self.optimization_hints
}
pub fn set_optimization_hints(&mut self, hints: AttentionOptimizationHints) {
self.optimization_hints = hints;
}
pub fn parameter_count(&self) -> usize {
self.projections.query.parameter_count()
+ self.projections.key.parameter_count()
+ self.projections.value.parameter_count()
+ self.projections.out_proj.parameter_count()
}
pub fn set_query_weight(&mut self, weight: Tensor) -> Result<()> {
self.projections.query.set_weight(weight)
}
pub fn set_query_bias(&mut self, bias: Tensor) -> Result<()> {
self.projections.query.set_bias(bias)
}
pub fn set_key_weight(&mut self, weight: Tensor) -> Result<()> {
self.projections.key.set_weight(weight)
}
pub fn set_key_bias(&mut self, bias: Tensor) -> Result<()> {
self.projections.key.set_bias(bias)
}
pub fn set_value_weight(&mut self, weight: Tensor) -> Result<()> {
self.projections.value.set_weight(weight)
}
pub fn set_value_bias(&mut self, bias: Tensor) -> Result<()> {
self.projections.value.set_bias(bias)
}
pub fn set_out_proj_weight(&mut self, weight: Tensor) -> Result<()> {
self.projections.out_proj.set_weight(weight)
}
pub fn set_out_proj_bias(&mut self, bias: Tensor) -> Result<()> {
self.projections.out_proj.set_bias(bias)
}
pub fn forward_self_attention(
&self,
input: &Tensor,
attention_mask: Option<&Tensor>,
causal: bool,
) -> Result<Tensor> {
self.forward_attention(input, input, input, attention_mask, causal)
}
pub fn forward_attention(
&self,
query_input: &Tensor,
key_input: &Tensor,
value_input: &Tensor,
attention_mask: Option<&Tensor>,
causal: bool,
) -> Result<Tensor> {
let query = self.projections.query.forward(query_input.clone())?;
let key = self.projections.key.forward(key_input.clone())?;
let value = self.projections.value.forward(value_input.clone())?;
let q = AttentionUtils::split_heads(&query, self.config.num_heads, self.config.head_dim)?;
let k = AttentionUtils::split_heads(&key, self.config.num_heads, self.config.head_dim)?;
let v = AttentionUtils::split_heads(&value, self.config.num_heads, self.config.head_dim)?;
AttentionUtils::validate_attention_dims(
&q,
&k,
&v,
self.config.num_heads,
self.config.head_dim,
)?;
let attention_output = self.compute_attention(&q, &k, &v, attention_mask, causal)?;
let combined = AttentionUtils::combine_heads(
&attention_output,
self.config.num_heads,
self.config.head_dim,
)?;
self.projections.out_proj.forward(combined)
}
fn compute_attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
attention_mask: Option<&Tensor>,
causal: bool,
) -> Result<Tensor> {
let scale = 1.0 / (self.config.head_dim as f32).sqrt();
if self.optimization_hints.use_flash_attention {
self.compute_flash_attention(q, k, v, attention_mask, causal, scale)
} else {
self.compute_standard_attention(q, k, v, attention_mask, causal, scale)
}
}
fn compute_standard_attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
attention_mask: Option<&Tensor>,
causal: bool,
scale: f32,
) -> Result<Tensor> {
let attention_weights = AttentionUtils::compute_attention_weights(q, k, scale, causal)?;
let masked_weights = if let Some(mask) = attention_mask {
self.apply_attention_mask(&attention_weights, mask)?
} else {
attention_weights
};
let dropped_weights = self.apply_dropout(&masked_weights)?;
AttentionUtils::apply_attention(&dropped_weights, v)
}
fn compute_flash_attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
attention_mask: Option<&Tensor>,
causal: bool,
scale: f32,
) -> Result<Tensor> {
let shape = q.shape();
let batch_size = shape[0];
let num_heads = shape[1];
let seq_q = shape[2];
let head_dim = shape[3];
let seq_k = k.shape()[2];
let block_size = self.compute_flash_block_size(seq_q, seq_k, head_dim);
let output = Tensor::zeros(&[batch_size, num_heads, seq_q, head_dim])?;
let num_blocks_q = seq_q.div_ceil(block_size);
let num_blocks_k = seq_k.div_ceil(block_size);
for q_block_idx in 0..num_blocks_q {
let q_start = q_block_idx * block_size;
let q_end = (q_start + block_size).min(seq_q);
let q_block = q.slice_ranges(&[
(0, batch_size),
(0, num_heads),
(q_start, q_end),
(0, head_dim),
])?;
let mut block_output =
Tensor::zeros(&[batch_size, num_heads, q_end - q_start, head_dim])?;
let mut block_max = Tensor::full(
f32::NEG_INFINITY,
vec![batch_size, num_heads, q_end - q_start, 1],
)?;
let mut block_sum = Tensor::zeros(&[batch_size, num_heads, q_end - q_start, 1])?;
for k_block_idx in 0..num_blocks_k {
let k_start = k_block_idx * block_size;
let k_end = (k_start + block_size).min(seq_k);
let k_block = k.slice_ranges(&[
(0, batch_size),
(0, num_heads),
(k_start, k_end),
(0, head_dim),
])?;
let v_block = v.slice_ranges(&[
(0, batch_size),
(0, num_heads),
(k_start, k_end),
(0, head_dim),
])?;
let attention_scores = self.compute_block_scores(
&q_block,
&k_block,
scale,
q_start,
k_start,
attention_mask,
causal,
)?;
self.update_flash_statistics(
&mut block_output,
&mut block_max,
&mut block_sum,
&attention_scores,
&v_block,
)?;
}
let normalized_output = self.normalize_flash_output(&block_output, &block_sum)?;
}
Ok(output)
}
fn apply_attention_mask(&self, attention_weights: &Tensor, mask: &Tensor) -> Result<Tensor> {
let attention_shape = attention_weights.shape();
let mask_shape = mask.shape();
let compatible_mask = if mask_shape.len() == 3 && attention_shape.len() == 4 {
let batch_size = mask_shape[0];
let seq_q = mask_shape[1];
let seq_k = mask_shape[2];
mask.reshape(&[batch_size, 1, seq_q, seq_k])?
} else {
mask.clone()
};
let mask_value = Tensor::scalar(-1e9)?;
let inverted_mask = compatible_mask.sub(&Tensor::ones(&compatible_mask.shape())?)?;
let mask_additive = inverted_mask.mul(&mask_value)?;
attention_weights.add(&mask_additive)
}
fn apply_dropout(&self, attention_weights: &Tensor) -> Result<Tensor> {
if self.config.dropout_prob > 0.0 {
attention_weights.dropout(self.config.dropout_prob)
} else {
Ok(attention_weights.clone())
}
}
pub fn estimate_memory_usage(&self, batch_size: usize, seq_len: usize) -> usize {
let attention_matrix_size = batch_size * self.config.num_heads * seq_len * seq_len;
let projection_size = batch_size * seq_len * self.config.hidden_size * 4; let intermediate_size =
batch_size * self.config.num_heads * seq_len * self.config.head_dim * 3;
(attention_matrix_size + projection_size + intermediate_size) * 4 }
pub fn update_optimization_hints(
&mut self,
batch_size: usize,
seq_len: usize,
available_memory_mb: Option<usize>,
) {
self.optimization_hints =
AttentionOptimizationHints::for_sequence_length(seq_len, available_memory_mb);
let estimated_memory_mb = self.estimate_memory_usage(batch_size, seq_len) / (1024 * 1024);
if let Some(available_mb) = available_memory_mb {
if estimated_memory_mb > available_mb / 2 {
self.optimization_hints.use_flash_attention = true;
self.optimization_hints.use_half_precision = true;
}
}
}
fn compute_flash_block_size(&self, seq_q: usize, seq_k: usize, head_dim: usize) -> usize {
let base_size = 128;
if seq_q > 2048 || seq_k > 2048 {
base_size * 2 } else if seq_q < 128 && seq_k < 128 {
base_size / 2 } else {
base_size
}
}
fn compute_block_scores(
&self,
q_block: &Tensor,
k_block: &Tensor,
scale: f32,
q_offset: usize,
k_offset: usize,
attention_mask: Option<&Tensor>,
causal: bool,
) -> Result<Tensor> {
let scores = q_block.matmul(&k_block.transpose(2, 3)?)?;
let scaled_scores = scores.mul(&Tensor::scalar(scale)?)?;
let masked_scores = if causal {
self.apply_block_causal_mask(&scaled_scores, q_offset, k_offset)?
} else {
scaled_scores
};
if let Some(mask) = attention_mask {
let mask_block = self.extract_attention_mask_block(
mask,
q_offset,
k_offset,
q_block.shape()[2],
k_block.shape()[2],
)?;
masked_scores.add(&mask_block)
} else {
Ok(masked_scores)
}
}
fn apply_block_causal_mask(
&self,
scores: &Tensor,
q_offset: usize,
k_offset: usize,
) -> Result<Tensor> {
let shape = scores.shape();
let q_block_size = shape[shape.len() - 2];
let k_block_size = shape[shape.len() - 1];
let mut mask_data = vec![0.0f32; q_block_size * k_block_size];
for i in 0..q_block_size {
for j in 0..k_block_size {
let global_q_pos = q_offset + i;
let global_k_pos = k_offset + j;
if global_k_pos > global_q_pos {
mask_data[i * k_block_size + j] = f32::NEG_INFINITY;
}
}
}
let causal_mask = Tensor::from_vec(mask_data, &[q_block_size, k_block_size])?;
scores.add(&causal_mask)
}
fn extract_attention_mask_block(
&self,
mask: &Tensor,
q_offset: usize,
k_offset: usize,
q_block_size: usize,
k_block_size: usize,
) -> Result<Tensor> {
let mask_shape = mask.shape();
if mask_shape.len() == 2 {
mask.slice_ranges(&[
(q_offset, q_offset + q_block_size),
(k_offset, k_offset + k_block_size),
])
} else if mask_shape.len() == 3 {
mask.slice_ranges(&[
(0, mask_shape[0]),
(q_offset, q_offset + q_block_size),
(k_offset, k_offset + k_block_size),
])
} else if mask_shape.len() == 4 {
mask.slice_ranges(&[
(0, mask_shape[0]),
(0, mask_shape[1]),
(q_offset, q_offset + q_block_size),
(k_offset, k_offset + k_block_size),
])
} else {
Err(TrustformersError::tensor_op_error(
&format!("Unsupported attention mask shape: {:?}", mask_shape),
"extract_attention_mask_block",
))
}
}
fn update_flash_statistics(
&self,
block_output: &mut Tensor,
block_max: &mut Tensor,
block_sum: &mut Tensor,
attention_scores: &Tensor,
v_block: &Tensor,
) -> Result<()> {
let scores_max_val = attention_scores.max_value()?;
let new_max = block_max.max(&scores_max_val)?;
let scores_shifted = attention_scores.sub(&new_max)?;
let scores_exp = scores_shifted.exp()?;
let old_sum_correction = block_max.sub(&new_max)?.exp()?;
let corrected_old_sum = block_sum.mul(&old_sum_correction)?;
let new_contribution = scores_exp.sum(None, false)?; let updated_sum = corrected_old_sum.add(&new_contribution)?;
let old_output_correction = block_output.mul(&old_sum_correction)?;
let new_output_contribution = scores_exp.matmul(v_block)?;
let updated_output = old_output_correction.add(&new_output_contribution)?;
*block_max = new_max;
*block_sum = updated_sum;
*block_output = updated_output;
Ok(())
}
fn normalize_flash_output(&self, block_output: &Tensor, block_sum: &Tensor) -> Result<Tensor> {
let epsilon = Tensor::scalar(1e-8)?;
let safe_sum = block_sum.add(&epsilon)?;
block_output.div(&safe_sum)
}
}
impl Layer for MultiHeadAttention {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
self.forward_self_attention(&input, None, false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_multi_head_attention_creation() {
let attention =
MultiHeadAttention::new(512, 8, 0.1, true).expect("operation failed in test");
assert_eq!(attention.config.hidden_size, 512);
assert_eq!(attention.config.num_heads, 8);
assert_eq!(attention.config.head_dim, 64);
}
#[test]
fn test_invalid_head_configuration() {
let result = MultiHeadAttention::new(512, 7, 0.1, true);
assert!(result.is_err());
}
#[test]
fn test_self_attention_forward() {
let attention =
MultiHeadAttention::new(512, 8, 0.1, true).expect("operation failed in test");
let input = Tensor::randn(&[2, 10, 512]).expect("Failed to create random tensor");
let output = attention.forward(input).expect("Forward pass failed");
assert_eq!(output.shape(), vec![2, 10, 512]);
}
#[test]
fn test_memory_estimation() {
let attention =
MultiHeadAttention::new(512, 8, 0.1, true).expect("operation failed in test");
let memory_usage = attention.estimate_memory_usage(2, 100);
assert!(memory_usage > 0);
}
#[test]
fn test_optimization_hints_update() {
let mut attention =
MultiHeadAttention::new(512, 8, 0.1, true).expect("operation failed in test");
attention.update_optimization_hints(2, 2048, Some(1024));
assert!(attention.optimization_hints.use_flash_attention);
}
}